//! WebGPU Compute Context and Pipelines //! //! This module provides the core WebGPU compute functionality for WASM, //! including context initialization, pipeline creation, and kernel execution. //! //! Note: WebGPU bindings use JavaScript interop via js_sys/Reflect since //! web-sys WebGPU bindings are still unstable. use js_sys::{Array, Float32Array, Object, Promise, Reflect}; use wasm_bindgen::prelude::*; use wasm_bindgen_futures::JsFuture; use super::{shaders, AdapterInfo, AttentionConfig}; /// Check if WebGPU is available in this browser pub async fn is_webgpu_available() -> bool { #[cfg(target_arch = "wasm32")] { if let Some(gpu) = get_gpu_object() { return !gpu.is_undefined() && !gpu.is_null(); } false } #[cfg(not(target_arch = "wasm32"))] false } /// Get GPU adapter information if available pub async fn get_gpu_info() -> Option { #[cfg(target_arch = "wasm32")] { let gpu = get_gpu_object()?; // Request adapter let options = Object::new(); let _ = Reflect::set( &options, &"powerPreference".into(), &"high-performance".into(), ); let adapter_promise = call_method(&gpu, "requestAdapter", &[options.into()]).ok()?; let adapter = JsFuture::from(adapter_promise.dyn_into::().ok()?) .await .ok()?; if adapter.is_null() || adapter.is_undefined() { return None; } // Get adapter info via requestAdapterInfo() let info_promise = call_method(&adapter, "requestAdapterInfo", &[]).ok()?; let info = JsFuture::from(info_promise.dyn_into::().ok()?) .await .ok()?; // Extract limits let limits = Reflect::get(&adapter, &"limits".into()).ok()?; Some(AdapterInfo { vendor: get_string_prop(&info, "vendor").unwrap_or_default(), architecture: get_string_prop(&info, "architecture").unwrap_or_default(), device_type: get_string_prop(&info, "device").unwrap_or_else(|| "unknown".to_string()), backend: "WebGPU".to_string(), max_buffer_size: get_number_prop(&limits, "maxBufferSize") .unwrap_or(256.0 * 1024.0 * 1024.0) as u64, max_workgroup_size: get_number_prop(&limits, "maxComputeWorkgroupSizeX") .unwrap_or(256.0) as u32, }) } #[cfg(not(target_arch = "wasm32"))] None } // ============================================================================ // Helper Functions // ============================================================================ #[cfg(target_arch = "wasm32")] fn get_gpu_object() -> Option { let window = web_sys::window()?; let navigator = Reflect::get(&window, &"navigator".into()).ok()?; let gpu = Reflect::get(&navigator, &"gpu".into()).ok()?; if gpu.is_undefined() || gpu.is_null() { None } else { Some(gpu) } } #[cfg(target_arch = "wasm32")] fn get_string_prop(obj: &JsValue, key: &str) -> Option { Reflect::get(obj, &key.into()) .ok() .and_then(|v| v.as_string()) } #[cfg(target_arch = "wasm32")] fn get_number_prop(obj: &JsValue, key: &str) -> Option { Reflect::get(obj, &key.into()).ok().and_then(|v| v.as_f64()) } #[cfg(target_arch = "wasm32")] fn call_method(obj: &JsValue, method: &str, args: &[JsValue]) -> Result { let func = Reflect::get(obj, &method.into())?.dyn_into::()?; let args_array = Array::new(); for arg in args { args_array.push(arg); } Reflect::apply(&func, obj, &args_array) } // ============================================================================ // WebGPU Context // ============================================================================ /// WebGPU context holding device and queue references #[wasm_bindgen] pub struct WebGpuContext { /// GPU device object (JsValue wrapper) #[cfg(target_arch = "wasm32")] device: JsValue, /// Command queue object #[cfg(target_arch = "wasm32")] queue: JsValue, /// Placeholder for non-wasm builds #[cfg(not(target_arch = "wasm32"))] _phantom: std::marker::PhantomData<()>, /// Adapter information adapter_info: AdapterInfo, } #[wasm_bindgen] impl WebGpuContext { /// Initialize WebGPU context #[wasm_bindgen(js_name = init)] pub async fn init() -> Result { #[cfg(target_arch = "wasm32")] { let gpu = get_gpu_object().ok_or_else(|| JsValue::from_str("WebGPU not available"))?; // Request adapter with high performance preference let adapter_options = Object::new(); Reflect::set( &adapter_options, &"powerPreference".into(), &"high-performance".into(), )?; let adapter_promise = call_method(&gpu, "requestAdapter", &[adapter_options.into()])?; let adapter = JsFuture::from(adapter_promise.dyn_into::()?).await?; if adapter.is_null() || adapter.is_undefined() { return Err(JsValue::from_str("No suitable GPU adapter found")); } // Get adapter info let info_promise = call_method(&adapter, "requestAdapterInfo", &[])?; let info = JsFuture::from(info_promise.dyn_into::()?).await?; let limits = Reflect::get(&adapter, &"limits".into())?; let adapter_info = AdapterInfo { vendor: get_string_prop(&info, "vendor").unwrap_or_default(), architecture: get_string_prop(&info, "architecture").unwrap_or_default(), device_type: get_string_prop(&info, "device") .unwrap_or_else(|| "unknown".to_string()), backend: "WebGPU".to_string(), max_buffer_size: get_number_prop(&limits, "maxBufferSize") .unwrap_or(256.0 * 1024.0 * 1024.0) as u64, max_workgroup_size: get_number_prop(&limits, "maxComputeWorkgroupSizeX") .unwrap_or(256.0) as u32, }; // Request device let device_descriptor = Object::new(); Reflect::set(&device_descriptor, &"label".into(), &"ruvllm-wasm".into())?; let device_promise = call_method(&adapter, "requestDevice", &[device_descriptor.into()])?; let device = JsFuture::from(device_promise.dyn_into::()?).await?; // Get queue let queue = Reflect::get(&device, &"queue".into())?; Ok(WebGpuContext { device, queue, adapter_info, }) } #[cfg(not(target_arch = "wasm32"))] Err(JsValue::from_str("WebGPU only available in WASM")) } /// Get adapter information #[wasm_bindgen(getter, js_name = adapterInfo)] pub fn adapter_info(&self) -> AdapterInfo { self.adapter_info.clone() } /// Check if context is valid #[wasm_bindgen(getter, js_name = isValid)] pub fn is_valid(&self) -> bool { #[cfg(target_arch = "wasm32")] { !self.device.is_undefined() && !self.device.is_null() } #[cfg(not(target_arch = "wasm32"))] false } /// Create a GPU buffer #[cfg(target_arch = "wasm32")] fn create_buffer_internal( &self, size: usize, usage: u32, label: Option<&str>, ) -> Result { let descriptor = Object::new(); Reflect::set(&descriptor, &"size".into(), &JsValue::from_f64(size as f64))?; Reflect::set( &descriptor, &"usage".into(), &JsValue::from_f64(usage as f64), )?; if let Some(lbl) = label { Reflect::set(&descriptor, &"label".into(), &lbl.into())?; } call_method(&self.device, "createBuffer", &[descriptor.into()]) } /// Write data to GPU buffer #[cfg(target_arch = "wasm32")] fn write_buffer_internal(&self, buffer: &JsValue, data: &[f32]) -> Result<(), JsValue> { let data_array = Float32Array::from(data); call_method( &self.queue, "writeBuffer", &[ buffer.clone(), JsValue::from_f64(0.0), data_array.buffer().into(), ], )?; Ok(()) } } // ============================================================================ // Compute Pipeline // ============================================================================ /// Compute pipeline handle #[wasm_bindgen] pub struct ComputePipeline { #[cfg(target_arch = "wasm32")] pipeline: JsValue, #[cfg(target_arch = "wasm32")] bind_group_layout: JsValue, #[cfg(not(target_arch = "wasm32"))] _phantom: std::marker::PhantomData<()>, entry_point: String, workgroup_size: [u32; 3], } #[wasm_bindgen] impl ComputePipeline { /// Get the entry point name #[wasm_bindgen(getter, js_name = entryPoint)] pub fn entry_point(&self) -> String { self.entry_point.clone() } /// Get the workgroup size #[wasm_bindgen(getter, js_name = workgroupSize)] pub fn workgroup_size(&self) -> Vec { self.workgroup_size.to_vec() } } // ============================================================================ // WebGPU Inference Engine // ============================================================================ /// WebGPU inference engine for LLM operations #[wasm_bindgen] pub struct WebGpuInference { #[cfg(target_arch = "wasm32")] device: JsValue, #[cfg(target_arch = "wasm32")] queue: JsValue, #[cfg(not(target_arch = "wasm32"))] _phantom: std::marker::PhantomData<()>, adapter_info: AdapterInfo, } #[wasm_bindgen] impl WebGpuInference { /// Check if WebGPU is available #[wasm_bindgen(js_name = isAvailable)] pub async fn is_available() -> bool { is_webgpu_available().await } /// Initialize WebGPU inference engine #[wasm_bindgen(js_name = init)] pub async fn init() -> Result { let ctx = WebGpuContext::init().await?; Ok(WebGpuInference { #[cfg(target_arch = "wasm32")] device: ctx.device, #[cfg(target_arch = "wasm32")] queue: ctx.queue, #[cfg(not(target_arch = "wasm32"))] _phantom: std::marker::PhantomData, adapter_info: ctx.adapter_info, }) } /// Get adapter information #[wasm_bindgen(getter, js_name = adapterInfo)] pub fn adapter_info(&self) -> AdapterInfo { self.adapter_info.clone() } /// Perform matrix multiplication: C = A * B /// /// Args: /// a: Matrix A as flat f32 array (M x K) /// b: Matrix B as flat f32 array (K x N) /// m: Number of rows in A /// n: Number of columns in B /// k: Shared dimension /// /// Returns: Result matrix C as f32 array (M x N) #[wasm_bindgen] pub async fn matmul( &self, a: &[f32], b: &[f32], m: u32, n: u32, k: u32, ) -> Result, JsValue> { // Validate dimensions let expected_a = (m as usize) * (k as usize); let expected_b = (k as usize) * (n as usize); if a.len() != expected_a { return Err(JsValue::from_str(&format!( "Matrix A dimension mismatch: expected {}, got {}", expected_a, a.len() ))); } if b.len() != expected_b { return Err(JsValue::from_str(&format!( "Matrix B dimension mismatch: expected {}, got {}", expected_b, b.len() ))); } #[cfg(target_arch = "wasm32")] { let output_size = (m as usize) * (n as usize); // GPU buffer usage flags const STORAGE: u32 = 0x80; // GPUBufferUsage.STORAGE const COPY_SRC: u32 = 0x04; // GPUBufferUsage.COPY_SRC const COPY_DST: u32 = 0x08; // GPUBufferUsage.COPY_DST const MAP_READ: u32 = 0x01; // GPUBufferUsage.MAP_READ const UNIFORM: u32 = 0x40; // GPUBufferUsage.UNIFORM // Create buffers let buffer_a = self.create_buffer(a.len() * 4, STORAGE | COPY_DST, Some("matmul_a"))?; let buffer_b = self.create_buffer(b.len() * 4, STORAGE | COPY_DST, Some("matmul_b"))?; let buffer_c = self.create_buffer(output_size * 4, STORAGE | COPY_SRC, Some("matmul_c"))?; // Create uniform buffer for dimensions let uniform_data: [f32; 4] = [m as f32, n as f32, k as f32, 1.0]; // M, N, K, alpha let uniform_buffer = self.create_buffer(16, UNIFORM | COPY_DST, Some("matmul_uniforms"))?; // Write data to buffers self.write_buffer(&buffer_a, a)?; self.write_buffer(&buffer_b, b)?; self.write_buffer(&uniform_buffer, &uniform_data)?; // Create shader module let shader_desc = Object::new(); Reflect::set(&shader_desc, &"code".into(), &shaders::MATMUL_SHADER.into())?; let shader_module = call_method(&self.device, "createShaderModule", &[shader_desc.into()])?; // Create bind group layout let layout_entries = Array::new(); // Storage buffer entries (A, B, C) for i in 0..3u32 { let entry = Object::new(); Reflect::set(&entry, &"binding".into(), &JsValue::from_f64(i as f64))?; Reflect::set(&entry, &"visibility".into(), &JsValue::from_f64(4.0))?; // COMPUTE stage let buffer_layout = Object::new(); Reflect::set( &buffer_layout, &"type".into(), &(if i < 2 { "read-only-storage" } else { "storage" }) .into(), )?; Reflect::set(&entry, &"buffer".into(), &buffer_layout)?; layout_entries.push(&entry); } // Uniform buffer entry let uniform_entry = Object::new(); Reflect::set(&uniform_entry, &"binding".into(), &JsValue::from_f64(3.0))?; Reflect::set( &uniform_entry, &"visibility".into(), &JsValue::from_f64(4.0), )?; let uniform_layout = Object::new(); Reflect::set(&uniform_layout, &"type".into(), &"uniform".into())?; Reflect::set(&uniform_entry, &"buffer".into(), &uniform_layout)?; layout_entries.push(&uniform_entry); let layout_desc = Object::new(); Reflect::set(&layout_desc, &"entries".into(), &layout_entries)?; let bind_group_layout = call_method(&self.device, "createBindGroupLayout", &[layout_desc.into()])?; // Create pipeline layout let layouts = Array::new(); layouts.push(&bind_group_layout); let pipeline_layout_desc = Object::new(); Reflect::set(&pipeline_layout_desc, &"bindGroupLayouts".into(), &layouts)?; let pipeline_layout = call_method( &self.device, "createPipelineLayout", &[pipeline_layout_desc.into()], )?; // Create compute pipeline let compute_stage = Object::new(); Reflect::set(&compute_stage, &"module".into(), &shader_module)?; Reflect::set(&compute_stage, &"entryPoint".into(), &"main".into())?; let pipeline_desc = Object::new(); Reflect::set(&pipeline_desc, &"layout".into(), &pipeline_layout)?; Reflect::set(&pipeline_desc, &"compute".into(), &compute_stage)?; let pipeline = call_method( &self.device, "createComputePipeline", &[pipeline_desc.into()], )?; // Create bind group let bind_entries = Array::new(); for (i, buffer) in [&buffer_a, &buffer_b, &buffer_c, &uniform_buffer] .iter() .enumerate() { let entry = Object::new(); Reflect::set(&entry, &"binding".into(), &JsValue::from_f64(i as f64))?; let resource = Object::new(); Reflect::set(&resource, &"buffer".into(), buffer)?; Reflect::set(&entry, &"resource".into(), &resource)?; bind_entries.push(&entry); } let bind_group_desc = Object::new(); Reflect::set(&bind_group_desc, &"layout".into(), &bind_group_layout)?; Reflect::set(&bind_group_desc, &"entries".into(), &bind_entries)?; let bind_group = call_method(&self.device, "createBindGroup", &[bind_group_desc.into()])?; // Create command encoder let encoder_desc = Object::new(); let encoder = call_method(&self.device, "createCommandEncoder", &[encoder_desc.into()])?; // Begin compute pass let pass_desc = Object::new(); let pass = call_method(&encoder, "beginComputePass", &[pass_desc.into()])?; // Set pipeline and bind group call_method(&pass, "setPipeline", &[pipeline.clone()])?; call_method( &pass, "setBindGroup", &[JsValue::from_f64(0.0), bind_group.clone()], )?; // Dispatch workgroups (16x16 tile size) let workgroups_x = (m + 15) / 16; let workgroups_y = (n + 15) / 16; call_method( &pass, "dispatchWorkgroups", &[ JsValue::from_f64(workgroups_x as f64), JsValue::from_f64(workgroups_y as f64), ], )?; call_method(&pass, "end", &[])?; // Create staging buffer for readback let staging = self.create_buffer(output_size * 4, MAP_READ | COPY_DST, Some("staging"))?; // Copy result to staging call_method( &encoder, "copyBufferToBuffer", &[ buffer_c.clone(), JsValue::from_f64(0.0), staging.clone(), JsValue::from_f64(0.0), JsValue::from_f64((output_size * 4) as f64), ], )?; // Submit commands let command_buffer = call_method(&encoder, "finish", &[])?; let commands = Array::new(); commands.push(&command_buffer); call_method(&self.queue, "submit", &[commands.into()])?; // Map staging buffer and read result let map_promise = call_method(&staging, "mapAsync", &[JsValue::from_f64(1.0)])?; // MAP_READ = 1 JsFuture::from(map_promise.dyn_into::()?).await?; let mapped_range = call_method(&staging, "getMappedRange", &[])?; let data = Float32Array::new(&mapped_range).to_vec(); call_method(&staging, "unmap", &[])?; Ok(data) } #[cfg(not(target_arch = "wasm32"))] { // CPU fallback - naive implementation let mut c = vec![0.0f32; (m as usize) * (n as usize)]; for i in 0..m as usize { for j in 0..n as usize { let mut sum = 0.0f32; for l in 0..k as usize { sum += a[i * k as usize + l] * b[l * n as usize + j]; } c[i * n as usize + j] = sum; } } Ok(c) } } /// Perform attention: Output = softmax(Q * K^T / sqrt(d_k)) * V #[wasm_bindgen] pub async fn attention( &self, q: &[f32], k: &[f32], v: &[f32], config: &AttentionConfig, ) -> Result, JsValue> { let hidden_dim = config.hidden_dim(); let expected_size = (config.seq_len as usize) * (hidden_dim as usize); if q.len() != expected_size || k.len() != expected_size || v.len() != expected_size { return Err(JsValue::from_str(&format!( "Attention tensor dimension mismatch: expected {}, got Q:{}, K:{}, V:{}", expected_size, q.len(), k.len(), v.len() ))); } // CPU fallback for attention (GPU implementation similar to matmul pattern) // For production, would implement full GPU attention here self.attention_cpu(q, k, v, config) } /// CPU fallback for attention fn attention_cpu( &self, q: &[f32], k: &[f32], v: &[f32], config: &AttentionConfig, ) -> Result, JsValue> { let seq_len = config.seq_len as usize; let num_heads = config.num_heads as usize; let head_dim = config.head_dim as usize; let hidden_dim = num_heads * head_dim; let scale = config.scale(); let mut output = vec![0.0f32; seq_len * hidden_dim]; // Process each head independently for h in 0..num_heads { for i in 0..seq_len { // For this query position, compute attention to all key positions let q_offset = i * hidden_dim + h * head_dim; // Compute attention scores let mut scores = vec![0.0f32; seq_len]; let mut max_score = f32::NEG_INFINITY; for j in 0..seq_len { // Causal masking if config.causal && j > i { scores[j] = f32::NEG_INFINITY; continue; } let k_offset = j * hidden_dim + h * head_dim; let mut score = 0.0f32; for d in 0..head_dim { score += q[q_offset + d] * k[k_offset + d]; } score *= scale; scores[j] = score; if score > max_score { max_score = score; } } // Softmax let mut sum = 0.0f32; for j in 0..seq_len { scores[j] = (scores[j] - max_score).exp(); sum += scores[j]; } for j in 0..seq_len { scores[j] /= sum; } // Compute weighted sum of values let out_offset = i * hidden_dim + h * head_dim; for d in 0..head_dim { let mut weighted_sum = 0.0f32; for j in 0..seq_len { let v_offset = j * hidden_dim + h * head_dim; weighted_sum += scores[j] * v[v_offset + d]; } output[out_offset + d] = weighted_sum; } } } Ok(output) } /// Perform RMS normalization #[wasm_bindgen(js_name = rmsNorm)] pub async fn rms_norm( &self, input: &[f32], weight: &[f32], hidden_dim: u32, eps: f32, ) -> Result, JsValue> { if weight.len() != hidden_dim as usize { return Err(JsValue::from_str(&format!( "Weight dimension mismatch: expected {}, got {}", hidden_dim, weight.len() ))); } if input.len() % hidden_dim as usize != 0 { return Err(JsValue::from_str(&format!( "Input size {} not divisible by hidden_dim {}", input.len(), hidden_dim ))); } // CPU implementation let batch_size = input.len() / hidden_dim as usize; let mut output = vec![0.0f32; input.len()]; for b in 0..batch_size { let offset = b * hidden_dim as usize; // Compute sum of squares let mut sum_sq = 0.0f32; for i in 0..hidden_dim as usize { let x = input[offset + i]; sum_sq += x * x; } // RMS scale let rms = (sum_sq / hidden_dim as f32 + eps).sqrt(); // Normalize and scale for i in 0..hidden_dim as usize { output[offset + i] = input[offset + i] / rms * weight[i]; } } Ok(output) } /// Perform softmax #[wasm_bindgen] pub async fn softmax( &self, input: &[f32], dim: u32, temperature: f32, ) -> Result, JsValue> { if input.len() % dim as usize != 0 { return Err(JsValue::from_str(&format!( "Input size {} not divisible by dim {}", input.len(), dim ))); } let batch_size = input.len() / dim as usize; let mut output = vec![0.0f32; input.len()]; for b in 0..batch_size { let offset = b * dim as usize; // Find max (for numerical stability) let mut max_val = f32::NEG_INFINITY; for i in 0..dim as usize { let x = input[offset + i] / temperature; if x > max_val { max_val = x; } } // Compute exp and sum let mut sum = 0.0f32; for i in 0..dim as usize { let x = (input[offset + i] / temperature - max_val).exp(); output[offset + i] = x; sum += x; } // Normalize for i in 0..dim as usize { output[offset + i] /= sum; } } Ok(output) } // Helper methods for GPU buffer management #[cfg(target_arch = "wasm32")] fn create_buffer( &self, size: usize, usage: u32, label: Option<&str>, ) -> Result { let descriptor = Object::new(); Reflect::set(&descriptor, &"size".into(), &JsValue::from_f64(size as f64))?; Reflect::set( &descriptor, &"usage".into(), &JsValue::from_f64(usage as f64), )?; if let Some(lbl) = label { Reflect::set(&descriptor, &"label".into(), &lbl.into())?; } call_method(&self.device, "createBuffer", &[descriptor.into()]) } #[cfg(target_arch = "wasm32")] fn write_buffer(&self, buffer: &JsValue, data: &[f32]) -> Result<(), JsValue> { let data_array = Float32Array::from(data); call_method( &self.queue, "writeBuffer", &[ buffer.clone(), JsValue::from_f64(0.0), data_array.buffer().into(), ], )?; Ok(()) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_cpu_matmul_fallback() { // Test the CPU fallback logic (in non-wasm mode) let a = vec![1.0, 2.0, 3.0, 4.0]; // 2x2 let b = vec![5.0, 6.0, 7.0, 8.0]; // 2x2 // Expected: [[1*5+2*7, 1*6+2*8], [3*5+4*7, 3*6+4*8]] // = [[19, 22], [43, 50]] let mut c = vec![0.0f32; 4]; for i in 0..2usize { for j in 0..2usize { let mut sum = 0.0f32; for l in 0..2usize { sum += a[i * 2 + l] * b[l * 2 + j]; } c[i * 2 + j] = sum; } } assert_eq!(c, vec![19.0, 22.0, 43.0, 50.0]); } #[test] fn test_rms_norm_cpu() { let input = vec![1.0, 2.0, 3.0, 4.0]; let weight = vec![1.0, 1.0, 1.0, 1.0]; let hidden_dim = 4; let eps = 1e-5f32; // sum_sq = 1 + 4 + 9 + 16 = 30 // rms = sqrt(30/4 + eps) = sqrt(7.5) ≈ 2.7386 let rms = (30.0f32 / 4.0 + eps).sqrt(); let expected: Vec = input.iter().map(|&x| x / rms).collect(); // Verify calculation assert!((expected[0] - 0.3651).abs() < 0.001); } #[test] fn test_softmax_cpu() { let input = vec![1.0, 2.0, 3.0]; let temperature = 1.0f32; // max = 3 // exp(1-3) = exp(-2), exp(2-3) = exp(-1), exp(3-3) = exp(0) = 1 let exps: Vec = vec![(-2.0f32).exp(), (-1.0f32).exp(), 1.0]; let sum: f32 = exps.iter().sum(); let expected: Vec = exps.iter().map(|&x| x / sum).collect(); // Verify softmax sums to 1 let softmax_sum: f32 = expected.iter().sum(); assert!((softmax_sum - 1.0).abs() < 0.001); } }