//! # SIMD Matrix Operations //! //! High-performance matrix operations using SIMD intrinsics. //! Optimized for small to medium matrices common in coherence computation. //! //! ## Matrix Layout //! //! All matrices are stored in **row-major** order: //! - `A[i][j]` is at index `i * cols + j` //! - This matches Rust's natural 2D array layout //! //! ## Supported Operations //! //! | Operation | Description | Complexity | //! |-----------|-------------|------------| //! | `matmul_simd` | Matrix-matrix multiplication | O(m*k*n) | //! | `matvec_simd` | Matrix-vector multiplication | O(m*n) | //! | `transpose_simd` | Matrix transpose | O(m*n) | //! //! ## Performance Notes //! //! - Uses blocking/tiling for cache-friendly access patterns //! - Prefetches data for next iteration where beneficial //! - Falls back to highly optimized scalar code for small matrices use wide::f32x8; /// Block size for tiled matrix operations (cache optimization). const BLOCK_SIZE: usize = 64; /// Compute matrix-matrix multiplication: C = A * B /// /// # Arguments /// /// * `a` - First matrix (m x k), row-major, length = m * k /// * `b` - Second matrix (k x n), row-major, length = k * n /// * `c` - Output matrix (m x n), row-major, length = m * n /// * `m` - Number of rows in A /// * `k` - Number of columns in A (= rows in B) /// * `n` - Number of columns in B /// /// # Panics /// /// Panics in debug mode if buffer sizes don't match dimensions. /// /// # Example /// /// ```rust,ignore /// use prime_radiant::simd::matrix::matmul_simd; /// /// // 2x3 * 3x2 = 2x2 /// let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; // 2x3 /// let b = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; // 3x2 /// let mut c = [0.0f32; 4]; // 2x2 /// /// matmul_simd(&a, &b, &mut c, 2, 3, 2); /// // c = [22, 28, 49, 64] /// ``` #[inline] pub fn matmul_simd(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) { debug_assert_eq!(a.len(), m * k, "Matrix A size mismatch"); debug_assert_eq!(b.len(), k * n, "Matrix B size mismatch"); debug_assert_eq!(c.len(), m * n, "Matrix C size mismatch"); // Clear output c.fill(0.0); // For small matrices, use simple implementation if m * n < 256 || k < 8 { matmul_scalar(a, b, c, m, k, n); return; } // Blocked/tiled multiplication for cache efficiency matmul_blocked(a, b, c, m, k, n); } /// Compute matrix-vector multiplication: y = A * x /// /// # Arguments /// /// * `a` - Matrix (m x n), row-major /// * `x` - Input vector (length n) /// * `y` - Output vector (length m) /// * `m` - Number of rows /// * `n` - Number of columns /// /// # Panics /// /// Panics in debug mode if buffer sizes don't match dimensions. #[inline] pub fn matvec_simd(a: &[f32], x: &[f32], y: &mut [f32], m: usize, n: usize) { debug_assert_eq!(a.len(), m * n, "Matrix A size mismatch"); debug_assert_eq!(x.len(), n, "Vector x size mismatch"); debug_assert_eq!(y.len(), m, "Vector y size mismatch"); // For small matrices, use scalar implementation if n < 16 { matvec_scalar(a, x, y, m, n); return; } // Process each row for i in 0..m { let row_start = i * n; let row = &a[row_start..row_start + n]; y[i] = dot_product_simd(row, x); } } /// Transpose a matrix: B = A^T /// /// # Arguments /// /// * `a` - Input matrix (m x n), row-major /// * `b` - Output matrix (n x m), row-major /// * `m` - Number of rows in A /// * `n` - Number of columns in A #[inline] pub fn transpose_simd(a: &[f32], b: &mut [f32], m: usize, n: usize) { debug_assert_eq!(a.len(), m * n); debug_assert_eq!(b.len(), m * n); // For small matrices, use scalar transpose if m < 8 || n < 8 { transpose_scalar(a, b, m, n); return; } // Block-based transpose for cache efficiency let block = 8; for ii in (0..m).step_by(block) { for jj in (0..n).step_by(block) { // Process block let i_end = (ii + block).min(m); let j_end = (jj + block).min(n); for i in ii..i_end { for j in jj..j_end { b[j * m + i] = a[i * n + j]; } } } } } /// Compute outer product: C = a * b^T /// /// # Arguments /// /// * `a` - Column vector (length m) /// * `b` - Row vector (length n) /// * `c` - Output matrix (m x n), row-major #[inline] pub fn outer_product_simd(a: &[f32], b: &[f32], c: &mut [f32]) { let m = a.len(); let n = b.len(); debug_assert_eq!(c.len(), m * n); if n < 16 { // Scalar fallback for i in 0..m { for j in 0..n { c[i * n + j] = a[i] * b[j]; } } return; } // SIMD version: each row of C is a[i] * b for i in 0..m { let scalar = a[i]; let scalar_vec = f32x8::splat(scalar); let row_start = i * n; let chunks_b = b.chunks_exact(8); let chunks_c = c[row_start..row_start + n].chunks_exact_mut(8); let remainder_b = chunks_b.remainder(); let offset = n - remainder_b.len(); for (cb, cc) in chunks_b.zip(chunks_c) { let vb = load_f32x8(cb); let result = vb * scalar_vec; store_f32x8(cc, result); } // Handle remainder for (j, &bj) in remainder_b.iter().enumerate() { c[row_start + offset + j] = scalar * bj; } } } /// Add two matrices element-wise: C = A + B #[inline] pub fn matadd_simd(a: &[f32], b: &[f32], c: &mut [f32]) { debug_assert_eq!(a.len(), b.len()); debug_assert_eq!(a.len(), c.len()); let n = a.len(); if n < 16 { for i in 0..n { c[i] = a[i] + b[i]; } return; } let chunks_a = a.chunks_exact(8); let chunks_b = b.chunks_exact(8); let chunks_c = c.chunks_exact_mut(8); let remainder_a = chunks_a.remainder(); let remainder_b = chunks_b.remainder(); let offset = n - remainder_a.len(); for ((ca, cb), cc) in chunks_a.zip(chunks_b).zip(chunks_c) { let va = load_f32x8(ca); let vb = load_f32x8(cb); let result = va + vb; store_f32x8(cc, result); } for (i, (&va, &vb)) in remainder_a.iter().zip(remainder_b.iter()).enumerate() { c[offset + i] = va + vb; } } /// Scale a matrix by a scalar: B = alpha * A #[inline] pub fn matscale_simd(a: &[f32], alpha: f32, b: &mut [f32]) { debug_assert_eq!(a.len(), b.len()); let n = a.len(); if n < 16 { for i in 0..n { b[i] = alpha * a[i]; } return; } let alpha_vec = f32x8::splat(alpha); let chunks_a = a.chunks_exact(8); let chunks_b = b.chunks_exact_mut(8); let remainder_a = chunks_a.remainder(); let offset = n - remainder_a.len(); for (ca, cb) in chunks_a.zip(chunks_b) { let va = load_f32x8(ca); let result = va * alpha_vec; store_f32x8(cb, result); } for (i, &va) in remainder_a.iter().enumerate() { b[offset + i] = alpha * va; } } // ============================================================================ // Internal Implementations // ============================================================================ /// Blocked matrix multiplication for cache efficiency. fn matmul_blocked(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) { // Use smaller block size for k dimension to keep data in L1 cache let bk = BLOCK_SIZE.min(k); let bn = BLOCK_SIZE.min(n); for kk in (0..k).step_by(bk) { let k_end = (kk + bk).min(k); for jj in (0..n).step_by(bn) { let j_end = (jj + bn).min(n); for i in 0..m { let c_row = i * n; let a_row = i * k; // Process this block of the output row for kc in kk..k_end { let a_val = a[a_row + kc]; let a_vec = f32x8::splat(a_val); let b_row = kc * n; // SIMD inner loop let mut j = jj; while j + 8 <= j_end { let b_chunk = &b[b_row + j..b_row + j + 8]; let c_chunk = &mut c[c_row + j..c_row + j + 8]; let vb = load_f32x8(b_chunk); let vc = load_f32x8(c_chunk); let result = a_vec.mul_add(vb, vc); store_f32x8(c_chunk, result); j += 8; } // Scalar cleanup while j < j_end { c[c_row + j] += a_val * b[b_row + j]; j += 1; } } } } } } /// Simple scalar matrix multiplication for small matrices. fn matmul_scalar(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) { for i in 0..m { for j in 0..n { let mut sum = 0.0f32; for kc in 0..k { sum += a[i * k + kc] * b[kc * n + j]; } c[i * n + j] = sum; } } } /// Scalar matrix-vector multiplication. fn matvec_scalar(a: &[f32], x: &[f32], y: &mut [f32], m: usize, n: usize) { for i in 0..m { let mut sum = 0.0f32; let row_start = i * n; for j in 0..n { sum += a[row_start + j] * x[j]; } y[i] = sum; } } /// Scalar matrix transpose. fn transpose_scalar(a: &[f32], b: &mut [f32], m: usize, n: usize) { for i in 0..m { for j in 0..n { b[j * m + i] = a[i * n + j]; } } } /// SIMD dot product (copied from vectors module to avoid circular dep). fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 { let n = a.len(); if n < 16 { let mut sum = 0.0f32; for i in 0..n { sum += a[i] * b[i]; } return sum; } let chunks_a = a.chunks_exact(8); let chunks_b = b.chunks_exact(8); let remainder_a = chunks_a.remainder(); let remainder_b = chunks_b.remainder(); let mut acc = f32x8::ZERO; for (ca, cb) in chunks_a.zip(chunks_b) { let va = load_f32x8(ca); let vb = load_f32x8(cb); acc = va.mul_add(vb, acc); } let mut sum = acc.reduce_add(); for (&va, &vb) in remainder_a.iter().zip(remainder_b.iter()) { sum += va * vb; } sum } // ============================================================================ // Helper Functions // ============================================================================ #[inline(always)] fn load_f32x8(slice: &[f32]) -> f32x8 { debug_assert!(slice.len() >= 8); // Use try_into for direct memory copy instead of element-by-element let arr: [f32; 8] = slice[..8].try_into().unwrap(); f32x8::from(arr) } #[inline(always)] fn store_f32x8(slice: &mut [f32], v: f32x8) { debug_assert!(slice.len() >= 8); let arr: [f32; 8] = v.into(); slice[..8].copy_from_slice(&arr); } #[cfg(test)] mod tests { use super::*; const EPSILON: f32 = 1e-3; fn approx_eq(a: f32, b: f32) -> bool { // Use relative error for larger values let max_abs = a.abs().max(b.abs()); if max_abs > 1.0 { (a - b).abs() / max_abs < EPSILON } else { (a - b).abs() < EPSILON } } fn matrices_approx_eq(a: &[f32], b: &[f32]) -> bool { a.len() == b.len() && a.iter().zip(b.iter()).all(|(&x, &y)| approx_eq(x, y)) } #[test] fn test_matmul_small() { // 2x3 * 3x2 = 2x2 let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; // 2x3 let b = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; // 3x2 let mut c = [0.0f32; 4]; // 2x2 matmul_simd(&a, &b, &mut c, 2, 3, 2); // Row 0: [1,2,3] * [1,3,5; 2,4,6] = [1*1+2*3+3*5, 1*2+2*4+3*6] = [22, 28] // Row 1: [4,5,6] * [1,3,5; 2,4,6] = [4*1+5*3+6*5, 4*2+5*4+6*6] = [49, 64] let expected = [22.0, 28.0, 49.0, 64.0]; assert!(matrices_approx_eq(&c, &expected), "got {:?}", c); } #[test] fn test_matmul_identity() { // I * A = A let n = 64; let mut identity = vec![0.0f32; n * n]; for i in 0..n { identity[i * n + i] = 1.0; } let a: Vec = (0..n * n).map(|i| i as f32).collect(); let mut c = vec![0.0f32; n * n]; matmul_simd(&identity, &a, &mut c, n, n, n); assert!(matrices_approx_eq(&c, &a)); } #[test] fn test_matvec_small() { // 2x3 matrix * 3-vector let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; // 2x3 let x = [1.0, 2.0, 3.0]; // 3 let mut y = [0.0f32; 2]; // 2 matvec_simd(&a, &x, &mut y, 2, 3); // y[0] = 1*1 + 2*2 + 3*3 = 14 // y[1] = 4*1 + 5*2 + 6*3 = 32 let expected = [14.0, 32.0]; assert!(matrices_approx_eq(&y, &expected), "got {:?}", y); } #[test] fn test_matvec_large() { let m = 64; let n = 128; let a: Vec = (0..m * n).map(|i| (i as f32) * 0.01).collect(); let x: Vec = (0..n).map(|i| i as f32).collect(); let mut y_simd = vec![0.0f32; m]; let mut y_scalar = vec![0.0f32; m]; matvec_simd(&a, &x, &mut y_simd, m, n); matvec_scalar(&a, &x, &mut y_scalar, m, n); assert!(matrices_approx_eq(&y_simd, &y_scalar)); } #[test] fn test_transpose_small() { // 2x3 -> 3x2 let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; // 2x3 let mut b = [0.0f32; 6]; // 3x2 transpose_simd(&a, &mut b, 2, 3); // Transposed: [[1,4], [2,5], [3,6]] let expected = [1.0, 4.0, 2.0, 5.0, 3.0, 6.0]; assert_eq!(b, expected); } #[test] fn test_transpose_large() { let m = 32; let n = 64; let a: Vec = (0..m * n).map(|i| i as f32).collect(); let mut b = vec![0.0f32; m * n]; transpose_simd(&a, &mut b, m, n); // Verify transpose property for i in 0..m { for j in 0..n { assert!( approx_eq(a[i * n + j], b[j * m + i]), "mismatch at ({}, {})", i, j ); } } } #[test] fn test_outer_product() { let a = [1.0, 2.0, 3.0]; let b = [4.0, 5.0]; let mut c = [0.0f32; 6]; outer_product_simd(&a, &b, &mut c); // c[i,j] = a[i] * b[j] let expected = [4.0, 5.0, 8.0, 10.0, 12.0, 15.0]; assert!(matrices_approx_eq(&c, &expected)); } #[test] fn test_matadd() { let a = [1.0, 2.0, 3.0, 4.0]; let b = [5.0, 6.0, 7.0, 8.0]; let mut c = [0.0f32; 4]; matadd_simd(&a, &b, &mut c); assert_eq!(c, [6.0, 8.0, 10.0, 12.0]); } #[test] fn test_matscale() { let a = [1.0, 2.0, 3.0, 4.0]; let mut b = [0.0f32; 4]; matscale_simd(&a, 2.5, &mut b); assert!(matrices_approx_eq(&b, &[2.5, 5.0, 7.5, 10.0])); } #[test] fn test_matmul_large() { // Test with sizes that exercise the blocked algorithm let m = 128; let k = 96; let n = 64; let a: Vec = (0..m * k).map(|i| (i as f32) * 0.001).collect(); let b: Vec = (0..k * n).map(|i| (i as f32) * 0.001).collect(); let mut c_simd = vec![0.0f32; m * n]; let mut c_scalar = vec![0.0f32; m * n]; matmul_simd(&a, &b, &mut c_simd, m, k, n); matmul_scalar(&a, &b, &mut c_scalar, m, k, n); // Allow slightly more tolerance for larger matrices due to accumulation for i in 0..m * n { assert!( (c_simd[i] - c_scalar[i]).abs() < 0.01, "mismatch at {}: {} vs {}", i, c_simd[i], c_scalar[i] ); } } }