Files
wifi-densepose/crates/ruvllm/tests/cross_platform.rs
ruv d803bfe2b1 Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector
git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
2026-02-28 14:39:40 -05:00

466 lines
14 KiB
Rust

#![allow(
clippy::all,
unused_imports,
unused_variables,
dead_code,
unused_mut,
unused_assignments,
non_camel_case_types,
clippy::approx_constant,
unexpected_cfgs,
unused_must_use,
unused_parens
)]
//! Cross-platform tests for scalar fallback implementations
//!
//! These tests verify that the scalar fallback implementations produce
//! correct results and work on all platforms (including non-NEON and WASM).
use ruvllm::kernels::{flash_attention_neon, gemm_neon, gemv_neon, layer_norm_neon, rms_norm_neon};
// ========== Scalar Reference Implementations ==========
/// Scalar reference GEMV implementation
fn gemv_scalar(a: &[f32], x: &[f32], y: &mut [f32], m: usize, n: usize) {
for row in 0..m {
let mut sum = 0.0f32;
for col in 0..n {
sum += a[row * n + col] * x[col];
}
y[row] = sum;
}
}
/// Scalar reference GEMM implementation
fn gemm_scalar(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
c.fill(0.0);
for i in 0..m {
for j in 0..n {
let mut sum = 0.0f32;
for kk in 0..k {
sum += a[i * k + kk] * b[kk * n + j];
}
c[i * n + j] = sum;
}
}
}
/// Scalar reference attention implementation
fn attention_scalar(
query: &[f32],
key: &[f32],
value: &[f32],
head_dim: usize,
kv_len: usize,
scale: f32,
) -> Vec<f32> {
// Compute attention scores
let mut scores = Vec::with_capacity(kv_len);
for t in 0..kv_len {
let k_offset = t * head_dim;
let score: f32 = query
.iter()
.zip(&key[k_offset..k_offset + head_dim])
.map(|(q, k)| q * k * scale)
.sum();
scores.push(score);
}
// Softmax
let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_scores: Vec<f32> = scores.iter().map(|s| (s - max_score).exp()).collect();
let sum_exp: f32 = exp_scores.iter().sum();
let attn_weights: Vec<f32> = exp_scores.iter().map(|e| e / sum_exp).collect();
// Weighted sum of values
let mut output = vec![0.0; head_dim];
for (t, weight) in attn_weights.iter().enumerate() {
let v_offset = t * head_dim;
for (i, v) in value[v_offset..v_offset + head_dim].iter().enumerate() {
output[i] += weight * v;
}
}
output
}
/// Scalar reference RMSNorm implementation
fn rms_norm_scalar(x: &mut [f32], weight: &[f32], eps: f32) {
let len = x.len();
let sum_sq: f32 = x.iter().map(|v| v * v).sum();
let inv_rms = 1.0 / (sum_sq / len as f32 + eps).sqrt();
for (i, w) in weight.iter().enumerate() {
x[i] = x[i] * inv_rms * w;
}
}
/// Scalar reference LayerNorm implementation
fn layer_norm_scalar(x: &mut [f32], weight: &[f32], bias: &[f32], eps: f32) {
let len = x.len();
let mean: f32 = x.iter().sum::<f32>() / len as f32;
let var: f32 = x.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / len as f32;
let inv_std = 1.0 / (var + eps).sqrt();
for i in 0..len {
x[i] = (x[i] - mean) * inv_std * weight[i] + bias[i];
}
}
// ========== Cross-Platform Tests ==========
#[test]
fn test_cross_platform_gemv() {
let test_cases = [
(4, 4),
(8, 16),
(16, 32),
(32, 64),
(64, 128),
(100, 50),
(7, 13), // Non-aligned
];
for (m, n) in test_cases {
let a: Vec<f32> = (0..m * n)
.map(|i| ((i % 100) as f32 - 50.0) / 50.0)
.collect();
let x: Vec<f32> = (0..n).map(|i| ((i % 20) as f32 - 10.0) / 10.0).collect();
let mut y_neon = vec![0.0; m];
let mut y_scalar = vec![0.0; m];
gemv_neon(&a, &x, &mut y_neon, m, n);
gemv_scalar(&a, &x, &mut y_scalar, m, n);
for i in 0..m {
let abs_error = (y_neon[i] - y_scalar[i]).abs();
let rel_error = abs_error / y_scalar[i].abs().max(1e-6);
assert!(
rel_error < 0.001 || abs_error < 1e-5,
"Cross-platform GEMV mismatch at ({},{}) index {}: {} vs {} (rel: {:.6})",
m,
n,
i,
y_neon[i],
y_scalar[i],
rel_error
);
}
}
}
#[test]
fn test_cross_platform_gemm() {
let test_cases = [
(4, 4, 4),
(8, 16, 8),
(16, 32, 16),
(32, 64, 32),
(7, 11, 13), // Non-aligned
];
for (m, k, n) in test_cases {
let a: Vec<f32> = (0..m * k)
.map(|i| ((i % 100) as f32 - 50.0) / 100.0)
.collect();
let b: Vec<f32> = (0..k * n)
.map(|i| ((i % 50) as f32 - 25.0) / 50.0)
.collect();
let mut c_neon = vec![0.0; m * n];
let mut c_scalar = vec![0.0; m * n];
gemm_neon(&a, &b, &mut c_neon, m, k, n);
gemm_scalar(&a, &b, &mut c_scalar, m, k, n);
for i in 0..(m * n) {
let abs_error = (c_neon[i] - c_scalar[i]).abs();
let rel_error = abs_error / c_scalar[i].abs().max(1e-6);
assert!(
rel_error < 0.01 || abs_error < 0.001,
"Cross-platform GEMM mismatch at ({},{},{}) index {}: {} vs {} (rel: {:.6})",
m,
k,
n,
i,
c_neon[i],
c_scalar[i],
rel_error
);
}
}
}
#[test]
fn test_cross_platform_attention() {
let test_cases = [(16, 4), (32, 8), (64, 16), (128, 32)];
for (head_dim, kv_len) in test_cases {
let scale = 1.0 / (head_dim as f32).sqrt();
let query: Vec<f32> = (0..head_dim)
.map(|i| ((i % 7) as f32 - 3.0) / 10.0)
.collect();
let key: Vec<f32> = (0..kv_len * head_dim)
.map(|i| ((i % 11) as f32 - 5.0) / 20.0)
.collect();
let value: Vec<f32> = (0..kv_len * head_dim)
.map(|i| ((i % 13) as f32 - 6.0) / 15.0)
.collect();
let output_neon = flash_attention_neon(&query, &key, &value, scale, false);
let output_scalar = attention_scalar(&query, &key, &value, head_dim, kv_len, scale);
assert_eq!(output_neon.len(), output_scalar.len());
for i in 0..head_dim {
let abs_error = (output_neon[i] - output_scalar[i]).abs();
let rel_error = abs_error / output_scalar[i].abs().max(1e-6);
assert!(
rel_error < 0.01 || abs_error < 1e-4,
"Cross-platform attention mismatch at head_dim={}, kv_len={}, index {}: {} vs {} (rel: {:.6})",
head_dim, kv_len, i, output_neon[i], output_scalar[i], rel_error
);
}
}
}
#[test]
fn test_cross_platform_rms_norm() {
let test_cases = [8, 16, 32, 64, 128];
for dim in test_cases {
let mut x_neon: Vec<f32> = (0..dim)
.map(|i| (i as f32 - dim as f32 / 2.0) / 10.0)
.collect();
let mut x_scalar = x_neon.clone();
let weight: Vec<f32> = (0..dim).map(|i| 0.5 + (i as f32) * 0.01).collect();
let eps = 1e-6;
rms_norm_neon(&mut x_neon, &weight, eps);
rms_norm_scalar(&mut x_scalar, &weight, eps);
for i in 0..dim {
let abs_error = (x_neon[i] - x_scalar[i]).abs();
assert!(
abs_error < 1e-4,
"Cross-platform RMSNorm mismatch at dim={}, index {}: {} vs {} (abs: {:.6})",
dim,
i,
x_neon[i],
x_scalar[i],
abs_error
);
}
}
}
#[test]
fn test_cross_platform_layer_norm() {
let test_cases = [8, 16, 32, 64, 128];
for dim in test_cases {
let mut x_neon: Vec<f32> = (0..dim).map(|i| (i as f32) * 0.1 - 5.0).collect();
let mut x_scalar = x_neon.clone();
let weight: Vec<f32> = vec![1.0; dim];
let bias: Vec<f32> = vec![0.0; dim];
let eps = 1e-6;
layer_norm_neon(&mut x_neon, &weight, &bias, eps);
layer_norm_scalar(&mut x_scalar, &weight, &bias, eps);
for i in 0..dim {
let abs_error = (x_neon[i] - x_scalar[i]).abs();
assert!(
abs_error < 1e-4,
"Cross-platform LayerNorm mismatch at dim={}, index {}: {} vs {} (abs: {:.6})",
dim,
i,
x_neon[i],
x_scalar[i],
abs_error
);
}
}
}
// ========== Edge Case Tests ==========
#[test]
fn test_scalar_fallback_edge_cases() {
// Zero vectors
let a_zero = vec![0.0f32; 16];
let x_zero = vec![0.0f32; 4];
let mut y = vec![0.0f32; 4];
gemv_neon(&a_zero, &x_zero, &mut y, 4, 4);
assert!(
y.iter().all(|&v| v == 0.0),
"Zero input should give zero output"
);
// Single element
let a_single = vec![3.0f32];
let x_single = vec![4.0f32];
let mut y_single = vec![0.0f32];
gemv_neon(&a_single, &x_single, &mut y_single, 1, 1);
assert!((y_single[0] - 12.0).abs() < 1e-5, "1x1 GEMV failed");
// Negative values
let a_neg: Vec<f32> = (0..16)
.map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
.collect();
let x_neg: Vec<f32> = (0..4)
.map(|i| if i % 2 == 0 { -1.0 } else { 1.0 })
.collect();
let mut y_neg = vec![0.0f32; 4];
gemv_neon(&a_neg, &x_neg, &mut y_neg, 4, 4);
assert!(
y_neg.iter().all(|&v| v.is_finite()),
"Negative values should produce finite output"
);
}
#[test]
fn test_scalar_fallback_numerical_stability() {
// Very small values
let a_small: Vec<f32> = vec![1e-20; 64];
let x_small: Vec<f32> = vec![1e-20; 8];
let mut y_small = vec![0.0f32; 8];
gemv_neon(&a_small, &x_small, &mut y_small, 8, 8);
assert!(
y_small.iter().all(|&v| v.is_finite()),
"Very small values should produce finite output"
);
// Large values (but not overflow)
let a_large: Vec<f32> = vec![1e10; 64];
let x_large: Vec<f32> = vec![1e-10; 8]; // Scale x to avoid overflow
let mut y_large = vec![0.0f32; 8];
gemv_neon(&a_large, &x_large, &mut y_large, 8, 8);
assert!(
y_large.iter().all(|&v| v.is_finite()),
"Large values with small x should produce finite output"
);
// Mixed magnitudes
let a_mixed: Vec<f32> = (0..64)
.map(|i| if i % 2 == 0 { 1e5 } else { 1e-5 })
.collect();
let x_mixed: Vec<f32> = vec![1.0; 8];
let mut y_mixed = vec![0.0f32; 8];
gemv_neon(&a_mixed, &x_mixed, &mut y_mixed, 8, 8);
assert!(
y_mixed.iter().all(|&v| v.is_finite()),
"Mixed magnitude values should produce finite output"
);
}
#[test]
fn test_scalar_fallback_determinism() {
let m = 32;
let n = 64;
let a: Vec<f32> = (0..m * n).map(|i| ((i as f32) * 0.1).sin()).collect();
let x: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.2).cos()).collect();
// Run multiple times and verify same result
let mut results = Vec::new();
for _ in 0..5 {
let mut y = vec![0.0f32; m];
gemv_neon(&a, &x, &mut y, m, n);
results.push(y);
}
for i in 1..results.len() {
for j in 0..m {
assert_eq!(
results[0][j], results[i][j],
"GEMV should be deterministic: run 0 vs run {} differ at index {}",
i, j
);
}
}
}
// ========== WASM Compatibility Tests ==========
#[test]
fn test_wasm_compatible_operations() {
// These operations should work on WASM (no NEON)
// Test with dimensions that don't require SIMD
// Small GEMV
let a = vec![1.0, 2.0, 3.0, 4.0];
let x = vec![1.0, 1.0];
let mut y = vec![0.0; 2];
gemv_neon(&a, &x, &mut y, 2, 2);
assert!((y[0] - 3.0).abs() < 1e-5); // 1*1 + 2*1 = 3
assert!((y[1] - 7.0).abs() < 1e-5); // 3*1 + 4*1 = 7
// Small GEMM
let a_gemm = vec![1.0, 2.0, 3.0, 4.0];
let b_gemm = vec![1.0, 0.0, 0.0, 1.0]; // Identity
let mut c_gemm = vec![0.0; 4];
gemm_neon(&a_gemm, &b_gemm, &mut c_gemm, 2, 2, 2);
// A * I = A
for i in 0..4 {
assert!(
(c_gemm[i] - a_gemm[i]).abs() < 1e-5,
"GEMM with identity failed"
);
}
// Small attention
let query = vec![0.1, 0.2, 0.3, 0.4];
let key = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
let value = vec![1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0];
let scale = 0.5;
let output = flash_attention_neon(&query, &key, &value, scale, false);
assert_eq!(output.len(), 4);
assert!(output.iter().all(|&v| v.is_finite()));
}
#[test]
fn test_scalar_path_verification() {
// Test that scalar fallback path produces correct results
// for small inputs that might not trigger SIMD optimizations
// Verify GEMV with small non-aligned dimensions
let a = vec![1.0, 2.0, 3.0];
let x = vec![1.0, 2.0, 3.0];
let mut y = vec![0.0; 1];
gemv_neon(&a, &x, &mut y, 1, 3);
let expected = 1.0 + 4.0 + 9.0; // 1*1 + 2*2 + 3*3 = 14
assert!(
(y[0] - expected).abs() < 1e-5,
"Scalar GEMV expected {}, got {}",
expected,
y[0]
);
// Verify GEMM with 1x1
let a1 = vec![5.0f32];
let b1 = vec![3.0f32];
let mut c1 = vec![0.0f32];
gemm_neon(&a1, &b1, &mut c1, 1, 1, 1);
assert!(
(c1[0] - 15.0).abs() < 1e-5,
"1x1 GEMM expected 15, got {}",
c1[0]
);
// Verify normalization with small vector
let mut x_norm = vec![3.0, 4.0];
let weight = vec![1.0, 1.0];
rms_norm_neon(&mut x_norm, &weight, 1e-6);
// RMS = sqrt((9+16)/2) = sqrt(12.5) = 3.536
// Normalized: [3/3.536, 4/3.536] = [0.848, 1.131]
assert!(x_norm.iter().all(|&v| v.is_finite()));
}