Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
469
vendor/ruvector/crates/ruvllm-wasm/src/webgpu/buffers.rs
vendored
Normal file
469
vendor/ruvector/crates/ruvllm-wasm/src/webgpu/buffers.rs
vendored
Normal file
@@ -0,0 +1,469 @@
|
||||
//! GPU Buffer Management for WebGPU WASM
|
||||
//!
|
||||
//! This module provides buffer abstractions for GPU memory management
|
||||
//! in the browser WebGPU environment.
|
||||
|
||||
use js_sys::{Float32Array, Uint8Array};
|
||||
use std::cell::RefCell;
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
/// Buffer usage flags
|
||||
#[wasm_bindgen]
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct GpuBufferUsage {
|
||||
/// Can be mapped for reading
|
||||
#[wasm_bindgen(skip)]
|
||||
pub map_read: bool,
|
||||
/// Can be mapped for writing
|
||||
#[wasm_bindgen(skip)]
|
||||
pub map_write: bool,
|
||||
/// Can be used as copy source
|
||||
#[wasm_bindgen(skip)]
|
||||
pub copy_src: bool,
|
||||
/// Can be used as copy destination
|
||||
#[wasm_bindgen(skip)]
|
||||
pub copy_dst: bool,
|
||||
/// Can be used as storage buffer
|
||||
#[wasm_bindgen(skip)]
|
||||
pub storage: bool,
|
||||
/// Can be used as uniform buffer
|
||||
#[wasm_bindgen(skip)]
|
||||
pub uniform: bool,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl GpuBufferUsage {
|
||||
/// Create storage buffer usage (read/write compute)
|
||||
#[wasm_bindgen(js_name = storage)]
|
||||
pub fn new_storage() -> Self {
|
||||
Self {
|
||||
storage: true,
|
||||
copy_dst: true,
|
||||
copy_src: true,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create uniform buffer usage
|
||||
#[wasm_bindgen(js_name = uniform)]
|
||||
pub fn new_uniform() -> Self {
|
||||
Self {
|
||||
uniform: true,
|
||||
copy_dst: true,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create staging buffer for upload
|
||||
#[wasm_bindgen(js_name = stagingUpload)]
|
||||
pub fn staging_upload() -> Self {
|
||||
Self {
|
||||
map_write: true,
|
||||
copy_src: true,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create staging buffer for download
|
||||
#[wasm_bindgen(js_name = stagingDownload)]
|
||||
pub fn staging_download() -> Self {
|
||||
Self {
|
||||
map_read: true,
|
||||
copy_dst: true,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create read-only storage buffer
|
||||
#[wasm_bindgen(js_name = storageReadOnly)]
|
||||
pub fn storage_read_only() -> Self {
|
||||
Self {
|
||||
storage: true,
|
||||
copy_dst: true,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert to WebGPU usage flags (as raw u32)
|
||||
///
|
||||
/// WebGPU buffer usage flags:
|
||||
/// - MAP_READ = 0x0001
|
||||
/// - MAP_WRITE = 0x0002
|
||||
/// - COPY_SRC = 0x0004
|
||||
/// - COPY_DST = 0x0008
|
||||
/// - INDEX = 0x0010
|
||||
/// - VERTEX = 0x0020
|
||||
/// - UNIFORM = 0x0040
|
||||
/// - STORAGE = 0x0080
|
||||
/// - INDIRECT = 0x0100
|
||||
/// - QUERY_RESOLVE = 0x0200
|
||||
pub fn to_u32(&self) -> u32 {
|
||||
let mut flags = 0u32;
|
||||
if self.map_read {
|
||||
flags |= 0x0001;
|
||||
}
|
||||
if self.map_write {
|
||||
flags |= 0x0002;
|
||||
}
|
||||
if self.copy_src {
|
||||
flags |= 0x0004;
|
||||
}
|
||||
if self.copy_dst {
|
||||
flags |= 0x0008;
|
||||
}
|
||||
if self.uniform {
|
||||
flags |= 0x0040;
|
||||
}
|
||||
if self.storage {
|
||||
flags |= 0x0080;
|
||||
}
|
||||
flags
|
||||
}
|
||||
|
||||
#[wasm_bindgen(getter, js_name = mapRead)]
|
||||
pub fn get_map_read(&self) -> bool {
|
||||
self.map_read
|
||||
}
|
||||
|
||||
#[wasm_bindgen(setter, js_name = mapRead)]
|
||||
pub fn set_map_read(&mut self, value: bool) {
|
||||
self.map_read = value;
|
||||
}
|
||||
|
||||
#[wasm_bindgen(getter, js_name = mapWrite)]
|
||||
pub fn get_map_write(&self) -> bool {
|
||||
self.map_write
|
||||
}
|
||||
|
||||
#[wasm_bindgen(setter, js_name = mapWrite)]
|
||||
pub fn set_map_write(&mut self, value: bool) {
|
||||
self.map_write = value;
|
||||
}
|
||||
|
||||
#[wasm_bindgen(getter, js_name = copySrc)]
|
||||
pub fn get_copy_src(&self) -> bool {
|
||||
self.copy_src
|
||||
}
|
||||
|
||||
#[wasm_bindgen(setter, js_name = copySrc)]
|
||||
pub fn set_copy_src(&mut self, value: bool) {
|
||||
self.copy_src = value;
|
||||
}
|
||||
|
||||
#[wasm_bindgen(getter, js_name = copyDst)]
|
||||
pub fn get_copy_dst(&self) -> bool {
|
||||
self.copy_dst
|
||||
}
|
||||
|
||||
#[wasm_bindgen(setter, js_name = copyDst)]
|
||||
pub fn set_copy_dst(&mut self, value: bool) {
|
||||
self.copy_dst = value;
|
||||
}
|
||||
|
||||
#[wasm_bindgen(getter, js_name = isStorage)]
|
||||
pub fn get_storage(&self) -> bool {
|
||||
self.storage
|
||||
}
|
||||
|
||||
#[wasm_bindgen(setter, js_name = isStorage)]
|
||||
pub fn set_storage(&mut self, value: bool) {
|
||||
self.storage = value;
|
||||
}
|
||||
|
||||
#[wasm_bindgen(getter, js_name = isUniform)]
|
||||
pub fn get_uniform(&self) -> bool {
|
||||
self.uniform
|
||||
}
|
||||
|
||||
#[wasm_bindgen(setter, js_name = isUniform)]
|
||||
pub fn set_uniform(&mut self, value: bool) {
|
||||
self.uniform = value;
|
||||
}
|
||||
}
|
||||
|
||||
/// GPU buffer handle
|
||||
///
|
||||
/// Wraps a WebGPU buffer with metadata for safe operations.
|
||||
#[wasm_bindgen]
|
||||
pub struct GpuBuffer {
|
||||
/// Internal buffer handle (web_sys::GpuBuffer when on wasm32)
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
buffer: web_sys::GpuBuffer,
|
||||
|
||||
/// Placeholder for non-wasm32 builds
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
buffer: Vec<u8>,
|
||||
|
||||
/// Buffer size in bytes
|
||||
size: usize,
|
||||
|
||||
/// Buffer usage flags
|
||||
usage: GpuBufferUsage,
|
||||
|
||||
/// Optional label for debugging
|
||||
label: Option<String>,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl GpuBuffer {
|
||||
/// Get buffer size in bytes
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
|
||||
/// Get buffer label
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn label(&self) -> Option<String> {
|
||||
self.label.clone()
|
||||
}
|
||||
|
||||
/// Check if buffer supports mapping for read
|
||||
#[wasm_bindgen(getter, js_name = canMapRead)]
|
||||
pub fn can_map_read(&self) -> bool {
|
||||
self.usage.map_read
|
||||
}
|
||||
|
||||
/// Check if buffer supports mapping for write
|
||||
#[wasm_bindgen(getter, js_name = canMapWrite)]
|
||||
pub fn can_map_write(&self) -> bool {
|
||||
self.usage.map_write
|
||||
}
|
||||
|
||||
/// Get size as number of f32 elements
|
||||
#[wasm_bindgen(js_name = sizeAsF32)]
|
||||
pub fn size_as_f32(&self) -> usize {
|
||||
self.size / 4
|
||||
}
|
||||
|
||||
/// Get the raw web_sys buffer (for advanced usage)
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
#[wasm_bindgen(getter, js_name = rawBuffer)]
|
||||
pub fn raw_buffer(&self) -> web_sys::GpuBuffer {
|
||||
self.buffer.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl GpuBuffer {
|
||||
/// Create a new GPU buffer (internal constructor)
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
pub(crate) fn new(
|
||||
buffer: web_sys::GpuBuffer,
|
||||
size: usize,
|
||||
usage: GpuBufferUsage,
|
||||
label: Option<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
buffer,
|
||||
size,
|
||||
usage,
|
||||
label,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new GPU buffer (non-wasm32 placeholder)
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub(crate) fn new(size: usize, usage: GpuBufferUsage, label: Option<String>) -> Self {
|
||||
Self {
|
||||
buffer: vec![0u8; size],
|
||||
size,
|
||||
usage,
|
||||
label,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get internal buffer reference
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
pub(crate) fn inner(&self) -> &web_sys::GpuBuffer {
|
||||
&self.buffer
|
||||
}
|
||||
}
|
||||
|
||||
/// Staging buffer pool for efficient CPU<->GPU transfers
|
||||
#[wasm_bindgen]
|
||||
pub struct StagingBufferPool {
|
||||
/// Pool of upload staging buffers
|
||||
upload_pool: RefCell<Vec<GpuBuffer>>,
|
||||
/// Pool of download staging buffers
|
||||
download_pool: RefCell<Vec<GpuBuffer>>,
|
||||
/// Maximum buffers per pool
|
||||
max_per_pool: usize,
|
||||
/// Total bytes allocated
|
||||
total_allocated: RefCell<usize>,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl StagingBufferPool {
|
||||
/// Create a new staging buffer pool
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(max_per_pool: usize) -> Self {
|
||||
Self {
|
||||
upload_pool: RefCell::new(Vec::with_capacity(max_per_pool)),
|
||||
download_pool: RefCell::new(Vec::with_capacity(max_per_pool)),
|
||||
max_per_pool,
|
||||
total_allocated: RefCell::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the number of upload buffers in pool
|
||||
#[wasm_bindgen(getter, js_name = uploadBufferCount)]
|
||||
pub fn upload_buffer_count(&self) -> usize {
|
||||
self.upload_pool.borrow().len()
|
||||
}
|
||||
|
||||
/// Get the number of download buffers in pool
|
||||
#[wasm_bindgen(getter, js_name = downloadBufferCount)]
|
||||
pub fn download_buffer_count(&self) -> usize {
|
||||
self.download_pool.borrow().len()
|
||||
}
|
||||
|
||||
/// Get total bytes allocated
|
||||
#[wasm_bindgen(getter, js_name = totalAllocated)]
|
||||
pub fn total_allocated(&self) -> usize {
|
||||
*self.total_allocated.borrow()
|
||||
}
|
||||
|
||||
/// Clear all pooled buffers
|
||||
#[wasm_bindgen]
|
||||
pub fn clear(&self) {
|
||||
self.upload_pool.borrow_mut().clear();
|
||||
self.download_pool.borrow_mut().clear();
|
||||
*self.total_allocated.borrow_mut() = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Tensor descriptor for buffer allocation
|
||||
#[wasm_bindgen]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TensorDescriptor {
|
||||
/// Shape dimensions
|
||||
shape: Vec<u32>,
|
||||
/// Data type (0=f32, 1=f16, 2=i32, 3=u8)
|
||||
dtype: u8,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl TensorDescriptor {
|
||||
/// Create tensor descriptor for a matrix
|
||||
#[wasm_bindgen(js_name = matrix)]
|
||||
pub fn matrix(rows: u32, cols: u32) -> Self {
|
||||
Self {
|
||||
shape: vec![rows, cols],
|
||||
dtype: 0, // f32
|
||||
}
|
||||
}
|
||||
|
||||
/// Create tensor descriptor for a vector
|
||||
#[wasm_bindgen(js_name = vector)]
|
||||
pub fn vector(len: u32) -> Self {
|
||||
Self {
|
||||
shape: vec![len],
|
||||
dtype: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create tensor descriptor with arbitrary shape
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(shape: Vec<u32>, dtype: u8) -> Self {
|
||||
Self { shape, dtype }
|
||||
}
|
||||
|
||||
/// Get total number of elements
|
||||
#[wasm_bindgen(js_name = numElements)]
|
||||
pub fn num_elements(&self) -> usize {
|
||||
self.shape.iter().map(|&d| d as usize).product()
|
||||
}
|
||||
|
||||
/// Get size in bytes
|
||||
#[wasm_bindgen(js_name = sizeBytes)]
|
||||
pub fn size_bytes(&self) -> usize {
|
||||
let element_size = match self.dtype {
|
||||
0 => 4, // f32
|
||||
1 => 2, // f16
|
||||
2 => 4, // i32
|
||||
3 => 1, // u8
|
||||
_ => 4, // default to f32
|
||||
};
|
||||
self.num_elements() * element_size
|
||||
}
|
||||
|
||||
/// Get shape dimensions
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn shape(&self) -> Vec<u32> {
|
||||
self.shape.clone()
|
||||
}
|
||||
|
||||
/// Get data type
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn dtype(&self) -> u8 {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
/// Get number of dimensions
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn ndim(&self) -> usize {
|
||||
self.shape.len()
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper functions for creating typed arrays from GPU buffers
|
||||
#[wasm_bindgen]
|
||||
pub struct BufferHelpers;
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl BufferHelpers {
|
||||
/// Create a Float32Array view from a Uint8Array
|
||||
#[wasm_bindgen(js_name = asFloat32Array)]
|
||||
pub fn as_float32_array(data: &Uint8Array) -> Float32Array {
|
||||
Float32Array::new(&data.buffer())
|
||||
}
|
||||
|
||||
/// Calculate aligned size for GPU buffers (must be multiple of 4)
|
||||
#[wasm_bindgen(js_name = alignedSize)]
|
||||
pub fn aligned_size(size: usize) -> usize {
|
||||
(size + 3) & !3
|
||||
}
|
||||
|
||||
/// Calculate workgroup count for a given dimension
|
||||
#[wasm_bindgen(js_name = workgroupCount)]
|
||||
pub fn workgroup_count(total: u32, workgroup_size: u32) -> u32 {
|
||||
(total + workgroup_size - 1) / workgroup_size
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_buffer_usage() {
|
||||
let storage = GpuBufferUsage::new_storage();
|
||||
assert!(storage.storage);
|
||||
assert!(storage.copy_dst);
|
||||
assert!(storage.copy_src);
|
||||
assert!(!storage.uniform);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tensor_descriptor() {
|
||||
let matrix = TensorDescriptor::matrix(1024, 768);
|
||||
assert_eq!(matrix.num_elements(), 1024 * 768);
|
||||
assert_eq!(matrix.size_bytes(), 1024 * 768 * 4);
|
||||
assert_eq!(matrix.ndim(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aligned_size() {
|
||||
assert_eq!(BufferHelpers::aligned_size(0), 0);
|
||||
assert_eq!(BufferHelpers::aligned_size(1), 4);
|
||||
assert_eq!(BufferHelpers::aligned_size(4), 4);
|
||||
assert_eq!(BufferHelpers::aligned_size(5), 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_workgroup_count() {
|
||||
assert_eq!(BufferHelpers::workgroup_count(1000, 256), 4);
|
||||
assert_eq!(BufferHelpers::workgroup_count(256, 256), 1);
|
||||
assert_eq!(BufferHelpers::workgroup_count(257, 256), 2);
|
||||
}
|
||||
}
|
||||
882
vendor/ruvector/crates/ruvllm-wasm/src/webgpu/compute.rs
vendored
Normal file
882
vendor/ruvector/crates/ruvllm-wasm/src/webgpu/compute.rs
vendored
Normal file
@@ -0,0 +1,882 @@
|
||||
//! WebGPU Compute Context and Pipelines
|
||||
//!
|
||||
//! This module provides the core WebGPU compute functionality for WASM,
|
||||
//! including context initialization, pipeline creation, and kernel execution.
|
||||
//!
|
||||
//! Note: WebGPU bindings use JavaScript interop via js_sys/Reflect since
|
||||
//! web-sys WebGPU bindings are still unstable.
|
||||
|
||||
use js_sys::{Array, Float32Array, Object, Promise, Reflect};
|
||||
use wasm_bindgen::prelude::*;
|
||||
use wasm_bindgen_futures::JsFuture;
|
||||
|
||||
use super::{shaders, AdapterInfo, AttentionConfig};
|
||||
|
||||
/// Check if WebGPU is available in this browser
|
||||
pub async fn is_webgpu_available() -> bool {
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
{
|
||||
if let Some(gpu) = get_gpu_object() {
|
||||
return !gpu.is_undefined() && !gpu.is_null();
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
false
|
||||
}
|
||||
|
||||
/// Get GPU adapter information if available
|
||||
pub async fn get_gpu_info() -> Option<AdapterInfo> {
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
{
|
||||
let gpu = get_gpu_object()?;
|
||||
|
||||
// Request adapter
|
||||
let options = Object::new();
|
||||
let _ = Reflect::set(
|
||||
&options,
|
||||
&"powerPreference".into(),
|
||||
&"high-performance".into(),
|
||||
);
|
||||
|
||||
let adapter_promise = call_method(&gpu, "requestAdapter", &[options.into()]).ok()?;
|
||||
let adapter = JsFuture::from(adapter_promise.dyn_into::<Promise>().ok()?)
|
||||
.await
|
||||
.ok()?;
|
||||
|
||||
if adapter.is_null() || adapter.is_undefined() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Get adapter info via requestAdapterInfo()
|
||||
let info_promise = call_method(&adapter, "requestAdapterInfo", &[]).ok()?;
|
||||
let info = JsFuture::from(info_promise.dyn_into::<Promise>().ok()?)
|
||||
.await
|
||||
.ok()?;
|
||||
|
||||
// Extract limits
|
||||
let limits = Reflect::get(&adapter, &"limits".into()).ok()?;
|
||||
|
||||
Some(AdapterInfo {
|
||||
vendor: get_string_prop(&info, "vendor").unwrap_or_default(),
|
||||
architecture: get_string_prop(&info, "architecture").unwrap_or_default(),
|
||||
device_type: get_string_prop(&info, "device").unwrap_or_else(|| "unknown".to_string()),
|
||||
backend: "WebGPU".to_string(),
|
||||
max_buffer_size: get_number_prop(&limits, "maxBufferSize")
|
||||
.unwrap_or(256.0 * 1024.0 * 1024.0) as u64,
|
||||
max_workgroup_size: get_number_prop(&limits, "maxComputeWorkgroupSizeX")
|
||||
.unwrap_or(256.0) as u32,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
None
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Helper Functions
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
fn get_gpu_object() -> Option<JsValue> {
|
||||
let window = web_sys::window()?;
|
||||
let navigator = Reflect::get(&window, &"navigator".into()).ok()?;
|
||||
let gpu = Reflect::get(&navigator, &"gpu".into()).ok()?;
|
||||
if gpu.is_undefined() || gpu.is_null() {
|
||||
None
|
||||
} else {
|
||||
Some(gpu)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
fn get_string_prop(obj: &JsValue, key: &str) -> Option<String> {
|
||||
Reflect::get(obj, &key.into())
|
||||
.ok()
|
||||
.and_then(|v| v.as_string())
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
fn get_number_prop(obj: &JsValue, key: &str) -> Option<f64> {
|
||||
Reflect::get(obj, &key.into()).ok().and_then(|v| v.as_f64())
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
fn call_method(obj: &JsValue, method: &str, args: &[JsValue]) -> Result<JsValue, JsValue> {
|
||||
let func = Reflect::get(obj, &method.into())?.dyn_into::<js_sys::Function>()?;
|
||||
|
||||
let args_array = Array::new();
|
||||
for arg in args {
|
||||
args_array.push(arg);
|
||||
}
|
||||
|
||||
Reflect::apply(&func, obj, &args_array)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// WebGPU Context
|
||||
// ============================================================================
|
||||
|
||||
/// WebGPU context holding device and queue references
|
||||
#[wasm_bindgen]
|
||||
pub struct WebGpuContext {
|
||||
/// GPU device object (JsValue wrapper)
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
device: JsValue,
|
||||
|
||||
/// Command queue object
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
queue: JsValue,
|
||||
|
||||
/// Placeholder for non-wasm builds
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
_phantom: std::marker::PhantomData<()>,
|
||||
|
||||
/// Adapter information
|
||||
adapter_info: AdapterInfo,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WebGpuContext {
|
||||
/// Initialize WebGPU context
|
||||
#[wasm_bindgen(js_name = init)]
|
||||
pub async fn init() -> Result<WebGpuContext, JsValue> {
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
{
|
||||
let gpu = get_gpu_object().ok_or_else(|| JsValue::from_str("WebGPU not available"))?;
|
||||
|
||||
// Request adapter with high performance preference
|
||||
let adapter_options = Object::new();
|
||||
Reflect::set(
|
||||
&adapter_options,
|
||||
&"powerPreference".into(),
|
||||
&"high-performance".into(),
|
||||
)?;
|
||||
|
||||
let adapter_promise = call_method(&gpu, "requestAdapter", &[adapter_options.into()])?;
|
||||
let adapter = JsFuture::from(adapter_promise.dyn_into::<Promise>()?).await?;
|
||||
|
||||
if adapter.is_null() || adapter.is_undefined() {
|
||||
return Err(JsValue::from_str("No suitable GPU adapter found"));
|
||||
}
|
||||
|
||||
// Get adapter info
|
||||
let info_promise = call_method(&adapter, "requestAdapterInfo", &[])?;
|
||||
let info = JsFuture::from(info_promise.dyn_into::<Promise>()?).await?;
|
||||
let limits = Reflect::get(&adapter, &"limits".into())?;
|
||||
|
||||
let adapter_info = AdapterInfo {
|
||||
vendor: get_string_prop(&info, "vendor").unwrap_or_default(),
|
||||
architecture: get_string_prop(&info, "architecture").unwrap_or_default(),
|
||||
device_type: get_string_prop(&info, "device")
|
||||
.unwrap_or_else(|| "unknown".to_string()),
|
||||
backend: "WebGPU".to_string(),
|
||||
max_buffer_size: get_number_prop(&limits, "maxBufferSize")
|
||||
.unwrap_or(256.0 * 1024.0 * 1024.0) as u64,
|
||||
max_workgroup_size: get_number_prop(&limits, "maxComputeWorkgroupSizeX")
|
||||
.unwrap_or(256.0) as u32,
|
||||
};
|
||||
|
||||
// Request device
|
||||
let device_descriptor = Object::new();
|
||||
Reflect::set(&device_descriptor, &"label".into(), &"ruvllm-wasm".into())?;
|
||||
|
||||
let device_promise =
|
||||
call_method(&adapter, "requestDevice", &[device_descriptor.into()])?;
|
||||
let device = JsFuture::from(device_promise.dyn_into::<Promise>()?).await?;
|
||||
|
||||
// Get queue
|
||||
let queue = Reflect::get(&device, &"queue".into())?;
|
||||
|
||||
Ok(WebGpuContext {
|
||||
device,
|
||||
queue,
|
||||
adapter_info,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Err(JsValue::from_str("WebGPU only available in WASM"))
|
||||
}
|
||||
|
||||
/// Get adapter information
|
||||
#[wasm_bindgen(getter, js_name = adapterInfo)]
|
||||
pub fn adapter_info(&self) -> AdapterInfo {
|
||||
self.adapter_info.clone()
|
||||
}
|
||||
|
||||
/// Check if context is valid
|
||||
#[wasm_bindgen(getter, js_name = isValid)]
|
||||
pub fn is_valid(&self) -> bool {
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
{
|
||||
!self.device.is_undefined() && !self.device.is_null()
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
false
|
||||
}
|
||||
|
||||
/// Create a GPU buffer
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
fn create_buffer_internal(
|
||||
&self,
|
||||
size: usize,
|
||||
usage: u32,
|
||||
label: Option<&str>,
|
||||
) -> Result<JsValue, JsValue> {
|
||||
let descriptor = Object::new();
|
||||
Reflect::set(&descriptor, &"size".into(), &JsValue::from_f64(size as f64))?;
|
||||
Reflect::set(
|
||||
&descriptor,
|
||||
&"usage".into(),
|
||||
&JsValue::from_f64(usage as f64),
|
||||
)?;
|
||||
if let Some(lbl) = label {
|
||||
Reflect::set(&descriptor, &"label".into(), &lbl.into())?;
|
||||
}
|
||||
|
||||
call_method(&self.device, "createBuffer", &[descriptor.into()])
|
||||
}
|
||||
|
||||
/// Write data to GPU buffer
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
fn write_buffer_internal(&self, buffer: &JsValue, data: &[f32]) -> Result<(), JsValue> {
|
||||
let data_array = Float32Array::from(data);
|
||||
call_method(
|
||||
&self.queue,
|
||||
"writeBuffer",
|
||||
&[
|
||||
buffer.clone(),
|
||||
JsValue::from_f64(0.0),
|
||||
data_array.buffer().into(),
|
||||
],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Compute Pipeline
|
||||
// ============================================================================
|
||||
|
||||
/// Compute pipeline handle
|
||||
#[wasm_bindgen]
|
||||
pub struct ComputePipeline {
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
pipeline: JsValue,
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
bind_group_layout: JsValue,
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
_phantom: std::marker::PhantomData<()>,
|
||||
|
||||
entry_point: String,
|
||||
workgroup_size: [u32; 3],
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl ComputePipeline {
|
||||
/// Get the entry point name
|
||||
#[wasm_bindgen(getter, js_name = entryPoint)]
|
||||
pub fn entry_point(&self) -> String {
|
||||
self.entry_point.clone()
|
||||
}
|
||||
|
||||
/// Get the workgroup size
|
||||
#[wasm_bindgen(getter, js_name = workgroupSize)]
|
||||
pub fn workgroup_size(&self) -> Vec<u32> {
|
||||
self.workgroup_size.to_vec()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// WebGPU Inference Engine
|
||||
// ============================================================================
|
||||
|
||||
/// WebGPU inference engine for LLM operations
|
||||
#[wasm_bindgen]
|
||||
pub struct WebGpuInference {
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
device: JsValue,
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
queue: JsValue,
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
_phantom: std::marker::PhantomData<()>,
|
||||
|
||||
adapter_info: AdapterInfo,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WebGpuInference {
|
||||
/// Check if WebGPU is available
|
||||
#[wasm_bindgen(js_name = isAvailable)]
|
||||
pub async fn is_available() -> bool {
|
||||
is_webgpu_available().await
|
||||
}
|
||||
|
||||
/// Initialize WebGPU inference engine
|
||||
#[wasm_bindgen(js_name = init)]
|
||||
pub async fn init() -> Result<WebGpuInference, JsValue> {
|
||||
let ctx = WebGpuContext::init().await?;
|
||||
|
||||
Ok(WebGpuInference {
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
device: ctx.device,
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
queue: ctx.queue,
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
_phantom: std::marker::PhantomData,
|
||||
adapter_info: ctx.adapter_info,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get adapter information
|
||||
#[wasm_bindgen(getter, js_name = adapterInfo)]
|
||||
pub fn adapter_info(&self) -> AdapterInfo {
|
||||
self.adapter_info.clone()
|
||||
}
|
||||
|
||||
/// Perform matrix multiplication: C = A * B
|
||||
///
|
||||
/// Args:
|
||||
/// a: Matrix A as flat f32 array (M x K)
|
||||
/// b: Matrix B as flat f32 array (K x N)
|
||||
/// m: Number of rows in A
|
||||
/// n: Number of columns in B
|
||||
/// k: Shared dimension
|
||||
///
|
||||
/// Returns: Result matrix C as f32 array (M x N)
|
||||
#[wasm_bindgen]
|
||||
pub async fn matmul(
|
||||
&self,
|
||||
a: &[f32],
|
||||
b: &[f32],
|
||||
m: u32,
|
||||
n: u32,
|
||||
k: u32,
|
||||
) -> Result<Vec<f32>, JsValue> {
|
||||
// Validate dimensions
|
||||
let expected_a = (m as usize) * (k as usize);
|
||||
let expected_b = (k as usize) * (n as usize);
|
||||
|
||||
if a.len() != expected_a {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Matrix A dimension mismatch: expected {}, got {}",
|
||||
expected_a,
|
||||
a.len()
|
||||
)));
|
||||
}
|
||||
|
||||
if b.len() != expected_b {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Matrix B dimension mismatch: expected {}, got {}",
|
||||
expected_b,
|
||||
b.len()
|
||||
)));
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
{
|
||||
let output_size = (m as usize) * (n as usize);
|
||||
|
||||
// GPU buffer usage flags
|
||||
const STORAGE: u32 = 0x80; // GPUBufferUsage.STORAGE
|
||||
const COPY_SRC: u32 = 0x04; // GPUBufferUsage.COPY_SRC
|
||||
const COPY_DST: u32 = 0x08; // GPUBufferUsage.COPY_DST
|
||||
const MAP_READ: u32 = 0x01; // GPUBufferUsage.MAP_READ
|
||||
const UNIFORM: u32 = 0x40; // GPUBufferUsage.UNIFORM
|
||||
|
||||
// Create buffers
|
||||
let buffer_a = self.create_buffer(a.len() * 4, STORAGE | COPY_DST, Some("matmul_a"))?;
|
||||
let buffer_b = self.create_buffer(b.len() * 4, STORAGE | COPY_DST, Some("matmul_b"))?;
|
||||
let buffer_c =
|
||||
self.create_buffer(output_size * 4, STORAGE | COPY_SRC, Some("matmul_c"))?;
|
||||
|
||||
// Create uniform buffer for dimensions
|
||||
let uniform_data: [f32; 4] = [m as f32, n as f32, k as f32, 1.0]; // M, N, K, alpha
|
||||
let uniform_buffer =
|
||||
self.create_buffer(16, UNIFORM | COPY_DST, Some("matmul_uniforms"))?;
|
||||
|
||||
// Write data to buffers
|
||||
self.write_buffer(&buffer_a, a)?;
|
||||
self.write_buffer(&buffer_b, b)?;
|
||||
self.write_buffer(&uniform_buffer, &uniform_data)?;
|
||||
|
||||
// Create shader module
|
||||
let shader_desc = Object::new();
|
||||
Reflect::set(&shader_desc, &"code".into(), &shaders::MATMUL_SHADER.into())?;
|
||||
let shader_module =
|
||||
call_method(&self.device, "createShaderModule", &[shader_desc.into()])?;
|
||||
|
||||
// Create bind group layout
|
||||
let layout_entries = Array::new();
|
||||
|
||||
// Storage buffer entries (A, B, C)
|
||||
for i in 0..3u32 {
|
||||
let entry = Object::new();
|
||||
Reflect::set(&entry, &"binding".into(), &JsValue::from_f64(i as f64))?;
|
||||
Reflect::set(&entry, &"visibility".into(), &JsValue::from_f64(4.0))?; // COMPUTE stage
|
||||
let buffer_layout = Object::new();
|
||||
Reflect::set(
|
||||
&buffer_layout,
|
||||
&"type".into(),
|
||||
&(if i < 2 {
|
||||
"read-only-storage"
|
||||
} else {
|
||||
"storage"
|
||||
})
|
||||
.into(),
|
||||
)?;
|
||||
Reflect::set(&entry, &"buffer".into(), &buffer_layout)?;
|
||||
layout_entries.push(&entry);
|
||||
}
|
||||
|
||||
// Uniform buffer entry
|
||||
let uniform_entry = Object::new();
|
||||
Reflect::set(&uniform_entry, &"binding".into(), &JsValue::from_f64(3.0))?;
|
||||
Reflect::set(
|
||||
&uniform_entry,
|
||||
&"visibility".into(),
|
||||
&JsValue::from_f64(4.0),
|
||||
)?;
|
||||
let uniform_layout = Object::new();
|
||||
Reflect::set(&uniform_layout, &"type".into(), &"uniform".into())?;
|
||||
Reflect::set(&uniform_entry, &"buffer".into(), &uniform_layout)?;
|
||||
layout_entries.push(&uniform_entry);
|
||||
|
||||
let layout_desc = Object::new();
|
||||
Reflect::set(&layout_desc, &"entries".into(), &layout_entries)?;
|
||||
let bind_group_layout =
|
||||
call_method(&self.device, "createBindGroupLayout", &[layout_desc.into()])?;
|
||||
|
||||
// Create pipeline layout
|
||||
let layouts = Array::new();
|
||||
layouts.push(&bind_group_layout);
|
||||
let pipeline_layout_desc = Object::new();
|
||||
Reflect::set(&pipeline_layout_desc, &"bindGroupLayouts".into(), &layouts)?;
|
||||
let pipeline_layout = call_method(
|
||||
&self.device,
|
||||
"createPipelineLayout",
|
||||
&[pipeline_layout_desc.into()],
|
||||
)?;
|
||||
|
||||
// Create compute pipeline
|
||||
let compute_stage = Object::new();
|
||||
Reflect::set(&compute_stage, &"module".into(), &shader_module)?;
|
||||
Reflect::set(&compute_stage, &"entryPoint".into(), &"main".into())?;
|
||||
|
||||
let pipeline_desc = Object::new();
|
||||
Reflect::set(&pipeline_desc, &"layout".into(), &pipeline_layout)?;
|
||||
Reflect::set(&pipeline_desc, &"compute".into(), &compute_stage)?;
|
||||
|
||||
let pipeline = call_method(
|
||||
&self.device,
|
||||
"createComputePipeline",
|
||||
&[pipeline_desc.into()],
|
||||
)?;
|
||||
|
||||
// Create bind group
|
||||
let bind_entries = Array::new();
|
||||
for (i, buffer) in [&buffer_a, &buffer_b, &buffer_c, &uniform_buffer]
|
||||
.iter()
|
||||
.enumerate()
|
||||
{
|
||||
let entry = Object::new();
|
||||
Reflect::set(&entry, &"binding".into(), &JsValue::from_f64(i as f64))?;
|
||||
let resource = Object::new();
|
||||
Reflect::set(&resource, &"buffer".into(), buffer)?;
|
||||
Reflect::set(&entry, &"resource".into(), &resource)?;
|
||||
bind_entries.push(&entry);
|
||||
}
|
||||
|
||||
let bind_group_desc = Object::new();
|
||||
Reflect::set(&bind_group_desc, &"layout".into(), &bind_group_layout)?;
|
||||
Reflect::set(&bind_group_desc, &"entries".into(), &bind_entries)?;
|
||||
let bind_group =
|
||||
call_method(&self.device, "createBindGroup", &[bind_group_desc.into()])?;
|
||||
|
||||
// Create command encoder
|
||||
let encoder_desc = Object::new();
|
||||
let encoder =
|
||||
call_method(&self.device, "createCommandEncoder", &[encoder_desc.into()])?;
|
||||
|
||||
// Begin compute pass
|
||||
let pass_desc = Object::new();
|
||||
let pass = call_method(&encoder, "beginComputePass", &[pass_desc.into()])?;
|
||||
|
||||
// Set pipeline and bind group
|
||||
call_method(&pass, "setPipeline", &[pipeline.clone()])?;
|
||||
call_method(
|
||||
&pass,
|
||||
"setBindGroup",
|
||||
&[JsValue::from_f64(0.0), bind_group.clone()],
|
||||
)?;
|
||||
|
||||
// Dispatch workgroups (16x16 tile size)
|
||||
let workgroups_x = (m + 15) / 16;
|
||||
let workgroups_y = (n + 15) / 16;
|
||||
call_method(
|
||||
&pass,
|
||||
"dispatchWorkgroups",
|
||||
&[
|
||||
JsValue::from_f64(workgroups_x as f64),
|
||||
JsValue::from_f64(workgroups_y as f64),
|
||||
],
|
||||
)?;
|
||||
|
||||
call_method(&pass, "end", &[])?;
|
||||
|
||||
// Create staging buffer for readback
|
||||
let staging =
|
||||
self.create_buffer(output_size * 4, MAP_READ | COPY_DST, Some("staging"))?;
|
||||
|
||||
// Copy result to staging
|
||||
call_method(
|
||||
&encoder,
|
||||
"copyBufferToBuffer",
|
||||
&[
|
||||
buffer_c.clone(),
|
||||
JsValue::from_f64(0.0),
|
||||
staging.clone(),
|
||||
JsValue::from_f64(0.0),
|
||||
JsValue::from_f64((output_size * 4) as f64),
|
||||
],
|
||||
)?;
|
||||
|
||||
// Submit commands
|
||||
let command_buffer = call_method(&encoder, "finish", &[])?;
|
||||
let commands = Array::new();
|
||||
commands.push(&command_buffer);
|
||||
call_method(&self.queue, "submit", &[commands.into()])?;
|
||||
|
||||
// Map staging buffer and read result
|
||||
let map_promise = call_method(&staging, "mapAsync", &[JsValue::from_f64(1.0)])?; // MAP_READ = 1
|
||||
JsFuture::from(map_promise.dyn_into::<Promise>()?).await?;
|
||||
|
||||
let mapped_range = call_method(&staging, "getMappedRange", &[])?;
|
||||
let data = Float32Array::new(&mapped_range).to_vec();
|
||||
|
||||
call_method(&staging, "unmap", &[])?;
|
||||
|
||||
Ok(data)
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
{
|
||||
// CPU fallback - naive implementation
|
||||
let mut c = vec![0.0f32; (m as usize) * (n as usize)];
|
||||
for i in 0..m as usize {
|
||||
for j in 0..n as usize {
|
||||
let mut sum = 0.0f32;
|
||||
for l in 0..k as usize {
|
||||
sum += a[i * k as usize + l] * b[l * n as usize + j];
|
||||
}
|
||||
c[i * n as usize + j] = sum;
|
||||
}
|
||||
}
|
||||
Ok(c)
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform attention: Output = softmax(Q * K^T / sqrt(d_k)) * V
|
||||
#[wasm_bindgen]
|
||||
pub async fn attention(
|
||||
&self,
|
||||
q: &[f32],
|
||||
k: &[f32],
|
||||
v: &[f32],
|
||||
config: &AttentionConfig,
|
||||
) -> Result<Vec<f32>, JsValue> {
|
||||
let hidden_dim = config.hidden_dim();
|
||||
let expected_size = (config.seq_len as usize) * (hidden_dim as usize);
|
||||
|
||||
if q.len() != expected_size || k.len() != expected_size || v.len() != expected_size {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Attention tensor dimension mismatch: expected {}, got Q:{}, K:{}, V:{}",
|
||||
expected_size,
|
||||
q.len(),
|
||||
k.len(),
|
||||
v.len()
|
||||
)));
|
||||
}
|
||||
|
||||
// CPU fallback for attention (GPU implementation similar to matmul pattern)
|
||||
// For production, would implement full GPU attention here
|
||||
self.attention_cpu(q, k, v, config)
|
||||
}
|
||||
|
||||
/// CPU fallback for attention
|
||||
fn attention_cpu(
|
||||
&self,
|
||||
q: &[f32],
|
||||
k: &[f32],
|
||||
v: &[f32],
|
||||
config: &AttentionConfig,
|
||||
) -> Result<Vec<f32>, JsValue> {
|
||||
let seq_len = config.seq_len as usize;
|
||||
let num_heads = config.num_heads as usize;
|
||||
let head_dim = config.head_dim as usize;
|
||||
let hidden_dim = num_heads * head_dim;
|
||||
let scale = config.scale();
|
||||
|
||||
let mut output = vec![0.0f32; seq_len * hidden_dim];
|
||||
|
||||
// Process each head independently
|
||||
for h in 0..num_heads {
|
||||
for i in 0..seq_len {
|
||||
// For this query position, compute attention to all key positions
|
||||
let q_offset = i * hidden_dim + h * head_dim;
|
||||
|
||||
// Compute attention scores
|
||||
let mut scores = vec![0.0f32; seq_len];
|
||||
let mut max_score = f32::NEG_INFINITY;
|
||||
|
||||
for j in 0..seq_len {
|
||||
// Causal masking
|
||||
if config.causal && j > i {
|
||||
scores[j] = f32::NEG_INFINITY;
|
||||
continue;
|
||||
}
|
||||
|
||||
let k_offset = j * hidden_dim + h * head_dim;
|
||||
let mut score = 0.0f32;
|
||||
|
||||
for d in 0..head_dim {
|
||||
score += q[q_offset + d] * k[k_offset + d];
|
||||
}
|
||||
|
||||
score *= scale;
|
||||
scores[j] = score;
|
||||
if score > max_score {
|
||||
max_score = score;
|
||||
}
|
||||
}
|
||||
|
||||
// Softmax
|
||||
let mut sum = 0.0f32;
|
||||
for j in 0..seq_len {
|
||||
scores[j] = (scores[j] - max_score).exp();
|
||||
sum += scores[j];
|
||||
}
|
||||
for j in 0..seq_len {
|
||||
scores[j] /= sum;
|
||||
}
|
||||
|
||||
// Compute weighted sum of values
|
||||
let out_offset = i * hidden_dim + h * head_dim;
|
||||
for d in 0..head_dim {
|
||||
let mut weighted_sum = 0.0f32;
|
||||
for j in 0..seq_len {
|
||||
let v_offset = j * hidden_dim + h * head_dim;
|
||||
weighted_sum += scores[j] * v[v_offset + d];
|
||||
}
|
||||
output[out_offset + d] = weighted_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Perform RMS normalization
|
||||
#[wasm_bindgen(js_name = rmsNorm)]
|
||||
pub async fn rms_norm(
|
||||
&self,
|
||||
input: &[f32],
|
||||
weight: &[f32],
|
||||
hidden_dim: u32,
|
||||
eps: f32,
|
||||
) -> Result<Vec<f32>, JsValue> {
|
||||
if weight.len() != hidden_dim as usize {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Weight dimension mismatch: expected {}, got {}",
|
||||
hidden_dim,
|
||||
weight.len()
|
||||
)));
|
||||
}
|
||||
|
||||
if input.len() % hidden_dim as usize != 0 {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Input size {} not divisible by hidden_dim {}",
|
||||
input.len(),
|
||||
hidden_dim
|
||||
)));
|
||||
}
|
||||
|
||||
// CPU implementation
|
||||
let batch_size = input.len() / hidden_dim as usize;
|
||||
let mut output = vec![0.0f32; input.len()];
|
||||
|
||||
for b in 0..batch_size {
|
||||
let offset = b * hidden_dim as usize;
|
||||
|
||||
// Compute sum of squares
|
||||
let mut sum_sq = 0.0f32;
|
||||
for i in 0..hidden_dim as usize {
|
||||
let x = input[offset + i];
|
||||
sum_sq += x * x;
|
||||
}
|
||||
|
||||
// RMS scale
|
||||
let rms = (sum_sq / hidden_dim as f32 + eps).sqrt();
|
||||
|
||||
// Normalize and scale
|
||||
for i in 0..hidden_dim as usize {
|
||||
output[offset + i] = input[offset + i] / rms * weight[i];
|
||||
}
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Perform softmax
|
||||
#[wasm_bindgen]
|
||||
pub async fn softmax(
|
||||
&self,
|
||||
input: &[f32],
|
||||
dim: u32,
|
||||
temperature: f32,
|
||||
) -> Result<Vec<f32>, JsValue> {
|
||||
if input.len() % dim as usize != 0 {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Input size {} not divisible by dim {}",
|
||||
input.len(),
|
||||
dim
|
||||
)));
|
||||
}
|
||||
|
||||
let batch_size = input.len() / dim as usize;
|
||||
let mut output = vec![0.0f32; input.len()];
|
||||
|
||||
for b in 0..batch_size {
|
||||
let offset = b * dim as usize;
|
||||
|
||||
// Find max (for numerical stability)
|
||||
let mut max_val = f32::NEG_INFINITY;
|
||||
for i in 0..dim as usize {
|
||||
let x = input[offset + i] / temperature;
|
||||
if x > max_val {
|
||||
max_val = x;
|
||||
}
|
||||
}
|
||||
|
||||
// Compute exp and sum
|
||||
let mut sum = 0.0f32;
|
||||
for i in 0..dim as usize {
|
||||
let x = (input[offset + i] / temperature - max_val).exp();
|
||||
output[offset + i] = x;
|
||||
sum += x;
|
||||
}
|
||||
|
||||
// Normalize
|
||||
for i in 0..dim as usize {
|
||||
output[offset + i] /= sum;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
// Helper methods for GPU buffer management
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
fn create_buffer(
|
||||
&self,
|
||||
size: usize,
|
||||
usage: u32,
|
||||
label: Option<&str>,
|
||||
) -> Result<JsValue, JsValue> {
|
||||
let descriptor = Object::new();
|
||||
Reflect::set(&descriptor, &"size".into(), &JsValue::from_f64(size as f64))?;
|
||||
Reflect::set(
|
||||
&descriptor,
|
||||
&"usage".into(),
|
||||
&JsValue::from_f64(usage as f64),
|
||||
)?;
|
||||
if let Some(lbl) = label {
|
||||
Reflect::set(&descriptor, &"label".into(), &lbl.into())?;
|
||||
}
|
||||
|
||||
call_method(&self.device, "createBuffer", &[descriptor.into()])
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
fn write_buffer(&self, buffer: &JsValue, data: &[f32]) -> Result<(), JsValue> {
|
||||
let data_array = Float32Array::from(data);
|
||||
call_method(
|
||||
&self.queue,
|
||||
"writeBuffer",
|
||||
&[
|
||||
buffer.clone(),
|
||||
JsValue::from_f64(0.0),
|
||||
data_array.buffer().into(),
|
||||
],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_cpu_matmul_fallback() {
|
||||
// Test the CPU fallback logic (in non-wasm mode)
|
||||
let a = vec![1.0, 2.0, 3.0, 4.0]; // 2x2
|
||||
let b = vec![5.0, 6.0, 7.0, 8.0]; // 2x2
|
||||
|
||||
// Expected: [[1*5+2*7, 1*6+2*8], [3*5+4*7, 3*6+4*8]]
|
||||
// = [[19, 22], [43, 50]]
|
||||
|
||||
let mut c = vec![0.0f32; 4];
|
||||
for i in 0..2usize {
|
||||
for j in 0..2usize {
|
||||
let mut sum = 0.0f32;
|
||||
for l in 0..2usize {
|
||||
sum += a[i * 2 + l] * b[l * 2 + j];
|
||||
}
|
||||
c[i * 2 + j] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(c, vec![19.0, 22.0, 43.0, 50.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rms_norm_cpu() {
|
||||
let input = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let weight = vec![1.0, 1.0, 1.0, 1.0];
|
||||
let hidden_dim = 4;
|
||||
let eps = 1e-5f32;
|
||||
|
||||
// sum_sq = 1 + 4 + 9 + 16 = 30
|
||||
// rms = sqrt(30/4 + eps) = sqrt(7.5) ≈ 2.7386
|
||||
let rms = (30.0f32 / 4.0 + eps).sqrt();
|
||||
|
||||
let expected: Vec<f32> = input.iter().map(|&x| x / rms).collect();
|
||||
|
||||
// Verify calculation
|
||||
assert!((expected[0] - 0.3651).abs() < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_softmax_cpu() {
|
||||
let input = vec![1.0, 2.0, 3.0];
|
||||
let temperature = 1.0f32;
|
||||
|
||||
// max = 3
|
||||
// exp(1-3) = exp(-2), exp(2-3) = exp(-1), exp(3-3) = exp(0) = 1
|
||||
let exps: Vec<f32> = vec![(-2.0f32).exp(), (-1.0f32).exp(), 1.0];
|
||||
let sum: f32 = exps.iter().sum();
|
||||
let expected: Vec<f32> = exps.iter().map(|&x| x / sum).collect();
|
||||
|
||||
// Verify softmax sums to 1
|
||||
let softmax_sum: f32 = expected.iter().sum();
|
||||
assert!((softmax_sum - 1.0).abs() < 0.001);
|
||||
}
|
||||
}
|
||||
345
vendor/ruvector/crates/ruvllm-wasm/src/webgpu/mod.rs
vendored
Normal file
345
vendor/ruvector/crates/ruvllm-wasm/src/webgpu/mod.rs
vendored
Normal file
@@ -0,0 +1,345 @@
|
||||
//! WebGPU Compute Module for WASM-based GPU Acceleration
|
||||
//!
|
||||
//! This module provides WebGPU compute shader support for LLM inference
|
||||
//! operations in the browser. It includes:
|
||||
//!
|
||||
//! - Matrix multiplication (tiled, batched, GEMV)
|
||||
//! - Flash Attention (causal, GQA, decode)
|
||||
//! - RMSNorm and LayerNorm
|
||||
//! - Softmax (standard, temperature-scaled, log-softmax)
|
||||
//!
|
||||
//! ## Feature Detection
|
||||
//!
|
||||
//! WebGPU availability is checked at runtime with graceful fallback:
|
||||
//!
|
||||
//! ```javascript
|
||||
//! if (await WebGpuInference.isAvailable()) {
|
||||
//! const gpu = await WebGpuInference.init();
|
||||
//! const result = await gpu.matmul(a, b, m, n, k);
|
||||
//! } else {
|
||||
//! // Fall back to CPU implementation
|
||||
//! }
|
||||
//! ```
|
||||
//!
|
||||
//! ## Performance Targets
|
||||
//!
|
||||
//! - Matrix multiply: ~1 TFLOP on integrated GPUs, ~10 TFLOPS on discrete
|
||||
//! - Attention: 2ms for 4K context on discrete GPU
|
||||
//! - Normalization: <0.5ms for typical hidden dimensions
|
||||
|
||||
pub mod buffers;
|
||||
pub mod compute;
|
||||
pub mod shaders;
|
||||
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
pub use buffers::{GpuBuffer, GpuBufferUsage};
|
||||
pub use compute::{ComputePipeline, WebGpuContext};
|
||||
pub use shaders::ShaderModule;
|
||||
|
||||
/// GPU adapter information
|
||||
#[wasm_bindgen]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AdapterInfo {
|
||||
/// GPU vendor name
|
||||
#[wasm_bindgen(skip)]
|
||||
pub vendor: String,
|
||||
/// GPU architecture/device name
|
||||
#[wasm_bindgen(skip)]
|
||||
pub architecture: String,
|
||||
/// Device type (integrated, discrete, etc.)
|
||||
#[wasm_bindgen(skip)]
|
||||
pub device_type: String,
|
||||
/// Backend API (WebGPU, etc.)
|
||||
#[wasm_bindgen(skip)]
|
||||
pub backend: String,
|
||||
/// Maximum buffer size in bytes
|
||||
#[wasm_bindgen(skip)]
|
||||
pub max_buffer_size: u64,
|
||||
/// Maximum compute workgroup size
|
||||
#[wasm_bindgen(skip)]
|
||||
pub max_workgroup_size: u32,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl AdapterInfo {
|
||||
/// Get GPU vendor name
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn vendor(&self) -> String {
|
||||
self.vendor.clone()
|
||||
}
|
||||
|
||||
/// Get GPU architecture
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn architecture(&self) -> String {
|
||||
self.architecture.clone()
|
||||
}
|
||||
|
||||
/// Get device type
|
||||
#[wasm_bindgen(getter, js_name = deviceType)]
|
||||
pub fn device_type(&self) -> String {
|
||||
self.device_type.clone()
|
||||
}
|
||||
|
||||
/// Get backend API
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn backend(&self) -> String {
|
||||
self.backend.clone()
|
||||
}
|
||||
|
||||
/// Get maximum buffer size
|
||||
#[wasm_bindgen(getter, js_name = maxBufferSize)]
|
||||
pub fn max_buffer_size(&self) -> u64 {
|
||||
self.max_buffer_size
|
||||
}
|
||||
|
||||
/// Get maximum workgroup size
|
||||
#[wasm_bindgen(getter, js_name = maxWorkgroupSize)]
|
||||
pub fn max_workgroup_size(&self) -> u32 {
|
||||
self.max_workgroup_size
|
||||
}
|
||||
|
||||
/// Convert to JSON string
|
||||
#[wasm_bindgen(js_name = toJson)]
|
||||
pub fn to_json(&self) -> Result<String, JsValue> {
|
||||
let json = serde_json::json!({
|
||||
"vendor": self.vendor,
|
||||
"architecture": self.architecture,
|
||||
"deviceType": self.device_type,
|
||||
"backend": self.backend,
|
||||
"maxBufferSize": self.max_buffer_size,
|
||||
"maxWorkgroupSize": self.max_workgroup_size,
|
||||
});
|
||||
serde_json::to_string(&json).map_err(|e| JsValue::from_str(&e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Attention configuration for compute shaders
|
||||
#[wasm_bindgen]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AttentionConfig {
|
||||
/// Sequence length for queries
|
||||
#[wasm_bindgen(skip)]
|
||||
pub seq_len: u32,
|
||||
/// Key/Value sequence length (can differ for encoder-decoder)
|
||||
#[wasm_bindgen(skip)]
|
||||
pub kv_seq_len: u32,
|
||||
/// Number of attention heads
|
||||
#[wasm_bindgen(skip)]
|
||||
pub num_heads: u32,
|
||||
/// Dimension per head
|
||||
#[wasm_bindgen(skip)]
|
||||
pub head_dim: u32,
|
||||
/// Whether to apply causal masking
|
||||
#[wasm_bindgen(skip)]
|
||||
pub causal: bool,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl AttentionConfig {
|
||||
/// Create new attention configuration
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(seq_len: u32, num_heads: u32, head_dim: u32, causal: bool) -> Self {
|
||||
Self {
|
||||
seq_len,
|
||||
kv_seq_len: seq_len,
|
||||
num_heads,
|
||||
head_dim,
|
||||
causal,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create for encoder-decoder models with different KV length
|
||||
#[wasm_bindgen(js_name = forEncoderDecoder)]
|
||||
pub fn for_encoder_decoder(
|
||||
seq_len: u32,
|
||||
kv_seq_len: u32,
|
||||
num_heads: u32,
|
||||
head_dim: u32,
|
||||
) -> Self {
|
||||
Self {
|
||||
seq_len,
|
||||
kv_seq_len,
|
||||
num_heads,
|
||||
head_dim,
|
||||
causal: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the scaling factor (1/sqrt(head_dim))
|
||||
pub fn scale(&self) -> f32 {
|
||||
1.0 / (self.head_dim as f32).sqrt()
|
||||
}
|
||||
|
||||
/// Get total hidden dimension
|
||||
pub fn hidden_dim(&self) -> u32 {
|
||||
self.num_heads * self.head_dim
|
||||
}
|
||||
|
||||
#[wasm_bindgen(getter, js_name = seqLen)]
|
||||
pub fn get_seq_len(&self) -> u32 {
|
||||
self.seq_len
|
||||
}
|
||||
|
||||
#[wasm_bindgen(setter, js_name = seqLen)]
|
||||
pub fn set_seq_len(&mut self, value: u32) {
|
||||
self.seq_len = value;
|
||||
}
|
||||
|
||||
#[wasm_bindgen(getter, js_name = kvSeqLen)]
|
||||
pub fn get_kv_seq_len(&self) -> u32 {
|
||||
self.kv_seq_len
|
||||
}
|
||||
|
||||
#[wasm_bindgen(setter, js_name = kvSeqLen)]
|
||||
pub fn set_kv_seq_len(&mut self, value: u32) {
|
||||
self.kv_seq_len = value;
|
||||
}
|
||||
|
||||
#[wasm_bindgen(getter, js_name = numHeads)]
|
||||
pub fn get_num_heads(&self) -> u32 {
|
||||
self.num_heads
|
||||
}
|
||||
|
||||
#[wasm_bindgen(setter, js_name = numHeads)]
|
||||
pub fn set_num_heads(&mut self, value: u32) {
|
||||
self.num_heads = value;
|
||||
}
|
||||
|
||||
#[wasm_bindgen(getter, js_name = headDim)]
|
||||
pub fn get_head_dim(&self) -> u32 {
|
||||
self.head_dim
|
||||
}
|
||||
|
||||
#[wasm_bindgen(setter, js_name = headDim)]
|
||||
pub fn set_head_dim(&mut self, value: u32) {
|
||||
self.head_dim = value;
|
||||
}
|
||||
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn get_causal(&self) -> bool {
|
||||
self.causal
|
||||
}
|
||||
|
||||
#[wasm_bindgen(setter)]
|
||||
pub fn set_causal(&mut self, value: bool) {
|
||||
self.causal = value;
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if WebGPU is available in this browser
|
||||
#[wasm_bindgen(js_name = isWebGpuAvailable)]
|
||||
pub async fn is_webgpu_available() -> bool {
|
||||
compute::is_webgpu_available().await
|
||||
}
|
||||
|
||||
/// Get GPU information if available
|
||||
#[wasm_bindgen(js_name = getGpuInfo)]
|
||||
pub async fn get_gpu_info() -> Result<JsValue, JsValue> {
|
||||
match compute::get_gpu_info().await {
|
||||
Some(info) => {
|
||||
let js_obj = js_sys::Object::new();
|
||||
js_sys::Reflect::set(&js_obj, &"vendor".into(), &info.vendor.into())?;
|
||||
js_sys::Reflect::set(&js_obj, &"architecture".into(), &info.architecture.into())?;
|
||||
js_sys::Reflect::set(&js_obj, &"deviceType".into(), &info.device_type.into())?;
|
||||
js_sys::Reflect::set(&js_obj, &"backend".into(), &info.backend.into())?;
|
||||
js_sys::Reflect::set(
|
||||
&js_obj,
|
||||
&"maxBufferSize".into(),
|
||||
&JsValue::from_f64(info.max_buffer_size as f64),
|
||||
)?;
|
||||
js_sys::Reflect::set(
|
||||
&js_obj,
|
||||
&"maxWorkgroupSize".into(),
|
||||
&JsValue::from_f64(info.max_workgroup_size as f64),
|
||||
)?;
|
||||
Ok(js_obj.into())
|
||||
}
|
||||
None => Ok(JsValue::NULL),
|
||||
}
|
||||
}
|
||||
|
||||
/// WebGPU error types
|
||||
#[derive(Debug)]
|
||||
pub enum WebGpuError {
|
||||
/// WebGPU not available in this browser
|
||||
NotAvailable,
|
||||
/// Failed to get GPU adapter
|
||||
AdapterNotFound,
|
||||
/// Failed to create device
|
||||
DeviceCreationFailed(String),
|
||||
/// Buffer allocation failed
|
||||
BufferAllocationFailed { requested: usize, available: usize },
|
||||
/// Shader compilation failed
|
||||
ShaderCompilationFailed(String),
|
||||
/// Invalid dimensions for operation
|
||||
DimensionMismatch { expected: String, actual: String },
|
||||
/// Operation timed out
|
||||
Timeout,
|
||||
/// Generic GPU error
|
||||
GpuError(String),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for WebGpuError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::NotAvailable => write!(f, "WebGPU is not available in this browser"),
|
||||
Self::AdapterNotFound => write!(f, "No suitable GPU adapter found"),
|
||||
Self::DeviceCreationFailed(msg) => write!(f, "Failed to create GPU device: {}", msg),
|
||||
Self::BufferAllocationFailed {
|
||||
requested,
|
||||
available,
|
||||
} => {
|
||||
write!(
|
||||
f,
|
||||
"Buffer allocation failed: requested {} bytes, {} available",
|
||||
requested, available
|
||||
)
|
||||
}
|
||||
Self::ShaderCompilationFailed(msg) => write!(f, "Shader compilation failed: {}", msg),
|
||||
Self::DimensionMismatch { expected, actual } => {
|
||||
write!(
|
||||
f,
|
||||
"Dimension mismatch: expected {}, got {}",
|
||||
expected, actual
|
||||
)
|
||||
}
|
||||
Self::Timeout => write!(f, "GPU operation timed out"),
|
||||
Self::GpuError(msg) => write!(f, "GPU error: {}", msg),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for WebGpuError {}
|
||||
|
||||
impl From<WebGpuError> for JsValue {
|
||||
fn from(error: WebGpuError) -> Self {
|
||||
JsValue::from_str(&error.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_attention_config() {
|
||||
let config = AttentionConfig::new(512, 8, 64, true);
|
||||
assert_eq!(config.hidden_dim(), 512);
|
||||
assert!((config.scale() - 0.125).abs() < 0.001); // 1/sqrt(64) = 0.125
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_adapter_info_json() {
|
||||
let info = AdapterInfo {
|
||||
vendor: "TestVendor".to_string(),
|
||||
architecture: "TestArch".to_string(),
|
||||
device_type: "integrated".to_string(),
|
||||
backend: "WebGPU".to_string(),
|
||||
max_buffer_size: 1024 * 1024 * 256,
|
||||
max_workgroup_size: 256,
|
||||
};
|
||||
let json = info.to_json().unwrap();
|
||||
assert!(json.contains("TestVendor"));
|
||||
}
|
||||
}
|
||||
195
vendor/ruvector/crates/ruvllm-wasm/src/webgpu/shaders.rs
vendored
Normal file
195
vendor/ruvector/crates/ruvllm-wasm/src/webgpu/shaders.rs
vendored
Normal file
@@ -0,0 +1,195 @@
|
||||
//! WGSL Shader Module Definitions
|
||||
//!
|
||||
//! This module contains the embedded WGSL shader source code for all
|
||||
//! compute operations. Shaders are embedded at compile time for efficient
|
||||
//! loading in WASM.
|
||||
|
||||
/// Matrix multiplication shader (tiled with shared memory)
|
||||
pub const MATMUL_SHADER: &str = include_str!("shaders/matmul.wgsl");
|
||||
|
||||
/// Flash attention shader (online softmax, causal masking)
|
||||
pub const ATTENTION_SHADER: &str = include_str!("shaders/attention.wgsl");
|
||||
|
||||
/// RMSNorm and LayerNorm shader
|
||||
pub const NORM_SHADER: &str = include_str!("shaders/norm.wgsl");
|
||||
|
||||
/// Softmax shader (numerically stable)
|
||||
pub const SOFTMAX_SHADER: &str = include_str!("shaders/softmax.wgsl");
|
||||
|
||||
/// Shader entry points for matrix multiplication
|
||||
pub mod matmul {
|
||||
/// Standard tiled matrix multiply
|
||||
pub const MAIN: &str = "main";
|
||||
/// Batched matrix multiply for attention projections
|
||||
pub const BATCHED: &str = "main_batched";
|
||||
/// Vector-matrix multiply for single token generation
|
||||
pub const GEMV: &str = "main_gemv";
|
||||
}
|
||||
|
||||
/// Shader entry points for attention
|
||||
pub mod attention {
|
||||
/// Standard multi-head attention
|
||||
pub const MAIN: &str = "main";
|
||||
/// Grouped query attention (GQA)
|
||||
pub const GQA: &str = "main_gqa";
|
||||
/// Single token decode attention
|
||||
pub const DECODE: &str = "main_decode";
|
||||
}
|
||||
|
||||
/// Shader entry points for normalization
|
||||
pub mod norm {
|
||||
/// RMSNorm (Llama-style)
|
||||
pub const RMS_NORM: &str = "rms_norm";
|
||||
/// RMSNorm with fused residual connection
|
||||
pub const RMS_NORM_RESIDUAL: &str = "rms_norm_residual";
|
||||
/// Standard LayerNorm
|
||||
pub const LAYER_NORM: &str = "layer_norm";
|
||||
/// Fast RMSNorm for small dimensions
|
||||
pub const RMS_NORM_SMALL: &str = "rms_norm_small";
|
||||
}
|
||||
|
||||
/// Shader entry points for softmax
|
||||
pub mod softmax {
|
||||
/// Standard row-wise softmax
|
||||
pub const MAIN: &str = "softmax";
|
||||
/// In-place softmax
|
||||
pub const INPLACE: &str = "softmax_inplace";
|
||||
/// Small dimension softmax
|
||||
pub const SMALL: &str = "softmax_small";
|
||||
/// Log softmax for loss computation
|
||||
pub const LOG_SOFTMAX: &str = "log_softmax";
|
||||
}
|
||||
|
||||
/// Shader module wrapper for wasm-bindgen
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
#[wasm_bindgen]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ShaderModule {
|
||||
name: String,
|
||||
source: String,
|
||||
entry_points: Vec<String>,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl ShaderModule {
|
||||
/// Get the matrix multiplication shader module
|
||||
#[wasm_bindgen(js_name = matmul)]
|
||||
pub fn get_matmul() -> ShaderModule {
|
||||
ShaderModule {
|
||||
name: "matmul".to_string(),
|
||||
source: MATMUL_SHADER.to_string(),
|
||||
entry_points: vec![
|
||||
matmul::MAIN.to_string(),
|
||||
matmul::BATCHED.to_string(),
|
||||
matmul::GEMV.to_string(),
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the attention shader module
|
||||
#[wasm_bindgen(js_name = attention)]
|
||||
pub fn get_attention() -> ShaderModule {
|
||||
ShaderModule {
|
||||
name: "attention".to_string(),
|
||||
source: ATTENTION_SHADER.to_string(),
|
||||
entry_points: vec![
|
||||
attention::MAIN.to_string(),
|
||||
attention::GQA.to_string(),
|
||||
attention::DECODE.to_string(),
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the normalization shader module
|
||||
#[wasm_bindgen(js_name = norm)]
|
||||
pub fn get_norm() -> ShaderModule {
|
||||
ShaderModule {
|
||||
name: "norm".to_string(),
|
||||
source: NORM_SHADER.to_string(),
|
||||
entry_points: vec![
|
||||
norm::RMS_NORM.to_string(),
|
||||
norm::RMS_NORM_RESIDUAL.to_string(),
|
||||
norm::LAYER_NORM.to_string(),
|
||||
norm::RMS_NORM_SMALL.to_string(),
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the softmax shader module
|
||||
#[wasm_bindgen(js_name = softmax)]
|
||||
pub fn get_softmax() -> ShaderModule {
|
||||
ShaderModule {
|
||||
name: "softmax".to_string(),
|
||||
source: SOFTMAX_SHADER.to_string(),
|
||||
entry_points: vec![
|
||||
softmax::MAIN.to_string(),
|
||||
softmax::INPLACE.to_string(),
|
||||
softmax::SMALL.to_string(),
|
||||
softmax::LOG_SOFTMAX.to_string(),
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
/// Get shader name
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn name(&self) -> String {
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
/// Get shader source code
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn source(&self) -> String {
|
||||
self.source.clone()
|
||||
}
|
||||
|
||||
/// Get available entry points
|
||||
#[wasm_bindgen(getter, js_name = entryPoints)]
|
||||
pub fn entry_points(&self) -> Vec<String> {
|
||||
self.entry_points.clone()
|
||||
}
|
||||
|
||||
/// Check if an entry point exists
|
||||
#[wasm_bindgen(js_name = hasEntryPoint)]
|
||||
pub fn has_entry_point(&self, name: &str) -> bool {
|
||||
self.entry_points.iter().any(|ep| ep == name)
|
||||
}
|
||||
}
|
||||
|
||||
/// Get all available shader modules
|
||||
#[wasm_bindgen(js_name = getAllShaderModules)]
|
||||
pub fn get_all_shader_modules() -> Vec<ShaderModule> {
|
||||
vec![
|
||||
ShaderModule::get_matmul(),
|
||||
ShaderModule::get_attention(),
|
||||
ShaderModule::get_norm(),
|
||||
ShaderModule::get_softmax(),
|
||||
]
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_shader_sources_not_empty() {
|
||||
assert!(!MATMUL_SHADER.is_empty());
|
||||
assert!(!ATTENTION_SHADER.is_empty());
|
||||
assert!(!NORM_SHADER.is_empty());
|
||||
assert!(!SOFTMAX_SHADER.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shader_module_creation() {
|
||||
let matmul = ShaderModule::get_matmul();
|
||||
assert_eq!(matmul.name(), "matmul");
|
||||
assert!(matmul.has_entry_point("main"));
|
||||
assert!(matmul.has_entry_point("main_batched"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_shader_modules() {
|
||||
let modules = get_all_shader_modules();
|
||||
assert_eq!(modules.len(), 4);
|
||||
}
|
||||
}
|
||||
283
vendor/ruvector/crates/ruvllm-wasm/src/webgpu/shaders/attention.wgsl
vendored
Normal file
283
vendor/ruvector/crates/ruvllm-wasm/src/webgpu/shaders/attention.wgsl
vendored
Normal file
@@ -0,0 +1,283 @@
|
||||
// Flash Attention Shader for WebGPU WASM
|
||||
//
|
||||
// Implements memory-efficient attention using online softmax algorithm.
|
||||
// Supports causal masking for autoregressive generation.
|
||||
//
|
||||
// Algorithm:
|
||||
// 1. Process Q in blocks, streaming K and V
|
||||
// 2. Maintain running max and sum for numerical stability
|
||||
// 3. Rescale outputs on-the-fly (Flash Attention v2)
|
||||
// 4. O(n) memory vs O(n^2) for standard attention
|
||||
//
|
||||
// Memory Layout:
|
||||
// - Q: (seq_len, num_heads, head_dim)
|
||||
// - K: (seq_len, num_heads, head_dim)
|
||||
// - V: (seq_len, num_heads, head_dim)
|
||||
// - Output: (seq_len, num_heads, head_dim)
|
||||
|
||||
const BLOCK_SIZE: u32 = 32u; // Reduced for WebGPU limits
|
||||
const MAX_HEAD_DIM: u32 = 128u;
|
||||
|
||||
struct AttentionUniforms {
|
||||
seq_len: u32,
|
||||
head_dim: u32,
|
||||
num_heads: u32,
|
||||
scale: f32, // 1/sqrt(head_dim)
|
||||
causal_mask: u32, // 1 for causal, 0 for full attention
|
||||
kv_seq_len: u32, // For encoder-decoder or prefill
|
||||
_pad0: u32,
|
||||
_pad1: u32,
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> Q: array<f32>;
|
||||
@group(0) @binding(1) var<storage, read> K: array<f32>;
|
||||
@group(0) @binding(2) var<storage, read> V: array<f32>;
|
||||
@group(0) @binding(3) var<storage, read_write> Output: array<f32>;
|
||||
@group(0) @binding(4) var<uniform> uniforms: AttentionUniforms;
|
||||
|
||||
// Shared memory for blocks
|
||||
var<workgroup> Q_shared: array<f32, 4096>; // BLOCK_SIZE * MAX_HEAD_DIM
|
||||
var<workgroup> K_shared: array<f32, 4096>;
|
||||
var<workgroup> V_shared: array<f32, 4096>;
|
||||
var<workgroup> scores_shared: array<f32, 1024>; // BLOCK_SIZE * BLOCK_SIZE
|
||||
|
||||
// Thread-local state for online softmax
|
||||
var<private> m_i: f32; // Running max
|
||||
var<private> l_i: f32; // Running sum
|
||||
var<private> o_i: array<f32, 128>; // Output accumulator
|
||||
|
||||
@compute @workgroup_size(32, 1, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
let seq_len = uniforms.seq_len;
|
||||
let head_dim = uniforms.head_dim;
|
||||
let num_heads = uniforms.num_heads;
|
||||
let scale = uniforms.scale;
|
||||
let is_causal = uniforms.causal_mask == 1u;
|
||||
let kv_seq_len = uniforms.kv_seq_len;
|
||||
|
||||
// This workgroup handles one Q block for one head
|
||||
let head_idx = group_id.y;
|
||||
let q_block_idx = group_id.x;
|
||||
let q_start = q_block_idx * BLOCK_SIZE;
|
||||
|
||||
let thread_id = local_id.x;
|
||||
let hidden_stride = num_heads * head_dim;
|
||||
|
||||
// Initialize online softmax state
|
||||
m_i = -1e10f;
|
||||
l_i = 0.0f;
|
||||
for (var d = 0u; d < head_dim; d++) {
|
||||
o_i[d] = 0.0f;
|
||||
}
|
||||
|
||||
// Load Q block into shared memory
|
||||
let q_pos = q_start + thread_id;
|
||||
if (q_pos < seq_len && thread_id < BLOCK_SIZE) {
|
||||
for (var d = 0u; d < head_dim; d++) {
|
||||
let q_idx = q_pos * hidden_stride + head_idx * head_dim + d;
|
||||
Q_shared[thread_id * head_dim + d] = Q[q_idx];
|
||||
}
|
||||
}
|
||||
workgroupBarrier();
|
||||
|
||||
// Iterate over K/V blocks
|
||||
let num_kv_blocks = (kv_seq_len + BLOCK_SIZE - 1u) / BLOCK_SIZE;
|
||||
|
||||
for (var kv_block = 0u; kv_block < num_kv_blocks; kv_block++) {
|
||||
let kv_start = kv_block * BLOCK_SIZE;
|
||||
|
||||
// Early exit for causal attention
|
||||
if (is_causal && kv_start > q_start + BLOCK_SIZE) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Load K block
|
||||
let k_pos = kv_start + thread_id;
|
||||
if (k_pos < kv_seq_len && thread_id < BLOCK_SIZE) {
|
||||
for (var d = 0u; d < head_dim; d++) {
|
||||
let k_idx = k_pos * hidden_stride + head_idx * head_dim + d;
|
||||
K_shared[thread_id * head_dim + d] = K[k_idx];
|
||||
}
|
||||
}
|
||||
|
||||
// Load V block
|
||||
let v_pos = kv_start + thread_id;
|
||||
if (v_pos < kv_seq_len && thread_id < BLOCK_SIZE) {
|
||||
for (var d = 0u; d < head_dim; d++) {
|
||||
let v_idx = v_pos * hidden_stride + head_idx * head_dim + d;
|
||||
V_shared[thread_id * head_dim + d] = V[v_idx];
|
||||
}
|
||||
}
|
||||
workgroupBarrier();
|
||||
|
||||
// Compute attention scores and update online softmax
|
||||
if (thread_id < BLOCK_SIZE && q_pos < seq_len) {
|
||||
let kv_block_len = min(BLOCK_SIZE, kv_seq_len - kv_start);
|
||||
|
||||
// Compute row max for this block
|
||||
var block_max = -1e10f;
|
||||
var local_scores: array<f32, 32>;
|
||||
|
||||
for (var k = 0u; k < kv_block_len; k++) {
|
||||
let k_global = kv_start + k;
|
||||
|
||||
// Apply causal mask
|
||||
if (is_causal && k_global > q_pos) {
|
||||
local_scores[k] = -1e10f;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Compute Q[q_pos] dot K[k]
|
||||
var score = 0.0f;
|
||||
for (var d = 0u; d < head_dim; d++) {
|
||||
score += Q_shared[thread_id * head_dim + d] * K_shared[k * head_dim + d];
|
||||
}
|
||||
score *= scale;
|
||||
local_scores[k] = score;
|
||||
block_max = max(block_max, score);
|
||||
}
|
||||
|
||||
// Update running statistics
|
||||
let m_ij = max(m_i, block_max);
|
||||
|
||||
// Rescale previous accumulator
|
||||
let alpha = exp(m_i - m_ij);
|
||||
for (var d = 0u; d < head_dim; d++) {
|
||||
o_i[d] *= alpha;
|
||||
}
|
||||
l_i *= alpha;
|
||||
|
||||
// Accumulate weighted V for this block
|
||||
for (var k = 0u; k < kv_block_len; k++) {
|
||||
let k_global = kv_start + k;
|
||||
if (is_causal && k_global > q_pos) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let p_ij = exp(local_scores[k] - m_ij);
|
||||
l_i += p_ij;
|
||||
|
||||
for (var d = 0u; d < head_dim; d++) {
|
||||
o_i[d] += p_ij * V_shared[k * head_dim + d];
|
||||
}
|
||||
}
|
||||
|
||||
m_i = m_ij;
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
// Normalize and write output
|
||||
if (thread_id < BLOCK_SIZE && q_pos < seq_len) {
|
||||
let inv_l = select(1.0f / l_i, 0.0f, l_i == 0.0f);
|
||||
|
||||
for (var d = 0u; d < head_dim; d++) {
|
||||
let out_idx = q_pos * hidden_stride + head_idx * head_dim + d;
|
||||
Output[out_idx] = o_i[d] * inv_l;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Grouped Query Attention (GQA) variant
|
||||
// Multiple Q heads share same K/V heads
|
||||
@compute @workgroup_size(32, 1, 1)
|
||||
fn main_gqa(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
// For GQA: kv_head_idx = q_head_idx / num_q_per_kv
|
||||
// This allows Llama2/3 style grouped attention
|
||||
// Implementation similar to main() with modified indexing
|
||||
}
|
||||
|
||||
// Single token attention for generation phase
|
||||
// More efficient when seq_len = 1 (decoding)
|
||||
@compute @workgroup_size(256, 1, 1)
|
||||
fn main_decode(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
let head_dim = uniforms.head_dim;
|
||||
let num_heads = uniforms.num_heads;
|
||||
let scale = uniforms.scale;
|
||||
let kv_seq_len = uniforms.kv_seq_len;
|
||||
let is_causal = uniforms.causal_mask == 1u;
|
||||
|
||||
let head_idx = group_id.x;
|
||||
let thread_id = local_id.x;
|
||||
let hidden_stride = num_heads * head_dim;
|
||||
|
||||
// Each thread handles part of the KV sequence
|
||||
let kv_per_thread = (kv_seq_len + 255u) / 256u;
|
||||
|
||||
// Thread-local accumulators
|
||||
var local_max = -1e10f;
|
||||
var local_sum = 0.0f;
|
||||
var local_out: array<f32, 128>;
|
||||
for (var d = 0u; d < head_dim; d++) {
|
||||
local_out[d] = 0.0f;
|
||||
}
|
||||
|
||||
// Load Q (single token)
|
||||
var q_vec: array<f32, 128>;
|
||||
if (thread_id == 0u) {
|
||||
for (var d = 0u; d < head_dim; d++) {
|
||||
q_vec[d] = Q[head_idx * head_dim + d];
|
||||
}
|
||||
}
|
||||
// Broadcast Q to all threads via shared memory
|
||||
for (var d = 0u; d < head_dim; d++) {
|
||||
Q_shared[d] = Q[head_idx * head_dim + d];
|
||||
}
|
||||
workgroupBarrier();
|
||||
|
||||
// Process assigned KV positions
|
||||
for (var i = 0u; i < kv_per_thread; i++) {
|
||||
let k_pos = thread_id * kv_per_thread + i;
|
||||
if (k_pos >= kv_seq_len) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Compute attention score
|
||||
var score = 0.0f;
|
||||
for (var d = 0u; d < head_dim; d++) {
|
||||
let k_idx = k_pos * hidden_stride + head_idx * head_dim + d;
|
||||
score += Q_shared[d] * K[k_idx];
|
||||
}
|
||||
score *= scale;
|
||||
|
||||
// Update local max
|
||||
let new_max = max(local_max, score);
|
||||
let alpha = exp(local_max - new_max);
|
||||
|
||||
for (var d = 0u; d < head_dim; d++) {
|
||||
local_out[d] *= alpha;
|
||||
}
|
||||
local_sum = local_sum * alpha + exp(score - new_max);
|
||||
|
||||
// Accumulate weighted V
|
||||
let p = exp(score - new_max);
|
||||
for (var d = 0u; d < head_dim; d++) {
|
||||
let v_idx = k_pos * hidden_stride + head_idx * head_dim + d;
|
||||
local_out[d] += p * V[v_idx];
|
||||
}
|
||||
|
||||
local_max = new_max;
|
||||
}
|
||||
|
||||
// Reduction across threads (simplified - real impl would use parallel reduction)
|
||||
// Store partial results for CPU reduction or use atomics
|
||||
if (thread_id == 0u) {
|
||||
let inv_sum = select(1.0f / local_sum, 0.0f, local_sum == 0.0f);
|
||||
for (var d = 0u; d < head_dim; d++) {
|
||||
Output[head_idx * head_dim + d] = local_out[d] * inv_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
182
vendor/ruvector/crates/ruvllm-wasm/src/webgpu/shaders/matmul.wgsl
vendored
Normal file
182
vendor/ruvector/crates/ruvllm-wasm/src/webgpu/shaders/matmul.wgsl
vendored
Normal file
@@ -0,0 +1,182 @@
|
||||
// Tiled Matrix Multiplication Shader for WebGPU WASM
|
||||
//
|
||||
// Computes C = A * B using 16x16 tiles optimized for browser WebGPU.
|
||||
// Uses workgroup shared memory for cache-efficient tile loading.
|
||||
//
|
||||
// Memory Layout (row-major):
|
||||
// - A: M x K matrix
|
||||
// - B: K x N matrix
|
||||
// - C: M x N matrix (output)
|
||||
|
||||
// Tile size optimized for WebGPU limits
|
||||
const TILE_SIZE: u32 = 16u;
|
||||
|
||||
struct Uniforms {
|
||||
M: u32, // Rows of A, rows of C
|
||||
N: u32, // Cols of B, cols of C
|
||||
K: u32, // Cols of A, rows of B
|
||||
alpha: f32, // Scaling factor (default 1.0)
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> A: array<f32>;
|
||||
@group(0) @binding(1) var<storage, read> B: array<f32>;
|
||||
@group(0) @binding(2) var<storage, read_write> C: array<f32>;
|
||||
@group(0) @binding(3) var<uniform> uniforms: Uniforms;
|
||||
|
||||
// Shared memory for tile caching
|
||||
var<workgroup> A_tile: array<f32, 256>; // TILE_SIZE * TILE_SIZE
|
||||
var<workgroup> B_tile: array<f32, 256>;
|
||||
|
||||
@compute @workgroup_size(16, 16, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
let M = uniforms.M;
|
||||
let N = uniforms.N;
|
||||
let K = uniforms.K;
|
||||
let alpha = uniforms.alpha;
|
||||
|
||||
// Global row and column
|
||||
let row = global_id.x;
|
||||
let col = global_id.y;
|
||||
|
||||
// Thread position within tile
|
||||
let local_row = local_id.x;
|
||||
let local_col = local_id.y;
|
||||
|
||||
// Accumulator for this thread's output element
|
||||
var sum = 0.0f;
|
||||
|
||||
// Number of tiles to process along K dimension
|
||||
let num_tiles = (K + TILE_SIZE - 1u) / TILE_SIZE;
|
||||
|
||||
// Iterate over tiles
|
||||
for (var t = 0u; t < num_tiles; t++) {
|
||||
let tile_k = t * TILE_SIZE;
|
||||
|
||||
// Load A tile element
|
||||
let a_row = row;
|
||||
let a_col = tile_k + local_col;
|
||||
if (a_row < M && a_col < K) {
|
||||
A_tile[local_row * TILE_SIZE + local_col] = A[a_row * K + a_col];
|
||||
} else {
|
||||
A_tile[local_row * TILE_SIZE + local_col] = 0.0;
|
||||
}
|
||||
|
||||
// Load B tile element
|
||||
let b_row = tile_k + local_row;
|
||||
let b_col = col;
|
||||
if (b_row < K && b_col < N) {
|
||||
B_tile[local_row * TILE_SIZE + local_col] = B[b_row * N + b_col];
|
||||
} else {
|
||||
B_tile[local_row * TILE_SIZE + local_col] = 0.0;
|
||||
}
|
||||
|
||||
// Synchronize to ensure tile is fully loaded
|
||||
workgroupBarrier();
|
||||
|
||||
// Compute partial dot product for this tile
|
||||
let tile_k_end = min(TILE_SIZE, K - tile_k);
|
||||
for (var k = 0u; k < tile_k_end; k++) {
|
||||
sum += A_tile[local_row * TILE_SIZE + k] * B_tile[k * TILE_SIZE + local_col];
|
||||
}
|
||||
|
||||
// Synchronize before loading next tile
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
// Write result with optional scaling
|
||||
if (row < M && col < N) {
|
||||
C[row * N + col] = sum * alpha;
|
||||
}
|
||||
}
|
||||
|
||||
// Batched matrix multiply for multi-head attention projections
|
||||
// C[b] = A[b] * B where A is batch_size x M x K and B is K x N
|
||||
@compute @workgroup_size(16, 16, 1)
|
||||
fn main_batched(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
let M = uniforms.M;
|
||||
let N = uniforms.N;
|
||||
let K = uniforms.K;
|
||||
|
||||
let batch_idx = group_id.z;
|
||||
let row = global_id.x;
|
||||
let col = global_id.y;
|
||||
|
||||
let local_row = local_id.x;
|
||||
let local_col = local_id.y;
|
||||
|
||||
var sum = 0.0f;
|
||||
let num_tiles = (K + TILE_SIZE - 1u) / TILE_SIZE;
|
||||
|
||||
// Offset into batched A
|
||||
let batch_offset_a = batch_idx * M * K;
|
||||
let batch_offset_c = batch_idx * M * N;
|
||||
|
||||
for (var t = 0u; t < num_tiles; t++) {
|
||||
let tile_k = t * TILE_SIZE;
|
||||
|
||||
// Load A tile (batched)
|
||||
let a_row = row;
|
||||
let a_col = tile_k + local_col;
|
||||
if (a_row < M && a_col < K) {
|
||||
A_tile[local_row * TILE_SIZE + local_col] = A[batch_offset_a + a_row * K + a_col];
|
||||
} else {
|
||||
A_tile[local_row * TILE_SIZE + local_col] = 0.0;
|
||||
}
|
||||
|
||||
// Load B tile (shared across batch)
|
||||
let b_row = tile_k + local_row;
|
||||
let b_col = col;
|
||||
if (b_row < K && b_col < N) {
|
||||
B_tile[local_row * TILE_SIZE + local_col] = B[b_row * N + b_col];
|
||||
} else {
|
||||
B_tile[local_row * TILE_SIZE + local_col] = 0.0;
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
let tile_k_end = min(TILE_SIZE, K - tile_k);
|
||||
for (var k = 0u; k < tile_k_end; k++) {
|
||||
sum += A_tile[local_row * TILE_SIZE + k] * B_tile[k * TILE_SIZE + local_col];
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
if (row < M && col < N) {
|
||||
C[batch_offset_c + row * N + col] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
// Vector-matrix multiply optimized for single token generation
|
||||
// y = x * W where x is 1 x K and W is K x N
|
||||
@compute @workgroup_size(256, 1, 1)
|
||||
fn main_gemv(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
) {
|
||||
let K = uniforms.K;
|
||||
let N = uniforms.N;
|
||||
|
||||
let col = global_id.x;
|
||||
|
||||
if (col >= N) {
|
||||
return;
|
||||
}
|
||||
|
||||
var sum = 0.0f;
|
||||
|
||||
// Simple reduction - each thread computes one output element
|
||||
for (var k = 0u; k < K; k++) {
|
||||
sum += A[k] * B[k * N + col];
|
||||
}
|
||||
|
||||
C[col] = sum * uniforms.alpha;
|
||||
}
|
||||
235
vendor/ruvector/crates/ruvllm-wasm/src/webgpu/shaders/norm.wgsl
vendored
Normal file
235
vendor/ruvector/crates/ruvllm-wasm/src/webgpu/shaders/norm.wgsl
vendored
Normal file
@@ -0,0 +1,235 @@
|
||||
// RMSNorm and LayerNorm Shaders for WebGPU WASM
|
||||
//
|
||||
// Implements normalization layers used in transformer architectures:
|
||||
// - RMSNorm: Used in Llama, Mistral (no mean subtraction)
|
||||
// - LayerNorm: Standard transformer normalization
|
||||
//
|
||||
// RMSNorm: y = x / sqrt(mean(x^2) + eps) * weight
|
||||
// LayerNorm: y = (x - mean) / sqrt(var + eps) * weight + bias
|
||||
|
||||
const WARP_SIZE: u32 = 32u;
|
||||
const MAX_DIM: u32 = 8192u;
|
||||
|
||||
struct NormUniforms {
|
||||
hidden_dim: u32,
|
||||
batch_size: u32,
|
||||
eps: f32,
|
||||
_pad: u32,
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> input: array<f32>;
|
||||
@group(0) @binding(1) var<storage, read> weight: array<f32>;
|
||||
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
|
||||
@group(0) @binding(3) var<uniform> uniforms: NormUniforms;
|
||||
|
||||
// Shared memory for parallel reduction
|
||||
var<workgroup> partial_sums: array<f32, 256>;
|
||||
|
||||
// RMSNorm: y = x * rsqrt(mean(x^2) + eps) * weight
|
||||
@compute @workgroup_size(256, 1, 1)
|
||||
fn rms_norm(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
let hidden_dim = uniforms.hidden_dim;
|
||||
let eps = uniforms.eps;
|
||||
|
||||
let batch_idx = group_id.x;
|
||||
let thread_id = local_id.x;
|
||||
let offset = batch_idx * hidden_dim;
|
||||
|
||||
// Each thread computes partial sum of squares
|
||||
var thread_sum = 0.0f;
|
||||
let elements_per_thread = (hidden_dim + 255u) / 256u;
|
||||
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < hidden_dim) {
|
||||
let x = input[offset + idx];
|
||||
thread_sum += x * x;
|
||||
}
|
||||
}
|
||||
|
||||
// Store partial sum
|
||||
partial_sums[thread_id] = thread_sum;
|
||||
workgroupBarrier();
|
||||
|
||||
// Parallel reduction for sum of squares
|
||||
for (var stride = 128u; stride > 0u; stride >>= 1u) {
|
||||
if (thread_id < stride) {
|
||||
partial_sums[thread_id] += partial_sums[thread_id + stride];
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
// Compute RMS scale factor
|
||||
let mean_sq = partial_sums[0] / f32(hidden_dim);
|
||||
let rms_scale = 1.0f / sqrt(mean_sq + eps);
|
||||
workgroupBarrier();
|
||||
|
||||
// Apply normalization and weight
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < hidden_dim) {
|
||||
let x = input[offset + idx];
|
||||
output[offset + idx] = x * rms_scale * weight[idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fused RMSNorm + Residual: y = (x + residual) * rsqrt(mean((x+res)^2) + eps) * weight
|
||||
@compute @workgroup_size(256, 1, 1)
|
||||
fn rms_norm_residual(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
let hidden_dim = uniforms.hidden_dim;
|
||||
let eps = uniforms.eps;
|
||||
|
||||
let batch_idx = group_id.x;
|
||||
let thread_id = local_id.x;
|
||||
let offset = batch_idx * hidden_dim;
|
||||
|
||||
// Compute partial sum of (x + residual)^2
|
||||
var thread_sum = 0.0f;
|
||||
let elements_per_thread = (hidden_dim + 255u) / 256u;
|
||||
|
||||
// First pass: compute residual sum and store in shared for reduction
|
||||
// Note: residual is passed in output buffer for in-place update
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < hidden_dim) {
|
||||
let x = input[offset + idx] + output[offset + idx]; // x + residual
|
||||
thread_sum += x * x;
|
||||
}
|
||||
}
|
||||
|
||||
partial_sums[thread_id] = thread_sum;
|
||||
workgroupBarrier();
|
||||
|
||||
// Parallel reduction
|
||||
for (var stride = 128u; stride > 0u; stride >>= 1u) {
|
||||
if (thread_id < stride) {
|
||||
partial_sums[thread_id] += partial_sums[thread_id + stride];
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
let mean_sq = partial_sums[0] / f32(hidden_dim);
|
||||
let rms_scale = 1.0f / sqrt(mean_sq + eps);
|
||||
workgroupBarrier();
|
||||
|
||||
// Apply normalization
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < hidden_dim) {
|
||||
let x = input[offset + idx] + output[offset + idx];
|
||||
output[offset + idx] = x * rms_scale * weight[idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Standard LayerNorm with bias
|
||||
@group(0) @binding(4) var<storage, read> bias: array<f32>;
|
||||
|
||||
@compute @workgroup_size(256, 1, 1)
|
||||
fn layer_norm(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
let hidden_dim = uniforms.hidden_dim;
|
||||
let eps = uniforms.eps;
|
||||
|
||||
let batch_idx = group_id.x;
|
||||
let thread_id = local_id.x;
|
||||
let offset = batch_idx * hidden_dim;
|
||||
|
||||
let elements_per_thread = (hidden_dim + 255u) / 256u;
|
||||
|
||||
// First pass: compute mean
|
||||
var thread_sum = 0.0f;
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < hidden_dim) {
|
||||
thread_sum += input[offset + idx];
|
||||
}
|
||||
}
|
||||
|
||||
partial_sums[thread_id] = thread_sum;
|
||||
workgroupBarrier();
|
||||
|
||||
for (var stride = 128u; stride > 0u; stride >>= 1u) {
|
||||
if (thread_id < stride) {
|
||||
partial_sums[thread_id] += partial_sums[thread_id + stride];
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
let mean = partial_sums[0] / f32(hidden_dim);
|
||||
workgroupBarrier();
|
||||
|
||||
// Second pass: compute variance
|
||||
var thread_var = 0.0f;
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < hidden_dim) {
|
||||
let diff = input[offset + idx] - mean;
|
||||
thread_var += diff * diff;
|
||||
}
|
||||
}
|
||||
|
||||
partial_sums[thread_id] = thread_var;
|
||||
workgroupBarrier();
|
||||
|
||||
for (var stride = 128u; stride > 0u; stride >>= 1u) {
|
||||
if (thread_id < stride) {
|
||||
partial_sums[thread_id] += partial_sums[thread_id + stride];
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
let variance = partial_sums[0] / f32(hidden_dim);
|
||||
let inv_std = 1.0f / sqrt(variance + eps);
|
||||
workgroupBarrier();
|
||||
|
||||
// Third pass: normalize and apply affine transform
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < hidden_dim) {
|
||||
let x = input[offset + idx];
|
||||
output[offset + idx] = (x - mean) * inv_std * weight[idx] + bias[idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fast RMSNorm for small hidden dimensions (direct reduction)
|
||||
@compute @workgroup_size(128, 1, 1)
|
||||
fn rms_norm_small(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
let hidden_dim = uniforms.hidden_dim;
|
||||
let eps = uniforms.eps;
|
||||
|
||||
let batch_idx = group_id.x;
|
||||
let thread_id = local_id.x;
|
||||
let offset = batch_idx * hidden_dim;
|
||||
|
||||
// For small hidden_dim (<= 128), direct computation
|
||||
if (thread_id < hidden_dim) {
|
||||
// Compute sum of squares (all threads contribute)
|
||||
var sum_sq = 0.0f;
|
||||
for (var i = 0u; i < hidden_dim; i++) {
|
||||
let x = input[offset + i];
|
||||
sum_sq += x * x;
|
||||
}
|
||||
|
||||
let rms = sqrt(sum_sq / f32(hidden_dim) + eps);
|
||||
let x = input[offset + thread_id];
|
||||
output[offset + thread_id] = x / rms * weight[thread_id];
|
||||
}
|
||||
}
|
||||
288
vendor/ruvector/crates/ruvllm-wasm/src/webgpu/shaders/softmax.wgsl
vendored
Normal file
288
vendor/ruvector/crates/ruvllm-wasm/src/webgpu/shaders/softmax.wgsl
vendored
Normal file
@@ -0,0 +1,288 @@
|
||||
// Softmax Shader for WebGPU WASM
|
||||
//
|
||||
// Numerically stable softmax: y = exp(x - max(x)) / sum(exp(x - max(x)))
|
||||
// Uses parallel reduction for finding max and computing sum.
|
||||
//
|
||||
// Variants:
|
||||
// - Full softmax for attention scores
|
||||
// - Temperature-scaled softmax for sampling
|
||||
// - Top-k softmax for efficient sampling
|
||||
|
||||
const MAX_SEQ_LEN: u32 = 8192u;
|
||||
|
||||
struct SoftmaxUniforms {
|
||||
dim: u32, // Dimension to reduce over
|
||||
batch_size: u32, // Number of rows
|
||||
temperature: f32, // Scaling factor (1.0 for standard)
|
||||
top_k: u32, // 0 for full softmax, >0 for top-k
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> input: array<f32>;
|
||||
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
|
||||
@group(0) @binding(2) var<uniform> uniforms: SoftmaxUniforms;
|
||||
|
||||
// Shared memory for reductions
|
||||
var<workgroup> reduction_buf: array<f32, 256>;
|
||||
|
||||
// Standard row-wise softmax
|
||||
@compute @workgroup_size(256, 1, 1)
|
||||
fn softmax(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
let dim = uniforms.dim;
|
||||
let temperature = uniforms.temperature;
|
||||
|
||||
let batch_idx = group_id.x;
|
||||
let thread_id = local_id.x;
|
||||
let offset = batch_idx * dim;
|
||||
|
||||
let elements_per_thread = (dim + 255u) / 256u;
|
||||
|
||||
// Phase 1: Find max value
|
||||
var thread_max = -1e10f;
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < dim) {
|
||||
thread_max = max(thread_max, input[offset + idx] / temperature);
|
||||
}
|
||||
}
|
||||
|
||||
reduction_buf[thread_id] = thread_max;
|
||||
workgroupBarrier();
|
||||
|
||||
// Parallel max reduction
|
||||
for (var stride = 128u; stride > 0u; stride >>= 1u) {
|
||||
if (thread_id < stride) {
|
||||
reduction_buf[thread_id] = max(reduction_buf[thread_id], reduction_buf[thread_id + stride]);
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
let max_val = reduction_buf[0];
|
||||
workgroupBarrier();
|
||||
|
||||
// Phase 2: Compute sum of exp(x - max)
|
||||
var thread_sum = 0.0f;
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < dim) {
|
||||
let x = input[offset + idx] / temperature - max_val;
|
||||
thread_sum += exp(x);
|
||||
}
|
||||
}
|
||||
|
||||
reduction_buf[thread_id] = thread_sum;
|
||||
workgroupBarrier();
|
||||
|
||||
// Parallel sum reduction
|
||||
for (var stride = 128u; stride > 0u; stride >>= 1u) {
|
||||
if (thread_id < stride) {
|
||||
reduction_buf[thread_id] += reduction_buf[thread_id + stride];
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
let sum_val = reduction_buf[0];
|
||||
let inv_sum = 1.0f / sum_val;
|
||||
workgroupBarrier();
|
||||
|
||||
// Phase 3: Compute normalized softmax
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < dim) {
|
||||
let x = input[offset + idx] / temperature - max_val;
|
||||
output[offset + idx] = exp(x) * inv_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// In-place softmax (input and output point to same buffer)
|
||||
@compute @workgroup_size(256, 1, 1)
|
||||
fn softmax_inplace(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
let dim = uniforms.dim;
|
||||
let temperature = uniforms.temperature;
|
||||
|
||||
let batch_idx = group_id.x;
|
||||
let thread_id = local_id.x;
|
||||
let offset = batch_idx * dim;
|
||||
|
||||
let elements_per_thread = (dim + 255u) / 256u;
|
||||
|
||||
// Find max
|
||||
var thread_max = -1e10f;
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < dim) {
|
||||
thread_max = max(thread_max, output[offset + idx] / temperature);
|
||||
}
|
||||
}
|
||||
|
||||
reduction_buf[thread_id] = thread_max;
|
||||
workgroupBarrier();
|
||||
|
||||
for (var stride = 128u; stride > 0u; stride >>= 1u) {
|
||||
if (thread_id < stride) {
|
||||
reduction_buf[thread_id] = max(reduction_buf[thread_id], reduction_buf[thread_id + stride]);
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
let max_val = reduction_buf[0];
|
||||
workgroupBarrier();
|
||||
|
||||
// Compute exp and sum
|
||||
var thread_sum = 0.0f;
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < dim) {
|
||||
let x = exp(output[offset + idx] / temperature - max_val);
|
||||
output[offset + idx] = x; // Store intermediate exp value
|
||||
thread_sum += x;
|
||||
}
|
||||
}
|
||||
|
||||
reduction_buf[thread_id] = thread_sum;
|
||||
workgroupBarrier();
|
||||
|
||||
for (var stride = 128u; stride > 0u; stride >>= 1u) {
|
||||
if (thread_id < stride) {
|
||||
reduction_buf[thread_id] += reduction_buf[thread_id + stride];
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
let inv_sum = 1.0f / reduction_buf[0];
|
||||
workgroupBarrier();
|
||||
|
||||
// Normalize in place
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < dim) {
|
||||
output[offset + idx] *= inv_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Small dimension softmax (dim <= 256)
|
||||
@compute @workgroup_size(256, 1, 1)
|
||||
fn softmax_small(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
let dim = uniforms.dim;
|
||||
let temperature = uniforms.temperature;
|
||||
|
||||
let batch_idx = group_id.x;
|
||||
let thread_id = local_id.x;
|
||||
let offset = batch_idx * dim;
|
||||
|
||||
// Load value for this thread
|
||||
var x = -1e10f;
|
||||
if (thread_id < dim) {
|
||||
x = input[offset + thread_id] / temperature;
|
||||
}
|
||||
|
||||
reduction_buf[thread_id] = x;
|
||||
workgroupBarrier();
|
||||
|
||||
// Find max using warp-level operations
|
||||
var max_val = x;
|
||||
for (var i = 0u; i < dim; i++) {
|
||||
max_val = max(max_val, reduction_buf[i]);
|
||||
}
|
||||
workgroupBarrier();
|
||||
|
||||
// Compute exp and sum
|
||||
var exp_val = 0.0f;
|
||||
if (thread_id < dim) {
|
||||
exp_val = exp(x - max_val);
|
||||
}
|
||||
reduction_buf[thread_id] = exp_val;
|
||||
workgroupBarrier();
|
||||
|
||||
var sum_val = 0.0f;
|
||||
for (var i = 0u; i < dim; i++) {
|
||||
sum_val += reduction_buf[i];
|
||||
}
|
||||
|
||||
// Write normalized output
|
||||
if (thread_id < dim) {
|
||||
output[offset + thread_id] = exp_val / sum_val;
|
||||
}
|
||||
}
|
||||
|
||||
// Log softmax for numerical stability in loss computation
|
||||
@compute @workgroup_size(256, 1, 1)
|
||||
fn log_softmax(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
let dim = uniforms.dim;
|
||||
let temperature = uniforms.temperature;
|
||||
|
||||
let batch_idx = group_id.x;
|
||||
let thread_id = local_id.x;
|
||||
let offset = batch_idx * dim;
|
||||
|
||||
let elements_per_thread = (dim + 255u) / 256u;
|
||||
|
||||
// Find max
|
||||
var thread_max = -1e10f;
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < dim) {
|
||||
thread_max = max(thread_max, input[offset + idx] / temperature);
|
||||
}
|
||||
}
|
||||
|
||||
reduction_buf[thread_id] = thread_max;
|
||||
workgroupBarrier();
|
||||
|
||||
for (var stride = 128u; stride > 0u; stride >>= 1u) {
|
||||
if (thread_id < stride) {
|
||||
reduction_buf[thread_id] = max(reduction_buf[thread_id], reduction_buf[thread_id + stride]);
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
let max_val = reduction_buf[0];
|
||||
workgroupBarrier();
|
||||
|
||||
// Compute log-sum-exp
|
||||
var thread_sum = 0.0f;
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < dim) {
|
||||
thread_sum += exp(input[offset + idx] / temperature - max_val);
|
||||
}
|
||||
}
|
||||
|
||||
reduction_buf[thread_id] = thread_sum;
|
||||
workgroupBarrier();
|
||||
|
||||
for (var stride = 128u; stride > 0u; stride >>= 1u) {
|
||||
if (thread_id < stride) {
|
||||
reduction_buf[thread_id] += reduction_buf[thread_id + stride];
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
let log_sum = log(reduction_buf[0]) + max_val;
|
||||
workgroupBarrier();
|
||||
|
||||
// Compute log softmax: log(softmax(x)) = x - log_sum_exp(x)
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < dim) {
|
||||
output[offset + idx] = input[offset + idx] / temperature - log_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user