Files
wifi-densepose/docs/research/latent-space/implementation-plans/agents/14-simd-optimizations.md
ruv d803bfe2b1 Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector
git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
2026-02-28 14:39:40 -05:00

34 KiB

Agent 14: SIMD Optimizations - Implementation Plan

Overview

This implementation plan covers SIMD-optimized kernels for critical operations in the GNN latent space system, targeting 4-8x speedup for vector operations and 2-4x for attention mechanisms through platform-specific optimizations.

Architecture

src/simd/
├── mod.rs                 # Public API and feature detection
├── kernels.rs             # High-level kernel interface
├── x86/
│   ├── mod.rs
│   ├── avx2.rs           # AVX2 optimizations
│   ├── sse.rs            # SSE fallback
│   └── fma.rs            # FMA instructions
├── arm/
│   ├── mod.rs
│   └── neon.rs           # ARM NEON optimizations
├── wasm/
│   ├── mod.rs
│   └── simd128.rs        # WASM SIMD128
└── fallback/
    └── scalar.rs          # Portable fallback

1. AVX2/SSE Optimizations

1.1 SIMD Dot Product (AVX2)

// src/simd/x86/avx2.rs

#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;

/// AVX2-optimized dot product for f32 vectors
///
/// # Safety
/// Requires AVX2 support (checked at runtime)
#[target_feature(enable = "avx2")]
#[target_feature(enable = "fma")]
pub unsafe fn simd_dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
    debug_assert_eq!(a.len(), b.len());

    let len = a.len();
    let mut sum = _mm256_setzero_ps();

    // Process 8 floats at a time (256-bit registers)
    let chunks = len / 8;
    let remainder = len % 8;

    for i in 0..chunks {
        let offset = i * 8;

        // Load 8 floats from each array
        let va = _mm256_loadu_ps(a.as_ptr().add(offset));
        let vb = _mm256_loadu_ps(b.as_ptr().add(offset));

        // Fused multiply-add: sum += a * b
        sum = _mm256_fmadd_ps(va, vb, sum);
    }

    // Horizontal sum of 8 lanes
    let mut result = horizontal_sum_avx2(sum);

    // Handle remainder with scalar operations
    for i in (len - remainder)..len {
        result += a[i] * b[i];
    }

    result
}

/// Horizontal sum of AVX2 register
#[target_feature(enable = "avx2")]
unsafe fn horizontal_sum_avx2(v: __m256) -> f32 {
    // Sum high and low 128-bit lanes
    let high = _mm256_extractf128_ps(v, 1);
    let low = _mm256_castps256_ps128(v);
    let sum128 = _mm_add_ps(high, low);

    // Sum 4 lanes to 2
    let shuf = _mm_movehdup_ps(sum128);
    let sum2 = _mm_add_ps(sum128, shuf);

    // Sum 2 lanes to 1
    let shuf = _mm_movehl_ps(shuf, sum2);
    let sum1 = _mm_add_ss(sum2, shuf);

    _mm_cvtss_f32(sum1)
}

/// AVX2-optimized weighted sum
#[target_feature(enable = "avx2")]
#[target_feature(enable = "fma")]
pub unsafe fn simd_weighted_sum_avx2(
    vectors: &[&[f32]],
    weights: &[f32],
    output: &mut [f32],
) {
    debug_assert_eq!(vectors.len(), weights.len());

    let dim = output.len();
    let n_vectors = vectors.len();

    // Zero output
    for i in 0..(dim / 8) {
        let offset = i * 8;
        _mm256_storeu_ps(output.as_mut_ptr().add(offset), _mm256_setzero_ps());
    }

    // Process each vector
    for (vec, &weight) in vectors.iter().zip(weights.iter()) {
        debug_assert_eq!(vec.len(), dim);

        // Broadcast weight to all 8 lanes
        let vweight = _mm256_set1_ps(weight);

        // Process 8 floats at a time
        for i in 0..(dim / 8) {
            let offset = i * 8;

            let vvec = _mm256_loadu_ps(vec.as_ptr().add(offset));
            let vout = _mm256_loadu_ps(output.as_ptr().add(offset));

            // output += vec * weight
            let result = _mm256_fmadd_ps(vvec, vweight, vout);
            _mm256_storeu_ps(output.as_mut_ptr().add(offset), result);
        }
    }

    // Handle remainder
    let remainder_start = (dim / 8) * 8;
    for i in remainder_start..dim {
        for (vec, &weight) in vectors.iter().zip(weights.iter()) {
            output[i] += vec[i] * weight;
        }
    }
}

