//! Tangent Space Operations for HNSW Pruning Optimization //! //! This module implements the key optimization for hyperbolic HNSW: //! - Precompute tangent space coordinates at shard centroids //! - Use cheap Euclidean distance in tangent space for pruning //! - Only compute exact Poincaré distance for final ranking //! //! # HNSW Speed Trick //! //! The core insight is that for points near a centroid c: //! 1. Map points to tangent space: u = log_c(x) //! 2. Euclidean distance ||u_q - u_p|| approximates hyperbolic distance //! 3. Prune candidates using fast Euclidean comparisons //! 4. Rank final top-N candidates with exact Poincaré distance use crate::error::{HyperbolicError, HyperbolicResult}; use crate::poincare::{ conformal_factor, frechet_mean, log_map, norm, norm_squared, poincare_distance, project_to_ball, PoincareConfig, EPS, }; use serde::{Deserialize, Serialize}; /// Tangent space cache for a shard /// /// Stores precomputed tangent coordinates for fast pruning. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct TangentCache { /// Centroid point (base of tangent space) pub centroid: Vec, /// Precomputed tangent coordinates for all points in shard pub tangent_coords: Vec>, /// Original point indices pub point_indices: Vec, /// Curvature parameter pub curvature: f32, /// Cached conformal factor at centroid conformal: f32, } impl TangentCache { /// Create a new tangent cache for a shard /// /// # Arguments /// * `points` - Points in the shard (Poincaré ball coordinates) /// * `indices` - Original indices of the points /// * `curvature` - Curvature parameter pub fn new(points: &[Vec], indices: &[usize], curvature: f32) -> HyperbolicResult { if points.is_empty() { return Err(HyperbolicError::EmptyCollection); } let config = PoincareConfig::with_curvature(curvature)?; // Compute centroid as Fréchet mean let point_refs: Vec<&[f32]> = points.iter().map(|p| p.as_slice()).collect(); let centroid = frechet_mean(&point_refs, None, &config)?; // Precompute tangent coordinates let tangent_coords: Vec> = points .iter() .map(|p| log_map(p, ¢roid, curvature)) .collect(); let conformal = conformal_factor(¢roid, curvature); Ok(Self { centroid, tangent_coords, point_indices: indices.to_vec(), curvature, conformal, }) } /// Create from centroid directly (for incremental updates) pub fn from_centroid( centroid: Vec, points: &[Vec], indices: &[usize], curvature: f32, ) -> HyperbolicResult { let tangent_coords: Vec> = points .iter() .map(|p| log_map(p, ¢roid, curvature)) .collect(); let conformal = conformal_factor(¢roid, curvature); Ok(Self { centroid, tangent_coords, point_indices: indices.to_vec(), curvature, conformal, }) } /// Get tangent coordinates for a query point pub fn query_tangent(&self, query: &[f32]) -> Vec { log_map(query, &self.centroid, self.curvature) } /// Fast Euclidean distance in tangent space (for pruning) #[inline] pub fn tangent_distance_squared(&self, query_tangent: &[f32], idx: usize) -> f32 { if idx >= self.tangent_coords.len() { return f32::MAX; } let p = &self.tangent_coords[idx]; query_tangent .iter() .zip(p.iter()) .map(|(&a, &b)| (a - b) * (a - b)) .sum() } /// Exact Poincaré distance for final ranking pub fn exact_distance(&self, query: &[f32], idx: usize, points: &[Vec]) -> f32 { if idx >= points.len() { return f32::MAX; } poincare_distance(query, &points[idx], self.curvature) } /// Add a new point to the cache (for incremental updates) pub fn add_point(&mut self, point: &[f32], index: usize) { let tangent = log_map(point, &self.centroid, self.curvature); self.tangent_coords.push(tangent); self.point_indices.push(index); } /// Update centroid and recompute all tangent coordinates pub fn recompute_centroid(&mut self, points: &[Vec]) -> HyperbolicResult<()> { if points.is_empty() { return Err(HyperbolicError::EmptyCollection); } let config = PoincareConfig::with_curvature(self.curvature)?; let point_refs: Vec<&[f32]> = points.iter().map(|p| p.as_slice()).collect(); self.centroid = frechet_mean(&point_refs, None, &config)?; self.tangent_coords = points .iter() .map(|p| log_map(p, &self.centroid, self.curvature)) .collect(); self.conformal = conformal_factor(&self.centroid, self.curvature); Ok(()) } /// Get number of points in cache pub fn len(&self) -> usize { self.tangent_coords.len() } /// Check if cache is empty pub fn is_empty(&self) -> bool { self.tangent_coords.is_empty() } /// Get the dimension of the tangent space pub fn dim(&self) -> usize { self.centroid.len() } } /// Tangent space pruning result #[derive(Debug, Clone)] pub struct PrunedCandidate { /// Original index pub index: usize, /// Tangent space distance (for initial ranking) pub tangent_dist: f32, /// Exact Poincaré distance (computed lazily) pub exact_dist: Option, } /// Tangent space pruner for HNSW neighbor selection /// /// Implements the two-phase search: /// 1. Fast pruning using Euclidean distance in tangent space /// 2. Exact ranking using Poincaré distance for top candidates pub struct TangentPruner { /// Tangent caches for each shard caches: Vec, /// Number of candidates to consider in exact phase top_n: usize, /// Pruning factor (how many candidates to keep from tangent phase) prune_factor: usize, } impl TangentPruner { /// Create a new pruner /// /// # Arguments /// * `top_n` - Number of final results /// * `prune_factor` - Multiplier for candidates to consider (e.g., 10 means consider 10*top_n) pub fn new(top_n: usize, prune_factor: usize) -> Self { Self { caches: Vec::new(), top_n, prune_factor, } } /// Add a shard cache pub fn add_cache(&mut self, cache: TangentCache) { self.caches.push(cache); } /// Get shard caches pub fn caches(&self) -> &[TangentCache] { &self.caches } /// Get mutable shard caches pub fn caches_mut(&mut self) -> &mut [TangentCache] { &mut self.caches } /// Search across all shards with tangent pruning /// /// Returns top_n candidates sorted by exact Poincaré distance. pub fn search( &self, query: &[f32], points: &[Vec], curvature: f32, ) -> Vec { let num_prune = self.top_n * self.prune_factor; let mut candidates: Vec = Vec::with_capacity(num_prune); // Phase 1: Tangent space pruning across all shards for cache in &self.caches { let query_tangent = cache.query_tangent(query); for (local_idx, &global_idx) in cache.point_indices.iter().enumerate() { let tangent_dist = cache.tangent_distance_squared(&query_tangent, local_idx); candidates.push(PrunedCandidate { index: global_idx, tangent_dist, exact_dist: None, }); } } // Sort by tangent distance and keep top prune_factor * top_n candidates.sort_by(|a, b| a.tangent_dist.partial_cmp(&b.tangent_dist).unwrap()); candidates.truncate(num_prune); // Phase 2: Exact Poincaré distance for finalists for candidate in &mut candidates { if candidate.index < points.len() { candidate.exact_dist = Some(poincare_distance(query, &points[candidate.index], curvature)); } } // Sort by exact distance and return top_n candidates.sort_by(|a, b| { a.exact_dist .unwrap_or(f32::MAX) .partial_cmp(&b.exact_dist.unwrap_or(f32::MAX)) .unwrap() }); candidates.truncate(self.top_n); candidates } } /// Compute micro tangent update for incremental operations /// /// For small updates (reflex loop), compute tangent-space delta /// that keeps the point inside the ball. pub fn tangent_micro_update( point: &[f32], delta: &[f32], centroid: &[f32], curvature: f32, max_step: f32, ) -> Vec { // Get current tangent coordinates let tangent = log_map(point, centroid, curvature); // Apply bounded delta in tangent space let delta_norm = norm(delta); let scale = if delta_norm > max_step { max_step / delta_norm } else { 1.0 }; let new_tangent: Vec = tangent .iter() .zip(delta.iter()) .map(|(&t, &d)| t + scale * d) .collect(); // Map back to ball and project let new_point = crate::poincare::exp_map(&new_tangent, centroid, curvature); project_to_ball(&new_point, curvature, EPS) } #[cfg(test)] mod tests { use super::*; #[test] fn test_tangent_cache_creation() { let points = vec![ vec![0.1, 0.2, 0.1], vec![-0.1, 0.15, 0.05], vec![0.2, -0.1, 0.1], ]; let indices: Vec = (0..3).collect(); let cache = TangentCache::new(&points, &indices, 1.0).unwrap(); assert_eq!(cache.len(), 3); assert_eq!(cache.dim(), 3); } #[test] fn test_tangent_pruning() { let points = vec![ vec![0.1, 0.2], vec![-0.1, 0.15], vec![0.2, -0.1], vec![0.05, 0.05], ]; let indices: Vec = (0..4).collect(); let cache = TangentCache::new(&points, &indices, 1.0).unwrap(); let mut pruner = TangentPruner::new(2, 2); pruner.add_cache(cache); let query = vec![0.08, 0.1]; let results = pruner.search(&query, &points, 1.0); assert_eq!(results.len(), 2); // Results should be sorted by exact distance assert!(results[0].exact_dist.unwrap() <= results[1].exact_dist.unwrap()); } }