Files
wifi-densepose/vendor/ruvector/examples/ruvLLM/src/simd_inference.rs

1232 lines
38 KiB
Rust

//! SIMD-Optimized CPU Inference Engine
//!
//! Implements a minimal transformer architecture with native SIMD operations
//! for efficient CPU inference. Uses direct SIMD intrinsics when available.
//!
//! ## Optimized Kernels (v2.0)
//!
//! This module now integrates with `ruvllm_lib::kernels` for optimized operations:
//! - **Flash Attention 2**: Use `flash_attention_neon` for 3-6x speedup
//! - **GEMM/GEMV**: Use `gemm_neon`/`gemv_neon` for optimized matrix ops
//! - **Parallel**: Enable `parallel` feature for multi-threaded inference
//!
//! ## Example: Using Optimized Kernels
//!
//! ```rust,ignore
//! use ruvllm::kernels::{flash_attention_neon, gemv_neon, gemm_neon};
//! use ruvllm::simd_inference::SimdOps;
//!
//! // Use optimized attention (falls back to local impl on non-aarch64)
//! let output = SimdOps::attention(&query, &key, &value, scale, causal);
//!
//! // Use optimized GEMV
//! let y = SimdOps::gemv(&matrix, &vector);
//! ```
use crate::error::{Error, InferenceError, Result};
use crate::types::ModelSize;
use ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
use parking_lot::RwLock;
use rayon::prelude::*;
use std::collections::HashMap;
use std::sync::Arc;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
// Import optimized kernels from ruvllm when available on aarch64
#[cfg(target_arch = "aarch64")]
use ruvllm_lib::kernels::{
flash_attention_neon as optimized_attention,
gemv_neon as optimized_gemv,
rms_norm_neon as optimized_rms_norm,
AttentionConfig as OptimizedAttentionConfig,
};
#[cfg(all(target_arch = "aarch64", feature = "parallel"))]
use ruvllm_lib::kernels::{
gemv_parallel as optimized_gemv_parallel,
multi_query_attention_parallel,
};
/// SIMD-optimized matrix operations
pub struct SimdOps;
impl SimdOps {
// =========================================================================
// Optimized operations using ruvllm kernels (v2.0)
// =========================================================================
/// Flash Attention 2 using optimized NEON kernels (aarch64) or fallback (x86_64)
///
/// This method uses the highly optimized Flash Attention 2 implementation from
/// `ruvllm_lib::kernels` on Apple Silicon, with automatic fallback
/// to the local implementation on other architectures.
///
/// # Performance
/// - aarch64 (M4 Pro): 3-6x speedup with online softmax rescaling
/// - x86_64 (AVX2): Uses local AVX2 implementation
#[inline]
pub fn attention(query: &[f32], key: &[f32], value: &[f32], scale: f32, causal: bool) -> Vec<f32> {
#[cfg(target_arch = "aarch64")]
{
// Use optimized Flash Attention 2 from ruvllm
optimized_attention(query, key, value, scale, causal)
}
#[cfg(not(target_arch = "aarch64"))]
{
// Fallback to local implementation
Self::attention_fallback(query, key, value, scale, causal)
}
}
/// GEMV using optimized NEON kernels with automatic parallel dispatch
///
/// Uses the 12-row micro-kernel from `ruvllm_lib` on aarch64.
/// Automatically dispatches to parallel version when `parallel` feature is enabled.
///
/// # Performance
/// - Single-threaded: ~8 GFLOPS on M4 Pro
/// - Multi-threaded: ~15 GFLOPS on M4 Pro (parallel feature)
#[inline]
pub fn gemv(matrix: &[f32], vector: &[f32], m: usize, n: usize) -> Vec<f32> {
let mut result = vec![0.0f32; m];
#[cfg(target_arch = "aarch64")]
{
optimized_gemv(matrix, vector, &mut result, m, n);
}
#[cfg(not(target_arch = "aarch64"))]
{
// Fallback: use matmul_vec
let mat = Array2::from_shape_vec((m, n), matrix.to_vec()).unwrap();
let vec = Array1::from_vec(vector.to_vec());
result = Self::matmul_vec(&mat, &vec).to_vec();
}
result
}
/// GEMV with explicit parallel dispatch (requires `parallel` feature)
#[cfg(feature = "parallel")]
#[inline]
pub fn gemv_parallel(matrix: &[f32], vector: &[f32], m: usize, n: usize) -> Vec<f32> {
let mut result = vec![0.0f32; m];
#[cfg(target_arch = "aarch64")]
unsafe {
optimized_gemv_parallel(matrix, vector, &mut result, m, n);
}
#[cfg(not(target_arch = "aarch64"))]
{
// Parallel fallback using rayon
result.par_iter_mut().enumerate().for_each(|(i, out)| {
*out = (0..n).map(|j| matrix[i * n + j] * vector[j]).sum();
});
}
result
}
/// RMSNorm using optimized NEON kernels
///
/// Uses vectorized sum-of-squares and normalization from `ruvllm_lib`.
#[inline]
pub fn rms_norm_optimized(input: &[f32], weight: &[f32], eps: f32) -> Vec<f32> {
#[cfg(target_arch = "aarch64")]
{
let mut result = input.to_vec();
optimized_rms_norm(&mut result, weight, eps);
result
}
#[cfg(not(target_arch = "aarch64"))]
{
Self::rms_norm(input, weight, eps)
}
}
// =========================================================================
// Local implementations (backward compatibility)
// =========================================================================
/// SIMD dot product for f32 vectors
#[inline]
pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
return unsafe { Self::dot_product_avx2(a, b) };
} else if is_x86_feature_detected!("sse4.1") {
return unsafe { Self::dot_product_sse(a, b) };
}
}
// Fallback scalar implementation
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
/// Attention fallback for non-aarch64 architectures
#[allow(dead_code)]
fn attention_fallback(query: &[f32], key: &[f32], value: &[f32], scale: f32, _causal: bool) -> Vec<f32> {
let head_dim = query.len();
let kv_len = key.len() / head_dim;
if kv_len == 0 {
return vec![0.0; head_dim];
}
// Compute attention scores
let mut scores = Vec::with_capacity(kv_len);
for t in 0..kv_len {
let k_offset = t * head_dim;
let score: f32 = query.iter()
.zip(&key[k_offset..k_offset + head_dim])
.map(|(q, k)| q * k * scale)
.sum();
scores.push(score);
}
// Softmax
let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_scores: Vec<f32> = scores.iter().map(|s| (s - max_score).exp()).collect();
let sum_exp: f32 = exp_scores.iter().sum();
let attn_weights: Vec<f32> = exp_scores.iter().map(|e| e / sum_exp).collect();
// Weighted sum of values
let mut output = vec![0.0; head_dim];
for (t, weight) in attn_weights.iter().enumerate() {
let v_offset = t * head_dim;
for (i, v) in value[v_offset..v_offset + head_dim].iter().enumerate() {
output[i] += weight * v;
}
}
output
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
unsafe {
let mut sum = _mm256_setzero_ps();
let chunks = a.len() / 8;
for i in 0..chunks {
let a_vec = _mm256_loadu_ps(a.as_ptr().add(i * 8));
let b_vec = _mm256_loadu_ps(b.as_ptr().add(i * 8));
sum = _mm256_fmadd_ps(a_vec, b_vec, sum);
}
// Horizontal sum
let high = _mm256_extractf128_ps(sum, 1);
let low = _mm256_castps256_ps128(sum);
let sum128 = _mm_add_ps(high, low);
let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
let mut result = _mm_cvtss_f32(sum32);
// Handle remainder
for i in (chunks * 8)..a.len() {
result += a[i] * b[i];
}
result
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse4.1")]
unsafe fn dot_product_sse(a: &[f32], b: &[f32]) -> f32 {
unsafe {
let mut sum = _mm_setzero_ps();
let chunks = a.len() / 4;
for i in 0..chunks {
let a_vec = _mm_loadu_ps(a.as_ptr().add(i * 4));
let b_vec = _mm_loadu_ps(b.as_ptr().add(i * 4));
sum = _mm_add_ps(sum, _mm_mul_ps(a_vec, b_vec));
}
// Horizontal sum
let shuf = _mm_shuffle_ps(sum, sum, 0b10_11_00_01);
let sums = _mm_add_ps(sum, shuf);
let shuf = _mm_movehl_ps(sums, sums);
let sums = _mm_add_ss(sums, shuf);
let mut result = _mm_cvtss_f32(sums);
// Handle remainder
for i in (chunks * 4)..a.len() {
result += a[i] * b[i];
}
result
}
}
/// SIMD matrix-vector multiplication
#[inline]
pub fn matmul_vec(matrix: &Array2<f32>, vec: &Array1<f32>) -> Array1<f32> {
let rows = matrix.nrows();
let mut result = Array1::zeros(rows);
result
.as_slice_mut()
.unwrap()
.par_iter_mut()
.enumerate()
.for_each(|(i, out)| {
let row = matrix.row(i);
*out = Self::dot_product(row.as_slice().unwrap(), vec.as_slice().unwrap());
});
result
}
/// SIMD-optimized softmax with vectorized max/sum
#[inline]
pub fn softmax(input: &mut [f32]) {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
unsafe { Self::softmax_avx2(input) };
return;
}
}
// Scalar fallback
let max = input.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0f32;
for x in input.iter_mut() {
*x = (*x - max).exp();
sum += *x;
}
let inv_sum = 1.0 / sum;
for x in input.iter_mut() {
*x *= inv_sum;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn softmax_avx2(input: &mut [f32]) {
let len = input.len();
let chunks = len / 8;
// Find max using AVX2
let mut max_vec = unsafe { _mm256_set1_ps(f32::NEG_INFINITY) };
for i in 0..chunks {
unsafe {
let v = _mm256_loadu_ps(input.as_ptr().add(i * 8));
max_vec = _mm256_max_ps(max_vec, v);
}
}
// Horizontal max reduction
let mut max_val = unsafe {
let high = _mm256_extractf128_ps(max_vec, 1);
let low = _mm256_castps256_ps128(max_vec);
let max128 = _mm_max_ps(high, low);
let max64 = _mm_max_ps(max128, _mm_movehl_ps(max128, max128));
let max32 = _mm_max_ss(max64, _mm_shuffle_ps(max64, max64, 1));
_mm_cvtss_f32(max32)
};
// Handle remainder for max
for i in (chunks * 8)..len {
max_val = max_val.max(input[i]);
}
let max_broadcast = unsafe { _mm256_set1_ps(max_val) };
// Subtract max and compute exp (approximate with fast exp)
let mut sum = 0.0f32;
for i in 0..chunks {
unsafe {
let ptr = input.as_mut_ptr().add(i * 8);
let v = _mm256_loadu_ps(ptr);
let shifted = _mm256_sub_ps(v, max_broadcast);
// Fast exp approximation for AVX2 using polynomial
let exp_v = Self::fast_exp_avx2(shifted);
_mm256_storeu_ps(ptr, exp_v);
// Sum reduction
let high = _mm256_extractf128_ps(exp_v, 1);
let low = _mm256_castps256_ps128(exp_v);
let sum128 = _mm_add_ps(high, low);
let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
sum += _mm_cvtss_f32(sum32);
}
}
// Handle remainder
for i in (chunks * 8)..len {
input[i] = (input[i] - max_val).exp();
sum += input[i];
}
// Divide by sum
let inv_sum = 1.0 / sum;
let inv_sum_vec = unsafe { _mm256_set1_ps(inv_sum) };
for i in 0..chunks {
unsafe {
let ptr = input.as_mut_ptr().add(i * 8);
let v = _mm256_loadu_ps(ptr);
_mm256_storeu_ps(ptr, _mm256_mul_ps(v, inv_sum_vec));
}
}
for i in (chunks * 8)..len {
input[i] *= inv_sum;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn fast_exp_avx2(x: __m256) -> __m256 {
// Fast exp approximation: exp(x) ≈ (1 + x/256)^256 simplified
// Using polynomial: exp(x) ≈ 1 + x + x²/2 + x³/6 for small x
unsafe {
let one = _mm256_set1_ps(1.0);
let half = _mm256_set1_ps(0.5);
let sixth = _mm256_set1_ps(1.0 / 6.0);
// Clamp to avoid overflow
let min_val = _mm256_set1_ps(-88.0);
let max_val = _mm256_set1_ps(88.0);
let x = _mm256_max_ps(_mm256_min_ps(x, max_val), min_val);
let x2 = _mm256_mul_ps(x, x);
let x3 = _mm256_mul_ps(x2, x);
// 1 + x + x²/2 + x³/6
_mm256_fmadd_ps(x3, sixth, _mm256_fmadd_ps(x2, half, _mm256_add_ps(one, x)))
}
}
/// SIMD-optimized RMSNorm with AVX2 acceleration
#[inline]
pub fn rms_norm(input: &[f32], weight: &[f32], eps: f32) -> Vec<f32> {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
return unsafe { Self::rms_norm_avx2(input, weight, eps) };
}
}
// Scalar fallback
let sum_sq: f32 = input.iter().map(|x| x * x).sum();
let rms = (sum_sq / input.len() as f32 + eps).sqrt();
let inv_rms = 1.0 / rms;
input
.iter()
.zip(weight.iter())
.map(|(x, w)| x * inv_rms * w)
.collect()
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn rms_norm_avx2(input: &[f32], weight: &[f32], eps: f32) -> Vec<f32> {
let len = input.len();
let chunks = len / 8;
let mut result = vec![0.0f32; len];
// Compute sum of squares using AVX2
let mut sum_sq_vec = unsafe { _mm256_setzero_ps() };
for i in 0..chunks {
unsafe {
let v = _mm256_loadu_ps(input.as_ptr().add(i * 8));
sum_sq_vec = _mm256_fmadd_ps(v, v, sum_sq_vec);
}
}
// Horizontal sum
let mut sum_sq = unsafe {
let high = _mm256_extractf128_ps(sum_sq_vec, 1);
let low = _mm256_castps256_ps128(sum_sq_vec);
let sum128 = _mm_add_ps(high, low);
let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
_mm_cvtss_f32(sum32)
};
// Handle remainder
for i in (chunks * 8)..len {
sum_sq += input[i] * input[i];
}
let inv_rms = 1.0 / (sum_sq / len as f32 + eps).sqrt();
let inv_rms_vec = unsafe { _mm256_set1_ps(inv_rms) };
// Apply normalization and weight
for i in 0..chunks {
unsafe {
let x = _mm256_loadu_ps(input.as_ptr().add(i * 8));
let w = _mm256_loadu_ps(weight.as_ptr().add(i * 8));
let normalized = _mm256_mul_ps(_mm256_mul_ps(x, inv_rms_vec), w);
_mm256_storeu_ps(result.as_mut_ptr().add(i * 8), normalized);
}
}
// Handle remainder
for i in (chunks * 8)..len {
result[i] = input[i] * inv_rms * weight[i];
}
result
}
/// SIMD-optimized GELU activation
#[inline]
pub fn gelu(x: f32) -> f32 {
// Approximation: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
let sqrt_2_pi = 0.7978845608028654f32;
let coef = 0.044715f32;
let inner = sqrt_2_pi * (x + coef * x * x * x);
0.5 * x * (1.0 + inner.tanh())
}
/// SIMD-optimized SiLU activation
#[inline]
pub fn silu(x: f32) -> f32 {
x / (1.0 + (-x).exp())
}
}
/// Quantized weight storage (Q4_0 format)
#[derive(Clone)]
pub struct Q4Weights {
/// Quantized data (4-bit packed)
data: Vec<u8>,
/// Scale factors per block
scales: Vec<f32>,
/// Block size (typically 32)
block_size: usize,
/// Original dimensions
rows: usize,
cols: usize,
}
impl Q4Weights {
/// Create from f32 weights with quantization
pub fn from_f32(weights: &Array2<f32>, block_size: usize) -> Self {
let rows = weights.nrows();
let cols = weights.ncols();
let total = rows * cols;
let num_blocks = (total + block_size - 1) / block_size;
let mut data = Vec::with_capacity(total / 2);
let mut scales = Vec::with_capacity(num_blocks);
let flat: Vec<f32> = weights.iter().cloned().collect();
for block in flat.chunks(block_size) {
// Find max absolute value for scale
let max_abs = block.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
let scale = max_abs / 7.0; // Q4 range is -8 to 7
scales.push(scale);
// Quantize
let inv_scale = if scale > 0.0 { 1.0 / scale } else { 0.0 };
for pair in block.chunks(2) {
let q0 = ((pair[0] * inv_scale).round() as i8).clamp(-8, 7) as u8 & 0x0F;
let q1 = if pair.len() > 1 {
((pair[1] * inv_scale).round() as i8).clamp(-8, 7) as u8 & 0x0F
} else {
0
};
data.push((q1 << 4) | q0);
}
}
Self {
data,
scales,
block_size,
rows,
cols,
}
}
/// Dequantize and multiply with vector - optimized with block processing
pub fn matmul_vec(&self, vec: &[f32]) -> Vec<f32> {
let mut result = vec![0.0f32; self.rows];
result.par_iter_mut().enumerate().for_each(|(row, out)| {
*out = self.matmul_row_optimized(row, vec);
});
result
}
/// Optimized single row multiplication with block-level dequantization
#[inline]
fn matmul_row_optimized(&self, row: usize, vec: &[f32]) -> f32 {
let row_start = row * self.cols;
let mut sum = 0.0f32;
// Process by blocks for better cache locality
let blocks_per_row = (self.cols + self.block_size - 1) / self.block_size;
let first_block = row_start / self.block_size;
for block_offset in 0..blocks_per_row {
let block_idx = first_block + block_offset;
let scale = self.scales.get(block_idx).copied().unwrap_or(1.0);
let block_start_in_row = block_offset * self.block_size;
let block_end_in_row = (block_start_in_row + self.block_size).min(self.cols);
// Process 8 elements at a time within the block
let mut col = block_start_in_row;
while col + 8 <= block_end_in_row {
let idx = row_start + col;
let byte_start = idx / 2;
// Unpack 8 values (4 bytes)
let mut weights = [0.0f32; 8];
for i in 0..4 {
let byte = self.data.get(byte_start + i).copied().unwrap_or(0);
let q0 = (byte & 0x0F) as i8;
let q1 = ((byte >> 4) & 0x0F) as i8;
let q0 = if q0 > 7 { q0 - 16 } else { q0 };
let q1 = if q1 > 7 { q1 - 16 } else { q1 };
weights[i * 2] = q0 as f32 * scale;
weights[i * 2 + 1] = q1 as f32 * scale;
}
// SIMD dot product for this block of 8
sum += SimdOps::dot_product(&weights, &vec[col..col + 8]);
col += 8;
}
// Handle remainder within block
while col < block_end_in_row {
let idx = row_start + col;
let byte_idx = idx / 2;
let byte = self.data.get(byte_idx).copied().unwrap_or(0);
let q = if idx % 2 == 0 {
(byte & 0x0F) as i8
} else {
((byte >> 4) & 0x0F) as i8
};
let q = if q > 7 { q - 16 } else { q };
let w = q as f32 * scale;
sum += w * vec[col];
col += 1;
}
}
sum
}
}
/// Minimal transformer layer
pub struct TransformerLayer {
/// Query projection
wq: Q4Weights,
/// Key projection
wk: Q4Weights,
/// Value projection
wv: Q4Weights,
/// Output projection
wo: Q4Weights,
/// FFN gate
w1: Q4Weights,
/// FFN down
w2: Q4Weights,
/// FFN up
w3: Q4Weights,
/// Attention norm weights
attn_norm: Vec<f32>,
/// FFN norm weights
ffn_norm: Vec<f32>,
/// Hidden dimension
hidden_dim: usize,
/// Number of heads
num_heads: usize,
/// Head dimension
head_dim: usize,
}
impl TransformerLayer {
pub fn new_random(hidden_dim: usize, num_heads: usize, ffn_dim: usize) -> Self {
use rand::Rng;
let mut rng = rand::thread_rng();
let head_dim = hidden_dim / num_heads;
let mut init_weight = |rows: usize, cols: usize| -> Q4Weights {
let scale = (2.0 / (rows + cols) as f32).sqrt();
let weights: Array2<f32> =
Array2::from_shape_fn((rows, cols), |_| rng.gen::<f32>() * scale * 2.0 - scale);
Q4Weights::from_f32(&weights, 32)
};
Self {
wq: init_weight(hidden_dim, hidden_dim),
wk: init_weight(hidden_dim, hidden_dim),
wv: init_weight(hidden_dim, hidden_dim),
wo: init_weight(hidden_dim, hidden_dim),
w1: init_weight(ffn_dim, hidden_dim),
w2: init_weight(hidden_dim, ffn_dim),
w3: init_weight(ffn_dim, hidden_dim),
attn_norm: vec![1.0; hidden_dim],
ffn_norm: vec![1.0; hidden_dim],
hidden_dim,
num_heads,
head_dim,
}
}
pub fn forward(&self, x: &[f32], kv_cache: Option<&mut KvCache>, pos: usize) -> Vec<f32> {
// RMS Norm
let normed = SimdOps::rms_norm(x, &self.attn_norm, 1e-6);
// QKV projections
let q = self.wq.matmul_vec(&normed);
let k = self.wk.matmul_vec(&normed);
let v = self.wv.matmul_vec(&normed);
// Update KV cache if provided
let (k, v) = if let Some(cache) = kv_cache {
cache.append(&k, &v);
(cache.keys.clone(), cache.values.clone())
} else {
(vec![k], vec![v])
};
// Multi-head attention
let mut attn_out = vec![0.0f32; self.hidden_dim];
let seq_len = k.len();
for h in 0..self.num_heads {
let head_start = h * self.head_dim;
let head_end = head_start + self.head_dim;
let q_head: Vec<f32> = q[head_start..head_end].to_vec();
// Compute attention scores
let mut scores = vec![0.0f32; seq_len];
for (i, k_vec) in k.iter().enumerate() {
let k_head: Vec<f32> = k_vec[head_start..head_end].to_vec();
scores[i] = SimdOps::dot_product(&q_head, &k_head) / (self.head_dim as f32).sqrt();
}
// Causal mask (only attend to past)
for i in (pos + 1)..seq_len {
scores[i] = f32::NEG_INFINITY;
}
// Softmax
SimdOps::softmax(&mut scores);
// Weighted sum of values
for (i, (score, v_vec)) in scores.iter().zip(v.iter()).enumerate() {
if *score > 0.0 {
for j in 0..self.head_dim {
attn_out[head_start + j] += score * v_vec[head_start + j];
}
}
}
}
// Output projection
let attn_out = self.wo.matmul_vec(&attn_out);
// Residual
let mut hidden: Vec<f32> = x.iter().zip(attn_out.iter()).map(|(a, b)| a + b).collect();
// FFN
let normed = SimdOps::rms_norm(&hidden, &self.ffn_norm, 1e-6);
let gate = self.w1.matmul_vec(&normed);
let up = self.w3.matmul_vec(&normed);
// SiLU(gate) * up
let ffn_hidden: Vec<f32> = gate
.iter()
.zip(up.iter())
.map(|(g, u)| SimdOps::silu(*g) * u)
.collect();
let ffn_out = self.w2.matmul_vec(&ffn_hidden);
// Residual
for (h, f) in hidden.iter_mut().zip(ffn_out.iter()) {
*h += f;
}
hidden
}
}
/// KV Cache for efficient generation
#[derive(Default)]
pub struct KvCache {
pub keys: Vec<Vec<f32>>,
pub values: Vec<Vec<f32>>,
}
impl KvCache {
pub fn new() -> Self {
Self::default()
}
pub fn append(&mut self, k: &[f32], v: &[f32]) {
self.keys.push(k.to_vec());
self.values.push(v.to_vec());
}
pub fn len(&self) -> usize {
self.keys.len()
}
pub fn clear(&mut self) {
self.keys.clear();
self.values.clear();
}
}
/// Small transformer model for CPU inference
pub struct SmallTransformer {
/// Embedding table
embeddings: Array2<f32>,
/// Transformer layers
layers: Vec<TransformerLayer>,
/// Output norm
output_norm: Vec<f32>,
/// LM head (output projection)
lm_head: Q4Weights,
/// Vocabulary size
vocab_size: usize,
/// Hidden dimension
hidden_dim: usize,
}
impl SmallTransformer {
/// Create a small model with random weights (for testing/demo)
pub fn new_random(
vocab_size: usize,
hidden_dim: usize,
num_layers: usize,
num_heads: usize,
ffn_dim: usize,
) -> Self {
use rand::Rng;
let mut rng = rand::thread_rng();
// Initialize embeddings
let scale = (1.0 / hidden_dim as f32).sqrt();
let embeddings = Array2::from_shape_fn((vocab_size, hidden_dim), |_| {
rng.gen::<f32>() * scale * 2.0 - scale
});
// Initialize layers
let layers: Vec<TransformerLayer> = (0..num_layers)
.map(|_| TransformerLayer::new_random(hidden_dim, num_heads, ffn_dim))
.collect();
// Output norm
let output_norm = vec![1.0; hidden_dim];
// LM head
let lm_head_weights = Array2::from_shape_fn((vocab_size, hidden_dim), |_| {
rng.gen::<f32>() * scale * 2.0 - scale
});
let lm_head = Q4Weights::from_f32(&lm_head_weights, 32);
Self {
embeddings,
layers,
output_norm,
lm_head,
vocab_size,
hidden_dim,
}
}
/// Forward pass for a single token
pub fn forward(&self, token: u32, kv_caches: &mut [KvCache], pos: usize) -> Vec<f32> {
// Get embedding
let mut hidden: Vec<f32> = self.embeddings.row(token as usize).to_vec();
// Run through layers
for (layer, cache) in self.layers.iter().zip(kv_caches.iter_mut()) {
hidden = layer.forward(&hidden, Some(cache), pos);
}
// Output norm
let normed = SimdOps::rms_norm(&hidden, &self.output_norm, 1e-6);
// LM head to get logits
self.lm_head.matmul_vec(&normed)
}
pub fn num_layers(&self) -> usize {
self.layers.len()
}
}
/// Simple tokenizer (BPE-style for demo)
pub struct SimpleTokenizer {
vocab: HashMap<String, u32>,
id_to_token: HashMap<u32, String>,
unk_token: u32,
bos_token: u32,
eos_token: u32,
}
impl SimpleTokenizer {
pub fn new_basic(vocab_size: usize) -> Self {
let mut vocab = HashMap::new();
let mut id_to_token = HashMap::new();
// Special tokens
vocab.insert("<unk>".to_string(), 0);
vocab.insert("<s>".to_string(), 1);
vocab.insert("</s>".to_string(), 2);
vocab.insert("<pad>".to_string(), 3);
id_to_token.insert(0, "<unk>".to_string());
id_to_token.insert(1, "<s>".to_string());
id_to_token.insert(2, "</s>".to_string());
id_to_token.insert(3, "<pad>".to_string());
// Basic ASCII characters and common tokens
let mut id = 4u32;
for c in ' '..='~' {
if id as usize >= vocab_size {
break;
}
let s = c.to_string();
vocab.insert(s.clone(), id);
id_to_token.insert(id, s);
id += 1;
}
// Common word pieces
let common_tokens = [
"the", "and", "is", "of", "to", "in", "that", "it", "for", "was", "on", "are", "as",
"with", "be", "at", "by", "this", "have", "from", "or", "had", "not", "but", "what",
"all", "were", "we", "when", "your", "can", "said", "there", "use", "an", "each",
"which", "she", "do", "how", "their", "if", "will", "up", "other", "about", "out",
"many", "then", "them", "##ing", "##ed", "##s", "##er", "##ly", "##tion", "##al",
"##ness",
];
for token in common_tokens.iter() {
if id as usize >= vocab_size {
break;
}
if !vocab.contains_key(*token) {
vocab.insert(token.to_string(), id);
id_to_token.insert(id, token.to_string());
id += 1;
}
}
Self {
vocab,
id_to_token,
unk_token: 0,
bos_token: 1,
eos_token: 2,
}
}
pub fn encode(&self, text: &str) -> Vec<u32> {
let mut tokens = vec![self.bos_token];
// Simple character-level tokenization with word piece fallback
for c in text.chars() {
let s = c.to_string();
let id = self.vocab.get(&s).copied().unwrap_or(self.unk_token);
tokens.push(id);
}
tokens
}
pub fn decode(&self, tokens: &[u32]) -> String {
tokens
.iter()
.filter_map(|&id| self.id_to_token.get(&id))
.filter(|s| !s.starts_with('<') || !s.ends_with('>'))
.cloned()
.collect()
}
pub fn vocab_size(&self) -> usize {
self.vocab.len()
}
pub fn eos_token(&self) -> u32 {
self.eos_token
}
}
/// Generation configuration
#[derive(Debug, Clone)]
pub struct SimdGenerationConfig {
pub max_tokens: usize,
pub temperature: f32,
pub top_p: f32,
pub top_k: usize,
pub repeat_penalty: f32,
}
impl Default for SimdGenerationConfig {
fn default() -> Self {
Self {
max_tokens: 128,
temperature: 0.8,
top_p: 0.9,
top_k: 40,
repeat_penalty: 1.1,
}
}
}
/// SIMD-optimized inference engine
pub struct SimdInferenceEngine {
model: SmallTransformer,
tokenizer: SimpleTokenizer,
kv_caches: RwLock<HashMap<String, Vec<KvCache>>>,
/// Whether this is a demo model with random weights (not a real trained model)
is_demo_model: bool,
}
impl SimdInferenceEngine {
/// Create engine with a small random model (for demo/testing)
///
/// WARNING: This creates a model with RANDOM weights for demonstration purposes.
/// It will produce a placeholder response, not actual LLM inference.
/// For real inference, load a trained model using `load_model()`.
pub fn new_demo() -> Self {
let vocab_size = 256;
let hidden_dim = 256;
let num_layers = 4;
let num_heads = 4;
let ffn_dim = 512;
let model =
SmallTransformer::new_random(vocab_size, hidden_dim, num_layers, num_heads, ffn_dim);
let tokenizer = SimpleTokenizer::new_basic(vocab_size);
Self {
model,
tokenizer,
kv_caches: RwLock::new(HashMap::new()),
is_demo_model: true,
}
}
/// Check if this is a demo model (random weights, not trained)
pub fn is_demo(&self) -> bool {
self.is_demo_model
}
/// Sample next token
fn sample(&self, logits: &[f32], config: &SimdGenerationConfig, history: &[u32]) -> u32 {
let mut probs = logits.to_vec();
// Apply repeat penalty
for &token in history {
if (token as usize) < probs.len() {
probs[token as usize] /= config.repeat_penalty;
}
}
// Temperature
if config.temperature > 0.0 {
for p in &mut probs {
*p /= config.temperature;
}
}
// Softmax
SimdOps::softmax(&mut probs);
// Top-k filtering
let mut indices: Vec<usize> = (0..probs.len()).collect();
indices.sort_by(|&a, &b| probs[b].partial_cmp(&probs[a]).unwrap());
// Top-p (nucleus) sampling
let mut cumsum = 0.0;
let mut cutoff = indices.len();
for (i, &idx) in indices.iter().enumerate() {
cumsum += probs[idx];
if cumsum > config.top_p {
cutoff = (i + 1).min(config.top_k);
break;
}
}
cutoff = cutoff.min(config.top_k);
// Renormalize
let valid_indices = &indices[..cutoff];
let sum: f32 = valid_indices.iter().map(|&i| probs[i]).sum();
// Sample
use rand::Rng;
let mut rng = rand::thread_rng();
let r: f32 = rng.gen();
let mut cumsum = 0.0;
for &idx in valid_indices {
cumsum += probs[idx] / sum;
if r < cumsum {
return idx as u32;
}
}
valid_indices[0] as u32
}
/// Generate text
///
/// If this is a demo model (random weights), returns a placeholder response
/// explaining that no trained model is loaded.
pub fn generate(
&self,
prompt: &str,
config: &SimdGenerationConfig,
session_id: Option<&str>,
) -> (String, usize, f64) {
let start = std::time::Instant::now();
// Demo model returns a helpful message instead of garbled output
if self.is_demo_model {
let elapsed = start.elapsed().as_secs_f64() * 1000.0;
let response = format!(
"[RuvLLM Demo Mode]\n\
No trained model is currently loaded. This is a demonstration engine.\n\n\
Your prompt: \"{}\"\n\n\
To get actual LLM inference:\n\
1. Load a GGUF model file\n\
2. Or connect to an external LLM API\n\
3. Or use RuvLLM with a trained checkpoint\n\n\
The SIMD inference pipeline is operational with {} layers.\n\
Config: temp={:.2}, top_p={:.2}, max_tokens={}",
prompt.chars().take(100).collect::<String>(),
self.model.num_layers(),
config.temperature,
config.top_p,
config.max_tokens,
);
return (response, 0, elapsed);
}
// Tokenize
let input_tokens = self.tokenizer.encode(prompt);
// Get or create KV cache
let session = session_id
.map(|s| s.to_string())
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
let mut caches_guard = self.kv_caches.write();
let kv_caches = caches_guard.entry(session).or_insert_with(|| {
(0..self.model.num_layers())
.map(|_| KvCache::new())
.collect()
});
// Process input tokens
let mut all_tokens = input_tokens.clone();
let start_pos = kv_caches[0].len();
for (i, &token) in input_tokens.iter().enumerate() {
let _ = self.model.forward(token, kv_caches, start_pos + i);
}
// Generate
let mut generated = Vec::new();
let eos = self.tokenizer.eos_token();
for i in 0..config.max_tokens {
let pos = start_pos + input_tokens.len() + i;
let last_token = *all_tokens.last().unwrap_or(&0);
let logits = self.model.forward(last_token, kv_caches, pos);
let next_token = self.sample(&logits, config, &all_tokens);
if next_token == eos {
break;
}
generated.push(next_token);
all_tokens.push(next_token);
}
let output = self.tokenizer.decode(&generated);
let elapsed = start.elapsed().as_secs_f64() * 1000.0;
(output, generated.len(), elapsed)
}
/// Get model info
pub fn model_info(&self) -> (usize, usize) {
(self.tokenizer.vocab_size(), self.model.num_layers())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simd_dot_product() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let b = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
let result = SimdOps::dot_product(&a, &b);
assert!((result - 36.0).abs() < 1e-5);
}
#[test]
fn test_softmax() {
let mut values = vec![1.0, 2.0, 3.0];
SimdOps::softmax(&mut values);
let sum: f32 = values.iter().sum();
assert!((sum - 1.0).abs() < 1e-5);
assert!(values[2] > values[1]);
assert!(values[1] > values[0]);
}
#[test]
fn test_q4_quantization() {
let weights = Array2::from_shape_fn((4, 4), |(i, j)| (i + j) as f32 * 0.1);
let q4 = Q4Weights::from_f32(&weights, 8);
let input = vec![1.0, 0.5, 0.25, 0.125];
let result = q4.matmul_vec(&input);
assert_eq!(result.len(), 4);
}
#[test]
fn test_inference_engine() {
let engine = SimdInferenceEngine::new_demo();
let (vocab_size, num_layers) = engine.model_info();
assert!(vocab_size > 0);
assert!(num_layers > 0);
}
#[test]
fn test_generation() {
let engine = SimdInferenceEngine::new_demo();
let config = SimdGenerationConfig {
max_tokens: 10,
..Default::default()
};
let (output, tokens, time_ms) = engine.generate("Hello", &config, None);
assert!(tokens <= 10);
assert!(time_ms > 0.0);
}
}