/// AVX2-optimized softmax
#[target_feature(enable = "avx2")]
#[target_feature(enable = "fma")]
pub unsafe fn simd_softmax_avx2(input: &[f32], output: &mut [f32]) {
    debug_assert_eq!(input.len(), output.len());

    let len = input.len();

    // Find max value for numerical stability
    let max_val = simd_max_avx2(input);
    let vmax = _mm256_set1_ps(max_val);

    // Compute exp(x - max) and sum
    let mut sum = _mm256_setzero_ps();

    for i in 0..(len / 8) {
        let offset = i * 8;
        let vx = _mm256_loadu_ps(input.as_ptr().add(offset));

        // x - max
        let vx_shifted = _mm256_sub_ps(vx, vmax);

        // exp(x - max)
        let vexp = simd_exp_avx2(vx_shifted);
        _mm256_storeu_ps(output.as_mut_ptr().add(offset), vexp);

        sum = _mm256_add_ps(sum, vexp);
    }

    // Sum remainder
    let mut sum_scalar = horizontal_sum_avx2(sum);
    let remainder_start = (len / 8) * 8;
    for i in remainder_start..len {
        let exp_val = (input[i] - max_val).exp();
        output[i] = exp_val;
        sum_scalar += exp_val;
    }

    // Divide by sum
    let vsum = _mm256_set1_ps(sum_scalar);
    for i in 0..(len / 8) {
        let offset = i * 8;
        let vexp = _mm256_loadu_ps(output.as_ptr().add(offset));
        let result = _mm256_div_ps(vexp, vsum);
        _mm256_storeu_ps(output.as_mut_ptr().add(offset), result);
    }

    for i in remainder_start..len {
        output[i] /= sum_scalar;
    }
}

/// Fast AVX2 exponential approximation
#[target_feature(enable = "avx2")]
unsafe fn simd_exp_avx2(x: __m256) -> __m256 {
    // Polynomial approximation for exp(x)
    // exp(x) ≈ 1 + x + x²/2 + x³/6 + x⁴/24 + x⁵/120

    let one = _mm256_set1_ps(1.0);
    let c2 = _mm256_set1_ps(0.5);
    let c3 = _mm256_set1_ps(1.0 / 6.0);
    let c4 = _mm256_set1_ps(1.0 / 24.0);
    let c5 = _mm256_set1_ps(1.0 / 120.0);

    let x2 = _mm256_mul_ps(x, x);
    let x3 = _mm256_mul_ps(x2, x);
    let x4 = _mm256_mul_ps(x3, x);
    let x5 = _mm256_mul_ps(x4, x);

    let mut result = one;
    result = _mm256_add_ps(result, x);
    result = _mm256_fmadd_ps(x2, c2, result);
    result = _mm256_fmadd_ps(x3, c3, result);
    result = _mm256_fmadd_ps(x4, c4, result);
    result = _mm256_fmadd_ps(x5, c5, result);

    result
}

/// Find max value with AVX2
#[target_feature(enable = "avx2")]
unsafe fn simd_max_avx2(values: &[f32]) -> f32 {
    let len = values.len();
    let mut vmax = _mm256_set1_ps(f32::NEG_INFINITY);

    for i in 0..(len / 8) {
        let offset = i * 8;
        let v = _mm256_loadu_ps(values.as_ptr().add(offset));
        vmax = _mm256_max_ps(vmax, v);
    }

    // Horizontal max
    let mut max_val = horizontal_max_avx2(vmax);

    // Check remainder
    let remainder_start = (len / 8) * 8;
    for i in remainder_start..len {
        max_val = max_val.max(values[i]);
    }

    max_val
}

#[target_feature(enable = "avx2")]
unsafe fn horizontal_max_avx2(v: __m256) -> f32 {
    let high = _mm256_extractf128_ps(v, 1);
    let low = _mm256_castps256_ps128(v);
    let max128 = _mm_max_ps(high, low);

    let shuf = _mm_movehdup_ps(max128);
    let max2 = _mm_max_ps(max128, shuf);

    let shuf = _mm_movehl_ps(shuf, max2);
    let max1 = _mm_max_ss(max2, shuf);

    _mm_cvtss_f32(max1)
}

1.2 SSE Fallback

// src/simd/x86/sse.rs

#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;

/// SSE-optimized dot product (fallback for older CPUs)
#[target_feature(enable = "sse")]
pub unsafe fn simd_dot_product_sse(a: &[f32], b: &[f32]) -> f32 {
    debug_assert_eq!(a.len(), b.len());

    let len = a.len();
    let mut sum = _mm_setzero_ps();

    // Process 4 floats at a time (128-bit registers)
    let chunks = len / 4;

    for i in 0..chunks {
        let offset = i * 4;
        let va = _mm_loadu_ps(a.as_ptr().add(offset));
        let vb = _mm_loadu_ps(b.as_ptr().add(offset));
        let prod = _mm_mul_ps(va, vb);
        sum = _mm_add_ps(sum, prod);
    }

    // Horizontal sum
    let mut result = horizontal_sum_sse(sum);

    // Handle remainder
    let remainder = len % 4;
    for i in (len - remainder)..len {
        result += a[i] * b[i];
    }

    result
}

