git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
170 lines
5.0 KiB
Rust
170 lines
5.0 KiB
Rust
//! Benchmarks for product manifold operations
|
|
|
|
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
|
|
use rand::prelude::*;
|
|
use rand_distr::StandardNormal;
|
|
use ruvector_math::product_manifold::ProductManifold;
|
|
use ruvector_math::spherical::SphericalSpace;
|
|
|
|
fn generate_point(dim: usize, rng: &mut impl Rng) -> Vec<f64> {
|
|
(0..dim).map(|_| rng.sample(StandardNormal)).collect()
|
|
}
|
|
|
|
fn bench_product_manifold_distance(c: &mut Criterion) {
|
|
let mut group = c.benchmark_group("product_manifold_distance");
|
|
|
|
// Various configurations
|
|
let configs = [
|
|
(32, 0, 0, "euclidean_only"),
|
|
(0, 16, 0, "hyperbolic_only"),
|
|
(0, 0, 8, "spherical_only"),
|
|
(32, 16, 8, "mixed_small"),
|
|
(64, 32, 16, "mixed_medium"),
|
|
(128, 64, 32, "mixed_large"),
|
|
];
|
|
|
|
for (e, h, s, name) in configs.iter() {
|
|
let manifold = ProductManifold::new(*e, *h, *s);
|
|
let dim = manifold.dim();
|
|
|
|
let mut rng = StdRng::seed_from_u64(42);
|
|
let x = manifold.project(&generate_point(dim, &mut rng)).unwrap();
|
|
let y = manifold.project(&generate_point(dim, &mut rng)).unwrap();
|
|
|
|
group.throughput(Throughput::Elements(dim as u64));
|
|
|
|
group.bench_with_input(BenchmarkId::new(*name, dim), &(&x, &y), |b, (px, py)| {
|
|
b.iter(|| manifold.distance(black_box(px), black_box(py)));
|
|
});
|
|
}
|
|
|
|
group.finish();
|
|
}
|
|
|
|
fn bench_product_manifold_exp_log(c: &mut Criterion) {
|
|
let mut group = c.benchmark_group("product_manifold_exp_log");
|
|
|
|
let manifold = ProductManifold::new(64, 32, 16);
|
|
let dim = manifold.dim();
|
|
|
|
let mut rng = StdRng::seed_from_u64(42);
|
|
let x = manifold.project(&generate_point(dim, &mut rng)).unwrap();
|
|
let y = manifold.project(&generate_point(dim, &mut rng)).unwrap();
|
|
let v = manifold.log_map(&x, &y).unwrap();
|
|
|
|
group.throughput(Throughput::Elements(dim as u64));
|
|
|
|
group.bench_function("exp_map", |b| {
|
|
b.iter(|| manifold.exp_map(black_box(&x), black_box(&v)));
|
|
});
|
|
|
|
group.bench_function("log_map", |b| {
|
|
b.iter(|| manifold.log_map(black_box(&x), black_box(&y)));
|
|
});
|
|
|
|
group.bench_function("geodesic", |b| {
|
|
b.iter(|| manifold.geodesic(black_box(&x), black_box(&y), 0.5));
|
|
});
|
|
|
|
group.finish();
|
|
}
|
|
|
|
fn bench_frechet_mean(c: &mut Criterion) {
|
|
let mut group = c.benchmark_group("frechet_mean");
|
|
|
|
for n in [10, 50, 100, 200] {
|
|
let manifold = ProductManifold::new(32, 16, 8);
|
|
let dim = manifold.dim();
|
|
|
|
let mut rng = StdRng::seed_from_u64(42);
|
|
let points: Vec<Vec<f64>> = (0..n)
|
|
.map(|_| manifold.project(&generate_point(dim, &mut rng)).unwrap())
|
|
.collect();
|
|
|
|
group.throughput(Throughput::Elements((n * dim) as u64));
|
|
|
|
group.bench_with_input(
|
|
BenchmarkId::new("product_manifold", n),
|
|
&points,
|
|
|b, pts| {
|
|
b.iter(|| manifold.frechet_mean(black_box(pts), None));
|
|
},
|
|
);
|
|
}
|
|
|
|
group.finish();
|
|
}
|
|
|
|
fn bench_spherical_operations(c: &mut Criterion) {
|
|
let mut group = c.benchmark_group("spherical");
|
|
|
|
for dim in [8, 16, 32, 64, 128] {
|
|
let sphere = SphericalSpace::new(dim);
|
|
|
|
let mut rng = StdRng::seed_from_u64(42);
|
|
let x = sphere.project(&generate_point(dim, &mut rng)).unwrap();
|
|
let y = sphere.project(&generate_point(dim, &mut rng)).unwrap();
|
|
|
|
group.throughput(Throughput::Elements(dim as u64));
|
|
|
|
group.bench_with_input(
|
|
BenchmarkId::new("distance", dim),
|
|
&(&x, &y),
|
|
|b, (px, py)| {
|
|
b.iter(|| sphere.distance(black_box(px), black_box(py)));
|
|
},
|
|
);
|
|
|
|
group.bench_with_input(
|
|
BenchmarkId::new("exp_map", dim),
|
|
&(&x, &y),
|
|
|b, (px, py)| {
|
|
if let Ok(v) = sphere.log_map(px, py) {
|
|
b.iter(|| sphere.exp_map(black_box(px), black_box(&v)));
|
|
}
|
|
},
|
|
);
|
|
}
|
|
|
|
group.finish();
|
|
}
|
|
|
|
fn bench_knn(c: &mut Criterion) {
|
|
let mut group = c.benchmark_group("knn");
|
|
|
|
let manifold = ProductManifold::new(64, 32, 16);
|
|
let dim = manifold.dim();
|
|
|
|
for n in [100, 500, 1000] {
|
|
let mut rng = StdRng::seed_from_u64(42);
|
|
let points: Vec<Vec<f64>> = (0..n)
|
|
.map(|_| manifold.project(&generate_point(dim, &mut rng)).unwrap())
|
|
.collect();
|
|
let query = manifold.project(&generate_point(dim, &mut rng)).unwrap();
|
|
|
|
group.throughput(Throughput::Elements(n as u64));
|
|
|
|
for k in [5, 10, 20] {
|
|
group.bench_with_input(
|
|
BenchmarkId::new(format!("n={}_k={}", n, k), n),
|
|
&(&query, &points),
|
|
|b, (q, pts)| {
|
|
b.iter(|| manifold.knn(black_box(q), black_box(pts), k));
|
|
},
|
|
);
|
|
}
|
|
}
|
|
|
|
group.finish();
|
|
}
|
|
|
|
criterion_group!(
|
|
benches,
|
|
bench_product_manifold_distance,
|
|
bench_product_manifold_exp_log,
|
|
bench_frechet_mean,
|
|
bench_spherical_operations,
|
|
bench_knn,
|
|
);
|
|
criterion_main!(benches);
|