Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,55 @@
[package]
name = "ruvector-hyperbolic-hnsw"
version = "0.1.0"
edition = "2021"
rust-version = "1.77"
license = "MIT OR Apache-2.0"
authors = ["RuVector Team <team@ruvector.dev>"]
repository = "https://github.com/ruvnet/ruvector"
homepage = "https://ruv.io/ruvector"
documentation = "https://docs.rs/ruvector-hyperbolic-hnsw"
description = "Hyperbolic (Poincare ball) embeddings with HNSW integration for hierarchy-aware vector search, enabling efficient similarity search in non-Euclidean spaces for taxonomies, ontologies, and hierarchical data"
keywords = ["hyperbolic", "poincare", "hnsw", "vector-search", "hierarchy"]
categories = ["algorithms", "science", "mathematics"]
readme = "README.md"
[lib]
crate-type = ["rlib"]
[features]
default = ["simd", "parallel"]
simd = []
parallel = ["rayon"]
wasm = []
[dependencies]
# Math and numerics (exact versions as specified)
nalgebra = "0.34.1"
ndarray = "0.17.1"
# Parallel processing
rayon = { version = "1.10", optional = true }
# Serialization
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
# Error handling
thiserror = "2.0"
# Random number generation
rand = "0.8"
rand_distr = "0.4"
[dev-dependencies]
criterion = { version = "0.5", features = ["html_reports"] }
approx = "0.5"
proptest = "1.5"
[[bench]]
name = "hyperbolic_bench"
harness = false
[[test]]
name = "math_tests"
path = "tests/math_tests.rs"

View File

@@ -0,0 +1,242 @@
# ruvector-hyperbolic-hnsw
Hyperbolic (Poincaré ball) embeddings with HNSW integration for hierarchy-aware vector search.
## Why Hyperbolic Space?
Hierarchies compress naturally in hyperbolic space. Taxonomies, catalogs, ICD trees, product facets, org charts, and long-tail tags all fit better than in Euclidean space, which means higher recall on deep leaves without blowing up memory or latency.
## Key Features
- **Poincaré Ball Model**: Store vectors in the Poincaré ball with clamp `r < 1 eps`
- **HNSW Speed Trick**: Prune with cheap tangent-space proxy, rank with true hyperbolic distance
- **Per-Shard Curvature**: Different parts of the hierarchy can have different optimal curvatures
- **Dual-Space Index**: Keep a synchronized Euclidean ANN for fallback and mutual-ranking fusion
- **Production Guardrails**: Numerical stability, canary testing, hot curvature reload
## Installation
### Rust
```toml
[dependencies]
ruvector-hyperbolic-hnsw = "0.1.0"
```
### WebAssembly
```bash
cd crates/ruvector-hyperbolic-hnsw-wasm
wasm-pack build --target web --release
```
### TypeScript/JavaScript
```typescript
import init, {
HyperbolicIndex,
poincareDistance,
mobiusAdd,
expMap,
logMap
} from 'ruvector-hyperbolic-hnsw-wasm';
await init();
const index = new HyperbolicIndex(16, 1.0);
index.insert(new Float32Array([0.1, 0.2, 0.3]));
const results = index.search(new Float32Array([0.15, 0.1, 0.2]), 5);
```
## Quick Start
```rust
use ruvector_hyperbolic_hnsw::{HyperbolicHnsw, HyperbolicHnswConfig};
// Create index with default settings
let mut index = HyperbolicHnsw::default_config();
// Insert vectors (automatically projected to Poincaré ball)
index.insert(vec![0.1, 0.2, 0.3]).unwrap();
index.insert(vec![-0.1, 0.15, 0.25]).unwrap();
index.insert(vec![0.2, -0.1, 0.1]).unwrap();
// Search for nearest neighbors
let results = index.search(&[0.15, 0.1, 0.2], 2).unwrap();
for r in results {
println!("ID: {}, Distance: {:.4}", r.id, r.distance);
}
```
## HNSW Speed Trick
The core optimization:
1. Precompute `u = log_c(x)` at a shard centroid `c`
2. During neighbor selection, use Euclidean `||u_q - u_p||` to prune
3. Run exact Poincaré distance only on top N candidates before final ranking
```rust
use ruvector_hyperbolic_hnsw::{HyperbolicHnsw, HyperbolicHnswConfig};
let mut config = HyperbolicHnswConfig::default();
config.use_tangent_pruning = true;
config.prune_factor = 10; // Consider 10x candidates in tangent space
let mut index = HyperbolicHnsw::new(config);
// ... insert vectors ...
// Build tangent cache for pruning optimization
index.build_tangent_cache().unwrap();
// Search with pruning (faster!)
let results = index.search_with_pruning(&[0.1, 0.15], 5).unwrap();
```
## Core Mathematical Operations
```rust
use ruvector_hyperbolic_hnsw::poincare::{
mobius_add, exp_map, log_map, poincare_distance, project_to_ball
};
let x = vec![0.3, 0.2];
let y = vec![-0.1, 0.4];
let c = 1.0; // Curvature
// Möbius addition (hyperbolic vector addition)
let z = mobius_add(&x, &y, c);
// Geodesic distance in hyperbolic space
let d = poincare_distance(&x, &y, c);
// Map to tangent space at x
let v = log_map(&y, &x, c);
// Map back to manifold
let y_recovered = exp_map(&v, &x, c);
```
## Sharded Index with Per-Shard Curvature
```rust
use ruvector_hyperbolic_hnsw::{ShardedHyperbolicHnsw, ShardStrategy};
let mut manager = ShardedHyperbolicHnsw::new(1.0);
// Insert with hierarchy depth information
manager.insert(vec![0.1, 0.2], Some(0)).unwrap(); // Root level
manager.insert(vec![0.3, 0.1], Some(3)).unwrap(); // Deeper level
// Update curvature for specific shard
manager.update_curvature("radius_1", 0.5).unwrap();
// Canary testing for new curvature
manager.registry.set_canary("radius_1", 0.3, 10); // 10% traffic
// Search across all shards
let results = manager.search(&[0.2, 0.15], 5).unwrap();
```
## Numerical Stability
All operations include numerical safeguards:
- **Norm clamping**: Points projected with `eps = 1e-5`
- **Projection after updates**: All operations keep points inside the ball
- **Stable acosh**: Uses `log1p` expansions for safety
- **Clamp arguments**: `arctanh` and `atanh` arguments bounded away from ±1
## Evaluation Protocol
### Datasets
- WordNet
- DBpedia slices
- Synthetic scale-free tree
- Domain taxonomy
### Primary Metrics
- **recall@k** (1, 5, 10)
- **Mean rank**
- **NDCG**
### Hierarchy Metrics
- **Radius vs depth Spearman correlation**
- **Distance distortion**
- **Ancestor AUPRC**
### Baselines
- Euclidean HNSW
- OPQ/PQ compressed
- Simple mutual-ranking fusion
### Ablations
- Tangent proxy vs full hyperbolic
- Fixed vs learnable curvature c
- Global vs shard centroids
## Production Integration
### Reflex Loop (on writes)
Small Möbius deltas and tangent-space micro updates that never push points outside the ball.
```rust
use ruvector_hyperbolic_hnsw::tangent_micro_update;
let updated = tangent_micro_update(
&point,
&delta,
&centroid,
curvature,
0.1 // max step size
);
```
### Habit (nightly)
Riemannian SGD passes to clean neighborhoods and optionally relearn per-shard curvature. Run canary first.
### Structural (periodic)
Rebuild of HNSW with true hyperbolic metric, curvature retune, and shard reshuffle if hierarchy preservation drops below SLO.
## Dependencies (Exact Versions)
```toml
nalgebra = "0.34.1"
ndarray = "0.17.1"
wasm-bindgen = "0.2.106"
```
## Benchmarks
```bash
cd crates/ruvector-hyperbolic-hnsw
cargo bench
```
Benchmark suite includes:
- Poincaré distance computation
- Möbius addition
- exp/log map operations
- HNSW insert and search
- Tangent cache building
- Search with vs without pruning
## License
MIT
## Related
- [ruvector-attention](../ruvector-attention) - Hyperbolic attention mechanisms
- [micro-hnsw-wasm](../micro-hnsw-wasm) - Minimal HNSW for WASM
- [ruvector-math](../ruvector-math) - General math primitives

View File

@@ -0,0 +1,178 @@
//! Benchmarks for hyperbolic HNSW operations
//!
//! Metrics as specified in evaluation protocol:
//! - p50 and p95 latency
//! - Memory overhead
//! - Search recall@k
use criterion::{black_box, criterion_group, criterion_main, Criterion, BenchmarkId};
use ruvector_hyperbolic_hnsw::*;
fn bench_poincare_distance(c: &mut Criterion) {
let dims = [8, 32, 128, 512];
let mut group = c.benchmark_group("poincare_distance");
for dim in dims {
let x: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.01) % 0.9).collect();
let y: Vec<f32> = (0..dim).map(|i| ((i as f32 * 0.02) + 0.1) % 0.9).collect();
group.bench_with_input(BenchmarkId::new("dim", dim), &dim, |b, _| {
b.iter(|| poincare_distance(black_box(&x), black_box(&y), 1.0))
});
}
group.finish();
}
fn bench_mobius_add(c: &mut Criterion) {
let dims = [8, 32, 128];
let mut group = c.benchmark_group("mobius_add");
for dim in dims {
let x: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.01) % 0.5).collect();
let y: Vec<f32> = (0..dim).map(|i| ((i as f32 * 0.02) + 0.1) % 0.5).collect();
group.bench_with_input(BenchmarkId::new("dim", dim), &dim, |b, _| {
b.iter(|| mobius_add(black_box(&x), black_box(&y), 1.0))
});
}
group.finish();
}
fn bench_exp_log_map(c: &mut Criterion) {
let dim = 32;
let p: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.01) % 0.3).collect();
let v: Vec<f32> = (0..dim).map(|i| ((i as f32 * 0.005) - 0.1) % 0.2).collect();
let q: Vec<f32> = (0..dim).map(|i| ((i as f32 * 0.02) + 0.1) % 0.4).collect();
let mut group = c.benchmark_group("exp_log_map");
group.bench_function("exp_map", |b| {
b.iter(|| exp_map(black_box(&v), black_box(&p), 1.0))
});
group.bench_function("log_map", |b| {
b.iter(|| log_map(black_box(&q), black_box(&p), 1.0))
});
group.finish();
}
fn bench_hnsw_insert(c: &mut Criterion) {
let sizes = [100, 500, 1000];
let mut group = c.benchmark_group("hnsw_insert");
group.sample_size(20);
for size in sizes {
let vectors: Vec<Vec<f32>> = (0..size)
.map(|i| vec![
(i as f32 * 0.01) % 0.8,
((i as f32 * 0.02) + 0.1) % 0.8,
])
.collect();
group.bench_with_input(BenchmarkId::new("n", size), &vectors, |b, vecs| {
b.iter(|| {
let mut hnsw = HyperbolicHnsw::default_config();
for v in vecs {
hnsw.insert(v.clone()).unwrap();
}
})
});
}
group.finish();
}
fn bench_hnsw_search(c: &mut Criterion) {
let ks = [1, 5, 10, 50];
// Build index once
let mut hnsw = HyperbolicHnsw::default_config();
for i in 0..1000 {
let v = vec![
(i as f32 * 0.01) % 0.8,
((i as f32 * 0.02) + 0.1) % 0.8,
];
hnsw.insert(v).unwrap();
}
let query = vec![0.4, 0.4];
let mut group = c.benchmark_group("hnsw_search");
for k in ks {
group.bench_with_input(BenchmarkId::new("k", k), &k, |b, &k| {
b.iter(|| hnsw.search(black_box(&query), k))
});
}
group.finish();
}
fn bench_tangent_cache(c: &mut Criterion) {
let sizes = [100, 500, 1000];
let mut group = c.benchmark_group("tangent_cache");
group.sample_size(20);
for size in sizes {
let points: Vec<Vec<f32>> = (0..size)
.map(|i| vec![
(i as f32 * 0.01) % 0.8,
((i as f32 * 0.02) + 0.1) % 0.8,
])
.collect();
let indices: Vec<usize> = (0..size).collect();
group.bench_with_input(BenchmarkId::new("build", size), &(points.clone(), indices.clone()), |b, (p, i)| {
b.iter(|| TangentCache::new(black_box(p), black_box(i), 1.0))
});
}
group.finish();
}
fn bench_search_with_pruning(c: &mut Criterion) {
// Build index with tangent cache
let mut hnsw = HyperbolicHnsw::default_config();
for i in 0..1000 {
let v = vec![
(i as f32 * 0.01) % 0.8,
((i as f32 * 0.02) + 0.1) % 0.8,
];
hnsw.insert(v).unwrap();
}
hnsw.build_tangent_cache().unwrap();
let query = vec![0.4, 0.4];
let mut group = c.benchmark_group("search_comparison");
group.bench_function("standard_search", |b| {
b.iter(|| hnsw.search(black_box(&query), 10))
});
group.bench_function("pruning_search", |b| {
b.iter(|| hnsw.search_with_pruning(black_box(&query), 10))
});
group.finish();
}
criterion_group!(
benches,
bench_poincare_distance,
bench_mobius_add,
bench_exp_log_map,
bench_hnsw_insert,
bench_hnsw_search,
bench_tangent_cache,
bench_search_with_pruning,
);
criterion_main!(benches);