#[target_feature(enable = "sse")]
unsafe fn horizontal_sum_sse(v: __m128) -> f32 {
    let shuf = _mm_movehdup_ps(v);
    let sum2 = _mm_add_ps(v, shuf);
    let shuf = _mm_movehl_ps(shuf, sum2);
    let sum1 = _mm_add_ss(sum2, shuf);
    _mm_cvtss_f32(sum1)
}

2. ARM NEON Optimizations

2.1 NEON Implementations

// src/simd/arm/neon.rs

#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;

/// NEON-optimized dot product for ARM64
#[target_feature(enable = "neon")]
pub unsafe fn simd_dot_product_neon(a: &[f32], b: &[f32]) -> f32 {
    debug_assert_eq!(a.len(), b.len());

    let len = a.len();
    let mut sum = vdupq_n_f32(0.0);

    // Process 4 floats at a time (128-bit registers)
    let chunks = len / 4;

    for i in 0..chunks {
        let offset = i * 4;

        let va = vld1q_f32(a.as_ptr().add(offset));
        let vb = vld1q_f32(b.as_ptr().add(offset));

        // Fused multiply-add
        sum = vfmaq_f32(sum, va, vb);
    }

    // Horizontal sum
    let mut result = vaddvq_f32(sum);

    // Handle remainder
    let remainder = len % 4;
    for i in (len - remainder)..len {
        result += a[i] * b[i];
    }

    result
}

/// NEON-optimized weighted sum
#[target_feature(enable = "neon")]
pub unsafe fn simd_weighted_sum_neon(
    vectors: &[&[f32]],
    weights: &[f32],
    output: &mut [f32],
) {
    debug_assert_eq!(vectors.len(), weights.len());

    let dim = output.len();

    // Zero output
    let zero = vdupq_n_f32(0.0);
    for i in 0..(dim / 4) {
        let offset = i * 4;
        vst1q_f32(output.as_mut_ptr().add(offset), zero);
    }

    // Process each vector
    for (vec, &weight) in vectors.iter().zip(weights.iter()) {
        debug_assert_eq!(vec.len(), dim);

        let vweight = vdupq_n_f32(weight);

        for i in 0..(dim / 4) {
            let offset = i * 4;

            let vvec = vld1q_f32(vec.as_ptr().add(offset));
            let vout = vld1q_f32(output.as_ptr().add(offset));

            // output += vec * weight
            let result = vfmaq_f32(vout, vvec, vweight);
            vst1q_f32(output.as_mut_ptr().add(offset), result);
        }
    }

    // Handle remainder
    let remainder_start = (dim / 4) * 4;
    for i in remainder_start..dim {
        for (vec, &weight) in vectors.iter().zip(weights.iter()) {
            output[i] += vec[i] * weight;
        }
    }
}

/// NEON-optimized softmax
#[target_feature(enable = "neon")]
pub unsafe fn simd_softmax_neon(input: &[f32], output: &mut [f32]) {
    debug_assert_eq!(input.len(), output.len());

    let len = input.len();

    // Find max
    let max_val = simd_max_neon(input);
    let vmax = vdupq_n_f32(max_val);

    // Compute exp(x - max) and sum
    let mut sum = vdupq_n_f32(0.0);

    for i in 0..(len / 4) {
        let offset = i * 4;
        let vx = vld1q_f32(input.as_ptr().add(offset));
        let vx_shifted = vsubq_f32(vx, vmax);

        // Use scalar exp for now (NEON doesn't have native exp)
        let mut exp_vals = [0.0f32; 4];
        vst1q_f32(exp_vals.as_mut_ptr(), vx_shifted);
        for val in &mut exp_vals {
            *val = val.exp();
        }

        let vexp = vld1q_f32(exp_vals.as_ptr());
        vst1q_f32(output.as_mut_ptr().add(offset), vexp);
        sum = vaddq_f32(sum, vexp);
    }

    let mut sum_scalar = vaddvq_f32(sum);

    // Handle remainder
    let remainder_start = (len / 4) * 4;
    for i in remainder_start..len {
        let exp_val = (input[i] - max_val).exp();
        output[i] = exp_val;
        sum_scalar += exp_val;
    }

    // Divide by sum
    let vsum = vdupq_n_f32(sum_scalar);
    for i in 0..(len / 4) {
        let offset = i * 4;
        let vexp = vld1q_f32(output.as_ptr().add(offset));
        let result = vdivq_f32(vexp, vsum);
        vst1q_f32(output.as_mut_ptr().add(offset), result);
    }

    for i in remainder_start..len {
        output[i] /= sum_scalar;
    }
}

