//! SIMD-Optimized CPU Inference Engine //! //! Implements a minimal transformer architecture with native SIMD operations //! for efficient CPU inference. Uses direct SIMD intrinsics when available. //! //! ## Optimized Kernels (v2.0) //! //! This module now integrates with `ruvllm_lib::kernels` for optimized operations: //! - **Flash Attention 2**: Use `flash_attention_neon` for 3-6x speedup //! - **GEMM/GEMV**: Use `gemm_neon`/`gemv_neon` for optimized matrix ops //! - **Parallel**: Enable `parallel` feature for multi-threaded inference //! //! ## Example: Using Optimized Kernels //! //! ```rust,ignore //! use ruvllm::kernels::{flash_attention_neon, gemv_neon, gemm_neon}; //! use ruvllm::simd_inference::SimdOps; //! //! // Use optimized attention (falls back to local impl on non-aarch64) //! let output = SimdOps::attention(&query, &key, &value, scale, causal); //! //! // Use optimized GEMV //! let y = SimdOps::gemv(&matrix, &vector); //! ``` use crate::error::{Error, InferenceError, Result}; use crate::types::ModelSize; use ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis}; use parking_lot::RwLock; use rayon::prelude::*; use std::collections::HashMap; use std::sync::Arc; #[cfg(target_arch = "x86_64")] use std::arch::x86_64::*; // Import optimized kernels from ruvllm when available on aarch64 #[cfg(target_arch = "aarch64")] use ruvllm_lib::kernels::{ flash_attention_neon as optimized_attention, gemv_neon as optimized_gemv, rms_norm_neon as optimized_rms_norm, AttentionConfig as OptimizedAttentionConfig, }; #[cfg(all(target_arch = "aarch64", feature = "parallel"))] use ruvllm_lib::kernels::{ gemv_parallel as optimized_gemv_parallel, multi_query_attention_parallel, }; /// SIMD-optimized matrix operations pub struct SimdOps; impl SimdOps { // ========================================================================= // Optimized operations using ruvllm kernels (v2.0) // ========================================================================= /// Flash Attention 2 using optimized NEON kernels (aarch64) or fallback (x86_64) /// /// This method uses the highly optimized Flash Attention 2 implementation from /// `ruvllm_lib::kernels` on Apple Silicon, with automatic fallback /// to the local implementation on other architectures. /// /// # Performance /// - aarch64 (M4 Pro): 3-6x speedup with online softmax rescaling /// - x86_64 (AVX2): Uses local AVX2 implementation #[inline] pub fn attention(query: &[f32], key: &[f32], value: &[f32], scale: f32, causal: bool) -> Vec { #[cfg(target_arch = "aarch64")] { // Use optimized Flash Attention 2 from ruvllm optimized_attention(query, key, value, scale, causal) } #[cfg(not(target_arch = "aarch64"))] { // Fallback to local implementation Self::attention_fallback(query, key, value, scale, causal) } } /// GEMV using optimized NEON kernels with automatic parallel dispatch /// /// Uses the 12-row micro-kernel from `ruvllm_lib` on aarch64. /// Automatically dispatches to parallel version when `parallel` feature is enabled. /// /// # Performance /// - Single-threaded: ~8 GFLOPS on M4 Pro /// - Multi-threaded: ~15 GFLOPS on M4 Pro (parallel feature) #[inline] pub fn gemv(matrix: &[f32], vector: &[f32], m: usize, n: usize) -> Vec { let mut result = vec![0.0f32; m]; #[cfg(target_arch = "aarch64")] { optimized_gemv(matrix, vector, &mut result, m, n); } #[cfg(not(target_arch = "aarch64"))] { // Fallback: use matmul_vec let mat = Array2::from_shape_vec((m, n), matrix.to_vec()).unwrap(); let vec = Array1::from_vec(vector.to_vec()); result = Self::matmul_vec(&mat, &vec).to_vec(); } result } /// GEMV with explicit parallel dispatch (requires `parallel` feature) #[cfg(feature = "parallel")] #[inline] pub fn gemv_parallel(matrix: &[f32], vector: &[f32], m: usize, n: usize) -> Vec { let mut result = vec![0.0f32; m]; #[cfg(target_arch = "aarch64")] unsafe { optimized_gemv_parallel(matrix, vector, &mut result, m, n); } #[cfg(not(target_arch = "aarch64"))] { // Parallel fallback using rayon result.par_iter_mut().enumerate().for_each(|(i, out)| { *out = (0..n).map(|j| matrix[i * n + j] * vector[j]).sum(); }); } result } /// RMSNorm using optimized NEON kernels /// /// Uses vectorized sum-of-squares and normalization from `ruvllm_lib`. #[inline] pub fn rms_norm_optimized(input: &[f32], weight: &[f32], eps: f32) -> Vec { #[cfg(target_arch = "aarch64")] { let mut result = input.to_vec(); optimized_rms_norm(&mut result, weight, eps); result } #[cfg(not(target_arch = "aarch64"))] { Self::rms_norm(input, weight, eps) } } // ========================================================================= // Local implementations (backward compatibility) // ========================================================================= /// SIMD dot product for f32 vectors #[inline] pub fn dot_product(a: &[f32], b: &[f32]) -> f32 { debug_assert_eq!(a.len(), b.len()); #[cfg(target_arch = "x86_64")] { if is_x86_feature_detected!("avx2") { return unsafe { Self::dot_product_avx2(a, b) }; } else if is_x86_feature_detected!("sse4.1") { return unsafe { Self::dot_product_sse(a, b) }; } } // Fallback scalar implementation a.iter().zip(b.iter()).map(|(x, y)| x * y).sum() } /// Attention fallback for non-aarch64 architectures #[allow(dead_code)] fn attention_fallback(query: &[f32], key: &[f32], value: &[f32], scale: f32, _causal: bool) -> Vec { let head_dim = query.len(); let kv_len = key.len() / head_dim; if kv_len == 0 { return vec![0.0; head_dim]; } // Compute attention scores let mut scores = Vec::with_capacity(kv_len); for t in 0..kv_len { let k_offset = t * head_dim; let score: f32 = query.iter() .zip(&key[k_offset..k_offset + head_dim]) .map(|(q, k)| q * k * scale) .sum(); scores.push(score); } // Softmax let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max); let exp_scores: Vec = scores.iter().map(|s| (s - max_score).exp()).collect(); let sum_exp: f32 = exp_scores.iter().sum(); let attn_weights: Vec = exp_scores.iter().map(|e| e / sum_exp).collect(); // Weighted sum of values let mut output = vec![0.0; head_dim]; for (t, weight) in attn_weights.iter().enumerate() { let v_offset = t * head_dim; for (i, v) in value[v_offset..v_offset + head_dim].iter().enumerate() { output[i] += weight * v; } } output } #[cfg(target_arch = "x86_64")] #[target_feature(enable = "avx2")] unsafe fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 { unsafe { let mut sum = _mm256_setzero_ps(); let chunks = a.len() / 8; for i in 0..chunks { let a_vec = _mm256_loadu_ps(a.as_ptr().add(i * 8)); let b_vec = _mm256_loadu_ps(b.as_ptr().add(i * 8)); sum = _mm256_fmadd_ps(a_vec, b_vec, sum); } // Horizontal sum let high = _mm256_extractf128_ps(sum, 1); let low = _mm256_castps256_ps128(sum); let sum128 = _mm_add_ps(high, low); let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128)); let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1)); let mut result = _mm_cvtss_f32(sum32); // Handle remainder for i in (chunks * 8)..a.len() { result += a[i] * b[i]; } result } } #[cfg(target_arch = "x86_64")] #[target_feature(enable = "sse4.1")] unsafe fn dot_product_sse(a: &[f32], b: &[f32]) -> f32 { unsafe { let mut sum = _mm_setzero_ps(); let chunks = a.len() / 4; for i in 0..chunks { let a_vec = _mm_loadu_ps(a.as_ptr().add(i * 4)); let b_vec = _mm_loadu_ps(b.as_ptr().add(i * 4)); sum = _mm_add_ps(sum, _mm_mul_ps(a_vec, b_vec)); } // Horizontal sum let shuf = _mm_shuffle_ps(sum, sum, 0b10_11_00_01); let sums = _mm_add_ps(sum, shuf); let shuf = _mm_movehl_ps(sums, sums); let sums = _mm_add_ss(sums, shuf); let mut result = _mm_cvtss_f32(sums); // Handle remainder for i in (chunks * 4)..a.len() { result += a[i] * b[i]; } result } } /// SIMD matrix-vector multiplication #[inline] pub fn matmul_vec(matrix: &Array2, vec: &Array1) -> Array1 { let rows = matrix.nrows(); let mut result = Array1::zeros(rows); result .as_slice_mut() .unwrap() .par_iter_mut() .enumerate() .for_each(|(i, out)| { let row = matrix.row(i); *out = Self::dot_product(row.as_slice().unwrap(), vec.as_slice().unwrap()); }); result } /// SIMD-optimized softmax with vectorized max/sum #[inline] pub fn softmax(input: &mut [f32]) { #[cfg(target_arch = "x86_64")] { if is_x86_feature_detected!("avx2") { unsafe { Self::softmax_avx2(input) }; return; } } // Scalar fallback let max = input.iter().cloned().fold(f32::NEG_INFINITY, f32::max); let mut sum = 0.0f32; for x in input.iter_mut() { *x = (*x - max).exp(); sum += *x; } let inv_sum = 1.0 / sum; for x in input.iter_mut() { *x *= inv_sum; } } #[cfg(target_arch = "x86_64")] #[target_feature(enable = "avx2")] unsafe fn softmax_avx2(input: &mut [f32]) { let len = input.len(); let chunks = len / 8; // Find max using AVX2 let mut max_vec = unsafe { _mm256_set1_ps(f32::NEG_INFINITY) }; for i in 0..chunks { unsafe { let v = _mm256_loadu_ps(input.as_ptr().add(i * 8)); max_vec = _mm256_max_ps(max_vec, v); } } // Horizontal max reduction let mut max_val = unsafe { let high = _mm256_extractf128_ps(max_vec, 1); let low = _mm256_castps256_ps128(max_vec); let max128 = _mm_max_ps(high, low); let max64 = _mm_max_ps(max128, _mm_movehl_ps(max128, max128)); let max32 = _mm_max_ss(max64, _mm_shuffle_ps(max64, max64, 1)); _mm_cvtss_f32(max32) }; // Handle remainder for max for i in (chunks * 8)..len { max_val = max_val.max(input[i]); } let max_broadcast = unsafe { _mm256_set1_ps(max_val) }; // Subtract max and compute exp (approximate with fast exp) let mut sum = 0.0f32; for i in 0..chunks { unsafe { let ptr = input.as_mut_ptr().add(i * 8); let v = _mm256_loadu_ps(ptr); let shifted = _mm256_sub_ps(v, max_broadcast); // Fast exp approximation for AVX2 using polynomial let exp_v = Self::fast_exp_avx2(shifted); _mm256_storeu_ps(ptr, exp_v); // Sum reduction let high = _mm256_extractf128_ps(exp_v, 1); let low = _mm256_castps256_ps128(exp_v); let sum128 = _mm_add_ps(high, low); let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128)); let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1)); sum += _mm_cvtss_f32(sum32); } } // Handle remainder for i in (chunks * 8)..len { input[i] = (input[i] - max_val).exp(); sum += input[i]; } // Divide by sum let inv_sum = 1.0 / sum; let inv_sum_vec = unsafe { _mm256_set1_ps(inv_sum) }; for i in 0..chunks { unsafe { let ptr = input.as_mut_ptr().add(i * 8); let v = _mm256_loadu_ps(ptr); _mm256_storeu_ps(ptr, _mm256_mul_ps(v, inv_sum_vec)); } } for i in (chunks * 8)..len { input[i] *= inv_sum; } } #[cfg(target_arch = "x86_64")] #[target_feature(enable = "avx2")] #[inline] unsafe fn fast_exp_avx2(x: __m256) -> __m256 { // Fast exp approximation: exp(x) ≈ (1 + x/256)^256 simplified // Using polynomial: exp(x) ≈ 1 + x + x²/2 + x³/6 for small x unsafe { let one = _mm256_set1_ps(1.0); let half = _mm256_set1_ps(0.5); let sixth = _mm256_set1_ps(1.0 / 6.0); // Clamp to avoid overflow let min_val = _mm256_set1_ps(-88.0); let max_val = _mm256_set1_ps(88.0); let x = _mm256_max_ps(_mm256_min_ps(x, max_val), min_val); let x2 = _mm256_mul_ps(x, x); let x3 = _mm256_mul_ps(x2, x); // 1 + x + x²/2 + x³/6 _mm256_fmadd_ps(x3, sixth, _mm256_fmadd_ps(x2, half, _mm256_add_ps(one, x))) } } /// SIMD-optimized RMSNorm with AVX2 acceleration #[inline] pub fn rms_norm(input: &[f32], weight: &[f32], eps: f32) -> Vec { #[cfg(target_arch = "x86_64")] { if is_x86_feature_detected!("avx2") { return unsafe { Self::rms_norm_avx2(input, weight, eps) }; } } // Scalar fallback let sum_sq: f32 = input.iter().map(|x| x * x).sum(); let rms = (sum_sq / input.len() as f32 + eps).sqrt(); let inv_rms = 1.0 / rms; input .iter() .zip(weight.iter()) .map(|(x, w)| x * inv_rms * w) .collect() } #[cfg(target_arch = "x86_64")] #[target_feature(enable = "avx2")] unsafe fn rms_norm_avx2(input: &[f32], weight: &[f32], eps: f32) -> Vec { let len = input.len(); let chunks = len / 8; let mut result = vec![0.0f32; len]; // Compute sum of squares using AVX2 let mut sum_sq_vec = unsafe { _mm256_setzero_ps() }; for i in 0..chunks { unsafe { let v = _mm256_loadu_ps(input.as_ptr().add(i * 8)); sum_sq_vec = _mm256_fmadd_ps(v, v, sum_sq_vec); } } // Horizontal sum let mut sum_sq = unsafe { let high = _mm256_extractf128_ps(sum_sq_vec, 1); let low = _mm256_castps256_ps128(sum_sq_vec); let sum128 = _mm_add_ps(high, low); let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128)); let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1)); _mm_cvtss_f32(sum32) }; // Handle remainder for i in (chunks * 8)..len { sum_sq += input[i] * input[i]; } let inv_rms = 1.0 / (sum_sq / len as f32 + eps).sqrt(); let inv_rms_vec = unsafe { _mm256_set1_ps(inv_rms) }; // Apply normalization and weight for i in 0..chunks { unsafe { let x = _mm256_loadu_ps(input.as_ptr().add(i * 8)); let w = _mm256_loadu_ps(weight.as_ptr().add(i * 8)); let normalized = _mm256_mul_ps(_mm256_mul_ps(x, inv_rms_vec), w); _mm256_storeu_ps(result.as_mut_ptr().add(i * 8), normalized); } } // Handle remainder for i in (chunks * 8)..len { result[i] = input[i] * inv_rms * weight[i]; } result } /// SIMD-optimized GELU activation #[inline] pub fn gelu(x: f32) -> f32 { // Approximation: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) let sqrt_2_pi = 0.7978845608028654f32; let coef = 0.044715f32; let inner = sqrt_2_pi * (x + coef * x * x * x); 0.5 * x * (1.0 + inner.tanh()) } /// SIMD-optimized SiLU activation #[inline] pub fn silu(x: f32) -> f32 { x / (1.0 + (-x).exp()) } } /// Quantized weight storage (Q4_0 format) #[derive(Clone)] pub struct Q4Weights { /// Quantized data (4-bit packed) data: Vec, /// Scale factors per block scales: Vec, /// Block size (typically 32) block_size: usize, /// Original dimensions rows: usize, cols: usize, } impl Q4Weights { /// Create from f32 weights with quantization pub fn from_f32(weights: &Array2, block_size: usize) -> Self { let rows = weights.nrows(); let cols = weights.ncols(); let total = rows * cols; let num_blocks = (total + block_size - 1) / block_size; let mut data = Vec::with_capacity(total / 2); let mut scales = Vec::with_capacity(num_blocks); let flat: Vec = weights.iter().cloned().collect(); for block in flat.chunks(block_size) { // Find max absolute value for scale let max_abs = block.iter().map(|x| x.abs()).fold(0.0f32, f32::max); let scale = max_abs / 7.0; // Q4 range is -8 to 7 scales.push(scale); // Quantize let inv_scale = if scale > 0.0 { 1.0 / scale } else { 0.0 }; for pair in block.chunks(2) { let q0 = ((pair[0] * inv_scale).round() as i8).clamp(-8, 7) as u8 & 0x0F; let q1 = if pair.len() > 1 { ((pair[1] * inv_scale).round() as i8).clamp(-8, 7) as u8 & 0x0F } else { 0 }; data.push((q1 << 4) | q0); } } Self { data, scales, block_size, rows, cols, } } /// Dequantize and multiply with vector - optimized with block processing pub fn matmul_vec(&self, vec: &[f32]) -> Vec { let mut result = vec![0.0f32; self.rows]; result.par_iter_mut().enumerate().for_each(|(row, out)| { *out = self.matmul_row_optimized(row, vec); }); result } /// Optimized single row multiplication with block-level dequantization #[inline] fn matmul_row_optimized(&self, row: usize, vec: &[f32]) -> f32 { let row_start = row * self.cols; let mut sum = 0.0f32; // Process by blocks for better cache locality let blocks_per_row = (self.cols + self.block_size - 1) / self.block_size; let first_block = row_start / self.block_size; for block_offset in 0..blocks_per_row { let block_idx = first_block + block_offset; let scale = self.scales.get(block_idx).copied().unwrap_or(1.0); let block_start_in_row = block_offset * self.block_size; let block_end_in_row = (block_start_in_row + self.block_size).min(self.cols); // Process 8 elements at a time within the block let mut col = block_start_in_row; while col + 8 <= block_end_in_row { let idx = row_start + col; let byte_start = idx / 2; // Unpack 8 values (4 bytes) let mut weights = [0.0f32; 8]; for i in 0..4 { let byte = self.data.get(byte_start + i).copied().unwrap_or(0); let q0 = (byte & 0x0F) as i8; let q1 = ((byte >> 4) & 0x0F) as i8; let q0 = if q0 > 7 { q0 - 16 } else { q0 }; let q1 = if q1 > 7 { q1 - 16 } else { q1 }; weights[i * 2] = q0 as f32 * scale; weights[i * 2 + 1] = q1 as f32 * scale; } // SIMD dot product for this block of 8 sum += SimdOps::dot_product(&weights, &vec[col..col + 8]); col += 8; } // Handle remainder within block while col < block_end_in_row { let idx = row_start + col; let byte_idx = idx / 2; let byte = self.data.get(byte_idx).copied().unwrap_or(0); let q = if idx % 2 == 0 { (byte & 0x0F) as i8 } else { ((byte >> 4) & 0x0F) as i8 }; let q = if q > 7 { q - 16 } else { q }; let w = q as f32 * scale; sum += w * vec[col]; col += 1; } } sum } } /// Minimal transformer layer pub struct TransformerLayer { /// Query projection wq: Q4Weights, /// Key projection wk: Q4Weights, /// Value projection wv: Q4Weights, /// Output projection wo: Q4Weights, /// FFN gate w1: Q4Weights, /// FFN down w2: Q4Weights, /// FFN up w3: Q4Weights, /// Attention norm weights attn_norm: Vec, /// FFN norm weights ffn_norm: Vec, /// Hidden dimension hidden_dim: usize, /// Number of heads num_heads: usize, /// Head dimension head_dim: usize, } impl TransformerLayer { pub fn new_random(hidden_dim: usize, num_heads: usize, ffn_dim: usize) -> Self { use rand::Rng; let mut rng = rand::thread_rng(); let head_dim = hidden_dim / num_heads; let mut init_weight = |rows: usize, cols: usize| -> Q4Weights { let scale = (2.0 / (rows + cols) as f32).sqrt(); let weights: Array2 = Array2::from_shape_fn((rows, cols), |_| rng.gen::() * scale * 2.0 - scale); Q4Weights::from_f32(&weights, 32) }; Self { wq: init_weight(hidden_dim, hidden_dim), wk: init_weight(hidden_dim, hidden_dim), wv: init_weight(hidden_dim, hidden_dim), wo: init_weight(hidden_dim, hidden_dim), w1: init_weight(ffn_dim, hidden_dim), w2: init_weight(hidden_dim, ffn_dim), w3: init_weight(ffn_dim, hidden_dim), attn_norm: vec![1.0; hidden_dim], ffn_norm: vec![1.0; hidden_dim], hidden_dim, num_heads, head_dim, } } pub fn forward(&self, x: &[f32], kv_cache: Option<&mut KvCache>, pos: usize) -> Vec { // RMS Norm let normed = SimdOps::rms_norm(x, &self.attn_norm, 1e-6); // QKV projections let q = self.wq.matmul_vec(&normed); let k = self.wk.matmul_vec(&normed); let v = self.wv.matmul_vec(&normed); // Update KV cache if provided let (k, v) = if let Some(cache) = kv_cache { cache.append(&k, &v); (cache.keys.clone(), cache.values.clone()) } else { (vec![k], vec![v]) }; // Multi-head attention let mut attn_out = vec![0.0f32; self.hidden_dim]; let seq_len = k.len(); for h in 0..self.num_heads { let head_start = h * self.head_dim; let head_end = head_start + self.head_dim; let q_head: Vec = q[head_start..head_end].to_vec(); // Compute attention scores let mut scores = vec![0.0f32; seq_len]; for (i, k_vec) in k.iter().enumerate() { let k_head: Vec = k_vec[head_start..head_end].to_vec(); scores[i] = SimdOps::dot_product(&q_head, &k_head) / (self.head_dim as f32).sqrt(); } // Causal mask (only attend to past) for i in (pos + 1)..seq_len { scores[i] = f32::NEG_INFINITY; } // Softmax SimdOps::softmax(&mut scores); // Weighted sum of values for (i, (score, v_vec)) in scores.iter().zip(v.iter()).enumerate() { if *score > 0.0 { for j in 0..self.head_dim { attn_out[head_start + j] += score * v_vec[head_start + j]; } } } } // Output projection let attn_out = self.wo.matmul_vec(&attn_out); // Residual let mut hidden: Vec = x.iter().zip(attn_out.iter()).map(|(a, b)| a + b).collect(); // FFN let normed = SimdOps::rms_norm(&hidden, &self.ffn_norm, 1e-6); let gate = self.w1.matmul_vec(&normed); let up = self.w3.matmul_vec(&normed); // SiLU(gate) * up let ffn_hidden: Vec = gate .iter() .zip(up.iter()) .map(|(g, u)| SimdOps::silu(*g) * u) .collect(); let ffn_out = self.w2.matmul_vec(&ffn_hidden); // Residual for (h, f) in hidden.iter_mut().zip(ffn_out.iter()) { *h += f; } hidden } } /// KV Cache for efficient generation #[derive(Default)] pub struct KvCache { pub keys: Vec>, pub values: Vec>, } impl KvCache { pub fn new() -> Self { Self::default() } pub fn append(&mut self, k: &[f32], v: &[f32]) { self.keys.push(k.to_vec()); self.values.push(v.to_vec()); } pub fn len(&self) -> usize { self.keys.len() } pub fn clear(&mut self) { self.keys.clear(); self.values.clear(); } } /// Small transformer model for CPU inference pub struct SmallTransformer { /// Embedding table embeddings: Array2, /// Transformer layers layers: Vec, /// Output norm output_norm: Vec, /// LM head (output projection) lm_head: Q4Weights, /// Vocabulary size vocab_size: usize, /// Hidden dimension hidden_dim: usize, } impl SmallTransformer { /// Create a small model with random weights (for testing/demo) pub fn new_random( vocab_size: usize, hidden_dim: usize, num_layers: usize, num_heads: usize, ffn_dim: usize, ) -> Self { use rand::Rng; let mut rng = rand::thread_rng(); // Initialize embeddings let scale = (1.0 / hidden_dim as f32).sqrt(); let embeddings = Array2::from_shape_fn((vocab_size, hidden_dim), |_| { rng.gen::() * scale * 2.0 - scale }); // Initialize layers let layers: Vec = (0..num_layers) .map(|_| TransformerLayer::new_random(hidden_dim, num_heads, ffn_dim)) .collect(); // Output norm let output_norm = vec![1.0; hidden_dim]; // LM head let lm_head_weights = Array2::from_shape_fn((vocab_size, hidden_dim), |_| { rng.gen::() * scale * 2.0 - scale }); let lm_head = Q4Weights::from_f32(&lm_head_weights, 32); Self { embeddings, layers, output_norm, lm_head, vocab_size, hidden_dim, } } /// Forward pass for a single token pub fn forward(&self, token: u32, kv_caches: &mut [KvCache], pos: usize) -> Vec { // Get embedding let mut hidden: Vec = self.embeddings.row(token as usize).to_vec(); // Run through layers for (layer, cache) in self.layers.iter().zip(kv_caches.iter_mut()) { hidden = layer.forward(&hidden, Some(cache), pos); } // Output norm let normed = SimdOps::rms_norm(&hidden, &self.output_norm, 1e-6); // LM head to get logits self.lm_head.matmul_vec(&normed) } pub fn num_layers(&self) -> usize { self.layers.len() } } /// Simple tokenizer (BPE-style for demo) pub struct SimpleTokenizer { vocab: HashMap, id_to_token: HashMap, unk_token: u32, bos_token: u32, eos_token: u32, } impl SimpleTokenizer { pub fn new_basic(vocab_size: usize) -> Self { let mut vocab = HashMap::new(); let mut id_to_token = HashMap::new(); // Special tokens vocab.insert("".to_string(), 0); vocab.insert("".to_string(), 1); vocab.insert("".to_string(), 2); vocab.insert("".to_string(), 3); id_to_token.insert(0, "".to_string()); id_to_token.insert(1, "".to_string()); id_to_token.insert(2, "".to_string()); id_to_token.insert(3, "".to_string()); // Basic ASCII characters and common tokens let mut id = 4u32; for c in ' '..='~' { if id as usize >= vocab_size { break; } let s = c.to_string(); vocab.insert(s.clone(), id); id_to_token.insert(id, s); id += 1; } // Common word pieces let common_tokens = [ "the", "and", "is", "of", "to", "in", "that", "it", "for", "was", "on", "are", "as", "with", "be", "at", "by", "this", "have", "from", "or", "had", "not", "but", "what", "all", "were", "we", "when", "your", "can", "said", "there", "use", "an", "each", "which", "she", "do", "how", "their", "if", "will", "up", "other", "about", "out", "many", "then", "them", "##ing", "##ed", "##s", "##er", "##ly", "##tion", "##al", "##ness", ]; for token in common_tokens.iter() { if id as usize >= vocab_size { break; } if !vocab.contains_key(*token) { vocab.insert(token.to_string(), id); id_to_token.insert(id, token.to_string()); id += 1; } } Self { vocab, id_to_token, unk_token: 0, bos_token: 1, eos_token: 2, } } pub fn encode(&self, text: &str) -> Vec { let mut tokens = vec![self.bos_token]; // Simple character-level tokenization with word piece fallback for c in text.chars() { let s = c.to_string(); let id = self.vocab.get(&s).copied().unwrap_or(self.unk_token); tokens.push(id); } tokens } pub fn decode(&self, tokens: &[u32]) -> String { tokens .iter() .filter_map(|&id| self.id_to_token.get(&id)) .filter(|s| !s.starts_with('<') || !s.ends_with('>')) .cloned() .collect() } pub fn vocab_size(&self) -> usize { self.vocab.len() } pub fn eos_token(&self) -> u32 { self.eos_token } } /// Generation configuration #[derive(Debug, Clone)] pub struct SimdGenerationConfig { pub max_tokens: usize, pub temperature: f32, pub top_p: f32, pub top_k: usize, pub repeat_penalty: f32, } impl Default for SimdGenerationConfig { fn default() -> Self { Self { max_tokens: 128, temperature: 0.8, top_p: 0.9, top_k: 40, repeat_penalty: 1.1, } } } /// SIMD-optimized inference engine pub struct SimdInferenceEngine { model: SmallTransformer, tokenizer: SimpleTokenizer, kv_caches: RwLock>>, /// Whether this is a demo model with random weights (not a real trained model) is_demo_model: bool, } impl SimdInferenceEngine { /// Create engine with a small random model (for demo/testing) /// /// WARNING: This creates a model with RANDOM weights for demonstration purposes. /// It will produce a placeholder response, not actual LLM inference. /// For real inference, load a trained model using `load_model()`. pub fn new_demo() -> Self { let vocab_size = 256; let hidden_dim = 256; let num_layers = 4; let num_heads = 4; let ffn_dim = 512; let model = SmallTransformer::new_random(vocab_size, hidden_dim, num_layers, num_heads, ffn_dim); let tokenizer = SimpleTokenizer::new_basic(vocab_size); Self { model, tokenizer, kv_caches: RwLock::new(HashMap::new()), is_demo_model: true, } } /// Check if this is a demo model (random weights, not trained) pub fn is_demo(&self) -> bool { self.is_demo_model } /// Sample next token fn sample(&self, logits: &[f32], config: &SimdGenerationConfig, history: &[u32]) -> u32 { let mut probs = logits.to_vec(); // Apply repeat penalty for &token in history { if (token as usize) < probs.len() { probs[token as usize] /= config.repeat_penalty; } } // Temperature if config.temperature > 0.0 { for p in &mut probs { *p /= config.temperature; } } // Softmax SimdOps::softmax(&mut probs); // Top-k filtering let mut indices: Vec = (0..probs.len()).collect(); indices.sort_by(|&a, &b| probs[b].partial_cmp(&probs[a]).unwrap()); // Top-p (nucleus) sampling let mut cumsum = 0.0; let mut cutoff = indices.len(); for (i, &idx) in indices.iter().enumerate() { cumsum += probs[idx]; if cumsum > config.top_p { cutoff = (i + 1).min(config.top_k); break; } } cutoff = cutoff.min(config.top_k); // Renormalize let valid_indices = &indices[..cutoff]; let sum: f32 = valid_indices.iter().map(|&i| probs[i]).sum(); // Sample use rand::Rng; let mut rng = rand::thread_rng(); let r: f32 = rng.gen(); let mut cumsum = 0.0; for &idx in valid_indices { cumsum += probs[idx] / sum; if r < cumsum { return idx as u32; } } valid_indices[0] as u32 } /// Generate text /// /// If this is a demo model (random weights), returns a placeholder response /// explaining that no trained model is loaded. pub fn generate( &self, prompt: &str, config: &SimdGenerationConfig, session_id: Option<&str>, ) -> (String, usize, f64) { let start = std::time::Instant::now(); // Demo model returns a helpful message instead of garbled output if self.is_demo_model { let elapsed = start.elapsed().as_secs_f64() * 1000.0; let response = format!( "[RuvLLM Demo Mode]\n\ No trained model is currently loaded. This is a demonstration engine.\n\n\ Your prompt: \"{}\"\n\n\ To get actual LLM inference:\n\ 1. Load a GGUF model file\n\ 2. Or connect to an external LLM API\n\ 3. Or use RuvLLM with a trained checkpoint\n\n\ The SIMD inference pipeline is operational with {} layers.\n\ Config: temp={:.2}, top_p={:.2}, max_tokens={}", prompt.chars().take(100).collect::(), self.model.num_layers(), config.temperature, config.top_p, config.max_tokens, ); return (response, 0, elapsed); } // Tokenize let input_tokens = self.tokenizer.encode(prompt); // Get or create KV cache let session = session_id .map(|s| s.to_string()) .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()); let mut caches_guard = self.kv_caches.write(); let kv_caches = caches_guard.entry(session).or_insert_with(|| { (0..self.model.num_layers()) .map(|_| KvCache::new()) .collect() }); // Process input tokens let mut all_tokens = input_tokens.clone(); let start_pos = kv_caches[0].len(); for (i, &token) in input_tokens.iter().enumerate() { let _ = self.model.forward(token, kv_caches, start_pos + i); } // Generate let mut generated = Vec::new(); let eos = self.tokenizer.eos_token(); for i in 0..config.max_tokens { let pos = start_pos + input_tokens.len() + i; let last_token = *all_tokens.last().unwrap_or(&0); let logits = self.model.forward(last_token, kv_caches, pos); let next_token = self.sample(&logits, config, &all_tokens); if next_token == eos { break; } generated.push(next_token); all_tokens.push(next_token); } let output = self.tokenizer.decode(&generated); let elapsed = start.elapsed().as_secs_f64() * 1000.0; (output, generated.len(), elapsed) } /// Get model info pub fn model_info(&self) -> (usize, usize) { (self.tokenizer.vocab_size(), self.model.num_layers()) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_simd_dot_product() { let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; let b = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]; let result = SimdOps::dot_product(&a, &b); assert!((result - 36.0).abs() < 1e-5); } #[test] fn test_softmax() { let mut values = vec![1.0, 2.0, 3.0]; SimdOps::softmax(&mut values); let sum: f32 = values.iter().sum(); assert!((sum - 1.0).abs() < 1e-5); assert!(values[2] > values[1]); assert!(values[1] > values[0]); } #[test] fn test_q4_quantization() { let weights = Array2::from_shape_fn((4, 4), |(i, j)| (i + j) as f32 * 0.1); let q4 = Q4Weights::from_f32(&weights, 8); let input = vec![1.0, 0.5, 0.25, 0.125]; let result = q4.matmul_vec(&input); assert_eq!(result.len(), 4); } #[test] fn test_inference_engine() { let engine = SimdInferenceEngine::new_demo(); let (vocab_size, num_layers) = engine.model_info(); assert!(vocab_size > 0); assert!(num_layers > 0); } #[test] fn test_generation() { let engine = SimdInferenceEngine::new_demo(); let config = SimdGenerationConfig { max_tokens: 10, ..Default::default() }; let (output, tokens, time_ms) = engine.generate("Hello", &config, None); assert!(tokens <= 10); assert!(time_ms > 0.0); } }