git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
474 lines
15 KiB
Rust
474 lines
15 KiB
Rust
//! Log-Stabilized Sinkhorn Algorithm
|
||
//!
|
||
//! The Sinkhorn algorithm computes the entropic-regularized optimal transport:
|
||
//!
|
||
//! min_{γ ∈ Π(a,b)} ⟨γ, C⟩ - ε H(γ)
|
||
//!
|
||
//! where H(γ) = -Σ γ_ij log(γ_ij) is the entropy and ε is the regularization.
|
||
//!
|
||
//! ## Log-Stabilization
|
||
//!
|
||
//! We work in log-domain to prevent numerical overflow/underflow:
|
||
//! - Store log(u) and log(v) instead of u, v
|
||
//! - Use log-sum-exp for stable normalization
|
||
//!
|
||
//! ## Complexity
|
||
//!
|
||
//! - O(n² × iterations) for dense cost matrix
|
||
//! - Typically converges in 50-200 iterations
|
||
//! - ~1000x faster than linear programming for exact OT
|
||
|
||
use crate::error::{MathError, Result};
|
||
use crate::utils::{log_sum_exp, EPS, LOG_MIN};
|
||
|
||
/// Result of Sinkhorn algorithm
|
||
#[derive(Debug, Clone)]
|
||
pub struct TransportPlan {
|
||
/// Transport plan matrix γ[i,j] (n × m)
|
||
pub plan: Vec<Vec<f64>>,
|
||
/// Total transport cost
|
||
pub cost: f64,
|
||
/// Number of iterations to convergence
|
||
pub iterations: usize,
|
||
/// Final marginal error (||Pγ - a||₁ + ||γᵀ1 - b||₁)
|
||
pub marginal_error: f64,
|
||
/// Whether the algorithm converged
|
||
pub converged: bool,
|
||
}
|
||
|
||
/// Log-stabilized Sinkhorn solver for entropic optimal transport
|
||
#[derive(Debug, Clone)]
|
||
pub struct SinkhornSolver {
|
||
/// Regularization parameter ε
|
||
regularization: f64,
|
||
/// Maximum iterations
|
||
max_iterations: usize,
|
||
/// Convergence threshold
|
||
threshold: f64,
|
||
}
|
||
|
||
impl SinkhornSolver {
|
||
/// Create a new Sinkhorn solver
|
||
///
|
||
/// # Arguments
|
||
/// * `regularization` - Entropy regularization ε (0.01-0.1 typical)
|
||
/// * `max_iterations` - Maximum Sinkhorn iterations (100-1000 typical)
|
||
pub fn new(regularization: f64, max_iterations: usize) -> Self {
|
||
Self {
|
||
regularization: regularization.max(1e-6),
|
||
max_iterations: max_iterations.max(1),
|
||
threshold: 1e-6,
|
||
}
|
||
}
|
||
|
||
/// Set convergence threshold
|
||
pub fn with_threshold(mut self, threshold: f64) -> Self {
|
||
self.threshold = threshold.max(1e-12);
|
||
self
|
||
}
|
||
|
||
/// Compute the cost matrix for squared Euclidean distance
|
||
/// Uses SIMD-friendly 4-way unrolled accumulator for better performance
|
||
#[inline]
|
||
pub fn compute_cost_matrix(source: &[Vec<f64>], target: &[Vec<f64>]) -> Vec<Vec<f64>> {
|
||
source
|
||
.iter()
|
||
.map(|s| {
|
||
target
|
||
.iter()
|
||
.map(|t| Self::squared_euclidean(s, t))
|
||
.collect()
|
||
})
|
||
.collect()
|
||
}
|
||
|
||
/// SIMD-friendly squared Euclidean distance
|
||
#[inline(always)]
|
||
fn squared_euclidean(a: &[f64], b: &[f64]) -> f64 {
|
||
let len = a.len();
|
||
let chunks = len / 4;
|
||
let remainder = len % 4;
|
||
|
||
let mut sum0 = 0.0f64;
|
||
let mut sum1 = 0.0f64;
|
||
let mut sum2 = 0.0f64;
|
||
let mut sum3 = 0.0f64;
|
||
|
||
for i in 0..chunks {
|
||
let base = i * 4;
|
||
let d0 = a[base] - b[base];
|
||
let d1 = a[base + 1] - b[base + 1];
|
||
let d2 = a[base + 2] - b[base + 2];
|
||
let d3 = a[base + 3] - b[base + 3];
|
||
sum0 += d0 * d0;
|
||
sum1 += d1 * d1;
|
||
sum2 += d2 * d2;
|
||
sum3 += d3 * d3;
|
||
}
|
||
|
||
let base = chunks * 4;
|
||
for i in 0..remainder {
|
||
let d = a[base + i] - b[base + i];
|
||
sum0 += d * d;
|
||
}
|
||
|
||
sum0 + sum1 + sum2 + sum3
|
||
}
|
||
|
||
/// Solve optimal transport using log-stabilized Sinkhorn
|
||
///
|
||
/// # Arguments
|
||
/// * `cost_matrix` - C[i,j] = cost to move from source[i] to target[j]
|
||
/// * `source_weights` - Marginal distribution a (sum to 1)
|
||
/// * `target_weights` - Marginal distribution b (sum to 1)
|
||
pub fn solve(
|
||
&self,
|
||
cost_matrix: &[Vec<f64>],
|
||
source_weights: &[f64],
|
||
target_weights: &[f64],
|
||
) -> Result<TransportPlan> {
|
||
let n = source_weights.len();
|
||
let m = target_weights.len();
|
||
|
||
if n == 0 || m == 0 {
|
||
return Err(MathError::empty_input("weights"));
|
||
}
|
||
|
||
if cost_matrix.len() != n || cost_matrix.iter().any(|row| row.len() != m) {
|
||
return Err(MathError::dimension_mismatch(n, cost_matrix.len()));
|
||
}
|
||
|
||
// Normalize weights
|
||
let sum_a: f64 = source_weights.iter().sum();
|
||
let sum_b: f64 = target_weights.iter().sum();
|
||
let a: Vec<f64> = source_weights.iter().map(|&w| w / sum_a).collect();
|
||
let b: Vec<f64> = target_weights.iter().map(|&w| w / sum_b).collect();
|
||
|
||
// Initialize log-domain Gibbs kernel: K = exp(-C/ε)
|
||
// Store log(K) = -C/ε
|
||
let log_k: Vec<Vec<f64>> = cost_matrix
|
||
.iter()
|
||
.map(|row| row.iter().map(|&c| -c / self.regularization).collect())
|
||
.collect();
|
||
|
||
// Initialize log scaling vectors
|
||
let mut log_u = vec![0.0; n];
|
||
let mut log_v = vec![0.0; m];
|
||
|
||
let log_a: Vec<f64> = a.iter().map(|&ai| ai.ln().max(LOG_MIN)).collect();
|
||
let log_b: Vec<f64> = b.iter().map(|&bi| bi.ln().max(LOG_MIN)).collect();
|
||
|
||
let mut converged = false;
|
||
let mut iterations = 0;
|
||
let mut marginal_error = f64::INFINITY;
|
||
|
||
// Pre-allocate buffers for log-sum-exp computation (reduces allocations per iteration)
|
||
let mut log_terms_row = vec![0.0; m];
|
||
let mut log_terms_col = vec![0.0; n];
|
||
|
||
// Sinkhorn iterations in log domain
|
||
for iter in 0..self.max_iterations {
|
||
iterations = iter + 1;
|
||
|
||
// Update log_u: log_u = log_a - log_sum_exp_j(log_v[j] + log_K[i,j])
|
||
let mut max_u_change: f64 = 0.0;
|
||
for i in 0..n {
|
||
let old_log_u = log_u[i];
|
||
// Compute into pre-allocated buffer
|
||
for j in 0..m {
|
||
log_terms_row[j] = log_v[j] + log_k[i][j];
|
||
}
|
||
let lse = log_sum_exp(&log_terms_row);
|
||
log_u[i] = log_a[i] - lse;
|
||
max_u_change = max_u_change.max((log_u[i] - old_log_u).abs());
|
||
}
|
||
|
||
// Update log_v: log_v = log_b - log_sum_exp_i(log_u[i] + log_K[i,j])
|
||
let mut max_v_change: f64 = 0.0;
|
||
for j in 0..m {
|
||
let old_log_v = log_v[j];
|
||
// Compute into pre-allocated buffer
|
||
for i in 0..n {
|
||
log_terms_col[i] = log_u[i] + log_k[i][j];
|
||
}
|
||
let lse = log_sum_exp(&log_terms_col);
|
||
log_v[j] = log_b[j] - lse;
|
||
max_v_change = max_v_change.max((log_v[j] - old_log_v).abs());
|
||
}
|
||
|
||
// Check convergence
|
||
let max_change = max_u_change.max(max_v_change);
|
||
|
||
// Compute marginal error every 10 iterations
|
||
if iter % 10 == 0 || max_change < self.threshold {
|
||
marginal_error = self.compute_marginal_error(&log_u, &log_v, &log_k, &a, &b);
|
||
|
||
if max_change < self.threshold && marginal_error < self.threshold * 10.0 {
|
||
converged = true;
|
||
break;
|
||
}
|
||
}
|
||
}
|
||
|
||
// Compute transport plan: γ[i,j] = exp(log_u[i] + log_K[i,j] + log_v[j])
|
||
let plan: Vec<Vec<f64>> = (0..n)
|
||
.map(|i| {
|
||
(0..m)
|
||
.map(|j| {
|
||
let log_gamma = log_u[i] + log_k[i][j] + log_v[j];
|
||
log_gamma.exp().max(0.0)
|
||
})
|
||
.collect()
|
||
})
|
||
.collect();
|
||
|
||
// Compute transport cost: ⟨γ, C⟩
|
||
let cost = plan
|
||
.iter()
|
||
.zip(cost_matrix.iter())
|
||
.map(|(gamma_row, cost_row)| {
|
||
gamma_row
|
||
.iter()
|
||
.zip(cost_row.iter())
|
||
.map(|(&g, &c)| g * c)
|
||
.sum::<f64>()
|
||
})
|
||
.sum();
|
||
|
||
Ok(TransportPlan {
|
||
plan,
|
||
cost,
|
||
iterations,
|
||
marginal_error,
|
||
converged,
|
||
})
|
||
}
|
||
|
||
/// Compute marginal constraint error
|
||
fn compute_marginal_error(
|
||
&self,
|
||
log_u: &[f64],
|
||
log_v: &[f64],
|
||
log_k: &[Vec<f64>],
|
||
a: &[f64],
|
||
b: &[f64],
|
||
) -> f64 {
|
||
let n = log_u.len();
|
||
let m = log_v.len();
|
||
|
||
// Compute row sums (γ1 should equal a)
|
||
let mut row_error = 0.0;
|
||
for i in 0..n {
|
||
let log_row_sum = log_sum_exp(
|
||
&(0..m)
|
||
.map(|j| log_u[i] + log_k[i][j] + log_v[j])
|
||
.collect::<Vec<_>>(),
|
||
);
|
||
row_error += (log_row_sum.exp() - a[i]).abs();
|
||
}
|
||
|
||
// Compute column sums (γᵀ1 should equal b)
|
||
let mut col_error = 0.0;
|
||
for j in 0..m {
|
||
let log_col_sum = log_sum_exp(
|
||
&(0..n)
|
||
.map(|i| log_u[i] + log_k[i][j] + log_v[j])
|
||
.collect::<Vec<_>>(),
|
||
);
|
||
col_error += (log_col_sum.exp() - b[j]).abs();
|
||
}
|
||
|
||
row_error + col_error
|
||
}
|
||
|
||
/// Compute Sinkhorn distance (optimal transport cost) between point clouds
|
||
pub fn distance(&self, source: &[Vec<f64>], target: &[Vec<f64>]) -> Result<f64> {
|
||
let cost_matrix = Self::compute_cost_matrix(source, target);
|
||
|
||
// Uniform weights
|
||
let n = source.len();
|
||
let m = target.len();
|
||
let source_weights = vec![1.0 / n as f64; n];
|
||
let target_weights = vec![1.0 / m as f64; m];
|
||
|
||
let result = self.solve(&cost_matrix, &source_weights, &target_weights)?;
|
||
Ok(result.cost)
|
||
}
|
||
|
||
/// Compute Wasserstein barycenter of multiple distributions
|
||
///
|
||
/// Returns the barycenter (mean distribution) in transport space
|
||
pub fn barycenter(
|
||
&self,
|
||
distributions: &[&[Vec<f64>]],
|
||
weights: Option<&[f64]>,
|
||
support_size: usize,
|
||
dim: usize,
|
||
) -> Result<Vec<Vec<f64>>> {
|
||
if distributions.is_empty() {
|
||
return Err(MathError::empty_input("distributions"));
|
||
}
|
||
|
||
let k = distributions.len();
|
||
let barycenter_weights = match weights {
|
||
Some(w) => {
|
||
let sum: f64 = w.iter().sum();
|
||
w.iter().map(|&wi| wi / sum).collect()
|
||
}
|
||
None => vec![1.0 / k as f64; k],
|
||
};
|
||
|
||
// Initialize barycenter as mean of first distribution
|
||
let mut barycenter: Vec<Vec<f64>> = (0..support_size)
|
||
.map(|i| {
|
||
let t = i as f64 / (support_size - 1).max(1) as f64;
|
||
vec![t; dim]
|
||
})
|
||
.collect();
|
||
|
||
// Fixed-point iteration to find barycenter
|
||
for _outer in 0..20 {
|
||
// For each input distribution, compute transport to barycenter
|
||
let mut displacements = vec![vec![0.0; dim]; support_size];
|
||
|
||
for (dist_idx, &distribution) in distributions.iter().enumerate() {
|
||
let cost_matrix = Self::compute_cost_matrix(distribution, &barycenter);
|
||
|
||
let n = distribution.len();
|
||
let source_w = vec![1.0 / n as f64; n];
|
||
let target_w = vec![1.0 / support_size as f64; support_size];
|
||
|
||
if let Ok(plan) = self.solve(&cost_matrix, &source_w, &target_w) {
|
||
// Compute displacement from plan
|
||
for j in 0..support_size {
|
||
for i in 0..n {
|
||
let weight = plan.plan[i][j] * support_size as f64;
|
||
for d in 0..dim {
|
||
displacements[j][d] += barycenter_weights[dist_idx]
|
||
* weight
|
||
* (distribution[i][d] - barycenter[j][d]);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// Update barycenter
|
||
let mut max_update: f64 = 0.0;
|
||
for j in 0..support_size {
|
||
for d in 0..dim {
|
||
let delta = displacements[j][d] * 0.5; // Step size
|
||
barycenter[j][d] += delta;
|
||
max_update = max_update.max(delta.abs());
|
||
}
|
||
}
|
||
|
||
if max_update < EPS {
|
||
break;
|
||
}
|
||
}
|
||
|
||
Ok(barycenter)
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
|
||
#[test]
|
||
fn test_sinkhorn_identity() {
|
||
let solver = SinkhornSolver::new(0.1, 100);
|
||
|
||
let source = vec![vec![0.0, 0.0], vec![1.0, 1.0]];
|
||
let target = vec![vec![0.0, 0.0], vec![1.0, 1.0]];
|
||
|
||
let cost = solver.distance(&source, &target).unwrap();
|
||
assert!(cost < 0.1, "Identity should have near-zero cost: {}", cost);
|
||
}
|
||
|
||
#[test]
|
||
fn test_sinkhorn_translation() {
|
||
let solver = SinkhornSolver::new(0.05, 200);
|
||
|
||
let source = vec![
|
||
vec![0.0, 0.0],
|
||
vec![1.0, 0.0],
|
||
vec![0.0, 1.0],
|
||
vec![1.0, 1.0],
|
||
];
|
||
|
||
// Translate by (1, 0)
|
||
let target: Vec<Vec<f64>> = source.iter().map(|p| vec![p[0] + 1.0, p[1]]).collect();
|
||
|
||
let cost = solver.distance(&source, &target).unwrap();
|
||
|
||
// Expected cost for unit translation: each point moves distance 1
|
||
// With squared Euclidean: cost ≈ 1.0
|
||
assert!(
|
||
cost > 0.5 && cost < 2.0,
|
||
"Translation cost should be ~1.0: {}",
|
||
cost
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn test_sinkhorn_convergence() {
|
||
let solver = SinkhornSolver::new(0.1, 100).with_threshold(1e-6);
|
||
|
||
let cost_matrix = vec![
|
||
vec![0.0, 1.0, 2.0],
|
||
vec![1.0, 0.0, 1.0],
|
||
vec![2.0, 1.0, 0.0],
|
||
];
|
||
|
||
let a = vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0];
|
||
let b = vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0];
|
||
|
||
let result = solver.solve(&cost_matrix, &a, &b).unwrap();
|
||
|
||
assert!(result.converged, "Should converge");
|
||
assert!(
|
||
result.marginal_error < 0.01,
|
||
"Marginal error too high: {}",
|
||
result.marginal_error
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn test_transport_plan_marginals() {
|
||
let solver = SinkhornSolver::new(0.1, 100);
|
||
|
||
let cost_matrix = vec![vec![0.0, 1.0], vec![1.0, 0.0]];
|
||
|
||
let a = vec![0.3, 0.7];
|
||
let b = vec![0.6, 0.4];
|
||
|
||
let result = solver.solve(&cost_matrix, &a, &b).unwrap();
|
||
|
||
// Check row marginals
|
||
for (i, &ai) in a.iter().enumerate() {
|
||
let row_sum: f64 = result.plan[i].iter().sum();
|
||
assert!(
|
||
(row_sum - ai).abs() < 0.05,
|
||
"Row {} sum {} != {}",
|
||
i,
|
||
row_sum,
|
||
ai
|
||
);
|
||
}
|
||
|
||
// Check column marginals
|
||
for (j, &bj) in b.iter().enumerate() {
|
||
let col_sum: f64 = result.plan.iter().map(|row| row[j]).sum();
|
||
assert!(
|
||
(col_sum - bj).abs() < 0.05,
|
||
"Col {} sum {} != {}",
|
||
j,
|
||
col_sum,
|
||
bj
|
||
);
|
||
}
|
||
}
|
||
}
|