#[target_feature(enable = "neon")]
unsafe fn simd_max_neon(values: &[f32]) -> f32 {
    let len = values.len();
    let mut vmax = vdupq_n_f32(f32::NEG_INFINITY);

    for i in 0..(len / 4) {
        let offset = i * 4;
        let v = vld1q_f32(values.as_ptr().add(offset));
        vmax = vmaxq_f32(vmax, v);
    }

    let mut max_val = vmaxvq_f32(vmax);

    let remainder_start = (len / 4) * 4;
    for i in remainder_start..len {
        max_val = max_val.max(values[i]);
    }

    max_val
}

3. WASM SIMD Support

3.1 WASM SIMD128

// src/simd/wasm/simd128.rs

#[cfg(target_arch = "wasm32")]
use std::arch::wasm32::*;

/// WASM SIMD128-optimized dot product
#[target_feature(enable = "simd128")]
pub unsafe fn simd_dot_product_wasm(a: &[f32], b: &[f32]) -> f32 {
    debug_assert_eq!(a.len(), b.len());

    let len = a.len();
    let mut sum = f32x4_splat(0.0);

    // Process 4 floats at a time
    let chunks = len / 4;

    for i in 0..chunks {
        let offset = i * 4;

        let va = v128_load(a.as_ptr().add(offset) as *const v128);
        let vb = v128_load(b.as_ptr().add(offset) as *const v128);

        let prod = f32x4_mul(va, vb);
        sum = f32x4_add(sum, prod);
    }

    // Extract and sum 4 lanes
    let mut result = f32x4_extract_lane::<0>(sum)
        + f32x4_extract_lane::<1>(sum)
        + f32x4_extract_lane::<2>(sum)
        + f32x4_extract_lane::<3>(sum);

    // Handle remainder
    let remainder = len % 4;
    for i in (len - remainder)..len {
        result += a[i] * b[i];
    }

    result
}

/// WASM SIMD128-optimized weighted sum
#[target_feature(enable = "simd128")]
pub unsafe fn simd_weighted_sum_wasm(
    vectors: &[&[f32]],
    weights: &[f32],
    output: &mut [f32],
) {
    debug_assert_eq!(vectors.len(), weights.len());

    let dim = output.len();

    // Zero output
    let zero = f32x4_splat(0.0);
    for i in 0..(dim / 4) {
        let offset = i * 4;
        v128_store(output.as_mut_ptr().add(offset) as *mut v128, zero);
    }

    // Process each vector
    for (vec, &weight) in vectors.iter().zip(weights.iter()) {
        debug_assert_eq!(vec.len(), dim);

        let vweight = f32x4_splat(weight);

        for i in 0..(dim / 4) {
            let offset = i * 4;

            let vvec = v128_load(vec.as_ptr().add(offset) as *const v128);
            let vout = v128_load(output.as_ptr().add(offset) as *const v128);

            let weighted = f32x4_mul(vvec, vweight);
            let result = f32x4_add(vout, weighted);

            v128_store(output.as_mut_ptr().add(offset) as *mut v128, result);
        }
    }

    // Handle remainder
    let remainder_start = (dim / 4) * 4;
    for i in remainder_start..dim {
        for (vec, &weight) in vectors.iter().zip(weights.iter()) {
            output[i] += vec[i] * weight;
        }
    }
}

4. Cross-Platform Abstraction Layer

4.1 Runtime Feature Detection

// src/simd/mod.rs

use std::sync::OnceLock;

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SimdCapability {
    Avx2Fma,
    Sse,
    Neon,
    WasmSimd128,
    Scalar,
}

static SIMD_CAPABILITY: OnceLock<SimdCapability> = OnceLock::new();

/// Detect available SIMD instructions at runtime
pub fn detect_simd_capability() -> SimdCapability {
    *SIMD_CAPABILITY.get_or_init(|| {
        #[cfg(target_arch = "x86_64")]
        {
            if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
                return SimdCapability::Avx2Fma;
            }
            if is_x86_feature_detected!("sse") {
                return SimdCapability::Sse;
            }
        }

        #[cfg(target_arch = "aarch64")]
        {
            if std::arch::is_aarch64_feature_detected!("neon") {
                return SimdCapability::Neon;
            }
        }

        #[cfg(target_arch = "wasm32")]
        {
            #[cfg(target_feature = "simd128")]
            return SimdCapability::WasmSimd128;
        }

        SimdCapability::Scalar
    })
}

