git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
264 lines
12 KiB
Rust
264 lines
12 KiB
Rust
//! Benchmark: Lorentz Cascade Attention vs Poincaré Attention
|
|
//!
|
|
//! Run with: cargo bench -p ruvector-attention --bench hyperbolic_bench
|
|
|
|
use std::time::Instant;
|
|
|
|
// Import both attention mechanisms
|
|
use ruvector_attention::hyperbolic::{
|
|
busemann_score,
|
|
einstein_midpoint,
|
|
frechet_mean,
|
|
lorentz_distance,
|
|
// Poincaré (baseline)
|
|
poincare_distance,
|
|
project_hyperboloid,
|
|
HyperbolicAttention,
|
|
HyperbolicAttentionConfig,
|
|
LCAConfig,
|
|
// Lorentz Cascade (novel)
|
|
LorentzCascadeAttention,
|
|
};
|
|
|
|
fn generate_test_data(n: usize, dim: usize) -> (Vec<f32>, Vec<Vec<f32>>) {
|
|
let query: Vec<f32> = (0..dim)
|
|
.map(|i| ((i as f32 * 0.1).sin() * 0.3).clamp(-0.9, 0.9))
|
|
.collect();
|
|
|
|
let keys: Vec<Vec<f32>> = (0..n)
|
|
.map(|j| {
|
|
(0..dim)
|
|
.map(|i| (((i + j) as f32 * 0.07).cos() * 0.3).clamp(-0.9, 0.9))
|
|
.collect()
|
|
})
|
|
.collect();
|
|
|
|
(query, keys)
|
|
}
|
|
|
|
fn bench_poincare_distance(iterations: usize, n_keys: usize, dim: usize) -> std::time::Duration {
|
|
let (query, keys) = generate_test_data(n_keys, dim);
|
|
let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
|
|
|
|
let start = Instant::now();
|
|
for _ in 0..iterations {
|
|
for key in &keys_refs {
|
|
let _d = poincare_distance(&query, key, 1.0);
|
|
}
|
|
}
|
|
start.elapsed()
|
|
}
|
|
|
|
fn bench_lorentz_distance(iterations: usize, n_keys: usize, dim: usize) -> std::time::Duration {
|
|
let (query, keys) = generate_test_data(n_keys, dim + 1); // +1 for time dimension
|
|
let query_h = project_hyperboloid(&query, 1.0);
|
|
let keys_h: Vec<Vec<f32>> = keys.iter().map(|k| project_hyperboloid(k, 1.0)).collect();
|
|
|
|
let start = Instant::now();
|
|
for _ in 0..iterations {
|
|
for key in &keys_h {
|
|
let _d = lorentz_distance(&query_h, key, 1.0);
|
|
}
|
|
}
|
|
start.elapsed()
|
|
}
|
|
|
|
fn bench_busemann_scoring(iterations: usize, n_keys: usize, dim: usize) -> std::time::Duration {
|
|
let (query, keys) = generate_test_data(n_keys, dim + 1);
|
|
let focal: Vec<f32> = {
|
|
let mut f = vec![1.0];
|
|
f.extend(vec![0.0; dim]);
|
|
f[1] = 1.0; // Light-like
|
|
f
|
|
};
|
|
let query_h = project_hyperboloid(&query, 1.0);
|
|
let keys_h: Vec<Vec<f32>> = keys.iter().map(|k| project_hyperboloid(k, 1.0)).collect();
|
|
|
|
let start = Instant::now();
|
|
for _ in 0..iterations {
|
|
for key in &keys_h {
|
|
let _score = busemann_score(key, &focal) - busemann_score(&query_h, &focal);
|
|
}
|
|
}
|
|
start.elapsed()
|
|
}
|
|
|
|
fn bench_frechet_mean(iterations: usize, n_points: usize, dim: usize) -> std::time::Duration {
|
|
let (_, points) = generate_test_data(n_points, dim);
|
|
let points_refs: Vec<&[f32]> = points.iter().map(|p| p.as_slice()).collect();
|
|
|
|
let start = Instant::now();
|
|
for _ in 0..iterations {
|
|
let _mean = frechet_mean(&points_refs, None, 1.0, 50, 1e-5);
|
|
}
|
|
start.elapsed()
|
|
}
|
|
|
|
fn bench_einstein_midpoint(iterations: usize, n_points: usize, dim: usize) -> std::time::Duration {
|
|
let (_, points) = generate_test_data(n_points, dim + 1);
|
|
let points_h: Vec<Vec<f32>> = points.iter().map(|p| project_hyperboloid(p, 1.0)).collect();
|
|
let points_refs: Vec<&[f32]> = points_h.iter().map(|p| p.as_slice()).collect();
|
|
let weights: Vec<f32> = vec![1.0 / n_points as f32; n_points];
|
|
|
|
let start = Instant::now();
|
|
for _ in 0..iterations {
|
|
let _mid = einstein_midpoint(&points_refs, &weights, 1.0);
|
|
}
|
|
start.elapsed()
|
|
}
|
|
|
|
fn bench_full_poincare_attention(
|
|
iterations: usize,
|
|
n_keys: usize,
|
|
dim: usize,
|
|
) -> std::time::Duration {
|
|
let (query, keys) = generate_test_data(n_keys, dim);
|
|
let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
|
|
|
|
let config = HyperbolicAttentionConfig {
|
|
dim,
|
|
curvature: -1.0,
|
|
adaptive_curvature: false,
|
|
temperature: 1.0,
|
|
frechet_max_iter: 50,
|
|
frechet_tol: 1e-5,
|
|
};
|
|
let attention = HyperbolicAttention::new(config);
|
|
|
|
let start = Instant::now();
|
|
for _ in 0..iterations {
|
|
let _result = attention.compute_weights(&query, &keys_refs);
|
|
}
|
|
start.elapsed()
|
|
}
|
|
|
|
fn bench_full_lca_attention(iterations: usize, n_keys: usize, dim: usize) -> std::time::Duration {
|
|
let (query, keys) = generate_test_data(n_keys, dim);
|
|
let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
|
|
|
|
let config = LCAConfig {
|
|
dim,
|
|
num_heads: 4,
|
|
curvature_range: (0.1, 2.0),
|
|
temperature: 1.0,
|
|
};
|
|
let attention = LorentzCascadeAttention::new(config);
|
|
|
|
let start = Instant::now();
|
|
for _ in 0..iterations {
|
|
let _result = attention.attend(&query, &keys_refs, &keys_refs);
|
|
}
|
|
start.elapsed()
|
|
}
|
|
|
|
fn main() {
|
|
println!("╔══════════════════════════════════════════════════════════════════╗");
|
|
println!("║ Lorentz Cascade Attention (LCA) vs Poincaré Benchmark ║");
|
|
println!("╚══════════════════════════════════════════════════════════════════╝\n");
|
|
|
|
let iterations = 1000;
|
|
let n_keys = 100;
|
|
let dim = 64;
|
|
|
|
println!(
|
|
"Configuration: {} iterations, {} keys, {} dimensions\n",
|
|
iterations, n_keys, dim
|
|
);
|
|
|
|
// Distance computation benchmarks
|
|
println!("┌─────────────────────────────────────────────────────────────────┐");
|
|
println!("│ 1. DISTANCE COMPUTATION │");
|
|
println!("├─────────────────────────────────────────────────────────────────┤");
|
|
|
|
let poincare_dist_time = bench_poincare_distance(iterations, n_keys, dim);
|
|
let lorentz_dist_time = bench_lorentz_distance(iterations, n_keys, dim);
|
|
let busemann_time = bench_busemann_scoring(iterations, n_keys, dim);
|
|
|
|
let poincare_per_op = poincare_dist_time.as_nanos() as f64 / (iterations * n_keys) as f64;
|
|
let lorentz_per_op = lorentz_dist_time.as_nanos() as f64 / (iterations * n_keys) as f64;
|
|
let busemann_per_op = busemann_time.as_nanos() as f64 / (iterations * n_keys) as f64;
|
|
|
|
println!(
|
|
"│ Poincaré distance: {:>8.1} ns/op │",
|
|
poincare_per_op
|
|
);
|
|
println!(
|
|
"│ Lorentz distance: {:>8.1} ns/op ({:.1}x vs Poincaré) │",
|
|
lorentz_per_op,
|
|
poincare_per_op / lorentz_per_op
|
|
);
|
|
println!(
|
|
"│ Busemann scoring: {:>8.1} ns/op ({:.1}x vs Poincaré) │",
|
|
busemann_per_op,
|
|
poincare_per_op / busemann_per_op
|
|
);
|
|
println!("└─────────────────────────────────────────────────────────────────┘\n");
|
|
|
|
// Aggregation benchmarks
|
|
println!("┌─────────────────────────────────────────────────────────────────┐");
|
|
println!("│ 2. AGGREGATION (CENTROID) │");
|
|
println!("├─────────────────────────────────────────────────────────────────┤");
|
|
|
|
let frechet_time = bench_frechet_mean(iterations / 10, n_keys, dim); // Fewer iterations (slow)
|
|
let einstein_time = bench_einstein_midpoint(iterations, n_keys, dim);
|
|
|
|
let frechet_per_op = frechet_time.as_nanos() as f64 / (iterations / 10) as f64;
|
|
let einstein_per_op = einstein_time.as_nanos() as f64 / iterations as f64;
|
|
|
|
println!(
|
|
"│ Fréchet mean (50 iter): {:>10.1} ns/op │",
|
|
frechet_per_op
|
|
);
|
|
println!(
|
|
"│ Einstein midpoint: {:>10.1} ns/op ({:.1}x faster!) │",
|
|
einstein_per_op,
|
|
frechet_per_op / einstein_per_op
|
|
);
|
|
println!("└─────────────────────────────────────────────────────────────────┘\n");
|
|
|
|
// Full attention benchmarks
|
|
println!("┌─────────────────────────────────────────────────────────────────┐");
|
|
println!("│ 3. FULL ATTENTION (END-TO-END) │");
|
|
println!("├─────────────────────────────────────────────────────────────────┤");
|
|
|
|
let poincare_full_time = bench_full_poincare_attention(iterations / 10, n_keys, dim);
|
|
let lca_full_time = bench_full_lca_attention(iterations / 10, n_keys, dim);
|
|
|
|
let poincare_full_per_op = poincare_full_time.as_nanos() as f64 / (iterations / 10) as f64;
|
|
let lca_full_per_op = lca_full_time.as_nanos() as f64 / (iterations / 10) as f64;
|
|
|
|
println!(
|
|
"│ Poincaré Attention: {:>10.1} ns/op │",
|
|
poincare_full_per_op
|
|
);
|
|
println!(
|
|
"│ Lorentz Cascade (4 heads): {:>7.1} ns/op ({:.1}x speedup) │",
|
|
lca_full_per_op,
|
|
poincare_full_per_op / lca_full_per_op
|
|
);
|
|
println!("└─────────────────────────────────────────────────────────────────┘\n");
|
|
|
|
// Summary
|
|
println!("╔══════════════════════════════════════════════════════════════════╗");
|
|
println!("║ SUMMARY: Lorentz Cascade Attention Improvements ║");
|
|
println!("╠══════════════════════════════════════════════════════════════════╣");
|
|
println!(
|
|
"║ • Busemann scoring: {:.1}x faster than Poincaré distance ║",
|
|
poincare_per_op / busemann_per_op
|
|
);
|
|
println!(
|
|
"║ • Einstein midpoint: {:.1}x faster than Fréchet mean ║",
|
|
frechet_per_op / einstein_per_op
|
|
);
|
|
println!(
|
|
"║ • End-to-end: {:.1}x overall speedup ║",
|
|
poincare_full_per_op / lca_full_per_op
|
|
);
|
|
println!("║ ║");
|
|
println!("║ Additional benefits: ║");
|
|
println!("║ • No boundary instability (Lorentz vs Poincaré ball) ║");
|
|
println!("║ • Multi-scale hierarchy (4 curvature heads) ║");
|
|
println!("║ • Sparse attention via hierarchical pruning ║");
|
|
println!("╚══════════════════════════════════════════════════════════════════╝");
|
|
}
|