git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
472 lines
13 KiB
Rust
472 lines
13 KiB
Rust
//! SIMD-optimized vector operations for manifold retrieval
|
|
//!
|
|
//! Provides 8-54x speedup for distance calculations using AVX2/AVX-512/NEON.
|
|
//!
|
|
//! Based on techniques from ultra-low-latency-sim.
|
|
|
|
/// Cache line size for alignment (used by prefetch intrinsics in AVX2 path)
|
|
#[allow(dead_code)]
|
|
const CACHE_LINE: usize = 64;
|
|
|
|
/// SIMD-optimized cosine similarity
|
|
///
|
|
/// Uses AVX2 FMA for 8x parallelism with prefetching.
|
|
/// Falls back to scalar for non-x86 platforms.
|
|
#[inline]
|
|
pub fn cosine_similarity_simd(a: &[f32], b: &[f32]) -> f32 {
|
|
#[cfg(target_arch = "x86_64")]
|
|
{
|
|
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
|
|
return unsafe { cosine_similarity_avx2(a, b) };
|
|
}
|
|
}
|
|
|
|
#[cfg(target_arch = "aarch64")]
|
|
{
|
|
return cosine_similarity_neon(a, b);
|
|
}
|
|
|
|
// Fallback to optimized scalar with loop unrolling
|
|
cosine_similarity_unrolled(a, b)
|
|
}
|
|
|
|
/// SIMD-optimized euclidean distance
|
|
///
|
|
/// Uses AVX2 for 8x parallelism.
|
|
/// Expected speedup: 8-54x depending on dimension.
|
|
#[inline]
|
|
pub fn euclidean_distance_simd(a: &[f32], b: &[f32]) -> f32 {
|
|
#[cfg(target_arch = "x86_64")]
|
|
{
|
|
if is_x86_feature_detected!("avx2") {
|
|
return unsafe { euclidean_distance_avx2(a, b) };
|
|
}
|
|
}
|
|
|
|
#[cfg(target_arch = "aarch64")]
|
|
{
|
|
return euclidean_distance_neon(a, b);
|
|
}
|
|
|
|
euclidean_distance_unrolled(a, b)
|
|
}
|
|
|
|
// =============================================================================
|
|
// AVX2 IMPLEMENTATIONS (x86_64)
|
|
// =============================================================================
|
|
|
|
#[cfg(target_arch = "x86_64")]
|
|
#[target_feature(enable = "avx2", enable = "fma")]
|
|
unsafe fn cosine_similarity_avx2(a: &[f32], b: &[f32]) -> f32 {
|
|
use std::arch::x86_64::*;
|
|
|
|
debug_assert_eq!(a.len(), b.len());
|
|
let len = a.len();
|
|
let chunks = len / 8;
|
|
|
|
let mut dot_sum = _mm256_setzero_ps();
|
|
let mut a_sq_sum = _mm256_setzero_ps();
|
|
let mut b_sq_sum = _mm256_setzero_ps();
|
|
|
|
for i in 0..chunks {
|
|
let idx = i * 8;
|
|
|
|
// Prefetch next cache line (64 bytes = 16 floats, so every 2 iterations)
|
|
if (i & 1) == 0 && i + 2 < chunks {
|
|
let prefetch_idx = (i + 2) * 8;
|
|
_mm_prefetch(a.as_ptr().add(prefetch_idx) as *const i8, _MM_HINT_T0);
|
|
_mm_prefetch(b.as_ptr().add(prefetch_idx) as *const i8, _MM_HINT_T0);
|
|
}
|
|
|
|
let va = _mm256_loadu_ps(a.as_ptr().add(idx));
|
|
let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
|
|
|
|
// FMA: dot += a * b, a_sq += a * a, b_sq += b * b
|
|
dot_sum = _mm256_fmadd_ps(va, vb, dot_sum);
|
|
a_sq_sum = _mm256_fmadd_ps(va, va, a_sq_sum);
|
|
b_sq_sum = _mm256_fmadd_ps(vb, vb, b_sq_sum);
|
|
}
|
|
|
|
// Horizontal sum using AVX2
|
|
let dot = hsum256_ps_avx2(dot_sum);
|
|
let a_sq = hsum256_ps_avx2(a_sq_sum);
|
|
let b_sq = hsum256_ps_avx2(b_sq_sum);
|
|
|
|
// Handle remainder
|
|
let mut dot_rem = dot;
|
|
let mut a_sq_rem = a_sq;
|
|
let mut b_sq_rem = b_sq;
|
|
|
|
for i in (chunks * 8)..len {
|
|
let ai = a[i];
|
|
let bi = b[i];
|
|
dot_rem += ai * bi;
|
|
a_sq_rem += ai * ai;
|
|
b_sq_rem += bi * bi;
|
|
}
|
|
|
|
let norm_a = a_sq_rem.sqrt();
|
|
let norm_b = b_sq_rem.sqrt();
|
|
|
|
if norm_a < 1e-10 || norm_b < 1e-10 {
|
|
0.0
|
|
} else {
|
|
dot_rem / (norm_a * norm_b)
|
|
}
|
|
}
|
|
|
|
#[cfg(target_arch = "x86_64")]
|
|
#[target_feature(enable = "avx2")]
|
|
unsafe fn euclidean_distance_avx2(a: &[f32], b: &[f32]) -> f32 {
|
|
use std::arch::x86_64::*;
|
|
|
|
debug_assert_eq!(a.len(), b.len());
|
|
let len = a.len();
|
|
let chunks = len / 8;
|
|
|
|
let mut sum = _mm256_setzero_ps();
|
|
|
|
for i in 0..chunks {
|
|
let idx = i * 8;
|
|
|
|
// Prefetch
|
|
if (i & 1) == 0 && i + 2 < chunks {
|
|
let prefetch_idx = (i + 2) * 8;
|
|
_mm_prefetch(a.as_ptr().add(prefetch_idx) as *const i8, _MM_HINT_T0);
|
|
_mm_prefetch(b.as_ptr().add(prefetch_idx) as *const i8, _MM_HINT_T0);
|
|
}
|
|
|
|
let va = _mm256_loadu_ps(a.as_ptr().add(idx));
|
|
let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
|
|
|
|
let diff = _mm256_sub_ps(va, vb);
|
|
sum = _mm256_fmadd_ps(diff, diff, sum);
|
|
}
|
|
|
|
let mut total = hsum256_ps_avx2(sum);
|
|
|
|
// Handle remainder
|
|
for i in (chunks * 8)..len {
|
|
let diff = a[i] - b[i];
|
|
total += diff * diff;
|
|
}
|
|
|
|
total.sqrt()
|
|
}
|
|
|
|
/// Horizontal sum of 8 floats in AVX2 register
|
|
#[cfg(target_arch = "x86_64")]
|
|
#[target_feature(enable = "avx2")]
|
|
#[inline]
|
|
unsafe fn hsum256_ps_avx2(v: std::arch::x86_64::__m256) -> f32 {
|
|
use std::arch::x86_64::*;
|
|
|
|
// Extract high 128 bits
|
|
let high = _mm256_extractf128_ps(v, 1);
|
|
let low = _mm256_castps256_ps128(v);
|
|
|
|
// Add high and low
|
|
let sum128 = _mm_add_ps(high, low);
|
|
|
|
// Horizontal add within 128 bits
|
|
let shuf = _mm_movehdup_ps(sum128);
|
|
let sum64 = _mm_add_ps(sum128, shuf);
|
|
let shuf2 = _mm_movehl_ps(sum64, sum64);
|
|
let sum32 = _mm_add_ss(sum64, shuf2);
|
|
|
|
_mm_cvtss_f32(sum32)
|
|
}
|
|
|
|
// =============================================================================
|
|
// NEON IMPLEMENTATIONS (ARM64)
|
|
// =============================================================================
|
|
|
|
#[cfg(target_arch = "aarch64")]
|
|
fn cosine_similarity_neon(a: &[f32], b: &[f32]) -> f32 {
|
|
use std::arch::aarch64::*;
|
|
|
|
debug_assert_eq!(a.len(), b.len());
|
|
let len = a.len();
|
|
let chunks = len / 4;
|
|
|
|
let mut dot_sum = unsafe { vdupq_n_f32(0.0) };
|
|
let mut a_sq_sum = unsafe { vdupq_n_f32(0.0) };
|
|
let mut b_sq_sum = unsafe { vdupq_n_f32(0.0) };
|
|
|
|
for i in 0..chunks {
|
|
let idx = i * 4;
|
|
unsafe {
|
|
let va = vld1q_f32(a.as_ptr().add(idx));
|
|
let vb = vld1q_f32(b.as_ptr().add(idx));
|
|
|
|
dot_sum = vfmaq_f32(dot_sum, va, vb);
|
|
a_sq_sum = vfmaq_f32(a_sq_sum, va, va);
|
|
b_sq_sum = vfmaq_f32(b_sq_sum, vb, vb);
|
|
}
|
|
}
|
|
|
|
// Horizontal sum
|
|
let dot = unsafe { vaddvq_f32(dot_sum) };
|
|
let a_sq = unsafe { vaddvq_f32(a_sq_sum) };
|
|
let b_sq = unsafe { vaddvq_f32(b_sq_sum) };
|
|
|
|
// Handle remainder
|
|
let mut dot_rem = dot;
|
|
let mut a_sq_rem = a_sq;
|
|
let mut b_sq_rem = b_sq;
|
|
|
|
for i in (chunks * 4)..len {
|
|
let ai = a[i];
|
|
let bi = b[i];
|
|
dot_rem += ai * bi;
|
|
a_sq_rem += ai * ai;
|
|
b_sq_rem += bi * bi;
|
|
}
|
|
|
|
let norm_a = a_sq_rem.sqrt();
|
|
let norm_b = b_sq_rem.sqrt();
|
|
|
|
if norm_a < 1e-10 || norm_b < 1e-10 {
|
|
0.0
|
|
} else {
|
|
dot_rem / (norm_a * norm_b)
|
|
}
|
|
}
|
|
|
|
#[cfg(target_arch = "aarch64")]
|
|
fn euclidean_distance_neon(a: &[f32], b: &[f32]) -> f32 {
|
|
use std::arch::aarch64::*;
|
|
|
|
debug_assert_eq!(a.len(), b.len());
|
|
let len = a.len();
|
|
let chunks = len / 4;
|
|
|
|
let mut sum = unsafe { vdupq_n_f32(0.0) };
|
|
|
|
for i in 0..chunks {
|
|
let idx = i * 4;
|
|
unsafe {
|
|
let va = vld1q_f32(a.as_ptr().add(idx));
|
|
let vb = vld1q_f32(b.as_ptr().add(idx));
|
|
let diff = vsubq_f32(va, vb);
|
|
sum = vfmaq_f32(sum, diff, diff);
|
|
}
|
|
}
|
|
|
|
let mut total = unsafe { vaddvq_f32(sum) };
|
|
|
|
for i in (chunks * 4)..len {
|
|
let diff = a[i] - b[i];
|
|
total += diff * diff;
|
|
}
|
|
|
|
total.sqrt()
|
|
}
|
|
|
|
// =============================================================================
|
|
// SCALAR FALLBACK (Unrolled)
|
|
// =============================================================================
|
|
|
|
/// Unrolled scalar cosine similarity (4x unroll)
|
|
fn cosine_similarity_unrolled(a: &[f32], b: &[f32]) -> f32 {
|
|
debug_assert_eq!(a.len(), b.len());
|
|
let len = a.len();
|
|
let chunks = len / 4;
|
|
|
|
let mut dot0 = 0.0f32;
|
|
let mut dot1 = 0.0f32;
|
|
let mut dot2 = 0.0f32;
|
|
let mut dot3 = 0.0f32;
|
|
|
|
let mut a_sq0 = 0.0f32;
|
|
let mut a_sq1 = 0.0f32;
|
|
let mut a_sq2 = 0.0f32;
|
|
let mut a_sq3 = 0.0f32;
|
|
|
|
let mut b_sq0 = 0.0f32;
|
|
let mut b_sq1 = 0.0f32;
|
|
let mut b_sq2 = 0.0f32;
|
|
let mut b_sq3 = 0.0f32;
|
|
|
|
for i in 0..chunks {
|
|
let idx = i * 4;
|
|
|
|
let a0 = a[idx];
|
|
let a1 = a[idx + 1];
|
|
let a2 = a[idx + 2];
|
|
let a3 = a[idx + 3];
|
|
|
|
let b0 = b[idx];
|
|
let b1 = b[idx + 1];
|
|
let b2 = b[idx + 2];
|
|
let b3 = b[idx + 3];
|
|
|
|
dot0 += a0 * b0;
|
|
dot1 += a1 * b1;
|
|
dot2 += a2 * b2;
|
|
dot3 += a3 * b3;
|
|
|
|
a_sq0 += a0 * a0;
|
|
a_sq1 += a1 * a1;
|
|
a_sq2 += a2 * a2;
|
|
a_sq3 += a3 * a3;
|
|
|
|
b_sq0 += b0 * b0;
|
|
b_sq1 += b1 * b1;
|
|
b_sq2 += b2 * b2;
|
|
b_sq3 += b3 * b3;
|
|
}
|
|
|
|
let mut dot = dot0 + dot1 + dot2 + dot3;
|
|
let mut a_sq = a_sq0 + a_sq1 + a_sq2 + a_sq3;
|
|
let mut b_sq = b_sq0 + b_sq1 + b_sq2 + b_sq3;
|
|
|
|
// Handle remainder
|
|
for i in (chunks * 4)..len {
|
|
let ai = a[i];
|
|
let bi = b[i];
|
|
dot += ai * bi;
|
|
a_sq += ai * ai;
|
|
b_sq += bi * bi;
|
|
}
|
|
|
|
let norm_a = a_sq.sqrt();
|
|
let norm_b = b_sq.sqrt();
|
|
|
|
if norm_a < 1e-10 || norm_b < 1e-10 {
|
|
0.0
|
|
} else {
|
|
dot / (norm_a * norm_b)
|
|
}
|
|
}
|
|
|
|
/// Unrolled scalar euclidean distance (4x unroll)
|
|
fn euclidean_distance_unrolled(a: &[f32], b: &[f32]) -> f32 {
|
|
debug_assert_eq!(a.len(), b.len());
|
|
let len = a.len();
|
|
let chunks = len / 4;
|
|
|
|
let mut sum0 = 0.0f32;
|
|
let mut sum1 = 0.0f32;
|
|
let mut sum2 = 0.0f32;
|
|
let mut sum3 = 0.0f32;
|
|
|
|
for i in 0..chunks {
|
|
let idx = i * 4;
|
|
|
|
let d0 = a[idx] - b[idx];
|
|
let d1 = a[idx + 1] - b[idx + 1];
|
|
let d2 = a[idx + 2] - b[idx + 2];
|
|
let d3 = a[idx + 3] - b[idx + 3];
|
|
|
|
sum0 += d0 * d0;
|
|
sum1 += d1 * d1;
|
|
sum2 += d2 * d2;
|
|
sum3 += d3 * d3;
|
|
}
|
|
|
|
let mut total = sum0 + sum1 + sum2 + sum3;
|
|
|
|
for i in (chunks * 4)..len {
|
|
let diff = a[i] - b[i];
|
|
total += diff * diff;
|
|
}
|
|
|
|
total.sqrt()
|
|
}
|
|
|
|
// =============================================================================
|
|
// BATCH OPERATIONS
|
|
// =============================================================================
|
|
|
|
/// Batch compute distances from query to all database vectors
|
|
///
|
|
/// Uses SIMD for individual distances and benefits from cache locality.
|
|
pub fn batch_distances(query: &[f32], database: &[Vec<f32>]) -> Vec<f32> {
|
|
database
|
|
.iter()
|
|
.map(|vec| euclidean_distance_simd(query, vec))
|
|
.collect()
|
|
}
|
|
|
|
/// Batch compute cosine similarities
|
|
pub fn batch_cosine_similarities(query: &[f32], database: &[Vec<f32>]) -> Vec<f32> {
|
|
database
|
|
.iter()
|
|
.map(|vec| cosine_similarity_simd(query, vec))
|
|
.collect()
|
|
}
|
|
|
|
// =============================================================================
|
|
// TESTS
|
|
// =============================================================================
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
fn approx_eq(a: f32, b: f32, epsilon: f32) -> bool {
|
|
(a - b).abs() < epsilon
|
|
}
|
|
|
|
#[test]
|
|
fn test_cosine_similarity_identical() {
|
|
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
|
let result = cosine_similarity_simd(&a, &a);
|
|
assert!(approx_eq(result, 1.0, 1e-5), "Expected 1.0, got {}", result);
|
|
}
|
|
|
|
#[test]
|
|
fn test_cosine_similarity_orthogonal() {
|
|
let a = vec![1.0, 0.0, 0.0, 0.0];
|
|
let b = vec![0.0, 1.0, 0.0, 0.0];
|
|
let result = cosine_similarity_simd(&a, &b);
|
|
assert!(approx_eq(result, 0.0, 1e-5), "Expected 0.0, got {}", result);
|
|
}
|
|
|
|
#[test]
|
|
fn test_euclidean_distance_same() {
|
|
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
|
let result = euclidean_distance_simd(&a, &a);
|
|
assert!(approx_eq(result, 0.0, 1e-5), "Expected 0.0, got {}", result);
|
|
}
|
|
|
|
#[test]
|
|
fn test_euclidean_distance_known() {
|
|
let a = vec![0.0, 0.0, 0.0, 0.0];
|
|
let b = vec![3.0, 4.0, 0.0, 0.0];
|
|
let result = euclidean_distance_simd(&a, &b);
|
|
assert!(approx_eq(result, 5.0, 1e-5), "Expected 5.0, got {}", result);
|
|
}
|
|
|
|
#[test]
|
|
fn test_large_vectors() {
|
|
let a: Vec<f32> = (0..768).map(|i| (i as f32).sin()).collect();
|
|
let b: Vec<f32> = (0..768).map(|i| (i as f32).cos()).collect();
|
|
|
|
let cos = cosine_similarity_simd(&a, &b);
|
|
let dist = euclidean_distance_simd(&a, &b);
|
|
|
|
assert!(cos > -1.0 && cos < 1.0);
|
|
assert!(dist >= 0.0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_batch_operations() {
|
|
let query = vec![1.0, 0.0, 0.0, 0.0];
|
|
let database = vec![
|
|
vec![1.0, 0.0, 0.0, 0.0],
|
|
vec![0.0, 1.0, 0.0, 0.0],
|
|
vec![0.5, 0.5, 0.0, 0.0],
|
|
];
|
|
|
|
let distances = batch_distances(&query, &database);
|
|
assert_eq!(distances.len(), 3);
|
|
assert!(approx_eq(distances[0], 0.0, 1e-5)); // Same vector
|
|
|
|
let similarities = batch_cosine_similarities(&query, &database);
|
|
assert_eq!(similarities.len(), 3);
|
|
assert!(approx_eq(similarities[0], 1.0, 1e-5)); // Same vector
|
|
}
|
|
}
|