/// Get human-readable SIMD capability name
pub fn simd_capability_name() -> &'static str {
    match detect_simd_capability() {
        SimdCapability::Avx2Fma => "AVX2 + FMA",
        SimdCapability::Sse => "SSE",
        SimdCapability::Neon => "NEON",
        SimdCapability::WasmSimd128 => "WASM SIMD128",
        SimdCapability::Scalar => "Scalar (no SIMD)",
    }
}

4.2 Safe Public API

// src/simd/kernels.rs

use super::*;

/// Safe wrapper for SIMD dot product with runtime dispatch
pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
    assert_eq!(
        a.len(),
        b.len(),
        "dot_product: vectors must have same length"
    );

    if a.len() == 0 {
        return 0.0;
    }

    match detect_simd_capability() {
        #[cfg(target_arch = "x86_64")]
        SimdCapability::Avx2Fma => unsafe {
            x86::avx2::simd_dot_product_avx2(a, b)
        },

        #[cfg(target_arch = "x86_64")]
        SimdCapability::Sse => unsafe {
            x86::sse::simd_dot_product_sse(a, b)
        },

        #[cfg(target_arch = "aarch64")]
        SimdCapability::Neon => unsafe {
            arm::neon::simd_dot_product_neon(a, b)
        },

        #[cfg(target_arch = "wasm32")]
        SimdCapability::WasmSimd128 => unsafe {
            wasm::simd128::simd_dot_product_wasm(a, b)
        },

        SimdCapability::Scalar | _ => {
            fallback::scalar::dot_product_scalar(a, b)
        }
    }
}

/// Safe wrapper for SIMD weighted sum
pub fn weighted_sum(vectors: &[&[f32]], weights: &[f32], output: &mut [f32]) {
    assert_eq!(
        vectors.len(),
        weights.len(),
        "weighted_sum: vectors and weights must have same length"
    );

    if vectors.is_empty() {
        return;
    }

    let dim = output.len();
    for vec in vectors {
        assert_eq!(
            vec.len(),
            dim,
            "weighted_sum: all vectors must match output dimension"
        );
    }

    match detect_simd_capability() {
        #[cfg(target_arch = "x86_64")]
        SimdCapability::Avx2Fma => unsafe {
            x86::avx2::simd_weighted_sum_avx2(vectors, weights, output)
        },

        #[cfg(target_arch = "aarch64")]
        SimdCapability::Neon => unsafe {
            arm::neon::simd_weighted_sum_neon(vectors, weights, output)
        },

        #[cfg(target_arch = "wasm32")]
        SimdCapability::WasmSimd128 => unsafe {
            wasm::simd128::simd_weighted_sum_wasm(vectors, weights, output)
        },

        _ => {
            fallback::scalar::weighted_sum_scalar(vectors, weights, output)
        }
    }
}

/// Safe wrapper for SIMD softmax
pub fn softmax(input: &[f32], output: &mut [f32]) {
    assert_eq!(
        input.len(),
        output.len(),
        "softmax: input and output must have same length"
    );

    if input.is_empty() {
        return;
    }

    match detect_simd_capability() {
        #[cfg(target_arch = "x86_64")]
        SimdCapability::Avx2Fma => unsafe {
            x86::avx2::simd_softmax_avx2(input, output)
        },

        #[cfg(target_arch = "aarch64")]
        SimdCapability::Neon => unsafe {
            arm::neon::simd_softmax_neon(input, output)
        },

        _ => {
            fallback::scalar::softmax_scalar(input, output)
        }
    }
}

/// Batched attention computation with SIMD
pub fn attention_forward(
    queries: &[f32],    // [num_queries, dim]
    keys: &[f32],       // [num_keys, dim]
    values: &[f32],     // [num_keys, value_dim]
    num_queries: usize,
    num_keys: usize,
    dim: usize,
    value_dim: usize,
    output: &mut [f32], // [num_queries, value_dim]
) {
    assert_eq!(queries.len(), num_queries * dim);
    assert_eq!(keys.len(), num_keys * dim);
    assert_eq!(values.len(), num_keys * value_dim);
    assert_eq!(output.len(), num_queries * value_dim);

    let scale = 1.0 / (dim as f32).sqrt();

    let mut scores = vec![0.0f32; num_keys];
    let mut attn_weights = vec![0.0f32; num_keys];

    for q_idx in 0..num_queries {
        let q_start = q_idx * dim;
        let query = &queries[q_start..q_start + dim];

        // Compute attention scores
        for k_idx in 0..num_keys {
            let k_start = k_idx * dim;
            let key = &keys[k_start..k_start + dim];
            scores[k_idx] = dot_product(query, key) * scale;
        }

        // Apply softmax
        softmax(&scores, &mut attn_weights);

        // Weighted sum of values
        let value_refs: Vec<&[f32]> = (0..num_keys)
            .map(|k_idx| {
                let v_start = k_idx * value_dim;
                &values[v_start..v_start + value_dim]
            })
            .collect();

        let out_start = q_idx * value_dim;
        weighted_sum(
            &value_refs,
            &attn_weights,
            &mut output[out_start..out_start + value_dim],
        );
    }
}

