Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
This commit is contained in:
473
crates/ruvector-math/src/optimal_transport/sinkhorn.rs
Normal file
473
crates/ruvector-math/src/optimal_transport/sinkhorn.rs
Normal file
@@ -0,0 +1,473 @@
|
||||
//! Log-Stabilized Sinkhorn Algorithm
|
||||
//!
|
||||
//! The Sinkhorn algorithm computes the entropic-regularized optimal transport:
|
||||
//!
|
||||
//! min_{γ ∈ Π(a,b)} ⟨γ, C⟩ - ε H(γ)
|
||||
//!
|
||||
//! where H(γ) = -Σ γ_ij log(γ_ij) is the entropy and ε is the regularization.
|
||||
//!
|
||||
//! ## Log-Stabilization
|
||||
//!
|
||||
//! We work in log-domain to prevent numerical overflow/underflow:
|
||||
//! - Store log(u) and log(v) instead of u, v
|
||||
//! - Use log-sum-exp for stable normalization
|
||||
//!
|
||||
//! ## Complexity
|
||||
//!
|
||||
//! - O(n² × iterations) for dense cost matrix
|
||||
//! - Typically converges in 50-200 iterations
|
||||
//! - ~1000x faster than linear programming for exact OT
|
||||
|
||||
use crate::error::{MathError, Result};
|
||||
use crate::utils::{log_sum_exp, EPS, LOG_MIN};
|
||||
|
||||
/// Result of Sinkhorn algorithm
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TransportPlan {
|
||||
/// Transport plan matrix γ[i,j] (n × m)
|
||||
pub plan: Vec<Vec<f64>>,
|
||||
/// Total transport cost
|
||||
pub cost: f64,
|
||||
/// Number of iterations to convergence
|
||||
pub iterations: usize,
|
||||
/// Final marginal error (||Pγ - a||₁ + ||γᵀ1 - b||₁)
|
||||
pub marginal_error: f64,
|
||||
/// Whether the algorithm converged
|
||||
pub converged: bool,
|
||||
}
|
||||
|
||||
/// Log-stabilized Sinkhorn solver for entropic optimal transport
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SinkhornSolver {
|
||||
/// Regularization parameter ε
|
||||
regularization: f64,
|
||||
/// Maximum iterations
|
||||
max_iterations: usize,
|
||||
/// Convergence threshold
|
||||
threshold: f64,
|
||||
}
|
||||
|
||||
impl SinkhornSolver {
|
||||
/// Create a new Sinkhorn solver
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `regularization` - Entropy regularization ε (0.01-0.1 typical)
|
||||
/// * `max_iterations` - Maximum Sinkhorn iterations (100-1000 typical)
|
||||
pub fn new(regularization: f64, max_iterations: usize) -> Self {
|
||||
Self {
|
||||
regularization: regularization.max(1e-6),
|
||||
max_iterations: max_iterations.max(1),
|
||||
threshold: 1e-6,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set convergence threshold
|
||||
pub fn with_threshold(mut self, threshold: f64) -> Self {
|
||||
self.threshold = threshold.max(1e-12);
|
||||
self
|
||||
}
|
||||
|
||||
/// Compute the cost matrix for squared Euclidean distance
|
||||
/// Uses SIMD-friendly 4-way unrolled accumulator for better performance
|
||||
#[inline]
|
||||
pub fn compute_cost_matrix(source: &[Vec<f64>], target: &[Vec<f64>]) -> Vec<Vec<f64>> {
|
||||
source
|
||||
.iter()
|
||||
.map(|s| {
|
||||
target
|
||||
.iter()
|
||||
.map(|t| Self::squared_euclidean(s, t))
|
||||
.collect()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// SIMD-friendly squared Euclidean distance
|
||||
#[inline(always)]
|
||||
fn squared_euclidean(a: &[f64], b: &[f64]) -> f64 {
|
||||
let len = a.len();
|
||||
let chunks = len / 4;
|
||||
let remainder = len % 4;
|
||||
|
||||
let mut sum0 = 0.0f64;
|
||||
let mut sum1 = 0.0f64;
|
||||
let mut sum2 = 0.0f64;
|
||||
let mut sum3 = 0.0f64;
|
||||
|
||||
for i in 0..chunks {
|
||||
let base = i * 4;
|
||||
let d0 = a[base] - b[base];
|
||||
let d1 = a[base + 1] - b[base + 1];
|
||||
let d2 = a[base + 2] - b[base + 2];
|
||||
let d3 = a[base + 3] - b[base + 3];
|
||||
sum0 += d0 * d0;
|
||||
sum1 += d1 * d1;
|
||||
sum2 += d2 * d2;
|
||||
sum3 += d3 * d3;
|
||||
}
|
||||
|
||||
let base = chunks * 4;
|
||||
for i in 0..remainder {
|
||||
let d = a[base + i] - b[base + i];
|
||||
sum0 += d * d;
|
||||
}
|
||||
|
||||
sum0 + sum1 + sum2 + sum3
|
||||
}
|
||||
|
||||
/// Solve optimal transport using log-stabilized Sinkhorn
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `cost_matrix` - C[i,j] = cost to move from source[i] to target[j]
|
||||
/// * `source_weights` - Marginal distribution a (sum to 1)
|
||||
/// * `target_weights` - Marginal distribution b (sum to 1)
|
||||
pub fn solve(
|
||||
&self,
|
||||
cost_matrix: &[Vec<f64>],
|
||||
source_weights: &[f64],
|
||||
target_weights: &[f64],
|
||||
) -> Result<TransportPlan> {
|
||||
let n = source_weights.len();
|
||||
let m = target_weights.len();
|
||||
|
||||
if n == 0 || m == 0 {
|
||||
return Err(MathError::empty_input("weights"));
|
||||
}
|
||||
|
||||
if cost_matrix.len() != n || cost_matrix.iter().any(|row| row.len() != m) {
|
||||
return Err(MathError::dimension_mismatch(n, cost_matrix.len()));
|
||||
}
|
||||
|
||||
// Normalize weights
|
||||
let sum_a: f64 = source_weights.iter().sum();
|
||||
let sum_b: f64 = target_weights.iter().sum();
|
||||
let a: Vec<f64> = source_weights.iter().map(|&w| w / sum_a).collect();
|
||||
let b: Vec<f64> = target_weights.iter().map(|&w| w / sum_b).collect();
|
||||
|
||||
// Initialize log-domain Gibbs kernel: K = exp(-C/ε)
|
||||
// Store log(K) = -C/ε
|
||||
let log_k: Vec<Vec<f64>> = cost_matrix
|
||||
.iter()
|
||||
.map(|row| row.iter().map(|&c| -c / self.regularization).collect())
|
||||
.collect();
|
||||
|
||||
// Initialize log scaling vectors
|
||||
let mut log_u = vec![0.0; n];
|
||||
let mut log_v = vec![0.0; m];
|
||||
|
||||
let log_a: Vec<f64> = a.iter().map(|&ai| ai.ln().max(LOG_MIN)).collect();
|
||||
let log_b: Vec<f64> = b.iter().map(|&bi| bi.ln().max(LOG_MIN)).collect();
|
||||
|
||||
let mut converged = false;
|
||||
let mut iterations = 0;
|
||||
let mut marginal_error = f64::INFINITY;
|
||||
|
||||
// Pre-allocate buffers for log-sum-exp computation (reduces allocations per iteration)
|
||||
let mut log_terms_row = vec![0.0; m];
|
||||
let mut log_terms_col = vec![0.0; n];
|
||||
|
||||
// Sinkhorn iterations in log domain
|
||||
for iter in 0..self.max_iterations {
|
||||
iterations = iter + 1;
|
||||
|
||||
// Update log_u: log_u = log_a - log_sum_exp_j(log_v[j] + log_K[i,j])
|
||||
let mut max_u_change: f64 = 0.0;
|
||||
for i in 0..n {
|
||||
let old_log_u = log_u[i];
|
||||
// Compute into pre-allocated buffer
|
||||
for j in 0..m {
|
||||
log_terms_row[j] = log_v[j] + log_k[i][j];
|
||||
}
|
||||
let lse = log_sum_exp(&log_terms_row);
|
||||
log_u[i] = log_a[i] - lse;
|
||||
max_u_change = max_u_change.max((log_u[i] - old_log_u).abs());
|
||||
}
|
||||
|
||||
// Update log_v: log_v = log_b - log_sum_exp_i(log_u[i] + log_K[i,j])
|
||||
let mut max_v_change: f64 = 0.0;
|
||||
for j in 0..m {
|
||||
let old_log_v = log_v[j];
|
||||
// Compute into pre-allocated buffer
|
||||
for i in 0..n {
|
||||
log_terms_col[i] = log_u[i] + log_k[i][j];
|
||||
}
|
||||
let lse = log_sum_exp(&log_terms_col);
|
||||
log_v[j] = log_b[j] - lse;
|
||||
max_v_change = max_v_change.max((log_v[j] - old_log_v).abs());
|
||||
}
|
||||
|
||||
// Check convergence
|
||||
let max_change = max_u_change.max(max_v_change);
|
||||
|
||||
// Compute marginal error every 10 iterations
|
||||
if iter % 10 == 0 || max_change < self.threshold {
|
||||
marginal_error = self.compute_marginal_error(&log_u, &log_v, &log_k, &a, &b);
|
||||
|
||||
if max_change < self.threshold && marginal_error < self.threshold * 10.0 {
|
||||
converged = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Compute transport plan: γ[i,j] = exp(log_u[i] + log_K[i,j] + log_v[j])
|
||||
let plan: Vec<Vec<f64>> = (0..n)
|
||||
.map(|i| {
|
||||
(0..m)
|
||||
.map(|j| {
|
||||
let log_gamma = log_u[i] + log_k[i][j] + log_v[j];
|
||||
log_gamma.exp().max(0.0)
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Compute transport cost: ⟨γ, C⟩
|
||||
let cost = plan
|
||||
.iter()
|
||||
.zip(cost_matrix.iter())
|
||||
.map(|(gamma_row, cost_row)| {
|
||||
gamma_row
|
||||
.iter()
|
||||
.zip(cost_row.iter())
|
||||
.map(|(&g, &c)| g * c)
|
||||
.sum::<f64>()
|
||||
})
|
||||
.sum();
|
||||
|
||||
Ok(TransportPlan {
|
||||
plan,
|
||||
cost,
|
||||
iterations,
|
||||
marginal_error,
|
||||
converged,
|
||||
})
|
||||
}
|
||||
|
||||
/// Compute marginal constraint error
|
||||
fn compute_marginal_error(
|
||||
&self,
|
||||
log_u: &[f64],
|
||||
log_v: &[f64],
|
||||
log_k: &[Vec<f64>],
|
||||
a: &[f64],
|
||||
b: &[f64],
|
||||
) -> f64 {
|
||||
let n = log_u.len();
|
||||
let m = log_v.len();
|
||||
|
||||
// Compute row sums (γ1 should equal a)
|
||||
let mut row_error = 0.0;
|
||||
for i in 0..n {
|
||||
let log_row_sum = log_sum_exp(
|
||||
&(0..m)
|
||||
.map(|j| log_u[i] + log_k[i][j] + log_v[j])
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
row_error += (log_row_sum.exp() - a[i]).abs();
|
||||
}
|
||||
|
||||
// Compute column sums (γᵀ1 should equal b)
|
||||
let mut col_error = 0.0;
|
||||
for j in 0..m {
|
||||
let log_col_sum = log_sum_exp(
|
||||
&(0..n)
|
||||
.map(|i| log_u[i] + log_k[i][j] + log_v[j])
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
col_error += (log_col_sum.exp() - b[j]).abs();
|
||||
}
|
||||
|
||||
row_error + col_error
|
||||
}
|
||||
|
||||
/// Compute Sinkhorn distance (optimal transport cost) between point clouds
|
||||
pub fn distance(&self, source: &[Vec<f64>], target: &[Vec<f64>]) -> Result<f64> {
|
||||
let cost_matrix = Self::compute_cost_matrix(source, target);
|
||||
|
||||
// Uniform weights
|
||||
let n = source.len();
|
||||
let m = target.len();
|
||||
let source_weights = vec![1.0 / n as f64; n];
|
||||
let target_weights = vec![1.0 / m as f64; m];
|
||||
|
||||
let result = self.solve(&cost_matrix, &source_weights, &target_weights)?;
|
||||
Ok(result.cost)
|
||||
}
|
||||
|
||||
/// Compute Wasserstein barycenter of multiple distributions
|
||||
///
|
||||
/// Returns the barycenter (mean distribution) in transport space
|
||||
pub fn barycenter(
|
||||
&self,
|
||||
distributions: &[&[Vec<f64>]],
|
||||
weights: Option<&[f64]>,
|
||||
support_size: usize,
|
||||
dim: usize,
|
||||
) -> Result<Vec<Vec<f64>>> {
|
||||
if distributions.is_empty() {
|
||||
return Err(MathError::empty_input("distributions"));
|
||||
}
|
||||
|
||||
let k = distributions.len();
|
||||
let barycenter_weights = match weights {
|
||||
Some(w) => {
|
||||
let sum: f64 = w.iter().sum();
|
||||
w.iter().map(|&wi| wi / sum).collect()
|
||||
}
|
||||
None => vec![1.0 / k as f64; k],
|
||||
};
|
||||
|
||||
// Initialize barycenter as mean of first distribution
|
||||
let mut barycenter: Vec<Vec<f64>> = (0..support_size)
|
||||
.map(|i| {
|
||||
let t = i as f64 / (support_size - 1).max(1) as f64;
|
||||
vec![t; dim]
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Fixed-point iteration to find barycenter
|
||||
for _outer in 0..20 {
|
||||
// For each input distribution, compute transport to barycenter
|
||||
let mut displacements = vec![vec![0.0; dim]; support_size];
|
||||
|
||||
for (dist_idx, &distribution) in distributions.iter().enumerate() {
|
||||
let cost_matrix = Self::compute_cost_matrix(distribution, &barycenter);
|
||||
|
||||
let n = distribution.len();
|
||||
let source_w = vec![1.0 / n as f64; n];
|
||||
let target_w = vec![1.0 / support_size as f64; support_size];
|
||||
|
||||
if let Ok(plan) = self.solve(&cost_matrix, &source_w, &target_w) {
|
||||
// Compute displacement from plan
|
||||
for j in 0..support_size {
|
||||
for i in 0..n {
|
||||
let weight = plan.plan[i][j] * support_size as f64;
|
||||
for d in 0..dim {
|
||||
displacements[j][d] += barycenter_weights[dist_idx]
|
||||
* weight
|
||||
* (distribution[i][d] - barycenter[j][d]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update barycenter
|
||||
let mut max_update: f64 = 0.0;
|
||||
for j in 0..support_size {
|
||||
for d in 0..dim {
|
||||
let delta = displacements[j][d] * 0.5; // Step size
|
||||
barycenter[j][d] += delta;
|
||||
max_update = max_update.max(delta.abs());
|
||||
}
|
||||
}
|
||||
|
||||
if max_update < EPS {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(barycenter)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_sinkhorn_identity() {
|
||||
let solver = SinkhornSolver::new(0.1, 100);
|
||||
|
||||
let source = vec![vec![0.0, 0.0], vec![1.0, 1.0]];
|
||||
let target = vec![vec![0.0, 0.0], vec![1.0, 1.0]];
|
||||
|
||||
let cost = solver.distance(&source, &target).unwrap();
|
||||
assert!(cost < 0.1, "Identity should have near-zero cost: {}", cost);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sinkhorn_translation() {
|
||||
let solver = SinkhornSolver::new(0.05, 200);
|
||||
|
||||
let source = vec![
|
||||
vec![0.0, 0.0],
|
||||
vec![1.0, 0.0],
|
||||
vec![0.0, 1.0],
|
||||
vec![1.0, 1.0],
|
||||
];
|
||||
|
||||
// Translate by (1, 0)
|
||||
let target: Vec<Vec<f64>> = source.iter().map(|p| vec![p[0] + 1.0, p[1]]).collect();
|
||||
|
||||
let cost = solver.distance(&source, &target).unwrap();
|
||||
|
||||
// Expected cost for unit translation: each point moves distance 1
|
||||
// With squared Euclidean: cost ≈ 1.0
|
||||
assert!(
|
||||
cost > 0.5 && cost < 2.0,
|
||||
"Translation cost should be ~1.0: {}",
|
||||
cost
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sinkhorn_convergence() {
|
||||
let solver = SinkhornSolver::new(0.1, 100).with_threshold(1e-6);
|
||||
|
||||
let cost_matrix = vec![
|
||||
vec![0.0, 1.0, 2.0],
|
||||
vec![1.0, 0.0, 1.0],
|
||||
vec![2.0, 1.0, 0.0],
|
||||
];
|
||||
|
||||
let a = vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0];
|
||||
let b = vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0];
|
||||
|
||||
let result = solver.solve(&cost_matrix, &a, &b).unwrap();
|
||||
|
||||
assert!(result.converged, "Should converge");
|
||||
assert!(
|
||||
result.marginal_error < 0.01,
|
||||
"Marginal error too high: {}",
|
||||
result.marginal_error
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transport_plan_marginals() {
|
||||
let solver = SinkhornSolver::new(0.1, 100);
|
||||
|
||||
let cost_matrix = vec![vec![0.0, 1.0], vec![1.0, 0.0]];
|
||||
|
||||
let a = vec![0.3, 0.7];
|
||||
let b = vec![0.6, 0.4];
|
||||
|
||||
let result = solver.solve(&cost_matrix, &a, &b).unwrap();
|
||||
|
||||
// Check row marginals
|
||||
for (i, &ai) in a.iter().enumerate() {
|
||||
let row_sum: f64 = result.plan[i].iter().sum();
|
||||
assert!(
|
||||
(row_sum - ai).abs() < 0.05,
|
||||
"Row {} sum {} != {}",
|
||||
i,
|
||||
row_sum,
|
||||
ai
|
||||
);
|
||||
}
|
||||
|
||||
// Check column marginals
|
||||
for (j, &bj) in b.iter().enumerate() {
|
||||
let col_sum: f64 = result.plan.iter().map(|row| row[j]).sum();
|
||||
assert!(
|
||||
(col_sum - bj).abs() < 0.05,
|
||||
"Col {} sum {} != {}",
|
||||
j,
|
||||
col_sum,
|
||||
bj
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user