Files
wifi-densepose/vendor/ruvector/crates/ruvllm-wasm/src/webgpu/compute.rs

883 lines
29 KiB
Rust

//! WebGPU Compute Context and Pipelines
//!
//! This module provides the core WebGPU compute functionality for WASM,
//! including context initialization, pipeline creation, and kernel execution.
//!
//! Note: WebGPU bindings use JavaScript interop via js_sys/Reflect since
//! web-sys WebGPU bindings are still unstable.
use js_sys::{Array, Float32Array, Object, Promise, Reflect};
use wasm_bindgen::prelude::*;
use wasm_bindgen_futures::JsFuture;
use super::{shaders, AdapterInfo, AttentionConfig};
/// Check if WebGPU is available in this browser
pub async fn is_webgpu_available() -> bool {
#[cfg(target_arch = "wasm32")]
{
if let Some(gpu) = get_gpu_object() {
return !gpu.is_undefined() && !gpu.is_null();
}
false
}
#[cfg(not(target_arch = "wasm32"))]
false
}
/// Get GPU adapter information if available
pub async fn get_gpu_info() -> Option<AdapterInfo> {
#[cfg(target_arch = "wasm32")]
{
let gpu = get_gpu_object()?;
// Request adapter
let options = Object::new();
let _ = Reflect::set(
&options,
&"powerPreference".into(),
&"high-performance".into(),
);
let adapter_promise = call_method(&gpu, "requestAdapter", &[options.into()]).ok()?;
let adapter = JsFuture::from(adapter_promise.dyn_into::<Promise>().ok()?)
.await
.ok()?;
if adapter.is_null() || adapter.is_undefined() {
return None;
}
// Get adapter info via requestAdapterInfo()
let info_promise = call_method(&adapter, "requestAdapterInfo", &[]).ok()?;
let info = JsFuture::from(info_promise.dyn_into::<Promise>().ok()?)
.await
.ok()?;
// Extract limits
let limits = Reflect::get(&adapter, &"limits".into()).ok()?;
Some(AdapterInfo {
vendor: get_string_prop(&info, "vendor").unwrap_or_default(),
architecture: get_string_prop(&info, "architecture").unwrap_or_default(),
device_type: get_string_prop(&info, "device").unwrap_or_else(|| "unknown".to_string()),
backend: "WebGPU".to_string(),
max_buffer_size: get_number_prop(&limits, "maxBufferSize")
.unwrap_or(256.0 * 1024.0 * 1024.0) as u64,
max_workgroup_size: get_number_prop(&limits, "maxComputeWorkgroupSizeX")
.unwrap_or(256.0) as u32,
})
}
#[cfg(not(target_arch = "wasm32"))]
None
}
// ============================================================================
// Helper Functions
// ============================================================================
#[cfg(target_arch = "wasm32")]
fn get_gpu_object() -> Option<JsValue> {
let window = web_sys::window()?;
let navigator = Reflect::get(&window, &"navigator".into()).ok()?;
let gpu = Reflect::get(&navigator, &"gpu".into()).ok()?;
if gpu.is_undefined() || gpu.is_null() {
None
} else {
Some(gpu)
}
}
#[cfg(target_arch = "wasm32")]
fn get_string_prop(obj: &JsValue, key: &str) -> Option<String> {
Reflect::get(obj, &key.into())
.ok()
.and_then(|v| v.as_string())
}
#[cfg(target_arch = "wasm32")]
fn get_number_prop(obj: &JsValue, key: &str) -> Option<f64> {
Reflect::get(obj, &key.into()).ok().and_then(|v| v.as_f64())
}
#[cfg(target_arch = "wasm32")]
fn call_method(obj: &JsValue, method: &str, args: &[JsValue]) -> Result<JsValue, JsValue> {
let func = Reflect::get(obj, &method.into())?.dyn_into::<js_sys::Function>()?;
let args_array = Array::new();
for arg in args {
args_array.push(arg);
}
Reflect::apply(&func, obj, &args_array)
}
// ============================================================================
// WebGPU Context
// ============================================================================
/// WebGPU context holding device and queue references
#[wasm_bindgen]
pub struct WebGpuContext {
/// GPU device object (JsValue wrapper)
#[cfg(target_arch = "wasm32")]
device: JsValue,
/// Command queue object
#[cfg(target_arch = "wasm32")]
queue: JsValue,
/// Placeholder for non-wasm builds
#[cfg(not(target_arch = "wasm32"))]
_phantom: std::marker::PhantomData<()>,
/// Adapter information
adapter_info: AdapterInfo,
}
#[wasm_bindgen]
impl WebGpuContext {
/// Initialize WebGPU context
#[wasm_bindgen(js_name = init)]
pub async fn init() -> Result<WebGpuContext, JsValue> {
#[cfg(target_arch = "wasm32")]
{
let gpu = get_gpu_object().ok_or_else(|| JsValue::from_str("WebGPU not available"))?;
// Request adapter with high performance preference
let adapter_options = Object::new();
Reflect::set(
&adapter_options,
&"powerPreference".into(),
&"high-performance".into(),
)?;
let adapter_promise = call_method(&gpu, "requestAdapter", &[adapter_options.into()])?;
let adapter = JsFuture::from(adapter_promise.dyn_into::<Promise>()?).await?;
if adapter.is_null() || adapter.is_undefined() {
return Err(JsValue::from_str("No suitable GPU adapter found"));
}
// Get adapter info
let info_promise = call_method(&adapter, "requestAdapterInfo", &[])?;
let info = JsFuture::from(info_promise.dyn_into::<Promise>()?).await?;
let limits = Reflect::get(&adapter, &"limits".into())?;
let adapter_info = AdapterInfo {
vendor: get_string_prop(&info, "vendor").unwrap_or_default(),
architecture: get_string_prop(&info, "architecture").unwrap_or_default(),
device_type: get_string_prop(&info, "device")
.unwrap_or_else(|| "unknown".to_string()),
backend: "WebGPU".to_string(),
max_buffer_size: get_number_prop(&limits, "maxBufferSize")
.unwrap_or(256.0 * 1024.0 * 1024.0) as u64,
max_workgroup_size: get_number_prop(&limits, "maxComputeWorkgroupSizeX")
.unwrap_or(256.0) as u32,
};
// Request device
let device_descriptor = Object::new();
Reflect::set(&device_descriptor, &"label".into(), &"ruvllm-wasm".into())?;
let device_promise =
call_method(&adapter, "requestDevice", &[device_descriptor.into()])?;
let device = JsFuture::from(device_promise.dyn_into::<Promise>()?).await?;
// Get queue
let queue = Reflect::get(&device, &"queue".into())?;
Ok(WebGpuContext {
device,
queue,
adapter_info,
})
}
#[cfg(not(target_arch = "wasm32"))]
Err(JsValue::from_str("WebGPU only available in WASM"))
}
/// Get adapter information
#[wasm_bindgen(getter, js_name = adapterInfo)]
pub fn adapter_info(&self) -> AdapterInfo {
self.adapter_info.clone()
}
/// Check if context is valid
#[wasm_bindgen(getter, js_name = isValid)]
pub fn is_valid(&self) -> bool {
#[cfg(target_arch = "wasm32")]
{
!self.device.is_undefined() && !self.device.is_null()
}
#[cfg(not(target_arch = "wasm32"))]
false
}
/// Create a GPU buffer
#[cfg(target_arch = "wasm32")]
fn create_buffer_internal(
&self,
size: usize,
usage: u32,
label: Option<&str>,
) -> Result<JsValue, JsValue> {
let descriptor = Object::new();
Reflect::set(&descriptor, &"size".into(), &JsValue::from_f64(size as f64))?;
Reflect::set(
&descriptor,
&"usage".into(),
&JsValue::from_f64(usage as f64),
)?;
if let Some(lbl) = label {
Reflect::set(&descriptor, &"label".into(), &lbl.into())?;
}
call_method(&self.device, "createBuffer", &[descriptor.into()])
}
/// Write data to GPU buffer
#[cfg(target_arch = "wasm32")]
fn write_buffer_internal(&self, buffer: &JsValue, data: &[f32]) -> Result<(), JsValue> {
let data_array = Float32Array::from(data);
call_method(
&self.queue,
"writeBuffer",
&[
buffer.clone(),
JsValue::from_f64(0.0),
data_array.buffer().into(),
],
)?;
Ok(())
}
}
// ============================================================================
// Compute Pipeline
// ============================================================================
/// Compute pipeline handle
#[wasm_bindgen]
pub struct ComputePipeline {
#[cfg(target_arch = "wasm32")]
pipeline: JsValue,
#[cfg(target_arch = "wasm32")]
bind_group_layout: JsValue,
#[cfg(not(target_arch = "wasm32"))]
_phantom: std::marker::PhantomData<()>,
entry_point: String,
workgroup_size: [u32; 3],
}
#[wasm_bindgen]
impl ComputePipeline {
/// Get the entry point name
#[wasm_bindgen(getter, js_name = entryPoint)]
pub fn entry_point(&self) -> String {
self.entry_point.clone()
}
/// Get the workgroup size
#[wasm_bindgen(getter, js_name = workgroupSize)]
pub fn workgroup_size(&self) -> Vec<u32> {
self.workgroup_size.to_vec()
}
}
// ============================================================================
// WebGPU Inference Engine
// ============================================================================
/// WebGPU inference engine for LLM operations
#[wasm_bindgen]
pub struct WebGpuInference {
#[cfg(target_arch = "wasm32")]
device: JsValue,
#[cfg(target_arch = "wasm32")]
queue: JsValue,
#[cfg(not(target_arch = "wasm32"))]
_phantom: std::marker::PhantomData<()>,
adapter_info: AdapterInfo,
}
#[wasm_bindgen]
impl WebGpuInference {
/// Check if WebGPU is available
#[wasm_bindgen(js_name = isAvailable)]
pub async fn is_available() -> bool {
is_webgpu_available().await
}
/// Initialize WebGPU inference engine
#[wasm_bindgen(js_name = init)]
pub async fn init() -> Result<WebGpuInference, JsValue> {
let ctx = WebGpuContext::init().await?;
Ok(WebGpuInference {
#[cfg(target_arch = "wasm32")]
device: ctx.device,
#[cfg(target_arch = "wasm32")]
queue: ctx.queue,
#[cfg(not(target_arch = "wasm32"))]
_phantom: std::marker::PhantomData,
adapter_info: ctx.adapter_info,
})
}
/// Get adapter information
#[wasm_bindgen(getter, js_name = adapterInfo)]
pub fn adapter_info(&self) -> AdapterInfo {
self.adapter_info.clone()
}
/// Perform matrix multiplication: C = A * B
///
/// Args:
/// a: Matrix A as flat f32 array (M x K)
/// b: Matrix B as flat f32 array (K x N)
/// m: Number of rows in A
/// n: Number of columns in B
/// k: Shared dimension
///
/// Returns: Result matrix C as f32 array (M x N)
#[wasm_bindgen]
pub async fn matmul(
&self,
a: &[f32],
b: &[f32],
m: u32,
n: u32,
k: u32,
) -> Result<Vec<f32>, JsValue> {
// Validate dimensions
let expected_a = (m as usize) * (k as usize);
let expected_b = (k as usize) * (n as usize);
if a.len() != expected_a {
return Err(JsValue::from_str(&format!(
"Matrix A dimension mismatch: expected {}, got {}",
expected_a,
a.len()
)));
}
if b.len() != expected_b {
return Err(JsValue::from_str(&format!(
"Matrix B dimension mismatch: expected {}, got {}",
expected_b,
b.len()
)));
}
#[cfg(target_arch = "wasm32")]
{
let output_size = (m as usize) * (n as usize);
// GPU buffer usage flags
const STORAGE: u32 = 0x80; // GPUBufferUsage.STORAGE
const COPY_SRC: u32 = 0x04; // GPUBufferUsage.COPY_SRC
const COPY_DST: u32 = 0x08; // GPUBufferUsage.COPY_DST
const MAP_READ: u32 = 0x01; // GPUBufferUsage.MAP_READ
const UNIFORM: u32 = 0x40; // GPUBufferUsage.UNIFORM
// Create buffers
let buffer_a = self.create_buffer(a.len() * 4, STORAGE | COPY_DST, Some("matmul_a"))?;
let buffer_b = self.create_buffer(b.len() * 4, STORAGE | COPY_DST, Some("matmul_b"))?;
let buffer_c =
self.create_buffer(output_size * 4, STORAGE | COPY_SRC, Some("matmul_c"))?;
// Create uniform buffer for dimensions
let uniform_data: [f32; 4] = [m as f32, n as f32, k as f32, 1.0]; // M, N, K, alpha
let uniform_buffer =
self.create_buffer(16, UNIFORM | COPY_DST, Some("matmul_uniforms"))?;
// Write data to buffers
self.write_buffer(&buffer_a, a)?;
self.write_buffer(&buffer_b, b)?;
self.write_buffer(&uniform_buffer, &uniform_data)?;
// Create shader module
let shader_desc = Object::new();
Reflect::set(&shader_desc, &"code".into(), &shaders::MATMUL_SHADER.into())?;
let shader_module =
call_method(&self.device, "createShaderModule", &[shader_desc.into()])?;
// Create bind group layout
let layout_entries = Array::new();
// Storage buffer entries (A, B, C)
for i in 0..3u32 {
let entry = Object::new();
Reflect::set(&entry, &"binding".into(), &JsValue::from_f64(i as f64))?;
Reflect::set(&entry, &"visibility".into(), &JsValue::from_f64(4.0))?; // COMPUTE stage
let buffer_layout = Object::new();
Reflect::set(
&buffer_layout,
&"type".into(),
&(if i < 2 {
"read-only-storage"
} else {
"storage"
})
.into(),
)?;
Reflect::set(&entry, &"buffer".into(), &buffer_layout)?;
layout_entries.push(&entry);
}
// Uniform buffer entry
let uniform_entry = Object::new();
Reflect::set(&uniform_entry, &"binding".into(), &JsValue::from_f64(3.0))?;
Reflect::set(
&uniform_entry,
&"visibility".into(),
&JsValue::from_f64(4.0),
)?;
let uniform_layout = Object::new();
Reflect::set(&uniform_layout, &"type".into(), &"uniform".into())?;
Reflect::set(&uniform_entry, &"buffer".into(), &uniform_layout)?;
layout_entries.push(&uniform_entry);
let layout_desc = Object::new();
Reflect::set(&layout_desc, &"entries".into(), &layout_entries)?;
let bind_group_layout =
call_method(&self.device, "createBindGroupLayout", &[layout_desc.into()])?;
// Create pipeline layout
let layouts = Array::new();
layouts.push(&bind_group_layout);
let pipeline_layout_desc = Object::new();
Reflect::set(&pipeline_layout_desc, &"bindGroupLayouts".into(), &layouts)?;
let pipeline_layout = call_method(
&self.device,
"createPipelineLayout",
&[pipeline_layout_desc.into()],
)?;
// Create compute pipeline
let compute_stage = Object::new();
Reflect::set(&compute_stage, &"module".into(), &shader_module)?;
Reflect::set(&compute_stage, &"entryPoint".into(), &"main".into())?;
let pipeline_desc = Object::new();
Reflect::set(&pipeline_desc, &"layout".into(), &pipeline_layout)?;
Reflect::set(&pipeline_desc, &"compute".into(), &compute_stage)?;
let pipeline = call_method(
&self.device,
"createComputePipeline",
&[pipeline_desc.into()],
)?;
// Create bind group
let bind_entries = Array::new();
for (i, buffer) in [&buffer_a, &buffer_b, &buffer_c, &uniform_buffer]
.iter()
.enumerate()
{
let entry = Object::new();
Reflect::set(&entry, &"binding".into(), &JsValue::from_f64(i as f64))?;
let resource = Object::new();
Reflect::set(&resource, &"buffer".into(), buffer)?;
Reflect::set(&entry, &"resource".into(), &resource)?;
bind_entries.push(&entry);
}
let bind_group_desc = Object::new();
Reflect::set(&bind_group_desc, &"layout".into(), &bind_group_layout)?;
Reflect::set(&bind_group_desc, &"entries".into(), &bind_entries)?;
let bind_group =
call_method(&self.device, "createBindGroup", &[bind_group_desc.into()])?;
// Create command encoder
let encoder_desc = Object::new();
let encoder =
call_method(&self.device, "createCommandEncoder", &[encoder_desc.into()])?;
// Begin compute pass
let pass_desc = Object::new();
let pass = call_method(&encoder, "beginComputePass", &[pass_desc.into()])?;
// Set pipeline and bind group
call_method(&pass, "setPipeline", &[pipeline.clone()])?;
call_method(
&pass,
"setBindGroup",
&[JsValue::from_f64(0.0), bind_group.clone()],
)?;
// Dispatch workgroups (16x16 tile size)
let workgroups_x = (m + 15) / 16;
let workgroups_y = (n + 15) / 16;
call_method(
&pass,
"dispatchWorkgroups",
&[
JsValue::from_f64(workgroups_x as f64),
JsValue::from_f64(workgroups_y as f64),
],
)?;
call_method(&pass, "end", &[])?;
// Create staging buffer for readback
let staging =
self.create_buffer(output_size * 4, MAP_READ | COPY_DST, Some("staging"))?;
// Copy result to staging
call_method(
&encoder,
"copyBufferToBuffer",
&[
buffer_c.clone(),
JsValue::from_f64(0.0),
staging.clone(),
JsValue::from_f64(0.0),
JsValue::from_f64((output_size * 4) as f64),
],
)?;
// Submit commands
let command_buffer = call_method(&encoder, "finish", &[])?;
let commands = Array::new();
commands.push(&command_buffer);
call_method(&self.queue, "submit", &[commands.into()])?;
// Map staging buffer and read result
let map_promise = call_method(&staging, "mapAsync", &[JsValue::from_f64(1.0)])?; // MAP_READ = 1
JsFuture::from(map_promise.dyn_into::<Promise>()?).await?;
let mapped_range = call_method(&staging, "getMappedRange", &[])?;
let data = Float32Array::new(&mapped_range).to_vec();
call_method(&staging, "unmap", &[])?;
Ok(data)
}
#[cfg(not(target_arch = "wasm32"))]
{
// CPU fallback - naive implementation
let mut c = vec![0.0f32; (m as usize) * (n as usize)];
for i in 0..m as usize {
for j in 0..n as usize {
let mut sum = 0.0f32;
for l in 0..k as usize {
sum += a[i * k as usize + l] * b[l * n as usize + j];
}
c[i * n as usize + j] = sum;
}
}
Ok(c)
}
}
/// Perform attention: Output = softmax(Q * K^T / sqrt(d_k)) * V
#[wasm_bindgen]
pub async fn attention(
&self,
q: &[f32],
k: &[f32],
v: &[f32],
config: &AttentionConfig,
) -> Result<Vec<f32>, JsValue> {
let hidden_dim = config.hidden_dim();
let expected_size = (config.seq_len as usize) * (hidden_dim as usize);
if q.len() != expected_size || k.len() != expected_size || v.len() != expected_size {
return Err(JsValue::from_str(&format!(
"Attention tensor dimension mismatch: expected {}, got Q:{}, K:{}, V:{}",
expected_size,
q.len(),
k.len(),
v.len()
)));
}
// CPU fallback for attention (GPU implementation similar to matmul pattern)
// For production, would implement full GPU attention here
self.attention_cpu(q, k, v, config)
}
/// CPU fallback for attention
fn attention_cpu(
&self,
q: &[f32],
k: &[f32],
v: &[f32],
config: &AttentionConfig,
) -> Result<Vec<f32>, JsValue> {
let seq_len = config.seq_len as usize;
let num_heads = config.num_heads as usize;
let head_dim = config.head_dim as usize;
let hidden_dim = num_heads * head_dim;
let scale = config.scale();
let mut output = vec![0.0f32; seq_len * hidden_dim];
// Process each head independently
for h in 0..num_heads {
for i in 0..seq_len {
// For this query position, compute attention to all key positions
let q_offset = i * hidden_dim + h * head_dim;
// Compute attention scores
let mut scores = vec![0.0f32; seq_len];
let mut max_score = f32::NEG_INFINITY;
for j in 0..seq_len {
// Causal masking
if config.causal && j > i {
scores[j] = f32::NEG_INFINITY;
continue;
}
let k_offset = j * hidden_dim + h * head_dim;
let mut score = 0.0f32;
for d in 0..head_dim {
score += q[q_offset + d] * k[k_offset + d];
}
score *= scale;
scores[j] = score;
if score > max_score {
max_score = score;
}
}
// Softmax
let mut sum = 0.0f32;
for j in 0..seq_len {
scores[j] = (scores[j] - max_score).exp();
sum += scores[j];
}
for j in 0..seq_len {
scores[j] /= sum;
}
// Compute weighted sum of values
let out_offset = i * hidden_dim + h * head_dim;
for d in 0..head_dim {
let mut weighted_sum = 0.0f32;
for j in 0..seq_len {
let v_offset = j * hidden_dim + h * head_dim;
weighted_sum += scores[j] * v[v_offset + d];
}
output[out_offset + d] = weighted_sum;
}
}
}
Ok(output)
}
/// Perform RMS normalization
#[wasm_bindgen(js_name = rmsNorm)]
pub async fn rms_norm(
&self,
input: &[f32],
weight: &[f32],
hidden_dim: u32,
eps: f32,
) -> Result<Vec<f32>, JsValue> {
if weight.len() != hidden_dim as usize {
return Err(JsValue::from_str(&format!(
"Weight dimension mismatch: expected {}, got {}",
hidden_dim,
weight.len()
)));
}
if input.len() % hidden_dim as usize != 0 {
return Err(JsValue::from_str(&format!(
"Input size {} not divisible by hidden_dim {}",
input.len(),
hidden_dim
)));
}
// CPU implementation
let batch_size = input.len() / hidden_dim as usize;
let mut output = vec![0.0f32; input.len()];
for b in 0..batch_size {
let offset = b * hidden_dim as usize;
// Compute sum of squares
let mut sum_sq = 0.0f32;
for i in 0..hidden_dim as usize {
let x = input[offset + i];
sum_sq += x * x;
}
// RMS scale
let rms = (sum_sq / hidden_dim as f32 + eps).sqrt();
// Normalize and scale
for i in 0..hidden_dim as usize {
output[offset + i] = input[offset + i] / rms * weight[i];
}
}
Ok(output)
}
/// Perform softmax
#[wasm_bindgen]
pub async fn softmax(
&self,
input: &[f32],
dim: u32,
temperature: f32,
) -> Result<Vec<f32>, JsValue> {
if input.len() % dim as usize != 0 {
return Err(JsValue::from_str(&format!(
"Input size {} not divisible by dim {}",
input.len(),
dim
)));
}
let batch_size = input.len() / dim as usize;
let mut output = vec![0.0f32; input.len()];
for b in 0..batch_size {
let offset = b * dim as usize;
// Find max (for numerical stability)
let mut max_val = f32::NEG_INFINITY;
for i in 0..dim as usize {
let x = input[offset + i] / temperature;
if x > max_val {
max_val = x;
}
}
// Compute exp and sum
let mut sum = 0.0f32;
for i in 0..dim as usize {
let x = (input[offset + i] / temperature - max_val).exp();
output[offset + i] = x;
sum += x;
}
// Normalize
for i in 0..dim as usize {
output[offset + i] /= sum;
}
}
Ok(output)
}
// Helper methods for GPU buffer management
#[cfg(target_arch = "wasm32")]
fn create_buffer(
&self,
size: usize,
usage: u32,
label: Option<&str>,
) -> Result<JsValue, JsValue> {
let descriptor = Object::new();
Reflect::set(&descriptor, &"size".into(), &JsValue::from_f64(size as f64))?;
Reflect::set(
&descriptor,
&"usage".into(),
&JsValue::from_f64(usage as f64),
)?;
if let Some(lbl) = label {
Reflect::set(&descriptor, &"label".into(), &lbl.into())?;
}
call_method(&self.device, "createBuffer", &[descriptor.into()])
}
#[cfg(target_arch = "wasm32")]
fn write_buffer(&self, buffer: &JsValue, data: &[f32]) -> Result<(), JsValue> {
let data_array = Float32Array::from(data);
call_method(
&self.queue,
"writeBuffer",
&[
buffer.clone(),
JsValue::from_f64(0.0),
data_array.buffer().into(),
],
)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cpu_matmul_fallback() {
// Test the CPU fallback logic (in non-wasm mode)
let a = vec![1.0, 2.0, 3.0, 4.0]; // 2x2
let b = vec![5.0, 6.0, 7.0, 8.0]; // 2x2
// Expected: [[1*5+2*7, 1*6+2*8], [3*5+4*7, 3*6+4*8]]
// = [[19, 22], [43, 50]]
let mut c = vec![0.0f32; 4];
for i in 0..2usize {
for j in 0..2usize {
let mut sum = 0.0f32;
for l in 0..2usize {
sum += a[i * 2 + l] * b[l * 2 + j];
}
c[i * 2 + j] = sum;
}
}
assert_eq!(c, vec![19.0, 22.0, 43.0, 50.0]);
}
#[test]
fn test_rms_norm_cpu() {
let input = vec![1.0, 2.0, 3.0, 4.0];
let weight = vec![1.0, 1.0, 1.0, 1.0];
let hidden_dim = 4;
let eps = 1e-5f32;
// sum_sq = 1 + 4 + 9 + 16 = 30
// rms = sqrt(30/4 + eps) = sqrt(7.5) ≈ 2.7386
let rms = (30.0f32 / 4.0 + eps).sqrt();
let expected: Vec<f32> = input.iter().map(|&x| x / rms).collect();
// Verify calculation
assert!((expected[0] - 0.3651).abs() < 0.001);
}
#[test]
fn test_softmax_cpu() {
let input = vec![1.0, 2.0, 3.0];
let temperature = 1.0f32;
// max = 3
// exp(1-3) = exp(-2), exp(2-3) = exp(-1), exp(3-3) = exp(0) = 1
let exps: Vec<f32> = vec![(-2.0f32).exp(), (-1.0f32).exp(), 1.0];
let sum: f32 = exps.iter().sum();
let expected: Vec<f32> = exps.iter().map(|&x| x / sum).collect();
// Verify softmax sums to 1
let softmax_sum: f32 = expected.iter().sum();
assert!((softmax_sum - 1.0).abs() < 0.001);
}
}