Files
wifi-densepose/examples/exo-ai-2025/crates/exo-manifold/src/simd_ops.rs
ruv d803bfe2b1 Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector
git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
2026-02-28 14:39:40 -05:00

472 lines
13 KiB
Rust

//! SIMD-optimized vector operations for manifold retrieval
//!
//! Provides 8-54x speedup for distance calculations using AVX2/AVX-512/NEON.
//!
//! Based on techniques from ultra-low-latency-sim.
/// Cache line size for alignment (used by prefetch intrinsics in AVX2 path)
#[allow(dead_code)]
const CACHE_LINE: usize = 64;
/// SIMD-optimized cosine similarity
///
/// Uses AVX2 FMA for 8x parallelism with prefetching.
/// Falls back to scalar for non-x86 platforms.
#[inline]
pub fn cosine_similarity_simd(a: &[f32], b: &[f32]) -> f32 {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
return unsafe { cosine_similarity_avx2(a, b) };
}
}
#[cfg(target_arch = "aarch64")]
{
return cosine_similarity_neon(a, b);
}
// Fallback to optimized scalar with loop unrolling
cosine_similarity_unrolled(a, b)
}
/// SIMD-optimized euclidean distance
///
/// Uses AVX2 for 8x parallelism.
/// Expected speedup: 8-54x depending on dimension.
#[inline]
pub fn euclidean_distance_simd(a: &[f32], b: &[f32]) -> f32 {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
return unsafe { euclidean_distance_avx2(a, b) };
}
}
#[cfg(target_arch = "aarch64")]
{
return euclidean_distance_neon(a, b);
}
euclidean_distance_unrolled(a, b)
}
// =============================================================================
// AVX2 IMPLEMENTATIONS (x86_64)
// =============================================================================
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
unsafe fn cosine_similarity_avx2(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
debug_assert_eq!(a.len(), b.len());
let len = a.len();
let chunks = len / 8;
let mut dot_sum = _mm256_setzero_ps();
let mut a_sq_sum = _mm256_setzero_ps();
let mut b_sq_sum = _mm256_setzero_ps();
for i in 0..chunks {
let idx = i * 8;
// Prefetch next cache line (64 bytes = 16 floats, so every 2 iterations)
if (i & 1) == 0 && i + 2 < chunks {
let prefetch_idx = (i + 2) * 8;
_mm_prefetch(a.as_ptr().add(prefetch_idx) as *const i8, _MM_HINT_T0);
_mm_prefetch(b.as_ptr().add(prefetch_idx) as *const i8, _MM_HINT_T0);
}
let va = _mm256_loadu_ps(a.as_ptr().add(idx));
let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
// FMA: dot += a * b, a_sq += a * a, b_sq += b * b
dot_sum = _mm256_fmadd_ps(va, vb, dot_sum);
a_sq_sum = _mm256_fmadd_ps(va, va, a_sq_sum);
b_sq_sum = _mm256_fmadd_ps(vb, vb, b_sq_sum);
}
// Horizontal sum using AVX2
let dot = hsum256_ps_avx2(dot_sum);
let a_sq = hsum256_ps_avx2(a_sq_sum);
let b_sq = hsum256_ps_avx2(b_sq_sum);
// Handle remainder
let mut dot_rem = dot;
let mut a_sq_rem = a_sq;
let mut b_sq_rem = b_sq;
for i in (chunks * 8)..len {
let ai = a[i];
let bi = b[i];
dot_rem += ai * bi;
a_sq_rem += ai * ai;
b_sq_rem += bi * bi;
}
let norm_a = a_sq_rem.sqrt();
let norm_b = b_sq_rem.sqrt();
if norm_a < 1e-10 || norm_b < 1e-10 {
0.0
} else {
dot_rem / (norm_a * norm_b)
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn euclidean_distance_avx2(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
debug_assert_eq!(a.len(), b.len());
let len = a.len();
let chunks = len / 8;
let mut sum = _mm256_setzero_ps();
for i in 0..chunks {
let idx = i * 8;
// Prefetch
if (i & 1) == 0 && i + 2 < chunks {
let prefetch_idx = (i + 2) * 8;
_mm_prefetch(a.as_ptr().add(prefetch_idx) as *const i8, _MM_HINT_T0);
_mm_prefetch(b.as_ptr().add(prefetch_idx) as *const i8, _MM_HINT_T0);
}
let va = _mm256_loadu_ps(a.as_ptr().add(idx));
let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
let diff = _mm256_sub_ps(va, vb);
sum = _mm256_fmadd_ps(diff, diff, sum);
}
let mut total = hsum256_ps_avx2(sum);
// Handle remainder
for i in (chunks * 8)..len {
let diff = a[i] - b[i];
total += diff * diff;
}
total.sqrt()
}
/// Horizontal sum of 8 floats in AVX2 register
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn hsum256_ps_avx2(v: std::arch::x86_64::__m256) -> f32 {
use std::arch::x86_64::*;
// Extract high 128 bits
let high = _mm256_extractf128_ps(v, 1);
let low = _mm256_castps256_ps128(v);
// Add high and low
let sum128 = _mm_add_ps(high, low);
// Horizontal add within 128 bits
let shuf = _mm_movehdup_ps(sum128);
let sum64 = _mm_add_ps(sum128, shuf);
let shuf2 = _mm_movehl_ps(sum64, sum64);
let sum32 = _mm_add_ss(sum64, shuf2);
_mm_cvtss_f32(sum32)
}
// =============================================================================
// NEON IMPLEMENTATIONS (ARM64)
// =============================================================================
#[cfg(target_arch = "aarch64")]
fn cosine_similarity_neon(a: &[f32], b: &[f32]) -> f32 {
use std::arch::aarch64::*;
debug_assert_eq!(a.len(), b.len());
let len = a.len();
let chunks = len / 4;
let mut dot_sum = unsafe { vdupq_n_f32(0.0) };
let mut a_sq_sum = unsafe { vdupq_n_f32(0.0) };
let mut b_sq_sum = unsafe { vdupq_n_f32(0.0) };
for i in 0..chunks {
let idx = i * 4;
unsafe {
let va = vld1q_f32(a.as_ptr().add(idx));
let vb = vld1q_f32(b.as_ptr().add(idx));
dot_sum = vfmaq_f32(dot_sum, va, vb);
a_sq_sum = vfmaq_f32(a_sq_sum, va, va);
b_sq_sum = vfmaq_f32(b_sq_sum, vb, vb);
}
}
// Horizontal sum
let dot = unsafe { vaddvq_f32(dot_sum) };
let a_sq = unsafe { vaddvq_f32(a_sq_sum) };
let b_sq = unsafe { vaddvq_f32(b_sq_sum) };
// Handle remainder
let mut dot_rem = dot;
let mut a_sq_rem = a_sq;
let mut b_sq_rem = b_sq;
for i in (chunks * 4)..len {
let ai = a[i];
let bi = b[i];
dot_rem += ai * bi;
a_sq_rem += ai * ai;
b_sq_rem += bi * bi;
}
let norm_a = a_sq_rem.sqrt();
let norm_b = b_sq_rem.sqrt();
if norm_a < 1e-10 || norm_b < 1e-10 {
0.0
} else {
dot_rem / (norm_a * norm_b)
}
}
#[cfg(target_arch = "aarch64")]
fn euclidean_distance_neon(a: &[f32], b: &[f32]) -> f32 {
use std::arch::aarch64::*;
debug_assert_eq!(a.len(), b.len());
let len = a.len();
let chunks = len / 4;
let mut sum = unsafe { vdupq_n_f32(0.0) };
for i in 0..chunks {
let idx = i * 4;
unsafe {
let va = vld1q_f32(a.as_ptr().add(idx));
let vb = vld1q_f32(b.as_ptr().add(idx));
let diff = vsubq_f32(va, vb);
sum = vfmaq_f32(sum, diff, diff);
}
}
let mut total = unsafe { vaddvq_f32(sum) };
for i in (chunks * 4)..len {
let diff = a[i] - b[i];
total += diff * diff;
}
total.sqrt()
}
// =============================================================================
// SCALAR FALLBACK (Unrolled)
// =============================================================================
/// Unrolled scalar cosine similarity (4x unroll)
fn cosine_similarity_unrolled(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
let len = a.len();
let chunks = len / 4;
let mut dot0 = 0.0f32;
let mut dot1 = 0.0f32;
let mut dot2 = 0.0f32;
let mut dot3 = 0.0f32;
let mut a_sq0 = 0.0f32;
let mut a_sq1 = 0.0f32;
let mut a_sq2 = 0.0f32;
let mut a_sq3 = 0.0f32;
let mut b_sq0 = 0.0f32;
let mut b_sq1 = 0.0f32;
let mut b_sq2 = 0.0f32;
let mut b_sq3 = 0.0f32;
for i in 0..chunks {
let idx = i * 4;
let a0 = a[idx];
let a1 = a[idx + 1];
let a2 = a[idx + 2];
let a3 = a[idx + 3];
let b0 = b[idx];
let b1 = b[idx + 1];
let b2 = b[idx + 2];
let b3 = b[idx + 3];
dot0 += a0 * b0;
dot1 += a1 * b1;
dot2 += a2 * b2;
dot3 += a3 * b3;
a_sq0 += a0 * a0;
a_sq1 += a1 * a1;
a_sq2 += a2 * a2;
a_sq3 += a3 * a3;
b_sq0 += b0 * b0;
b_sq1 += b1 * b1;
b_sq2 += b2 * b2;
b_sq3 += b3 * b3;
}
let mut dot = dot0 + dot1 + dot2 + dot3;
let mut a_sq = a_sq0 + a_sq1 + a_sq2 + a_sq3;
let mut b_sq = b_sq0 + b_sq1 + b_sq2 + b_sq3;
// Handle remainder
for i in (chunks * 4)..len {
let ai = a[i];
let bi = b[i];
dot += ai * bi;
a_sq += ai * ai;
b_sq += bi * bi;
}
let norm_a = a_sq.sqrt();
let norm_b = b_sq.sqrt();
if norm_a < 1e-10 || norm_b < 1e-10 {
0.0
} else {
dot / (norm_a * norm_b)
}
}
/// Unrolled scalar euclidean distance (4x unroll)
fn euclidean_distance_unrolled(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
let len = a.len();
let chunks = len / 4;
let mut sum0 = 0.0f32;
let mut sum1 = 0.0f32;
let mut sum2 = 0.0f32;
let mut sum3 = 0.0f32;
for i in 0..chunks {
let idx = i * 4;
let d0 = a[idx] - b[idx];
let d1 = a[idx + 1] - b[idx + 1];
let d2 = a[idx + 2] - b[idx + 2];
let d3 = a[idx + 3] - b[idx + 3];
sum0 += d0 * d0;
sum1 += d1 * d1;
sum2 += d2 * d2;
sum3 += d3 * d3;
}
let mut total = sum0 + sum1 + sum2 + sum3;
for i in (chunks * 4)..len {
let diff = a[i] - b[i];
total += diff * diff;
}
total.sqrt()
}
// =============================================================================
// BATCH OPERATIONS
// =============================================================================
/// Batch compute distances from query to all database vectors
///
/// Uses SIMD for individual distances and benefits from cache locality.
pub fn batch_distances(query: &[f32], database: &[Vec<f32>]) -> Vec<f32> {
database
.iter()
.map(|vec| euclidean_distance_simd(query, vec))
.collect()
}
/// Batch compute cosine similarities
pub fn batch_cosine_similarities(query: &[f32], database: &[Vec<f32>]) -> Vec<f32> {
database
.iter()
.map(|vec| cosine_similarity_simd(query, vec))
.collect()
}
// =============================================================================
// TESTS
// =============================================================================
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: f32, b: f32, epsilon: f32) -> bool {
(a - b).abs() < epsilon
}
#[test]
fn test_cosine_similarity_identical() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let result = cosine_similarity_simd(&a, &a);
assert!(approx_eq(result, 1.0, 1e-5), "Expected 1.0, got {}", result);
}
#[test]
fn test_cosine_similarity_orthogonal() {
let a = vec![1.0, 0.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0, 0.0];
let result = cosine_similarity_simd(&a, &b);
assert!(approx_eq(result, 0.0, 1e-5), "Expected 0.0, got {}", result);
}
#[test]
fn test_euclidean_distance_same() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let result = euclidean_distance_simd(&a, &a);
assert!(approx_eq(result, 0.0, 1e-5), "Expected 0.0, got {}", result);
}
#[test]
fn test_euclidean_distance_known() {
let a = vec![0.0, 0.0, 0.0, 0.0];
let b = vec![3.0, 4.0, 0.0, 0.0];
let result = euclidean_distance_simd(&a, &b);
assert!(approx_eq(result, 5.0, 1e-5), "Expected 5.0, got {}", result);
}
#[test]
fn test_large_vectors() {
let a: Vec<f32> = (0..768).map(|i| (i as f32).sin()).collect();
let b: Vec<f32> = (0..768).map(|i| (i as f32).cos()).collect();
let cos = cosine_similarity_simd(&a, &b);
let dist = euclidean_distance_simd(&a, &b);
assert!(cos > -1.0 && cos < 1.0);
assert!(dist >= 0.0);
}
#[test]
fn test_batch_operations() {
let query = vec![1.0, 0.0, 0.0, 0.0];
let database = vec![
vec![1.0, 0.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0, 0.0],
vec![0.5, 0.5, 0.0, 0.0],
];
let distances = batch_distances(&query, &database);
assert_eq!(distances.len(), 3);
assert!(approx_eq(distances[0], 0.0, 1e-5)); // Same vector
let similarities = batch_cosine_similarities(&query, &database);
assert_eq!(similarities.len(), 3);
assert!(approx_eq(similarities[0], 1.0, 1e-5)); // Same vector
}
}