Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
703
vendor/ruvector/crates/prime-radiant/src/simd/energy.rs
vendored
Normal file
703
vendor/ruvector/crates/prime-radiant/src/simd/energy.rs
vendored
Normal file
@@ -0,0 +1,703 @@
|
||||
//! # SIMD Energy Computation
|
||||
//!
|
||||
//! High-performance coherence energy computation using SIMD intrinsics.
|
||||
//! These operations are critical for the hot path of coherence evaluation.
|
||||
//!
|
||||
//! ## Key Operations
|
||||
//!
|
||||
//! | Operation | Description | Use Case |
|
||||
//! |-----------|-------------|----------|
|
||||
//! | `batch_residuals_simd` | Compute residuals for multiple edges | Bulk energy update |
|
||||
//! | `batch_residual_norms_simd` | Compute squared norms of residuals | Energy aggregation |
|
||||
//! | `weighted_energy_sum_simd` | Sum residual energies with weights | Total energy |
|
||||
//! | `batch_lane_assignment_simd` | Branchless lane routing | Gate evaluation |
|
||||
//!
|
||||
//! ## Performance Characteristics
|
||||
//!
|
||||
//! The batch operations are designed to process multiple edges in parallel,
|
||||
//! achieving near-optimal memory bandwidth utilization when vector dimensions
|
||||
//! align with SIMD register widths.
|
||||
|
||||
use wide::{f32x8, CmpGe};
|
||||
|
||||
use crate::execution::ComputeLane;
|
||||
|
||||
/// Compute residuals for multiple edges in parallel.
|
||||
///
|
||||
/// Given flattened source and target state vectors, computes the residual
|
||||
/// for each edge: `residual[i] = source[i] - target[i]`
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `sources` - Flattened source states: `[s0_0, s0_1, ..., s1_0, s1_1, ...]`
|
||||
/// * `targets` - Flattened target states: `[t0_0, t0_1, ..., t1_0, t1_1, ...]`
|
||||
/// * `residuals` - Output buffer for residuals (same layout as inputs)
|
||||
/// * `dim` - Dimension of each state vector
|
||||
/// * `count` - Number of edges to process
|
||||
///
|
||||
/// # Layout
|
||||
///
|
||||
/// For `count` edges with `dim`-dimensional states:
|
||||
/// - Total elements = `count * dim`
|
||||
/// - Edge `i` starts at index `i * dim`
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics in debug mode if buffer sizes don't match `dim * count`.
|
||||
#[inline]
|
||||
pub fn batch_residuals_simd(
|
||||
sources: &[f32],
|
||||
targets: &[f32],
|
||||
residuals: &mut [f32],
|
||||
dim: usize,
|
||||
count: usize,
|
||||
) {
|
||||
let total = dim * count;
|
||||
debug_assert_eq!(sources.len(), total);
|
||||
debug_assert_eq!(targets.len(), total);
|
||||
debug_assert_eq!(residuals.len(), total);
|
||||
|
||||
// For small batches, use scalar
|
||||
if total < 32 {
|
||||
batch_residuals_scalar(sources, targets, residuals);
|
||||
return;
|
||||
}
|
||||
|
||||
// SIMD subtraction
|
||||
let chunks_s = sources.chunks_exact(8);
|
||||
let chunks_t = targets.chunks_exact(8);
|
||||
let chunks_r = residuals.chunks_exact_mut(8);
|
||||
|
||||
let remainder_s = chunks_s.remainder();
|
||||
let remainder_t = chunks_t.remainder();
|
||||
let offset = total - remainder_s.len();
|
||||
|
||||
for ((cs, ct), cr) in chunks_s.zip(chunks_t).zip(chunks_r) {
|
||||
let vs = load_f32x8(cs);
|
||||
let vt = load_f32x8(ct);
|
||||
let result = vs - vt;
|
||||
store_f32x8(cr, result);
|
||||
}
|
||||
|
||||
// Handle remainder
|
||||
for (i, (&vs, &vt)) in remainder_s.iter().zip(remainder_t.iter()).enumerate() {
|
||||
residuals[offset + i] = vs - vt;
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute squared norms of residuals for multiple edges.
|
||||
///
|
||||
/// This operation computes `||residual_i||^2` for each edge without
|
||||
/// storing the full residual vectors.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `sources` - Flattened source states
|
||||
/// * `targets` - Flattened target states
|
||||
/// * `norms` - Output buffer for squared norms (length = `count`)
|
||||
/// * `dim` - Dimension of each state vector
|
||||
/// * `count` - Number of edges
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// use prime_radiant::simd::energy::batch_residual_norms_simd;
|
||||
///
|
||||
/// let sources = [1.0, 0.0, 0.0, 0.0]; // 2 edges, dim=2
|
||||
/// let targets = [0.0, 0.0, 1.0, 0.0];
|
||||
/// let mut norms = [0.0f32; 2];
|
||||
///
|
||||
/// batch_residual_norms_simd(&sources, &targets, &mut norms, 2, 2);
|
||||
/// // norms[0] = 1.0 (||[1,0] - [0,0]||^2)
|
||||
/// // norms[1] = 1.0 (||[0,0] - [1,0]||^2)
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn batch_residual_norms_simd(
|
||||
sources: &[f32],
|
||||
targets: &[f32],
|
||||
norms: &mut [f32],
|
||||
dim: usize,
|
||||
count: usize,
|
||||
) {
|
||||
debug_assert_eq!(sources.len(), dim * count);
|
||||
debug_assert_eq!(targets.len(), dim * count);
|
||||
debug_assert_eq!(norms.len(), count);
|
||||
|
||||
// For small dimensions, process edges directly
|
||||
if dim < 16 {
|
||||
for i in 0..count {
|
||||
let offset = i * dim;
|
||||
norms[i] = compute_residual_norm_sq_scalar(
|
||||
&sources[offset..offset + dim],
|
||||
&targets[offset..offset + dim],
|
||||
);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// For larger dimensions, use SIMD per-edge
|
||||
for i in 0..count {
|
||||
let offset = i * dim;
|
||||
norms[i] = compute_residual_norm_sq_simd(
|
||||
&sources[offset..offset + dim],
|
||||
&targets[offset..offset + dim],
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute residual norm squared for a single edge using SIMD.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `source` - Source state vector
|
||||
/// * `target` - Target state vector
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// `||source - target||^2`
|
||||
#[inline]
|
||||
pub fn compute_residual_norm_sq_simd(source: &[f32], target: &[f32]) -> f32 {
|
||||
debug_assert_eq!(source.len(), target.len());
|
||||
|
||||
let len = source.len();
|
||||
|
||||
if len < 16 {
|
||||
return compute_residual_norm_sq_scalar(source, target);
|
||||
}
|
||||
|
||||
let chunks_s = source.chunks_exact(8);
|
||||
let chunks_t = target.chunks_exact(8);
|
||||
let remainder_s = chunks_s.remainder();
|
||||
let remainder_t = chunks_t.remainder();
|
||||
|
||||
let mut acc0 = f32x8::ZERO;
|
||||
let mut acc1 = f32x8::ZERO;
|
||||
|
||||
let mut chunks_s_iter = chunks_s;
|
||||
let mut chunks_t_iter = chunks_t;
|
||||
|
||||
// Unroll 2x
|
||||
while let (Some(cs0), Some(ct0)) = (chunks_s_iter.next(), chunks_t_iter.next()) {
|
||||
let vs0 = load_f32x8(cs0);
|
||||
let vt0 = load_f32x8(ct0);
|
||||
let diff0 = vs0 - vt0;
|
||||
acc0 = diff0.mul_add(diff0, acc0);
|
||||
|
||||
if let (Some(cs1), Some(ct1)) = (chunks_s_iter.next(), chunks_t_iter.next()) {
|
||||
let vs1 = load_f32x8(cs1);
|
||||
let vt1 = load_f32x8(ct1);
|
||||
let diff1 = vs1 - vt1;
|
||||
acc1 = diff1.mul_add(diff1, acc1);
|
||||
}
|
||||
}
|
||||
|
||||
let combined = acc0 + acc1;
|
||||
let mut sum = combined.reduce_add();
|
||||
|
||||
// Handle remainder
|
||||
for (&vs, &vt) in remainder_s.iter().zip(remainder_t.iter()) {
|
||||
let diff = vs - vt;
|
||||
sum += diff * diff;
|
||||
}
|
||||
|
||||
sum
|
||||
}
|
||||
|
||||
/// Compute weighted energy sum using SIMD horizontal reduction.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `residual_norms` - Squared norms of residuals: `||r_e||^2`
|
||||
/// * `weights` - Edge weights: `w_e`
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Total energy: `E(S) = sum(w_e * ||r_e||^2)`
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// use prime_radiant::simd::energy::weighted_energy_sum_simd;
|
||||
///
|
||||
/// let norms = [1.0, 4.0, 9.0, 16.0];
|
||||
/// let weights = [1.0, 0.5, 0.25, 0.125];
|
||||
/// let energy = weighted_energy_sum_simd(&norms, &weights);
|
||||
/// // energy = 1*1 + 0.5*4 + 0.25*9 + 0.125*16 = 1 + 2 + 2.25 + 2 = 7.25
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn weighted_energy_sum_simd(residual_norms: &[f32], weights: &[f32]) -> f32 {
|
||||
debug_assert_eq!(residual_norms.len(), weights.len());
|
||||
|
||||
let len = residual_norms.len();
|
||||
|
||||
if len < 16 {
|
||||
return weighted_energy_sum_scalar(residual_norms, weights);
|
||||
}
|
||||
|
||||
let chunks_n = residual_norms.chunks_exact(8);
|
||||
let chunks_w = weights.chunks_exact(8);
|
||||
let remainder_n = chunks_n.remainder();
|
||||
let remainder_w = chunks_w.remainder();
|
||||
|
||||
let mut acc0 = f32x8::ZERO;
|
||||
let mut acc1 = f32x8::ZERO;
|
||||
|
||||
let mut chunks_n_iter = chunks_n;
|
||||
let mut chunks_w_iter = chunks_w;
|
||||
|
||||
// Unroll 2x
|
||||
while let (Some(cn0), Some(cw0)) = (chunks_n_iter.next(), chunks_w_iter.next()) {
|
||||
let vn0 = load_f32x8(cn0);
|
||||
let vw0 = load_f32x8(cw0);
|
||||
acc0 = vn0.mul_add(vw0, acc0);
|
||||
|
||||
if let (Some(cn1), Some(cw1)) = (chunks_n_iter.next(), chunks_w_iter.next()) {
|
||||
let vn1 = load_f32x8(cn1);
|
||||
let vw1 = load_f32x8(cw1);
|
||||
acc1 = vn1.mul_add(vw1, acc1);
|
||||
}
|
||||
}
|
||||
|
||||
let combined = acc0 + acc1;
|
||||
let mut sum = combined.reduce_add();
|
||||
|
||||
// Handle remainder
|
||||
for (&n, &w) in remainder_n.iter().zip(remainder_w.iter()) {
|
||||
sum += n * w;
|
||||
}
|
||||
|
||||
sum
|
||||
}
|
||||
|
||||
/// Batch lane assignment using branchless SIMD comparison.
|
||||
///
|
||||
/// Assigns each energy value to a compute lane based on threshold comparison.
|
||||
/// Uses branchless operations for consistent performance regardless of data.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `energies` - Array of energy values to route
|
||||
/// * `thresholds` - `[reflex, retrieval, heavy, human]` thresholds
|
||||
/// * `lanes` - Output buffer for lane assignments (as `u8`)
|
||||
///
|
||||
/// # Lane Assignment Logic
|
||||
///
|
||||
/// - `energy < reflex` -> Lane 0 (Reflex)
|
||||
/// - `reflex <= energy < retrieval` -> Lane 1 (Retrieval)
|
||||
/// - `retrieval <= energy < heavy` -> Lane 2 (Heavy)
|
||||
/// - `energy >= heavy` -> Lane 3 (Human)
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// use prime_radiant::simd::energy::batch_lane_assignment_simd;
|
||||
///
|
||||
/// let energies = [0.1, 0.25, 0.6, 0.9];
|
||||
/// let thresholds = [0.2, 0.5, 0.8, 1.0];
|
||||
/// let mut lanes = [0u8; 4];
|
||||
///
|
||||
/// batch_lane_assignment_simd(&energies, thresholds, &mut lanes);
|
||||
/// // lanes = [0, 1, 2, 3] (Reflex, Retrieval, Heavy, Human)
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn batch_lane_assignment_simd(energies: &[f32], thresholds: [f32; 4], lanes: &mut [u8]) {
|
||||
debug_assert_eq!(energies.len(), lanes.len());
|
||||
|
||||
let len = energies.len();
|
||||
|
||||
// Thresholds for lane boundaries
|
||||
let t_reflex = thresholds[0];
|
||||
let t_retrieval = thresholds[1];
|
||||
let t_heavy = thresholds[2];
|
||||
|
||||
if len < 16 {
|
||||
batch_lane_assignment_scalar(energies, thresholds, lanes);
|
||||
return;
|
||||
}
|
||||
|
||||
// SIMD thresholds
|
||||
let vt_reflex = f32x8::splat(t_reflex);
|
||||
let vt_retrieval = f32x8::splat(t_retrieval);
|
||||
let vt_heavy = f32x8::splat(t_heavy);
|
||||
|
||||
let chunks_e = energies.chunks_exact(8);
|
||||
let chunks_l = lanes.chunks_exact_mut(8);
|
||||
|
||||
let remainder_e = chunks_e.remainder();
|
||||
let offset = len - remainder_e.len();
|
||||
|
||||
let v_one = f32x8::splat(1.0);
|
||||
let v_zero = f32x8::ZERO;
|
||||
|
||||
for (ce, cl) in chunks_e.zip(chunks_l) {
|
||||
let ve = load_f32x8(ce);
|
||||
|
||||
// Branchless comparison using SIMD masks
|
||||
let mask_reflex = ve.cmp_ge(vt_reflex);
|
||||
let mask_retrieval = ve.cmp_ge(vt_retrieval);
|
||||
let mask_heavy = ve.cmp_ge(vt_heavy);
|
||||
|
||||
// Convert masks to 1.0/0.0 using blend, then sum
|
||||
let add_reflex = mask_reflex.blend(v_one, v_zero);
|
||||
let add_retrieval = mask_retrieval.blend(v_one, v_zero);
|
||||
let add_heavy = mask_heavy.blend(v_one, v_zero);
|
||||
|
||||
let lane_floats = add_reflex + add_retrieval + add_heavy;
|
||||
let lane_arr: [f32; 8] = lane_floats.into();
|
||||
|
||||
// Convert to u8 (branchless)
|
||||
for i in 0..8 {
|
||||
cl[i] = (lane_arr[i] as u8).min(3);
|
||||
}
|
||||
}
|
||||
|
||||
// Handle remainder
|
||||
for (i, &e) in remainder_e.iter().enumerate() {
|
||||
let lane = (e >= t_reflex) as u8 + (e >= t_retrieval) as u8 + (e >= t_heavy) as u8;
|
||||
lanes[offset + i] = lane.min(3);
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert lane assignments to ComputeLane enum values.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lane_bytes` - Raw lane assignments (0-3)
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Vector of `ComputeLane` values
|
||||
pub fn lanes_to_enum(lane_bytes: &[u8]) -> Vec<ComputeLane> {
|
||||
lane_bytes
|
||||
.iter()
|
||||
.map(|&b| ComputeLane::from_u8(b).unwrap_or(ComputeLane::Human))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Compute total energy for a graph with batched operations.
|
||||
///
|
||||
/// This is the main entry point for efficient energy computation.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `sources` - Flattened source states
|
||||
/// * `targets` - Flattened target states
|
||||
/// * `weights` - Edge weights
|
||||
/// * `dim` - State vector dimension
|
||||
/// * `count` - Number of edges
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Total coherence energy: `E(S) = sum(w_e * ||r_e||^2)`
|
||||
#[inline]
|
||||
pub fn compute_total_energy_simd(
|
||||
sources: &[f32],
|
||||
targets: &[f32],
|
||||
weights: &[f32],
|
||||
dim: usize,
|
||||
count: usize,
|
||||
) -> f32 {
|
||||
debug_assert_eq!(sources.len(), dim * count);
|
||||
debug_assert_eq!(targets.len(), dim * count);
|
||||
debug_assert_eq!(weights.len(), count);
|
||||
|
||||
// Compute residual norms
|
||||
let mut norms = vec![0.0f32; count];
|
||||
batch_residual_norms_simd(sources, targets, &mut norms, dim, count);
|
||||
|
||||
// Compute weighted sum
|
||||
weighted_energy_sum_simd(&norms, weights)
|
||||
}
|
||||
|
||||
/// Compute per-edge energies for a graph.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `sources` - Flattened source states
|
||||
/// * `targets` - Flattened target states
|
||||
/// * `weights` - Edge weights
|
||||
/// * `energies` - Output buffer for per-edge energies
|
||||
/// * `dim` - State vector dimension
|
||||
/// * `count` - Number of edges
|
||||
#[inline]
|
||||
pub fn compute_edge_energies_simd(
|
||||
sources: &[f32],
|
||||
targets: &[f32],
|
||||
weights: &[f32],
|
||||
energies: &mut [f32],
|
||||
dim: usize,
|
||||
count: usize,
|
||||
) {
|
||||
debug_assert_eq!(sources.len(), dim * count);
|
||||
debug_assert_eq!(targets.len(), dim * count);
|
||||
debug_assert_eq!(weights.len(), count);
|
||||
debug_assert_eq!(energies.len(), count);
|
||||
|
||||
// Compute residual norms
|
||||
batch_residual_norms_simd(sources, targets, energies, dim, count);
|
||||
|
||||
// Multiply by weights in-place
|
||||
if count < 16 {
|
||||
for i in 0..count {
|
||||
energies[i] *= weights[i];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
let chunks_e = energies.chunks_exact_mut(8);
|
||||
let chunks_w = weights.chunks_exact(8);
|
||||
|
||||
let remainder_w = chunks_w.remainder();
|
||||
let offset = count - remainder_w.len();
|
||||
|
||||
for (ce, cw) in chunks_e.zip(chunks_w) {
|
||||
let ve = load_f32x8(ce);
|
||||
let vw = load_f32x8(cw);
|
||||
let result = ve * vw;
|
||||
store_f32x8(ce, result);
|
||||
}
|
||||
|
||||
for (i, &w) in remainder_w.iter().enumerate() {
|
||||
energies[offset + i] *= w;
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Scalar Fallback Implementations
|
||||
// ============================================================================
|
||||
|
||||
#[inline(always)]
|
||||
fn batch_residuals_scalar(sources: &[f32], targets: &[f32], residuals: &mut [f32]) {
|
||||
for ((s, t), r) in sources.iter().zip(targets.iter()).zip(residuals.iter_mut()) {
|
||||
*r = s - t;
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn compute_residual_norm_sq_scalar(source: &[f32], target: &[f32]) -> f32 {
|
||||
let mut sum = 0.0f32;
|
||||
for (&s, &t) in source.iter().zip(target.iter()) {
|
||||
let diff = s - t;
|
||||
sum += diff * diff;
|
||||
}
|
||||
sum
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn weighted_energy_sum_scalar(norms: &[f32], weights: &[f32]) -> f32 {
|
||||
let mut sum = 0.0f32;
|
||||
for (&n, &w) in norms.iter().zip(weights.iter()) {
|
||||
sum += n * w;
|
||||
}
|
||||
sum
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn batch_lane_assignment_scalar(energies: &[f32], thresholds: [f32; 4], lanes: &mut [u8]) {
|
||||
let t_reflex = thresholds[0];
|
||||
let t_retrieval = thresholds[1];
|
||||
let t_heavy = thresholds[2];
|
||||
|
||||
for (e, l) in energies.iter().zip(lanes.iter_mut()) {
|
||||
let lane = (*e >= t_reflex) as u8 + (*e >= t_retrieval) as u8 + (*e >= t_heavy) as u8;
|
||||
*l = lane.min(3);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Helper Functions
|
||||
// ============================================================================
|
||||
|
||||
#[inline(always)]
|
||||
fn load_f32x8(slice: &[f32]) -> f32x8 {
|
||||
debug_assert!(slice.len() >= 8);
|
||||
// Use try_into for direct memory copy instead of element-by-element
|
||||
let arr: [f32; 8] = slice[..8].try_into().unwrap();
|
||||
f32x8::from(arr)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn store_f32x8(slice: &mut [f32], v: f32x8) {
|
||||
debug_assert!(slice.len() >= 8);
|
||||
let arr: [f32; 8] = v.into();
|
||||
slice[..8].copy_from_slice(&arr);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
const EPSILON: f32 = 1e-4;
|
||||
|
||||
fn approx_eq(a: f32, b: f32) -> bool {
|
||||
// Use relative error for larger values
|
||||
let max_abs = a.abs().max(b.abs());
|
||||
if max_abs > 1.0 {
|
||||
(a - b).abs() / max_abs < EPSILON
|
||||
} else {
|
||||
(a - b).abs() < EPSILON
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_residuals_small() {
|
||||
let sources = [1.0, 2.0, 3.0, 4.0];
|
||||
let targets = [0.5, 1.5, 2.5, 3.5];
|
||||
let mut residuals = [0.0f32; 4];
|
||||
|
||||
batch_residuals_simd(&sources, &targets, &mut residuals, 2, 2);
|
||||
|
||||
let expected = [0.5, 0.5, 0.5, 0.5];
|
||||
for (i, (&r, &e)) in residuals.iter().zip(expected.iter()).enumerate() {
|
||||
assert!(approx_eq(r, e), "at {} got {} expected {}", i, r, e);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_residuals_large() {
|
||||
let n = 1024;
|
||||
let sources: Vec<f32> = (0..n).map(|i| i as f32).collect();
|
||||
let targets: Vec<f32> = (0..n).map(|i| i as f32 * 0.5).collect();
|
||||
let mut residuals_simd = vec![0.0f32; n];
|
||||
let mut residuals_scalar = vec![0.0f32; n];
|
||||
|
||||
batch_residuals_simd(&sources, &targets, &mut residuals_simd, 64, 16);
|
||||
batch_residuals_scalar(&sources, &targets, &mut residuals_scalar);
|
||||
|
||||
for (i, (&s, &sc)) in residuals_simd
|
||||
.iter()
|
||||
.zip(residuals_scalar.iter())
|
||||
.enumerate()
|
||||
{
|
||||
assert!(approx_eq(s, sc), "at {} got {} expected {}", i, s, sc);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_residual_norms() {
|
||||
// 2 edges, dim=2
|
||||
let sources = [1.0, 0.0, 0.0, 1.0];
|
||||
let targets = [0.0, 0.0, 1.0, 0.0];
|
||||
let mut norms = [0.0f32; 2];
|
||||
|
||||
batch_residual_norms_simd(&sources, &targets, &mut norms, 2, 2);
|
||||
|
||||
// Edge 0: ||(1,0) - (0,0)||^2 = 1
|
||||
// Edge 1: ||(0,1) - (1,0)||^2 = 1 + 1 = 2
|
||||
assert!(approx_eq(norms[0], 1.0), "got {}", norms[0]);
|
||||
assert!(approx_eq(norms[1], 2.0), "got {}", norms[1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_weighted_energy_sum() {
|
||||
let norms = [1.0, 4.0, 9.0, 16.0];
|
||||
let weights = [1.0, 0.5, 0.25, 0.125];
|
||||
|
||||
let result = weighted_energy_sum_simd(&norms, &weights);
|
||||
// 1*1 + 0.5*4 + 0.25*9 + 0.125*16 = 1 + 2 + 2.25 + 2 = 7.25
|
||||
assert!(approx_eq(result, 7.25), "got {}", result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_weighted_energy_sum_large() {
|
||||
let n = 1024;
|
||||
let norms: Vec<f32> = (0..n).map(|i| i as f32).collect();
|
||||
let weights: Vec<f32> = (0..n).map(|_| 0.5).collect();
|
||||
|
||||
let result = weighted_energy_sum_simd(&norms, &weights);
|
||||
let expected = weighted_energy_sum_scalar(&norms, &weights);
|
||||
assert!(
|
||||
approx_eq(result, expected),
|
||||
"got {} expected {}",
|
||||
result,
|
||||
expected
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_lane_assignment() {
|
||||
let energies = [0.1, 0.25, 0.6, 0.9];
|
||||
let thresholds = [0.2, 0.5, 0.8, 1.0];
|
||||
let mut lanes = [0u8; 4];
|
||||
|
||||
batch_lane_assignment_simd(&energies, thresholds, &mut lanes);
|
||||
|
||||
// 0.1 < 0.2 -> Lane 0
|
||||
// 0.2 <= 0.25 < 0.5 -> Lane 1
|
||||
// 0.5 <= 0.6 < 0.8 -> Lane 2
|
||||
// 0.8 <= 0.9 < 1.0 -> Lane 3
|
||||
assert_eq!(lanes, [0, 1, 2, 3]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_lane_assignment_large() {
|
||||
let n = 1024;
|
||||
let energies: Vec<f32> = (0..n).map(|i| (i as f32) / (n as f32)).collect();
|
||||
let thresholds = [0.2, 0.5, 0.8, 1.0];
|
||||
let mut lanes_simd = vec![0u8; n];
|
||||
let mut lanes_scalar = vec![0u8; n];
|
||||
|
||||
batch_lane_assignment_simd(&energies, thresholds, &mut lanes_simd);
|
||||
batch_lane_assignment_scalar(&energies, thresholds, &mut lanes_scalar);
|
||||
|
||||
assert_eq!(lanes_simd, lanes_scalar);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_total_energy() {
|
||||
// 2 edges, dim=2
|
||||
let sources = [1.0, 0.0, 0.0, 1.0];
|
||||
let targets = [0.0, 0.0, 1.0, 0.0];
|
||||
let weights = [1.0, 2.0];
|
||||
|
||||
let energy = compute_total_energy_simd(&sources, &targets, &weights, 2, 2);
|
||||
|
||||
// Edge 0: w=1, ||r||^2 = 1 -> energy = 1
|
||||
// Edge 1: w=2, ||r||^2 = 2 -> energy = 4
|
||||
// Total = 5
|
||||
assert!(approx_eq(energy, 5.0), "got {}", energy);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_edge_energies() {
|
||||
let sources = [1.0, 0.0, 0.0, 1.0];
|
||||
let targets = [0.0, 0.0, 1.0, 0.0];
|
||||
let weights = [1.0, 2.0];
|
||||
let mut energies = [0.0f32; 2];
|
||||
|
||||
compute_edge_energies_simd(&sources, &targets, &weights, &mut energies, 2, 2);
|
||||
|
||||
assert!(approx_eq(energies[0], 1.0), "got {}", energies[0]);
|
||||
assert!(approx_eq(energies[1], 4.0), "got {}", energies[1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lanes_to_enum() {
|
||||
let bytes = [0u8, 1, 2, 3, 0];
|
||||
let lanes = lanes_to_enum(&bytes);
|
||||
|
||||
assert_eq!(lanes[0], ComputeLane::Reflex);
|
||||
assert_eq!(lanes[1], ComputeLane::Retrieval);
|
||||
assert_eq!(lanes[2], ComputeLane::Heavy);
|
||||
assert_eq!(lanes[3], ComputeLane::Human);
|
||||
assert_eq!(lanes[4], ComputeLane::Reflex);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_residual_norm_consistency() {
|
||||
// Verify SIMD and scalar produce same results
|
||||
let n = 128;
|
||||
let source: Vec<f32> = (0..n).map(|i| (i as f32) * 0.1).collect();
|
||||
let target: Vec<f32> = (0..n).map(|i| (i as f32) * 0.2).collect();
|
||||
|
||||
let simd_result = compute_residual_norm_sq_simd(&source, &target);
|
||||
let scalar_result = compute_residual_norm_sq_scalar(&source, &target);
|
||||
|
||||
assert!(
|
||||
approx_eq(simd_result, scalar_result),
|
||||
"simd={} scalar={}",
|
||||
simd_result,
|
||||
scalar_result
|
||||
);
|
||||
}
|
||||
}
|
||||
580
vendor/ruvector/crates/prime-radiant/src/simd/matrix.rs
vendored
Normal file
580
vendor/ruvector/crates/prime-radiant/src/simd/matrix.rs
vendored
Normal file
@@ -0,0 +1,580 @@
|
||||
//! # SIMD Matrix Operations
|
||||
//!
|
||||
//! High-performance matrix operations using SIMD intrinsics.
|
||||
//! Optimized for small to medium matrices common in coherence computation.
|
||||
//!
|
||||
//! ## Matrix Layout
|
||||
//!
|
||||
//! All matrices are stored in **row-major** order:
|
||||
//! - `A[i][j]` is at index `i * cols + j`
|
||||
//! - This matches Rust's natural 2D array layout
|
||||
//!
|
||||
//! ## Supported Operations
|
||||
//!
|
||||
//! | Operation | Description | Complexity |
|
||||
//! |-----------|-------------|------------|
|
||||
//! | `matmul_simd` | Matrix-matrix multiplication | O(m*k*n) |
|
||||
//! | `matvec_simd` | Matrix-vector multiplication | O(m*n) |
|
||||
//! | `transpose_simd` | Matrix transpose | O(m*n) |
|
||||
//!
|
||||
//! ## Performance Notes
|
||||
//!
|
||||
//! - Uses blocking/tiling for cache-friendly access patterns
|
||||
//! - Prefetches data for next iteration where beneficial
|
||||
//! - Falls back to highly optimized scalar code for small matrices
|
||||
|
||||
use wide::f32x8;
|
||||
|
||||
/// Block size for tiled matrix operations (cache optimization).
|
||||
const BLOCK_SIZE: usize = 64;
|
||||
|
||||
/// Compute matrix-matrix multiplication: C = A * B
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a` - First matrix (m x k), row-major, length = m * k
|
||||
/// * `b` - Second matrix (k x n), row-major, length = k * n
|
||||
/// * `c` - Output matrix (m x n), row-major, length = m * n
|
||||
/// * `m` - Number of rows in A
|
||||
/// * `k` - Number of columns in A (= rows in B)
|
||||
/// * `n` - Number of columns in B
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics in debug mode if buffer sizes don't match dimensions.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// use prime_radiant::simd::matrix::matmul_simd;
|
||||
///
|
||||
/// // 2x3 * 3x2 = 2x2
|
||||
/// let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; // 2x3
|
||||
/// let b = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; // 3x2
|
||||
/// let mut c = [0.0f32; 4]; // 2x2
|
||||
///
|
||||
/// matmul_simd(&a, &b, &mut c, 2, 3, 2);
|
||||
/// // c = [22, 28, 49, 64]
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn matmul_simd(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
|
||||
debug_assert_eq!(a.len(), m * k, "Matrix A size mismatch");
|
||||
debug_assert_eq!(b.len(), k * n, "Matrix B size mismatch");
|
||||
debug_assert_eq!(c.len(), m * n, "Matrix C size mismatch");
|
||||
|
||||
// Clear output
|
||||
c.fill(0.0);
|
||||
|
||||
// For small matrices, use simple implementation
|
||||
if m * n < 256 || k < 8 {
|
||||
matmul_scalar(a, b, c, m, k, n);
|
||||
return;
|
||||
}
|
||||
|
||||
// Blocked/tiled multiplication for cache efficiency
|
||||
matmul_blocked(a, b, c, m, k, n);
|
||||
}
|
||||
|
||||
/// Compute matrix-vector multiplication: y = A * x
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a` - Matrix (m x n), row-major
|
||||
/// * `x` - Input vector (length n)
|
||||
/// * `y` - Output vector (length m)
|
||||
/// * `m` - Number of rows
|
||||
/// * `n` - Number of columns
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics in debug mode if buffer sizes don't match dimensions.
|
||||
#[inline]
|
||||
pub fn matvec_simd(a: &[f32], x: &[f32], y: &mut [f32], m: usize, n: usize) {
|
||||
debug_assert_eq!(a.len(), m * n, "Matrix A size mismatch");
|
||||
debug_assert_eq!(x.len(), n, "Vector x size mismatch");
|
||||
debug_assert_eq!(y.len(), m, "Vector y size mismatch");
|
||||
|
||||
// For small matrices, use scalar implementation
|
||||
if n < 16 {
|
||||
matvec_scalar(a, x, y, m, n);
|
||||
return;
|
||||
}
|
||||
|
||||
// Process each row
|
||||
for i in 0..m {
|
||||
let row_start = i * n;
|
||||
let row = &a[row_start..row_start + n];
|
||||
y[i] = dot_product_simd(row, x);
|
||||
}
|
||||
}
|
||||
|
||||
/// Transpose a matrix: B = A^T
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a` - Input matrix (m x n), row-major
|
||||
/// * `b` - Output matrix (n x m), row-major
|
||||
/// * `m` - Number of rows in A
|
||||
/// * `n` - Number of columns in A
|
||||
#[inline]
|
||||
pub fn transpose_simd(a: &[f32], b: &mut [f32], m: usize, n: usize) {
|
||||
debug_assert_eq!(a.len(), m * n);
|
||||
debug_assert_eq!(b.len(), m * n);
|
||||
|
||||
// For small matrices, use scalar transpose
|
||||
if m < 8 || n < 8 {
|
||||
transpose_scalar(a, b, m, n);
|
||||
return;
|
||||
}
|
||||
|
||||
// Block-based transpose for cache efficiency
|
||||
let block = 8;
|
||||
|
||||
for ii in (0..m).step_by(block) {
|
||||
for jj in (0..n).step_by(block) {
|
||||
// Process block
|
||||
let i_end = (ii + block).min(m);
|
||||
let j_end = (jj + block).min(n);
|
||||
|
||||
for i in ii..i_end {
|
||||
for j in jj..j_end {
|
||||
b[j * m + i] = a[i * n + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute outer product: C = a * b^T
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a` - Column vector (length m)
|
||||
/// * `b` - Row vector (length n)
|
||||
/// * `c` - Output matrix (m x n), row-major
|
||||
#[inline]
|
||||
pub fn outer_product_simd(a: &[f32], b: &[f32], c: &mut [f32]) {
|
||||
let m = a.len();
|
||||
let n = b.len();
|
||||
debug_assert_eq!(c.len(), m * n);
|
||||
|
||||
if n < 16 {
|
||||
// Scalar fallback
|
||||
for i in 0..m {
|
||||
for j in 0..n {
|
||||
c[i * n + j] = a[i] * b[j];
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// SIMD version: each row of C is a[i] * b
|
||||
for i in 0..m {
|
||||
let scalar = a[i];
|
||||
let scalar_vec = f32x8::splat(scalar);
|
||||
let row_start = i * n;
|
||||
|
||||
let chunks_b = b.chunks_exact(8);
|
||||
let chunks_c = c[row_start..row_start + n].chunks_exact_mut(8);
|
||||
let remainder_b = chunks_b.remainder();
|
||||
let offset = n - remainder_b.len();
|
||||
|
||||
for (cb, cc) in chunks_b.zip(chunks_c) {
|
||||
let vb = load_f32x8(cb);
|
||||
let result = vb * scalar_vec;
|
||||
store_f32x8(cc, result);
|
||||
}
|
||||
|
||||
// Handle remainder
|
||||
for (j, &bj) in remainder_b.iter().enumerate() {
|
||||
c[row_start + offset + j] = scalar * bj;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Add two matrices element-wise: C = A + B
|
||||
#[inline]
|
||||
pub fn matadd_simd(a: &[f32], b: &[f32], c: &mut [f32]) {
|
||||
debug_assert_eq!(a.len(), b.len());
|
||||
debug_assert_eq!(a.len(), c.len());
|
||||
|
||||
let n = a.len();
|
||||
|
||||
if n < 16 {
|
||||
for i in 0..n {
|
||||
c[i] = a[i] + b[i];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
let chunks_a = a.chunks_exact(8);
|
||||
let chunks_b = b.chunks_exact(8);
|
||||
let chunks_c = c.chunks_exact_mut(8);
|
||||
|
||||
let remainder_a = chunks_a.remainder();
|
||||
let remainder_b = chunks_b.remainder();
|
||||
let offset = n - remainder_a.len();
|
||||
|
||||
for ((ca, cb), cc) in chunks_a.zip(chunks_b).zip(chunks_c) {
|
||||
let va = load_f32x8(ca);
|
||||
let vb = load_f32x8(cb);
|
||||
let result = va + vb;
|
||||
store_f32x8(cc, result);
|
||||
}
|
||||
|
||||
for (i, (&va, &vb)) in remainder_a.iter().zip(remainder_b.iter()).enumerate() {
|
||||
c[offset + i] = va + vb;
|
||||
}
|
||||
}
|
||||
|
||||
/// Scale a matrix by a scalar: B = alpha * A
|
||||
#[inline]
|
||||
pub fn matscale_simd(a: &[f32], alpha: f32, b: &mut [f32]) {
|
||||
debug_assert_eq!(a.len(), b.len());
|
||||
|
||||
let n = a.len();
|
||||
|
||||
if n < 16 {
|
||||
for i in 0..n {
|
||||
b[i] = alpha * a[i];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
let alpha_vec = f32x8::splat(alpha);
|
||||
|
||||
let chunks_a = a.chunks_exact(8);
|
||||
let chunks_b = b.chunks_exact_mut(8);
|
||||
|
||||
let remainder_a = chunks_a.remainder();
|
||||
let offset = n - remainder_a.len();
|
||||
|
||||
for (ca, cb) in chunks_a.zip(chunks_b) {
|
||||
let va = load_f32x8(ca);
|
||||
let result = va * alpha_vec;
|
||||
store_f32x8(cb, result);
|
||||
}
|
||||
|
||||
for (i, &va) in remainder_a.iter().enumerate() {
|
||||
b[offset + i] = alpha * va;
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Internal Implementations
|
||||
// ============================================================================
|
||||
|
||||
/// Blocked matrix multiplication for cache efficiency.
|
||||
fn matmul_blocked(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
|
||||
// Use smaller block size for k dimension to keep data in L1 cache
|
||||
let bk = BLOCK_SIZE.min(k);
|
||||
let bn = BLOCK_SIZE.min(n);
|
||||
|
||||
for kk in (0..k).step_by(bk) {
|
||||
let k_end = (kk + bk).min(k);
|
||||
|
||||
for jj in (0..n).step_by(bn) {
|
||||
let j_end = (jj + bn).min(n);
|
||||
|
||||
for i in 0..m {
|
||||
let c_row = i * n;
|
||||
let a_row = i * k;
|
||||
|
||||
// Process this block of the output row
|
||||
for kc in kk..k_end {
|
||||
let a_val = a[a_row + kc];
|
||||
let a_vec = f32x8::splat(a_val);
|
||||
let b_row = kc * n;
|
||||
|
||||
// SIMD inner loop
|
||||
let mut j = jj;
|
||||
while j + 8 <= j_end {
|
||||
let b_chunk = &b[b_row + j..b_row + j + 8];
|
||||
let c_chunk = &mut c[c_row + j..c_row + j + 8];
|
||||
|
||||
let vb = load_f32x8(b_chunk);
|
||||
let vc = load_f32x8(c_chunk);
|
||||
let result = a_vec.mul_add(vb, vc);
|
||||
store_f32x8(c_chunk, result);
|
||||
|
||||
j += 8;
|
||||
}
|
||||
|
||||
// Scalar cleanup
|
||||
while j < j_end {
|
||||
c[c_row + j] += a_val * b[b_row + j];
|
||||
j += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Simple scalar matrix multiplication for small matrices.
|
||||
fn matmul_scalar(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
|
||||
for i in 0..m {
|
||||
for j in 0..n {
|
||||
let mut sum = 0.0f32;
|
||||
for kc in 0..k {
|
||||
sum += a[i * k + kc] * b[kc * n + j];
|
||||
}
|
||||
c[i * n + j] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Scalar matrix-vector multiplication.
|
||||
fn matvec_scalar(a: &[f32], x: &[f32], y: &mut [f32], m: usize, n: usize) {
|
||||
for i in 0..m {
|
||||
let mut sum = 0.0f32;
|
||||
let row_start = i * n;
|
||||
for j in 0..n {
|
||||
sum += a[row_start + j] * x[j];
|
||||
}
|
||||
y[i] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
/// Scalar matrix transpose.
|
||||
fn transpose_scalar(a: &[f32], b: &mut [f32], m: usize, n: usize) {
|
||||
for i in 0..m {
|
||||
for j in 0..n {
|
||||
b[j * m + i] = a[i * n + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// SIMD dot product (copied from vectors module to avoid circular dep).
|
||||
fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 {
|
||||
let n = a.len();
|
||||
|
||||
if n < 16 {
|
||||
let mut sum = 0.0f32;
|
||||
for i in 0..n {
|
||||
sum += a[i] * b[i];
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
let chunks_a = a.chunks_exact(8);
|
||||
let chunks_b = b.chunks_exact(8);
|
||||
let remainder_a = chunks_a.remainder();
|
||||
let remainder_b = chunks_b.remainder();
|
||||
|
||||
let mut acc = f32x8::ZERO;
|
||||
|
||||
for (ca, cb) in chunks_a.zip(chunks_b) {
|
||||
let va = load_f32x8(ca);
|
||||
let vb = load_f32x8(cb);
|
||||
acc = va.mul_add(vb, acc);
|
||||
}
|
||||
|
||||
let mut sum = acc.reduce_add();
|
||||
|
||||
for (&va, &vb) in remainder_a.iter().zip(remainder_b.iter()) {
|
||||
sum += va * vb;
|
||||
}
|
||||
|
||||
sum
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Helper Functions
|
||||
// ============================================================================
|
||||
|
||||
#[inline(always)]
|
||||
fn load_f32x8(slice: &[f32]) -> f32x8 {
|
||||
debug_assert!(slice.len() >= 8);
|
||||
// Use try_into for direct memory copy instead of element-by-element
|
||||
let arr: [f32; 8] = slice[..8].try_into().unwrap();
|
||||
f32x8::from(arr)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn store_f32x8(slice: &mut [f32], v: f32x8) {
|
||||
debug_assert!(slice.len() >= 8);
|
||||
let arr: [f32; 8] = v.into();
|
||||
slice[..8].copy_from_slice(&arr);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
const EPSILON: f32 = 1e-3;
|
||||
|
||||
fn approx_eq(a: f32, b: f32) -> bool {
|
||||
// Use relative error for larger values
|
||||
let max_abs = a.abs().max(b.abs());
|
||||
if max_abs > 1.0 {
|
||||
(a - b).abs() / max_abs < EPSILON
|
||||
} else {
|
||||
(a - b).abs() < EPSILON
|
||||
}
|
||||
}
|
||||
|
||||
fn matrices_approx_eq(a: &[f32], b: &[f32]) -> bool {
|
||||
a.len() == b.len() && a.iter().zip(b.iter()).all(|(&x, &y)| approx_eq(x, y))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matmul_small() {
|
||||
// 2x3 * 3x2 = 2x2
|
||||
let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; // 2x3
|
||||
let b = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; // 3x2
|
||||
let mut c = [0.0f32; 4]; // 2x2
|
||||
|
||||
matmul_simd(&a, &b, &mut c, 2, 3, 2);
|
||||
|
||||
// Row 0: [1,2,3] * [1,3,5; 2,4,6] = [1*1+2*3+3*5, 1*2+2*4+3*6] = [22, 28]
|
||||
// Row 1: [4,5,6] * [1,3,5; 2,4,6] = [4*1+5*3+6*5, 4*2+5*4+6*6] = [49, 64]
|
||||
let expected = [22.0, 28.0, 49.0, 64.0];
|
||||
assert!(matrices_approx_eq(&c, &expected), "got {:?}", c);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matmul_identity() {
|
||||
// I * A = A
|
||||
let n = 64;
|
||||
let mut identity = vec![0.0f32; n * n];
|
||||
for i in 0..n {
|
||||
identity[i * n + i] = 1.0;
|
||||
}
|
||||
|
||||
let a: Vec<f32> = (0..n * n).map(|i| i as f32).collect();
|
||||
let mut c = vec![0.0f32; n * n];
|
||||
|
||||
matmul_simd(&identity, &a, &mut c, n, n, n);
|
||||
|
||||
assert!(matrices_approx_eq(&c, &a));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matvec_small() {
|
||||
// 2x3 matrix * 3-vector
|
||||
let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; // 2x3
|
||||
let x = [1.0, 2.0, 3.0]; // 3
|
||||
let mut y = [0.0f32; 2]; // 2
|
||||
|
||||
matvec_simd(&a, &x, &mut y, 2, 3);
|
||||
|
||||
// y[0] = 1*1 + 2*2 + 3*3 = 14
|
||||
// y[1] = 4*1 + 5*2 + 6*3 = 32
|
||||
let expected = [14.0, 32.0];
|
||||
assert!(matrices_approx_eq(&y, &expected), "got {:?}", y);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matvec_large() {
|
||||
let m = 64;
|
||||
let n = 128;
|
||||
|
||||
let a: Vec<f32> = (0..m * n).map(|i| (i as f32) * 0.01).collect();
|
||||
let x: Vec<f32> = (0..n).map(|i| i as f32).collect();
|
||||
let mut y_simd = vec![0.0f32; m];
|
||||
let mut y_scalar = vec![0.0f32; m];
|
||||
|
||||
matvec_simd(&a, &x, &mut y_simd, m, n);
|
||||
matvec_scalar(&a, &x, &mut y_scalar, m, n);
|
||||
|
||||
assert!(matrices_approx_eq(&y_simd, &y_scalar));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transpose_small() {
|
||||
// 2x3 -> 3x2
|
||||
let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; // 2x3
|
||||
let mut b = [0.0f32; 6]; // 3x2
|
||||
|
||||
transpose_simd(&a, &mut b, 2, 3);
|
||||
|
||||
// Transposed: [[1,4], [2,5], [3,6]]
|
||||
let expected = [1.0, 4.0, 2.0, 5.0, 3.0, 6.0];
|
||||
assert_eq!(b, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transpose_large() {
|
||||
let m = 32;
|
||||
let n = 64;
|
||||
|
||||
let a: Vec<f32> = (0..m * n).map(|i| i as f32).collect();
|
||||
let mut b = vec![0.0f32; m * n];
|
||||
|
||||
transpose_simd(&a, &mut b, m, n);
|
||||
|
||||
// Verify transpose property
|
||||
for i in 0..m {
|
||||
for j in 0..n {
|
||||
assert!(
|
||||
approx_eq(a[i * n + j], b[j * m + i]),
|
||||
"mismatch at ({}, {})",
|
||||
i,
|
||||
j
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_outer_product() {
|
||||
let a = [1.0, 2.0, 3.0];
|
||||
let b = [4.0, 5.0];
|
||||
let mut c = [0.0f32; 6];
|
||||
|
||||
outer_product_simd(&a, &b, &mut c);
|
||||
|
||||
// c[i,j] = a[i] * b[j]
|
||||
let expected = [4.0, 5.0, 8.0, 10.0, 12.0, 15.0];
|
||||
assert!(matrices_approx_eq(&c, &expected));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matadd() {
|
||||
let a = [1.0, 2.0, 3.0, 4.0];
|
||||
let b = [5.0, 6.0, 7.0, 8.0];
|
||||
let mut c = [0.0f32; 4];
|
||||
|
||||
matadd_simd(&a, &b, &mut c);
|
||||
|
||||
assert_eq!(c, [6.0, 8.0, 10.0, 12.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matscale() {
|
||||
let a = [1.0, 2.0, 3.0, 4.0];
|
||||
let mut b = [0.0f32; 4];
|
||||
|
||||
matscale_simd(&a, 2.5, &mut b);
|
||||
|
||||
assert!(matrices_approx_eq(&b, &[2.5, 5.0, 7.5, 10.0]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matmul_large() {
|
||||
// Test with sizes that exercise the blocked algorithm
|
||||
let m = 128;
|
||||
let k = 96;
|
||||
let n = 64;
|
||||
|
||||
let a: Vec<f32> = (0..m * k).map(|i| (i as f32) * 0.001).collect();
|
||||
let b: Vec<f32> = (0..k * n).map(|i| (i as f32) * 0.001).collect();
|
||||
let mut c_simd = vec![0.0f32; m * n];
|
||||
let mut c_scalar = vec![0.0f32; m * n];
|
||||
|
||||
matmul_simd(&a, &b, &mut c_simd, m, k, n);
|
||||
matmul_scalar(&a, &b, &mut c_scalar, m, k, n);
|
||||
|
||||
// Allow slightly more tolerance for larger matrices due to accumulation
|
||||
for i in 0..m * n {
|
||||
assert!(
|
||||
(c_simd[i] - c_scalar[i]).abs() < 0.01,
|
||||
"mismatch at {}: {} vs {}",
|
||||
i,
|
||||
c_simd[i],
|
||||
c_scalar[i]
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
332
vendor/ruvector/crates/prime-radiant/src/simd/mod.rs
vendored
Normal file
332
vendor/ruvector/crates/prime-radiant/src/simd/mod.rs
vendored
Normal file
@@ -0,0 +1,332 @@
|
||||
//! # SIMD Optimizations for Prime-Radiant
|
||||
//!
|
||||
//! This module provides explicit SIMD (Single Instruction, Multiple Data) intrinsics
|
||||
//! for high-performance coherence computation. The implementation supports multiple
|
||||
//! SIMD widths with automatic runtime detection.
|
||||
//!
|
||||
//! ## Architecture Support
|
||||
//!
|
||||
//! | Architecture | SIMD Extension | Width | Features |
|
||||
//! |--------------|----------------|-------|----------|
|
||||
//! | x86_64 | SSE4.2 | 128-bit | Baseline vector support |
|
||||
//! | x86_64 | AVX2 | 256-bit | 8x f32 parallel ops |
|
||||
//! | x86_64 | AVX-512 | 512-bit | 16x f32 parallel ops |
|
||||
//! | aarch64 | NEON | 128-bit | ARM vector support |
|
||||
//!
|
||||
//! ## Implementation Strategy
|
||||
//!
|
||||
//! 1. **Primary**: `std::simd` with `portable_simd` feature (nightly)
|
||||
//! 2. **Fallback**: `wide` crate for stable Rust compatibility
|
||||
//! 3. **Scalar**: Auto-vectorizable fallback for unsupported platforms
|
||||
//!
|
||||
//! ## Performance Targets
|
||||
//!
|
||||
//! | Operation | Scalar | SIMD (AVX2) | Speedup |
|
||||
//! |-----------|--------|-------------|---------|
|
||||
//! | `dot_product` (1024-dim) | 1.2us | 0.15us | ~8x |
|
||||
//! | `norm_squared` (1024-dim) | 0.8us | 0.10us | ~8x |
|
||||
//! | `batch_residuals` (256 edges) | 50us | 6.5us | ~7.7x |
|
||||
//! | `batch_lane_assignment` (1024) | 4us | 0.5us | ~8x |
|
||||
//!
|
||||
//! ## Usage
|
||||
//!
|
||||
//! ```rust,ignore
|
||||
//! use prime_radiant::simd::{SimdWidth, best_simd_width, vectors, energy};
|
||||
//!
|
||||
//! // Auto-detect best SIMD width at runtime
|
||||
//! let width = best_simd_width();
|
||||
//! println!("Using {:?}", width);
|
||||
//!
|
||||
//! // SIMD dot product
|
||||
//! let a = [1.0f32; 256];
|
||||
//! let b = [2.0f32; 256];
|
||||
//! let result = vectors::dot_product_simd(&a, &b);
|
||||
//! ```
|
||||
|
||||
pub mod energy;
|
||||
pub mod matrix;
|
||||
pub mod vectors;
|
||||
|
||||
// Re-export key types
|
||||
pub use energy::{
|
||||
batch_lane_assignment_simd, batch_residual_norms_simd, batch_residuals_simd,
|
||||
weighted_energy_sum_simd,
|
||||
};
|
||||
pub use matrix::{matmul_simd, matvec_simd};
|
||||
pub use vectors::{dot_product_simd, norm_squared_simd, scale_simd, subtract_simd};
|
||||
|
||||
/// Available SIMD instruction set widths.
|
||||
///
|
||||
/// The actual width available depends on the CPU and detected features.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
#[repr(u8)]
|
||||
pub enum SimdWidth {
|
||||
/// No SIMD available, use scalar operations
|
||||
Scalar = 0,
|
||||
/// SSE4.2: 128-bit (4x f32)
|
||||
Sse42 = 1,
|
||||
/// AVX2: 256-bit (8x f32)
|
||||
Avx2 = 2,
|
||||
/// AVX-512: 512-bit (16x f32)
|
||||
Avx512 = 3,
|
||||
/// ARM NEON: 128-bit (4x f32)
|
||||
Neon = 4,
|
||||
}
|
||||
|
||||
impl SimdWidth {
|
||||
/// Number of f32 values that can be processed in parallel.
|
||||
#[inline]
|
||||
pub const fn lanes_f32(self) -> usize {
|
||||
match self {
|
||||
SimdWidth::Scalar => 1,
|
||||
SimdWidth::Sse42 | SimdWidth::Neon => 4,
|
||||
SimdWidth::Avx2 => 8,
|
||||
SimdWidth::Avx512 => 16,
|
||||
}
|
||||
}
|
||||
|
||||
/// Number of f64 values that can be processed in parallel.
|
||||
#[inline]
|
||||
pub const fn lanes_f64(self) -> usize {
|
||||
match self {
|
||||
SimdWidth::Scalar => 1,
|
||||
SimdWidth::Sse42 | SimdWidth::Neon => 2,
|
||||
SimdWidth::Avx2 => 4,
|
||||
SimdWidth::Avx512 => 8,
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether this SIMD width is supported on the current CPU.
|
||||
#[inline]
|
||||
pub fn is_supported(self) -> bool {
|
||||
match self {
|
||||
SimdWidth::Scalar => true,
|
||||
SimdWidth::Sse42 => cfg!(target_arch = "x86_64") && is_sse42_supported(),
|
||||
SimdWidth::Avx2 => cfg!(target_arch = "x86_64") && is_avx2_supported(),
|
||||
SimdWidth::Avx512 => cfg!(target_arch = "x86_64") && is_avx512_supported(),
|
||||
SimdWidth::Neon => cfg!(target_arch = "aarch64") && is_neon_supported(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a human-readable name for this SIMD width.
|
||||
pub const fn name(self) -> &'static str {
|
||||
match self {
|
||||
SimdWidth::Scalar => "Scalar",
|
||||
SimdWidth::Sse42 => "SSE4.2",
|
||||
SimdWidth::Avx2 => "AVX2",
|
||||
SimdWidth::Avx512 => "AVX-512",
|
||||
SimdWidth::Neon => "NEON",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SimdWidth {
|
||||
fn default() -> Self {
|
||||
best_simd_width()
|
||||
}
|
||||
}
|
||||
|
||||
/// Detect the best available SIMD width for the current CPU.
|
||||
///
|
||||
/// This function performs runtime CPU feature detection and returns
|
||||
/// the highest-performance SIMD instruction set available.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// use prime_radiant::simd::best_simd_width;
|
||||
///
|
||||
/// let width = best_simd_width();
|
||||
/// match width {
|
||||
/// SimdWidth::Avx512 => println!("AVX-512 available!"),
|
||||
/// SimdWidth::Avx2 => println!("AVX2 available"),
|
||||
/// _ => println!("Using {:?}", width),
|
||||
/// }
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn best_simd_width() -> SimdWidth {
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
{
|
||||
if is_avx512_supported() {
|
||||
return SimdWidth::Avx512;
|
||||
}
|
||||
if is_avx2_supported() {
|
||||
return SimdWidth::Avx2;
|
||||
}
|
||||
if is_sse42_supported() {
|
||||
return SimdWidth::Sse42;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
{
|
||||
if is_neon_supported() {
|
||||
return SimdWidth::Neon;
|
||||
}
|
||||
}
|
||||
|
||||
SimdWidth::Scalar
|
||||
}
|
||||
|
||||
/// Check if SSE4.2 is supported (x86_64).
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[inline]
|
||||
fn is_sse42_supported() -> bool {
|
||||
#[cfg(target_feature = "sse4.2")]
|
||||
{
|
||||
true
|
||||
}
|
||||
#[cfg(not(target_feature = "sse4.2"))]
|
||||
{
|
||||
std::arch::is_x86_feature_detected!("sse4.2")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "x86_64"))]
|
||||
#[inline]
|
||||
fn is_sse42_supported() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
/// Check if AVX2 is supported (x86_64).
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[inline]
|
||||
fn is_avx2_supported() -> bool {
|
||||
#[cfg(target_feature = "avx2")]
|
||||
{
|
||||
true
|
||||
}
|
||||
#[cfg(not(target_feature = "avx2"))]
|
||||
{
|
||||
std::arch::is_x86_feature_detected!("avx2")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "x86_64"))]
|
||||
#[inline]
|
||||
fn is_avx2_supported() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
/// Check if AVX-512 is supported (x86_64).
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[inline]
|
||||
fn is_avx512_supported() -> bool {
|
||||
#[cfg(target_feature = "avx512f")]
|
||||
{
|
||||
true
|
||||
}
|
||||
#[cfg(not(target_feature = "avx512f"))]
|
||||
{
|
||||
std::arch::is_x86_feature_detected!("avx512f")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "x86_64"))]
|
||||
#[inline]
|
||||
fn is_avx512_supported() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
/// Check if NEON is supported (aarch64).
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
#[inline]
|
||||
fn is_neon_supported() -> bool {
|
||||
// NEON is mandatory on aarch64
|
||||
true
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "aarch64"))]
|
||||
#[inline]
|
||||
fn is_neon_supported() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
/// SIMD runtime context for operation dispatch.
|
||||
///
|
||||
/// Caches the detected SIMD width to avoid repeated feature detection.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SimdContext {
|
||||
/// The detected SIMD width for this CPU.
|
||||
pub width: SimdWidth,
|
||||
/// Number of f32 lanes available.
|
||||
pub f32_lanes: usize,
|
||||
/// Number of f64 lanes available.
|
||||
pub f64_lanes: usize,
|
||||
}
|
||||
|
||||
impl SimdContext {
|
||||
/// Create a new SIMD context with auto-detection.
|
||||
pub fn new() -> Self {
|
||||
let width = best_simd_width();
|
||||
Self {
|
||||
width,
|
||||
f32_lanes: width.lanes_f32(),
|
||||
f64_lanes: width.lanes_f64(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a context with a specific SIMD width (for testing).
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the requested width is not supported on this CPU.
|
||||
pub fn with_width(width: SimdWidth) -> Self {
|
||||
assert!(width.is_supported(), "SIMD width {:?} not supported", width);
|
||||
Self {
|
||||
width,
|
||||
f32_lanes: width.lanes_f32(),
|
||||
f64_lanes: width.lanes_f64(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a reference to the global SIMD context.
|
||||
///
|
||||
/// This is lazily initialized on first access.
|
||||
pub fn global() -> &'static SimdContext {
|
||||
use once_cell::sync::Lazy;
|
||||
static CONTEXT: Lazy<SimdContext> = Lazy::new(SimdContext::new);
|
||||
&CONTEXT
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SimdContext {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_simd_width_detection() {
|
||||
let width = best_simd_width();
|
||||
println!("Detected SIMD width: {:?}", width);
|
||||
assert!(width.is_supported());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simd_lanes() {
|
||||
assert_eq!(SimdWidth::Scalar.lanes_f32(), 1);
|
||||
assert_eq!(SimdWidth::Sse42.lanes_f32(), 4);
|
||||
assert_eq!(SimdWidth::Avx2.lanes_f32(), 8);
|
||||
assert_eq!(SimdWidth::Avx512.lanes_f32(), 16);
|
||||
assert_eq!(SimdWidth::Neon.lanes_f32(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simd_context() {
|
||||
let ctx = SimdContext::new();
|
||||
assert!(ctx.width.is_supported());
|
||||
assert_eq!(ctx.f32_lanes, ctx.width.lanes_f32());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_global_context() {
|
||||
let ctx1 = SimdContext::global();
|
||||
let ctx2 = SimdContext::global();
|
||||
assert_eq!(ctx1.width, ctx2.width);
|
||||
}
|
||||
}
|
||||
681
vendor/ruvector/crates/prime-radiant/src/simd/vectors.rs
vendored
Normal file
681
vendor/ruvector/crates/prime-radiant/src/simd/vectors.rs
vendored
Normal file
@@ -0,0 +1,681 @@
|
||||
//! # SIMD Vector Operations
|
||||
//!
|
||||
//! High-performance vector operations using explicit SIMD intrinsics.
|
||||
//! All operations fall back to optimized scalar code when SIMD is unavailable.
|
||||
//!
|
||||
//! ## Supported Operations
|
||||
//!
|
||||
//! | Operation | Description | Complexity |
|
||||
//! |-----------|-------------|------------|
|
||||
//! | `dot_product_simd` | Inner product of two vectors | O(n) |
|
||||
//! | `norm_squared_simd` | Squared L2 norm | O(n) |
|
||||
//! | `subtract_simd` | Element-wise subtraction | O(n) |
|
||||
//! | `scale_simd` | Scalar multiplication | O(n) |
|
||||
//!
|
||||
//! ## Performance Notes
|
||||
//!
|
||||
//! - Vectors should be aligned to cache line boundaries for best performance
|
||||
//! - Processing 8 elements at a time with AVX2 achieves ~8x throughput
|
||||
//! - Small vectors (<32 elements) may not benefit from SIMD overhead
|
||||
|
||||
use wide::f32x8;
|
||||
|
||||
/// Compute the dot product of two f32 slices using SIMD.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a` - First input vector
|
||||
/// * `b` - Second input vector (must have same length as `a`)
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The dot product: sum(a[i] * b[i])
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics in debug mode if vectors have different lengths.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// use prime_radiant::simd::vectors::dot_product_simd;
|
||||
///
|
||||
/// let a = [1.0, 2.0, 3.0, 4.0];
|
||||
/// let b = [4.0, 3.0, 2.0, 1.0];
|
||||
/// let result = dot_product_simd(&a, &b);
|
||||
/// assert!((result - 20.0).abs() < 1e-6);
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 {
|
||||
debug_assert_eq!(a.len(), b.len(), "Vectors must have equal length");
|
||||
|
||||
let len = a.len();
|
||||
|
||||
// Fast path for small vectors - avoid SIMD overhead
|
||||
if len < 16 {
|
||||
return dot_product_scalar(a, b);
|
||||
}
|
||||
|
||||
// Process 8 elements at a time with AVX2/wide
|
||||
let chunks_a = a.chunks_exact(8);
|
||||
let chunks_b = b.chunks_exact(8);
|
||||
let remainder_a = chunks_a.remainder();
|
||||
let remainder_b = chunks_b.remainder();
|
||||
|
||||
// Use 4 accumulators for better ILP (Instruction Level Parallelism)
|
||||
let mut acc0 = f32x8::ZERO;
|
||||
let mut acc1 = f32x8::ZERO;
|
||||
let mut acc2 = f32x8::ZERO;
|
||||
let mut acc3 = f32x8::ZERO;
|
||||
|
||||
let mut chunks_a_iter = chunks_a;
|
||||
let mut chunks_b_iter = chunks_b;
|
||||
|
||||
// Unroll 4x for better throughput
|
||||
while let (Some(ca0), Some(cb0)) = (chunks_a_iter.next(), chunks_b_iter.next()) {
|
||||
let va0 = load_f32x8(ca0);
|
||||
let vb0 = load_f32x8(cb0);
|
||||
acc0 = va0.mul_add(vb0, acc0);
|
||||
|
||||
if let (Some(ca1), Some(cb1)) = (chunks_a_iter.next(), chunks_b_iter.next()) {
|
||||
let va1 = load_f32x8(ca1);
|
||||
let vb1 = load_f32x8(cb1);
|
||||
acc1 = va1.mul_add(vb1, acc1);
|
||||
|
||||
if let (Some(ca2), Some(cb2)) = (chunks_a_iter.next(), chunks_b_iter.next()) {
|
||||
let va2 = load_f32x8(ca2);
|
||||
let vb2 = load_f32x8(cb2);
|
||||
acc2 = va2.mul_add(vb2, acc2);
|
||||
|
||||
if let (Some(ca3), Some(cb3)) = (chunks_a_iter.next(), chunks_b_iter.next()) {
|
||||
let va3 = load_f32x8(ca3);
|
||||
let vb3 = load_f32x8(cb3);
|
||||
acc3 = va3.mul_add(vb3, acc3);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Combine accumulators
|
||||
let combined = acc0 + acc1 + acc2 + acc3;
|
||||
let mut sum = combined.reduce_add();
|
||||
|
||||
// Handle remainder
|
||||
for (&va, &vb) in remainder_a.iter().zip(remainder_b.iter()) {
|
||||
sum += va * vb;
|
||||
}
|
||||
|
||||
sum
|
||||
}
|
||||
|
||||
/// Compute the squared L2 norm of a vector using SIMD.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `v` - Input vector
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The squared norm: sum(v[i]^2)
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// use prime_radiant::simd::vectors::norm_squared_simd;
|
||||
///
|
||||
/// let v = [3.0, 4.0];
|
||||
/// let result = norm_squared_simd(&v);
|
||||
/// assert!((result - 25.0).abs() < 1e-6);
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn norm_squared_simd(v: &[f32]) -> f32 {
|
||||
let len = v.len();
|
||||
|
||||
// Fast path for small vectors
|
||||
if len < 16 {
|
||||
return norm_squared_scalar(v);
|
||||
}
|
||||
|
||||
let chunks = v.chunks_exact(8);
|
||||
let remainder = chunks.remainder();
|
||||
|
||||
// Use 4 accumulators for better ILP
|
||||
let mut acc0 = f32x8::ZERO;
|
||||
let mut acc1 = f32x8::ZERO;
|
||||
let mut acc2 = f32x8::ZERO;
|
||||
let mut acc3 = f32x8::ZERO;
|
||||
|
||||
let mut chunks_iter = chunks;
|
||||
|
||||
// Unroll 4x
|
||||
while let Some(c0) = chunks_iter.next() {
|
||||
let v0 = load_f32x8(c0);
|
||||
acc0 = v0.mul_add(v0, acc0);
|
||||
|
||||
if let Some(c1) = chunks_iter.next() {
|
||||
let v1 = load_f32x8(c1);
|
||||
acc1 = v1.mul_add(v1, acc1);
|
||||
|
||||
if let Some(c2) = chunks_iter.next() {
|
||||
let v2 = load_f32x8(c2);
|
||||
acc2 = v2.mul_add(v2, acc2);
|
||||
|
||||
if let Some(c3) = chunks_iter.next() {
|
||||
let v3 = load_f32x8(c3);
|
||||
acc3 = v3.mul_add(v3, acc3);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Combine accumulators
|
||||
let combined = acc0 + acc1 + acc2 + acc3;
|
||||
let mut sum = combined.reduce_add();
|
||||
|
||||
// Handle remainder
|
||||
for &val in remainder {
|
||||
sum += val * val;
|
||||
}
|
||||
|
||||
sum
|
||||
}
|
||||
|
||||
/// Subtract two vectors element-wise using SIMD: out = a - b
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a` - Minuend vector
|
||||
/// * `b` - Subtrahend vector
|
||||
/// * `out` - Output buffer (must have same length as inputs)
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics in debug mode if vectors have different lengths.
|
||||
#[inline]
|
||||
pub fn subtract_simd(a: &[f32], b: &[f32], out: &mut [f32]) {
|
||||
debug_assert_eq!(a.len(), b.len(), "Input vectors must have equal length");
|
||||
debug_assert_eq!(a.len(), out.len(), "Output must have same length as inputs");
|
||||
|
||||
let len = a.len();
|
||||
|
||||
// Fast path for small vectors
|
||||
if len < 16 {
|
||||
subtract_scalar(a, b, out);
|
||||
return;
|
||||
}
|
||||
|
||||
let chunks_a = a.chunks_exact(8);
|
||||
let chunks_b = b.chunks_exact(8);
|
||||
let chunks_out = out.chunks_exact_mut(8);
|
||||
|
||||
let remainder_a = chunks_a.remainder();
|
||||
let remainder_b = chunks_b.remainder();
|
||||
let offset = len - remainder_a.len();
|
||||
|
||||
for ((ca, cb), cout) in chunks_a.zip(chunks_b).zip(chunks_out) {
|
||||
let va = load_f32x8(ca);
|
||||
let vb = load_f32x8(cb);
|
||||
let result = va - vb;
|
||||
store_f32x8(cout, result);
|
||||
}
|
||||
|
||||
// Handle remainder
|
||||
for (i, (&va, &vb)) in remainder_a.iter().zip(remainder_b.iter()).enumerate() {
|
||||
out[offset + i] = va - vb;
|
||||
}
|
||||
}
|
||||
|
||||
/// Scale a vector by a scalar using SIMD: out = v * scalar
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `v` - Input vector
|
||||
/// * `scalar` - Scaling factor
|
||||
/// * `out` - Output buffer (must have same length as input)
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics in debug mode if output has different length than input.
|
||||
#[inline]
|
||||
pub fn scale_simd(v: &[f32], scalar: f32, out: &mut [f32]) {
|
||||
debug_assert_eq!(v.len(), out.len(), "Output must have same length as input");
|
||||
|
||||
let len = v.len();
|
||||
|
||||
// Fast path for small vectors
|
||||
if len < 16 {
|
||||
scale_scalar(v, scalar, out);
|
||||
return;
|
||||
}
|
||||
|
||||
let scalar_vec = f32x8::splat(scalar);
|
||||
|
||||
let chunks_v = v.chunks_exact(8);
|
||||
let chunks_out = out.chunks_exact_mut(8);
|
||||
|
||||
let remainder_v = chunks_v.remainder();
|
||||
let offset = len - remainder_v.len();
|
||||
|
||||
for (cv, cout) in chunks_v.zip(chunks_out) {
|
||||
let vv = load_f32x8(cv);
|
||||
let result = vv * scalar_vec;
|
||||
store_f32x8(cout, result);
|
||||
}
|
||||
|
||||
// Handle remainder
|
||||
for (i, &val) in remainder_v.iter().enumerate() {
|
||||
out[offset + i] = val * scalar;
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute element-wise sum of squares of differences: sum((a[i] - b[i])^2)
|
||||
///
|
||||
/// This is equivalent to `norm_squared_simd(subtract_simd(a, b))` but more efficient
|
||||
/// as it avoids the intermediate allocation.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a` - First input vector
|
||||
/// * `b` - Second input vector
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The squared distance between the vectors.
|
||||
#[inline]
|
||||
pub fn squared_distance_simd(a: &[f32], b: &[f32]) -> f32 {
|
||||
debug_assert_eq!(a.len(), b.len(), "Vectors must have equal length");
|
||||
|
||||
let len = a.len();
|
||||
|
||||
// Fast path for small vectors
|
||||
if len < 16 {
|
||||
return squared_distance_scalar(a, b);
|
||||
}
|
||||
|
||||
let chunks_a = a.chunks_exact(8);
|
||||
let chunks_b = b.chunks_exact(8);
|
||||
let remainder_a = chunks_a.remainder();
|
||||
let remainder_b = chunks_b.remainder();
|
||||
|
||||
let mut acc0 = f32x8::ZERO;
|
||||
let mut acc1 = f32x8::ZERO;
|
||||
let mut acc2 = f32x8::ZERO;
|
||||
let mut acc3 = f32x8::ZERO;
|
||||
|
||||
let mut chunks_a_iter = chunks_a;
|
||||
let mut chunks_b_iter = chunks_b;
|
||||
|
||||
while let (Some(ca0), Some(cb0)) = (chunks_a_iter.next(), chunks_b_iter.next()) {
|
||||
let va0 = load_f32x8(ca0);
|
||||
let vb0 = load_f32x8(cb0);
|
||||
let diff0 = va0 - vb0;
|
||||
acc0 = diff0.mul_add(diff0, acc0);
|
||||
|
||||
if let (Some(ca1), Some(cb1)) = (chunks_a_iter.next(), chunks_b_iter.next()) {
|
||||
let va1 = load_f32x8(ca1);
|
||||
let vb1 = load_f32x8(cb1);
|
||||
let diff1 = va1 - vb1;
|
||||
acc1 = diff1.mul_add(diff1, acc1);
|
||||
|
||||
if let (Some(ca2), Some(cb2)) = (chunks_a_iter.next(), chunks_b_iter.next()) {
|
||||
let va2 = load_f32x8(ca2);
|
||||
let vb2 = load_f32x8(cb2);
|
||||
let diff2 = va2 - vb2;
|
||||
acc2 = diff2.mul_add(diff2, acc2);
|
||||
|
||||
if let (Some(ca3), Some(cb3)) = (chunks_a_iter.next(), chunks_b_iter.next()) {
|
||||
let va3 = load_f32x8(ca3);
|
||||
let vb3 = load_f32x8(cb3);
|
||||
let diff3 = va3 - vb3;
|
||||
acc3 = diff3.mul_add(diff3, acc3);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let combined = acc0 + acc1 + acc2 + acc3;
|
||||
let mut sum = combined.reduce_add();
|
||||
|
||||
// Handle remainder
|
||||
for (&va, &vb) in remainder_a.iter().zip(remainder_b.iter()) {
|
||||
let diff = va - vb;
|
||||
sum += diff * diff;
|
||||
}
|
||||
|
||||
sum
|
||||
}
|
||||
|
||||
/// Compute weighted sum: sum(a[i] * weights[i])
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `values` - Values to sum
|
||||
/// * `weights` - Corresponding weights
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The weighted sum.
|
||||
#[inline]
|
||||
pub fn weighted_sum_simd(values: &[f32], weights: &[f32]) -> f32 {
|
||||
// This is just a dot product
|
||||
dot_product_simd(values, weights)
|
||||
}
|
||||
|
||||
/// Fused multiply-add for vectors: out = a * b + c
|
||||
///
|
||||
/// Uses FMA instructions when available for better precision and performance.
|
||||
#[inline]
|
||||
pub fn fma_simd(a: &[f32], b: &[f32], c: &[f32], out: &mut [f32]) {
|
||||
debug_assert_eq!(a.len(), b.len());
|
||||
debug_assert_eq!(a.len(), c.len());
|
||||
debug_assert_eq!(a.len(), out.len());
|
||||
|
||||
let len = a.len();
|
||||
|
||||
if len < 16 {
|
||||
for i in 0..len {
|
||||
out[i] = a[i].mul_add(b[i], c[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
let chunks_a = a.chunks_exact(8);
|
||||
let chunks_b = b.chunks_exact(8);
|
||||
let chunks_c = c.chunks_exact(8);
|
||||
let chunks_out = out.chunks_exact_mut(8);
|
||||
|
||||
let remainder_a = chunks_a.remainder();
|
||||
let remainder_b = chunks_b.remainder();
|
||||
let remainder_c = chunks_c.remainder();
|
||||
let offset = len - remainder_a.len();
|
||||
|
||||
for (((ca, cb), cc), cout) in chunks_a.zip(chunks_b).zip(chunks_c).zip(chunks_out) {
|
||||
let va = load_f32x8(ca);
|
||||
let vb = load_f32x8(cb);
|
||||
let vc = load_f32x8(cc);
|
||||
let result = va.mul_add(vb, vc);
|
||||
store_f32x8(cout, result);
|
||||
}
|
||||
|
||||
// Handle remainder
|
||||
for (i, ((&va, &vb), &vc)) in remainder_a
|
||||
.iter()
|
||||
.zip(remainder_b.iter())
|
||||
.zip(remainder_c.iter())
|
||||
.enumerate()
|
||||
{
|
||||
out[offset + i] = va.mul_add(vb, vc);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Scalar Fallback Implementations
|
||||
// ============================================================================
|
||||
|
||||
#[inline(always)]
|
||||
fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
|
||||
// Use 4 accumulators for ILP even in scalar path
|
||||
let chunks_a = a.chunks_exact(4);
|
||||
let chunks_b = b.chunks_exact(4);
|
||||
let rem_a = chunks_a.remainder();
|
||||
let rem_b = chunks_b.remainder();
|
||||
|
||||
let mut acc0 = 0.0f32;
|
||||
let mut acc1 = 0.0f32;
|
||||
let mut acc2 = 0.0f32;
|
||||
let mut acc3 = 0.0f32;
|
||||
|
||||
for (ca, cb) in chunks_a.zip(chunks_b) {
|
||||
acc0 += ca[0] * cb[0];
|
||||
acc1 += ca[1] * cb[1];
|
||||
acc2 += ca[2] * cb[2];
|
||||
acc3 += ca[3] * cb[3];
|
||||
}
|
||||
|
||||
let mut sum = acc0 + acc1 + acc2 + acc3;
|
||||
for (&a, &b) in rem_a.iter().zip(rem_b.iter()) {
|
||||
sum += a * b;
|
||||
}
|
||||
sum
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn norm_squared_scalar(v: &[f32]) -> f32 {
|
||||
let chunks = v.chunks_exact(4);
|
||||
let remainder = chunks.remainder();
|
||||
|
||||
let mut acc0 = 0.0f32;
|
||||
let mut acc1 = 0.0f32;
|
||||
let mut acc2 = 0.0f32;
|
||||
let mut acc3 = 0.0f32;
|
||||
|
||||
for c in chunks {
|
||||
acc0 += c[0] * c[0];
|
||||
acc1 += c[1] * c[1];
|
||||
acc2 += c[2] * c[2];
|
||||
acc3 += c[3] * c[3];
|
||||
}
|
||||
|
||||
let mut sum = acc0 + acc1 + acc2 + acc3;
|
||||
for &x in remainder {
|
||||
sum += x * x;
|
||||
}
|
||||
sum
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn subtract_scalar(a: &[f32], b: &[f32], out: &mut [f32]) {
|
||||
for ((va, vb), vo) in a.iter().zip(b.iter()).zip(out.iter_mut()) {
|
||||
*vo = va - vb;
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn scale_scalar(v: &[f32], scalar: f32, out: &mut [f32]) {
|
||||
for (vi, vo) in v.iter().zip(out.iter_mut()) {
|
||||
*vo = vi * scalar;
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn squared_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
|
||||
let mut sum = 0.0f32;
|
||||
for (&va, &vb) in a.iter().zip(b.iter()) {
|
||||
let diff = va - vb;
|
||||
sum += diff * diff;
|
||||
}
|
||||
sum
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Helper Functions
|
||||
// ============================================================================
|
||||
|
||||
/// Load 8 f32 values into a SIMD register.
|
||||
#[inline(always)]
|
||||
fn load_f32x8(slice: &[f32]) -> f32x8 {
|
||||
debug_assert!(slice.len() >= 8);
|
||||
// Use try_into for direct memory copy instead of element-by-element
|
||||
let arr: [f32; 8] = slice[..8].try_into().unwrap();
|
||||
f32x8::from(arr)
|
||||
}
|
||||
|
||||
/// Store 8 f32 values from a SIMD register to a slice.
|
||||
#[inline(always)]
|
||||
fn store_f32x8(slice: &mut [f32], v: f32x8) {
|
||||
debug_assert!(slice.len() >= 8);
|
||||
let arr: [f32; 8] = v.into();
|
||||
slice[..8].copy_from_slice(&arr);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
const EPSILON: f32 = 1e-4;
|
||||
|
||||
fn approx_eq(a: f32, b: f32) -> bool {
|
||||
// Use relative error for larger values
|
||||
let max_abs = a.abs().max(b.abs());
|
||||
if max_abs > 1.0 {
|
||||
(a - b).abs() / max_abs < EPSILON
|
||||
} else {
|
||||
(a - b).abs() < EPSILON
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dot_product_small() {
|
||||
let a = [1.0, 2.0, 3.0, 4.0];
|
||||
let b = [4.0, 3.0, 2.0, 1.0];
|
||||
let result = dot_product_simd(&a, &b);
|
||||
assert!(approx_eq(result, 20.0), "got {}", result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dot_product_large() {
|
||||
let n = 1024;
|
||||
let a: Vec<f32> = (0..n).map(|i| i as f32).collect();
|
||||
let b: Vec<f32> = (0..n).map(|i| (n - 1 - i) as f32).collect();
|
||||
|
||||
let result = dot_product_simd(&a, &b);
|
||||
let expected = dot_product_scalar(&a, &b);
|
||||
assert!(
|
||||
approx_eq(result, expected),
|
||||
"got {} expected {}",
|
||||
result,
|
||||
expected
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_norm_squared_small() {
|
||||
let v = [3.0, 4.0];
|
||||
let result = norm_squared_simd(&v);
|
||||
assert!(approx_eq(result, 25.0), "got {}", result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_norm_squared_large() {
|
||||
let n = 1024;
|
||||
let v: Vec<f32> = (0..n).map(|i| i as f32 * 0.01).collect();
|
||||
|
||||
let result = norm_squared_simd(&v);
|
||||
let expected = norm_squared_scalar(&v);
|
||||
assert!(
|
||||
approx_eq(result, expected),
|
||||
"got {} expected {}",
|
||||
result,
|
||||
expected
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_subtract_small() {
|
||||
let a = [5.0, 6.0, 7.0, 8.0];
|
||||
let b = [1.0, 2.0, 3.0, 4.0];
|
||||
let mut out = [0.0f32; 4];
|
||||
|
||||
subtract_simd(&a, &b, &mut out);
|
||||
assert_eq!(out, [4.0, 4.0, 4.0, 4.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_subtract_large() {
|
||||
let n = 1024;
|
||||
let a: Vec<f32> = (0..n).map(|i| i as f32).collect();
|
||||
let b: Vec<f32> = (0..n).map(|i| i as f32 * 0.5).collect();
|
||||
let mut out = vec![0.0f32; n];
|
||||
|
||||
subtract_simd(&a, &b, &mut out);
|
||||
|
||||
for i in 0..n {
|
||||
let expected = a[i] - b[i];
|
||||
assert!(
|
||||
approx_eq(out[i], expected),
|
||||
"at {} got {} expected {}",
|
||||
i,
|
||||
out[i],
|
||||
expected
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scale_small() {
|
||||
let v = [1.0, 2.0, 3.0, 4.0];
|
||||
let mut out = [0.0f32; 4];
|
||||
|
||||
scale_simd(&v, 2.0, &mut out);
|
||||
assert_eq!(out, [2.0, 4.0, 6.0, 8.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scale_large() {
|
||||
let n = 1024;
|
||||
let v: Vec<f32> = (0..n).map(|i| i as f32).collect();
|
||||
let mut out = vec![0.0f32; n];
|
||||
let scalar = 3.5;
|
||||
|
||||
scale_simd(&v, scalar, &mut out);
|
||||
|
||||
for i in 0..n {
|
||||
let expected = v[i] * scalar;
|
||||
assert!(
|
||||
approx_eq(out[i], expected),
|
||||
"at {} got {} expected {}",
|
||||
i,
|
||||
out[i],
|
||||
expected
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_squared_distance() {
|
||||
let a = [1.0, 2.0, 3.0];
|
||||
let b = [4.0, 5.0, 6.0];
|
||||
let result = squared_distance_simd(&a, &b);
|
||||
// (4-1)^2 + (5-2)^2 + (6-3)^2 = 9 + 9 + 9 = 27
|
||||
assert!(approx_eq(result, 27.0), "got {}", result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_squared_distance_large() {
|
||||
let n = 1024;
|
||||
let a: Vec<f32> = (0..n).map(|i| i as f32 * 0.1).collect();
|
||||
let b: Vec<f32> = (0..n).map(|i| i as f32 * 0.2).collect();
|
||||
|
||||
let result = squared_distance_simd(&a, &b);
|
||||
let expected = squared_distance_scalar(&a, &b);
|
||||
assert!(
|
||||
approx_eq(result, expected),
|
||||
"got {} expected {}",
|
||||
result,
|
||||
expected
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fma() {
|
||||
let a = [1.0, 2.0, 3.0, 4.0];
|
||||
let b = [2.0, 2.0, 2.0, 2.0];
|
||||
let c = [1.0, 1.0, 1.0, 1.0];
|
||||
let mut out = [0.0f32; 4];
|
||||
|
||||
fma_simd(&a, &b, &c, &mut out);
|
||||
// a * b + c = [3, 5, 7, 9]
|
||||
assert_eq!(out, [3.0, 5.0, 7.0, 9.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_edge_cases() {
|
||||
// Empty vectors
|
||||
assert!(approx_eq(dot_product_simd(&[], &[]), 0.0));
|
||||
assert!(approx_eq(norm_squared_simd(&[]), 0.0));
|
||||
|
||||
// Single element
|
||||
assert!(approx_eq(dot_product_simd(&[3.0], &[4.0]), 12.0));
|
||||
assert!(approx_eq(norm_squared_simd(&[5.0]), 25.0));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user