4.3 Scalar Fallback

// src/simd/fallback/scalar.rs

/// Portable scalar dot product
pub fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
    a.iter()
        .zip(b.iter())
        .map(|(x, y)| x * y)
        .sum()
}

/// Portable scalar weighted sum
pub fn weighted_sum_scalar(
    vectors: &[&[f32]],
    weights: &[f32],
    output: &mut [f32],
) {
    output.fill(0.0);

    for (vec, &weight) in vectors.iter().zip(weights.iter()) {
        for (out, &val) in output.iter_mut().zip(vec.iter()) {
            *out += val * weight;
        }
    }
}

/// Portable scalar softmax
pub fn softmax_scalar(input: &[f32], output: &mut [f32]) {
    let max_val = input.iter().copied().fold(f32::NEG_INFINITY, f32::max);

    let sum: f32 = input
        .iter()
        .map(|&x| (x - max_val).exp())
        .sum();

    for (out, &inp) in output.iter_mut().zip(input.iter()) {
        *out = (inp - max_val).exp() / sum;
    }
}

5. Performance Targets & Benchmarks

5.1 Target Metrics

// benches/simd_benchmarks.rs

use criterion::{black_box, criterion_group, criterion_main, Criterion, BenchmarkId};
use ruvector::simd;

fn bench_dot_product(c: &mut Criterion) {
    let mut group = c.benchmark_group("dot_product");

    for size in [64, 128, 256, 512, 1024, 2048, 4096].iter() {
        let a: Vec<f32> = (0..*size).map(|i| i as f32).collect();
        let b: Vec<f32> = (0..*size).map(|i| (i * 2) as f32).collect();

        group.bench_with_input(BenchmarkId::new("simd", size), size, |bencher, _| {
            bencher.iter(|| {
                simd::kernels::dot_product(black_box(&a), black_box(&b))
            });
        });

        group.bench_with_input(BenchmarkId::new("scalar", size), size, |bencher, _| {
            bencher.iter(|| {
                simd::fallback::scalar::dot_product_scalar(black_box(&a), black_box(&b))
            });
        });
    }

    group.finish();
}

fn bench_weighted_sum(c: &mut Criterion) {
    let mut group = c.benchmark_group("weighted_sum");

    let dim = 512;
    let n_vectors = 16;

    let vectors: Vec<Vec<f32>> = (0..n_vectors)
        .map(|_| (0..dim).map(|i| i as f32).collect())
        .collect();
    let vector_refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
    let weights: Vec<f32> = (0..n_vectors).map(|i| 1.0 / (i + 1) as f32).collect();
    let mut output = vec![0.0f32; dim];

    group.bench_function("simd", |bencher| {
        bencher.iter(|| {
            simd::kernels::weighted_sum(
                black_box(&vector_refs),
                black_box(&weights),
                black_box(&mut output),
            );
        });
    });

    group.bench_function("scalar", |bencher| {
        bencher.iter(|| {
            simd::fallback::scalar::weighted_sum_scalar(
                black_box(&vector_refs),
                black_box(&weights),
                black_box(&mut output),
            );
        });
    });

    group.finish();
}

fn bench_softmax(c: &mut Criterion) {
    let mut group = c.benchmark_group("softmax");

    for size in [64, 128, 256, 512].iter() {
        let input: Vec<f32> = (0..*size).map(|i| (i as f32) * 0.1).collect();
        let mut output = vec![0.0f32; *size];

        group.bench_with_input(BenchmarkId::new("simd", size), size, |bencher, _| {
            bencher.iter(|| {
                simd::kernels::softmax(black_box(&input), black_box(&mut output));
            });
        });

        group.bench_with_input(BenchmarkId::new("scalar", size), size, |bencher, _| {
            bencher.iter(|| {
                simd::fallback::scalar::softmax_scalar(
                    black_box(&input),
                    black_box(&mut output),
                );
            });
        });
    }

    group.finish();
}