View File

@@ -0,0 +1,42 @@
//! Error types for hyperbolic HNSW operations
use thiserror::Error;
/// Errors that can occur during hyperbolic operations
#[derive(Error, Debug, Clone)]
pub enum HyperbolicError {
/// Vector is outside the Poincaré ball
#[error("Vector norm {norm} exceeds ball radius (1/sqrt(c) - eps) for curvature c={curvature}")]
OutsideBall { norm: f32, curvature: f32 },
/// Invalid curvature parameter
#[error("Invalid curvature: {0}. Must be positive.")]
InvalidCurvature(f32),
/// Dimension mismatch between vectors
#[error("Dimension mismatch: expected {expected}, got {got}")]
DimensionMismatch { expected: usize, got: usize },
/// Numerical instability detected
#[error("Numerical instability: {0}")]
NumericalInstability(String),
/// Shard not found
#[error("Shard not found: {0}")]
ShardNotFound(String),
/// Index out of bounds
#[error("Index {index} out of bounds for size {size}")]
IndexOutOfBounds { index: usize, size: usize },
/// Empty collection
#[error("Cannot perform operation on empty collection")]
EmptyCollection,
/// Search failed
#[error("Search failed: {0}")]
SearchFailed(String),
}
/// Result type for hyperbolic operations
pub type HyperbolicResult<T> = Result<T, HyperbolicError>;

View File

