Files
wifi-densepose/vendor/ruvector/crates/ruvector-solver/src/simd.rs

282 lines
9.2 KiB
Rust

//! SIMD-accelerated sparse matrix-vector multiply.
//!
//! Provides [`spmv_simd`], which dispatches to an architecture-specific
//! implementation when the `simd` feature is enabled, and falls back to a
//! portable scalar loop otherwise.
use crate::types::CsrMatrix;
/// Sparse matrix-vector multiply with optional SIMD acceleration.
///
/// Computes `y = A * x` where `A` is a CSR matrix of `f32` values.
pub fn spmv_simd(matrix: &CsrMatrix<f32>, x: &[f32], y: &mut [f32]) {
assert_eq!(x.len(), matrix.cols, "x length must equal matrix.cols");
assert_eq!(y.len(), matrix.rows, "y length must equal matrix.rows");
#[cfg(all(feature = "simd", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("avx2") {
// SAFETY: we have checked for AVX2 support at runtime.
unsafe {
spmv_avx2(matrix, x, y);
}
return;
}
}
spmv_scalar(matrix, x, y);
}
/// Scalar fallback implementation of SpMV.
pub fn spmv_scalar(matrix: &CsrMatrix<f32>, x: &[f32], y: &mut [f32]) {
for i in 0..matrix.rows {
let start = matrix.row_ptr[i];
let end = matrix.row_ptr[i + 1];
let mut sum = 0.0f32;
for idx in start..end {
let col = matrix.col_indices[idx];
sum += matrix.values[idx] * x[col];
}
y[i] = sum;
}
}
/// AVX2-accelerated SpMV for x86_64.
///
/// # Safety
///
/// - The caller must ensure AVX2 is supported on the current CPU (checked at
/// runtime via `is_x86_feature_detected!("avx2")` in [`spmv_simd`]).
/// - The caller must ensure `x.len() >= matrix.cols` and
/// `y.len() >= matrix.rows`. These are asserted in [`spmv_simd`] before
/// dispatching here.
/// - The CSR matrix must be structurally valid: `row_ptr[i] <= row_ptr[i+1]`,
/// all `col_indices[j] < matrix.cols`, and `values.len() >= row_ptr[rows]`.
/// Use [`crate::validation::validate_csr_matrix`] before calling the solver
/// to guarantee this.
#[cfg(all(feature = "simd", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
unsafe fn spmv_avx2(matrix: &CsrMatrix<f32>, x: &[f32], y: &mut [f32]) {
use std::arch::x86_64::*;
for i in 0..matrix.rows {
let start = matrix.row_ptr[i];
let end = matrix.row_ptr[i + 1];
let len = end - start;
let mut accum = _mm256_setzero_ps();
let chunks = len / 8;
let remainder = len % 8;
for chunk in 0..chunks {
let base = start + chunk * 8;
// SAFETY: `base + 7 < end <= values.len()` because
// `chunk < chunks` implies `base + 8 <= start + chunks * 8 <= end`.
let vals = _mm256_loadu_ps(matrix.values.as_ptr().add(base));
let mut x_buf = [0.0f32; 8];
for k in 0..8 {
// SAFETY: `base + k < end` so `col_indices[base + k]` is in
// bounds. `col < matrix.cols <= x.len()` by the CSR structural
// invariant (enforced by `validate_csr_matrix`).
let col = *matrix.col_indices.get_unchecked(base + k);
x_buf[k] = *x.get_unchecked(col);
}
let x_vec = _mm256_loadu_ps(x_buf.as_ptr());
accum = _mm256_add_ps(accum, _mm256_mul_ps(vals, x_vec));
}
let mut sum = horizontal_sum_f32x8(accum);
let tail_start = start + chunks * 8;
for idx in tail_start..(tail_start + remainder) {
// SAFETY: `idx < end <= values.len()` and `col < cols <= x.len()`
// by the same CSR structural invariant.
let col = *matrix.col_indices.get_unchecked(idx);
sum += *matrix.values.get_unchecked(idx) * *x.get_unchecked(col);
}
// SAFETY: `i < matrix.rows <= y.len()` by the assert in `spmv_simd`.
*y.get_unchecked_mut(i) = sum;
}
}
/// Horizontal sum of an AVX2 register (8 x f32 -> 1 x f32).
#[cfg(all(feature = "simd", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
unsafe fn horizontal_sum_f32x8(v: std::arch::x86_64::__m256) -> f32 {
use std::arch::x86_64::*;
let hi = _mm256_extractf128_ps(v, 1);
let lo = _mm256_castps256_ps128(v);
let sum128 = _mm_add_ps(lo, hi);
let shuf = _mm_movehdup_ps(sum128);
let sums = _mm_add_ps(sum128, shuf);
let shuf2 = _mm_movehl_ps(sums, sums);
let result = _mm_add_ss(sums, shuf2);
_mm_cvtss_f32(result)
}
/// Sparse matrix-vector multiply with optional SIMD acceleration for f64.
///
/// Computes `y = A * x` where `A` is a CSR matrix of `f64` values.
pub fn spmv_simd_f64(matrix: &CsrMatrix<f64>, x: &[f64], y: &mut [f64]) {
assert_eq!(x.len(), matrix.cols, "x length must equal matrix.cols");
assert_eq!(y.len(), matrix.rows, "y length must equal matrix.rows");
#[cfg(all(feature = "simd", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("avx2") {
unsafe {
spmv_avx2_f64(matrix, x, y);
}
return;
}
}
spmv_scalar_f64(matrix, x, y);
}
/// Scalar fallback for f64 SpMV.
pub fn spmv_scalar_f64(matrix: &CsrMatrix<f64>, x: &[f64], y: &mut [f64]) {
for i in 0..matrix.rows {
let start = matrix.row_ptr[i];
let end = matrix.row_ptr[i + 1];
let mut sum = 0.0f64;
for idx in start..end {
let col = matrix.col_indices[idx];
sum += matrix.values[idx] * x[col];
}
y[i] = sum;
}
}
#[cfg(all(feature = "simd", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
unsafe fn spmv_avx2_f64(matrix: &CsrMatrix<f64>, x: &[f64], y: &mut [f64]) {
use std::arch::x86_64::*;
for i in 0..matrix.rows {
let start = matrix.row_ptr[i];
let end = matrix.row_ptr[i + 1];
let len = end - start;
let mut accum = _mm256_setzero_pd();
let chunks = len / 4;
let remainder = len % 4;
for chunk in 0..chunks {
let base = start + chunk * 4;
let vals = _mm256_loadu_pd(matrix.values.as_ptr().add(base));
let mut x_buf = [0.0f64; 4];
for k in 0..4 {
let col = *matrix.col_indices.get_unchecked(base + k);
x_buf[k] = *x.get_unchecked(col);
}
let x_vec = _mm256_loadu_pd(x_buf.as_ptr());
accum = _mm256_add_pd(accum, _mm256_mul_pd(vals, x_vec));
}
let mut sum = horizontal_sum_f64x4(accum);
let tail_start = start + chunks * 4;
for idx in tail_start..(tail_start + remainder) {
let col = *matrix.col_indices.get_unchecked(idx);
sum += *matrix.values.get_unchecked(idx) * *x.get_unchecked(col);
}
*y.get_unchecked_mut(i) = sum;
}
}
#[cfg(all(feature = "simd", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
unsafe fn horizontal_sum_f64x4(v: std::arch::x86_64::__m256d) -> f64 {
use std::arch::x86_64::*;
let hi = _mm256_extractf128_pd(v, 1);
let lo = _mm256_castpd256_pd128(v);
let sum128 = _mm_add_pd(lo, hi);
let hi64 = _mm_unpackhi_pd(sum128, sum128);
let result = _mm_add_sd(sum128, hi64);
_mm_cvtsd_f64(result)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::CsrMatrix;
fn make_test_matrix() -> (CsrMatrix<f32>, Vec<f32>) {
// [2 0 1] [1] [5]
// [0 3 0] * [2] = [6]
// [1 0 4] [3] [13]
let mat = CsrMatrix {
values: vec![2.0, 1.0, 3.0, 1.0, 4.0],
col_indices: vec![0, 2, 1, 0, 2],
row_ptr: vec![0, 2, 3, 5],
rows: 3,
cols: 3,
};
let x = vec![1.0, 2.0, 3.0];
(mat, x)
}
#[test]
fn scalar_spmv_correctness() {
let (mat, x) = make_test_matrix();
let mut y = vec![0.0f32; 3];
spmv_scalar(&mat, &x, &mut y);
assert!((y[0] - 5.0).abs() < 1e-6);
assert!((y[1] - 6.0).abs() < 1e-6);
assert!((y[2] - 13.0).abs() < 1e-6);
}
#[test]
fn spmv_simd_dispatch() {
let (mat, x) = make_test_matrix();
let mut y = vec![0.0f32; 3];
spmv_simd(&mat, &x, &mut y);
assert!((y[0] - 5.0).abs() < 1e-6);
assert!((y[1] - 6.0).abs() < 1e-6);
assert!((y[2] - 13.0).abs() < 1e-6);
}
#[test]
fn spmv_simd_f64_correctness() {
let mat = CsrMatrix::<f64> {
values: vec![2.0, 1.0, 3.0, 1.0, 4.0],
col_indices: vec![0, 2, 1, 0, 2],
row_ptr: vec![0, 2, 3, 5],
rows: 3,
cols: 3,
};
let x = vec![1.0, 2.0, 3.0];
let mut y = vec![0.0f64; 3];
spmv_simd_f64(&mat, &x, &mut y);
assert!((y[0] - 5.0).abs() < 1e-10);
assert!((y[1] - 6.0).abs() < 1e-10);
assert!((y[2] - 13.0).abs() < 1e-10);
}
#[test]
fn scalar_spmv_f64_correctness() {
let mat = CsrMatrix::<f64> {
values: vec![2.0, 1.0, 3.0, 1.0, 4.0],
col_indices: vec![0, 2, 1, 0, 2],
row_ptr: vec![0, 2, 3, 5],
rows: 3,
cols: 3,
};
let x = vec![1.0, 2.0, 3.0];
let mut y = vec![0.0f64; 3];
spmv_scalar_f64(&mat, &x, &mut y);
assert!((y[0] - 5.0).abs() < 1e-10);
assert!((y[1] - 6.0).abs() < 1e-10);
assert!((y[2] - 13.0).abs() < 1e-10);
}
}