Files
wifi-densepose/crates/ruvector-math/benches/optimal_transport.rs
ruv d803bfe2b1 Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector
git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
2026-02-28 14:39:40 -05:00

95 lines
2.7 KiB
Rust

//! Benchmarks for optimal transport algorithms
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use rand::prelude::*;
use rand_distr::StandardNormal;
use ruvector_math::optimal_transport::{OptimalTransport, SinkhornSolver, SlicedWasserstein};
fn generate_points(n: usize, dim: usize, seed: u64) -> Vec<Vec<f64>> {
let mut rng = StdRng::seed_from_u64(seed);
(0..n)
.map(|_| (0..dim).map(|_| rng.sample(StandardNormal)).collect())
.collect()
}
fn bench_sliced_wasserstein(c: &mut Criterion) {
let mut group = c.benchmark_group("sliced_wasserstein");
for n in [100, 500, 1000, 5000] {
group.throughput(Throughput::Elements(n as u64));
let source = generate_points(n, 128, 42);
let target = generate_points(n, 128, 43);
// Vary number of projections
for projections in [50, 100, 200] {
let sw = SlicedWasserstein::new(projections).with_seed(42);
group.bench_with_input(
BenchmarkId::new(format!("n={}_proj={}", n, projections), n),
&(&source, &target),
|b, (s, t)| {
b.iter(|| sw.distance(black_box(s), black_box(t)));
},
);
}
}
group.finish();
}
fn bench_sinkhorn(c: &mut Criterion) {
let mut group = c.benchmark_group("sinkhorn");
for n in [50, 100, 200, 500] {
group.throughput(Throughput::Elements((n * n) as u64));
let source = generate_points(n, 32, 42);
let target = generate_points(n, 32, 43);
for reg in [0.01, 0.05, 0.1] {
let solver = SinkhornSolver::new(reg, 100);
group.bench_with_input(
BenchmarkId::new(format!("n={}_reg={}", n, reg), n),
&(&source, &target),
|b, (s, t)| {
b.iter(|| solver.distance(black_box(s), black_box(t)));
},
);
}
}
group.finish();
}
fn bench_scaling(c: &mut Criterion) {
let mut group = c.benchmark_group("scaling");
// Test how Sliced Wasserstein scales with dimension
let n = 500;
for dim in [32, 64, 128, 256, 512] {
let source = generate_points(n, dim, 42);
let target = generate_points(n, dim, 43);
let sw = SlicedWasserstein::new(100).with_seed(42);
group.bench_with_input(
BenchmarkId::new("sw_dim_scaling", dim),
&(&source, &target),
|b, (s, t)| {
b.iter(|| sw.distance(black_box(s), black_box(t)));
},
);
}
group.finish();
}
criterion_group!(
benches,
bench_sliced_wasserstein,
bench_sinkhorn,
bench_scaling,
);
criterion_main!(benches);