#![allow( clippy::all, unused_imports, unused_variables, dead_code, unused_mut, unused_assignments, non_camel_case_types, clippy::approx_constant, unexpected_cfgs, unused_must_use, unused_parens )] //! Matrix Multiplication Benchmarks for M4 Pro //! //! Benchmarks for GEMV, GEMM, and batched GEMM implementations. //! //! ## Running Benchmarks //! //! Single-threaded baseline: //! ```bash //! cargo bench -p ruvllm --features candle --bench matmul_bench -- gemm/512 //! ``` //! //! Parallel (with rayon): //! ```bash //! cargo bench -p ruvllm --features candle,parallel --bench matmul_bench -- gemm/512 //! ``` //! //! ## Performance Targets for M4 Pro //! //! | Operation | Size | Single-thread | Parallel (10 cores) | //! |-----------|------|---------------|---------------------| //! | GEMV | 4096x4096 | <500us | <150us | //! | GEMM | 1024x1024 | <2ms | <500us | //! | GEMM | 2048x2048 | <15ms | <3ms | //! | Batched | 32x128x128 | <2ms | <500us | //! //! Target speedup: 4-6x on 10-core M4 Pro for large matrices. use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; use rand::Rng; const NEON_LANE_WIDTH: usize = 4; const UNROLL_FACTOR: usize = 4; const TILE_M: usize = 64; const TILE_N: usize = 64; const TILE_K: usize = 64; const MR: usize = 4; /// General Matrix-Vector multiplication with NEON #[inline(always)] fn gemv_neon(a: &[f32], x: &[f32], y: &mut [f32], m: usize, n: usize) { debug_assert_eq!(a.len(), m * n); debug_assert_eq!(x.len(), n); debug_assert_eq!(y.len(), m); #[cfg(target_arch = "aarch64")] unsafe { gemv_neon_impl(a, x, y, m, n); } #[cfg(not(target_arch = "aarch64"))] { gemv_scalar(a, x, y, m, n); } } #[cfg(target_arch = "aarch64")] #[inline(always)] unsafe fn gemv_neon_impl(a: &[f32], x: &[f32], y: &mut [f32], m: usize, n: usize) { use std::arch::aarch64::*; let a_ptr = a.as_ptr(); let x_ptr = x.as_ptr(); let y_ptr = y.as_mut_ptr(); let row_chunks = m / MR; for rc in 0..row_chunks { let row_base = rc * MR; let mut sum0 = vdupq_n_f32(0.0); let mut sum1 = vdupq_n_f32(0.0); let mut sum2 = vdupq_n_f32(0.0); let mut sum3 = vdupq_n_f32(0.0); let col_chunks = n / NEON_LANE_WIDTH; let mut col = 0usize; for _ in 0..col_chunks { let x_v = vld1q_f32(x_ptr.add(col)); let a0 = vld1q_f32(a_ptr.add((row_base + 0) * n + col)); sum0 = vfmaq_f32(sum0, a0, x_v); let a1 = vld1q_f32(a_ptr.add((row_base + 1) * n + col)); sum1 = vfmaq_f32(sum1, a1, x_v); let a2 = vld1q_f32(a_ptr.add((row_base + 2) * n + col)); sum2 = vfmaq_f32(sum2, a2, x_v); let a3 = vld1q_f32(a_ptr.add((row_base + 3) * n + col)); sum3 = vfmaq_f32(sum3, a3, x_v); col += 4; } let mut y0 = vaddvq_f32(sum0); let mut y1 = vaddvq_f32(sum1); let mut y2 = vaddvq_f32(sum2); let mut y3 = vaddvq_f32(sum3); for c in col..n { let x_val = *x_ptr.add(c); y0 += *a_ptr.add((row_base + 0) * n + c) * x_val; y1 += *a_ptr.add((row_base + 1) * n + c) * x_val; y2 += *a_ptr.add((row_base + 2) * n + c) * x_val; y3 += *a_ptr.add((row_base + 3) * n + c) * x_val; } *y_ptr.add(row_base + 0) = y0; *y_ptr.add(row_base + 1) = y1; *y_ptr.add(row_base + 2) = y2; *y_ptr.add(row_base + 3) = y3; } for row in (row_chunks * MR)..m { let mut sum = vdupq_n_f32(0.0); let col_chunks = n / NEON_LANE_WIDTH; let mut col = 0usize; for _ in 0..col_chunks { let x_v = vld1q_f32(x_ptr.add(col)); let a_v = vld1q_f32(a_ptr.add(row * n + col)); sum = vfmaq_f32(sum, a_v, x_v); col += 4; } let mut y_val = vaddvq_f32(sum); for c in col..n { y_val += *a_ptr.add(row * n + c) * *x_ptr.add(c); } *y_ptr.add(row) = y_val; } } #[allow(dead_code)] fn gemv_scalar(a: &[f32], x: &[f32], y: &mut [f32], m: usize, n: usize) { for row in 0..m { let mut sum = 0.0f32; for col in 0..n { sum += a[row * n + col] * x[col]; } y[row] = sum; } } /// General Matrix-Matrix multiplication with NEON #[inline(always)] fn gemm_neon(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) { debug_assert_eq!(a.len(), m * k); debug_assert_eq!(b.len(), k * n); debug_assert_eq!(c.len(), m * n); c.fill(0.0); #[cfg(target_arch = "aarch64")] unsafe { gemm_neon_impl(a, b, c, m, k, n); } #[cfg(not(target_arch = "aarch64"))] { gemm_scalar(a, b, c, m, k, n); } } #[cfg(target_arch = "aarch64")] #[inline(always)] unsafe fn gemm_neon_impl(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) { use std::arch::aarch64::*; let a_ptr = a.as_ptr(); let b_ptr = b.as_ptr(); let c_ptr = c.as_mut_ptr(); let mut i = 0usize; while i < m { let i_end = (i + TILE_M).min(m); let mut j = 0usize; while j < n { let j_end = (j + TILE_N).min(n); let mut kk = 0usize; while kk < k { let kk_end = (kk + TILE_K).min(k); for ii in i..i_end { for jj in (j..j_end).step_by(NEON_LANE_WIDTH) { let j_remaining = (j_end - jj).min(NEON_LANE_WIDTH); if j_remaining == NEON_LANE_WIDTH { let mut acc = vld1q_f32(c_ptr.add(ii * n + jj)); for kkk in kk..kk_end { let a_val = vdupq_n_f32(*a_ptr.add(ii * k + kkk)); let b_v = vld1q_f32(b_ptr.add(kkk * n + jj)); acc = vfmaq_f32(acc, a_val, b_v); } vst1q_f32(c_ptr.add(ii * n + jj), acc); } else { for jjj in jj..j_end { let mut sum = *c_ptr.add(ii * n + jjj); for kkk in kk..kk_end { sum += *a_ptr.add(ii * k + kkk) * *b_ptr.add(kkk * n + jjj); } *c_ptr.add(ii * n + jjj) = sum; } } } } kk = kk_end; } j = j_end; } i = i_end; } } #[allow(dead_code)] fn gemm_scalar(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) { for i in 0..m { for j in 0..n { let mut sum = 0.0f32; for kk in 0..k { sum += a[i * k + kk] * b[kk * n + j]; } c[i * n + j] = sum; } } } /// Batched GEMM for attention computation #[inline(always)] fn batched_gemm_neon( a: &[f32], b: &[f32], c: &mut [f32], batch_size: usize, m: usize, k: usize, n: usize, ) { debug_assert_eq!(a.len(), batch_size * m * k); debug_assert_eq!(b.len(), batch_size * k * n); debug_assert_eq!(c.len(), batch_size * m * n); let a_batch_stride = m * k; let b_batch_stride = k * n; let c_batch_stride = m * n; for batch in 0..batch_size { let a_offset = batch * a_batch_stride; let b_offset = batch * b_batch_stride; let c_offset = batch * c_batch_stride; gemm_neon( &a[a_offset..a_offset + a_batch_stride], &b[b_offset..b_offset + b_batch_stride], &mut c[c_offset..c_offset + c_batch_stride], m, k, n, ); } } /// GEMM with transposed B matrix (for Q * K^T) fn gemm_nt_neon(a: &[f32], b_t: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) { debug_assert_eq!(a.len(), m * k); debug_assert_eq!(b_t.len(), n * k); debug_assert_eq!(c.len(), m * n); c.fill(0.0); #[cfg(target_arch = "aarch64")] unsafe { gemm_nt_neon_impl(a, b_t, c, m, k, n); } #[cfg(not(target_arch = "aarch64"))] { gemm_nt_scalar(a, b_t, c, m, k, n); } } #[cfg(target_arch = "aarch64")] #[inline(always)] unsafe fn gemm_nt_neon_impl(a: &[f32], b_t: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) { use std::arch::aarch64::*; let a_ptr = a.as_ptr(); let b_ptr = b_t.as_ptr(); let c_ptr = c.as_mut_ptr(); for i in 0..m { let n_chunks = n / NEON_LANE_WIDTH; for nc in 0..n_chunks { let j_base = nc * NEON_LANE_WIDTH; let mut acc0 = 0.0f32; let mut acc1 = 0.0f32; let mut acc2 = 0.0f32; let mut acc3 = 0.0f32; let k_chunks = k / NEON_LANE_WIDTH; let mut kk = 0usize; for _ in 0..k_chunks { let a_v = vld1q_f32(a_ptr.add(i * k + kk)); let b0 = vld1q_f32(b_ptr.add((j_base + 0) * k + kk)); let b1 = vld1q_f32(b_ptr.add((j_base + 1) * k + kk)); let b2 = vld1q_f32(b_ptr.add((j_base + 2) * k + kk)); let b3 = vld1q_f32(b_ptr.add((j_base + 3) * k + kk)); acc0 += vaddvq_f32(vmulq_f32(a_v, b0)); acc1 += vaddvq_f32(vmulq_f32(a_v, b1)); acc2 += vaddvq_f32(vmulq_f32(a_v, b2)); acc3 += vaddvq_f32(vmulq_f32(a_v, b3)); kk += 4; } for kkk in kk..k { let a_val = *a_ptr.add(i * k + kkk); acc0 += a_val * *b_ptr.add((j_base + 0) * k + kkk); acc1 += a_val * *b_ptr.add((j_base + 1) * k + kkk); acc2 += a_val * *b_ptr.add((j_base + 2) * k + kkk); acc3 += a_val * *b_ptr.add((j_base + 3) * k + kkk); } *c_ptr.add(i * n + j_base + 0) = acc0; *c_ptr.add(i * n + j_base + 1) = acc1; *c_ptr.add(i * n + j_base + 2) = acc2; *c_ptr.add(i * n + j_base + 3) = acc3; } for j in (n_chunks * NEON_LANE_WIDTH)..n { let mut acc = vdupq_n_f32(0.0); let k_chunks = k / NEON_LANE_WIDTH; let mut kk = 0usize; for _ in 0..k_chunks { let a_v = vld1q_f32(a_ptr.add(i * k + kk)); let b_v = vld1q_f32(b_ptr.add(j * k + kk)); acc = vfmaq_f32(acc, a_v, b_v); kk += 4; } let mut sum = vaddvq_f32(acc); for kkk in kk..k { sum += *a_ptr.add(i * k + kkk) * *b_ptr.add(j * k + kkk); } *c_ptr.add(i * n + j) = sum; } } } #[allow(dead_code)] fn gemm_nt_scalar(a: &[f32], b_t: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) { for i in 0..m { for j in 0..n { let mut sum = 0.0f32; for kk in 0..k { sum += a[i * k + kk] * b_t[j * k + kk]; } c[i * n + j] = sum; } } } /// Dot product of two vectors #[cfg(target_arch = "aarch64")] #[inline(always)] unsafe fn dot_product_neon(a: &[f32], b: &[f32]) -> f32 { use std::arch::aarch64::*; debug_assert_eq!(a.len(), b.len()); let len = a.len(); let a_ptr = a.as_ptr(); let b_ptr = b.as_ptr(); let mut sum0 = vdupq_n_f32(0.0); let mut sum1 = vdupq_n_f32(0.0); let mut sum2 = vdupq_n_f32(0.0); let mut sum3 = vdupq_n_f32(0.0); let chunks = len / (NEON_LANE_WIDTH * UNROLL_FACTOR); let mut idx = 0usize; for _ in 0..chunks { let a0 = vld1q_f32(a_ptr.add(idx)); let b0 = vld1q_f32(b_ptr.add(idx)); sum0 = vfmaq_f32(sum0, a0, b0); let a1 = vld1q_f32(a_ptr.add(idx + 4)); let b1 = vld1q_f32(b_ptr.add(idx + 4)); sum1 = vfmaq_f32(sum1, a1, b1); let a2 = vld1q_f32(a_ptr.add(idx + 8)); let b2 = vld1q_f32(b_ptr.add(idx + 8)); sum2 = vfmaq_f32(sum2, a2, b2); let a3 = vld1q_f32(a_ptr.add(idx + 12)); let b3 = vld1q_f32(b_ptr.add(idx + 12)); sum3 = vfmaq_f32(sum3, a3, b3); idx += 16; } let sum01 = vaddq_f32(sum0, sum1); let sum23 = vaddq_f32(sum2, sum3); let sum = vaddq_f32(sum01, sum23); let remaining = (len - idx) / NEON_LANE_WIDTH; let mut final_sum = sum; for _ in 0..remaining { let a_v = vld1q_f32(a_ptr.add(idx)); let b_v = vld1q_f32(b_ptr.add(idx)); final_sum = vfmaq_f32(final_sum, a_v, b_v); idx += 4; } let mut result = vaddvq_f32(final_sum); for i in idx..len { result += *a_ptr.add(i) * *b_ptr.add(i); } result } // Helper function to generate random tensor data fn random_tensor(size: usize) -> Vec { let mut rng = rand::thread_rng(); (0..size).map(|_| rng.gen_range(-1.0..1.0)).collect() } // === Benchmark Functions === fn bench_gemv(c: &mut Criterion) { let mut group = c.benchmark_group("gemv"); group.sample_size(50); for (m, n) in [ (256, 256), (512, 512), (1024, 1024), (2048, 2048), (4096, 4096), ] { let a = random_tensor(m * n); let x = random_tensor(n); let mut y = vec![0.0; m]; let flops = 2 * m * n; // multiply + add per element let id = BenchmarkId::new(format!("{}x{}", m, n), m * n); group.throughput(Throughput::Elements(flops as u64)); group.bench_function(id, |b| { b.iter(|| { gemv_neon(black_box(&a), black_box(&x), black_box(&mut y), m, n); }) }); } group.finish(); } fn bench_gemm(c: &mut Criterion) { let mut group = c.benchmark_group("gemm"); group.sample_size(30); for size in [128, 256, 512, 1024, 2048] { let m = size; let k = size; let n = size; let mat_a = random_tensor(m * k); let mat_b = random_tensor(k * n); let mut c_out = vec![0.0; m * n]; let flops = 2 * m * k * n; // multiply + add per output element let id = BenchmarkId::new(format!("{}x{}x{}", m, k, n), m * k * n); group.throughput(Throughput::Elements(flops as u64)); group.bench_function(id, |bencher| { bencher.iter(|| { gemm_neon( black_box(&mat_a), black_box(&mat_b), black_box(&mut c_out), m, k, n, ); }) }); } group.finish(); } fn bench_gemm_non_square(c: &mut Criterion) { let mut group = c.benchmark_group("gemm_non_square"); group.sample_size(30); // Common shapes in LLM inference let shapes = [ (1, 4096, 4096), // Single token projection (32, 4096, 4096), // Batch projection (128, 4096, 4096), // Larger batch (1, 4096, 11008), // MLP up projection (Llama2 7B) (1, 11008, 4096), // MLP down projection (32, 128, 4096), // Attention output ]; for (m, k, n) in shapes { let mat_a = random_tensor(m * k); let mat_b = random_tensor(k * n); let mut c_out = vec![0.0; m * n]; let flops = 2 * m * k * n; let id = BenchmarkId::new(format!("{}x{}x{}", m, k, n), m); group.throughput(Throughput::Elements(flops as u64)); group.bench_function(id, |bencher| { bencher.iter(|| { gemm_neon( black_box(&mat_a), black_box(&mat_b), black_box(&mut c_out), m, k, n, ); }) }); } group.finish(); } fn bench_batched_gemm(c: &mut Criterion) { let mut group = c.benchmark_group("batched_gemm"); group.sample_size(30); for batch_size in [1, 8, 16, 32] { for (m, k, n) in [(64, 64, 64), (128, 128, 128), (256, 256, 256)] { let mat_a = random_tensor(batch_size * m * k); let mat_b = random_tensor(batch_size * k * n); let mut c_out = vec![0.0; batch_size * m * n]; let flops = 2 * batch_size * m * k * n; let id = BenchmarkId::new( format!("batch_{}_{}x{}x{}", batch_size, m, k, n), batch_size, ); group.throughput(Throughput::Elements(flops as u64)); group.bench_function(id, |bencher| { bencher.iter(|| { batched_gemm_neon( black_box(&mat_a), black_box(&mat_b), black_box(&mut c_out), batch_size, m, k, n, ); }) }); } } group.finish(); } fn bench_gemm_nt(c: &mut Criterion) { let mut group = c.benchmark_group("gemm_nt"); group.sample_size(30); // Q * K^T shapes in attention let shapes = [ (128, 128, 128), // seq=128 (256, 128, 256), // seq=256 (512, 128, 512), // seq=512 (1024, 128, 1024), // seq=1024 ]; for (m, k, n) in shapes { let a = random_tensor(m * k); let b_t = random_tensor(n * k); // Transposed let mut c_out = vec![0.0; m * n]; let flops = 2 * m * k * n; let id = BenchmarkId::new(format!("{}x{}x{}", m, k, n), m * n); group.throughput(Throughput::Elements(flops as u64)); group.bench_function(id, |b| { b.iter(|| { gemm_nt_neon( black_box(&a), black_box(&b_t), black_box(&mut c_out), m, k, n, ); }) }); } group.finish(); } #[cfg(target_arch = "aarch64")] fn bench_dot_product(c: &mut Criterion) { let mut group = c.benchmark_group("dot_product"); group.sample_size(100); for size in [64, 128, 256, 512, 1024, 2048, 4096] { let a = random_tensor(size); let b = random_tensor(size); let id = BenchmarkId::new(format!("dim_{}", size), size); group.throughput(Throughput::Elements((2 * size) as u64)); // multiply + add group.bench_function(id, |b_iter| { b_iter.iter(|| unsafe { dot_product_neon(black_box(&a), black_box(&b)) }) }); } group.finish(); } fn bench_tiling_efficiency(c: &mut Criterion) { let mut group = c.benchmark_group("tiling_efficiency"); group.sample_size(20); // Test how well tiling works at various sizes for size in [63, 64, 65, 127, 128, 129, 255, 256, 257] { let mat_a = random_tensor(size * size); let mat_b = random_tensor(size * size); let mut c_out = vec![0.0; size * size]; let flops = 2 * size * size * size; let id = BenchmarkId::new(format!("size_{}", size), size); group.throughput(Throughput::Elements(flops as u64)); group.bench_function(id, |bencher| { bencher.iter(|| { gemm_neon( black_box(&mat_a), black_box(&mat_b), black_box(&mut c_out), size, size, size, ); }) }); } group.finish(); } fn bench_memory_bandwidth(c: &mut Criterion) { let mut group = c.benchmark_group("memory_bandwidth"); group.sample_size(30); // Test memory-bound vs compute-bound behavior for (m, k, n) in [ (1, 4096, 4096), // Very memory bound (GEMV-like) (32, 4096, 4096), // More compute (128, 4096, 4096), // Compute bound ] { let mat_a = random_tensor(m * k); let mat_b = random_tensor(k * n); let mut c_out = vec![0.0; m * n]; // Memory: A (m*k*4), B (k*n*4), C (m*n*4) let memory_bytes = ((m * k) + (k * n) + (m * n)) * 4; let flops = 2 * m * k * n; let id = BenchmarkId::new( format!( "{}x{}x{}_ratio_{:.2}", m, k, n, flops as f64 / memory_bytes as f64 ), m, ); group.throughput(Throughput::Bytes(memory_bytes as u64)); group.bench_function(id, |bencher| { bencher.iter(|| { gemm_neon( black_box(&mat_a), black_box(&mat_b), black_box(&mut c_out), m, k, n, ); }) }); } group.finish(); } fn bench_llm_projection_sizes(c: &mut Criterion) { let mut group = c.benchmark_group("llm_projections"); group.sample_size(20); // Real LLM projection sizes (single token) let configs = [ ("llama2_7b_qkv", 1, 4096, 4096), ("llama2_7b_mlp_up", 1, 4096, 11008), ("llama2_7b_mlp_down", 1, 11008, 4096), ("llama2_13b_qkv", 1, 5120, 5120), ("llama2_70b_qkv", 1, 8192, 8192), ("mistral_7b_qkv", 1, 4096, 4096), ]; for (name, m, k, n) in configs { let mat_a = random_tensor(m * k); let mat_b = random_tensor(k * n); let mut c_out = vec![0.0; m * n]; let flops = 2 * m * k * n; let id = BenchmarkId::new(name, flops); group.throughput(Throughput::Elements(flops as u64)); group.bench_function(id, |bencher| { bencher.iter(|| { gemm_neon( black_box(&mat_a), black_box(&mat_b), black_box(&mut c_out), m, k, n, ); }) }); } group.finish(); } // ============================================================================ // Parallel benchmarks (enabled with `parallel` feature) // ============================================================================ #[cfg(feature = "parallel")] mod parallel_benches { use super::*; /// Get physical core count fn get_physical_cores() -> usize { std::thread::available_parallelism() .map(|n| n.get()) .unwrap_or(4) } /// Configure thread pool once at start fn init_thread_pool() { use std::sync::Once; static INIT: Once = Once::new(); INIT.call_once(|| { rayon::ThreadPoolBuilder::new() .num_threads(get_physical_cores()) .thread_name(|i| format!("bench-gemm-{}", i)) .build_global() .ok(); }); } // ======================================================================== // Parallel GEMM implementations (mirrors single-threaded versions) // ======================================================================== fn gemm_parallel(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) { use rayon::prelude::*; const MIN_ROWS_PER_THREAD: usize = 32; const PARALLEL_THRESHOLD: usize = 128; if m < PARALLEL_THRESHOLD || (m * k * n) < 1_000_000 { return gemm_neon(a, b, c, m, k, n); } c.fill(0.0); let num_threads = get_physical_cores(); let chunk_size = (m / num_threads).max(MIN_ROWS_PER_THREAD); c.par_chunks_mut(chunk_size * n) .enumerate() .for_each(|(chunk_idx, c_chunk)| { let row_start = chunk_idx * chunk_size; let actual_rows = c_chunk.len() / n; let row_end = row_start + actual_rows; let a_start = row_start * k; let a_end = row_end * k; let a_chunk = &a[a_start..a_end]; gemm_chunk(a_chunk, b, c_chunk, actual_rows, k, n); }); } fn gemm_chunk(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) { #[cfg(target_arch = "aarch64")] unsafe { gemm_chunk_neon(a, b, c, m, k, n); } #[cfg(not(target_arch = "aarch64"))] { for i in 0..m { for j in 0..n { let mut sum = 0.0f32; for kk in 0..k { sum += a[i * k + kk] * b[kk * n + j]; } c[i * n + j] = sum; } } } } #[cfg(target_arch = "aarch64")] unsafe fn gemm_chunk_neon(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) { use std::arch::aarch64::*; let a_ptr = a.as_ptr(); let b_ptr = b.as_ptr(); let c_ptr = c.as_mut_ptr(); let mut i = 0usize; while i + 4 <= m { let mut j = 0usize; while j + 8 <= n { let mut c00 = vdupq_n_f32(0.0); let mut c01 = vdupq_n_f32(0.0); let mut c10 = vdupq_n_f32(0.0); let mut c11 = vdupq_n_f32(0.0); let mut c20 = vdupq_n_f32(0.0); let mut c21 = vdupq_n_f32(0.0); let mut c30 = vdupq_n_f32(0.0); let mut c31 = vdupq_n_f32(0.0); for kk in 0..k { let b0 = vld1q_f32(b_ptr.add(kk * n + j)); let b1 = vld1q_f32(b_ptr.add(kk * n + j + 4)); let a0 = vdupq_n_f32(*a_ptr.add(i * k + kk)); let a1 = vdupq_n_f32(*a_ptr.add((i + 1) * k + kk)); let a2 = vdupq_n_f32(*a_ptr.add((i + 2) * k + kk)); let a3 = vdupq_n_f32(*a_ptr.add((i + 3) * k + kk)); c00 = vfmaq_f32(c00, a0, b0); c01 = vfmaq_f32(c01, a0, b1); c10 = vfmaq_f32(c10, a1, b0); c11 = vfmaq_f32(c11, a1, b1); c20 = vfmaq_f32(c20, a2, b0); c21 = vfmaq_f32(c21, a2, b1); c30 = vfmaq_f32(c30, a3, b0); c31 = vfmaq_f32(c31, a3, b1); } vst1q_f32(c_ptr.add(i * n + j), c00); vst1q_f32(c_ptr.add(i * n + j + 4), c01); vst1q_f32(c_ptr.add((i + 1) * n + j), c10); vst1q_f32(c_ptr.add((i + 1) * n + j + 4), c11); vst1q_f32(c_ptr.add((i + 2) * n + j), c20); vst1q_f32(c_ptr.add((i + 2) * n + j + 4), c21); vst1q_f32(c_ptr.add((i + 3) * n + j), c30); vst1q_f32(c_ptr.add((i + 3) * n + j + 4), c31); j += 8; } while j + 4 <= n { let mut c0 = vdupq_n_f32(0.0); let mut c1 = vdupq_n_f32(0.0); let mut c2 = vdupq_n_f32(0.0); let mut c3 = vdupq_n_f32(0.0); for kk in 0..k { let b_v = vld1q_f32(b_ptr.add(kk * n + j)); c0 = vfmaq_f32(c0, vdupq_n_f32(*a_ptr.add(i * k + kk)), b_v); c1 = vfmaq_f32(c1, vdupq_n_f32(*a_ptr.add((i + 1) * k + kk)), b_v); c2 = vfmaq_f32(c2, vdupq_n_f32(*a_ptr.add((i + 2) * k + kk)), b_v); c3 = vfmaq_f32(c3, vdupq_n_f32(*a_ptr.add((i + 3) * k + kk)), b_v); } vst1q_f32(c_ptr.add(i * n + j), c0); vst1q_f32(c_ptr.add((i + 1) * n + j), c1); vst1q_f32(c_ptr.add((i + 2) * n + j), c2); vst1q_f32(c_ptr.add((i + 3) * n + j), c3); j += 4; } while j < n { for row in i..i + 4 { let mut sum = 0.0f32; for kk in 0..k { sum += *a_ptr.add(row * k + kk) * *b_ptr.add(kk * n + j); } *c_ptr.add(row * n + j) = sum; } j += 1; } i += 4; } while i < m { let mut j = 0usize; while j + 4 <= n { let mut acc = vdupq_n_f32(0.0); for kk in 0..k { let a_val = vdupq_n_f32(*a_ptr.add(i * k + kk)); let b_v = vld1q_f32(b_ptr.add(kk * n + j)); acc = vfmaq_f32(acc, a_val, b_v); } vst1q_f32(c_ptr.add(i * n + j), acc); j += 4; } while j < n { let mut sum = 0.0f32; for kk in 0..k { sum += *a_ptr.add(i * k + kk) * *b_ptr.add(kk * n + j); } *c_ptr.add(i * n + j) = sum; j += 1; } i += 1; } } fn gemv_parallel(a: &[f32], x: &[f32], y: &mut [f32], m: usize, n: usize) { use rayon::prelude::*; const MIN_ROWS_PER_THREAD: usize = 32; const PARALLEL_THRESHOLD: usize = 256; if m < PARALLEL_THRESHOLD { return gemv_neon(a, x, y, m, n); } let num_threads = get_physical_cores(); let chunk_size = (m / num_threads).max(MIN_ROWS_PER_THREAD); y.par_chunks_mut(chunk_size) .enumerate() .for_each(|(chunk_idx, y_chunk)| { let row_start = chunk_idx * chunk_size; let row_end = (row_start + y_chunk.len()).min(m); let chunk_rows = row_end - row_start; let a_start = row_start * n; let a_end = row_end * n; let a_chunk = &a[a_start..a_end]; gemv_neon(a_chunk, x, y_chunk, chunk_rows, n); }); } fn batched_gemm_parallel( a: &[f32], b: &[f32], c: &mut [f32], batch_size: usize, m: usize, k: usize, n: usize, ) { use rayon::prelude::*; const PARALLEL_THRESHOLD: usize = 128; let a_batch_stride = m * k; let b_batch_stride = k * n; let c_batch_stride = m * n; if batch_size <= 4 && m >= PARALLEL_THRESHOLD { for batch in 0..batch_size { let a_offset = batch * a_batch_stride; let b_offset = batch * b_batch_stride; let c_offset = batch * c_batch_stride; gemm_parallel( &a[a_offset..a_offset + a_batch_stride], &b[b_offset..b_offset + b_batch_stride], &mut c[c_offset..c_offset + c_batch_stride], m, k, n, ); } } else { c.par_chunks_mut(c_batch_stride) .enumerate() .for_each(|(batch, c_batch)| { let a_offset = batch * a_batch_stride; let b_offset = batch * b_batch_stride; gemm_neon( &a[a_offset..a_offset + a_batch_stride], &b[b_offset..b_offset + b_batch_stride], c_batch, m, k, n, ); }); } } // ======================================================================== // Benchmark functions // ======================================================================== pub fn bench_gemm_parallel(c: &mut Criterion) { init_thread_pool(); let mut group = c.benchmark_group("gemm_parallel"); group.sample_size(30); for size in [256, 512, 1024, 2048] { let m = size; let k = size; let n = size; let mat_a = random_tensor(m * k); let mat_b = random_tensor(k * n); let mut c_out = vec![0.0; m * n]; let flops = 2 * m * k * n; let id = BenchmarkId::new(format!("{}x{}x{}", m, k, n), m * k * n); group.throughput(Throughput::Elements(flops as u64)); group.bench_function(id, |bencher| { bencher.iter(|| { gemm_parallel( black_box(&mat_a), black_box(&mat_b), black_box(&mut c_out), m, k, n, ); }) }); } group.finish(); } pub fn bench_gemv_parallel(c: &mut Criterion) { init_thread_pool(); let mut group = c.benchmark_group("gemv_parallel"); group.sample_size(50); for (m, n) in [(512, 512), (1024, 1024), (2048, 2048), (4096, 4096)] { let a = random_tensor(m * n); let x = random_tensor(n); let mut y = vec![0.0; m]; let flops = 2 * m * n; let id = BenchmarkId::new(format!("{}x{}", m, n), m * n); group.throughput(Throughput::Elements(flops as u64)); group.bench_function(id, |b| { b.iter(|| { gemv_parallel(black_box(&a), black_box(&x), black_box(&mut y), m, n); }) }); } group.finish(); } pub fn bench_batched_gemm_parallel(c: &mut Criterion) { init_thread_pool(); let mut group = c.benchmark_group("batched_gemm_parallel"); group.sample_size(30); for batch_size in [8, 16, 32] { for (m, k, n) in [(128, 128, 128), (256, 256, 256)] { let mat_a = random_tensor(batch_size * m * k); let mat_b = random_tensor(batch_size * k * n); let mut c_out = vec![0.0; batch_size * m * n]; let flops = 2 * batch_size * m * k * n; let id = BenchmarkId::new( format!("batch_{}_{}x{}x{}", batch_size, m, k, n), batch_size, ); group.throughput(Throughput::Elements(flops as u64)); group.bench_function(id, |bencher| { bencher.iter(|| { batched_gemm_parallel( black_box(&mat_a), black_box(&mat_b), black_box(&mut c_out), batch_size, m, k, n, ); }) }); } } group.finish(); } /// Compare single-threaded vs parallel for large matrices pub fn bench_parallel_speedup(c: &mut Criterion) { init_thread_pool(); let mut group = c.benchmark_group("parallel_speedup"); group.sample_size(20); let size = 512; let m = size; let k = size; let n = size; let mat_a = random_tensor(m * k); let mat_b = random_tensor(k * n); let mut c_out = vec![0.0; m * n]; let flops = 2 * m * k * n; group.throughput(Throughput::Elements(flops as u64)); group.bench_function("single_thread", |bencher| { bencher.iter(|| { gemm_neon( black_box(&mat_a), black_box(&mat_b), black_box(&mut c_out), m, k, n, ); }) }); group.bench_function("parallel", |bencher| { bencher.iter(|| { gemm_parallel( black_box(&mat_a), black_box(&mat_b), black_box(&mut c_out), m, k, n, ); }) }); group.finish(); } } #[cfg(feature = "parallel")] use parallel_benches::*; #[cfg(all(target_arch = "aarch64", not(feature = "parallel")))] criterion_group!( benches, bench_gemv, bench_gemm, bench_gemm_non_square, bench_batched_gemm, bench_gemm_nt, bench_dot_product, bench_tiling_efficiency, bench_memory_bandwidth, bench_llm_projection_sizes, ); #[cfg(all(target_arch = "aarch64", feature = "parallel"))] criterion_group!( benches, bench_gemv, bench_gemm, bench_gemm_non_square, bench_batched_gemm, bench_gemm_nt, bench_dot_product, bench_tiling_efficiency, bench_memory_bandwidth, bench_llm_projection_sizes, bench_gemm_parallel, bench_gemv_parallel, bench_batched_gemm_parallel, bench_parallel_speedup, ); #[cfg(all(not(target_arch = "aarch64"), not(feature = "parallel")))] criterion_group!( benches, bench_gemv, bench_gemm, bench_gemm_non_square, bench_batched_gemm, bench_gemm_nt, bench_tiling_efficiency, bench_memory_bandwidth, bench_llm_projection_sizes, ); #[cfg(all(not(target_arch = "aarch64"), feature = "parallel"))] criterion_group!( benches, bench_gemv, bench_gemm, bench_gemm_non_square, bench_batched_gemm, bench_gemm_nt, bench_tiling_efficiency, bench_memory_bandwidth, bench_llm_projection_sizes, bench_gemm_parallel, bench_gemv_parallel, bench_batched_gemm_parallel, bench_parallel_speedup, ); criterion_main!(benches);