fn bench_attention_forward(c: &mut Criterion) {
    let mut group = c.benchmark_group("attention_forward");

    let num_queries = 32;
    let num_keys = 64;
    let dim = 128;
    let value_dim = 128;

    let queries: Vec<f32> = (0..num_queries * dim).map(|i| i as f32 * 0.01).collect();
    let keys: Vec<f32> = (0..num_keys * dim).map(|i| i as f32 * 0.01).collect();
    let values: Vec<f32> = (0..num_keys * value_dim).map(|i| i as f32 * 0.01).collect();
    let mut output = vec![0.0f32; num_queries * value_dim];

    group.bench_function("simd", |bencher| {
        bencher.iter(|| {
            simd::kernels::attention_forward(
                black_box(&queries),
                black_box(&keys),
                black_box(&values),
                black_box(num_queries),
                black_box(num_keys),
                black_box(dim),
                black_box(value_dim),
                black_box(&mut output),
            );
        });
    });

    group.finish();
}

criterion_group!(
    benches,
    bench_dot_product,
    bench_weighted_sum,
    bench_softmax,
    bench_attention_forward
);
criterion_main!(benches);

5.2 Expected Performance

Operation Dimension Scalar AVX2 Speedup Target
Dot Product 512 100 ns 12.5 ns 8x 4-8x ✓
Dot Product 1024 200 ns 25 ns 8x 4-8x ✓
Weighted Sum 512x16 2.5 µs 400 ns 6.25x 4-8x ✓
Softmax 256 800 ns 200 ns 4x 2-4x ✓
Softmax 512 1.6 µs 400 ns 4x 2-4x ✓
Attention 32x64x128 150 µs 50 µs 3x 2-4x ✓

5.3 ARM NEON Performance

Operation Dimension Scalar NEON Speedup
Dot Product 512 120 ns 20 ns 6x
Weighted Sum 512x16 2.8 µs 500 ns 5.6x
Softmax 256 900 ns 250 ns 3.6x

5.4 WASM SIMD Performance

Operation Dimension Scalar SIMD128 Speedup
Dot Product 512 150 ns 40 ns 3.75x
Weighted Sum 512x16 3.2 µs 800 ns 4x

6. Testing Strategy

6.1 Correctness Tests

// tests/simd_correctness.rs

use ruvector::simd;
use approx::assert_relative_eq;

#[test]
fn test_dot_product_correctness() {
    let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
    let b = vec![8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];

    let expected: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
    let result = simd::kernels::dot_product(&a, &b);

    assert_relative_eq!(result, expected, epsilon = 1e-5);
}

#[test]
fn test_weighted_sum_correctness() {
    let v1 = vec![1.0, 2.0, 3.0, 4.0];
    let v2 = vec![5.0, 6.0, 7.0, 8.0];
    let vectors = vec![v1.as_slice(), v2.as_slice()];
    let weights = vec![0.3, 0.7];
    let mut output = vec![0.0; 4];

    simd::kernels::weighted_sum(&vectors, &weights, &mut output);

    let expected = vec![
        1.0 * 0.3 + 5.0 * 0.7,
        2.0 * 0.3 + 6.0 * 0.7,
        3.0 * 0.3 + 7.0 * 0.7,
        4.0 * 0.3 + 8.0 * 0.7,
    ];

    for (out, exp) in output.iter().zip(expected.iter()) {
        assert_relative_eq!(out, exp, epsilon = 1e-5);
    }
}

#[test]
fn test_softmax_correctness() {
    let input = vec![1.0, 2.0, 3.0, 4.0];
    let mut output = vec![0.0; 4];

    simd::kernels::softmax(&input, &mut output);

    // Check sum is 1.0
    let sum: f32 = output.iter().sum();
    assert_relative_eq!(sum, 1.0, epsilon = 1e-5);

    // Check monotonicity
    for i in 0..output.len() - 1 {
        assert!(output[i] < output[i + 1]);
    }
}

#[test]
fn test_attention_forward_correctness() {
    let num_queries = 2;
    let num_keys = 3;
    let dim = 4;
    let value_dim = 4;

    let queries = vec![1.0; num_queries * dim];
    let keys = vec![1.0; num_keys * dim];
    let values = vec![1.0; num_keys * value_dim];
    let mut output = vec![0.0; num_queries * value_dim];

    simd::kernels::attention_forward(
        &queries,
        &keys,
        &values,
        num_queries,
        num_keys,
        dim,
        value_dim,
        &mut output,
    );

    // With uniform inputs, attention should average values
    for val in output.iter() {
        assert_relative_eq!(val, 1.0, epsilon = 1e-4);
    }
}

7. Integration with GNN System

7.1 GNN Attention Layer with SIMD

// src/gnn/attention.rs

use crate::simd;

pub struct GNNAttentionLayer {
    dim: usize,
    num_heads: usize,
    head_dim: usize,
}

impl GNNAttentionLayer {
    pub fn new(dim: usize, num_heads: usize) -> Self {
        assert_eq!(dim % num_heads, 0);
        Self {
            dim,
            num_heads,
            head_dim: dim / num_heads,
        }
    }

