Squashed 'vendor/ruvector/' content from commit b64c2172

git-subtree-dir: vendor/ruvector
git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
commit d803bfe2b1
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,533 @@
//! Sliced Wasserstein Distance
//!
//! The Sliced Wasserstein distance projects high-dimensional distributions
//! onto random 1D lines and averages the 1D Wasserstein distances.
//!
//! ## Algorithm
//!
//! 1. Generate L random unit vectors (directions) in R^d
//! 2. For each direction θ:
//! a. Project all source and target points onto θ
//! b. Compute 1D Wasserstein distance (closed-form via sorted quantiles)
//! 3. Average over all directions
//!
//! ## Complexity
//!
//! - O(L × n log n) where L = number of projections, n = number of points
//! - Linear in dimension d (only dot products)
//!
//! ## Advantages
//!
//! - **Fast**: Near-linear scaling to millions of points
//! - **SIMD-friendly**: Projections are just dot products
//! - **Statistically consistent**: Converges to true W2 as L → ∞
use super::{OptimalTransport, WassersteinConfig};
use crate::utils::{argsort, EPS};
use rand::prelude::*;
use rand_distr::StandardNormal;
/// Sliced Wasserstein distance calculator
#[derive(Debug, Clone)]
pub struct SlicedWasserstein {
/// Number of random projection directions
num_projections: usize,
/// Power for Wasserstein-p (typically 1 or 2)
p: f64,
/// Random seed for reproducibility
seed: Option<u64>,
}
impl SlicedWasserstein {
/// Create a new Sliced Wasserstein calculator
///
/// # Arguments
/// * `num_projections` - Number of random 1D projections (100-1000 typical)
pub fn new(num_projections: usize) -> Self {
Self {
num_projections: num_projections.max(1),
p: 2.0,
seed: None,
}
}
/// Create from configuration
pub fn from_config(config: &WassersteinConfig) -> Self {
Self {
num_projections: config.num_projections.max(1),
p: config.p,
seed: config.seed,
}
}
/// Set the Wasserstein power (1 for W1, 2 for W2)
pub fn with_power(mut self, p: f64) -> Self {
self.p = p.max(1.0);
self
}
/// Set random seed for reproducibility
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
/// Generate random unit directions
fn generate_directions(&self, dim: usize) -> Vec<Vec<f64>> {
let mut rng = match self.seed {
Some(s) => StdRng::seed_from_u64(s),
None => StdRng::from_entropy(),
};
(0..self.num_projections)
.map(|_| {
let mut direction: Vec<f64> =
(0..dim).map(|_| rng.sample(StandardNormal)).collect();
// Normalize to unit vector
let norm: f64 = direction.iter().map(|&x| x * x).sum::<f64>().sqrt();
if norm > EPS {
for x in &mut direction {
*x /= norm;
}
}
direction
})
.collect()
}
/// Project points onto a direction (SIMD-friendly dot product)
#[inline(always)]
fn project(points: &[Vec<f64>], direction: &[f64]) -> Vec<f64> {
points
.iter()
.map(|p| Self::dot_product(p, direction))
.collect()
}
/// Project points into pre-allocated buffer (reduces allocations)
#[inline(always)]
fn project_into(points: &[Vec<f64>], direction: &[f64], out: &mut [f64]) {
for (i, p) in points.iter().enumerate() {
out[i] = Self::dot_product(p, direction);
}
}
/// SIMD-friendly dot product using fold pattern
/// Compiler can auto-vectorize this pattern effectively
#[inline(always)]
fn dot_product(a: &[f64], b: &[f64]) -> f64 {
// Use 4-way unrolled accumulator for better SIMD utilization
let len = a.len();
let chunks = len / 4;
let remainder = len % 4;
let mut sum0 = 0.0f64;
let mut sum1 = 0.0f64;
let mut sum2 = 0.0f64;
let mut sum3 = 0.0f64;
// Process 4 elements at a time (helps SIMD vectorization)
for i in 0..chunks {
let base = i * 4;
sum0 += a[base] * b[base];
sum1 += a[base + 1] * b[base + 1];
sum2 += a[base + 2] * b[base + 2];
sum3 += a[base + 3] * b[base + 3];
}
// Handle remainder
let base = chunks * 4;
for i in 0..remainder {
sum0 += a[base + i] * b[base + i];
}
sum0 + sum1 + sum2 + sum3
}
/// Compute 1D Wasserstein distance between two sorted distributions
///
/// For uniform weights, this is simply the sum of |sorted_a[i] - sorted_b[i]|^p
#[inline]
fn wasserstein_1d_uniform(&self, mut proj_a: Vec<f64>, mut proj_b: Vec<f64>) -> f64 {
let n = proj_a.len();
let m = proj_b.len();
// Sort projections using fast f64 comparison
proj_a.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
proj_b.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
if n == m {
// Same size: direct comparison with SIMD-friendly accumulator
self.wasserstein_1d_equal_size(&proj_a, &proj_b)
} else {
// Different sizes: interpolate via quantiles
self.wasserstein_1d_quantile(&proj_a, &proj_b, n.max(m))
}
}
/// Optimized equal-size 1D Wasserstein with SIMD-friendly pattern
#[inline(always)]
fn wasserstein_1d_equal_size(&self, sorted_a: &[f64], sorted_b: &[f64]) -> f64 {
let n = sorted_a.len();
if n == 0 {
return 0.0;
}
// Use p=2 fast path (most common case)
if (self.p - 2.0).abs() < 1e-10 {
// L2 Wasserstein: sum of squared differences
let mut sum0 = 0.0f64;
let mut sum1 = 0.0f64;
let mut sum2 = 0.0f64;
let mut sum3 = 0.0f64;
let chunks = n / 4;
let remainder = n % 4;
for i in 0..chunks {
let base = i * 4;
let d0 = sorted_a[base] - sorted_b[base];
let d1 = sorted_a[base + 1] - sorted_b[base + 1];
let d2 = sorted_a[base + 2] - sorted_b[base + 2];
let d3 = sorted_a[base + 3] - sorted_b[base + 3];
sum0 += d0 * d0;
sum1 += d1 * d1;
sum2 += d2 * d2;
sum3 += d3 * d3;
}
let base = chunks * 4;
for i in 0..remainder {
let d = sorted_a[base + i] - sorted_b[base + i];
sum0 += d * d;
}
(sum0 + sum1 + sum2 + sum3) / n as f64
} else if (self.p - 1.0).abs() < 1e-10 {
// L1 Wasserstein: sum of absolute differences
let mut sum = 0.0f64;
for i in 0..n {
sum += (sorted_a[i] - sorted_b[i]).abs();
}
sum / n as f64
} else {
// General case
sorted_a
.iter()
.zip(sorted_b.iter())
.map(|(&a, &b)| (a - b).abs().powf(self.p))
.sum::<f64>()
/ n as f64
}
}
/// Compute 1D Wasserstein via quantile interpolation
fn wasserstein_1d_quantile(
&self,
sorted_a: &[f64],
sorted_b: &[f64],
num_samples: usize,
) -> f64 {
let mut total = 0.0;
for i in 0..num_samples {
let q = (i as f64 + 0.5) / num_samples as f64;
let val_a = quantile_sorted(sorted_a, q);
let val_b = quantile_sorted(sorted_b, q);
total += (val_a - val_b).abs().powf(self.p);
}
total / num_samples as f64
}
/// Compute 1D Wasserstein with weights
fn wasserstein_1d_weighted(
&self,
proj_a: &[f64],
weights_a: &[f64],
proj_b: &[f64],
weights_b: &[f64],
) -> f64 {
// Sort by projected values
let idx_a = argsort(proj_a);
let idx_b = argsort(proj_b);
let sorted_a: Vec<f64> = idx_a.iter().map(|&i| proj_a[i]).collect();
let sorted_w_a: Vec<f64> = idx_a.iter().map(|&i| weights_a[i]).collect();
let sorted_b: Vec<f64> = idx_b.iter().map(|&i| proj_b[i]).collect();
let sorted_w_b: Vec<f64> = idx_b.iter().map(|&i| weights_b[i]).collect();
// Compute cumulative weights
let cdf_a = compute_cdf(&sorted_w_a);
let cdf_b = compute_cdf(&sorted_w_b);
// Merge and compute
self.wasserstein_1d_from_cdfs(&sorted_a, &cdf_a, &sorted_b, &cdf_b)
}
/// Compute 1D Wasserstein from CDFs
fn wasserstein_1d_from_cdfs(
&self,
values_a: &[f64],
cdf_a: &[f64],
values_b: &[f64],
cdf_b: &[f64],
) -> f64 {
// Merge all CDF points
let mut events: Vec<(f64, f64, f64)> = Vec::new(); // (position, cdf_a, cdf_b)
let mut ia = 0;
let mut ib = 0;
let mut current_cdf_a = 0.0;
let mut current_cdf_b = 0.0;
while ia < values_a.len() || ib < values_b.len() {
let pos = match (ia < values_a.len(), ib < values_b.len()) {
(true, true) => {
if values_a[ia] <= values_b[ib] {
current_cdf_a = cdf_a[ia];
ia += 1;
values_a[ia - 1]
} else {
current_cdf_b = cdf_b[ib];
ib += 1;
values_b[ib - 1]
}
}
(true, false) => {
current_cdf_a = cdf_a[ia];
ia += 1;
values_a[ia - 1]
}
(false, true) => {
current_cdf_b = cdf_b[ib];
ib += 1;
values_b[ib - 1]
}
(false, false) => break,
};
events.push((pos, current_cdf_a, current_cdf_b));
}
// Integrate |F_a - F_b|^p
let mut total = 0.0;
for i in 1..events.len() {
let width = events[i].0 - events[i - 1].0;
let height = (events[i - 1].1 - events[i - 1].2).abs();
total += width * height.powf(self.p);
}
total
}
}
impl OptimalTransport for SlicedWasserstein {
fn distance(&self, source: &[Vec<f64>], target: &[Vec<f64>]) -> f64 {
if source.is_empty() || target.is_empty() {
return 0.0;
}
let dim = source[0].len();
if dim == 0 {
return 0.0;
}
let directions = self.generate_directions(dim);
let n_source = source.len();
let n_target = target.len();
// Pre-allocate projection buffers (reduces allocations per direction)
let mut proj_source = vec![0.0; n_source];
let mut proj_target = vec![0.0; n_target];
let total: f64 = directions
.iter()
.map(|dir| {
// Project into pre-allocated buffers
Self::project_into(source, dir, &mut proj_source);
Self::project_into(target, dir, &mut proj_target);
// Clone for sorting (wasserstein_1d_uniform sorts in place)
self.wasserstein_1d_uniform(proj_source.clone(), proj_target.clone())
})
.sum();
(total / self.num_projections as f64).powf(1.0 / self.p)
}
fn weighted_distance(
&self,
source: &[Vec<f64>],
source_weights: &[f64],
target: &[Vec<f64>],
target_weights: &[f64],
) -> f64 {
if source.is_empty() || target.is_empty() {
return 0.0;
}
let dim = source[0].len();
if dim == 0 {
return 0.0;
}
// Normalize weights
let sum_a: f64 = source_weights.iter().sum();
let sum_b: f64 = target_weights.iter().sum();
let weights_a: Vec<f64> = source_weights.iter().map(|&w| w / sum_a).collect();
let weights_b: Vec<f64> = target_weights.iter().map(|&w| w / sum_b).collect();
let directions = self.generate_directions(dim);
let total: f64 = directions
.iter()
.map(|dir| {
let proj_source = Self::project(source, dir);
let proj_target = Self::project(target, dir);
self.wasserstein_1d_weighted(&proj_source, &weights_a, &proj_target, &weights_b)
})
.sum();
(total / self.num_projections as f64).powf(1.0 / self.p)
}
}
/// Quantile of sorted data
fn quantile_sorted(sorted: &[f64], q: f64) -> f64 {
if sorted.is_empty() {
return 0.0;
}
let q = q.clamp(0.0, 1.0);
let n = sorted.len();
if n == 1 {
return sorted[0];
}
let idx_f = q * (n - 1) as f64;
let idx_low = idx_f.floor() as usize;
let idx_high = (idx_low + 1).min(n - 1);
let frac = idx_f - idx_low as f64;
sorted[idx_low] * (1.0 - frac) + sorted[idx_high] * frac
}
/// Compute CDF from weights
fn compute_cdf(weights: &[f64]) -> Vec<f64> {
let total: f64 = weights.iter().sum();
let mut cdf = Vec::with_capacity(weights.len());
let mut cumsum = 0.0;
for &w in weights {
cumsum += w / total;
cdf.push(cumsum);
}
cdf
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sliced_wasserstein_identical() {
let sw = SlicedWasserstein::new(100).with_seed(42);
let points = vec![
vec![0.0, 0.0],
vec![1.0, 0.0],
vec![0.0, 1.0],
vec![1.0, 1.0],
];
// Distance to itself should be very small
let dist = sw.distance(&points, &points);
assert!(dist < 0.01, "Self-distance should be ~0, got {}", dist);
}
#[test]
fn test_sliced_wasserstein_translation() {
let sw = SlicedWasserstein::new(500).with_seed(42);
let source = vec![
vec![0.0, 0.0],
vec![1.0, 0.0],
vec![0.0, 1.0],
vec![1.0, 1.0],
];
// Translate by (1, 1)
let target: Vec<Vec<f64>> = source
.iter()
.map(|p| vec![p[0] + 1.0, p[1] + 1.0])
.collect();
let dist = sw.distance(&source, &target);
// For W2 translation by (1, 1), expected distance is sqrt(2) ≈ 1.414
// But Sliced Wasserstein is an approximation, so allow wider tolerance
assert!(
dist > 0.5 && dist < 2.0,
"Translation distance should be positive, got {:.3}",
dist
);
}
#[test]
fn test_sliced_wasserstein_scaling() {
let sw = SlicedWasserstein::new(500).with_seed(42);
let source = vec![
vec![0.0, 0.0],
vec![1.0, 0.0],
vec![0.0, 1.0],
vec![1.0, 1.0],
];
// Scale by 2
let target: Vec<Vec<f64>> = source
.iter()
.map(|p| vec![p[0] * 2.0, p[1] * 2.0])
.collect();
let dist = sw.distance(&source, &target);
// Should be positive for scaled distribution
assert!(dist > 0.0, "Scaling should produce positive distance");
}
#[test]
fn test_weighted_distance() {
let sw = SlicedWasserstein::new(100).with_seed(42);
let source = vec![vec![0.0], vec![1.0]];
let target = vec![vec![2.0], vec![3.0]];
// Uniform weights
let weights_s = vec![0.5, 0.5];
let weights_t = vec![0.5, 0.5];
let dist = sw.weighted_distance(&source, &weights_s, &target, &weights_t);
assert!(dist > 0.0);
}
#[test]
fn test_1d_projections() {
let sw = SlicedWasserstein::new(10);
let directions = sw.generate_directions(3);
assert_eq!(directions.len(), 10);
// Each direction should be unit length
for dir in &directions {
let norm: f64 = dir.iter().map(|&x| x * x).sum::<f64>().sqrt();
assert!((norm - 1.0).abs() < 1e-6, "Direction not unit: {}", norm);
}
}
}