@@ -0,0 +1,650 @@
//! HNSW Adapter with Hyperbolic Distance Support
//!
//! This module provides HNSW (Hierarchical Navigable Small World) graph
//! implementation optimized for hyperbolic space using the Poincaré ball model.
//!
//! # Key Features
//!
//! - Hyperbolic distance metric for neighbor selection
//! - Tangent space pruning for accelerated search
//! - Configurable curvature per index
//! - Dual-space search (Euclidean fallback)
use crate::error::{HyperbolicError, HyperbolicResult};
use crate::poincare::{fused_norms, norm_squared, poincare_distance, poincare_distance_from_norms, project_to_ball, EPS};
use crate::tangent::TangentCache;
use serde::{Deserialize, Serialize};
#[cfg(feature = "parallel")]
use rayon::prelude::*;
/// Distance metric type
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum DistanceMetric {
/// Poincaré ball hyperbolic distance
Poincare,
/// Standard Euclidean distance
Euclidean,
/// Cosine similarity (converted to distance)
Cosine,
/// Hybrid: Euclidean for pruning, Poincaré for ranking
Hybrid,
}
/// HNSW configuration for hyperbolic space
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HyperbolicHnswConfig {
/// Maximum number of connections per node (M parameter)
pub max_connections: usize,
/// Maximum connections for layer 0 (M0 = 2*M typically)
pub max_connections_0: usize,
/// Size of dynamic candidate list during construction (ef_construction)
pub ef_construction: usize,
/// Size of dynamic candidate list during search (ef)
pub ef_search: usize,
/// Level multiplier for layer selection (ml = 1/ln(M))
pub level_mult: f32,
/// Curvature parameter for Poincaré ball
pub curvature: f32,
/// Distance metric
pub metric: DistanceMetric,
/// Pruning factor for tangent space optimization
pub prune_factor: usize,
/// Whether to use tangent space pruning
pub use_tangent_pruning: bool,
}
impl Default for HyperbolicHnswConfig {
fn default() -> Self {
Self {
max_connections: 16,
max_connections_0: 32,
ef_construction: 200,
ef_search: 50,
level_mult: 1.0 / (16.0_f32).ln(),
curvature: 1.0,
metric: DistanceMetric::Poincare,
prune_factor: 10,
use_tangent_pruning: true,
}
}
}
/// A node in the HNSW graph
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HnswNode {
/// Node ID
pub id: usize,
/// Vector in Poincaré ball
pub vector: Vec<f32>,
/// Connections at each level (level -> neighbor ids)
pub connections: Vec<Vec<usize>>,
/// Maximum level this node appears in
pub level: usize,
}
impl HnswNode {
pub fn new(id: usize, vector: Vec<f32>, max_level: usize) -> Self {
let connections = (0..=max_level).map(|_| Vec::new()).collect();
Self {
id,
vector,
connections,
level: max_level,
}
}
}
/// Search result with distance
#[derive(Debug, Clone)]
pub struct SearchResult {
pub id: usize,
pub distance: f32,
}
impl PartialEq for SearchResult {
fn eq(&self, other: &Self) -> bool {
self.distance == other.distance
}
}
impl Eq for SearchResult {}
impl PartialOrd for SearchResult {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.distance.partial_cmp(&other.distance)
}
}
impl Ord for SearchResult {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.distance.partial_cmp(&other.distance).unwrap()
}
}
/// Hyperbolic HNSW Index
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HyperbolicHnsw {
/// Configuration
pub config: HyperbolicHnswConfig,
/// All nodes in the graph
nodes: Vec<HnswNode>,
/// Entry point node ID
entry_point: Option<usize>,
/// Maximum level in the graph
max_level: usize,
/// Tangent cache for pruning (not serialized)
#[serde(skip)]
tangent_cache: Option<TangentCache>,
}
impl HyperbolicHnsw {
/// Create a new empty HNSW index
pub fn new(config: HyperbolicHnswConfig) -> Self {
Self {
config,
nodes: Vec::new(),
entry_point: None,
max_level: 0,
tangent_cache: None,
}
}
/// Create with default configuration
pub fn default_config() -> Self {
Self::new(HyperbolicHnswConfig::default())
}
/// Get the number of nodes in the index
pub fn len(&self) -> usize {
self.nodes.len()
}
/// Check if the index is empty
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
/// Get the dimension of vectors
pub fn dim(&self) -> Option<usize> {
self.nodes.first().map(|n| n.vector.len())
}
/// Compute distance between two vectors (optimized with fused norms)
#[inline]
fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
match self.config.metric {
DistanceMetric::Poincare | DistanceMetric::Hybrid => {
// Use fused_norms for single-pass computation
let (diff_sq, norm_a_sq, norm_b_sq) = fused_norms(a, b);
poincare_distance_from_norms(diff_sq, norm_a_sq, norm_b_sq, self.config.curvature)
}
DistanceMetric::Euclidean => {
let (diff_sq, _, _) = fused_norms(a, b);
diff_sq.sqrt()
}
DistanceMetric::Cosine => {
let len = a.len().min(b.len());
let mut dot_ab = 0.0f32;
let mut norm_a_sq = 0.0f32;
let mut norm_b_sq = 0.0f32;
// Fused computation
for i in 0..len {
let ai = a[i];
let bi = b[i];
dot_ab += ai * bi;
norm_a_sq += ai * ai;
norm_b_sq += bi * bi;
}
let norm_prod = (norm_a_sq * norm_b_sq).sqrt();
1.0 - dot_ab / (norm_prod + EPS)
}
}
}
/// Compute distance with pre-computed query norm (for batch search)
#[inline]
fn distance_with_query_norm(&self, query: &[f32], query_norm_sq: f32, point: &[f32]) -> f32 {
match self.config.metric {
DistanceMetric::Poincare | DistanceMetric::Hybrid => {
let (diff_sq, _, point_norm_sq) = fused_norms(query, point);
poincare_distance_from_norms(diff_sq, query_norm_sq, point_norm_sq, self.config.curvature)
}
_ => self.distance(query, point)
}
}
/// Generate random level for a new node
fn random_level(&self) -> usize {
let r: f32 = rand::random();
(-r.ln() * self.config.level_mult) as usize
}
/// Insert a vector into the index
pub fn insert(&mut self, vector: Vec<f32>) -> HyperbolicResult<usize> {
// Project to ball for safety
let vector = project_to_ball(&vector, self.config.curvature, EPS);
let id = self.nodes.len();
let level = self.random_level();
// Create new node
let node = HnswNode::new(id, vector.clone(), level);
self.nodes.push(node);
if self.entry_point.is_none() {
self.entry_point = Some(id);
self.max_level = level;
return Ok(id);
}
let entry_id = self.entry_point.unwrap();
// Search for entry point at top levels
let mut current = entry_id;
for l in (level + 1..=self.max_level).rev() {
current = self.search_layer_single(&vector, current, l)?;
}
// Insert at levels [0, min(level, max_level)]
let insert_level = level.min(self.max_level);
for l in (0..=insert_level).rev() {
let neighbors = self.search_layer(&vector, current, self.config.ef_construction, l)?;
// Select best neighbors
let max_conn = if l == 0 {
self.config.max_connections_0
} else {
self.config.max_connections
};
let selected: Vec<usize> = neighbors.iter().take(max_conn).map(|r| r.id).collect();
// Add bidirectional connections
self.nodes[id].connections[l] = selected.clone();
for &neighbor_id in &selected {
self.nodes[neighbor_id].connections[l].push(id);
// Prune if too many connections
if self.nodes[neighbor_id].connections[l].len() > max_conn {
self.prune_connections(neighbor_id, l, max_conn)?;
}
}
if !neighbors.is_empty() {
current = neighbors[0].id;
}
}
// Update entry point if new node has higher level
if level > self.max_level {
self.entry_point = Some(id);
self.max_level = level;
}
// Invalidate tangent cache
self.tangent_cache = None;
Ok(id)
}
/// Insert batch of vectors
pub fn insert_batch(&mut self, vectors: Vec<Vec<f32>>) -> HyperbolicResult<Vec<usize>> {
let mut ids = Vec::with_capacity(vectors.len());
for vector in vectors {
ids.push(self.insert(vector)?);
}
Ok(ids)
}
/// Search for single nearest neighbor at a layer (greedy)
fn search_layer_single(&self, query: &[f32], entry: usize, level: usize) -> HyperbolicResult<usize> {
let mut current = entry;
let mut current_dist = self.distance(query, &self.nodes[current].vector);
loop {
let mut changed = false;
for &neighbor in &self.nodes[current].connections[level] {
let dist = self.distance(query, &self.nodes[neighbor].vector);
if dist < current_dist {
current_dist = dist;
current = neighbor;
changed = true;
}
}
if !changed {
break;
}
}
Ok(current)
}
/// Search layer with ef candidates
fn search_layer(
&self,
query: &[f32],
entry: usize,
ef: usize,
level: usize,
) -> HyperbolicResult<Vec<SearchResult>> {
use std::collections::{BinaryHeap, HashSet};
let entry_dist = self.distance(query, &self.nodes[entry].vector);
let mut visited = HashSet::new();
visited.insert(entry);
// Candidates (min-heap by distance)
let mut candidates: BinaryHeap<std::cmp::Reverse<SearchResult>> = BinaryHeap::new();
candidates.push(std::cmp::Reverse(SearchResult {
id: entry,
distance: entry_dist,
}));
// Results (max-heap by distance for easy pruning)
let mut results: BinaryHeap<SearchResult> = BinaryHeap::new();
results.push(SearchResult {
id: entry,
distance: entry_dist,
});
while let Some(std::cmp::Reverse(current)) = candidates.pop() {
// Check if we can stop early
if let Some(furthest) = results.peek() {
if current.distance > furthest.distance && results.len() >= ef {
break;
}
}
// Explore neighbors
for &neighbor in &self.nodes[current.id].connections[level] {
if visited.contains(&neighbor) {
continue;
}
visited.insert(neighbor);
let dist = self.distance(query, &self.nodes[neighbor].vector);
let should_add = results.len() < ef
|| results
.peek()
.map(|r| dist < r.distance)
.unwrap_or(true);
if should_add {
candidates.push(std::cmp::Reverse(SearchResult {
id: neighbor,
distance: dist,
}));
results.push(SearchResult {
id: neighbor,
distance: dist,
});
if results.len() > ef {
results.pop();
}
}
}
}
let mut result_vec: Vec<SearchResult> = results.into_iter().collect();
result_vec.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
Ok(result_vec)
}
/// Prune connections to keep only the best
fn prune_connections(
&mut self,
node_id: usize,
level: usize,
max_conn: usize,
) -> HyperbolicResult<()> {
let node_vector = self.nodes[node_id].vector.clone();
let connections = &self.nodes[node_id].connections[level];
let mut scored: Vec<(usize, f32)> = connections
.iter()
.map(|&id| (id, self.distance(&node_vector, &self.nodes[id].vector)))
.collect();
scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
self.nodes[node_id].connections[level] =
scored.into_iter().take(max_conn).map(|(id, _)| id).collect();
Ok(())
}
/// Search for k nearest neighbors
pub fn search(&self, query: &[f32], k: usize) -> HyperbolicResult<Vec<SearchResult>> {
if self.is_empty() {
return Ok(Vec::new());
}
let query = project_to_ball(query, self.config.curvature, EPS);
let entry = self.entry_point.unwrap();
// Navigate to lowest level from top
let mut current = entry;
for l in (1..=self.max_level).rev() {
current = self.search_layer_single(&query, current, l)?;
}
// Search at layer 0 with ef_search candidates
let ef = self.config.ef_search.max(k);
let mut results = self.search_layer(&query, current, ef, 0)?;
results.truncate(k);
Ok(results)
}
/// Search with tangent space pruning (optimized for hyperbolic)
pub fn search_with_pruning(&self, query: &[f32], k: usize) -> HyperbolicResult<Vec<SearchResult>> {
// Fall back to regular search if no tangent cache
if self.tangent_cache.is_none() || !self.config.use_tangent_pruning {
return self.search(query, k);
}
let cache = self.tangent_cache.as_ref().unwrap();
let query = project_to_ball(query, self.config.curvature, EPS);
// Phase 1: Fast tangent space filtering
let query_tangent = cache.query_tangent(&query);
let mut candidates: Vec<(usize, f32)> = (0..cache.len())
.map(|i| {
let tangent_dist = cache.tangent_distance_squared(&query_tangent, i);
(cache.point_indices[i], tangent_dist)
})
.collect();
// Sort by tangent distance
candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
// Keep top prune_factor * k candidates
let num_candidates = (k * self.config.prune_factor).min(candidates.len());
candidates.truncate(num_candidates);
// Phase 2: Exact Poincaré distance for finalists
let mut results: Vec<SearchResult> = candidates
.into_iter()
.map(|(id, _)| {
let dist = self.distance(&query, &self.nodes[id].vector);
SearchResult { id, distance: dist }
})
.collect();
results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
results.truncate(k);
Ok(results)
}
/// Build tangent cache for all points
pub fn build_tangent_cache(&mut self) -> HyperbolicResult<()> {
if self.is_empty() {
return Ok(());
}
let vectors: Vec<Vec<f32>> = self.nodes.iter().map(|n| n.vector.clone()).collect();
let indices: Vec<usize> = (0..self.nodes.len()).collect();
self.tangent_cache = Some(TangentCache::new(&vectors, &indices, self.config.curvature)?);
Ok(())
}
/// Get a reference to a node's vector
pub fn get_vector(&self, id: usize) -> Option<&[f32]> {
self.nodes.get(id).map(|n| n.vector.as_slice())
}
/// Update curvature and rebuild tangent cache
pub fn set_curvature(&mut self, curvature: f32) -> HyperbolicResult<()> {
if curvature <= 0.0 {
return Err(HyperbolicError::InvalidCurvature(curvature));
}
self.config.curvature = curvature;
// Reproject all vectors
for node in &mut self.nodes {
node.vector = project_to_ball(&node.vector, curvature, EPS);
}
// Rebuild tangent cache
if self.tangent_cache.is_some() {
self.build_tangent_cache()?;
}
Ok(())
}
/// Get all vectors as a slice
pub fn vectors(&self) -> Vec<&[f32]> {
self.nodes.iter().map(|n| n.vector.as_slice()).collect()
}
}
/// Dual-space index for fallback and mutual ranking fusion
#[derive(Debug)]
pub struct DualSpaceIndex {
/// Hyperbolic index (primary)
pub hyperbolic: HyperbolicHnsw,
/// Euclidean index (fallback)
pub euclidean: HyperbolicHnsw,
/// Fusion weight for hyperbolic results (0-1)
pub fusion_weight: f32,
}
impl DualSpaceIndex {
/// Create a new dual-space index
pub fn new(curvature: f32, fusion_weight: f32) -> Self {
let mut hyp_config = HyperbolicHnswConfig::default();
hyp_config.curvature = curvature;
hyp_config.metric = DistanceMetric::Poincare;
let mut euc_config = HyperbolicHnswConfig::default();
euc_config.metric = DistanceMetric::Euclidean;
Self {
hyperbolic: HyperbolicHnsw::new(hyp_config),
euclidean: HyperbolicHnsw::new(euc_config),
fusion_weight: fusion_weight.clamp(0.0, 1.0),
}
}
/// Insert into both indices
pub fn insert(&mut self, vector: Vec<f32>) -> HyperbolicResult<usize> {
self.euclidean.insert(vector.clone())?;
self.hyperbolic.insert(vector)
}
/// Search with mutual ranking fusion
pub fn search(&self, query: &[f32], k: usize) -> HyperbolicResult<Vec<SearchResult>> {
let hyp_results = self.hyperbolic.search(query, k * 2)?;
let euc_results = self.euclidean.search(query, k * 2)?;
// Combine and re-rank using fusion
use std::collections::HashMap;
let mut scores: HashMap<usize, f32> = HashMap::new();
// Add hyperbolic scores
for (rank, r) in hyp_results.iter().enumerate() {
let score = self.fusion_weight * (1.0 / (rank as f32 + 1.0));
*scores.entry(r.id).or_insert(0.0) += score;
}
// Add Euclidean scores
for (rank, r) in euc_results.iter().enumerate() {
let score = (1.0 - self.fusion_weight) * (1.0 / (rank as f32 + 1.0));
*scores.entry(r.id).or_insert(0.0) += score;
}
// Sort by combined score (higher is better)
let mut combined: Vec<(usize, f32)> = scores.into_iter().collect();
combined.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
// Return top k with hyperbolic distances
Ok(combined
.into_iter()
.take(k)
.map(|(id, _)| {
let dist = self.hyperbolic.distance(
query,
self.hyperbolic.get_vector(id).unwrap_or(&[]),
);
SearchResult { id, distance: dist }
})
.collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hnsw_insert_search() {
let mut hnsw = HyperbolicHnsw::default_config();
// Insert some vectors
for i in 0..10 {
let v = vec![0.1 * i as f32, 0.05 * i as f32];
hnsw.insert(v).unwrap();
}
assert_eq!(hnsw.len(), 10);
// Search
let query = vec![0.3, 0.15];
let results = hnsw.search(&query, 3).unwrap();
assert_eq!(results.len(), 3);
assert!(results[0].distance <= results[1].distance);
}
#[test]
fn test_dual_space() {
let mut dual = DualSpaceIndex::new(1.0, 0.5);
for i in 0..10 {
let v = vec![0.1 * i as f32, 0.05 * i as f32];
dual.insert(v).unwrap();
}
let query = vec![0.3, 0.15];
let results = dual.search(&query, 3).unwrap();
assert_eq!(results.len(), 3);
}
}

View File

@@ -0,0 +1,210 @@
//! Hyperbolic Embeddings with HNSW Integration for RuVector
//!
//! This crate provides hyperbolic (Poincaré ball) embeddings integrated with
//! HNSW (Hierarchical Navigable Small World) graphs for hierarchy-aware
//! vector search.
//!
//! # Overview
//!
//! Hierarchies compress naturally in hyperbolic space. Taxonomies, catalogs,
//! ICD trees, product facets, org charts, and long-tail tags all fit better
//! than in Euclidean space, which means higher recall on deep leaves without
//! blowing up memory or latency.
//!
//! # Key Features
//!
//! - **Poincaré Ball Model**: Store vectors in the Poincaré ball with proper
//! geometric operations (Möbius addition, exp/log maps)
//! - **Tangent Space Pruning**: Prune HNSW candidates with cheap Euclidean
//! distance in tangent space before exact hyperbolic ranking
//! - **Per-Shard Curvature**: Different parts of the hierarchy can have
//! different optimal curvatures
//! - **Dual-Space Index**: Keep a synchronized Euclidean index for fallback
//! and mutual ranking fusion
//!
//! # Quick Start
//!
//! ```rust
//! use ruvector_hyperbolic_hnsw::{HyperbolicHnsw, HyperbolicHnswConfig};
//!
//! // Create index with default settings
//! let mut index = HyperbolicHnsw::default_config();
//!
//! // Insert vectors (automatically projected to Poincaré ball)
//! index.insert(vec![0.1, 0.2, 0.3]).unwrap();
//! index.insert(vec![-0.1, 0.15, 0.25]).unwrap();
//! index.insert(vec![0.2, -0.1, 0.1]).unwrap();
//!
//! // Search for nearest neighbors
//! let results = index.search(&[0.15, 0.1, 0.2], 2).unwrap();
//! for r in results {
//! println!("ID: {}, Distance: {:.4}", r.id, r.distance);
//! }
//! ```
//!
//! # HNSW Speed Trick
//!
//! The core optimization is:
//! 1. Precompute `u = log_c(x)` at a shard centroid `c`
//! 2. During neighbor selection, use Euclidean `||u_q - u_p||` to prune
//! 3. Run exact Poincaré distance only on top N candidates before final ranking
//!
//! ```rust
//! use ruvector_hyperbolic_hnsw::{HyperbolicHnsw, HyperbolicHnswConfig};
//!
//! let mut config = HyperbolicHnswConfig::default();
//! config.use_tangent_pruning = true;
//! config.prune_factor = 10; // Consider 10x candidates in tangent space
//!
//! let mut index = HyperbolicHnsw::new(config);
//! // ... insert vectors ...
//!
//! // Build tangent cache for pruning optimization
//! # index.insert(vec![0.1, 0.2]).unwrap();
//! index.build_tangent_cache().unwrap();
//!
//! // Search with pruning
//! let results = index.search_with_pruning(&[0.1, 0.15], 5).unwrap();
//! ```
//!
//! # Sharded Index with Per-Shard Curvature
//!
//! ```rust
//! use ruvector_hyperbolic_hnsw::{ShardedHyperbolicHnsw, ShardStrategy};
//!
//! let mut manager = ShardedHyperbolicHnsw::new(1.0);
//!
//! // Insert with hierarchy depth information
//! manager.insert(vec![0.1, 0.2], Some(0)).unwrap(); // Root level
//! manager.insert(vec![0.3, 0.1], Some(3)).unwrap(); // Deeper level
//!
//! // Update curvature for specific shard
//! manager.update_curvature("radius_1", 0.5).unwrap();
//!
//! // Search across all shards
//! let results = manager.search(&[0.2, 0.15], 5).unwrap();
//! ```
//!
//! # Mathematical Operations
//!
//! The `poincare` module provides low-level hyperbolic geometry operations:
//!
//! ```rust
//! use ruvector_hyperbolic_hnsw::poincare::{
//! mobius_add, exp_map, log_map, poincare_distance, project_to_ball
//! };
//!
//! let x = vec![0.3, 0.2];
//! let y = vec![-0.1, 0.4];
//! let c = 1.0; // Curvature
//!
//! // Möbius addition (hyperbolic vector addition)
//! let z = mobius_add(&x, &y, c);
//!
//! // Geodesic distance in hyperbolic space
//! let d = poincare_distance(&x, &y, c);
//!
//! // Map to tangent space at x
//! let v = log_map(&y, &x, c);
//!
//! // Map back to manifold
//! let y_recovered = exp_map(&v, &x, c);
//! ```
//!
//! # Numerical Stability
//!
//! All operations include numerical safeguards:
//! - Norm clamping with `eps = 1e-5`
//! - Projection after every update
//! - Stable `acosh` and `log1p` implementations
//!
//! # Feature Flags
//!
//! - `simd`: Enable SIMD acceleration (default)
//! - `parallel`: Enable parallel processing with rayon (default)
//! - `wasm`: Enable WebAssembly compatibility
pub mod error;
pub mod hnsw;
pub mod poincare;
pub mod shard;
pub mod tangent;
// Re-exports
pub use error::{HyperbolicError, HyperbolicResult};
pub use hnsw::{
DistanceMetric, DualSpaceIndex, HnswNode, HyperbolicHnsw, HyperbolicHnswConfig, SearchResult,
};
pub use poincare::{
conformal_factor, conformal_factor_from_norm_sq, dot, exp_map, frechet_mean, fused_norms,
hyperbolic_midpoint, log_map, log_map_at_centroid, mobius_add, mobius_add_inplace,
mobius_scalar_mult, norm, norm_squared, parallel_transport, poincare_distance,
poincare_distance_batch, poincare_distance_from_norms, poincare_distance_squared,
project_to_ball, project_to_ball_inplace, PoincareConfig, DEFAULT_CURVATURE, EPS,
};
pub use shard::{
CurvatureRegistry, HierarchyMetrics, HyperbolicShard, ShardCurvature, ShardStrategy,
ShardedHyperbolicHnsw,
};
pub use tangent::{tangent_micro_update, PrunedCandidate, TangentCache, TangentPruner};
/// Library version
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
/// Prelude for common imports
pub mod prelude {
pub use crate::error::{HyperbolicError, HyperbolicResult};
pub use crate::hnsw::{HyperbolicHnsw, HyperbolicHnswConfig, SearchResult};
pub use crate::poincare::{exp_map, log_map, mobius_add, poincare_distance, project_to_ball};
pub use crate::shard::{ShardedHyperbolicHnsw, ShardStrategy};
pub use crate::tangent::{TangentCache, TangentPruner};
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_workflow() {
// Create index
let mut index = HyperbolicHnsw::default_config();
// Insert vectors
for i in 0..10 {
let v = vec![0.1 * i as f32, 0.05 * i as f32, 0.02 * i as f32];
index.insert(v).unwrap();
}
// Search
let query = vec![0.35, 0.175, 0.07];
let results = index.search(&query, 3).unwrap();
assert_eq!(results.len(), 3);
// Results should be sorted by distance
for i in 1..results.len() {
assert!(results[i - 1].distance <= results[i].distance);
}
}
#[test]
fn test_hierarchy_preservation() {
// Create points at different "depths"
let points: Vec<Vec<f32>> = (0..20)
.map(|i| {
// Points further from origin represent deeper hierarchy
let depth = i / 4;
let radius = 0.1 + 0.15 * depth as f32;
let angle = (i % 4) as f32 * std::f32::consts::PI / 2.0;
vec![radius * angle.cos(), radius * angle.sin()]
})
.collect();
let depths: Vec<usize> = (0..20).map(|i| i / 4).collect();
// Compute metrics
let metrics = HierarchyMetrics::compute(&points, &depths, 1.0).unwrap();
// Radius should correlate positively with depth
assert!(metrics.radius_depth_correlation > 0.5);
}
}

View File

@@ -0,0 +1,627 @@
//! Poincaré Ball Model Operations for Hyperbolic Geometry
//!
//! This module implements core operations in the Poincaré ball model of hyperbolic space,
//! providing mathematically correct implementations with numerical stability guarantees.
//!
//! # Mathematical Background
//!
//! The Poincaré ball model represents hyperbolic space as the interior of a unit ball
//! in Euclidean space. Points are constrained to satisfy ||x|| < 1/√c where c > 0 is
//! the curvature parameter.
//!
//! # Key Operations
//!
//! - **Möbius Addition**: The hyperbolic analog of vector addition
//! - **Exponential Map**: Maps tangent vectors to the manifold
//! - **Logarithmic Map**: Maps manifold points to tangent space
//! - **Poincaré Distance**: The geodesic distance in hyperbolic space
use crate::error::{HyperbolicError, HyperbolicResult};
use serde::{Deserialize, Serialize};
/// Small epsilon for numerical stability (as specified: eps=1e-5)
pub const EPS: f32 = 1e-5;
/// Default curvature parameter (negative curvature, c > 0)
pub const DEFAULT_CURVATURE: f32 = 1.0;
/// Configuration for Poincaré ball operations
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct PoincareConfig {
/// Curvature parameter (c > 0 for hyperbolic space)
pub curvature: f32,
/// Numerical stability epsilon
pub eps: f32,
/// Maximum iterations for iterative algorithms (e.g., Fréchet mean)
pub max_iter: usize,
/// Convergence tolerance
pub tol: f32,
}
impl Default for PoincareConfig {
fn default() -> Self {
Self {
curvature: DEFAULT_CURVATURE,
eps: EPS,
max_iter: 100,
tol: 1e-6,
}
}
}
impl PoincareConfig {
/// Create configuration with custom curvature
pub fn with_curvature(curvature: f32) -> HyperbolicResult<Self> {
if curvature <= 0.0 {
return Err(HyperbolicError::InvalidCurvature(curvature));
}
Ok(Self {
curvature,
..Default::default()
})
}
/// Maximum allowed norm for points in the ball
#[inline]
pub fn max_norm(&self) -> f32 {
(1.0 / self.curvature.sqrt()) - self.eps
}
}
// ============================================================================
// Optimized Core Operations (SIMD-friendly)
// ============================================================================
/// Compute the squared Euclidean norm of a slice (optimized with unrolling)
#[inline]
pub fn norm_squared(x: &[f32]) -> f32 {
let len = x.len();
let mut sum = 0.0f32;
// Process 4 elements at a time for better SIMD utilization
let chunks = len / 4;
let remainder = len % 4;
let mut i = 0;
for _ in 0..chunks {
let a = x[i];
let b = x[i + 1];
let c = x[i + 2];
let d = x[i + 3];
sum += a * a + b * b + c * c + d * d;
i += 4;
}
// Handle remainder
for j in 0..remainder {
let v = x[i + j];
sum += v * v;
}
sum
}
/// Compute the Euclidean norm of a slice
#[inline]
pub fn norm(x: &[f32]) -> f32 {
norm_squared(x).sqrt()
}
/// Compute the dot product of two slices (optimized with unrolling)
#[inline]
pub fn dot(x: &[f32], y: &[f32]) -> f32 {
let len = x.len().min(y.len());
let mut sum = 0.0f32;
// Process 4 elements at a time
let chunks = len / 4;
let remainder = len % 4;
let mut i = 0;
for _ in 0..chunks {
sum += x[i] * y[i] + x[i+1] * y[i+1] + x[i+2] * y[i+2] + x[i+3] * y[i+3];
i += 4;
}
for j in 0..remainder {
sum += x[i + j] * y[i + j];
}
sum
}
/// Fused computation of ||u-v||², ||u||², ||v||² in single pass (3x faster)
#[inline]
pub fn fused_norms(u: &[f32], v: &[f32]) -> (f32, f32, f32) {
let len = u.len().min(v.len());
let mut diff_sq = 0.0f32;
let mut norm_u_sq = 0.0f32;
let mut norm_v_sq = 0.0f32;
// Process 4 elements at a time
let chunks = len / 4;
let remainder = len % 4;
let mut i = 0;
for _ in 0..chunks {
let (u0, u1, u2, u3) = (u[i], u[i+1], u[i+2], u[i+3]);
let (v0, v1, v2, v3) = (v[i], v[i+1], v[i+2], v[i+3]);
let (d0, d1, d2, d3) = (u0 - v0, u1 - v1, u2 - v2, u3 - v3);
diff_sq += d0 * d0 + d1 * d1 + d2 * d2 + d3 * d3;
norm_u_sq += u0 * u0 + u1 * u1 + u2 * u2 + u3 * u3;
norm_v_sq += v0 * v0 + v1 * v1 + v2 * v2 + v3 * v3;
i += 4;
}
for j in 0..remainder {
let ui = u[i + j];
let vi = v[i + j];
let di = ui - vi;
diff_sq += di * di;
norm_u_sq += ui * ui;
norm_v_sq += vi * vi;
}
(diff_sq, norm_u_sq, norm_v_sq)
}
/// Project a point back into the Poincaré ball
///
/// Ensures ||x|| < 1/√c - eps for numerical stability
#[inline]
pub fn project_to_ball(x: &[f32], c: f32, eps: f32) -> Vec<f32> {
let c = c.abs().max(EPS);
let norm_sq = norm_squared(x);
let max_norm = (1.0 / c.sqrt()) - eps;
let max_norm_sq = max_norm * max_norm;
if norm_sq < max_norm_sq || norm_sq < eps * eps {
x.to_vec()
} else {
let scale = max_norm / norm_sq.sqrt();
x.iter().map(|&xi| scale * xi).collect()
}
}
/// Project in-place (avoids allocation when possible)
#[inline]
pub fn project_to_ball_inplace(x: &mut [f32], c: f32, eps: f32) {
let c = c.abs().max(EPS);
let norm_sq = norm_squared(x);
let max_norm = (1.0 / c.sqrt()) - eps;
let max_norm_sq = max_norm * max_norm;
if norm_sq >= max_norm_sq && norm_sq >= eps * eps {
let scale = max_norm / norm_sq.sqrt();
for xi in x.iter_mut() {
*xi *= scale;
}
}
}
/// Compute the conformal factor λ_x at point x
///
/// λ_x = 2 / (1 - c||x||²)
#[inline]
pub fn conformal_factor(x: &[f32], c: f32) -> f32 {
let norm_sq = norm_squared(x);
2.0 / (1.0 - c * norm_sq).max(EPS)
}
/// Conformal factor from pre-computed norm squared
#[inline]
pub fn conformal_factor_from_norm_sq(norm_sq: f32, c: f32) -> f32 {
2.0 / (1.0 - c * norm_sq).max(EPS)
}
// ============================================================================
// Poincaré Distance (Optimized)
// ============================================================================
/// Poincaré distance between two points (optimized with fused norms)
///
/// Uses the formula:
/// d(u, v) = (1/√c) acosh(1 + 2c ||u - v||² / ((1 - c||u||²)(1 - c||v||²)))
#[inline]
pub fn poincare_distance(u: &[f32], v: &[f32], c: f32) -> f32 {
let c = c.abs().max(EPS);
// Fused computation: single pass for all three norms
let (diff_sq, norm_u_sq, norm_v_sq) = fused_norms(u, v);
poincare_distance_from_norms(diff_sq, norm_u_sq, norm_v_sq, c)
}
/// Poincaré distance from pre-computed norms (for batch operations)
#[inline]
pub fn poincare_distance_from_norms(diff_sq: f32, norm_u_sq: f32, norm_v_sq: f32, c: f32) -> f32 {
let sqrt_c = c.sqrt();
let lambda_u = (1.0 - c * norm_u_sq).max(EPS);
let lambda_v = (1.0 - c * norm_v_sq).max(EPS);
let numerator = 2.0 * c * diff_sq;
let denominator = lambda_u * lambda_v;
let arg = 1.0 + numerator / denominator;
if arg <= 1.0 {
return 0.0;
}
// Stable acosh computation
(1.0 / sqrt_c) * fast_acosh(arg)
}
/// Fast acosh with numerical stability
#[inline]
fn fast_acosh(x: f32) -> f32 {
if x <= 1.0 {
return 0.0;
}
let delta = x - 1.0;
if delta < 1e-4 {
// Taylor expansion for small delta: acosh(1+δ) ≈ √(2δ)
(2.0 * delta).sqrt()
} else if x < 1e6 {
// Standard formula: acosh(x) = ln(x + √(x²-1))
(x + (x * x - 1.0).sqrt()).ln()
} else {
// For very large x: acosh(x) ≈ ln(2x)
(2.0 * x).ln()
}
}
/// Squared Poincaré distance (faster for comparisons)
#[inline]
pub fn poincare_distance_squared(u: &[f32], v: &[f32], c: f32) -> f32 {
let d = poincare_distance(u, v, c);
d * d
}
/// Batch distance computation (processes multiple pairs efficiently)
pub fn poincare_distance_batch(
query: &[f32],
points: &[&[f32]],
c: f32,
) -> Vec<f32> {
let c = c.abs().max(EPS);
let query_norm_sq = norm_squared(query);
points
.iter()
.map(|point| {
let (diff_sq, _, point_norm_sq) = fused_norms(query, point);
poincare_distance_from_norms(diff_sq, query_norm_sq, point_norm_sq, c)
})
.collect()
}
// ============================================================================
// Möbius Operations (Optimized)
// ============================================================================
/// Möbius addition in the Poincaré ball (optimized)
///
/// x ⊕_c y = ((1 + 2c⟨x,y⟩ + c||y||²)x + (1 - c||x||²)y) / (1 + 2c⟨x,y⟩ + c²||x||²||y||²)
#[inline]
pub fn mobius_add(x: &[f32], y: &[f32], c: f32) -> Vec<f32> {
let c = c.abs().max(EPS);
// Fused computation of norms and dot product
let len = x.len().min(y.len());
let mut norm_x_sq = 0.0f32;
let mut norm_y_sq = 0.0f32;
let mut dot_xy = 0.0f32;
// Process 4 elements at a time
let chunks = len / 4;
let remainder = len % 4;
let mut i = 0;
for _ in 0..chunks {
let (x0, x1, x2, x3) = (x[i], x[i+1], x[i+2], x[i+3]);
let (y0, y1, y2, y3) = (y[i], y[i+1], y[i+2], y[i+3]);
norm_x_sq += x0 * x0 + x1 * x1 + x2 * x2 + x3 * x3;
norm_y_sq += y0 * y0 + y1 * y1 + y2 * y2 + y3 * y3;
dot_xy += x0 * y0 + x1 * y1 + x2 * y2 + x3 * y3;
i += 4;
}
for j in 0..remainder {
let xi = x[i + j];
let yi = y[i + j];
norm_x_sq += xi * xi;
norm_y_sq += yi * yi;
dot_xy += xi * yi;
}
// Compute coefficients
let coef_x = 1.0 + 2.0 * c * dot_xy + c * norm_y_sq;
let coef_y = 1.0 - c * norm_x_sq;
let denom = (1.0 + 2.0 * c * dot_xy + c * c * norm_x_sq * norm_y_sq).max(EPS);
let inv_denom = 1.0 / denom;
// Compute result
let mut result = Vec::with_capacity(len);
for j in 0..len {
result.push((coef_x * x[j] + coef_y * y[j]) * inv_denom);
}
// Project back into ball
project_to_ball_inplace(&mut result, c, EPS);
result
}
/// Möbius addition in-place (modifies first argument)
#[inline]
pub fn mobius_add_inplace(x: &mut [f32], y: &[f32], c: f32) {
let c = c.abs().max(EPS);
let len = x.len().min(y.len());
let norm_x_sq = norm_squared(x);
let norm_y_sq = norm_squared(y);
let dot_xy = dot(x, y);
let coef_x = 1.0 + 2.0 * c * dot_xy + c * norm_y_sq;
let coef_y = 1.0 - c * norm_x_sq;
let denom = (1.0 + 2.0 * c * dot_xy + c * c * norm_x_sq * norm_y_sq).max(EPS);
let inv_denom = 1.0 / denom;
for j in 0..len {
x[j] = (coef_x * x[j] + coef_y * y[j]) * inv_denom;
}
project_to_ball_inplace(x, c, EPS);
}
/// Möbius scalar multiplication
///
/// r ⊗_c x = (1/√c) tanh(r · arctanh(√c ||x||)) · (x / ||x||)
pub fn mobius_scalar_mult(r: f32, x: &[f32], c: f32) -> Vec<f32> {
let c = c.abs().max(EPS);
let sqrt_c = c.sqrt();
let norm_x = norm(x);
if norm_x < EPS {
return x.to_vec();
}
let arctanh_arg = (sqrt_c * norm_x).min(1.0 - EPS);
let arctanh_val = arctanh_arg.atanh();
let scale = (1.0 / sqrt_c) * (r * arctanh_val).tanh() / norm_x;
x.iter().map(|&xi| scale * xi).collect()
}
// ============================================================================
// Exp/Log Maps (Optimized)
// ============================================================================
/// Exponential map at point p
///
/// exp_p(v) = p ⊕_c (tanh(√c λ_p ||v|| / 2) · v / (√c ||v||))
pub fn exp_map(v: &[f32], p: &[f32], c: f32) -> Vec<f32> {
let c = c.abs().max(EPS);
let sqrt_c = c.sqrt();
let norm_p_sq = norm_squared(p);
let lambda_p = conformal_factor_from_norm_sq(norm_p_sq, c);
let norm_v = norm(v);
if norm_v < EPS {
return p.to_vec();
}
let scaled_norm = sqrt_c * lambda_p * norm_v / 2.0;
let coef = scaled_norm.tanh() / (sqrt_c * norm_v);
let transported: Vec<f32> = v.iter().map(|&vi| coef * vi).collect();
mobius_add(p, &transported, c)
}
/// Logarithmic map at point p
///
/// log_p(y) = (2 / (√c λ_p)) arctanh(√c ||p ⊕_c y||) · (p ⊕_c y) / ||p ⊕_c y||
pub fn log_map(y: &[f32], p: &[f32], c: f32) -> Vec<f32> {
let c = c.abs().max(EPS);
let sqrt_c = c.sqrt();
// Compute -p ⊕_c y
let neg_p: Vec<f32> = p.iter().map(|&pi| -pi).collect();
let diff = mobius_add(&neg_p, y, c);
let norm_diff = norm(&diff);
if norm_diff < EPS {
return vec![0.0; y.len()];
}
let norm_p_sq = norm_squared(p);
let lambda_p = conformal_factor_from_norm_sq(norm_p_sq, c);
let arctanh_arg = (sqrt_c * norm_diff).min(1.0 - EPS);
let coef = (2.0 / (sqrt_c * lambda_p)) * arctanh_arg.atanh() / norm_diff;
diff.iter().map(|&di| coef * di).collect()
}
/// Logarithmic map at a shard centroid for tangent space coordinates
pub fn log_map_at_centroid(x: &[f32], centroid: &[f32], c: f32) -> Vec<f32> {
log_map(x, centroid, c)
}
// ============================================================================
// Fréchet Mean & Utilities
// ============================================================================
/// Compute the Fréchet mean (hyperbolic centroid) of points
pub fn frechet_mean(
points: &[&[f32]],
weights: Option<&[f32]>,
config: &PoincareConfig,
) -> HyperbolicResult<Vec<f32>> {
if points.is_empty() {
return Err(HyperbolicError::EmptyCollection);
}
let dim = points[0].len();
let c = config.curvature;
// Validate dimensions
for p in points.iter() {
if p.len() != dim {
return Err(HyperbolicError::DimensionMismatch {
expected: dim,
got: p.len(),
});
}
}
// Set up weights
let uniform_weights: Vec<f32>;
let w = if let Some(weights) = weights {
if weights.len() != points.len() {
return Err(HyperbolicError::DimensionMismatch {
expected: points.len(),
got: weights.len(),
});
}
weights
} else {
uniform_weights = vec![1.0 / points.len() as f32; points.len()];
&uniform_weights
};
// Initialize with Euclidean weighted mean, projected to ball
let mut mean = vec![0.0; dim];
for (point, &weight) in points.iter().zip(w) {
for (i, &val) in point.iter().enumerate() {
mean[i] += weight * val;
}
}
project_to_ball_inplace(&mut mean, c, config.eps);
// Riemannian gradient descent
let learning_rate = 0.1;
let mut grad = vec![0.0; dim];
for _ in 0..config.max_iter {
// Reset gradient
for g in grad.iter_mut() {
*g = 0.0;
}
// Compute Riemannian gradient
for (point, &weight) in points.iter().zip(w) {
let log_result = log_map(point, &mean, c);
for (i, &val) in log_result.iter().enumerate() {
grad[i] += weight * val;
}
}
// Check convergence
if norm(&grad) < config.tol {
break;
}
// Update step
let update: Vec<f32> = grad.iter().map(|&g| learning_rate * g).collect();
mean = exp_map(&update, &mean, c);
}
Ok(mean)
}
/// Hyperbolic midpoint between two points
pub fn hyperbolic_midpoint(x: &[f32], y: &[f32], c: f32) -> Vec<f32> {
let log_y = log_map(y, x, c);
let half_log: Vec<f32> = log_y.iter().map(|&v| 0.5 * v).collect();
exp_map(&half_log, x, c)
}
/// Parallel transport a tangent vector from p to q
pub fn parallel_transport(v: &[f32], p: &[f32], q: &[f32], c: f32) -> Vec<f32> {
let c = c.abs().max(EPS);
let lambda_p = conformal_factor(p, c);
let lambda_q = conformal_factor(q, c);
let scale = lambda_p / lambda_q;
v.iter().map(|&vi| scale * vi).collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_project_to_ball() {
let x = vec![0.5, 0.5, 0.5];
let projected = project_to_ball(&x, 1.0, EPS);
assert!(norm(&projected) < 1.0 - EPS);
}
#[test]
fn test_mobius_add_identity() {
let x = vec![0.3, 0.2, 0.1];
let zero = vec![0.0, 0.0, 0.0];
let result = mobius_add(&x, &zero, 1.0);
for (a, b) in x.iter().zip(result.iter()) {
assert!((a - b).abs() < 1e-5);
}
}
#[test]
fn test_exp_log_inverse() {
let p = vec![0.1, 0.2, 0.1];
let v = vec![0.1, -0.1, 0.05];
let q = exp_map(&v, &p, 1.0);
let v_recovered = log_map(&q, &p, 1.0);
for (a, b) in v.iter().zip(v_recovered.iter()) {
assert!((a - b).abs() < 1e-4);
}
}
#[test]
fn test_poincare_distance_symmetry() {
let u = vec![0.3, 0.2];
let v = vec![-0.1, 0.4];
let d1 = poincare_distance(&u, &v, 1.0);
let d2 = poincare_distance(&v, &u, 1.0);
assert!((d1 - d2).abs() < 1e-6);
}
#[test]
fn test_poincare_distance_origin() {
let origin = vec![0.0, 0.0];
let d = poincare_distance(&origin, &origin, 1.0);
assert!(d.abs() < 1e-6);
}
#[test]
fn test_fused_norms() {
let u = vec![0.3, 0.2, 0.1];
let v = vec![0.1, 0.4, 0.2];
let (diff_sq, norm_u_sq, norm_v_sq) = fused_norms(&u, &v);
let expected_diff_sq: f32 = u.iter().zip(v.iter())
.map(|(a, b)| (a - b) * (a - b)).sum();
let expected_norm_u_sq = norm_squared(&u);
let expected_norm_v_sq = norm_squared(&v);
assert!((diff_sq - expected_diff_sq).abs() < 1e-6);
assert!((norm_u_sq - expected_norm_u_sq).abs() < 1e-6);
assert!((norm_v_sq - expected_norm_v_sq).abs() < 1e-6);
}
}

View File

@@ -0,0 +1,575 @@
//! Shard Management with Curvature Registry
//!
//! This module implements per-shard curvature management for hierarchical data.
//! Different parts of the hierarchy may have different optimal curvatures.
//!
//! # Features
//!
//! - Per-shard curvature configuration
//! - Hot reload of curvature parameters
//! - Canary testing for curvature updates
//! - Hierarchy preservation metrics
use crate::error::{HyperbolicError, HyperbolicResult};
use crate::hnsw::{HyperbolicHnsw, HyperbolicHnswConfig, SearchResult};
use crate::poincare::{frechet_mean, poincare_distance, project_to_ball, PoincareConfig, EPS};
use crate::tangent::TangentCache;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[cfg(feature = "parallel")]
use rayon::prelude::*;
/// Curvature configuration for a shard
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ShardCurvature {
/// Current active curvature
pub current: f32,
/// Canary curvature (for testing)
pub canary: Option<f32>,
/// Traffic percentage for canary (0-100)
pub canary_traffic: u8,
/// Learned curvature from data
pub learned: Option<f32>,
/// Last update timestamp
pub updated_at: i64,
}
impl Default for ShardCurvature {
fn default() -> Self {
Self {
current: 1.0,
canary: None,
canary_traffic: 0,
learned: None,
updated_at: 0,
}
}
}
impl ShardCurvature {
/// Get the effective curvature (considering canary traffic)
pub fn effective(&self, use_canary: bool) -> f32 {
if use_canary && self.canary.is_some() && self.canary_traffic > 0 {
self.canary.unwrap()
} else {
self.current
}
}
/// Promote canary to current
pub fn promote_canary(&mut self) {
if let Some(c) = self.canary {
self.current = c;
self.canary = None;
self.canary_traffic = 0;
}
}
/// Rollback canary
pub fn rollback_canary(&mut self) {
self.canary = None;
self.canary_traffic = 0;
}
}
/// Curvature registry for managing per-shard curvatures
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CurvatureRegistry {
/// Shard curvatures by shard ID
pub shards: HashMap<String, ShardCurvature>,
/// Global default curvature
pub default_curvature: f32,
/// Registry version (for hot reload)
pub version: u64,
}
impl CurvatureRegistry {
/// Create a new registry with default curvature
pub fn new(default_curvature: f32) -> Self {
Self {
shards: HashMap::new(),
default_curvature,
version: 0,
}
}
/// Get curvature for a shard
pub fn get(&self, shard_id: &str) -> f32 {
self.shards
.get(shard_id)
.map(|s| s.current)
.unwrap_or(self.default_curvature)
}
/// Get curvature with canary consideration
pub fn get_effective(&self, shard_id: &str, use_canary: bool) -> f32 {
self.shards
.get(shard_id)
.map(|s| s.effective(use_canary))
.unwrap_or(self.default_curvature)
}
/// Set curvature for a shard
pub fn set(&mut self, shard_id: &str, curvature: f32) {
let entry = self.shards.entry(shard_id.to_string()).or_default();
entry.current = curvature;
entry.updated_at = chrono_timestamp();
self.version += 1;
}
/// Set canary curvature
pub fn set_canary(&mut self, shard_id: &str, curvature: f32, traffic: u8) {
let entry = self.shards.entry(shard_id.to_string()).or_default();
entry.canary = Some(curvature);
entry.canary_traffic = traffic.min(100);
entry.updated_at = chrono_timestamp();
self.version += 1;
}
/// Promote all canaries
pub fn promote_all_canaries(&mut self) {
for (_, shard) in self.shards.iter_mut() {
shard.promote_canary();
}
self.version += 1;
}
/// Rollback all canaries
pub fn rollback_all_canaries(&mut self) {
for (_, shard) in self.shards.iter_mut() {
shard.rollback_canary();
}
self.version += 1;
}
/// Record learned curvature
pub fn set_learned(&mut self, shard_id: &str, curvature: f32) {
let entry = self.shards.entry(shard_id.to_string()).or_default();
entry.learned = Some(curvature);
entry.updated_at = chrono_timestamp();
}
}
fn chrono_timestamp() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0)
}
/// A single shard in the sharded HNSW system
#[derive(Debug)]
pub struct HyperbolicShard {
/// Shard ID
pub id: String,
/// HNSW index for this shard
pub index: HyperbolicHnsw,
/// Tangent cache
pub tangent_cache: Option<TangentCache>,
/// Shard centroid
pub centroid: Vec<f32>,
/// Hierarchy depth range (min, max)
pub depth_range: (usize, usize),
/// Number of vectors in shard
pub count: usize,
}
impl HyperbolicShard {
/// Create a new shard
pub fn new(id: String, curvature: f32) -> Self {
let mut config = HyperbolicHnswConfig::default();
config.curvature = curvature;
Self {
id,
index: HyperbolicHnsw::new(config),
tangent_cache: None,
centroid: Vec::new(),
depth_range: (0, 0),
count: 0,
}
}
/// Insert a vector
pub fn insert(&mut self, vector: Vec<f32>) -> HyperbolicResult<usize> {
let id = self.index.insert(vector)?;
self.count += 1;
// Invalidate tangent cache
self.tangent_cache = None;
Ok(id)
}
/// Build tangent cache
pub fn build_cache(&mut self) -> HyperbolicResult<()> {
if self.count == 0 {
return Ok(());
}
let vectors: Vec<Vec<f32>> = self
.index
.vectors()
.iter()
.map(|v| v.to_vec())
.collect();
let indices: Vec<usize> = (0..vectors.len()).collect();
self.tangent_cache = Some(TangentCache::new(
&vectors,
&indices,
self.index.config.curvature,
)?);
if let Some(cache) = &self.tangent_cache {
self.centroid = cache.centroid.clone();
}
Ok(())
}
/// Search with tangent pruning
pub fn search(&self, query: &[f32], k: usize) -> HyperbolicResult<Vec<SearchResult>> {
self.index.search(query, k)
}
/// Update curvature
pub fn set_curvature(&mut self, curvature: f32) -> HyperbolicResult<()> {
self.index.set_curvature(curvature)?;
// Rebuild cache with new curvature
if self.tangent_cache.is_some() {
self.build_cache()?;
}
Ok(())
}
}
/// Sharded hyperbolic HNSW manager
#[derive(Debug)]
pub struct ShardedHyperbolicHnsw {
/// Shards by ID
pub shards: HashMap<String, HyperbolicShard>,
/// Curvature registry
pub registry: CurvatureRegistry,
/// Global ID to shard mapping
pub id_to_shard: Vec<(String, usize)>,
/// Shard assignment strategy
pub strategy: ShardStrategy,
}
/// Strategy for assigning vectors to shards
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ShardStrategy {
/// Assign by hash
Hash,
/// Assign by hierarchy depth
Depth,
/// Assign by radius (distance from origin)
Radius,
/// Round-robin
RoundRobin,
}
impl Default for ShardStrategy {
fn default() -> Self {
Self::Radius
}
}
impl ShardedHyperbolicHnsw {
/// Create a new sharded manager
pub fn new(default_curvature: f32) -> Self {
Self {
shards: HashMap::new(),
registry: CurvatureRegistry::new(default_curvature),
id_to_shard: Vec::new(),
strategy: ShardStrategy::default(),
}
}
/// Create or get a shard
pub fn get_or_create_shard(&mut self, shard_id: &str) -> &mut HyperbolicShard {
let curvature = self.registry.get(shard_id);
self.shards
.entry(shard_id.to_string())
.or_insert_with(|| HyperbolicShard::new(shard_id.to_string(), curvature))
}
/// Determine shard for a vector
pub fn assign_shard(&self, vector: &[f32], depth: Option<usize>) -> String {
match self.strategy {
ShardStrategy::Hash => {
let hash: u64 = vector.iter().fold(0u64, |acc, &v| {
acc.wrapping_add((v.to_bits() as u64).wrapping_mul(31))
});
format!("shard_{}", hash % (self.shards.len().max(1) as u64))
}
ShardStrategy::Depth => {
let d = depth.unwrap_or(0);
format!("depth_{}", d / 10) // Group by depth buckets
}
ShardStrategy::Radius => {
let radius: f32 = vector.iter().map(|v| v * v).sum::<f32>().sqrt();
let bucket = (radius * 10.0) as usize;
format!("radius_{}", bucket)
}
ShardStrategy::RoundRobin => {
let idx = self.id_to_shard.len() % self.shards.len().max(1);
self.shards
.keys()
.nth(idx)
.cloned()
.unwrap_or_else(|| "default".to_string())
}
}
}
/// Insert vector with automatic shard assignment
pub fn insert(&mut self, vector: Vec<f32>, depth: Option<usize>) -> HyperbolicResult<usize> {
let shard_id = self.assign_shard(&vector, depth);
let shard = self.get_or_create_shard(&shard_id);
let local_id = shard.insert(vector)?;
let global_id = self.id_to_shard.len();
self.id_to_shard.push((shard_id, local_id));
Ok(global_id)
}
/// Insert into specific shard
pub fn insert_to_shard(
&mut self,
shard_id: &str,
vector: Vec<f32>,
) -> HyperbolicResult<usize> {
let shard = self.get_or_create_shard(shard_id);
let local_id = shard.insert(vector)?;
let global_id = self.id_to_shard.len();
self.id_to_shard.push((shard_id.to_string(), local_id));
Ok(global_id)
}
/// Search across all shards
pub fn search(&self, query: &[f32], k: usize) -> HyperbolicResult<Vec<(usize, SearchResult)>> {
let mut all_results: Vec<(usize, SearchResult)> = Vec::new();
for (shard_id, shard) in &self.shards {
let results = shard.search(query, k)?;
for result in results {
// Map local ID to global ID
if let Some((global_id, _)) = self.id_to_shard.iter().enumerate().find(|(_, (s, l))| s == shard_id && *l == result.id) {
all_results.push((global_id, result));
}
}
}
// Sort by distance and take top k
all_results.sort_by(|a, b| a.1.distance.partial_cmp(&b.1.distance).unwrap());
all_results.truncate(k);
Ok(all_results)
}
/// Build all tangent caches
pub fn build_caches(&mut self) -> HyperbolicResult<()> {
for shard in self.shards.values_mut() {
shard.build_cache()?;
}
Ok(())
}
/// Update curvature for a shard
pub fn update_curvature(&mut self, shard_id: &str, curvature: f32) -> HyperbolicResult<()> {
self.registry.set(shard_id, curvature);
if let Some(shard) = self.shards.get_mut(shard_id) {
shard.set_curvature(curvature)?;
}
Ok(())
}
/// Hot reload curvatures from registry
pub fn reload_curvatures(&mut self) -> HyperbolicResult<()> {
for (shard_id, shard) in self.shards.iter_mut() {
let curvature = self.registry.get(shard_id);
shard.set_curvature(curvature)?;
}
Ok(())
}
/// Get total vector count
pub fn len(&self) -> usize {
self.id_to_shard.len()
}
/// Check if empty
pub fn is_empty(&self) -> bool {
self.id_to_shard.is_empty()
}
/// Get number of shards
pub fn num_shards(&self) -> usize {
self.shards.len()
}
}
/// Metrics for hierarchy preservation
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct HierarchyMetrics {
/// Spearman correlation between radius and depth
pub radius_depth_correlation: f32,
/// Average distance distortion
pub distance_distortion: f32,
/// Ancestor preservation (AUPRC)
pub ancestor_auprc: f32,
/// Mean rank
pub mean_rank: f32,
/// NDCG scores
pub ndcg: HashMap<String, f32>,
}
impl HierarchyMetrics {
/// Compute hierarchy metrics
pub fn compute(
points: &[Vec<f32>],
depths: &[usize],
curvature: f32,
) -> HyperbolicResult<Self> {
if points.is_empty() || points.len() != depths.len() {
return Err(HyperbolicError::EmptyCollection);
}
// Compute radii
let radii: Vec<f32> = points
.iter()
.map(|p| p.iter().map(|v| v * v).sum::<f32>().sqrt())
.collect();
// Spearman correlation between radius and depth
let radius_depth_correlation = spearman_correlation(&radii, depths);
// Distance distortion (sample-based for efficiency)
let sample_size = points.len().min(100);
let mut distortion_sum = 0.0;
let mut distortion_count = 0;
for i in 0..sample_size {
for j in (i + 1)..sample_size {
let hyp_dist = poincare_distance(&points[i], &points[j], curvature);
let depth_diff = (depths[i] as f32 - depths[j] as f32).abs();
if depth_diff > 0.0 {
distortion_sum += (hyp_dist - depth_diff).abs() / depth_diff;
distortion_count += 1;
}
}
}
let distance_distortion = if distortion_count > 0 {
distortion_sum / distortion_count as f32
} else {
0.0
};
Ok(Self {
radius_depth_correlation,
distance_distortion,
ancestor_auprc: 0.0, // Requires ground truth
mean_rank: 0.0, // Requires ground truth
ndcg: HashMap::new(),
})
}
}
/// Compute Spearman rank correlation
fn spearman_correlation(x: &[f32], y: &[usize]) -> f32 {
if x.len() != y.len() || x.is_empty() {
return 0.0;
}
let n = x.len();
// Compute ranks for x
let mut x_indexed: Vec<(usize, f32)> = x.iter().cloned().enumerate().collect();
x_indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
let mut x_ranks = vec![0.0; n];
for (rank, (idx, _)) in x_indexed.iter().enumerate() {
x_ranks[*idx] = rank as f32;
}
// Compute ranks for y
let mut y_indexed: Vec<(usize, usize)> = y.iter().cloned().enumerate().collect();
y_indexed.sort_by_key(|a| a.1);
let mut y_ranks = vec![0.0; n];
for (rank, (idx, _)) in y_indexed.iter().enumerate() {
y_ranks[*idx] = rank as f32;
}
// Compute Spearman correlation
let mean_x: f32 = x_ranks.iter().sum::<f32>() / n as f32;
let mean_y: f32 = y_ranks.iter().sum::<f32>() / n as f32;
let mut cov = 0.0;
let mut var_x = 0.0;
let mut var_y = 0.0;
for i in 0..n {
let dx = x_ranks[i] - mean_x;
let dy = y_ranks[i] - mean_y;
cov += dx * dy;
var_x += dx * dx;
var_y += dy * dy;
}
if var_x == 0.0 || var_y == 0.0 {
return 0.0;
}
cov / (var_x * var_y).sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_curvature_registry() {
let mut registry = CurvatureRegistry::new(1.0);
registry.set("shard_1", 0.5);
assert_eq!(registry.get("shard_1"), 0.5);
assert_eq!(registry.get("shard_2"), 1.0); // Default
registry.set_canary("shard_1", 0.3, 50);
assert_eq!(registry.get_effective("shard_1", false), 0.5);
assert_eq!(registry.get_effective("shard_1", true), 0.3);
}
#[test]
fn test_sharded_hnsw() {
let mut manager = ShardedHyperbolicHnsw::new(1.0);
for i in 0..20 {
let v = vec![0.1 * i as f32, 0.05 * i as f32];
manager.insert(v, Some(i / 5)).unwrap();
}
assert_eq!(manager.len(), 20);
let query = vec![0.3, 0.15];
let results = manager.search(&query, 5).unwrap();
assert!(!results.is_empty());
}
#[test]
fn test_spearman() {
let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let y = vec![1, 2, 3, 4, 5];
let corr = spearman_correlation(&x, &y);
assert!((corr - 1.0).abs() < 0.01);
let y_rev = vec![5, 4, 3, 2, 1];
let corr_rev = spearman_correlation(&x, &y_rev);
assert!((corr_rev + 1.0).abs() < 0.01);
}
}

