Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,260 @@
//! Component Quantization for Mixed-Curvature Attention
//!
//! Different precision for each geometric component:
//! - Euclidean: 7-8 bit (needs precision)
//! - Hyperbolic tangent: 5 bit (tolerates noise)
//! - Spherical: 5 bit (only direction matters)
use serde::{Deserialize, Serialize};
/// Quantization configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuantizationConfig {
/// Bits for Euclidean component
pub euclidean_bits: u8,
/// Bits for Hyperbolic component
pub hyperbolic_bits: u8,
/// Bits for Spherical component
pub spherical_bits: u8,
}
impl Default for QuantizationConfig {
fn default() -> Self {
Self {
euclidean_bits: 8,
hyperbolic_bits: 5,
spherical_bits: 5,
}
}
}
/// Quantized vector representation
#[derive(Debug, Clone)]
pub struct QuantizedVector {
/// Quantized Euclidean component
pub euclidean: Vec<i8>,
/// Euclidean scale factor
pub euclidean_scale: f32,
/// Quantized Hyperbolic component
pub hyperbolic: Vec<i8>,
/// Hyperbolic scale factor
pub hyperbolic_scale: f32,
/// Quantized Spherical component
pub spherical: Vec<i8>,
/// Spherical scale factor
pub spherical_scale: f32,
}
/// Component quantizer for efficient storage and compute
#[derive(Debug, Clone)]
pub struct ComponentQuantizer {
config: QuantizationConfig,
euclidean_levels: i32,
hyperbolic_levels: i32,
spherical_levels: i32,
}
impl ComponentQuantizer {
/// Create new quantizer
pub fn new(config: QuantizationConfig) -> Self {
Self {
euclidean_levels: (1 << (config.euclidean_bits - 1)) - 1,
hyperbolic_levels: (1 << (config.hyperbolic_bits - 1)) - 1,
spherical_levels: (1 << (config.spherical_bits - 1)) - 1,
config,
}
}
/// Quantize a component vector
fn quantize_component(&self, values: &[f32], levels: i32) -> (Vec<i8>, f32) {
if values.is_empty() {
return (vec![], 1.0);
}
// Find absmax for scale
let absmax = values
.iter()
.map(|v| v.abs())
.fold(0.0f32, f32::max)
.max(1e-8);
let scale = absmax / levels as f32;
let inv_scale = levels as f32 / absmax;
let quantized: Vec<i8> = values
.iter()
.map(|v| (v * inv_scale).round().clamp(-127.0, 127.0) as i8)
.collect();
(quantized, scale)
}
/// Dequantize a component
fn dequantize_component(&self, quantized: &[i8], scale: f32) -> Vec<f32> {
quantized.iter().map(|&q| q as f32 * scale).collect()
}
/// Quantize full vector with component ranges
pub fn quantize(
&self,
vector: &[f32],
e_range: std::ops::Range<usize>,
h_range: std::ops::Range<usize>,
s_range: std::ops::Range<usize>,
) -> QuantizedVector {
let (euclidean, euclidean_scale) =
self.quantize_component(&vector[e_range], self.euclidean_levels);
let (hyperbolic, hyperbolic_scale) =
self.quantize_component(&vector[h_range], self.hyperbolic_levels);
let (spherical, spherical_scale) =
self.quantize_component(&vector[s_range], self.spherical_levels);
QuantizedVector {
euclidean,
euclidean_scale,
hyperbolic,
hyperbolic_scale,
spherical,
spherical_scale,
}
}
/// Compute dot product between quantized vectors (integer arithmetic)
#[inline]
pub fn quantized_dot_product(
&self,
a: &QuantizedVector,
b: &QuantizedVector,
weights: &[f32; 3],
) -> f32 {
// Integer dot products
let dot_e = Self::int_dot(&a.euclidean, &b.euclidean);
let dot_h = Self::int_dot(&a.hyperbolic, &b.hyperbolic);
let dot_s = Self::int_dot(&a.spherical, &b.spherical);
// Scale and weight
let sim_e = dot_e as f32 * a.euclidean_scale * b.euclidean_scale;
let sim_h = dot_h as f32 * a.hyperbolic_scale * b.hyperbolic_scale;
let sim_s = dot_s as f32 * a.spherical_scale * b.spherical_scale;
weights[0] * sim_e + weights[1] * sim_h + weights[2] * sim_s
}
/// Integer dot product (SIMD-friendly)
#[inline(always)]
fn int_dot(a: &[i8], b: &[i8]) -> i32 {
let len = a.len().min(b.len());
let chunks = len / 4;
let remainder = len % 4;
let mut sum0 = 0i32;
let mut sum1 = 0i32;
let mut sum2 = 0i32;
let mut sum3 = 0i32;
for i in 0..chunks {
let base = i * 4;
sum0 += a[base] as i32 * b[base] as i32;
sum1 += a[base + 1] as i32 * b[base + 1] as i32;
sum2 += a[base + 2] as i32 * b[base + 2] as i32;
sum3 += a[base + 3] as i32 * b[base + 3] as i32;
}
let base = chunks * 4;
for i in 0..remainder {
sum0 += a[base + i] as i32 * b[base + i] as i32;
}
sum0 + sum1 + sum2 + sum3
}
/// Dequantize to full vector
pub fn dequantize(&self, quant: &QuantizedVector, total_dim: usize) -> Vec<f32> {
let mut result = vec![0.0f32; total_dim];
let e_vec = self.dequantize_component(&quant.euclidean, quant.euclidean_scale);
let h_vec = self.dequantize_component(&quant.hyperbolic, quant.hyperbolic_scale);
let s_vec = self.dequantize_component(&quant.spherical, quant.spherical_scale);
let e_end = e_vec.len();
let h_end = e_end + h_vec.len();
result[0..e_end].copy_from_slice(&e_vec);
result[e_end..h_end].copy_from_slice(&h_vec);
result[h_end..h_end + s_vec.len()].copy_from_slice(&s_vec);
result
}
/// Get memory savings ratio
pub fn compression_ratio(&self, dim: usize, e_dim: usize, h_dim: usize, s_dim: usize) -> f32 {
let original_bits = dim as f32 * 32.0;
let quantized_bits = e_dim as f32 * self.config.euclidean_bits as f32
+ h_dim as f32 * self.config.hyperbolic_bits as f32
+ s_dim as f32 * self.config.spherical_bits as f32
+ 3.0 * 32.0; // 3 scale factors
original_bits / quantized_bits
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_quantize_dequantize() {
let quantizer = ComponentQuantizer::new(QuantizationConfig::default());
let vector = vec![0.5f32; 64];
let e_range = 0..32;
let h_range = 32..48;
let s_range = 48..64;
let quantized =
quantizer.quantize(&vector, e_range.clone(), h_range.clone(), s_range.clone());
assert_eq!(quantized.euclidean.len(), 32);
assert_eq!(quantized.hyperbolic.len(), 16);
assert_eq!(quantized.spherical.len(), 16);
// Dequantize and check approximate equality
let dequantized = quantizer.dequantize(&quantized, 64);
for (&orig, &deq) in vector.iter().zip(dequantized.iter()) {
assert!((orig - deq).abs() < 0.1);
}
}
#[test]
fn test_quantized_dot_product() {
let quantizer = ComponentQuantizer::new(QuantizationConfig::default());
let a = vec![1.0f32; 64];
let b = vec![1.0f32; 64];
let e_range = 0..32;
let h_range = 32..48;
let s_range = 48..64;
let qa = quantizer.quantize(&a, e_range.clone(), h_range.clone(), s_range.clone());
let qb = quantizer.quantize(&b, e_range, h_range, s_range);
let weights = [0.5, 0.3, 0.2];
let dot = quantizer.quantized_dot_product(&qa, &qb, &weights);
// Should be positive for same vectors
assert!(dot > 0.0);
}
#[test]
fn test_compression_ratio() {
let quantizer = ComponentQuantizer::new(QuantizationConfig::default());
let ratio = quantizer.compression_ratio(512, 256, 192, 64);
// With 8/5/5 bits vs 32 bits, expect ~4-5x compression
assert!(ratio > 3.0);
assert!(ratio < 7.0);
}
}

View File

@@ -0,0 +1,441 @@
//! Fused Mixed-Curvature Attention
//!
//! Single kernel that computes Euclidean, Hyperbolic (tangent), and Spherical
//! similarities in one pass for maximum cache efficiency.
//!
//! logit(q,k) = a * dot(q_E, k_E) + b * dot(q_H_tan, k_H_tan) + c * dot(q_S, k_S)
use super::tangent_space::{TangentSpaceConfig, TangentSpaceMapper};
use crate::error::{AttentionError, AttentionResult};
use crate::traits::Attention;
use serde::{Deserialize, Serialize};
/// Configuration for fused mixed-curvature attention
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FusedCurvatureConfig {
/// Total dimension
pub dim: usize,
/// Euclidean component dimension
pub euclidean_dim: usize,
/// Hyperbolic component dimension
pub hyperbolic_dim: usize,
/// Spherical component dimension
pub spherical_dim: usize,
/// Mixing weight for Euclidean component
pub weight_e: f32,
/// Mixing weight for Hyperbolic component
pub weight_h: f32,
/// Mixing weight for Spherical component
pub weight_s: f32,
/// Hyperbolic curvature
pub hyperbolic_curvature: f32,
/// Temperature for softmax
pub temperature: f32,
/// Number of attention heads
pub num_heads: usize,
/// Per-head weight variation (low-rank)
pub per_head_variation: f32,
}
impl Default for FusedCurvatureConfig {
fn default() -> Self {
Self {
dim: 512,
euclidean_dim: 256,
hyperbolic_dim: 192,
spherical_dim: 64,
weight_e: 0.5,
weight_h: 0.35,
weight_s: 0.15,
hyperbolic_curvature: -1.0,
temperature: 1.0,
num_heads: 8,
per_head_variation: 0.1,
}
}
}
impl FusedCurvatureConfig {
/// Validate config
pub fn validate(&self) -> Result<(), String> {
if self.euclidean_dim + self.hyperbolic_dim + self.spherical_dim != self.dim {
return Err("Component dimensions must sum to total dim".into());
}
Ok(())
}
/// Get component ranges
pub fn component_ranges(
&self,
) -> (
std::ops::Range<usize>,
std::ops::Range<usize>,
std::ops::Range<usize>,
) {
let e_end = self.euclidean_dim;
let h_end = e_end + self.hyperbolic_dim;
let s_end = h_end + self.spherical_dim;
(0..e_end, e_end..h_end, h_end..s_end)
}
}
/// Window cache for mixed-curvature attention
#[derive(Debug, Clone)]
pub struct MixedCurvatureCache {
/// Tangent-space mapped hyperbolic components [N × h_dim]
pub keys_hyperbolic_tangent: Vec<Vec<f32>>,
/// Normalized spherical components [N × s_dim]
pub keys_spherical_normalized: Vec<Vec<f32>>,
/// Number of keys
pub num_keys: usize,
}
/// Fused mixed-curvature attention
///
/// Computes attention with Euclidean, Hyperbolic, and Spherical
/// similarities in a single fused kernel.
#[derive(Debug, Clone)]
pub struct MixedCurvatureFusedAttention {
config: FusedCurvatureConfig,
tangent_mapper: TangentSpaceMapper,
/// Per-head weight modifiers [num_heads × 3]
head_weights: Vec<[f32; 3]>,
}
impl MixedCurvatureFusedAttention {
/// Create new fused attention
pub fn new(config: FusedCurvatureConfig) -> Self {
let tangent_config = TangentSpaceConfig {
hyperbolic_dim: config.hyperbolic_dim,
curvature: config.hyperbolic_curvature,
learnable_origin: true,
};
let tangent_mapper = TangentSpaceMapper::new(tangent_config);
// Initialize per-head weights with small variation
let head_weights: Vec<[f32; 3]> = (0..config.num_heads)
.map(|h| {
let var = config.per_head_variation;
let h_factor = h as f32 / config.num_heads as f32 - 0.5;
[
config.weight_e + h_factor * var,
config.weight_h - h_factor * var * 0.5,
config.weight_s + h_factor * var * 0.5,
]
})
.collect();
Self {
config,
tangent_mapper,
head_weights,
}
}
/// Create with balanced weights
pub fn with_dim(dim: usize) -> Self {
let e_dim = dim / 2;
let h_dim = dim / 4;
let s_dim = dim - e_dim - h_dim;
let config = FusedCurvatureConfig {
dim,
euclidean_dim: e_dim,
hyperbolic_dim: h_dim,
spherical_dim: s_dim,
..Default::default()
};
Self::new(config)
}
/// Build cache for keys (pre-compute expensive operations)
pub fn build_cache(&self, keys: &[&[f32]]) -> MixedCurvatureCache {
let (_e_range, h_range, s_range) = self.config.component_ranges();
// Pre-map hyperbolic components to tangent space
let keys_hyperbolic_tangent: Vec<Vec<f32>> = keys
.iter()
.map(|k| {
let h_part = &k[h_range.clone()];
self.tangent_mapper.log_map(h_part)
})
.collect();
// Pre-normalize spherical components
let keys_spherical_normalized: Vec<Vec<f32>> = keys
.iter()
.map(|k| {
let s_part = &k[s_range.clone()];
Self::normalize(s_part)
})
.collect();
MixedCurvatureCache {
keys_hyperbolic_tangent,
keys_spherical_normalized,
num_keys: keys.len(),
}
}
/// Compute attention with cache (fast path)
pub fn compute_with_cache(
&self,
query: &[f32],
keys: &[&[f32]],
values: &[&[f32]],
cache: &MixedCurvatureCache,
head_idx: usize,
) -> AttentionResult<Vec<f32>> {
let num_keys = cache.num_keys;
if num_keys == 0 {
return Err(AttentionError::InvalidConfig("No keys".into()));
}
let (e_range, h_range, s_range) = self.config.component_ranges();
let weights = &self.head_weights[head_idx % self.head_weights.len()];
// Extract query components
let q_e = &query[e_range.clone()];
let q_h = &query[h_range.clone()];
let q_s = &query[s_range.clone()];
// Map query hyperbolic to tangent space
let q_h_tangent = self.tangent_mapper.log_map(q_h);
// Normalize query spherical
let q_s_normalized = Self::normalize(q_s);
// Compute fused logits
let logits: Vec<f32> = (0..num_keys)
.map(|i| {
let k = keys[i];
// Euclidean similarity (dot product)
let sim_e = Self::dot_product_simd(&q_e, &k[e_range.clone()]);
// Hyperbolic similarity (tangent space dot product)
let sim_h = Self::dot_product_simd(&q_h_tangent, &cache.keys_hyperbolic_tangent[i]);
// Spherical similarity (normalized dot product)
let sim_s =
Self::dot_product_simd(&q_s_normalized, &cache.keys_spherical_normalized[i]);
// Fused logit
(weights[0] * sim_e + weights[1] * sim_h + weights[2] * sim_s)
/ self.config.temperature
})
.collect();
// Softmax
let attention_weights = Self::stable_softmax(&logits);
// Weighted sum
self.weighted_sum(&attention_weights, values)
}
/// Fused similarity computation (single pass through all components)
/// This is the hot path - maximize SIMD utilization
#[inline]
pub fn fused_similarity(
&self,
query: &[f32],
key: &[f32],
key_h_tangent: &[f32],
key_s_normalized: &[f32],
query_h_tangent: &[f32],
query_s_normalized: &[f32],
weights: &[f32; 3],
) -> f32 {
let (e_range, _, _) = self.config.component_ranges();
// Euclidean: direct dot product on original vectors
let sim_e = Self::dot_product_simd(&query[e_range.clone()], &key[e_range.clone()]);
// Hyperbolic: dot product in tangent space
let sim_h = Self::dot_product_simd(query_h_tangent, key_h_tangent);
// Spherical: dot product of normalized vectors
let sim_s = Self::dot_product_simd(query_s_normalized, key_s_normalized);
weights[0] * sim_e + weights[1] * sim_h + weights[2] * sim_s
}
/// Normalize vector to unit length
#[inline]
fn normalize(v: &[f32]) -> Vec<f32> {
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-8 {
v.iter().map(|x| x / norm).collect()
} else {
v.to_vec()
}
}
/// SIMD-friendly dot product
#[inline(always)]
fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 {
let len = a.len().min(b.len());
let chunks = len / 4;
let remainder = len % 4;
let mut sum0 = 0.0f32;
let mut sum1 = 0.0f32;
let mut sum2 = 0.0f32;
let mut sum3 = 0.0f32;
for i in 0..chunks {
let base = i * 4;
sum0 += a[base] * b[base];
sum1 += a[base + 1] * b[base + 1];
sum2 += a[base + 2] * b[base + 2];
sum3 += a[base + 3] * b[base + 3];
}
let base = chunks * 4;
for i in 0..remainder {
sum0 += a[base + i] * b[base + i];
}
sum0 + sum1 + sum2 + sum3
}
/// Stable softmax
fn stable_softmax(logits: &[f32]) -> Vec<f32> {
if logits.is_empty() {
return vec![];
}
let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_logits: Vec<f32> = logits.iter().map(|&l| (l - max_logit).exp()).collect();
let sum: f32 = exp_logits.iter().sum();
exp_logits.iter().map(|&e| e / sum).collect()
}
/// Weighted sum
fn weighted_sum(&self, weights: &[f32], values: &[&[f32]]) -> AttentionResult<Vec<f32>> {
if weights.is_empty() || values.is_empty() {
return Err(AttentionError::InvalidConfig("Empty inputs".into()));
}
let dim = values[0].len();
let mut output = vec![0.0f32; dim];
for (weight, value) in weights.iter().zip(values.iter()) {
for (o, &v) in output.iter_mut().zip(value.iter()) {
*o += weight * v;
}
}
Ok(output)
}
}
impl Attention for MixedCurvatureFusedAttention {
fn compute(
&self,
query: &[f32],
keys: &[&[f32]],
values: &[&[f32]],
) -> AttentionResult<Vec<f32>> {
let cache = self.build_cache(keys);
self.compute_with_cache(query, keys, values, &cache, 0)
}
fn compute_with_mask(
&self,
query: &[f32],
keys: &[&[f32]],
values: &[&[f32]],
mask: Option<&[bool]>,
) -> AttentionResult<Vec<f32>> {
if let Some(m) = mask {
let filtered: Vec<(&[f32], &[f32])> = keys
.iter()
.zip(values.iter())
.enumerate()
.filter(|(i, _)| m.get(*i).copied().unwrap_or(true))
.map(|(_, (k, v))| (*k, *v))
.collect();
let filtered_keys: Vec<&[f32]> = filtered.iter().map(|(k, _)| *k).collect();
let filtered_values: Vec<&[f32]> = filtered.iter().map(|(_, v)| *v).collect();
self.compute(query, &filtered_keys, &filtered_values)
} else {
self.compute(query, keys, values)
}
}
fn dim(&self) -> usize {
self.config.dim
}
fn num_heads(&self) -> usize {
self.config.num_heads
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fused_attention_config() {
let config = FusedCurvatureConfig {
dim: 64,
euclidean_dim: 32,
hyperbolic_dim: 24,
spherical_dim: 8,
..Default::default()
};
assert!(config.validate().is_ok());
}
#[test]
fn test_fused_attention() {
let config = FusedCurvatureConfig {
dim: 64,
euclidean_dim: 32,
hyperbolic_dim: 24,
spherical_dim: 8,
..Default::default()
};
let attention = MixedCurvatureFusedAttention::new(config);
let query = vec![0.5f32; 64];
let keys: Vec<Vec<f32>> = (0..20).map(|i| vec![0.1 + i as f32 * 0.02; 64]).collect();
let values: Vec<Vec<f32>> = (0..20).map(|i| vec![i as f32; 64]).collect();
let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
let output = attention.compute(&query, &keys_refs, &values_refs).unwrap();
assert_eq!(output.len(), 64);
}
#[test]
fn test_cache_reuse() {
let attention = MixedCurvatureFusedAttention::with_dim(32);
let keys: Vec<Vec<f32>> = (0..10).map(|i| vec![0.1 * i as f32; 32]).collect();
let values: Vec<Vec<f32>> = (0..10).map(|i| vec![i as f32; 32]).collect();
let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
let cache = attention.build_cache(&keys_refs);
// Multiple queries with same cache
for h in 0..4 {
let query = vec![0.5f32; 32];
let output = attention
.compute_with_cache(&query, &keys_refs, &values_refs, &cache, h)
.unwrap();
assert_eq!(output.len(), 32);
}
}
}

View File

@@ -0,0 +1,28 @@
//! Mixed Curvature Attention
//!
//! Attention in product spaces: E^e × H^h × S^s
//!
//! ## Key Optimizations
//!
//! 1. **Tangent Space Mapping**: Map hyperbolic to tangent space at origin
//! 2. **Fused Dot Kernel**: Single vectorized loop for all three similarities
//! 3. **Per-Head Mixing**: Low-rank learned weights per head
//! 4. **Quantization-Friendly**: Different precision for each component
mod component_quantizer;
mod fused_attention;
mod tangent_space;
pub use component_quantizer::{ComponentQuantizer, QuantizationConfig, QuantizedVector};
pub use fused_attention::{
FusedCurvatureConfig, MixedCurvatureCache, MixedCurvatureFusedAttention,
};
pub use tangent_space::{TangentSpaceConfig, TangentSpaceMapper};
#[cfg(test)]
mod tests {
#[test]
fn test_module_exists() {
assert!(true);
}
}

View File

@@ -0,0 +1,246 @@
//! Tangent Space Mapping for Fast Hyperbolic Operations
//!
//! Instead of computing full geodesic distances in hyperbolic space,
//! we map points to the tangent space at a learned origin and use
//! dot products. This is 10-100x faster while preserving hierarchy.
use serde::{Deserialize, Serialize};
/// Configuration for tangent space mapping
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TangentSpaceConfig {
/// Dimension of hyperbolic component
pub hyperbolic_dim: usize,
/// Curvature (negative, e.g., -1.0)
pub curvature: f32,
/// Whether to learn the origin
pub learnable_origin: bool,
}
impl Default for TangentSpaceConfig {
fn default() -> Self {
Self {
hyperbolic_dim: 32,
curvature: -1.0,
learnable_origin: true,
}
}
}
/// Tangent space mapper for hyperbolic geometry
///
/// Maps points from Poincaré ball to tangent space at origin,
/// enabling fast dot-product similarity instead of geodesic distance.
#[derive(Debug, Clone)]
pub struct TangentSpaceMapper {
config: TangentSpaceConfig,
/// Origin point in Poincaré ball
origin: Vec<f32>,
/// Conformal factor at origin
lambda_origin: f32,
}
impl TangentSpaceMapper {
/// Create new mapper with config
pub fn new(config: TangentSpaceConfig) -> Self {
let origin = vec![0.0f32; config.hyperbolic_dim];
let c = -config.curvature;
let origin_norm_sq: f32 = origin.iter().map(|x| x * x).sum();
let lambda_origin = 2.0 / (1.0 - c * origin_norm_sq).max(1e-8);
Self {
config,
origin,
lambda_origin,
}
}
/// Set custom origin (for learned origins)
pub fn set_origin(&mut self, origin: Vec<f32>) {
let c = -self.config.curvature;
let origin_norm_sq: f32 = origin.iter().map(|x| x * x).sum();
self.lambda_origin = 2.0 / (1.0 - c * origin_norm_sq).max(1e-8);
self.origin = origin;
}
/// Map point from Poincaré ball to tangent space at origin
///
/// log_o(x) = (2 / λ_o) * arctanh(√c ||o ⊕ x||) * (o ⊕ x) / ||o ⊕ x||
///
/// For origin at 0, this simplifies to:
/// log_0(x) = 2 * arctanh(√c ||x||) * x / (√c ||x||)
#[inline]
pub fn log_map(&self, point: &[f32]) -> Vec<f32> {
let c = -self.config.curvature;
let sqrt_c = c.sqrt();
// For origin at 0, Möbius addition o ⊕ x = x
if self.origin.iter().all(|&x| x.abs() < 1e-8) {
return self.log_map_at_origin(point, sqrt_c);
}
// General case: compute -origin ⊕ point
let neg_origin: Vec<f32> = self.origin.iter().map(|x| -x).collect();
let diff = self.mobius_add(&neg_origin, point, c);
let diff_norm: f32 = diff.iter().map(|x| x * x).sum::<f32>().sqrt();
if diff_norm < 1e-8 {
return vec![0.0f32; point.len()];
}
let scale =
(2.0 / self.lambda_origin) * (sqrt_c * diff_norm).atanh() / (sqrt_c * diff_norm);
diff.iter().map(|&d| scale * d).collect()
}
/// Fast log map at origin (most common case)
#[inline]
fn log_map_at_origin(&self, point: &[f32], sqrt_c: f32) -> Vec<f32> {
let norm: f32 = point.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm < 1e-8 {
return vec![0.0f32; point.len()];
}
// Clamp to avoid infinity
let arg = (sqrt_c * norm).min(0.99);
let scale = 2.0 * arg.atanh() / (sqrt_c * norm);
point.iter().map(|&p| scale * p).collect()
}
/// Möbius addition in Poincaré ball
fn mobius_add(&self, x: &[f32], y: &[f32], c: f32) -> Vec<f32> {
let x_norm_sq: f32 = x.iter().map(|xi| xi * xi).sum();
let y_norm_sq: f32 = y.iter().map(|yi| yi * yi).sum();
let xy_dot: f32 = x.iter().zip(y.iter()).map(|(&xi, &yi)| xi * yi).sum();
let num_coef = 1.0 + 2.0 * c * xy_dot + c * y_norm_sq;
let denom = 1.0 + 2.0 * c * xy_dot + c * c * x_norm_sq * y_norm_sq;
if denom.abs() < 1e-8 {
return x.to_vec();
}
let y_coef = 1.0 - c * x_norm_sq;
x.iter()
.zip(y.iter())
.map(|(&xi, &yi)| (num_coef * xi + y_coef * yi) / denom)
.collect()
}
/// Compute tangent space similarity (dot product in tangent space)
///
/// This approximates hyperbolic distance but is much faster.
#[inline]
pub fn tangent_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
// Map both to tangent space
let ta = self.log_map(a);
let tb = self.log_map(b);
// Dot product
ta.iter().zip(tb.iter()).map(|(&ai, &bi)| ai * bi).sum()
}
/// Batch map points to tangent space (cache for window)
pub fn batch_log_map(&self, points: &[&[f32]]) -> Vec<Vec<f32>> {
points.iter().map(|p| self.log_map(p)).collect()
}
/// Compute similarities in tangent space (all pairwise with query)
pub fn batch_tangent_similarity(
&self,
query_tangent: &[f32],
keys_tangent: &[&[f32]],
) -> Vec<f32> {
keys_tangent
.iter()
.map(|k| Self::dot_product_simd(query_tangent, k))
.collect()
}
/// SIMD-friendly dot product
#[inline(always)]
fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 {
let len = a.len().min(b.len());
let chunks = len / 4;
let remainder = len % 4;
let mut sum0 = 0.0f32;
let mut sum1 = 0.0f32;
let mut sum2 = 0.0f32;
let mut sum3 = 0.0f32;
for i in 0..chunks {
let base = i * 4;
sum0 += a[base] * b[base];
sum1 += a[base + 1] * b[base + 1];
sum2 += a[base + 2] * b[base + 2];
sum3 += a[base + 3] * b[base + 3];
}
let base = chunks * 4;
for i in 0..remainder {
sum0 += a[base + i] * b[base + i];
}
sum0 + sum1 + sum2 + sum3
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_log_map_at_origin() {
let config = TangentSpaceConfig {
hyperbolic_dim: 4,
curvature: -1.0,
learnable_origin: false,
};
let mapper = TangentSpaceMapper::new(config);
// Point at origin maps to zero
let origin = vec![0.0f32; 4];
let result = mapper.log_map(&origin);
assert!(result.iter().all(|&x| x.abs() < 1e-6));
// Non-zero point
let point = vec![0.1, 0.2, 0.0, 0.0];
let tangent = mapper.log_map(&point);
assert_eq!(tangent.len(), 4);
}
#[test]
fn test_tangent_similarity() {
let config = TangentSpaceConfig {
hyperbolic_dim: 4,
curvature: -1.0,
learnable_origin: false,
};
let mapper = TangentSpaceMapper::new(config);
let a = vec![0.1, 0.1, 0.0, 0.0];
let b = vec![0.1, 0.1, 0.0, 0.0];
// Same points should have high similarity
let sim = mapper.tangent_similarity(&a, &b);
assert!(sim > 0.0);
}
#[test]
fn test_batch_operations() {
let config = TangentSpaceConfig::default();
let mapper = TangentSpaceMapper::new(config);
let points: Vec<Vec<f32>> = (0..10).map(|i| vec![i as f32 * 0.05; 32]).collect();
let points_refs: Vec<&[f32]> = points.iter().map(|p| p.as_slice()).collect();
let tangents = mapper.batch_log_map(&points_refs);
assert_eq!(tangents.len(), 10);
}
}