Files
wifi-densepose/vendor/ruvector/crates/ruvllm/benches/matmul_bench.rs

1279 lines
37 KiB
Rust

#![allow(
clippy::all,
unused_imports,
unused_variables,
dead_code,
unused_mut,
unused_assignments,
non_camel_case_types,
clippy::approx_constant,
unexpected_cfgs,
unused_must_use,
unused_parens
)]
//! Matrix Multiplication Benchmarks for M4 Pro
//!
//! Benchmarks for GEMV, GEMM, and batched GEMM implementations.
//!
//! ## Running Benchmarks
//!
//! Single-threaded baseline:
//! ```bash
//! cargo bench -p ruvllm --features candle --bench matmul_bench -- gemm/512
//! ```
//!
//! Parallel (with rayon):
//! ```bash
//! cargo bench -p ruvllm --features candle,parallel --bench matmul_bench -- gemm/512
//! ```
//!
//! ## Performance Targets for M4 Pro
//!
//! | Operation | Size | Single-thread | Parallel (10 cores) |
//! |-----------|------|---------------|---------------------|
//! | GEMV | 4096x4096 | <500us | <150us |
//! | GEMM | 1024x1024 | <2ms | <500us |
//! | GEMM | 2048x2048 | <15ms | <3ms |
//! | Batched | 32x128x128 | <2ms | <500us |
//!
//! Target speedup: 4-6x on 10-core M4 Pro for large matrices.
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use rand::Rng;
const NEON_LANE_WIDTH: usize = 4;
const UNROLL_FACTOR: usize = 4;
const TILE_M: usize = 64;
const TILE_N: usize = 64;
const TILE_K: usize = 64;
const MR: usize = 4;
/// General Matrix-Vector multiplication with NEON
#[inline(always)]
fn gemv_neon(a: &[f32], x: &[f32], y: &mut [f32], m: usize, n: usize) {
debug_assert_eq!(a.len(), m * n);
debug_assert_eq!(x.len(), n);
debug_assert_eq!(y.len(), m);
#[cfg(target_arch = "aarch64")]
unsafe {
gemv_neon_impl(a, x, y, m, n);
}
#[cfg(not(target_arch = "aarch64"))]
{
gemv_scalar(a, x, y, m, n);
}
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn gemv_neon_impl(a: &[f32], x: &[f32], y: &mut [f32], m: usize, n: usize) {
use std::arch::aarch64::*;
let a_ptr = a.as_ptr();
let x_ptr = x.as_ptr();
let y_ptr = y.as_mut_ptr();
let row_chunks = m / MR;
for rc in 0..row_chunks {
let row_base = rc * MR;
let mut sum0 = vdupq_n_f32(0.0);
let mut sum1 = vdupq_n_f32(0.0);
let mut sum2 = vdupq_n_f32(0.0);
let mut sum3 = vdupq_n_f32(0.0);
let col_chunks = n / NEON_LANE_WIDTH;
let mut col = 0usize;
for _ in 0..col_chunks {
let x_v = vld1q_f32(x_ptr.add(col));
let a0 = vld1q_f32(a_ptr.add((row_base + 0) * n + col));
sum0 = vfmaq_f32(sum0, a0, x_v);
let a1 = vld1q_f32(a_ptr.add((row_base + 1) * n + col));
sum1 = vfmaq_f32(sum1, a1, x_v);
let a2 = vld1q_f32(a_ptr.add((row_base + 2) * n + col));
sum2 = vfmaq_f32(sum2, a2, x_v);
let a3 = vld1q_f32(a_ptr.add((row_base + 3) * n + col));
sum3 = vfmaq_f32(sum3, a3, x_v);
col += 4;
}
let mut y0 = vaddvq_f32(sum0);
let mut y1 = vaddvq_f32(sum1);
let mut y2 = vaddvq_f32(sum2);
let mut y3 = vaddvq_f32(sum3);
for c in col..n {
let x_val = *x_ptr.add(c);
y0 += *a_ptr.add((row_base + 0) * n + c) * x_val;
y1 += *a_ptr.add((row_base + 1) * n + c) * x_val;
y2 += *a_ptr.add((row_base + 2) * n + c) * x_val;
y3 += *a_ptr.add((row_base + 3) * n + c) * x_val;
}
*y_ptr.add(row_base + 0) = y0;
*y_ptr.add(row_base + 1) = y1;
*y_ptr.add(row_base + 2) = y2;
*y_ptr.add(row_base + 3) = y3;
}
for row in (row_chunks * MR)..m {
let mut sum = vdupq_n_f32(0.0);
let col_chunks = n / NEON_LANE_WIDTH;
let mut col = 0usize;
for _ in 0..col_chunks {
let x_v = vld1q_f32(x_ptr.add(col));
let a_v = vld1q_f32(a_ptr.add(row * n + col));
sum = vfmaq_f32(sum, a_v, x_v);
col += 4;
}
let mut y_val = vaddvq_f32(sum);
for c in col..n {
y_val += *a_ptr.add(row * n + c) * *x_ptr.add(c);
}
*y_ptr.add(row) = y_val;
}
}
#[allow(dead_code)]
fn gemv_scalar(a: &[f32], x: &[f32], y: &mut [f32], m: usize, n: usize) {
for row in 0..m {
let mut sum = 0.0f32;
for col in 0..n {
sum += a[row * n + col] * x[col];
}
y[row] = sum;
}
}
/// General Matrix-Matrix multiplication with NEON
#[inline(always)]
fn gemm_neon(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
debug_assert_eq!(a.len(), m * k);
debug_assert_eq!(b.len(), k * n);
debug_assert_eq!(c.len(), m * n);
c.fill(0.0);
#[cfg(target_arch = "aarch64")]
unsafe {
gemm_neon_impl(a, b, c, m, k, n);
}
#[cfg(not(target_arch = "aarch64"))]
{
gemm_scalar(a, b, c, m, k, n);
}
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn gemm_neon_impl(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
use std::arch::aarch64::*;
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let c_ptr = c.as_mut_ptr();
let mut i = 0usize;
while i < m {
let i_end = (i + TILE_M).min(m);
let mut j = 0usize;
while j < n {
let j_end = (j + TILE_N).min(n);
let mut kk = 0usize;
while kk < k {
let kk_end = (kk + TILE_K).min(k);
for ii in i..i_end {
for jj in (j..j_end).step_by(NEON_LANE_WIDTH) {
let j_remaining = (j_end - jj).min(NEON_LANE_WIDTH);
if j_remaining == NEON_LANE_WIDTH {
let mut acc = vld1q_f32(c_ptr.add(ii * n + jj));
for kkk in kk..kk_end {
let a_val = vdupq_n_f32(*a_ptr.add(ii * k + kkk));
let b_v = vld1q_f32(b_ptr.add(kkk * n + jj));
acc = vfmaq_f32(acc, a_val, b_v);
}
vst1q_f32(c_ptr.add(ii * n + jj), acc);
} else {
for jjj in jj..j_end {
let mut sum = *c_ptr.add(ii * n + jjj);
for kkk in kk..kk_end {
sum += *a_ptr.add(ii * k + kkk) * *b_ptr.add(kkk * n + jjj);
}
*c_ptr.add(ii * n + jjj) = sum;
}
}
}
}
kk = kk_end;
}
j = j_end;
}
i = i_end;
}
}
#[allow(dead_code)]
fn gemm_scalar(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
for i in 0..m {
for j in 0..n {
let mut sum = 0.0f32;
for kk in 0..k {
sum += a[i * k + kk] * b[kk * n + j];
}
c[i * n + j] = sum;
}
}
}
/// Batched GEMM for attention computation
#[inline(always)]
fn batched_gemm_neon(
a: &[f32],
b: &[f32],
c: &mut [f32],
batch_size: usize,
m: usize,
k: usize,
n: usize,
) {
debug_assert_eq!(a.len(), batch_size * m * k);
debug_assert_eq!(b.len(), batch_size * k * n);
debug_assert_eq!(c.len(), batch_size * m * n);
let a_batch_stride = m * k;
let b_batch_stride = k * n;
let c_batch_stride = m * n;
for batch in 0..batch_size {
let a_offset = batch * a_batch_stride;
let b_offset = batch * b_batch_stride;
let c_offset = batch * c_batch_stride;
gemm_neon(
&a[a_offset..a_offset + a_batch_stride],
&b[b_offset..b_offset + b_batch_stride],
&mut c[c_offset..c_offset + c_batch_stride],
m,
k,
n,
);
}
}
/// GEMM with transposed B matrix (for Q * K^T)
fn gemm_nt_neon(a: &[f32], b_t: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
debug_assert_eq!(a.len(), m * k);
debug_assert_eq!(b_t.len(), n * k);
debug_assert_eq!(c.len(), m * n);
c.fill(0.0);
#[cfg(target_arch = "aarch64")]
unsafe {
gemm_nt_neon_impl(a, b_t, c, m, k, n);
}
#[cfg(not(target_arch = "aarch64"))]
{
gemm_nt_scalar(a, b_t, c, m, k, n);
}
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn gemm_nt_neon_impl(a: &[f32], b_t: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
use std::arch::aarch64::*;
let a_ptr = a.as_ptr();
let b_ptr = b_t.as_ptr();
let c_ptr = c.as_mut_ptr();
for i in 0..m {
let n_chunks = n / NEON_LANE_WIDTH;
for nc in 0..n_chunks {
let j_base = nc * NEON_LANE_WIDTH;
let mut acc0 = 0.0f32;
let mut acc1 = 0.0f32;
let mut acc2 = 0.0f32;
let mut acc3 = 0.0f32;
let k_chunks = k / NEON_LANE_WIDTH;
let mut kk = 0usize;
for _ in 0..k_chunks {
let a_v = vld1q_f32(a_ptr.add(i * k + kk));
let b0 = vld1q_f32(b_ptr.add((j_base + 0) * k + kk));
let b1 = vld1q_f32(b_ptr.add((j_base + 1) * k + kk));
let b2 = vld1q_f32(b_ptr.add((j_base + 2) * k + kk));
let b3 = vld1q_f32(b_ptr.add((j_base + 3) * k + kk));
acc0 += vaddvq_f32(vmulq_f32(a_v, b0));
acc1 += vaddvq_f32(vmulq_f32(a_v, b1));
acc2 += vaddvq_f32(vmulq_f32(a_v, b2));
acc3 += vaddvq_f32(vmulq_f32(a_v, b3));
kk += 4;
}
for kkk in kk..k {
let a_val = *a_ptr.add(i * k + kkk);
acc0 += a_val * *b_ptr.add((j_base + 0) * k + kkk);
acc1 += a_val * *b_ptr.add((j_base + 1) * k + kkk);
acc2 += a_val * *b_ptr.add((j_base + 2) * k + kkk);
acc3 += a_val * *b_ptr.add((j_base + 3) * k + kkk);
}
*c_ptr.add(i * n + j_base + 0) = acc0;
*c_ptr.add(i * n + j_base + 1) = acc1;
*c_ptr.add(i * n + j_base + 2) = acc2;
*c_ptr.add(i * n + j_base + 3) = acc3;
}
for j in (n_chunks * NEON_LANE_WIDTH)..n {
let mut acc = vdupq_n_f32(0.0);
let k_chunks = k / NEON_LANE_WIDTH;
let mut kk = 0usize;
for _ in 0..k_chunks {
let a_v = vld1q_f32(a_ptr.add(i * k + kk));
let b_v = vld1q_f32(b_ptr.add(j * k + kk));
acc = vfmaq_f32(acc, a_v, b_v);
kk += 4;
}
let mut sum = vaddvq_f32(acc);
for kkk in kk..k {
sum += *a_ptr.add(i * k + kkk) * *b_ptr.add(j * k + kkk);
}
*c_ptr.add(i * n + j) = sum;
}
}
}
#[allow(dead_code)]
fn gemm_nt_scalar(a: &[f32], b_t: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
for i in 0..m {
for j in 0..n {
let mut sum = 0.0f32;
for kk in 0..k {
sum += a[i * k + kk] * b_t[j * k + kk];
}
c[i * n + j] = sum;
}
}
}
/// Dot product of two vectors
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn dot_product_neon(a: &[f32], b: &[f32]) -> f32 {
use std::arch::aarch64::*;
debug_assert_eq!(a.len(), b.len());
let len = a.len();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let mut sum0 = vdupq_n_f32(0.0);
let mut sum1 = vdupq_n_f32(0.0);
let mut sum2 = vdupq_n_f32(0.0);
let mut sum3 = vdupq_n_f32(0.0);
let chunks = len / (NEON_LANE_WIDTH * UNROLL_FACTOR);
let mut idx = 0usize;
for _ in 0..chunks {
let a0 = vld1q_f32(a_ptr.add(idx));
let b0 = vld1q_f32(b_ptr.add(idx));
sum0 = vfmaq_f32(sum0, a0, b0);
let a1 = vld1q_f32(a_ptr.add(idx + 4));
let b1 = vld1q_f32(b_ptr.add(idx + 4));
sum1 = vfmaq_f32(sum1, a1, b1);
let a2 = vld1q_f32(a_ptr.add(idx + 8));
let b2 = vld1q_f32(b_ptr.add(idx + 8));
sum2 = vfmaq_f32(sum2, a2, b2);
let a3 = vld1q_f32(a_ptr.add(idx + 12));
let b3 = vld1q_f32(b_ptr.add(idx + 12));
sum3 = vfmaq_f32(sum3, a3, b3);
idx += 16;
}
let sum01 = vaddq_f32(sum0, sum1);
let sum23 = vaddq_f32(sum2, sum3);
let sum = vaddq_f32(sum01, sum23);
let remaining = (len - idx) / NEON_LANE_WIDTH;
let mut final_sum = sum;
for _ in 0..remaining {
let a_v = vld1q_f32(a_ptr.add(idx));
let b_v = vld1q_f32(b_ptr.add(idx));
final_sum = vfmaq_f32(final_sum, a_v, b_v);
idx += 4;
}
let mut result = vaddvq_f32(final_sum);
for i in idx..len {
result += *a_ptr.add(i) * *b_ptr.add(i);
}
result
}
// Helper function to generate random tensor data
fn random_tensor(size: usize) -> Vec<f32> {
let mut rng = rand::thread_rng();
(0..size).map(|_| rng.gen_range(-1.0..1.0)).collect()
}
// === Benchmark Functions ===
fn bench_gemv(c: &mut Criterion) {
let mut group = c.benchmark_group("gemv");
group.sample_size(50);
for (m, n) in [
(256, 256),
(512, 512),
(1024, 1024),
(2048, 2048),
(4096, 4096),
] {
let a = random_tensor(m * n);
let x = random_tensor(n);
let mut y = vec![0.0; m];
let flops = 2 * m * n; // multiply + add per element
let id = BenchmarkId::new(format!("{}x{}", m, n), m * n);
group.throughput(Throughput::Elements(flops as u64));
group.bench_function(id, |b| {
b.iter(|| {
gemv_neon(black_box(&a), black_box(&x), black_box(&mut y), m, n);
})
});
}
group.finish();
}
fn bench_gemm(c: &mut Criterion) {
let mut group = c.benchmark_group("gemm");
group.sample_size(30);
for size in [128, 256, 512, 1024, 2048] {
let m = size;
let k = size;
let n = size;
let mat_a = random_tensor(m * k);
let mat_b = random_tensor(k * n);
let mut c_out = vec![0.0; m * n];
let flops = 2 * m * k * n; // multiply + add per output element
let id = BenchmarkId::new(format!("{}x{}x{}", m, k, n), m * k * n);
group.throughput(Throughput::Elements(flops as u64));
group.bench_function(id, |bencher| {
bencher.iter(|| {
gemm_neon(
black_box(&mat_a),
black_box(&mat_b),
black_box(&mut c_out),
m,
k,
n,
);
})
});
}
group.finish();
}
fn bench_gemm_non_square(c: &mut Criterion) {
let mut group = c.benchmark_group("gemm_non_square");
group.sample_size(30);
// Common shapes in LLM inference
let shapes = [
(1, 4096, 4096), // Single token projection
(32, 4096, 4096), // Batch projection
(128, 4096, 4096), // Larger batch
(1, 4096, 11008), // MLP up projection (Llama2 7B)
(1, 11008, 4096), // MLP down projection
(32, 128, 4096), // Attention output
];
for (m, k, n) in shapes {
let mat_a = random_tensor(m * k);
let mat_b = random_tensor(k * n);
let mut c_out = vec![0.0; m * n];
let flops = 2 * m * k * n;
let id = BenchmarkId::new(format!("{}x{}x{}", m, k, n), m);
group.throughput(Throughput::Elements(flops as u64));
group.bench_function(id, |bencher| {
bencher.iter(|| {
gemm_neon(
black_box(&mat_a),
black_box(&mat_b),
black_box(&mut c_out),
m,
k,
n,
);
})
});
}
group.finish();
}
fn bench_batched_gemm(c: &mut Criterion) {
let mut group = c.benchmark_group("batched_gemm");
group.sample_size(30);
for batch_size in [1, 8, 16, 32] {
for (m, k, n) in [(64, 64, 64), (128, 128, 128), (256, 256, 256)] {
let mat_a = random_tensor(batch_size * m * k);
let mat_b = random_tensor(batch_size * k * n);
let mut c_out = vec![0.0; batch_size * m * n];
let flops = 2 * batch_size * m * k * n;
let id = BenchmarkId::new(
format!("batch_{}_{}x{}x{}", batch_size, m, k, n),
batch_size,
);
group.throughput(Throughput::Elements(flops as u64));
group.bench_function(id, |bencher| {
bencher.iter(|| {
batched_gemm_neon(
black_box(&mat_a),
black_box(&mat_b),
black_box(&mut c_out),
batch_size,
m,
k,
n,
);
})
});
}
}
group.finish();
}
fn bench_gemm_nt(c: &mut Criterion) {
let mut group = c.benchmark_group("gemm_nt");
group.sample_size(30);
// Q * K^T shapes in attention
let shapes = [
(128, 128, 128), // seq=128
(256, 128, 256), // seq=256
(512, 128, 512), // seq=512
(1024, 128, 1024), // seq=1024
];
for (m, k, n) in shapes {
let a = random_tensor(m * k);
let b_t = random_tensor(n * k); // Transposed
let mut c_out = vec![0.0; m * n];
let flops = 2 * m * k * n;
let id = BenchmarkId::new(format!("{}x{}x{}", m, k, n), m * n);
group.throughput(Throughput::Elements(flops as u64));
group.bench_function(id, |b| {
b.iter(|| {
gemm_nt_neon(
black_box(&a),
black_box(&b_t),
black_box(&mut c_out),
m,
k,
n,
);
})
});
}
group.finish();
}
#[cfg(target_arch = "aarch64")]
fn bench_dot_product(c: &mut Criterion) {
let mut group = c.benchmark_group("dot_product");
group.sample_size(100);
for size in [64, 128, 256, 512, 1024, 2048, 4096] {
let a = random_tensor(size);
let b = random_tensor(size);
let id = BenchmarkId::new(format!("dim_{}", size), size);
group.throughput(Throughput::Elements((2 * size) as u64)); // multiply + add
group.bench_function(id, |b_iter| {
b_iter.iter(|| unsafe { dot_product_neon(black_box(&a), black_box(&b)) })
});
}
group.finish();
}
fn bench_tiling_efficiency(c: &mut Criterion) {
let mut group = c.benchmark_group("tiling_efficiency");
group.sample_size(20);
// Test how well tiling works at various sizes
for size in [63, 64, 65, 127, 128, 129, 255, 256, 257] {
let mat_a = random_tensor(size * size);
let mat_b = random_tensor(size * size);
let mut c_out = vec![0.0; size * size];
let flops = 2 * size * size * size;
let id = BenchmarkId::new(format!("size_{}", size), size);
group.throughput(Throughput::Elements(flops as u64));
group.bench_function(id, |bencher| {
bencher.iter(|| {
gemm_neon(
black_box(&mat_a),
black_box(&mat_b),
black_box(&mut c_out),
size,
size,
size,
);
})
});
}
group.finish();
}
fn bench_memory_bandwidth(c: &mut Criterion) {
let mut group = c.benchmark_group("memory_bandwidth");
group.sample_size(30);
// Test memory-bound vs compute-bound behavior
for (m, k, n) in [
(1, 4096, 4096), // Very memory bound (GEMV-like)
(32, 4096, 4096), // More compute
(128, 4096, 4096), // Compute bound
] {
let mat_a = random_tensor(m * k);
let mat_b = random_tensor(k * n);
let mut c_out = vec![0.0; m * n];
// Memory: A (m*k*4), B (k*n*4), C (m*n*4)
let memory_bytes = ((m * k) + (k * n) + (m * n)) * 4;
let flops = 2 * m * k * n;
let id = BenchmarkId::new(
format!(
"{}x{}x{}_ratio_{:.2}",
m,
k,
n,
flops as f64 / memory_bytes as f64
),
m,
);
group.throughput(Throughput::Bytes(memory_bytes as u64));
group.bench_function(id, |bencher| {
bencher.iter(|| {
gemm_neon(
black_box(&mat_a),
black_box(&mat_b),
black_box(&mut c_out),
m,
k,
n,
);
})
});
}
group.finish();
}
fn bench_llm_projection_sizes(c: &mut Criterion) {
let mut group = c.benchmark_group("llm_projections");
group.sample_size(20);
// Real LLM projection sizes (single token)
let configs = [
("llama2_7b_qkv", 1, 4096, 4096),
("llama2_7b_mlp_up", 1, 4096, 11008),
("llama2_7b_mlp_down", 1, 11008, 4096),
("llama2_13b_qkv", 1, 5120, 5120),
("llama2_70b_qkv", 1, 8192, 8192),
("mistral_7b_qkv", 1, 4096, 4096),
];
for (name, m, k, n) in configs {
let mat_a = random_tensor(m * k);
let mat_b = random_tensor(k * n);
let mut c_out = vec![0.0; m * n];
let flops = 2 * m * k * n;
let id = BenchmarkId::new(name, flops);
group.throughput(Throughput::Elements(flops as u64));
group.bench_function(id, |bencher| {
bencher.iter(|| {
gemm_neon(
black_box(&mat_a),
black_box(&mat_b),
black_box(&mut c_out),
m,
k,
n,
);
})
});
}
group.finish();
}
// ============================================================================
// Parallel benchmarks (enabled with `parallel` feature)
// ============================================================================
#[cfg(feature = "parallel")]
mod parallel_benches {
use super::*;
/// Get physical core count
fn get_physical_cores() -> usize {
std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(4)
}
/// Configure thread pool once at start
fn init_thread_pool() {
use std::sync::Once;
static INIT: Once = Once::new();
INIT.call_once(|| {
rayon::ThreadPoolBuilder::new()
.num_threads(get_physical_cores())
.thread_name(|i| format!("bench-gemm-{}", i))
.build_global()
.ok();
});
}
// ========================================================================
// Parallel GEMM implementations (mirrors single-threaded versions)
// ========================================================================
fn gemm_parallel(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
use rayon::prelude::*;
const MIN_ROWS_PER_THREAD: usize = 32;
const PARALLEL_THRESHOLD: usize = 128;
if m < PARALLEL_THRESHOLD || (m * k * n) < 1_000_000 {
return gemm_neon(a, b, c, m, k, n);
}
c.fill(0.0);
let num_threads = get_physical_cores();
let chunk_size = (m / num_threads).max(MIN_ROWS_PER_THREAD);
c.par_chunks_mut(chunk_size * n)
.enumerate()
.for_each(|(chunk_idx, c_chunk)| {
let row_start = chunk_idx * chunk_size;
let actual_rows = c_chunk.len() / n;
let row_end = row_start + actual_rows;
let a_start = row_start * k;
let a_end = row_end * k;
let a_chunk = &a[a_start..a_end];
gemm_chunk(a_chunk, b, c_chunk, actual_rows, k, n);
});
}
fn gemm_chunk(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
#[cfg(target_arch = "aarch64")]
unsafe {
gemm_chunk_neon(a, b, c, m, k, n);
}
#[cfg(not(target_arch = "aarch64"))]
{
for i in 0..m {
for j in 0..n {
let mut sum = 0.0f32;
for kk in 0..k {
sum += a[i * k + kk] * b[kk * n + j];
}
c[i * n + j] = sum;
}
}
}
}
#[cfg(target_arch = "aarch64")]
unsafe fn gemm_chunk_neon(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
use std::arch::aarch64::*;
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let c_ptr = c.as_mut_ptr();
let mut i = 0usize;
while i + 4 <= m {
let mut j = 0usize;
while j + 8 <= n {
let mut c00 = vdupq_n_f32(0.0);
let mut c01 = vdupq_n_f32(0.0);
let mut c10 = vdupq_n_f32(0.0);
let mut c11 = vdupq_n_f32(0.0);
let mut c20 = vdupq_n_f32(0.0);
let mut c21 = vdupq_n_f32(0.0);
let mut c30 = vdupq_n_f32(0.0);
let mut c31 = vdupq_n_f32(0.0);
for kk in 0..k {
let b0 = vld1q_f32(b_ptr.add(kk * n + j));
let b1 = vld1q_f32(b_ptr.add(kk * n + j + 4));
let a0 = vdupq_n_f32(*a_ptr.add(i * k + kk));
let a1 = vdupq_n_f32(*a_ptr.add((i + 1) * k + kk));
let a2 = vdupq_n_f32(*a_ptr.add((i + 2) * k + kk));
let a3 = vdupq_n_f32(*a_ptr.add((i + 3) * k + kk));
c00 = vfmaq_f32(c00, a0, b0);
c01 = vfmaq_f32(c01, a0, b1);
c10 = vfmaq_f32(c10, a1, b0);
c11 = vfmaq_f32(c11, a1, b1);
c20 = vfmaq_f32(c20, a2, b0);
c21 = vfmaq_f32(c21, a2, b1);
c30 = vfmaq_f32(c30, a3, b0);
c31 = vfmaq_f32(c31, a3, b1);
}
vst1q_f32(c_ptr.add(i * n + j), c00);
vst1q_f32(c_ptr.add(i * n + j + 4), c01);
vst1q_f32(c_ptr.add((i + 1) * n + j), c10);
vst1q_f32(c_ptr.add((i + 1) * n + j + 4), c11);
vst1q_f32(c_ptr.add((i + 2) * n + j), c20);
vst1q_f32(c_ptr.add((i + 2) * n + j + 4), c21);
vst1q_f32(c_ptr.add((i + 3) * n + j), c30);
vst1q_f32(c_ptr.add((i + 3) * n + j + 4), c31);
j += 8;
}
while j + 4 <= n {
let mut c0 = vdupq_n_f32(0.0);
let mut c1 = vdupq_n_f32(0.0);
let mut c2 = vdupq_n_f32(0.0);
let mut c3 = vdupq_n_f32(0.0);
for kk in 0..k {
let b_v = vld1q_f32(b_ptr.add(kk * n + j));
c0 = vfmaq_f32(c0, vdupq_n_f32(*a_ptr.add(i * k + kk)), b_v);
c1 = vfmaq_f32(c1, vdupq_n_f32(*a_ptr.add((i + 1) * k + kk)), b_v);
c2 = vfmaq_f32(c2, vdupq_n_f32(*a_ptr.add((i + 2) * k + kk)), b_v);
c3 = vfmaq_f32(c3, vdupq_n_f32(*a_ptr.add((i + 3) * k + kk)), b_v);
}
vst1q_f32(c_ptr.add(i * n + j), c0);
vst1q_f32(c_ptr.add((i + 1) * n + j), c1);
vst1q_f32(c_ptr.add((i + 2) * n + j), c2);
vst1q_f32(c_ptr.add((i + 3) * n + j), c3);
j += 4;
}
while j < n {
for row in i..i + 4 {
let mut sum = 0.0f32;
for kk in 0..k {
sum += *a_ptr.add(row * k + kk) * *b_ptr.add(kk * n + j);
}
*c_ptr.add(row * n + j) = sum;
}
j += 1;
}
i += 4;
}
while i < m {
let mut j = 0usize;
while j + 4 <= n {
let mut acc = vdupq_n_f32(0.0);
for kk in 0..k {
let a_val = vdupq_n_f32(*a_ptr.add(i * k + kk));
let b_v = vld1q_f32(b_ptr.add(kk * n + j));
acc = vfmaq_f32(acc, a_val, b_v);
}
vst1q_f32(c_ptr.add(i * n + j), acc);
j += 4;
}
while j < n {
let mut sum = 0.0f32;
for kk in 0..k {
sum += *a_ptr.add(i * k + kk) * *b_ptr.add(kk * n + j);
}
*c_ptr.add(i * n + j) = sum;
j += 1;
}
i += 1;
}
}
fn gemv_parallel(a: &[f32], x: &[f32], y: &mut [f32], m: usize, n: usize) {
use rayon::prelude::*;
const MIN_ROWS_PER_THREAD: usize = 32;
const PARALLEL_THRESHOLD: usize = 256;
if m < PARALLEL_THRESHOLD {
return gemv_neon(a, x, y, m, n);
}
let num_threads = get_physical_cores();
let chunk_size = (m / num_threads).max(MIN_ROWS_PER_THREAD);
y.par_chunks_mut(chunk_size)
.enumerate()
.for_each(|(chunk_idx, y_chunk)| {
let row_start = chunk_idx * chunk_size;
let row_end = (row_start + y_chunk.len()).min(m);
let chunk_rows = row_end - row_start;
let a_start = row_start * n;
let a_end = row_end * n;
let a_chunk = &a[a_start..a_end];
gemv_neon(a_chunk, x, y_chunk, chunk_rows, n);
});
}
fn batched_gemm_parallel(
a: &[f32],
b: &[f32],
c: &mut [f32],
batch_size: usize,
m: usize,
k: usize,
n: usize,
) {
use rayon::prelude::*;
const PARALLEL_THRESHOLD: usize = 128;
let a_batch_stride = m * k;
let b_batch_stride = k * n;
let c_batch_stride = m * n;
if batch_size <= 4 && m >= PARALLEL_THRESHOLD {
for batch in 0..batch_size {
let a_offset = batch * a_batch_stride;
let b_offset = batch * b_batch_stride;
let c_offset = batch * c_batch_stride;
gemm_parallel(
&a[a_offset..a_offset + a_batch_stride],
&b[b_offset..b_offset + b_batch_stride],
&mut c[c_offset..c_offset + c_batch_stride],
m,
k,
n,
);
}
} else {
c.par_chunks_mut(c_batch_stride)
.enumerate()
.for_each(|(batch, c_batch)| {
let a_offset = batch * a_batch_stride;
let b_offset = batch * b_batch_stride;
gemm_neon(
&a[a_offset..a_offset + a_batch_stride],
&b[b_offset..b_offset + b_batch_stride],
c_batch,
m,
k,
n,
);
});
}
}
// ========================================================================
// Benchmark functions
// ========================================================================
pub fn bench_gemm_parallel(c: &mut Criterion) {
init_thread_pool();
let mut group = c.benchmark_group("gemm_parallel");
group.sample_size(30);
for size in [256, 512, 1024, 2048] {
let m = size;
let k = size;
let n = size;
let mat_a = random_tensor(m * k);
let mat_b = random_tensor(k * n);
let mut c_out = vec![0.0; m * n];
let flops = 2 * m * k * n;
let id = BenchmarkId::new(format!("{}x{}x{}", m, k, n), m * k * n);
group.throughput(Throughput::Elements(flops as u64));
group.bench_function(id, |bencher| {
bencher.iter(|| {
gemm_parallel(
black_box(&mat_a),
black_box(&mat_b),
black_box(&mut c_out),
m,
k,
n,
);
})
});
}
group.finish();
}
pub fn bench_gemv_parallel(c: &mut Criterion) {
init_thread_pool();
let mut group = c.benchmark_group("gemv_parallel");
group.sample_size(50);
for (m, n) in [(512, 512), (1024, 1024), (2048, 2048), (4096, 4096)] {
let a = random_tensor(m * n);
let x = random_tensor(n);
let mut y = vec![0.0; m];
let flops = 2 * m * n;
let id = BenchmarkId::new(format!("{}x{}", m, n), m * n);
group.throughput(Throughput::Elements(flops as u64));
group.bench_function(id, |b| {
b.iter(|| {
gemv_parallel(black_box(&a), black_box(&x), black_box(&mut y), m, n);
})
});
}
group.finish();
}
pub fn bench_batched_gemm_parallel(c: &mut Criterion) {
init_thread_pool();
let mut group = c.benchmark_group("batched_gemm_parallel");
group.sample_size(30);
for batch_size in [8, 16, 32] {
for (m, k, n) in [(128, 128, 128), (256, 256, 256)] {
let mat_a = random_tensor(batch_size * m * k);
let mat_b = random_tensor(batch_size * k * n);
let mut c_out = vec![0.0; batch_size * m * n];
let flops = 2 * batch_size * m * k * n;
let id = BenchmarkId::new(
format!("batch_{}_{}x{}x{}", batch_size, m, k, n),
batch_size,
);
group.throughput(Throughput::Elements(flops as u64));
group.bench_function(id, |bencher| {
bencher.iter(|| {
batched_gemm_parallel(
black_box(&mat_a),
black_box(&mat_b),
black_box(&mut c_out),
batch_size,
m,
k,
n,
);
})
});
}
}
group.finish();
}
/// Compare single-threaded vs parallel for large matrices
pub fn bench_parallel_speedup(c: &mut Criterion) {
init_thread_pool();
let mut group = c.benchmark_group("parallel_speedup");
group.sample_size(20);
let size = 512;
let m = size;
let k = size;
let n = size;
let mat_a = random_tensor(m * k);
let mat_b = random_tensor(k * n);
let mut c_out = vec![0.0; m * n];
let flops = 2 * m * k * n;
group.throughput(Throughput::Elements(flops as u64));
group.bench_function("single_thread", |bencher| {
bencher.iter(|| {
gemm_neon(
black_box(&mat_a),
black_box(&mat_b),
black_box(&mut c_out),
m,
k,
n,
);
})
});
group.bench_function("parallel", |bencher| {
bencher.iter(|| {
gemm_parallel(
black_box(&mat_a),
black_box(&mat_b),
black_box(&mut c_out),
m,
k,
n,
);
})
});
group.finish();
}
}
#[cfg(feature = "parallel")]
use parallel_benches::*;
#[cfg(all(target_arch = "aarch64", not(feature = "parallel")))]
criterion_group!(
benches,
bench_gemv,
bench_gemm,
bench_gemm_non_square,
bench_batched_gemm,
bench_gemm_nt,
bench_dot_product,
bench_tiling_efficiency,
bench_memory_bandwidth,
bench_llm_projection_sizes,
);
#[cfg(all(target_arch = "aarch64", feature = "parallel"))]
criterion_group!(
benches,
bench_gemv,
bench_gemm,
bench_gemm_non_square,
bench_batched_gemm,
bench_gemm_nt,
bench_dot_product,
bench_tiling_efficiency,
bench_memory_bandwidth,
bench_llm_projection_sizes,
bench_gemm_parallel,
bench_gemv_parallel,
bench_batched_gemm_parallel,
bench_parallel_speedup,
);
#[cfg(all(not(target_arch = "aarch64"), not(feature = "parallel")))]
criterion_group!(
benches,
bench_gemv,
bench_gemm,
bench_gemm_non_square,
bench_batched_gemm,
bench_gemm_nt,
bench_tiling_efficiency,
bench_memory_bandwidth,
bench_llm_projection_sizes,
);
#[cfg(all(not(target_arch = "aarch64"), feature = "parallel"))]
criterion_group!(
benches,
bench_gemv,
bench_gemm,
bench_gemm_non_square,
bench_batched_gemm,
bench_gemm_nt,
bench_tiling_efficiency,
bench_memory_bandwidth,
bench_llm_projection_sizes,
bench_gemm_parallel,
bench_gemv_parallel,
bench_batched_gemm_parallel,
bench_parallel_speedup,
);
criterion_main!(benches);