    pub fn forward(
        &self,
        node_features: &[f32],  // [num_nodes, dim]
        edge_index: &[(usize, usize)],
        num_nodes: usize,
    ) -> Vec<f32> {
        let mut output = vec![0.0f32; num_nodes * self.dim];

        // Process each head
        for head in 0..self.num_heads {
            self.forward_head(
                node_features,
                edge_index,
                num_nodes,
                head,
                &mut output,
            );
        }

        output
    }

    fn forward_head(
        &self,
        node_features: &[f32],
        edge_index: &[(usize, usize)],
        num_nodes: usize,
        head: usize,
        output: &mut [f32],
    ) {
        let head_offset = head * self.head_dim;

        for &(src, dst) in edge_index {
            let src_start = src * self.dim + head_offset;
            let dst_start = dst * self.dim + head_offset;

            let src_feat = &node_features[src_start..src_start + self.head_dim];
            let dst_feat = &node_features[dst_start..dst_start + self.head_dim];

            // Use SIMD dot product for attention score
            let score = simd::kernels::dot_product(src_feat, dst_feat);

            // Update output (simplified - real implementation would use softmax)
            for i in 0..self.head_dim {
                output[dst_start + i] += src_feat[i] * score;
            }
        }
    }
}

8. Cargo Configuration

8.1 Cargo.toml

[package]
name = "ruvector-simd"
version = "0.1.0"
edition = "2021"

[features]
default = ["simd"]
simd = []
avx2 = []
neon = []
wasm-simd = []

[dependencies]

[dev-dependencies]
criterion = "0.5"
approx = "0.5"

[[bench]]
name = "simd_benchmarks"
harness = false

[profile.release]
opt-level = 3
lto = "fat"
codegen-units = 1

[profile.bench]
inherits = "release"

[target.'cfg(target_arch = "x86_64")'.dependencies]
# x86-specific dependencies if needed

[target.'cfg(target_arch = "aarch64")'.dependencies]
# ARM-specific dependencies if needed

[target.'cfg(target_arch = "wasm32")'.dependencies]
# WASM-specific dependencies if needed

8.2 Build Configuration

# .cargo/config.toml

[target.x86_64-unknown-linux-gnu]
rustflags = ["-C", "target-cpu=native"]

[target.aarch64-unknown-linux-gnu]
rustflags = ["-C", "target-cpu=native"]

[target.wasm32-unknown-unknown]
rustflags = ["-C", "target-feature=+simd128"]

9. Documentation

9.1 Usage Examples

// examples/simd_usage.rs

use ruvector::simd;

fn main() {
    // Print detected SIMD capability
    println!("SIMD Capability: {}", simd::simd_capability_name());

    // Example 1: Dot product
    let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
    let b = vec![8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
    let result = simd::kernels::dot_product(&a, &b);
    println!("Dot product: {}", result);

    // Example 2: Weighted sum
    let v1 = vec![1.0; 128];
    let v2 = vec![2.0; 128];
    let v3 = vec![3.0; 128];
    let vectors = vec![v1.as_slice(), v2.as_slice(), v3.as_slice()];
    let weights = vec![0.2, 0.3, 0.5];
    let mut output = vec![0.0; 128];
    simd::kernels::weighted_sum(&vectors, &weights, &mut output);
    println!("Weighted sum first element: {}", output[0]);

    // Example 3: Softmax
    let input = vec![1.0, 2.0, 3.0, 4.0];
    let mut output = vec![0.0; 4];
    simd::kernels::softmax(&input, &mut output);
    println!("Softmax: {:?}", output);
}

10. Next Steps

  1. Implement base module structure (src/simd/mod.rs)
  2. Add AVX2 optimizations for x86_64
  3. Add NEON optimizations for ARM64
  4. Add WASM SIMD support
  5. Implement feature detection
  6. Create safe wrapper API
  7. Write comprehensive tests
  8. Add benchmarks and validate performance targets
  9. Integrate with GNN attention layers
  10. Document usage and optimization guidelines

Performance Validation Checklist

  • AVX2 dot product achieves 4-8x speedup
  • NEON dot product achieves 4-6x speedup
  • Weighted sum achieves 4-8x speedup
  • Softmax achieves 2-4x speedup
  • Attention forward achieves 2-4x speedup
  • All platforms have fallback implementations
  • Runtime feature detection works correctly
  • Safe API prevents undefined behavior
  • Benchmarks run on CI for all platforms
  • Documentation includes performance characteristics

Success Criteria

4-8x speedup for dot product operations 2-4x speedup for attention forward pass Cross-platform support (x86_64, ARM64, WASM) Safe abstractions over unsafe SIMD intrinsics Runtime dispatch based on CPU capabilities Zero-cost abstractions in release builds Comprehensive testing for correctness Production-ready code quality