View File

@@ -0,0 +1,348 @@
//! Tangent Space Operations for HNSW Pruning Optimization
//!
//! This module implements the key optimization for hyperbolic HNSW:
//! - Precompute tangent space coordinates at shard centroids
//! - Use cheap Euclidean distance in tangent space for pruning
//! - Only compute exact Poincaré distance for final ranking
//!
//! # HNSW Speed Trick
//!
//! The core insight is that for points near a centroid c:
//! 1. Map points to tangent space: u = log_c(x)
//! 2. Euclidean distance ||u_q - u_p|| approximates hyperbolic distance
//! 3. Prune candidates using fast Euclidean comparisons
//! 4. Rank final top-N candidates with exact Poincaré distance
use crate::error::{HyperbolicError, HyperbolicResult};
use crate::poincare::{
conformal_factor, frechet_mean, log_map, norm, norm_squared, poincare_distance,
project_to_ball, PoincareConfig, EPS,
};
use serde::{Deserialize, Serialize};
/// Tangent space cache for a shard
///
/// Stores precomputed tangent coordinates for fast pruning.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TangentCache {
/// Centroid point (base of tangent space)
pub centroid: Vec<f32>,
/// Precomputed tangent coordinates for all points in shard
pub tangent_coords: Vec<Vec<f32>>,
/// Original point indices
pub point_indices: Vec<usize>,
/// Curvature parameter
pub curvature: f32,
/// Cached conformal factor at centroid
conformal: f32,
}
impl TangentCache {
/// Create a new tangent cache for a shard
///
/// # Arguments
/// * `points` - Points in the shard (Poincaré ball coordinates)
/// * `indices` - Original indices of the points
/// * `curvature` - Curvature parameter
pub fn new(points: &[Vec<f32>], indices: &[usize], curvature: f32) -> HyperbolicResult<Self> {
if points.is_empty() {
return Err(HyperbolicError::EmptyCollection);
}
let config = PoincareConfig::with_curvature(curvature)?;
// Compute centroid as Fréchet mean
let point_refs: Vec<&[f32]> = points.iter().map(|p| p.as_slice()).collect();
let centroid = frechet_mean(&point_refs, None, &config)?;
// Precompute tangent coordinates
let tangent_coords: Vec<Vec<f32>> = points
.iter()
.map(|p| log_map(p, &centroid, curvature))
.collect();
let conformal = conformal_factor(&centroid, curvature);
Ok(Self {
centroid,
tangent_coords,
point_indices: indices.to_vec(),
curvature,
conformal,
})
}
/// Create from centroid directly (for incremental updates)
pub fn from_centroid(
centroid: Vec<f32>,
points: &[Vec<f32>],
indices: &[usize],
curvature: f32,
) -> HyperbolicResult<Self> {
let tangent_coords: Vec<Vec<f32>> = points
.iter()
.map(|p| log_map(p, &centroid, curvature))
.collect();
let conformal = conformal_factor(&centroid, curvature);
Ok(Self {
centroid,
tangent_coords,
point_indices: indices.to_vec(),
curvature,
conformal,
})
}
/// Get tangent coordinates for a query point
pub fn query_tangent(&self, query: &[f32]) -> Vec<f32> {
log_map(query, &self.centroid, self.curvature)
}
/// Fast Euclidean distance in tangent space (for pruning)
#[inline]
pub fn tangent_distance_squared(&self, query_tangent: &[f32], idx: usize) -> f32 {
if idx >= self.tangent_coords.len() {
return f32::MAX;
}
let p = &self.tangent_coords[idx];
query_tangent
.iter()
.zip(p.iter())
.map(|(&a, &b)| (a - b) * (a - b))
.sum()
}
/// Exact Poincaré distance for final ranking
pub fn exact_distance(&self, query: &[f32], idx: usize, points: &[Vec<f32>]) -> f32 {
if idx >= points.len() {
return f32::MAX;
}
poincare_distance(query, &points[idx], self.curvature)
}
/// Add a new point to the cache (for incremental updates)
pub fn add_point(&mut self, point: &[f32], index: usize) {
let tangent = log_map(point, &self.centroid, self.curvature);
self.tangent_coords.push(tangent);
self.point_indices.push(index);
}
/// Update centroid and recompute all tangent coordinates
pub fn recompute_centroid(&mut self, points: &[Vec<f32>]) -> HyperbolicResult<()> {
if points.is_empty() {
return Err(HyperbolicError::EmptyCollection);
}
let config = PoincareConfig::with_curvature(self.curvature)?;
let point_refs: Vec<&[f32]> = points.iter().map(|p| p.as_slice()).collect();
self.centroid = frechet_mean(&point_refs, None, &config)?;
self.tangent_coords = points
.iter()
.map(|p| log_map(p, &self.centroid, self.curvature))
.collect();
self.conformal = conformal_factor(&self.centroid, self.curvature);
Ok(())
}
/// Get number of points in cache
pub fn len(&self) -> usize {
self.tangent_coords.len()
}
/// Check if cache is empty
pub fn is_empty(&self) -> bool {
self.tangent_coords.is_empty()
}
/// Get the dimension of the tangent space
pub fn dim(&self) -> usize {
self.centroid.len()
}
}
/// Tangent space pruning result
#[derive(Debug, Clone)]
pub struct PrunedCandidate {
/// Original index
pub index: usize,
/// Tangent space distance (for initial ranking)
pub tangent_dist: f32,
/// Exact Poincaré distance (computed lazily)
pub exact_dist: Option<f32>,
}
/// Tangent space pruner for HNSW neighbor selection
///
/// Implements the two-phase search:
/// 1. Fast pruning using Euclidean distance in tangent space
/// 2. Exact ranking using Poincaré distance for top candidates
pub struct TangentPruner {
/// Tangent caches for each shard
caches: Vec<TangentCache>,
/// Number of candidates to consider in exact phase
top_n: usize,
/// Pruning factor (how many candidates to keep from tangent phase)
prune_factor: usize,
}
impl TangentPruner {
/// Create a new pruner
///
/// # Arguments
/// * `top_n` - Number of final results
/// * `prune_factor` - Multiplier for candidates to consider (e.g., 10 means consider 10*top_n)
pub fn new(top_n: usize, prune_factor: usize) -> Self {
Self {
caches: Vec::new(),
top_n,
prune_factor,
}
}
/// Add a shard cache
pub fn add_cache(&mut self, cache: TangentCache) {
self.caches.push(cache);
}
/// Get shard caches
pub fn caches(&self) -> &[TangentCache] {
&self.caches
}
/// Get mutable shard caches
pub fn caches_mut(&mut self) -> &mut [TangentCache] {
&mut self.caches
}
/// Search across all shards with tangent pruning
///
/// Returns top_n candidates sorted by exact Poincaré distance.
pub fn search(
&self,
query: &[f32],
points: &[Vec<f32>],
curvature: f32,
) -> Vec<PrunedCandidate> {
let num_prune = self.top_n * self.prune_factor;
let mut candidates: Vec<PrunedCandidate> = Vec::with_capacity(num_prune);
// Phase 1: Tangent space pruning across all shards
for cache in &self.caches {
let query_tangent = cache.query_tangent(query);
for (local_idx, &global_idx) in cache.point_indices.iter().enumerate() {
let tangent_dist = cache.tangent_distance_squared(&query_tangent, local_idx);
candidates.push(PrunedCandidate {
index: global_idx,
tangent_dist,
exact_dist: None,
});
}
}
// Sort by tangent distance and keep top prune_factor * top_n
candidates.sort_by(|a, b| a.tangent_dist.partial_cmp(&b.tangent_dist).unwrap());
candidates.truncate(num_prune);
// Phase 2: Exact Poincaré distance for finalists
for candidate in &mut candidates {
if candidate.index < points.len() {
candidate.exact_dist =
Some(poincare_distance(query, &points[candidate.index], curvature));
}
}
// Sort by exact distance and return top_n
candidates.sort_by(|a, b| {
a.exact_dist
.unwrap_or(f32::MAX)
.partial_cmp(&b.exact_dist.unwrap_or(f32::MAX))
.unwrap()
});
candidates.truncate(self.top_n);
candidates
}
}
/// Compute micro tangent update for incremental operations
///
/// For small updates (reflex loop), compute tangent-space delta
/// that keeps the point inside the ball.
pub fn tangent_micro_update(
point: &[f32],
delta: &[f32],
centroid: &[f32],
curvature: f32,
max_step: f32,
) -> Vec<f32> {
// Get current tangent coordinates
let tangent = log_map(point, centroid, curvature);
// Apply bounded delta in tangent space
let delta_norm = norm(delta);
let scale = if delta_norm > max_step {
max_step / delta_norm
} else {
1.0
};
let new_tangent: Vec<f32> = tangent
.iter()
.zip(delta.iter())
.map(|(&t, &d)| t + scale * d)
.collect();
// Map back to ball and project
let new_point = crate::poincare::exp_map(&new_tangent, centroid, curvature);
project_to_ball(&new_point, curvature, EPS)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tangent_cache_creation() {
let points = vec![
vec![0.1, 0.2, 0.1],
vec![-0.1, 0.15, 0.05],
vec![0.2, -0.1, 0.1],
];
let indices: Vec<usize> = (0..3).collect();
let cache = TangentCache::new(&points, &indices, 1.0).unwrap();
assert_eq!(cache.len(), 3);
assert_eq!(cache.dim(), 3);
}
#[test]
fn test_tangent_pruning() {
let points = vec![
vec![0.1, 0.2],
vec![-0.1, 0.15],
vec![0.2, -0.1],
vec![0.05, 0.05],
];
let indices: Vec<usize> = (0..4).collect();
let cache = TangentCache::new(&points, &indices, 1.0).unwrap();
let mut pruner = TangentPruner::new(2, 2);
pruner.add_cache(cache);
let query = vec![0.08, 0.1];
let results = pruner.search(&query, &points, 1.0);
assert_eq!(results.len(), 2);
// Results should be sorted by exact distance
assert!(results[0].exact_dist.unwrap() <= results[1].exact_dist.unwrap());
}
}

