Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
260
vendor/ruvector/crates/ruvector-attention/src/curvature/component_quantizer.rs
vendored
Normal file
260
vendor/ruvector/crates/ruvector-attention/src/curvature/component_quantizer.rs
vendored
Normal file
@@ -0,0 +1,260 @@
|
||||
//! Component Quantization for Mixed-Curvature Attention
|
||||
//!
|
||||
//! Different precision for each geometric component:
|
||||
//! - Euclidean: 7-8 bit (needs precision)
|
||||
//! - Hyperbolic tangent: 5 bit (tolerates noise)
|
||||
//! - Spherical: 5 bit (only direction matters)
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Quantization configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct QuantizationConfig {
|
||||
/// Bits for Euclidean component
|
||||
pub euclidean_bits: u8,
|
||||
/// Bits for Hyperbolic component
|
||||
pub hyperbolic_bits: u8,
|
||||
/// Bits for Spherical component
|
||||
pub spherical_bits: u8,
|
||||
}
|
||||
|
||||
impl Default for QuantizationConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
euclidean_bits: 8,
|
||||
hyperbolic_bits: 5,
|
||||
spherical_bits: 5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Quantized vector representation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QuantizedVector {
|
||||
/// Quantized Euclidean component
|
||||
pub euclidean: Vec<i8>,
|
||||
/// Euclidean scale factor
|
||||
pub euclidean_scale: f32,
|
||||
/// Quantized Hyperbolic component
|
||||
pub hyperbolic: Vec<i8>,
|
||||
/// Hyperbolic scale factor
|
||||
pub hyperbolic_scale: f32,
|
||||
/// Quantized Spherical component
|
||||
pub spherical: Vec<i8>,
|
||||
/// Spherical scale factor
|
||||
pub spherical_scale: f32,
|
||||
}
|
||||
|
||||
/// Component quantizer for efficient storage and compute
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ComponentQuantizer {
|
||||
config: QuantizationConfig,
|
||||
euclidean_levels: i32,
|
||||
hyperbolic_levels: i32,
|
||||
spherical_levels: i32,
|
||||
}
|
||||
|
||||
impl ComponentQuantizer {
|
||||
/// Create new quantizer
|
||||
pub fn new(config: QuantizationConfig) -> Self {
|
||||
Self {
|
||||
euclidean_levels: (1 << (config.euclidean_bits - 1)) - 1,
|
||||
hyperbolic_levels: (1 << (config.hyperbolic_bits - 1)) - 1,
|
||||
spherical_levels: (1 << (config.spherical_bits - 1)) - 1,
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
/// Quantize a component vector
|
||||
fn quantize_component(&self, values: &[f32], levels: i32) -> (Vec<i8>, f32) {
|
||||
if values.is_empty() {
|
||||
return (vec![], 1.0);
|
||||
}
|
||||
|
||||
// Find absmax for scale
|
||||
let absmax = values
|
||||
.iter()
|
||||
.map(|v| v.abs())
|
||||
.fold(0.0f32, f32::max)
|
||||
.max(1e-8);
|
||||
|
||||
let scale = absmax / levels as f32;
|
||||
let inv_scale = levels as f32 / absmax;
|
||||
|
||||
let quantized: Vec<i8> = values
|
||||
.iter()
|
||||
.map(|v| (v * inv_scale).round().clamp(-127.0, 127.0) as i8)
|
||||
.collect();
|
||||
|
||||
(quantized, scale)
|
||||
}
|
||||
|
||||
/// Dequantize a component
|
||||
fn dequantize_component(&self, quantized: &[i8], scale: f32) -> Vec<f32> {
|
||||
quantized.iter().map(|&q| q as f32 * scale).collect()
|
||||
}
|
||||
|
||||
/// Quantize full vector with component ranges
|
||||
pub fn quantize(
|
||||
&self,
|
||||
vector: &[f32],
|
||||
e_range: std::ops::Range<usize>,
|
||||
h_range: std::ops::Range<usize>,
|
||||
s_range: std::ops::Range<usize>,
|
||||
) -> QuantizedVector {
|
||||
let (euclidean, euclidean_scale) =
|
||||
self.quantize_component(&vector[e_range], self.euclidean_levels);
|
||||
|
||||
let (hyperbolic, hyperbolic_scale) =
|
||||
self.quantize_component(&vector[h_range], self.hyperbolic_levels);
|
||||
|
||||
let (spherical, spherical_scale) =
|
||||
self.quantize_component(&vector[s_range], self.spherical_levels);
|
||||
|
||||
QuantizedVector {
|
||||
euclidean,
|
||||
euclidean_scale,
|
||||
hyperbolic,
|
||||
hyperbolic_scale,
|
||||
spherical,
|
||||
spherical_scale,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute dot product between quantized vectors (integer arithmetic)
|
||||
#[inline]
|
||||
pub fn quantized_dot_product(
|
||||
&self,
|
||||
a: &QuantizedVector,
|
||||
b: &QuantizedVector,
|
||||
weights: &[f32; 3],
|
||||
) -> f32 {
|
||||
// Integer dot products
|
||||
let dot_e = Self::int_dot(&a.euclidean, &b.euclidean);
|
||||
let dot_h = Self::int_dot(&a.hyperbolic, &b.hyperbolic);
|
||||
let dot_s = Self::int_dot(&a.spherical, &b.spherical);
|
||||
|
||||
// Scale and weight
|
||||
let sim_e = dot_e as f32 * a.euclidean_scale * b.euclidean_scale;
|
||||
let sim_h = dot_h as f32 * a.hyperbolic_scale * b.hyperbolic_scale;
|
||||
let sim_s = dot_s as f32 * a.spherical_scale * b.spherical_scale;
|
||||
|
||||
weights[0] * sim_e + weights[1] * sim_h + weights[2] * sim_s
|
||||
}
|
||||
|
||||
/// Integer dot product (SIMD-friendly)
|
||||
#[inline(always)]
|
||||
fn int_dot(a: &[i8], b: &[i8]) -> i32 {
|
||||
let len = a.len().min(b.len());
|
||||
let chunks = len / 4;
|
||||
let remainder = len % 4;
|
||||
|
||||
let mut sum0 = 0i32;
|
||||
let mut sum1 = 0i32;
|
||||
let mut sum2 = 0i32;
|
||||
let mut sum3 = 0i32;
|
||||
|
||||
for i in 0..chunks {
|
||||
let base = i * 4;
|
||||
sum0 += a[base] as i32 * b[base] as i32;
|
||||
sum1 += a[base + 1] as i32 * b[base + 1] as i32;
|
||||
sum2 += a[base + 2] as i32 * b[base + 2] as i32;
|
||||
sum3 += a[base + 3] as i32 * b[base + 3] as i32;
|
||||
}
|
||||
|
||||
let base = chunks * 4;
|
||||
for i in 0..remainder {
|
||||
sum0 += a[base + i] as i32 * b[base + i] as i32;
|
||||
}
|
||||
|
||||
sum0 + sum1 + sum2 + sum3
|
||||
}
|
||||
|
||||
/// Dequantize to full vector
|
||||
pub fn dequantize(&self, quant: &QuantizedVector, total_dim: usize) -> Vec<f32> {
|
||||
let mut result = vec![0.0f32; total_dim];
|
||||
|
||||
let e_vec = self.dequantize_component(&quant.euclidean, quant.euclidean_scale);
|
||||
let h_vec = self.dequantize_component(&quant.hyperbolic, quant.hyperbolic_scale);
|
||||
let s_vec = self.dequantize_component(&quant.spherical, quant.spherical_scale);
|
||||
|
||||
let e_end = e_vec.len();
|
||||
let h_end = e_end + h_vec.len();
|
||||
|
||||
result[0..e_end].copy_from_slice(&e_vec);
|
||||
result[e_end..h_end].copy_from_slice(&h_vec);
|
||||
result[h_end..h_end + s_vec.len()].copy_from_slice(&s_vec);
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Get memory savings ratio
|
||||
pub fn compression_ratio(&self, dim: usize, e_dim: usize, h_dim: usize, s_dim: usize) -> f32 {
|
||||
let original_bits = dim as f32 * 32.0;
|
||||
let quantized_bits = e_dim as f32 * self.config.euclidean_bits as f32
|
||||
+ h_dim as f32 * self.config.hyperbolic_bits as f32
|
||||
+ s_dim as f32 * self.config.spherical_bits as f32
|
||||
+ 3.0 * 32.0; // 3 scale factors
|
||||
|
||||
original_bits / quantized_bits
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_quantize_dequantize() {
|
||||
let quantizer = ComponentQuantizer::new(QuantizationConfig::default());
|
||||
|
||||
let vector = vec![0.5f32; 64];
|
||||
let e_range = 0..32;
|
||||
let h_range = 32..48;
|
||||
let s_range = 48..64;
|
||||
|
||||
let quantized =
|
||||
quantizer.quantize(&vector, e_range.clone(), h_range.clone(), s_range.clone());
|
||||
|
||||
assert_eq!(quantized.euclidean.len(), 32);
|
||||
assert_eq!(quantized.hyperbolic.len(), 16);
|
||||
assert_eq!(quantized.spherical.len(), 16);
|
||||
|
||||
// Dequantize and check approximate equality
|
||||
let dequantized = quantizer.dequantize(&quantized, 64);
|
||||
for (&orig, &deq) in vector.iter().zip(dequantized.iter()) {
|
||||
assert!((orig - deq).abs() < 0.1);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantized_dot_product() {
|
||||
let quantizer = ComponentQuantizer::new(QuantizationConfig::default());
|
||||
|
||||
let a = vec![1.0f32; 64];
|
||||
let b = vec![1.0f32; 64];
|
||||
let e_range = 0..32;
|
||||
let h_range = 32..48;
|
||||
let s_range = 48..64;
|
||||
|
||||
let qa = quantizer.quantize(&a, e_range.clone(), h_range.clone(), s_range.clone());
|
||||
let qb = quantizer.quantize(&b, e_range, h_range, s_range);
|
||||
|
||||
let weights = [0.5, 0.3, 0.2];
|
||||
let dot = quantizer.quantized_dot_product(&qa, &qb, &weights);
|
||||
|
||||
// Should be positive for same vectors
|
||||
assert!(dot > 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compression_ratio() {
|
||||
let quantizer = ComponentQuantizer::new(QuantizationConfig::default());
|
||||
|
||||
let ratio = quantizer.compression_ratio(512, 256, 192, 64);
|
||||
|
||||
// With 8/5/5 bits vs 32 bits, expect ~4-5x compression
|
||||
assert!(ratio > 3.0);
|
||||
assert!(ratio < 7.0);
|
||||
}
|
||||
}
|
||||
441
vendor/ruvector/crates/ruvector-attention/src/curvature/fused_attention.rs
vendored
Normal file
441
vendor/ruvector/crates/ruvector-attention/src/curvature/fused_attention.rs
vendored
Normal file
@@ -0,0 +1,441 @@
|
||||
//! Fused Mixed-Curvature Attention
|
||||
//!
|
||||
//! Single kernel that computes Euclidean, Hyperbolic (tangent), and Spherical
|
||||
//! similarities in one pass for maximum cache efficiency.
|
||||
//!
|
||||
//! logit(q,k) = a * dot(q_E, k_E) + b * dot(q_H_tan, k_H_tan) + c * dot(q_S, k_S)
|
||||
|
||||
use super::tangent_space::{TangentSpaceConfig, TangentSpaceMapper};
|
||||
use crate::error::{AttentionError, AttentionResult};
|
||||
use crate::traits::Attention;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Configuration for fused mixed-curvature attention
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FusedCurvatureConfig {
|
||||
/// Total dimension
|
||||
pub dim: usize,
|
||||
/// Euclidean component dimension
|
||||
pub euclidean_dim: usize,
|
||||
/// Hyperbolic component dimension
|
||||
pub hyperbolic_dim: usize,
|
||||
/// Spherical component dimension
|
||||
pub spherical_dim: usize,
|
||||
/// Mixing weight for Euclidean component
|
||||
pub weight_e: f32,
|
||||
/// Mixing weight for Hyperbolic component
|
||||
pub weight_h: f32,
|
||||
/// Mixing weight for Spherical component
|
||||
pub weight_s: f32,
|
||||
/// Hyperbolic curvature
|
||||
pub hyperbolic_curvature: f32,
|
||||
/// Temperature for softmax
|
||||
pub temperature: f32,
|
||||
/// Number of attention heads
|
||||
pub num_heads: usize,
|
||||
/// Per-head weight variation (low-rank)
|
||||
pub per_head_variation: f32,
|
||||
}
|
||||
|
||||
impl Default for FusedCurvatureConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
dim: 512,
|
||||
euclidean_dim: 256,
|
||||
hyperbolic_dim: 192,
|
||||
spherical_dim: 64,
|
||||
weight_e: 0.5,
|
||||
weight_h: 0.35,
|
||||
weight_s: 0.15,
|
||||
hyperbolic_curvature: -1.0,
|
||||
temperature: 1.0,
|
||||
num_heads: 8,
|
||||
per_head_variation: 0.1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FusedCurvatureConfig {
|
||||
/// Validate config
|
||||
pub fn validate(&self) -> Result<(), String> {
|
||||
if self.euclidean_dim + self.hyperbolic_dim + self.spherical_dim != self.dim {
|
||||
return Err("Component dimensions must sum to total dim".into());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get component ranges
|
||||
pub fn component_ranges(
|
||||
&self,
|
||||
) -> (
|
||||
std::ops::Range<usize>,
|
||||
std::ops::Range<usize>,
|
||||
std::ops::Range<usize>,
|
||||
) {
|
||||
let e_end = self.euclidean_dim;
|
||||
let h_end = e_end + self.hyperbolic_dim;
|
||||
let s_end = h_end + self.spherical_dim;
|
||||
|
||||
(0..e_end, e_end..h_end, h_end..s_end)
|
||||
}
|
||||
}
|
||||
|
||||
/// Window cache for mixed-curvature attention
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MixedCurvatureCache {
|
||||
/// Tangent-space mapped hyperbolic components [N × h_dim]
|
||||
pub keys_hyperbolic_tangent: Vec<Vec<f32>>,
|
||||
/// Normalized spherical components [N × s_dim]
|
||||
pub keys_spherical_normalized: Vec<Vec<f32>>,
|
||||
/// Number of keys
|
||||
pub num_keys: usize,
|
||||
}
|
||||
|
||||
/// Fused mixed-curvature attention
|
||||
///
|
||||
/// Computes attention with Euclidean, Hyperbolic, and Spherical
|
||||
/// similarities in a single fused kernel.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MixedCurvatureFusedAttention {
|
||||
config: FusedCurvatureConfig,
|
||||
tangent_mapper: TangentSpaceMapper,
|
||||
/// Per-head weight modifiers [num_heads × 3]
|
||||
head_weights: Vec<[f32; 3]>,
|
||||
}
|
||||
|
||||
impl MixedCurvatureFusedAttention {
|
||||
/// Create new fused attention
|
||||
pub fn new(config: FusedCurvatureConfig) -> Self {
|
||||
let tangent_config = TangentSpaceConfig {
|
||||
hyperbolic_dim: config.hyperbolic_dim,
|
||||
curvature: config.hyperbolic_curvature,
|
||||
learnable_origin: true,
|
||||
};
|
||||
let tangent_mapper = TangentSpaceMapper::new(tangent_config);
|
||||
|
||||
// Initialize per-head weights with small variation
|
||||
let head_weights: Vec<[f32; 3]> = (0..config.num_heads)
|
||||
.map(|h| {
|
||||
let var = config.per_head_variation;
|
||||
let h_factor = h as f32 / config.num_heads as f32 - 0.5;
|
||||
[
|
||||
config.weight_e + h_factor * var,
|
||||
config.weight_h - h_factor * var * 0.5,
|
||||
config.weight_s + h_factor * var * 0.5,
|
||||
]
|
||||
})
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
config,
|
||||
tangent_mapper,
|
||||
head_weights,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with balanced weights
|
||||
pub fn with_dim(dim: usize) -> Self {
|
||||
let e_dim = dim / 2;
|
||||
let h_dim = dim / 4;
|
||||
let s_dim = dim - e_dim - h_dim;
|
||||
|
||||
let config = FusedCurvatureConfig {
|
||||
dim,
|
||||
euclidean_dim: e_dim,
|
||||
hyperbolic_dim: h_dim,
|
||||
spherical_dim: s_dim,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
Self::new(config)
|
||||
}
|
||||
|
||||
/// Build cache for keys (pre-compute expensive operations)
|
||||
pub fn build_cache(&self, keys: &[&[f32]]) -> MixedCurvatureCache {
|
||||
let (_e_range, h_range, s_range) = self.config.component_ranges();
|
||||
|
||||
// Pre-map hyperbolic components to tangent space
|
||||
let keys_hyperbolic_tangent: Vec<Vec<f32>> = keys
|
||||
.iter()
|
||||
.map(|k| {
|
||||
let h_part = &k[h_range.clone()];
|
||||
self.tangent_mapper.log_map(h_part)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Pre-normalize spherical components
|
||||
let keys_spherical_normalized: Vec<Vec<f32>> = keys
|
||||
.iter()
|
||||
.map(|k| {
|
||||
let s_part = &k[s_range.clone()];
|
||||
Self::normalize(s_part)
|
||||
})
|
||||
.collect();
|
||||
|
||||
MixedCurvatureCache {
|
||||
keys_hyperbolic_tangent,
|
||||
keys_spherical_normalized,
|
||||
num_keys: keys.len(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute attention with cache (fast path)
|
||||
pub fn compute_with_cache(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: &[&[f32]],
|
||||
values: &[&[f32]],
|
||||
cache: &MixedCurvatureCache,
|
||||
head_idx: usize,
|
||||
) -> AttentionResult<Vec<f32>> {
|
||||
let num_keys = cache.num_keys;
|
||||
if num_keys == 0 {
|
||||
return Err(AttentionError::InvalidConfig("No keys".into()));
|
||||
}
|
||||
|
||||
let (e_range, h_range, s_range) = self.config.component_ranges();
|
||||
let weights = &self.head_weights[head_idx % self.head_weights.len()];
|
||||
|
||||
// Extract query components
|
||||
let q_e = &query[e_range.clone()];
|
||||
let q_h = &query[h_range.clone()];
|
||||
let q_s = &query[s_range.clone()];
|
||||
|
||||
// Map query hyperbolic to tangent space
|
||||
let q_h_tangent = self.tangent_mapper.log_map(q_h);
|
||||
|
||||
// Normalize query spherical
|
||||
let q_s_normalized = Self::normalize(q_s);
|
||||
|
||||
// Compute fused logits
|
||||
let logits: Vec<f32> = (0..num_keys)
|
||||
.map(|i| {
|
||||
let k = keys[i];
|
||||
|
||||
// Euclidean similarity (dot product)
|
||||
let sim_e = Self::dot_product_simd(&q_e, &k[e_range.clone()]);
|
||||
|
||||
// Hyperbolic similarity (tangent space dot product)
|
||||
let sim_h = Self::dot_product_simd(&q_h_tangent, &cache.keys_hyperbolic_tangent[i]);
|
||||
|
||||
// Spherical similarity (normalized dot product)
|
||||
let sim_s =
|
||||
Self::dot_product_simd(&q_s_normalized, &cache.keys_spherical_normalized[i]);
|
||||
|
||||
// Fused logit
|
||||
(weights[0] * sim_e + weights[1] * sim_h + weights[2] * sim_s)
|
||||
/ self.config.temperature
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Softmax
|
||||
let attention_weights = Self::stable_softmax(&logits);
|
||||
|
||||
// Weighted sum
|
||||
self.weighted_sum(&attention_weights, values)
|
||||
}
|
||||
|
||||
/// Fused similarity computation (single pass through all components)
|
||||
/// This is the hot path - maximize SIMD utilization
|
||||
#[inline]
|
||||
pub fn fused_similarity(
|
||||
&self,
|
||||
query: &[f32],
|
||||
key: &[f32],
|
||||
key_h_tangent: &[f32],
|
||||
key_s_normalized: &[f32],
|
||||
query_h_tangent: &[f32],
|
||||
query_s_normalized: &[f32],
|
||||
weights: &[f32; 3],
|
||||
) -> f32 {
|
||||
let (e_range, _, _) = self.config.component_ranges();
|
||||
|
||||
// Euclidean: direct dot product on original vectors
|
||||
let sim_e = Self::dot_product_simd(&query[e_range.clone()], &key[e_range.clone()]);
|
||||
|
||||
// Hyperbolic: dot product in tangent space
|
||||
let sim_h = Self::dot_product_simd(query_h_tangent, key_h_tangent);
|
||||
|
||||
// Spherical: dot product of normalized vectors
|
||||
let sim_s = Self::dot_product_simd(query_s_normalized, key_s_normalized);
|
||||
|
||||
weights[0] * sim_e + weights[1] * sim_h + weights[2] * sim_s
|
||||
}
|
||||
|
||||
/// Normalize vector to unit length
|
||||
#[inline]
|
||||
fn normalize(v: &[f32]) -> Vec<f32> {
|
||||
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm > 1e-8 {
|
||||
v.iter().map(|x| x / norm).collect()
|
||||
} else {
|
||||
v.to_vec()
|
||||
}
|
||||
}
|
||||
|
||||
/// SIMD-friendly dot product
|
||||
#[inline(always)]
|
||||
fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 {
|
||||
let len = a.len().min(b.len());
|
||||
let chunks = len / 4;
|
||||
let remainder = len % 4;
|
||||
|
||||
let mut sum0 = 0.0f32;
|
||||
let mut sum1 = 0.0f32;
|
||||
let mut sum2 = 0.0f32;
|
||||
let mut sum3 = 0.0f32;
|
||||
|
||||
for i in 0..chunks {
|
||||
let base = i * 4;
|
||||
sum0 += a[base] * b[base];
|
||||
sum1 += a[base + 1] * b[base + 1];
|
||||
sum2 += a[base + 2] * b[base + 2];
|
||||
sum3 += a[base + 3] * b[base + 3];
|
||||
}
|
||||
|
||||
let base = chunks * 4;
|
||||
for i in 0..remainder {
|
||||
sum0 += a[base + i] * b[base + i];
|
||||
}
|
||||
|
||||
sum0 + sum1 + sum2 + sum3
|
||||
}
|
||||
|
||||
/// Stable softmax
|
||||
fn stable_softmax(logits: &[f32]) -> Vec<f32> {
|
||||
if logits.is_empty() {
|
||||
return vec![];
|
||||
}
|
||||
|
||||
let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let exp_logits: Vec<f32> = logits.iter().map(|&l| (l - max_logit).exp()).collect();
|
||||
let sum: f32 = exp_logits.iter().sum();
|
||||
|
||||
exp_logits.iter().map(|&e| e / sum).collect()
|
||||
}
|
||||
|
||||
/// Weighted sum
|
||||
fn weighted_sum(&self, weights: &[f32], values: &[&[f32]]) -> AttentionResult<Vec<f32>> {
|
||||
if weights.is_empty() || values.is_empty() {
|
||||
return Err(AttentionError::InvalidConfig("Empty inputs".into()));
|
||||
}
|
||||
|
||||
let dim = values[0].len();
|
||||
let mut output = vec![0.0f32; dim];
|
||||
|
||||
for (weight, value) in weights.iter().zip(values.iter()) {
|
||||
for (o, &v) in output.iter_mut().zip(value.iter()) {
|
||||
*o += weight * v;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
}
|
||||
|
||||
impl Attention for MixedCurvatureFusedAttention {
|
||||
fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: &[&[f32]],
|
||||
values: &[&[f32]],
|
||||
) -> AttentionResult<Vec<f32>> {
|
||||
let cache = self.build_cache(keys);
|
||||
self.compute_with_cache(query, keys, values, &cache, 0)
|
||||
}
|
||||
|
||||
fn compute_with_mask(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: &[&[f32]],
|
||||
values: &[&[f32]],
|
||||
mask: Option<&[bool]>,
|
||||
) -> AttentionResult<Vec<f32>> {
|
||||
if let Some(m) = mask {
|
||||
let filtered: Vec<(&[f32], &[f32])> = keys
|
||||
.iter()
|
||||
.zip(values.iter())
|
||||
.enumerate()
|
||||
.filter(|(i, _)| m.get(*i).copied().unwrap_or(true))
|
||||
.map(|(_, (k, v))| (*k, *v))
|
||||
.collect();
|
||||
|
||||
let filtered_keys: Vec<&[f32]> = filtered.iter().map(|(k, _)| *k).collect();
|
||||
let filtered_values: Vec<&[f32]> = filtered.iter().map(|(_, v)| *v).collect();
|
||||
|
||||
self.compute(query, &filtered_keys, &filtered_values)
|
||||
} else {
|
||||
self.compute(query, keys, values)
|
||||
}
|
||||
}
|
||||
|
||||
fn dim(&self) -> usize {
|
||||
self.config.dim
|
||||
}
|
||||
|
||||
fn num_heads(&self) -> usize {
|
||||
self.config.num_heads
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_fused_attention_config() {
|
||||
let config = FusedCurvatureConfig {
|
||||
dim: 64,
|
||||
euclidean_dim: 32,
|
||||
hyperbolic_dim: 24,
|
||||
spherical_dim: 8,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
assert!(config.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fused_attention() {
|
||||
let config = FusedCurvatureConfig {
|
||||
dim: 64,
|
||||
euclidean_dim: 32,
|
||||
hyperbolic_dim: 24,
|
||||
spherical_dim: 8,
|
||||
..Default::default()
|
||||
};
|
||||
let attention = MixedCurvatureFusedAttention::new(config);
|
||||
|
||||
let query = vec![0.5f32; 64];
|
||||
let keys: Vec<Vec<f32>> = (0..20).map(|i| vec![0.1 + i as f32 * 0.02; 64]).collect();
|
||||
let values: Vec<Vec<f32>> = (0..20).map(|i| vec![i as f32; 64]).collect();
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let output = attention.compute(&query, &keys_refs, &values_refs).unwrap();
|
||||
assert_eq!(output.len(), 64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_reuse() {
|
||||
let attention = MixedCurvatureFusedAttention::with_dim(32);
|
||||
|
||||
let keys: Vec<Vec<f32>> = (0..10).map(|i| vec![0.1 * i as f32; 32]).collect();
|
||||
let values: Vec<Vec<f32>> = (0..10).map(|i| vec![i as f32; 32]).collect();
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let cache = attention.build_cache(&keys_refs);
|
||||
|
||||
// Multiple queries with same cache
|
||||
for h in 0..4 {
|
||||
let query = vec![0.5f32; 32];
|
||||
let output = attention
|
||||
.compute_with_cache(&query, &keys_refs, &values_refs, &cache, h)
|
||||
.unwrap();
|
||||
assert_eq!(output.len(), 32);
|
||||
}
|
||||
}
|
||||
}
|
||||
28
vendor/ruvector/crates/ruvector-attention/src/curvature/mod.rs
vendored
Normal file
28
vendor/ruvector/crates/ruvector-attention/src/curvature/mod.rs
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
//! Mixed Curvature Attention
|
||||
//!
|
||||
//! Attention in product spaces: E^e × H^h × S^s
|
||||
//!
|
||||
//! ## Key Optimizations
|
||||
//!
|
||||
//! 1. **Tangent Space Mapping**: Map hyperbolic to tangent space at origin
|
||||
//! 2. **Fused Dot Kernel**: Single vectorized loop for all three similarities
|
||||
//! 3. **Per-Head Mixing**: Low-rank learned weights per head
|
||||
//! 4. **Quantization-Friendly**: Different precision for each component
|
||||
|
||||
mod component_quantizer;
|
||||
mod fused_attention;
|
||||
mod tangent_space;
|
||||
|
||||
pub use component_quantizer::{ComponentQuantizer, QuantizationConfig, QuantizedVector};
|
||||
pub use fused_attention::{
|
||||
FusedCurvatureConfig, MixedCurvatureCache, MixedCurvatureFusedAttention,
|
||||
};
|
||||
pub use tangent_space::{TangentSpaceConfig, TangentSpaceMapper};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#[test]
|
||||
fn test_module_exists() {
|
||||
assert!(true);
|
||||
}
|
||||
}
|
||||
246
vendor/ruvector/crates/ruvector-attention/src/curvature/tangent_space.rs
vendored
Normal file
246
vendor/ruvector/crates/ruvector-attention/src/curvature/tangent_space.rs
vendored
Normal file
@@ -0,0 +1,246 @@
|
||||
//! Tangent Space Mapping for Fast Hyperbolic Operations
|
||||
//!
|
||||
//! Instead of computing full geodesic distances in hyperbolic space,
|
||||
//! we map points to the tangent space at a learned origin and use
|
||||
//! dot products. This is 10-100x faster while preserving hierarchy.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Configuration for tangent space mapping
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TangentSpaceConfig {
|
||||
/// Dimension of hyperbolic component
|
||||
pub hyperbolic_dim: usize,
|
||||
/// Curvature (negative, e.g., -1.0)
|
||||
pub curvature: f32,
|
||||
/// Whether to learn the origin
|
||||
pub learnable_origin: bool,
|
||||
}
|
||||
|
||||
impl Default for TangentSpaceConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
hyperbolic_dim: 32,
|
||||
curvature: -1.0,
|
||||
learnable_origin: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Tangent space mapper for hyperbolic geometry
|
||||
///
|
||||
/// Maps points from Poincaré ball to tangent space at origin,
|
||||
/// enabling fast dot-product similarity instead of geodesic distance.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TangentSpaceMapper {
|
||||
config: TangentSpaceConfig,
|
||||
/// Origin point in Poincaré ball
|
||||
origin: Vec<f32>,
|
||||
/// Conformal factor at origin
|
||||
lambda_origin: f32,
|
||||
}
|
||||
|
||||
impl TangentSpaceMapper {
|
||||
/// Create new mapper with config
|
||||
pub fn new(config: TangentSpaceConfig) -> Self {
|
||||
let origin = vec![0.0f32; config.hyperbolic_dim];
|
||||
let c = -config.curvature;
|
||||
let origin_norm_sq: f32 = origin.iter().map(|x| x * x).sum();
|
||||
let lambda_origin = 2.0 / (1.0 - c * origin_norm_sq).max(1e-8);
|
||||
|
||||
Self {
|
||||
config,
|
||||
origin,
|
||||
lambda_origin,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set custom origin (for learned origins)
|
||||
pub fn set_origin(&mut self, origin: Vec<f32>) {
|
||||
let c = -self.config.curvature;
|
||||
let origin_norm_sq: f32 = origin.iter().map(|x| x * x).sum();
|
||||
self.lambda_origin = 2.0 / (1.0 - c * origin_norm_sq).max(1e-8);
|
||||
self.origin = origin;
|
||||
}
|
||||
|
||||
/// Map point from Poincaré ball to tangent space at origin
|
||||
///
|
||||
/// log_o(x) = (2 / λ_o) * arctanh(√c ||−o ⊕ x||) * (−o ⊕ x) / ||−o ⊕ x||
|
||||
///
|
||||
/// For origin at 0, this simplifies to:
|
||||
/// log_0(x) = 2 * arctanh(√c ||x||) * x / (√c ||x||)
|
||||
#[inline]
|
||||
pub fn log_map(&self, point: &[f32]) -> Vec<f32> {
|
||||
let c = -self.config.curvature;
|
||||
let sqrt_c = c.sqrt();
|
||||
|
||||
// For origin at 0, Möbius addition −o ⊕ x = x
|
||||
if self.origin.iter().all(|&x| x.abs() < 1e-8) {
|
||||
return self.log_map_at_origin(point, sqrt_c);
|
||||
}
|
||||
|
||||
// General case: compute -origin ⊕ point
|
||||
let neg_origin: Vec<f32> = self.origin.iter().map(|x| -x).collect();
|
||||
let diff = self.mobius_add(&neg_origin, point, c);
|
||||
|
||||
let diff_norm: f32 = diff.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
|
||||
if diff_norm < 1e-8 {
|
||||
return vec![0.0f32; point.len()];
|
||||
}
|
||||
|
||||
let scale =
|
||||
(2.0 / self.lambda_origin) * (sqrt_c * diff_norm).atanh() / (sqrt_c * diff_norm);
|
||||
|
||||
diff.iter().map(|&d| scale * d).collect()
|
||||
}
|
||||
|
||||
/// Fast log map at origin (most common case)
|
||||
#[inline]
|
||||
fn log_map_at_origin(&self, point: &[f32], sqrt_c: f32) -> Vec<f32> {
|
||||
let norm: f32 = point.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
|
||||
if norm < 1e-8 {
|
||||
return vec![0.0f32; point.len()];
|
||||
}
|
||||
|
||||
// Clamp to avoid infinity
|
||||
let arg = (sqrt_c * norm).min(0.99);
|
||||
let scale = 2.0 * arg.atanh() / (sqrt_c * norm);
|
||||
|
||||
point.iter().map(|&p| scale * p).collect()
|
||||
}
|
||||
|
||||
/// Möbius addition in Poincaré ball
|
||||
fn mobius_add(&self, x: &[f32], y: &[f32], c: f32) -> Vec<f32> {
|
||||
let x_norm_sq: f32 = x.iter().map(|xi| xi * xi).sum();
|
||||
let y_norm_sq: f32 = y.iter().map(|yi| yi * yi).sum();
|
||||
let xy_dot: f32 = x.iter().zip(y.iter()).map(|(&xi, &yi)| xi * yi).sum();
|
||||
|
||||
let num_coef = 1.0 + 2.0 * c * xy_dot + c * y_norm_sq;
|
||||
let denom = 1.0 + 2.0 * c * xy_dot + c * c * x_norm_sq * y_norm_sq;
|
||||
|
||||
if denom.abs() < 1e-8 {
|
||||
return x.to_vec();
|
||||
}
|
||||
|
||||
let y_coef = 1.0 - c * x_norm_sq;
|
||||
|
||||
x.iter()
|
||||
.zip(y.iter())
|
||||
.map(|(&xi, &yi)| (num_coef * xi + y_coef * yi) / denom)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Compute tangent space similarity (dot product in tangent space)
|
||||
///
|
||||
/// This approximates hyperbolic distance but is much faster.
|
||||
#[inline]
|
||||
pub fn tangent_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
|
||||
// Map both to tangent space
|
||||
let ta = self.log_map(a);
|
||||
let tb = self.log_map(b);
|
||||
|
||||
// Dot product
|
||||
ta.iter().zip(tb.iter()).map(|(&ai, &bi)| ai * bi).sum()
|
||||
}
|
||||
|
||||
/// Batch map points to tangent space (cache for window)
|
||||
pub fn batch_log_map(&self, points: &[&[f32]]) -> Vec<Vec<f32>> {
|
||||
points.iter().map(|p| self.log_map(p)).collect()
|
||||
}
|
||||
|
||||
/// Compute similarities in tangent space (all pairwise with query)
|
||||
pub fn batch_tangent_similarity(
|
||||
&self,
|
||||
query_tangent: &[f32],
|
||||
keys_tangent: &[&[f32]],
|
||||
) -> Vec<f32> {
|
||||
keys_tangent
|
||||
.iter()
|
||||
.map(|k| Self::dot_product_simd(query_tangent, k))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// SIMD-friendly dot product
|
||||
#[inline(always)]
|
||||
fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 {
|
||||
let len = a.len().min(b.len());
|
||||
let chunks = len / 4;
|
||||
let remainder = len % 4;
|
||||
|
||||
let mut sum0 = 0.0f32;
|
||||
let mut sum1 = 0.0f32;
|
||||
let mut sum2 = 0.0f32;
|
||||
let mut sum3 = 0.0f32;
|
||||
|
||||
for i in 0..chunks {
|
||||
let base = i * 4;
|
||||
sum0 += a[base] * b[base];
|
||||
sum1 += a[base + 1] * b[base + 1];
|
||||
sum2 += a[base + 2] * b[base + 2];
|
||||
sum3 += a[base + 3] * b[base + 3];
|
||||
}
|
||||
|
||||
let base = chunks * 4;
|
||||
for i in 0..remainder {
|
||||
sum0 += a[base + i] * b[base + i];
|
||||
}
|
||||
|
||||
sum0 + sum1 + sum2 + sum3
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_log_map_at_origin() {
|
||||
let config = TangentSpaceConfig {
|
||||
hyperbolic_dim: 4,
|
||||
curvature: -1.0,
|
||||
learnable_origin: false,
|
||||
};
|
||||
let mapper = TangentSpaceMapper::new(config);
|
||||
|
||||
// Point at origin maps to zero
|
||||
let origin = vec![0.0f32; 4];
|
||||
let result = mapper.log_map(&origin);
|
||||
assert!(result.iter().all(|&x| x.abs() < 1e-6));
|
||||
|
||||
// Non-zero point
|
||||
let point = vec![0.1, 0.2, 0.0, 0.0];
|
||||
let tangent = mapper.log_map(&point);
|
||||
assert_eq!(tangent.len(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tangent_similarity() {
|
||||
let config = TangentSpaceConfig {
|
||||
hyperbolic_dim: 4,
|
||||
curvature: -1.0,
|
||||
learnable_origin: false,
|
||||
};
|
||||
let mapper = TangentSpaceMapper::new(config);
|
||||
|
||||
let a = vec![0.1, 0.1, 0.0, 0.0];
|
||||
let b = vec![0.1, 0.1, 0.0, 0.0];
|
||||
|
||||
// Same points should have high similarity
|
||||
let sim = mapper.tangent_similarity(&a, &b);
|
||||
assert!(sim > 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_operations() {
|
||||
let config = TangentSpaceConfig::default();
|
||||
let mapper = TangentSpaceMapper::new(config);
|
||||
|
||||
let points: Vec<Vec<f32>> = (0..10).map(|i| vec![i as f32 * 0.05; 32]).collect();
|
||||
let points_refs: Vec<&[f32]> = points.iter().map(|p| p.as_slice()).collect();
|
||||
|
||||
let tangents = mapper.batch_log_map(&points_refs);
|
||||
assert_eq!(tangents.len(), 10);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user