//! SIMD-accelerated sparse matrix-vector multiply. //! //! Provides [`spmv_simd`], which dispatches to an architecture-specific //! implementation when the `simd` feature is enabled, and falls back to a //! portable scalar loop otherwise. use crate::types::CsrMatrix; /// Sparse matrix-vector multiply with optional SIMD acceleration. /// /// Computes `y = A * x` where `A` is a CSR matrix of `f32` values. pub fn spmv_simd(matrix: &CsrMatrix, x: &[f32], y: &mut [f32]) { assert_eq!(x.len(), matrix.cols, "x length must equal matrix.cols"); assert_eq!(y.len(), matrix.rows, "y length must equal matrix.rows"); #[cfg(all(feature = "simd", target_arch = "x86_64"))] { if is_x86_feature_detected!("avx2") { // SAFETY: we have checked for AVX2 support at runtime. unsafe { spmv_avx2(matrix, x, y); } return; } } spmv_scalar(matrix, x, y); } /// Scalar fallback implementation of SpMV. pub fn spmv_scalar(matrix: &CsrMatrix, x: &[f32], y: &mut [f32]) { for i in 0..matrix.rows { let start = matrix.row_ptr[i]; let end = matrix.row_ptr[i + 1]; let mut sum = 0.0f32; for idx in start..end { let col = matrix.col_indices[idx]; sum += matrix.values[idx] * x[col]; } y[i] = sum; } } /// AVX2-accelerated SpMV for x86_64. /// /// # Safety /// /// - The caller must ensure AVX2 is supported on the current CPU (checked at /// runtime via `is_x86_feature_detected!("avx2")` in [`spmv_simd`]). /// - The caller must ensure `x.len() >= matrix.cols` and /// `y.len() >= matrix.rows`. These are asserted in [`spmv_simd`] before /// dispatching here. /// - The CSR matrix must be structurally valid: `row_ptr[i] <= row_ptr[i+1]`, /// all `col_indices[j] < matrix.cols`, and `values.len() >= row_ptr[rows]`. /// Use [`crate::validation::validate_csr_matrix`] before calling the solver /// to guarantee this. #[cfg(all(feature = "simd", target_arch = "x86_64"))] #[target_feature(enable = "avx2")] unsafe fn spmv_avx2(matrix: &CsrMatrix, x: &[f32], y: &mut [f32]) { use std::arch::x86_64::*; for i in 0..matrix.rows { let start = matrix.row_ptr[i]; let end = matrix.row_ptr[i + 1]; let len = end - start; let mut accum = _mm256_setzero_ps(); let chunks = len / 8; let remainder = len % 8; for chunk in 0..chunks { let base = start + chunk * 8; // SAFETY: `base + 7 < end <= values.len()` because // `chunk < chunks` implies `base + 8 <= start + chunks * 8 <= end`. let vals = _mm256_loadu_ps(matrix.values.as_ptr().add(base)); let mut x_buf = [0.0f32; 8]; for k in 0..8 { // SAFETY: `base + k < end` so `col_indices[base + k]` is in // bounds. `col < matrix.cols <= x.len()` by the CSR structural // invariant (enforced by `validate_csr_matrix`). let col = *matrix.col_indices.get_unchecked(base + k); x_buf[k] = *x.get_unchecked(col); } let x_vec = _mm256_loadu_ps(x_buf.as_ptr()); accum = _mm256_add_ps(accum, _mm256_mul_ps(vals, x_vec)); } let mut sum = horizontal_sum_f32x8(accum); let tail_start = start + chunks * 8; for idx in tail_start..(tail_start + remainder) { // SAFETY: `idx < end <= values.len()` and `col < cols <= x.len()` // by the same CSR structural invariant. let col = *matrix.col_indices.get_unchecked(idx); sum += *matrix.values.get_unchecked(idx) * *x.get_unchecked(col); } // SAFETY: `i < matrix.rows <= y.len()` by the assert in `spmv_simd`. *y.get_unchecked_mut(i) = sum; } } /// Horizontal sum of an AVX2 register (8 x f32 -> 1 x f32). #[cfg(all(feature = "simd", target_arch = "x86_64"))] #[target_feature(enable = "avx2")] unsafe fn horizontal_sum_f32x8(v: std::arch::x86_64::__m256) -> f32 { use std::arch::x86_64::*; let hi = _mm256_extractf128_ps(v, 1); let lo = _mm256_castps256_ps128(v); let sum128 = _mm_add_ps(lo, hi); let shuf = _mm_movehdup_ps(sum128); let sums = _mm_add_ps(sum128, shuf); let shuf2 = _mm_movehl_ps(sums, sums); let result = _mm_add_ss(sums, shuf2); _mm_cvtss_f32(result) } /// Sparse matrix-vector multiply with optional SIMD acceleration for f64. /// /// Computes `y = A * x` where `A` is a CSR matrix of `f64` values. pub fn spmv_simd_f64(matrix: &CsrMatrix, x: &[f64], y: &mut [f64]) { assert_eq!(x.len(), matrix.cols, "x length must equal matrix.cols"); assert_eq!(y.len(), matrix.rows, "y length must equal matrix.rows"); #[cfg(all(feature = "simd", target_arch = "x86_64"))] { if is_x86_feature_detected!("avx2") { unsafe { spmv_avx2_f64(matrix, x, y); } return; } } spmv_scalar_f64(matrix, x, y); } /// Scalar fallback for f64 SpMV. pub fn spmv_scalar_f64(matrix: &CsrMatrix, x: &[f64], y: &mut [f64]) { for i in 0..matrix.rows { let start = matrix.row_ptr[i]; let end = matrix.row_ptr[i + 1]; let mut sum = 0.0f64; for idx in start..end { let col = matrix.col_indices[idx]; sum += matrix.values[idx] * x[col]; } y[i] = sum; } } #[cfg(all(feature = "simd", target_arch = "x86_64"))] #[target_feature(enable = "avx2")] unsafe fn spmv_avx2_f64(matrix: &CsrMatrix, x: &[f64], y: &mut [f64]) { use std::arch::x86_64::*; for i in 0..matrix.rows { let start = matrix.row_ptr[i]; let end = matrix.row_ptr[i + 1]; let len = end - start; let mut accum = _mm256_setzero_pd(); let chunks = len / 4; let remainder = len % 4; for chunk in 0..chunks { let base = start + chunk * 4; let vals = _mm256_loadu_pd(matrix.values.as_ptr().add(base)); let mut x_buf = [0.0f64; 4]; for k in 0..4 { let col = *matrix.col_indices.get_unchecked(base + k); x_buf[k] = *x.get_unchecked(col); } let x_vec = _mm256_loadu_pd(x_buf.as_ptr()); accum = _mm256_add_pd(accum, _mm256_mul_pd(vals, x_vec)); } let mut sum = horizontal_sum_f64x4(accum); let tail_start = start + chunks * 4; for idx in tail_start..(tail_start + remainder) { let col = *matrix.col_indices.get_unchecked(idx); sum += *matrix.values.get_unchecked(idx) * *x.get_unchecked(col); } *y.get_unchecked_mut(i) = sum; } } #[cfg(all(feature = "simd", target_arch = "x86_64"))] #[target_feature(enable = "avx2")] unsafe fn horizontal_sum_f64x4(v: std::arch::x86_64::__m256d) -> f64 { use std::arch::x86_64::*; let hi = _mm256_extractf128_pd(v, 1); let lo = _mm256_castpd256_pd128(v); let sum128 = _mm_add_pd(lo, hi); let hi64 = _mm_unpackhi_pd(sum128, sum128); let result = _mm_add_sd(sum128, hi64); _mm_cvtsd_f64(result) } #[cfg(test)] mod tests { use super::*; use crate::types::CsrMatrix; fn make_test_matrix() -> (CsrMatrix, Vec) { // [2 0 1] [1] [5] // [0 3 0] * [2] = [6] // [1 0 4] [3] [13] let mat = CsrMatrix { values: vec![2.0, 1.0, 3.0, 1.0, 4.0], col_indices: vec![0, 2, 1, 0, 2], row_ptr: vec![0, 2, 3, 5], rows: 3, cols: 3, }; let x = vec![1.0, 2.0, 3.0]; (mat, x) } #[test] fn scalar_spmv_correctness() { let (mat, x) = make_test_matrix(); let mut y = vec![0.0f32; 3]; spmv_scalar(&mat, &x, &mut y); assert!((y[0] - 5.0).abs() < 1e-6); assert!((y[1] - 6.0).abs() < 1e-6); assert!((y[2] - 13.0).abs() < 1e-6); } #[test] fn spmv_simd_dispatch() { let (mat, x) = make_test_matrix(); let mut y = vec![0.0f32; 3]; spmv_simd(&mat, &x, &mut y); assert!((y[0] - 5.0).abs() < 1e-6); assert!((y[1] - 6.0).abs() < 1e-6); assert!((y[2] - 13.0).abs() < 1e-6); } #[test] fn spmv_simd_f64_correctness() { let mat = CsrMatrix:: { values: vec![2.0, 1.0, 3.0, 1.0, 4.0], col_indices: vec![0, 2, 1, 0, 2], row_ptr: vec![0, 2, 3, 5], rows: 3, cols: 3, }; let x = vec![1.0, 2.0, 3.0]; let mut y = vec![0.0f64; 3]; spmv_simd_f64(&mat, &x, &mut y); assert!((y[0] - 5.0).abs() < 1e-10); assert!((y[1] - 6.0).abs() < 1e-10); assert!((y[2] - 13.0).abs() < 1e-10); } #[test] fn scalar_spmv_f64_correctness() { let mat = CsrMatrix:: { values: vec![2.0, 1.0, 3.0, 1.0, 4.0], col_indices: vec![0, 2, 1, 0, 2], row_ptr: vec![0, 2, 3, 5], rows: 3, cols: 3, }; let x = vec![1.0, 2.0, 3.0]; let mut y = vec![0.0f64; 3]; spmv_scalar_f64(&mat, &x, &mut y); assert!((y[0] - 5.0).abs() < 1e-10); assert!((y[1] - 6.0).abs() < 1e-10); assert!((y[2] - 13.0).abs() < 1e-10); } }