View File

@@ -0,0 +1,531 @@
//! Comprehensive Mathematical Correctness Tests for Hyperbolic Operations
//!
//! These tests verify the mathematical properties of Poincaré ball operations
//! as specified in the evaluation protocol.
use ruvector_hyperbolic_hnsw::poincare::*;
use ruvector_hyperbolic_hnsw::tangent::*;
use ruvector_hyperbolic_hnsw::hnsw::*;
use ruvector_hyperbolic_hnsw::shard::*;
// ============================================================================
// Poincaré Ball Properties
// ============================================================================
#[test]
fn test_mobius_add_identity() {
// x ⊕ 0 = x (right identity)
let x = vec![0.3, 0.2, 0.1];
let zero = vec![0.0, 0.0, 0.0];
let result = mobius_add(&x, &zero, 1.0);
for (a, b) in x.iter().zip(result.iter()) {
assert!((a - b).abs() < 1e-5, "Right identity failed");
}
}
#[test]
fn test_mobius_add_inverse() {
// x ⊕ (-x) ≈ 0 (inverse element)
let x = vec![0.3, 0.2];
let neg_x: Vec<f32> = x.iter().map(|v| -v).collect();
let result = mobius_add(&x, &neg_x, 1.0);
let result_norm = norm(&result);
// Result should be close to zero
assert!(result_norm < 0.1, "Inverse element failed: norm = {}", result_norm);
}
#[test]
fn test_mobius_add_gyrocommutative() {
// Gyrocommutative: x ⊕ y ≈ gyr[x,y](y ⊕ x) (holds for small vectors)
let x = vec![0.1, 0.05];
let y = vec![0.08, -0.03];
let xy = mobius_add(&x, &y, 1.0);
let yx = mobius_add(&y, &x, 1.0);
// For small vectors, these should be similar
let diff: f32 = xy.iter().zip(yx.iter()).map(|(a, b)| (a - b).abs()).sum();
assert!(diff < 0.5, "Gyrocommutative property check: diff = {}", diff);
}
#[test]
fn test_exp_log_inverse() {
// log_p(exp_p(v)) = v (inverse relationship)
let p = vec![0.1, 0.2, 0.1];
let v = vec![0.1, -0.1, 0.05];
let q = exp_map(&v, &p, 1.0);
let v_recovered = log_map(&q, &p, 1.0);
for (a, b) in v.iter().zip(v_recovered.iter()) {
assert!((a - b).abs() < 1e-4, "exp-log inverse failed: expected {}, got {}", a, b);
}
}
#[test]
fn test_log_exp_inverse() {
// exp_p(log_p(q)) = q (inverse relationship)
let p = vec![0.1, 0.15];
let q = vec![0.2, -0.1];
let v = log_map(&q, &p, 1.0);
let q_recovered = exp_map(&v, &p, 1.0);
for (a, b) in q.iter().zip(q_recovered.iter()) {
assert!((a - b).abs() < 1e-4, "log-exp inverse failed: expected {}, got {}", a, b);
}
}
#[test]
fn test_distance_symmetry() {
// d(x, y) = d(y, x)
let x = vec![0.3, 0.2, 0.1];
let y = vec![-0.1, 0.4, 0.2];
let d1 = poincare_distance(&x, &y, 1.0);
let d2 = poincare_distance(&y, &x, 1.0);
assert!((d1 - d2).abs() < 1e-6, "Symmetry failed: {} vs {}", d1, d2);
}
#[test]
fn test_distance_identity() {
// d(x, x) = 0
let x = vec![0.3, 0.2, 0.1];
let d = poincare_distance(&x, &x, 1.0);
assert!(d.abs() < 1e-6, "Identity of indiscernibles failed: d = {}", d);
}
#[test]
fn test_distance_non_negative() {
// d(x, y) >= 0
let x = vec![0.3, 0.2];
let y = vec![-0.1, 0.4];
let d = poincare_distance(&x, &y, 1.0);
assert!(d >= 0.0, "Non-negativity failed: d = {}", d);
}
#[test]
fn test_distance_triangle_inequality() {
// d(x, z) <= d(x, y) + d(y, z)
let x = vec![0.1, 0.2];
let y = vec![0.3, 0.1];
let z = vec![-0.1, 0.35];
let dxz = poincare_distance(&x, &z, 1.0);
let dxy = poincare_distance(&x, &y, 1.0);
let dyz = poincare_distance(&y, &z, 1.0);
assert!(dxz <= dxy + dyz + 1e-5,
"Triangle inequality failed: {} > {} + {}", dxz, dxy, dyz);
}
// ============================================================================
// Numerical Stability
// ============================================================================
#[test]
fn test_projection_keeps_points_inside() {
// All projected points should satisfy ||x|| < 1/sqrt(c) - eps
let test_points = vec![
vec![0.5, 0.5, 0.5],
vec![0.9, 0.9],
vec![10.0, 10.0, 10.0],
vec![-5.0, 3.0],
];
for point in test_points {
let projected = project_to_ball(&point, 1.0, EPS);
let n = norm(&projected);
// Use <= with small tolerance for floating point
assert!(n <= 1.0 - EPS + 1e-7,
"Projection failed: norm {} >= max {}", n, 1.0 - EPS);
}
}
#[test]
fn test_near_boundary_stability() {
// Operations near the boundary should remain stable
let near_boundary = vec![0.99 - EPS, 0.0];
let small_vec = vec![0.01, 0.01];
// Should not panic or produce NaN/Inf
let result = mobius_add(&near_boundary, &small_vec, 1.0);
assert!(!result.iter().any(|v| v.is_nan() || v.is_infinite()),
"Near boundary operation produced NaN/Inf");
let n = norm(&result);
assert!(n < 1.0 - EPS, "Result escaped ball boundary");
}
#[test]
fn test_zero_vector_handling() {
// Operations with zero vector should be stable
let zero = vec![0.0, 0.0, 0.0];
let x = vec![0.3, 0.2, 0.1];
// exp_map with zero tangent should return base point
let result = exp_map(&zero, &x, 1.0);
for (a, b) in x.iter().zip(result.iter()) {
assert!((a - b).abs() < 1e-5, "exp_map with zero failed");
}
// log_map of same point should be zero
let log_result = log_map(&x, &x, 1.0);
assert!(norm(&log_result) < 1e-5, "log_map of same point should be zero");
}
#[test]
fn test_small_curvature_stability() {
// Small curvatures should work (approaches Euclidean)
let x = vec![0.3, 0.2];
let y = vec![0.1, 0.4];
let d_small_c = poincare_distance(&x, &y, 0.01);
let d_euclidean: f32 = x.iter().zip(y.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum::<f32>()
.sqrt();
// For small curvature, should approach Euclidean
// The ratio should be bounded
assert!(!d_small_c.is_nan() && !d_small_c.is_infinite(),
"Small curvature produced invalid result");
}
#[test]
fn test_large_curvature_stability() {
// Large curvatures should work (stronger hyperbolic effect)
let x = vec![0.1, 0.1];
let y = vec![0.2, 0.1];
let d_large_c = poincare_distance(&x, &y, 10.0);
assert!(!d_large_c.is_nan() && !d_large_c.is_infinite(),
"Large curvature produced invalid result: {}", d_large_c);
}
// ============================================================================
// Frechet Mean Properties
// ============================================================================
#[test]
fn test_frechet_mean_single_point() {
// Frechet mean of single point is that point
let points = vec![vec![0.3, 0.2]];
let point_refs: Vec<&[f32]> = points.iter().map(|p| p.as_slice()).collect();
let config = PoincareConfig::default();
let mean = frechet_mean(&point_refs, None, &config).unwrap();
for (a, b) in points[0].iter().zip(mean.iter()) {
assert!((a - b).abs() < 1e-4, "Single point mean failed");
}
}
#[test]
fn test_frechet_mean_symmetric() {
// Mean of symmetric points should be near origin
let points = vec![
vec![0.3, 0.0],
vec![-0.3, 0.0],
];
let point_refs: Vec<&[f32]> = points.iter().map(|p| p.as_slice()).collect();
let config = PoincareConfig::default();
let mean = frechet_mean(&point_refs, None, &config).unwrap();
// Mean should be close to origin
let mean_norm = norm(&mean);
assert!(mean_norm < 0.1, "Symmetric mean not near origin: {}", mean_norm);
}
// ============================================================================
// Tangent Space Operations
// ============================================================================
#[test]
fn test_tangent_cache_creation() {
let points = vec![
vec![0.1, 0.2, 0.1],
vec![-0.1, 0.15, 0.05],
vec![0.2, -0.1, 0.1],
];
let indices: Vec<usize> = (0..3).collect();
let cache = TangentCache::new(&points, &indices, 1.0).unwrap();
assert_eq!(cache.len(), 3);
assert_eq!(cache.dim(), 3);
// Centroid should be inside ball
let centroid_norm = norm(&cache.centroid);
assert!(centroid_norm < 1.0 - EPS, "Centroid outside ball");
}
#[test]
fn test_tangent_distance_ordering() {
// Tangent distance should roughly preserve hyperbolic distance ordering
let points = vec![
vec![0.1, 0.1],
vec![0.2, 0.1],
vec![0.5, 0.3],
];
let indices: Vec<usize> = (0..3).collect();
let cache = TangentCache::new(&points, &indices, 1.0).unwrap();
let query = vec![0.12, 0.11];
let query_tangent = cache.query_tangent(&query);
let mut tangent_dists: Vec<(usize, f32)> = (0..3)
.map(|i| (i, cache.tangent_distance_squared(&query_tangent, i)))
.collect();
tangent_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
let mut hyp_dists: Vec<(usize, f32)> = (0..3)
.map(|i| (i, poincare_distance(&query, &points[i], 1.0)))
.collect();
hyp_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
// First nearest neighbor should match
assert_eq!(tangent_dists[0].0, hyp_dists[0].0,
"First neighbor mismatch: tangent says {}, hyperbolic says {}",
tangent_dists[0].0, hyp_dists[0].0);
}
// ============================================================================
// HNSW Integration
// ============================================================================
#[test]
fn test_hnsw_insert_and_search() {
let mut hnsw = HyperbolicHnsw::default_config();
// Insert points
for i in 0..20 {
let v = vec![0.1 * (i as f32 % 5.0), 0.05 * (i as f32 / 5.0)];
hnsw.insert(v).unwrap();
}
assert_eq!(hnsw.len(), 20);
// Search
let query = vec![0.25, 0.125];
let results = hnsw.search(&query, 5).unwrap();
assert_eq!(results.len(), 5);
// Results should be sorted by distance
for i in 1..results.len() {
assert!(results[i-1].distance <= results[i].distance,
"Results not sorted at index {}: {} > {}",
i, results[i-1].distance, results[i].distance);
}
}
#[test]
fn test_hnsw_nearest_is_correct() {
let mut hnsw = HyperbolicHnsw::default_config();
let points = vec![
vec![0.0, 0.0],
vec![0.5, 0.0],
vec![0.0, 0.5],
vec![0.3, 0.3],
];
for p in &points {
hnsw.insert(p.clone()).unwrap();
}
// Query near origin
let query = vec![0.05, 0.05];
let results = hnsw.search(&query, 1).unwrap();
// Should find point at origin (id 0)
assert_eq!(results[0].id, 0, "Expected nearest to be origin");
}
#[test]
fn test_hnsw_curvature_update() {
let mut hnsw = HyperbolicHnsw::default_config();
hnsw.insert(vec![0.1, 0.2]).unwrap();
hnsw.insert(vec![0.3, 0.1]).unwrap();
// Update curvature
hnsw.set_curvature(2.0).unwrap();
assert!((hnsw.config.curvature - 2.0).abs() < 1e-6);
// Search should still work
let results = hnsw.search(&[0.2, 0.15], 2).unwrap();
assert_eq!(results.len(), 2);
}
// ============================================================================
// Shard Management
// ============================================================================
#[test]
fn test_curvature_registry() {
let mut registry = CurvatureRegistry::new(1.0);
registry.set("shard_1", 0.5);
assert!((registry.get("shard_1") - 0.5).abs() < 1e-6);
assert!((registry.get("unknown") - 1.0).abs() < 1e-6); // Default
// Canary testing
registry.set_canary("shard_1", 0.3, 50);
assert!((registry.get_effective("shard_1", false) - 0.5).abs() < 1e-6);
assert!((registry.get_effective("shard_1", true) - 0.3).abs() < 1e-6);
// Promote canary
if let Some(shard) = registry.shards.get_mut("shard_1") {
shard.promote_canary();
}
assert!((registry.get("shard_1") - 0.3).abs() < 1e-6);
}
#[test]
fn test_sharded_hnsw() {
let mut manager = ShardedHyperbolicHnsw::new(1.0);
for i in 0..30 {
let v = vec![0.1 * (i as f32 % 6.0), 0.05 * (i as f32 / 6.0)];
manager.insert(v, Some(i / 10)).unwrap();
}
assert_eq!(manager.len(), 30);
assert!(manager.num_shards() > 0);
// Search
let results = manager.search(&[0.25, 0.125], 5).unwrap();
assert!(!results.is_empty());
}
// ============================================================================
// Hierarchy Metrics
// ============================================================================
#[test]
fn test_hierarchy_metrics_radius_correlation() {
// Points with radius proportional to depth should have positive correlation
let points: Vec<Vec<f32>> = (0..20).map(|i| {
let depth = i / 4;
let radius = 0.1 + 0.15 * depth as f32;
let angle = (i % 4) as f32 * std::f32::consts::PI / 2.0;
vec![radius * angle.cos(), radius * angle.sin()]
}).collect();
let depths: Vec<usize> = (0..20).map(|i| i / 4).collect();
let metrics = HierarchyMetrics::compute(&points, &depths, 1.0).unwrap();
assert!(metrics.radius_depth_correlation > 0.5,
"Expected positive correlation, got {}", metrics.radius_depth_correlation);
}
// ============================================================================
// Dual Space Index
// ============================================================================
#[test]
fn test_dual_space_index() {
let mut dual = DualSpaceIndex::new(1.0, 0.5);
for i in 0..15 {
let v = vec![0.1 * i as f32, 0.05 * i as f32];
dual.insert(v).unwrap();
}
let results = dual.search(&[0.35, 0.175], 5).unwrap();
assert_eq!(results.len(), 5);
// Results should be sorted
for i in 1..results.len() {
assert!(results[i-1].distance <= results[i].distance);
}
}
// ============================================================================
// Edge Cases
// ============================================================================
#[test]
fn test_empty_index_search() {
let hnsw = HyperbolicHnsw::default_config();
let results = hnsw.search(&[0.1, 0.2], 5).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_single_element_search() {
let mut hnsw = HyperbolicHnsw::default_config();
hnsw.insert(vec![0.3, 0.2]).unwrap();
let results = hnsw.search(&[0.1, 0.2], 5).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, 0);
}
#[test]
fn test_k_larger_than_index() {
let mut hnsw = HyperbolicHnsw::default_config();
for i in 0..3 {
hnsw.insert(vec![0.1 * i as f32, 0.1]).unwrap();
}
let results = hnsw.search(&[0.15, 0.1], 10).unwrap();
assert_eq!(results.len(), 3);
}
// ============================================================================
// Performance Characteristics
// ============================================================================
#[test]
fn test_insert_performance() {
let mut hnsw = HyperbolicHnsw::default_config();
// Should handle 100 insertions without panic
for i in 0..100 {
let v = vec![
0.05 * (i % 10) as f32,
0.05 * (i / 10) as f32,
];
hnsw.insert(v).unwrap();
}
assert_eq!(hnsw.len(), 100);
}
#[test]
fn test_search_performance() {
let mut hnsw = HyperbolicHnsw::default_config();
for i in 0..100 {
let v = vec![
0.05 * (i % 10) as f32,
0.05 * (i / 10) as f32,
];
hnsw.insert(v).unwrap();
}
// Should handle multiple searches
for _ in 0..10 {
let query = vec![0.25, 0.25];
let results = hnsw.search(&query, 10).unwrap();
assert_eq!(results.len(), 10);
}
}