Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,469 @@
//! GPU Buffer Management for WebGPU WASM
//!
//! This module provides buffer abstractions for GPU memory management
//! in the browser WebGPU environment.
use js_sys::{Float32Array, Uint8Array};
use std::cell::RefCell;
use wasm_bindgen::prelude::*;
/// Buffer usage flags
#[wasm_bindgen]
#[derive(Debug, Clone, Copy, Default)]
pub struct GpuBufferUsage {
/// Can be mapped for reading
#[wasm_bindgen(skip)]
pub map_read: bool,
/// Can be mapped for writing
#[wasm_bindgen(skip)]
pub map_write: bool,
/// Can be used as copy source
#[wasm_bindgen(skip)]
pub copy_src: bool,
/// Can be used as copy destination
#[wasm_bindgen(skip)]
pub copy_dst: bool,
/// Can be used as storage buffer
#[wasm_bindgen(skip)]
pub storage: bool,
/// Can be used as uniform buffer
#[wasm_bindgen(skip)]
pub uniform: bool,
}
#[wasm_bindgen]
impl GpuBufferUsage {
/// Create storage buffer usage (read/write compute)
#[wasm_bindgen(js_name = storage)]
pub fn new_storage() -> Self {
Self {
storage: true,
copy_dst: true,
copy_src: true,
..Default::default()
}
}
/// Create uniform buffer usage
#[wasm_bindgen(js_name = uniform)]
pub fn new_uniform() -> Self {
Self {
uniform: true,
copy_dst: true,
..Default::default()
}
}
/// Create staging buffer for upload
#[wasm_bindgen(js_name = stagingUpload)]
pub fn staging_upload() -> Self {
Self {
map_write: true,
copy_src: true,
..Default::default()
}
}
/// Create staging buffer for download
#[wasm_bindgen(js_name = stagingDownload)]
pub fn staging_download() -> Self {
Self {
map_read: true,
copy_dst: true,
..Default::default()
}
}
/// Create read-only storage buffer
#[wasm_bindgen(js_name = storageReadOnly)]
pub fn storage_read_only() -> Self {
Self {
storage: true,
copy_dst: true,
..Default::default()
}
}
/// Convert to WebGPU usage flags (as raw u32)
///
/// WebGPU buffer usage flags:
/// - MAP_READ = 0x0001
/// - MAP_WRITE = 0x0002
/// - COPY_SRC = 0x0004
/// - COPY_DST = 0x0008
/// - INDEX = 0x0010
/// - VERTEX = 0x0020
/// - UNIFORM = 0x0040
/// - STORAGE = 0x0080
/// - INDIRECT = 0x0100
/// - QUERY_RESOLVE = 0x0200
pub fn to_u32(&self) -> u32 {
let mut flags = 0u32;
if self.map_read {
flags |= 0x0001;
}
if self.map_write {
flags |= 0x0002;
}
if self.copy_src {
flags |= 0x0004;
}
if self.copy_dst {
flags |= 0x0008;
}
if self.uniform {
flags |= 0x0040;
}
if self.storage {
flags |= 0x0080;
}
flags
}
#[wasm_bindgen(getter, js_name = mapRead)]
pub fn get_map_read(&self) -> bool {
self.map_read
}
#[wasm_bindgen(setter, js_name = mapRead)]
pub fn set_map_read(&mut self, value: bool) {
self.map_read = value;
}
#[wasm_bindgen(getter, js_name = mapWrite)]
pub fn get_map_write(&self) -> bool {
self.map_write
}
#[wasm_bindgen(setter, js_name = mapWrite)]
pub fn set_map_write(&mut self, value: bool) {
self.map_write = value;
}
#[wasm_bindgen(getter, js_name = copySrc)]
pub fn get_copy_src(&self) -> bool {
self.copy_src
}
#[wasm_bindgen(setter, js_name = copySrc)]
pub fn set_copy_src(&mut self, value: bool) {
self.copy_src = value;
}
#[wasm_bindgen(getter, js_name = copyDst)]
pub fn get_copy_dst(&self) -> bool {
self.copy_dst
}
#[wasm_bindgen(setter, js_name = copyDst)]
pub fn set_copy_dst(&mut self, value: bool) {
self.copy_dst = value;
}
#[wasm_bindgen(getter, js_name = isStorage)]
pub fn get_storage(&self) -> bool {
self.storage
}
#[wasm_bindgen(setter, js_name = isStorage)]
pub fn set_storage(&mut self, value: bool) {
self.storage = value;
}
#[wasm_bindgen(getter, js_name = isUniform)]
pub fn get_uniform(&self) -> bool {
self.uniform
}
#[wasm_bindgen(setter, js_name = isUniform)]
pub fn set_uniform(&mut self, value: bool) {
self.uniform = value;
}
}
/// GPU buffer handle
///
/// Wraps a WebGPU buffer with metadata for safe operations.
#[wasm_bindgen]
pub struct GpuBuffer {
/// Internal buffer handle (web_sys::GpuBuffer when on wasm32)
#[cfg(target_arch = "wasm32")]
buffer: web_sys::GpuBuffer,
/// Placeholder for non-wasm32 builds
#[cfg(not(target_arch = "wasm32"))]
buffer: Vec<u8>,
/// Buffer size in bytes
size: usize,
/// Buffer usage flags
usage: GpuBufferUsage,
/// Optional label for debugging
label: Option<String>,
}
#[wasm_bindgen]
impl GpuBuffer {
/// Get buffer size in bytes
#[wasm_bindgen(getter)]
pub fn size(&self) -> usize {
self.size
}
/// Get buffer label
#[wasm_bindgen(getter)]
pub fn label(&self) -> Option<String> {
self.label.clone()
}
/// Check if buffer supports mapping for read
#[wasm_bindgen(getter, js_name = canMapRead)]
pub fn can_map_read(&self) -> bool {
self.usage.map_read
}
/// Check if buffer supports mapping for write
#[wasm_bindgen(getter, js_name = canMapWrite)]
pub fn can_map_write(&self) -> bool {
self.usage.map_write
}
/// Get size as number of f32 elements
#[wasm_bindgen(js_name = sizeAsF32)]
pub fn size_as_f32(&self) -> usize {
self.size / 4
}
/// Get the raw web_sys buffer (for advanced usage)
#[cfg(target_arch = "wasm32")]
#[wasm_bindgen(getter, js_name = rawBuffer)]
pub fn raw_buffer(&self) -> web_sys::GpuBuffer {
self.buffer.clone()
}
}
impl GpuBuffer {
/// Create a new GPU buffer (internal constructor)
#[cfg(target_arch = "wasm32")]
pub(crate) fn new(
buffer: web_sys::GpuBuffer,
size: usize,
usage: GpuBufferUsage,
label: Option<String>,
) -> Self {
Self {
buffer,
size,
usage,
label,
}
}
/// Create a new GPU buffer (non-wasm32 placeholder)
#[cfg(not(target_arch = "wasm32"))]
pub(crate) fn new(size: usize, usage: GpuBufferUsage, label: Option<String>) -> Self {
Self {
buffer: vec![0u8; size],
size,
usage,
label,
}
}
/// Get internal buffer reference
#[cfg(target_arch = "wasm32")]
pub(crate) fn inner(&self) -> &web_sys::GpuBuffer {
&self.buffer
}
}
/// Staging buffer pool for efficient CPU<->GPU transfers
#[wasm_bindgen]
pub struct StagingBufferPool {
/// Pool of upload staging buffers
upload_pool: RefCell<Vec<GpuBuffer>>,
/// Pool of download staging buffers
download_pool: RefCell<Vec<GpuBuffer>>,
/// Maximum buffers per pool
max_per_pool: usize,
/// Total bytes allocated
total_allocated: RefCell<usize>,
}
#[wasm_bindgen]
impl StagingBufferPool {
/// Create a new staging buffer pool
#[wasm_bindgen(constructor)]
pub fn new(max_per_pool: usize) -> Self {
Self {
upload_pool: RefCell::new(Vec::with_capacity(max_per_pool)),
download_pool: RefCell::new(Vec::with_capacity(max_per_pool)),
max_per_pool,
total_allocated: RefCell::new(0),
}
}
/// Get the number of upload buffers in pool
#[wasm_bindgen(getter, js_name = uploadBufferCount)]
pub fn upload_buffer_count(&self) -> usize {
self.upload_pool.borrow().len()
}
/// Get the number of download buffers in pool
#[wasm_bindgen(getter, js_name = downloadBufferCount)]
pub fn download_buffer_count(&self) -> usize {
self.download_pool.borrow().len()
}
/// Get total bytes allocated
#[wasm_bindgen(getter, js_name = totalAllocated)]
pub fn total_allocated(&self) -> usize {
*self.total_allocated.borrow()
}
/// Clear all pooled buffers
#[wasm_bindgen]
pub fn clear(&self) {
self.upload_pool.borrow_mut().clear();
self.download_pool.borrow_mut().clear();
*self.total_allocated.borrow_mut() = 0;
}
}
/// Tensor descriptor for buffer allocation
#[wasm_bindgen]
#[derive(Debug, Clone)]
pub struct TensorDescriptor {
/// Shape dimensions
shape: Vec<u32>,
/// Data type (0=f32, 1=f16, 2=i32, 3=u8)
dtype: u8,
}
#[wasm_bindgen]
impl TensorDescriptor {
/// Create tensor descriptor for a matrix
#[wasm_bindgen(js_name = matrix)]
pub fn matrix(rows: u32, cols: u32) -> Self {
Self {
shape: vec![rows, cols],
dtype: 0, // f32
}
}
/// Create tensor descriptor for a vector
#[wasm_bindgen(js_name = vector)]
pub fn vector(len: u32) -> Self {
Self {
shape: vec![len],
dtype: 0,
}
}
/// Create tensor descriptor with arbitrary shape
#[wasm_bindgen(constructor)]
pub fn new(shape: Vec<u32>, dtype: u8) -> Self {
Self { shape, dtype }
}
/// Get total number of elements
#[wasm_bindgen(js_name = numElements)]
pub fn num_elements(&self) -> usize {
self.shape.iter().map(|&d| d as usize).product()
}
/// Get size in bytes
#[wasm_bindgen(js_name = sizeBytes)]
pub fn size_bytes(&self) -> usize {
let element_size = match self.dtype {
0 => 4, // f32
1 => 2, // f16
2 => 4, // i32
3 => 1, // u8
_ => 4, // default to f32
};
self.num_elements() * element_size
}
/// Get shape dimensions
#[wasm_bindgen(getter)]
pub fn shape(&self) -> Vec<u32> {
self.shape.clone()
}
/// Get data type
#[wasm_bindgen(getter)]
pub fn dtype(&self) -> u8 {
self.dtype
}
/// Get number of dimensions
#[wasm_bindgen(getter)]
pub fn ndim(&self) -> usize {
self.shape.len()
}
}
/// Helper functions for creating typed arrays from GPU buffers
#[wasm_bindgen]
pub struct BufferHelpers;
#[wasm_bindgen]
impl BufferHelpers {
/// Create a Float32Array view from a Uint8Array
#[wasm_bindgen(js_name = asFloat32Array)]
pub fn as_float32_array(data: &Uint8Array) -> Float32Array {
Float32Array::new(&data.buffer())
}
/// Calculate aligned size for GPU buffers (must be multiple of 4)
#[wasm_bindgen(js_name = alignedSize)]
pub fn aligned_size(size: usize) -> usize {
(size + 3) & !3
}
/// Calculate workgroup count for a given dimension
#[wasm_bindgen(js_name = workgroupCount)]
pub fn workgroup_count(total: u32, workgroup_size: u32) -> u32 {
(total + workgroup_size - 1) / workgroup_size
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_buffer_usage() {
let storage = GpuBufferUsage::new_storage();
assert!(storage.storage);
assert!(storage.copy_dst);
assert!(storage.copy_src);
assert!(!storage.uniform);
}
#[test]
fn test_tensor_descriptor() {
let matrix = TensorDescriptor::matrix(1024, 768);
assert_eq!(matrix.num_elements(), 1024 * 768);
assert_eq!(matrix.size_bytes(), 1024 * 768 * 4);
assert_eq!(matrix.ndim(), 2);
}
#[test]
fn test_aligned_size() {
assert_eq!(BufferHelpers::aligned_size(0), 0);
assert_eq!(BufferHelpers::aligned_size(1), 4);
assert_eq!(BufferHelpers::aligned_size(4), 4);
assert_eq!(BufferHelpers::aligned_size(5), 8);
}
#[test]
fn test_workgroup_count() {
assert_eq!(BufferHelpers::workgroup_count(1000, 256), 4);
assert_eq!(BufferHelpers::workgroup_count(256, 256), 1);
assert_eq!(BufferHelpers::workgroup_count(257, 256), 2);
}
}

View File

@@ -0,0 +1,882 @@
//! WebGPU Compute Context and Pipelines
//!
//! This module provides the core WebGPU compute functionality for WASM,
//! including context initialization, pipeline creation, and kernel execution.
//!
//! Note: WebGPU bindings use JavaScript interop via js_sys/Reflect since
//! web-sys WebGPU bindings are still unstable.
use js_sys::{Array, Float32Array, Object, Promise, Reflect};
use wasm_bindgen::prelude::*;
use wasm_bindgen_futures::JsFuture;
use super::{shaders, AdapterInfo, AttentionConfig};
/// Check if WebGPU is available in this browser
pub async fn is_webgpu_available() -> bool {
#[cfg(target_arch = "wasm32")]
{
if let Some(gpu) = get_gpu_object() {
return !gpu.is_undefined() && !gpu.is_null();
}
false
}
#[cfg(not(target_arch = "wasm32"))]
false
}
/// Get GPU adapter information if available
pub async fn get_gpu_info() -> Option<AdapterInfo> {
#[cfg(target_arch = "wasm32")]
{
let gpu = get_gpu_object()?;
// Request adapter
let options = Object::new();
let _ = Reflect::set(
&options,
&"powerPreference".into(),
&"high-performance".into(),
);
let adapter_promise = call_method(&gpu, "requestAdapter", &[options.into()]).ok()?;
let adapter = JsFuture::from(adapter_promise.dyn_into::<Promise>().ok()?)
.await
.ok()?;
if adapter.is_null() || adapter.is_undefined() {
return None;
}
// Get adapter info via requestAdapterInfo()
let info_promise = call_method(&adapter, "requestAdapterInfo", &[]).ok()?;
let info = JsFuture::from(info_promise.dyn_into::<Promise>().ok()?)
.await
.ok()?;
// Extract limits
let limits = Reflect::get(&adapter, &"limits".into()).ok()?;
Some(AdapterInfo {
vendor: get_string_prop(&info, "vendor").unwrap_or_default(),
architecture: get_string_prop(&info, "architecture").unwrap_or_default(),
device_type: get_string_prop(&info, "device").unwrap_or_else(|| "unknown".to_string()),
backend: "WebGPU".to_string(),
max_buffer_size: get_number_prop(&limits, "maxBufferSize")
.unwrap_or(256.0 * 1024.0 * 1024.0) as u64,
max_workgroup_size: get_number_prop(&limits, "maxComputeWorkgroupSizeX")
.unwrap_or(256.0) as u32,
})
}
#[cfg(not(target_arch = "wasm32"))]
None
}
// ============================================================================
// Helper Functions
// ============================================================================
#[cfg(target_arch = "wasm32")]
fn get_gpu_object() -> Option<JsValue> {
let window = web_sys::window()?;
let navigator = Reflect::get(&window, &"navigator".into()).ok()?;
let gpu = Reflect::get(&navigator, &"gpu".into()).ok()?;
if gpu.is_undefined() || gpu.is_null() {
None
} else {
Some(gpu)
}
}
#[cfg(target_arch = "wasm32")]
fn get_string_prop(obj: &JsValue, key: &str) -> Option<String> {
Reflect::get(obj, &key.into())
.ok()
.and_then(|v| v.as_string())
}
#[cfg(target_arch = "wasm32")]
fn get_number_prop(obj: &JsValue, key: &str) -> Option<f64> {
Reflect::get(obj, &key.into()).ok().and_then(|v| v.as_f64())
}
#[cfg(target_arch = "wasm32")]
fn call_method(obj: &JsValue, method: &str, args: &[JsValue]) -> Result<JsValue, JsValue> {
let func = Reflect::get(obj, &method.into())?.dyn_into::<js_sys::Function>()?;
let args_array = Array::new();
for arg in args {
args_array.push(arg);
}
Reflect::apply(&func, obj, &args_array)
}
// ============================================================================
// WebGPU Context
// ============================================================================
/// WebGPU context holding device and queue references
#[wasm_bindgen]
pub struct WebGpuContext {
/// GPU device object (JsValue wrapper)
#[cfg(target_arch = "wasm32")]
device: JsValue,
/// Command queue object
#[cfg(target_arch = "wasm32")]
queue: JsValue,
/// Placeholder for non-wasm builds
#[cfg(not(target_arch = "wasm32"))]
_phantom: std::marker::PhantomData<()>,
/// Adapter information
adapter_info: AdapterInfo,
}
#[wasm_bindgen]
impl WebGpuContext {
/// Initialize WebGPU context
#[wasm_bindgen(js_name = init)]
pub async fn init() -> Result<WebGpuContext, JsValue> {
#[cfg(target_arch = "wasm32")]
{
let gpu = get_gpu_object().ok_or_else(|| JsValue::from_str("WebGPU not available"))?;
// Request adapter with high performance preference
let adapter_options = Object::new();
Reflect::set(
&adapter_options,
&"powerPreference".into(),
&"high-performance".into(),
)?;
let adapter_promise = call_method(&gpu, "requestAdapter", &[adapter_options.into()])?;
let adapter = JsFuture::from(adapter_promise.dyn_into::<Promise>()?).await?;
if adapter.is_null() || adapter.is_undefined() {
return Err(JsValue::from_str("No suitable GPU adapter found"));
}
// Get adapter info
let info_promise = call_method(&adapter, "requestAdapterInfo", &[])?;
let info = JsFuture::from(info_promise.dyn_into::<Promise>()?).await?;
let limits = Reflect::get(&adapter, &"limits".into())?;
let adapter_info = AdapterInfo {
vendor: get_string_prop(&info, "vendor").unwrap_or_default(),
architecture: get_string_prop(&info, "architecture").unwrap_or_default(),
device_type: get_string_prop(&info, "device")
.unwrap_or_else(|| "unknown".to_string()),
backend: "WebGPU".to_string(),
max_buffer_size: get_number_prop(&limits, "maxBufferSize")
.unwrap_or(256.0 * 1024.0 * 1024.0) as u64,
max_workgroup_size: get_number_prop(&limits, "maxComputeWorkgroupSizeX")
.unwrap_or(256.0) as u32,
};
// Request device
let device_descriptor = Object::new();
Reflect::set(&device_descriptor, &"label".into(), &"ruvllm-wasm".into())?;
let device_promise =
call_method(&adapter, "requestDevice", &[device_descriptor.into()])?;
let device = JsFuture::from(device_promise.dyn_into::<Promise>()?).await?;
// Get queue
let queue = Reflect::get(&device, &"queue".into())?;
Ok(WebGpuContext {
device,
queue,
adapter_info,
})
}
#[cfg(not(target_arch = "wasm32"))]
Err(JsValue::from_str("WebGPU only available in WASM"))
}
/// Get adapter information
#[wasm_bindgen(getter, js_name = adapterInfo)]
pub fn adapter_info(&self) -> AdapterInfo {
self.adapter_info.clone()
}
/// Check if context is valid
#[wasm_bindgen(getter, js_name = isValid)]
pub fn is_valid(&self) -> bool {
#[cfg(target_arch = "wasm32")]
{
!self.device.is_undefined() && !self.device.is_null()
}
#[cfg(not(target_arch = "wasm32"))]
false
}
/// Create a GPU buffer
#[cfg(target_arch = "wasm32")]
fn create_buffer_internal(
&self,
size: usize,
usage: u32,
label: Option<&str>,
) -> Result<JsValue, JsValue> {
let descriptor = Object::new();
Reflect::set(&descriptor, &"size".into(), &JsValue::from_f64(size as f64))?;
Reflect::set(
&descriptor,
&"usage".into(),
&JsValue::from_f64(usage as f64),
)?;
if let Some(lbl) = label {
Reflect::set(&descriptor, &"label".into(), &lbl.into())?;
}
call_method(&self.device, "createBuffer", &[descriptor.into()])
}
/// Write data to GPU buffer
#[cfg(target_arch = "wasm32")]
fn write_buffer_internal(&self, buffer: &JsValue, data: &[f32]) -> Result<(), JsValue> {
let data_array = Float32Array::from(data);
call_method(
&self.queue,
"writeBuffer",
&[
buffer.clone(),
JsValue::from_f64(0.0),
data_array.buffer().into(),
],
)?;
Ok(())
}
}
// ============================================================================
// Compute Pipeline
// ============================================================================
/// Compute pipeline handle
#[wasm_bindgen]
pub struct ComputePipeline {
#[cfg(target_arch = "wasm32")]
pipeline: JsValue,
#[cfg(target_arch = "wasm32")]
bind_group_layout: JsValue,
#[cfg(not(target_arch = "wasm32"))]
_phantom: std::marker::PhantomData<()>,
entry_point: String,
workgroup_size: [u32; 3],
}
#[wasm_bindgen]
impl ComputePipeline {
/// Get the entry point name
#[wasm_bindgen(getter, js_name = entryPoint)]
pub fn entry_point(&self) -> String {
self.entry_point.clone()
}
/// Get the workgroup size
#[wasm_bindgen(getter, js_name = workgroupSize)]
pub fn workgroup_size(&self) -> Vec<u32> {
self.workgroup_size.to_vec()
}
}
// ============================================================================
// WebGPU Inference Engine
// ============================================================================
/// WebGPU inference engine for LLM operations
#[wasm_bindgen]
pub struct WebGpuInference {
#[cfg(target_arch = "wasm32")]
device: JsValue,
#[cfg(target_arch = "wasm32")]
queue: JsValue,
#[cfg(not(target_arch = "wasm32"))]
_phantom: std::marker::PhantomData<()>,
adapter_info: AdapterInfo,
}
#[wasm_bindgen]
impl WebGpuInference {
/// Check if WebGPU is available
#[wasm_bindgen(js_name = isAvailable)]
pub async fn is_available() -> bool {
is_webgpu_available().await
}
/// Initialize WebGPU inference engine
#[wasm_bindgen(js_name = init)]
pub async fn init() -> Result<WebGpuInference, JsValue> {
let ctx = WebGpuContext::init().await?;
Ok(WebGpuInference {
#[cfg(target_arch = "wasm32")]
device: ctx.device,
#[cfg(target_arch = "wasm32")]
queue: ctx.queue,
#[cfg(not(target_arch = "wasm32"))]
_phantom: std::marker::PhantomData,
adapter_info: ctx.adapter_info,
})
}
/// Get adapter information
#[wasm_bindgen(getter, js_name = adapterInfo)]
pub fn adapter_info(&self) -> AdapterInfo {
self.adapter_info.clone()
}
/// Perform matrix multiplication: C = A * B
///
/// Args:
/// a: Matrix A as flat f32 array (M x K)
/// b: Matrix B as flat f32 array (K x N)
/// m: Number of rows in A
/// n: Number of columns in B
/// k: Shared dimension
///
/// Returns: Result matrix C as f32 array (M x N)
#[wasm_bindgen]
pub async fn matmul(
&self,
a: &[f32],
b: &[f32],
m: u32,
n: u32,
k: u32,
) -> Result<Vec<f32>, JsValue> {
// Validate dimensions
let expected_a = (m as usize) * (k as usize);
let expected_b = (k as usize) * (n as usize);
if a.len() != expected_a {
return Err(JsValue::from_str(&format!(
"Matrix A dimension mismatch: expected {}, got {}",
expected_a,
a.len()
)));
}
if b.len() != expected_b {
return Err(JsValue::from_str(&format!(
"Matrix B dimension mismatch: expected {}, got {}",
expected_b,
b.len()
)));
}
#[cfg(target_arch = "wasm32")]
{
let output_size = (m as usize) * (n as usize);
// GPU buffer usage flags
const STORAGE: u32 = 0x80; // GPUBufferUsage.STORAGE
const COPY_SRC: u32 = 0x04; // GPUBufferUsage.COPY_SRC
const COPY_DST: u32 = 0x08; // GPUBufferUsage.COPY_DST
const MAP_READ: u32 = 0x01; // GPUBufferUsage.MAP_READ
const UNIFORM: u32 = 0x40; // GPUBufferUsage.UNIFORM
// Create buffers
let buffer_a = self.create_buffer(a.len() * 4, STORAGE | COPY_DST, Some("matmul_a"))?;
let buffer_b = self.create_buffer(b.len() * 4, STORAGE | COPY_DST, Some("matmul_b"))?;
let buffer_c =
self.create_buffer(output_size * 4, STORAGE | COPY_SRC, Some("matmul_c"))?;
// Create uniform buffer for dimensions
let uniform_data: [f32; 4] = [m as f32, n as f32, k as f32, 1.0]; // M, N, K, alpha
let uniform_buffer =
self.create_buffer(16, UNIFORM | COPY_DST, Some("matmul_uniforms"))?;
// Write data to buffers
self.write_buffer(&buffer_a, a)?;
self.write_buffer(&buffer_b, b)?;
self.write_buffer(&uniform_buffer, &uniform_data)?;
// Create shader module
let shader_desc = Object::new();
Reflect::set(&shader_desc, &"code".into(), &shaders::MATMUL_SHADER.into())?;
let shader_module =
call_method(&self.device, "createShaderModule", &[shader_desc.into()])?;
// Create bind group layout
let layout_entries = Array::new();
// Storage buffer entries (A, B, C)
for i in 0..3u32 {
let entry = Object::new();
Reflect::set(&entry, &"binding".into(), &JsValue::from_f64(i as f64))?;
Reflect::set(&entry, &"visibility".into(), &JsValue::from_f64(4.0))?; // COMPUTE stage
let buffer_layout = Object::new();
Reflect::set(
&buffer_layout,
&"type".into(),
&(if i < 2 {
"read-only-storage"
} else {
"storage"
})
.into(),
)?;
Reflect::set(&entry, &"buffer".into(), &buffer_layout)?;
layout_entries.push(&entry);
}
// Uniform buffer entry
let uniform_entry = Object::new();
Reflect::set(&uniform_entry, &"binding".into(), &JsValue::from_f64(3.0))?;
Reflect::set(
&uniform_entry,
&"visibility".into(),
&JsValue::from_f64(4.0),
)?;
let uniform_layout = Object::new();
Reflect::set(&uniform_layout, &"type".into(), &"uniform".into())?;
Reflect::set(&uniform_entry, &"buffer".into(), &uniform_layout)?;
layout_entries.push(&uniform_entry);
let layout_desc = Object::new();
Reflect::set(&layout_desc, &"entries".into(), &layout_entries)?;
let bind_group_layout =
call_method(&self.device, "createBindGroupLayout", &[layout_desc.into()])?;
// Create pipeline layout
let layouts = Array::new();
layouts.push(&bind_group_layout);
let pipeline_layout_desc = Object::new();
Reflect::set(&pipeline_layout_desc, &"bindGroupLayouts".into(), &layouts)?;
let pipeline_layout = call_method(
&self.device,
"createPipelineLayout",
&[pipeline_layout_desc.into()],
)?;
// Create compute pipeline
let compute_stage = Object::new();
Reflect::set(&compute_stage, &"module".into(), &shader_module)?;
Reflect::set(&compute_stage, &"entryPoint".into(), &"main".into())?;
let pipeline_desc = Object::new();
Reflect::set(&pipeline_desc, &"layout".into(), &pipeline_layout)?;
Reflect::set(&pipeline_desc, &"compute".into(), &compute_stage)?;
let pipeline = call_method(
&self.device,
"createComputePipeline",
&[pipeline_desc.into()],
)?;
// Create bind group
let bind_entries = Array::new();
for (i, buffer) in [&buffer_a, &buffer_b, &buffer_c, &uniform_buffer]
.iter()
.enumerate()
{
let entry = Object::new();
Reflect::set(&entry, &"binding".into(), &JsValue::from_f64(i as f64))?;
let resource = Object::new();
Reflect::set(&resource, &"buffer".into(), buffer)?;
Reflect::set(&entry, &"resource".into(), &resource)?;
bind_entries.push(&entry);
}
let bind_group_desc = Object::new();
Reflect::set(&bind_group_desc, &"layout".into(), &bind_group_layout)?;
Reflect::set(&bind_group_desc, &"entries".into(), &bind_entries)?;
let bind_group =
call_method(&self.device, "createBindGroup", &[bind_group_desc.into()])?;
// Create command encoder
let encoder_desc = Object::new();
let encoder =
call_method(&self.device, "createCommandEncoder", &[encoder_desc.into()])?;
// Begin compute pass
let pass_desc = Object::new();
let pass = call_method(&encoder, "beginComputePass", &[pass_desc.into()])?;
// Set pipeline and bind group
call_method(&pass, "setPipeline", &[pipeline.clone()])?;
call_method(
&pass,
"setBindGroup",
&[JsValue::from_f64(0.0), bind_group.clone()],
)?;
// Dispatch workgroups (16x16 tile size)
let workgroups_x = (m + 15) / 16;
let workgroups_y = (n + 15) / 16;
call_method(
&pass,
"dispatchWorkgroups",
&[
JsValue::from_f64(workgroups_x as f64),
JsValue::from_f64(workgroups_y as f64),
],
)?;
call_method(&pass, "end", &[])?;
// Create staging buffer for readback
let staging =
self.create_buffer(output_size * 4, MAP_READ | COPY_DST, Some("staging"))?;
// Copy result to staging
call_method(
&encoder,
"copyBufferToBuffer",
&[
buffer_c.clone(),
JsValue::from_f64(0.0),
staging.clone(),
JsValue::from_f64(0.0),
JsValue::from_f64((output_size * 4) as f64),
],
)?;
// Submit commands
let command_buffer = call_method(&encoder, "finish", &[])?;
let commands = Array::new();
commands.push(&command_buffer);
call_method(&self.queue, "submit", &[commands.into()])?;
// Map staging buffer and read result
let map_promise = call_method(&staging, "mapAsync", &[JsValue::from_f64(1.0)])?; // MAP_READ = 1
JsFuture::from(map_promise.dyn_into::<Promise>()?).await?;
let mapped_range = call_method(&staging, "getMappedRange", &[])?;
let data = Float32Array::new(&mapped_range).to_vec();
call_method(&staging, "unmap", &[])?;
Ok(data)
}
#[cfg(not(target_arch = "wasm32"))]
{
// CPU fallback - naive implementation
let mut c = vec![0.0f32; (m as usize) * (n as usize)];
for i in 0..m as usize {
for j in 0..n as usize {
let mut sum = 0.0f32;
for l in 0..k as usize {
sum += a[i * k as usize + l] * b[l * n as usize + j];
}
c[i * n as usize + j] = sum;
}
}
Ok(c)
}
}
/// Perform attention: Output = softmax(Q * K^T / sqrt(d_k)) * V
#[wasm_bindgen]
pub async fn attention(
&self,
q: &[f32],
k: &[f32],
v: &[f32],
config: &AttentionConfig,
) -> Result<Vec<f32>, JsValue> {
let hidden_dim = config.hidden_dim();
let expected_size = (config.seq_len as usize) * (hidden_dim as usize);
if q.len() != expected_size || k.len() != expected_size || v.len() != expected_size {
return Err(JsValue::from_str(&format!(
"Attention tensor dimension mismatch: expected {}, got Q:{}, K:{}, V:{}",
expected_size,
q.len(),
k.len(),
v.len()
)));
}
// CPU fallback for attention (GPU implementation similar to matmul pattern)
// For production, would implement full GPU attention here
self.attention_cpu(q, k, v, config)
}
/// CPU fallback for attention
fn attention_cpu(
&self,
q: &[f32],
k: &[f32],
v: &[f32],
config: &AttentionConfig,
) -> Result<Vec<f32>, JsValue> {
let seq_len = config.seq_len as usize;
let num_heads = config.num_heads as usize;
let head_dim = config.head_dim as usize;
let hidden_dim = num_heads * head_dim;
let scale = config.scale();
let mut output = vec![0.0f32; seq_len * hidden_dim];
// Process each head independently
for h in 0..num_heads {
for i in 0..seq_len {
// For this query position, compute attention to all key positions
let q_offset = i * hidden_dim + h * head_dim;
// Compute attention scores
let mut scores = vec![0.0f32; seq_len];
let mut max_score = f32::NEG_INFINITY;
for j in 0..seq_len {
// Causal masking
if config.causal && j > i {
scores[j] = f32::NEG_INFINITY;
continue;
}
let k_offset = j * hidden_dim + h * head_dim;
let mut score = 0.0f32;
for d in 0..head_dim {
score += q[q_offset + d] * k[k_offset + d];
}
score *= scale;
scores[j] = score;
if score > max_score {
max_score = score;
}
}
// Softmax
let mut sum = 0.0f32;
for j in 0..seq_len {
scores[j] = (scores[j] - max_score).exp();
sum += scores[j];
}
for j in 0..seq_len {
scores[j] /= sum;
}
// Compute weighted sum of values
let out_offset = i * hidden_dim + h * head_dim;
for d in 0..head_dim {
let mut weighted_sum = 0.0f32;
for j in 0..seq_len {
let v_offset = j * hidden_dim + h * head_dim;
weighted_sum += scores[j] * v[v_offset + d];
}
output[out_offset + d] = weighted_sum;
}
}
}
Ok(output)
}
/// Perform RMS normalization
#[wasm_bindgen(js_name = rmsNorm)]
pub async fn rms_norm(
&self,
input: &[f32],
weight: &[f32],
hidden_dim: u32,
eps: f32,
) -> Result<Vec<f32>, JsValue> {
if weight.len() != hidden_dim as usize {
return Err(JsValue::from_str(&format!(
"Weight dimension mismatch: expected {}, got {}",
hidden_dim,
weight.len()
)));
}
if input.len() % hidden_dim as usize != 0 {
return Err(JsValue::from_str(&format!(
"Input size {} not divisible by hidden_dim {}",
input.len(),
hidden_dim
)));
}
// CPU implementation
let batch_size = input.len() / hidden_dim as usize;
let mut output = vec![0.0f32; input.len()];
for b in 0..batch_size {
let offset = b * hidden_dim as usize;
// Compute sum of squares
let mut sum_sq = 0.0f32;
for i in 0..hidden_dim as usize {
let x = input[offset + i];
sum_sq += x * x;
}
// RMS scale
let rms = (sum_sq / hidden_dim as f32 + eps).sqrt();
// Normalize and scale
for i in 0..hidden_dim as usize {
output[offset + i] = input[offset + i] / rms * weight[i];
}
}
Ok(output)
}
/// Perform softmax
#[wasm_bindgen]
pub async fn softmax(
&self,
input: &[f32],
dim: u32,
temperature: f32,
) -> Result<Vec<f32>, JsValue> {
if input.len() % dim as usize != 0 {
return Err(JsValue::from_str(&format!(
"Input size {} not divisible by dim {}",
input.len(),
dim
)));
}
let batch_size = input.len() / dim as usize;
let mut output = vec![0.0f32; input.len()];
for b in 0..batch_size {
let offset = b * dim as usize;
// Find max (for numerical stability)
let mut max_val = f32::NEG_INFINITY;
for i in 0..dim as usize {
let x = input[offset + i] / temperature;
if x > max_val {
max_val = x;
}
}
// Compute exp and sum
let mut sum = 0.0f32;
for i in 0..dim as usize {
let x = (input[offset + i] / temperature - max_val).exp();
output[offset + i] = x;
sum += x;
}
// Normalize
for i in 0..dim as usize {
output[offset + i] /= sum;
}
}
Ok(output)
}
// Helper methods for GPU buffer management
#[cfg(target_arch = "wasm32")]
fn create_buffer(
&self,
size: usize,
usage: u32,
label: Option<&str>,
) -> Result<JsValue, JsValue> {
let descriptor = Object::new();
Reflect::set(&descriptor, &"size".into(), &JsValue::from_f64(size as f64))?;
Reflect::set(
&descriptor,
&"usage".into(),
&JsValue::from_f64(usage as f64),
)?;
if let Some(lbl) = label {
Reflect::set(&descriptor, &"label".into(), &lbl.into())?;
}
call_method(&self.device, "createBuffer", &[descriptor.into()])
}
#[cfg(target_arch = "wasm32")]
fn write_buffer(&self, buffer: &JsValue, data: &[f32]) -> Result<(), JsValue> {
let data_array = Float32Array::from(data);
call_method(
&self.queue,
"writeBuffer",
&[
buffer.clone(),
JsValue::from_f64(0.0),
data_array.buffer().into(),
],
)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cpu_matmul_fallback() {
// Test the CPU fallback logic (in non-wasm mode)
let a = vec![1.0, 2.0, 3.0, 4.0]; // 2x2
let b = vec![5.0, 6.0, 7.0, 8.0]; // 2x2
// Expected: [[1*5+2*7, 1*6+2*8], [3*5+4*7, 3*6+4*8]]
// = [[19, 22], [43, 50]]
let mut c = vec![0.0f32; 4];
for i in 0..2usize {
for j in 0..2usize {
let mut sum = 0.0f32;
for l in 0..2usize {
sum += a[i * 2 + l] * b[l * 2 + j];
}
c[i * 2 + j] = sum;
}
}
assert_eq!(c, vec![19.0, 22.0, 43.0, 50.0]);
}
#[test]
fn test_rms_norm_cpu() {
let input = vec![1.0, 2.0, 3.0, 4.0];
let weight = vec![1.0, 1.0, 1.0, 1.0];
let hidden_dim = 4;
let eps = 1e-5f32;
// sum_sq = 1 + 4 + 9 + 16 = 30
// rms = sqrt(30/4 + eps) = sqrt(7.5) ≈ 2.7386
let rms = (30.0f32 / 4.0 + eps).sqrt();
let expected: Vec<f32> = input.iter().map(|&x| x / rms).collect();
// Verify calculation
assert!((expected[0] - 0.3651).abs() < 0.001);
}
#[test]
fn test_softmax_cpu() {
let input = vec![1.0, 2.0, 3.0];
let temperature = 1.0f32;
// max = 3
// exp(1-3) = exp(-2), exp(2-3) = exp(-1), exp(3-3) = exp(0) = 1
let exps: Vec<f32> = vec![(-2.0f32).exp(), (-1.0f32).exp(), 1.0];
let sum: f32 = exps.iter().sum();
let expected: Vec<f32> = exps.iter().map(|&x| x / sum).collect();
// Verify softmax sums to 1
let softmax_sum: f32 = expected.iter().sum();
assert!((softmax_sum - 1.0).abs() < 0.001);
}
}

View File

@@ -0,0 +1,345 @@
//! WebGPU Compute Module for WASM-based GPU Acceleration
//!
//! This module provides WebGPU compute shader support for LLM inference
//! operations in the browser. It includes:
//!
//! - Matrix multiplication (tiled, batched, GEMV)
//! - Flash Attention (causal, GQA, decode)
//! - RMSNorm and LayerNorm
//! - Softmax (standard, temperature-scaled, log-softmax)
//!
//! ## Feature Detection
//!
//! WebGPU availability is checked at runtime with graceful fallback:
//!
//! ```javascript
//! if (await WebGpuInference.isAvailable()) {
//! const gpu = await WebGpuInference.init();
//! const result = await gpu.matmul(a, b, m, n, k);
//! } else {
//! // Fall back to CPU implementation
//! }
//! ```
//!
//! ## Performance Targets
//!
//! - Matrix multiply: ~1 TFLOP on integrated GPUs, ~10 TFLOPS on discrete
//! - Attention: 2ms for 4K context on discrete GPU
//! - Normalization: <0.5ms for typical hidden dimensions
pub mod buffers;
pub mod compute;
pub mod shaders;
use wasm_bindgen::prelude::*;
pub use buffers::{GpuBuffer, GpuBufferUsage};
pub use compute::{ComputePipeline, WebGpuContext};
pub use shaders::ShaderModule;
/// GPU adapter information
#[wasm_bindgen]
#[derive(Debug, Clone)]
pub struct AdapterInfo {
/// GPU vendor name
#[wasm_bindgen(skip)]
pub vendor: String,
/// GPU architecture/device name
#[wasm_bindgen(skip)]
pub architecture: String,
/// Device type (integrated, discrete, etc.)
#[wasm_bindgen(skip)]
pub device_type: String,
/// Backend API (WebGPU, etc.)
#[wasm_bindgen(skip)]
pub backend: String,
/// Maximum buffer size in bytes
#[wasm_bindgen(skip)]
pub max_buffer_size: u64,
/// Maximum compute workgroup size
#[wasm_bindgen(skip)]
pub max_workgroup_size: u32,
}
#[wasm_bindgen]
impl AdapterInfo {
/// Get GPU vendor name
#[wasm_bindgen(getter)]
pub fn vendor(&self) -> String {
self.vendor.clone()
}
/// Get GPU architecture
#[wasm_bindgen(getter)]
pub fn architecture(&self) -> String {
self.architecture.clone()
}
/// Get device type
#[wasm_bindgen(getter, js_name = deviceType)]
pub fn device_type(&self) -> String {
self.device_type.clone()
}
/// Get backend API
#[wasm_bindgen(getter)]
pub fn backend(&self) -> String {
self.backend.clone()
}
/// Get maximum buffer size
#[wasm_bindgen(getter, js_name = maxBufferSize)]
pub fn max_buffer_size(&self) -> u64 {
self.max_buffer_size
}
/// Get maximum workgroup size
#[wasm_bindgen(getter, js_name = maxWorkgroupSize)]
pub fn max_workgroup_size(&self) -> u32 {
self.max_workgroup_size
}
/// Convert to JSON string
#[wasm_bindgen(js_name = toJson)]
pub fn to_json(&self) -> Result<String, JsValue> {
let json = serde_json::json!({
"vendor": self.vendor,
"architecture": self.architecture,
"deviceType": self.device_type,
"backend": self.backend,
"maxBufferSize": self.max_buffer_size,
"maxWorkgroupSize": self.max_workgroup_size,
});
serde_json::to_string(&json).map_err(|e| JsValue::from_str(&e.to_string()))
}
}
/// Attention configuration for compute shaders
#[wasm_bindgen]
#[derive(Debug, Clone)]
pub struct AttentionConfig {
/// Sequence length for queries
#[wasm_bindgen(skip)]
pub seq_len: u32,
/// Key/Value sequence length (can differ for encoder-decoder)
#[wasm_bindgen(skip)]
pub kv_seq_len: u32,
/// Number of attention heads
#[wasm_bindgen(skip)]
pub num_heads: u32,
/// Dimension per head
#[wasm_bindgen(skip)]
pub head_dim: u32,
/// Whether to apply causal masking
#[wasm_bindgen(skip)]
pub causal: bool,
}
#[wasm_bindgen]
impl AttentionConfig {
/// Create new attention configuration
#[wasm_bindgen(constructor)]
pub fn new(seq_len: u32, num_heads: u32, head_dim: u32, causal: bool) -> Self {
Self {
seq_len,
kv_seq_len: seq_len,
num_heads,
head_dim,
causal,
}
}
/// Create for encoder-decoder models with different KV length
#[wasm_bindgen(js_name = forEncoderDecoder)]
pub fn for_encoder_decoder(
seq_len: u32,
kv_seq_len: u32,
num_heads: u32,
head_dim: u32,
) -> Self {
Self {
seq_len,
kv_seq_len,
num_heads,
head_dim,
causal: false,
}
}
/// Get the scaling factor (1/sqrt(head_dim))
pub fn scale(&self) -> f32 {
1.0 / (self.head_dim as f32).sqrt()
}
/// Get total hidden dimension
pub fn hidden_dim(&self) -> u32 {
self.num_heads * self.head_dim
}
#[wasm_bindgen(getter, js_name = seqLen)]
pub fn get_seq_len(&self) -> u32 {
self.seq_len
}
#[wasm_bindgen(setter, js_name = seqLen)]
pub fn set_seq_len(&mut self, value: u32) {
self.seq_len = value;
}
#[wasm_bindgen(getter, js_name = kvSeqLen)]
pub fn get_kv_seq_len(&self) -> u32 {
self.kv_seq_len
}
#[wasm_bindgen(setter, js_name = kvSeqLen)]
pub fn set_kv_seq_len(&mut self, value: u32) {
self.kv_seq_len = value;
}
#[wasm_bindgen(getter, js_name = numHeads)]
pub fn get_num_heads(&self) -> u32 {
self.num_heads
}
#[wasm_bindgen(setter, js_name = numHeads)]
pub fn set_num_heads(&mut self, value: u32) {
self.num_heads = value;
}
#[wasm_bindgen(getter, js_name = headDim)]
pub fn get_head_dim(&self) -> u32 {
self.head_dim
}
#[wasm_bindgen(setter, js_name = headDim)]
pub fn set_head_dim(&mut self, value: u32) {
self.head_dim = value;
}
#[wasm_bindgen(getter)]
pub fn get_causal(&self) -> bool {
self.causal
}
#[wasm_bindgen(setter)]
pub fn set_causal(&mut self, value: bool) {
self.causal = value;
}
}
/// Check if WebGPU is available in this browser
#[wasm_bindgen(js_name = isWebGpuAvailable)]
pub async fn is_webgpu_available() -> bool {
compute::is_webgpu_available().await
}
/// Get GPU information if available
#[wasm_bindgen(js_name = getGpuInfo)]
pub async fn get_gpu_info() -> Result<JsValue, JsValue> {
match compute::get_gpu_info().await {
Some(info) => {
let js_obj = js_sys::Object::new();
js_sys::Reflect::set(&js_obj, &"vendor".into(), &info.vendor.into())?;
js_sys::Reflect::set(&js_obj, &"architecture".into(), &info.architecture.into())?;
js_sys::Reflect::set(&js_obj, &"deviceType".into(), &info.device_type.into())?;
js_sys::Reflect::set(&js_obj, &"backend".into(), &info.backend.into())?;
js_sys::Reflect::set(
&js_obj,
&"maxBufferSize".into(),
&JsValue::from_f64(info.max_buffer_size as f64),
)?;
js_sys::Reflect::set(
&js_obj,
&"maxWorkgroupSize".into(),
&JsValue::from_f64(info.max_workgroup_size as f64),
)?;
Ok(js_obj.into())
}
None => Ok(JsValue::NULL),
}
}
/// WebGPU error types
#[derive(Debug)]
pub enum WebGpuError {
/// WebGPU not available in this browser
NotAvailable,
/// Failed to get GPU adapter
AdapterNotFound,
/// Failed to create device
DeviceCreationFailed(String),
/// Buffer allocation failed
BufferAllocationFailed { requested: usize, available: usize },
/// Shader compilation failed
ShaderCompilationFailed(String),
/// Invalid dimensions for operation
DimensionMismatch { expected: String, actual: String },
/// Operation timed out
Timeout,
/// Generic GPU error
GpuError(String),
}
impl std::fmt::Display for WebGpuError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NotAvailable => write!(f, "WebGPU is not available in this browser"),
Self::AdapterNotFound => write!(f, "No suitable GPU adapter found"),
Self::DeviceCreationFailed(msg) => write!(f, "Failed to create GPU device: {}", msg),
Self::BufferAllocationFailed {
requested,
available,
} => {
write!(
f,
"Buffer allocation failed: requested {} bytes, {} available",
requested, available
)
}
Self::ShaderCompilationFailed(msg) => write!(f, "Shader compilation failed: {}", msg),
Self::DimensionMismatch { expected, actual } => {
write!(
f,
"Dimension mismatch: expected {}, got {}",
expected, actual
)
}
Self::Timeout => write!(f, "GPU operation timed out"),
Self::GpuError(msg) => write!(f, "GPU error: {}", msg),
}
}
}
impl std::error::Error for WebGpuError {}
impl From<WebGpuError> for JsValue {
fn from(error: WebGpuError) -> Self {
JsValue::from_str(&error.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_attention_config() {
let config = AttentionConfig::new(512, 8, 64, true);
assert_eq!(config.hidden_dim(), 512);
assert!((config.scale() - 0.125).abs() < 0.001); // 1/sqrt(64) = 0.125
}
#[test]
fn test_adapter_info_json() {
let info = AdapterInfo {
vendor: "TestVendor".to_string(),
architecture: "TestArch".to_string(),
device_type: "integrated".to_string(),
backend: "WebGPU".to_string(),
max_buffer_size: 1024 * 1024 * 256,
max_workgroup_size: 256,
};
let json = info.to_json().unwrap();
assert!(json.contains("TestVendor"));
}
}

View File

@@ -0,0 +1,195 @@
//! WGSL Shader Module Definitions
//!
//! This module contains the embedded WGSL shader source code for all
//! compute operations. Shaders are embedded at compile time for efficient
//! loading in WASM.
/// Matrix multiplication shader (tiled with shared memory)
pub const MATMUL_SHADER: &str = include_str!("shaders/matmul.wgsl");
/// Flash attention shader (online softmax, causal masking)
pub const ATTENTION_SHADER: &str = include_str!("shaders/attention.wgsl");
/// RMSNorm and LayerNorm shader
pub const NORM_SHADER: &str = include_str!("shaders/norm.wgsl");
/// Softmax shader (numerically stable)
pub const SOFTMAX_SHADER: &str = include_str!("shaders/softmax.wgsl");
/// Shader entry points for matrix multiplication
pub mod matmul {
/// Standard tiled matrix multiply
pub const MAIN: &str = "main";
/// Batched matrix multiply for attention projections
pub const BATCHED: &str = "main_batched";
/// Vector-matrix multiply for single token generation
pub const GEMV: &str = "main_gemv";
}
/// Shader entry points for attention
pub mod attention {
/// Standard multi-head attention
pub const MAIN: &str = "main";
/// Grouped query attention (GQA)
pub const GQA: &str = "main_gqa";
/// Single token decode attention
pub const DECODE: &str = "main_decode";
}
/// Shader entry points for normalization
pub mod norm {
/// RMSNorm (Llama-style)
pub const RMS_NORM: &str = "rms_norm";
/// RMSNorm with fused residual connection
pub const RMS_NORM_RESIDUAL: &str = "rms_norm_residual";
/// Standard LayerNorm
pub const LAYER_NORM: &str = "layer_norm";
/// Fast RMSNorm for small dimensions
pub const RMS_NORM_SMALL: &str = "rms_norm_small";
}
/// Shader entry points for softmax
pub mod softmax {
/// Standard row-wise softmax
pub const MAIN: &str = "softmax";
/// In-place softmax
pub const INPLACE: &str = "softmax_inplace";
/// Small dimension softmax
pub const SMALL: &str = "softmax_small";
/// Log softmax for loss computation
pub const LOG_SOFTMAX: &str = "log_softmax";
}
/// Shader module wrapper for wasm-bindgen
use wasm_bindgen::prelude::*;
#[wasm_bindgen]
#[derive(Debug, Clone)]
pub struct ShaderModule {
name: String,
source: String,
entry_points: Vec<String>,
}
#[wasm_bindgen]
impl ShaderModule {
/// Get the matrix multiplication shader module
#[wasm_bindgen(js_name = matmul)]
pub fn get_matmul() -> ShaderModule {
ShaderModule {
name: "matmul".to_string(),
source: MATMUL_SHADER.to_string(),
entry_points: vec![
matmul::MAIN.to_string(),
matmul::BATCHED.to_string(),
matmul::GEMV.to_string(),
],
}
}
/// Get the attention shader module
#[wasm_bindgen(js_name = attention)]
pub fn get_attention() -> ShaderModule {
ShaderModule {
name: "attention".to_string(),
source: ATTENTION_SHADER.to_string(),
entry_points: vec![
attention::MAIN.to_string(),
attention::GQA.to_string(),
attention::DECODE.to_string(),
],
}
}
/// Get the normalization shader module
#[wasm_bindgen(js_name = norm)]
pub fn get_norm() -> ShaderModule {
ShaderModule {
name: "norm".to_string(),
source: NORM_SHADER.to_string(),
entry_points: vec![
norm::RMS_NORM.to_string(),
norm::RMS_NORM_RESIDUAL.to_string(),
norm::LAYER_NORM.to_string(),
norm::RMS_NORM_SMALL.to_string(),
],
}
}
/// Get the softmax shader module
#[wasm_bindgen(js_name = softmax)]
pub fn get_softmax() -> ShaderModule {
ShaderModule {
name: "softmax".to_string(),
source: SOFTMAX_SHADER.to_string(),
entry_points: vec![
softmax::MAIN.to_string(),
softmax::INPLACE.to_string(),
softmax::SMALL.to_string(),
softmax::LOG_SOFTMAX.to_string(),
],
}
}
/// Get shader name
#[wasm_bindgen(getter)]
pub fn name(&self) -> String {
self.name.clone()
}
/// Get shader source code
#[wasm_bindgen(getter)]
pub fn source(&self) -> String {
self.source.clone()
}
/// Get available entry points
#[wasm_bindgen(getter, js_name = entryPoints)]
pub fn entry_points(&self) -> Vec<String> {
self.entry_points.clone()
}
/// Check if an entry point exists
#[wasm_bindgen(js_name = hasEntryPoint)]
pub fn has_entry_point(&self, name: &str) -> bool {
self.entry_points.iter().any(|ep| ep == name)
}
}
/// Get all available shader modules
#[wasm_bindgen(js_name = getAllShaderModules)]
pub fn get_all_shader_modules() -> Vec<ShaderModule> {
vec![
ShaderModule::get_matmul(),
ShaderModule::get_attention(),
ShaderModule::get_norm(),
ShaderModule::get_softmax(),
]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_shader_sources_not_empty() {
assert!(!MATMUL_SHADER.is_empty());
assert!(!ATTENTION_SHADER.is_empty());
assert!(!NORM_SHADER.is_empty());
assert!(!SOFTMAX_SHADER.is_empty());
}
#[test]
fn test_shader_module_creation() {
let matmul = ShaderModule::get_matmul();
assert_eq!(matmul.name(), "matmul");
assert!(matmul.has_entry_point("main"));
assert!(matmul.has_entry_point("main_batched"));
}
#[test]
fn test_all_shader_modules() {
let modules = get_all_shader_modules();
assert_eq!(modules.len(), 4);
}
}

View File

@@ -0,0 +1,283 @@
// Flash Attention Shader for WebGPU WASM
//
// Implements memory-efficient attention using online softmax algorithm.
// Supports causal masking for autoregressive generation.
//
// Algorithm:
// 1. Process Q in blocks, streaming K and V
// 2. Maintain running max and sum for numerical stability
// 3. Rescale outputs on-the-fly (Flash Attention v2)
// 4. O(n) memory vs O(n^2) for standard attention
//
// Memory Layout:
// - Q: (seq_len, num_heads, head_dim)
// - K: (seq_len, num_heads, head_dim)
// - V: (seq_len, num_heads, head_dim)
// - Output: (seq_len, num_heads, head_dim)
const BLOCK_SIZE: u32 = 32u; // Reduced for WebGPU limits
const MAX_HEAD_DIM: u32 = 128u;
struct AttentionUniforms {
seq_len: u32,
head_dim: u32,
num_heads: u32,
scale: f32, // 1/sqrt(head_dim)
causal_mask: u32, // 1 for causal, 0 for full attention
kv_seq_len: u32, // For encoder-decoder or prefill
_pad0: u32,
_pad1: u32,
}
@group(0) @binding(0) var<storage, read> Q: array<f32>;
@group(0) @binding(1) var<storage, read> K: array<f32>;
@group(0) @binding(2) var<storage, read> V: array<f32>;
@group(0) @binding(3) var<storage, read_write> Output: array<f32>;
@group(0) @binding(4) var<uniform> uniforms: AttentionUniforms;
// Shared memory for blocks
var<workgroup> Q_shared: array<f32, 4096>; // BLOCK_SIZE * MAX_HEAD_DIM
var<workgroup> K_shared: array<f32, 4096>;
var<workgroup> V_shared: array<f32, 4096>;
var<workgroup> scores_shared: array<f32, 1024>; // BLOCK_SIZE * BLOCK_SIZE
// Thread-local state for online softmax
var<private> m_i: f32; // Running max
var<private> l_i: f32; // Running sum
var<private> o_i: array<f32, 128>; // Output accumulator
@compute @workgroup_size(32, 1, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) group_id: vec3<u32>,
) {
let seq_len = uniforms.seq_len;
let head_dim = uniforms.head_dim;
let num_heads = uniforms.num_heads;
let scale = uniforms.scale;
let is_causal = uniforms.causal_mask == 1u;
let kv_seq_len = uniforms.kv_seq_len;
// This workgroup handles one Q block for one head
let head_idx = group_id.y;
let q_block_idx = group_id.x;
let q_start = q_block_idx * BLOCK_SIZE;
let thread_id = local_id.x;
let hidden_stride = num_heads * head_dim;
// Initialize online softmax state
m_i = -1e10f;
l_i = 0.0f;
for (var d = 0u; d < head_dim; d++) {
o_i[d] = 0.0f;
}
// Load Q block into shared memory
let q_pos = q_start + thread_id;
if (q_pos < seq_len && thread_id < BLOCK_SIZE) {
for (var d = 0u; d < head_dim; d++) {
let q_idx = q_pos * hidden_stride + head_idx * head_dim + d;
Q_shared[thread_id * head_dim + d] = Q[q_idx];
}
}
workgroupBarrier();
// Iterate over K/V blocks
let num_kv_blocks = (kv_seq_len + BLOCK_SIZE - 1u) / BLOCK_SIZE;
for (var kv_block = 0u; kv_block < num_kv_blocks; kv_block++) {
let kv_start = kv_block * BLOCK_SIZE;
// Early exit for causal attention
if (is_causal && kv_start > q_start + BLOCK_SIZE) {
break;
}
// Load K block
let k_pos = kv_start + thread_id;
if (k_pos < kv_seq_len && thread_id < BLOCK_SIZE) {
for (var d = 0u; d < head_dim; d++) {
let k_idx = k_pos * hidden_stride + head_idx * head_dim + d;
K_shared[thread_id * head_dim + d] = K[k_idx];
}
}
// Load V block
let v_pos = kv_start + thread_id;
if (v_pos < kv_seq_len && thread_id < BLOCK_SIZE) {
for (var d = 0u; d < head_dim; d++) {
let v_idx = v_pos * hidden_stride + head_idx * head_dim + d;
V_shared[thread_id * head_dim + d] = V[v_idx];
}
}
workgroupBarrier();
// Compute attention scores and update online softmax
if (thread_id < BLOCK_SIZE && q_pos < seq_len) {
let kv_block_len = min(BLOCK_SIZE, kv_seq_len - kv_start);
// Compute row max for this block
var block_max = -1e10f;
var local_scores: array<f32, 32>;
for (var k = 0u; k < kv_block_len; k++) {
let k_global = kv_start + k;
// Apply causal mask
if (is_causal && k_global > q_pos) {
local_scores[k] = -1e10f;
continue;
}
// Compute Q[q_pos] dot K[k]
var score = 0.0f;
for (var d = 0u; d < head_dim; d++) {
score += Q_shared[thread_id * head_dim + d] * K_shared[k * head_dim + d];
}
score *= scale;
local_scores[k] = score;
block_max = max(block_max, score);
}
// Update running statistics
let m_ij = max(m_i, block_max);
// Rescale previous accumulator
let alpha = exp(m_i - m_ij);
for (var d = 0u; d < head_dim; d++) {
o_i[d] *= alpha;
}
l_i *= alpha;
// Accumulate weighted V for this block
for (var k = 0u; k < kv_block_len; k++) {
let k_global = kv_start + k;
if (is_causal && k_global > q_pos) {
continue;
}
let p_ij = exp(local_scores[k] - m_ij);
l_i += p_ij;
for (var d = 0u; d < head_dim; d++) {
o_i[d] += p_ij * V_shared[k * head_dim + d];
}
}
m_i = m_ij;
}
workgroupBarrier();
}
// Normalize and write output
if (thread_id < BLOCK_SIZE && q_pos < seq_len) {
let inv_l = select(1.0f / l_i, 0.0f, l_i == 0.0f);
for (var d = 0u; d < head_dim; d++) {
let out_idx = q_pos * hidden_stride + head_idx * head_dim + d;
Output[out_idx] = o_i[d] * inv_l;
}
}
}
// Grouped Query Attention (GQA) variant
// Multiple Q heads share same K/V heads
@compute @workgroup_size(32, 1, 1)
fn main_gqa(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) group_id: vec3<u32>,
) {
// For GQA: kv_head_idx = q_head_idx / num_q_per_kv
// This allows Llama2/3 style grouped attention
// Implementation similar to main() with modified indexing
}
// Single token attention for generation phase
// More efficient when seq_len = 1 (decoding)
@compute @workgroup_size(256, 1, 1)
fn main_decode(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) group_id: vec3<u32>,
) {
let head_dim = uniforms.head_dim;
let num_heads = uniforms.num_heads;
let scale = uniforms.scale;
let kv_seq_len = uniforms.kv_seq_len;
let is_causal = uniforms.causal_mask == 1u;
let head_idx = group_id.x;
let thread_id = local_id.x;
let hidden_stride = num_heads * head_dim;
// Each thread handles part of the KV sequence
let kv_per_thread = (kv_seq_len + 255u) / 256u;
// Thread-local accumulators
var local_max = -1e10f;
var local_sum = 0.0f;
var local_out: array<f32, 128>;
for (var d = 0u; d < head_dim; d++) {
local_out[d] = 0.0f;
}
// Load Q (single token)
var q_vec: array<f32, 128>;
if (thread_id == 0u) {
for (var d = 0u; d < head_dim; d++) {
q_vec[d] = Q[head_idx * head_dim + d];
}
}
// Broadcast Q to all threads via shared memory
for (var d = 0u; d < head_dim; d++) {
Q_shared[d] = Q[head_idx * head_dim + d];
}
workgroupBarrier();
// Process assigned KV positions
for (var i = 0u; i < kv_per_thread; i++) {
let k_pos = thread_id * kv_per_thread + i;
if (k_pos >= kv_seq_len) {
break;
}
// Compute attention score
var score = 0.0f;
for (var d = 0u; d < head_dim; d++) {
let k_idx = k_pos * hidden_stride + head_idx * head_dim + d;
score += Q_shared[d] * K[k_idx];
}
score *= scale;
// Update local max
let new_max = max(local_max, score);
let alpha = exp(local_max - new_max);
for (var d = 0u; d < head_dim; d++) {
local_out[d] *= alpha;
}
local_sum = local_sum * alpha + exp(score - new_max);
// Accumulate weighted V
let p = exp(score - new_max);
for (var d = 0u; d < head_dim; d++) {
let v_idx = k_pos * hidden_stride + head_idx * head_dim + d;
local_out[d] += p * V[v_idx];
}
local_max = new_max;
}
// Reduction across threads (simplified - real impl would use parallel reduction)
// Store partial results for CPU reduction or use atomics
if (thread_id == 0u) {
let inv_sum = select(1.0f / local_sum, 0.0f, local_sum == 0.0f);
for (var d = 0u; d < head_dim; d++) {
Output[head_idx * head_dim + d] = local_out[d] * inv_sum;
}
}
}

View File

@@ -0,0 +1,182 @@
// Tiled Matrix Multiplication Shader for WebGPU WASM
//
// Computes C = A * B using 16x16 tiles optimized for browser WebGPU.
// Uses workgroup shared memory for cache-efficient tile loading.
//
// Memory Layout (row-major):
// - A: M x K matrix
// - B: K x N matrix
// - C: M x N matrix (output)
// Tile size optimized for WebGPU limits
const TILE_SIZE: u32 = 16u;
struct Uniforms {
M: u32, // Rows of A, rows of C
N: u32, // Cols of B, cols of C
K: u32, // Cols of A, rows of B
alpha: f32, // Scaling factor (default 1.0)
}
@group(0) @binding(0) var<storage, read> A: array<f32>;
@group(0) @binding(1) var<storage, read> B: array<f32>;
@group(0) @binding(2) var<storage, read_write> C: array<f32>;
@group(0) @binding(3) var<uniform> uniforms: Uniforms;
// Shared memory for tile caching
var<workgroup> A_tile: array<f32, 256>; // TILE_SIZE * TILE_SIZE
var<workgroup> B_tile: array<f32, 256>;
@compute @workgroup_size(16, 16, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) group_id: vec3<u32>,
) {
let M = uniforms.M;
let N = uniforms.N;
let K = uniforms.K;
let alpha = uniforms.alpha;
// Global row and column
let row = global_id.x;
let col = global_id.y;
// Thread position within tile
let local_row = local_id.x;
let local_col = local_id.y;
// Accumulator for this thread's output element
var sum = 0.0f;
// Number of tiles to process along K dimension
let num_tiles = (K + TILE_SIZE - 1u) / TILE_SIZE;
// Iterate over tiles
for (var t = 0u; t < num_tiles; t++) {
let tile_k = t * TILE_SIZE;
// Load A tile element
let a_row = row;
let a_col = tile_k + local_col;
if (a_row < M && a_col < K) {
A_tile[local_row * TILE_SIZE + local_col] = A[a_row * K + a_col];
} else {
A_tile[local_row * TILE_SIZE + local_col] = 0.0;
}
// Load B tile element
let b_row = tile_k + local_row;
let b_col = col;
if (b_row < K && b_col < N) {
B_tile[local_row * TILE_SIZE + local_col] = B[b_row * N + b_col];
} else {
B_tile[local_row * TILE_SIZE + local_col] = 0.0;
}
// Synchronize to ensure tile is fully loaded
workgroupBarrier();
// Compute partial dot product for this tile
let tile_k_end = min(TILE_SIZE, K - tile_k);
for (var k = 0u; k < tile_k_end; k++) {
sum += A_tile[local_row * TILE_SIZE + k] * B_tile[k * TILE_SIZE + local_col];
}
// Synchronize before loading next tile
workgroupBarrier();
}
// Write result with optional scaling
if (row < M && col < N) {
C[row * N + col] = sum * alpha;
}
}
// Batched matrix multiply for multi-head attention projections
// C[b] = A[b] * B where A is batch_size x M x K and B is K x N
@compute @workgroup_size(16, 16, 1)
fn main_batched(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) group_id: vec3<u32>,
) {
let M = uniforms.M;
let N = uniforms.N;
let K = uniforms.K;
let batch_idx = group_id.z;
let row = global_id.x;
let col = global_id.y;
let local_row = local_id.x;
let local_col = local_id.y;
var sum = 0.0f;
let num_tiles = (K + TILE_SIZE - 1u) / TILE_SIZE;
// Offset into batched A
let batch_offset_a = batch_idx * M * K;
let batch_offset_c = batch_idx * M * N;
for (var t = 0u; t < num_tiles; t++) {
let tile_k = t * TILE_SIZE;
// Load A tile (batched)
let a_row = row;
let a_col = tile_k + local_col;
if (a_row < M && a_col < K) {
A_tile[local_row * TILE_SIZE + local_col] = A[batch_offset_a + a_row * K + a_col];
} else {
A_tile[local_row * TILE_SIZE + local_col] = 0.0;
}
// Load B tile (shared across batch)
let b_row = tile_k + local_row;
let b_col = col;
if (b_row < K && b_col < N) {
B_tile[local_row * TILE_SIZE + local_col] = B[b_row * N + b_col];
} else {
B_tile[local_row * TILE_SIZE + local_col] = 0.0;
}
workgroupBarrier();
let tile_k_end = min(TILE_SIZE, K - tile_k);
for (var k = 0u; k < tile_k_end; k++) {
sum += A_tile[local_row * TILE_SIZE + k] * B_tile[k * TILE_SIZE + local_col];
}
workgroupBarrier();
}
if (row < M && col < N) {
C[batch_offset_c + row * N + col] = sum;
}
}
// Vector-matrix multiply optimized for single token generation
// y = x * W where x is 1 x K and W is K x N
@compute @workgroup_size(256, 1, 1)
fn main_gemv(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
) {
let K = uniforms.K;
let N = uniforms.N;
let col = global_id.x;
if (col >= N) {
return;
}
var sum = 0.0f;
// Simple reduction - each thread computes one output element
for (var k = 0u; k < K; k++) {
sum += A[k] * B[k * N + col];
}
C[col] = sum * uniforms.alpha;
}

View File

@@ -0,0 +1,235 @@
// RMSNorm and LayerNorm Shaders for WebGPU WASM
//
// Implements normalization layers used in transformer architectures:
// - RMSNorm: Used in Llama, Mistral (no mean subtraction)
// - LayerNorm: Standard transformer normalization
//
// RMSNorm: y = x / sqrt(mean(x^2) + eps) * weight
// LayerNorm: y = (x - mean) / sqrt(var + eps) * weight + bias
const WARP_SIZE: u32 = 32u;
const MAX_DIM: u32 = 8192u;
struct NormUniforms {
hidden_dim: u32,
batch_size: u32,
eps: f32,
_pad: u32,
}
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read> weight: array<f32>;
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
@group(0) @binding(3) var<uniform> uniforms: NormUniforms;
// Shared memory for parallel reduction
var<workgroup> partial_sums: array<f32, 256>;
// RMSNorm: y = x * rsqrt(mean(x^2) + eps) * weight
@compute @workgroup_size(256, 1, 1)
fn rms_norm(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) group_id: vec3<u32>,
) {
let hidden_dim = uniforms.hidden_dim;
let eps = uniforms.eps;
let batch_idx = group_id.x;
let thread_id = local_id.x;
let offset = batch_idx * hidden_dim;
// Each thread computes partial sum of squares
var thread_sum = 0.0f;
let elements_per_thread = (hidden_dim + 255u) / 256u;
for (var i = 0u; i < elements_per_thread; i++) {
let idx = thread_id + i * 256u;
if (idx < hidden_dim) {
let x = input[offset + idx];
thread_sum += x * x;
}
}
// Store partial sum
partial_sums[thread_id] = thread_sum;
workgroupBarrier();
// Parallel reduction for sum of squares
for (var stride = 128u; stride > 0u; stride >>= 1u) {
if (thread_id < stride) {
partial_sums[thread_id] += partial_sums[thread_id + stride];
}
workgroupBarrier();
}
// Compute RMS scale factor
let mean_sq = partial_sums[0] / f32(hidden_dim);
let rms_scale = 1.0f / sqrt(mean_sq + eps);
workgroupBarrier();
// Apply normalization and weight
for (var i = 0u; i < elements_per_thread; i++) {
let idx = thread_id + i * 256u;
if (idx < hidden_dim) {
let x = input[offset + idx];
output[offset + idx] = x * rms_scale * weight[idx];
}
}
}
// Fused RMSNorm + Residual: y = (x + residual) * rsqrt(mean((x+res)^2) + eps) * weight
@compute @workgroup_size(256, 1, 1)
fn rms_norm_residual(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) group_id: vec3<u32>,
) {
let hidden_dim = uniforms.hidden_dim;
let eps = uniforms.eps;
let batch_idx = group_id.x;
let thread_id = local_id.x;
let offset = batch_idx * hidden_dim;
// Compute partial sum of (x + residual)^2
var thread_sum = 0.0f;
let elements_per_thread = (hidden_dim + 255u) / 256u;
// First pass: compute residual sum and store in shared for reduction
// Note: residual is passed in output buffer for in-place update
for (var i = 0u; i < elements_per_thread; i++) {
let idx = thread_id + i * 256u;
if (idx < hidden_dim) {
let x = input[offset + idx] + output[offset + idx]; // x + residual
thread_sum += x * x;
}
}
partial_sums[thread_id] = thread_sum;
workgroupBarrier();
// Parallel reduction
for (var stride = 128u; stride > 0u; stride >>= 1u) {
if (thread_id < stride) {
partial_sums[thread_id] += partial_sums[thread_id + stride];
}
workgroupBarrier();
}
let mean_sq = partial_sums[0] / f32(hidden_dim);
let rms_scale = 1.0f / sqrt(mean_sq + eps);
workgroupBarrier();
// Apply normalization
for (var i = 0u; i < elements_per_thread; i++) {
let idx = thread_id + i * 256u;
if (idx < hidden_dim) {
let x = input[offset + idx] + output[offset + idx];
output[offset + idx] = x * rms_scale * weight[idx];
}
}
}
// Standard LayerNorm with bias
@group(0) @binding(4) var<storage, read> bias: array<f32>;
@compute @workgroup_size(256, 1, 1)
fn layer_norm(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) group_id: vec3<u32>,
) {
let hidden_dim = uniforms.hidden_dim;
let eps = uniforms.eps;
let batch_idx = group_id.x;
let thread_id = local_id.x;
let offset = batch_idx * hidden_dim;
let elements_per_thread = (hidden_dim + 255u) / 256u;
// First pass: compute mean
var thread_sum = 0.0f;
for (var i = 0u; i < elements_per_thread; i++) {
let idx = thread_id + i * 256u;
if (idx < hidden_dim) {
thread_sum += input[offset + idx];
}
}
partial_sums[thread_id] = thread_sum;
workgroupBarrier();
for (var stride = 128u; stride > 0u; stride >>= 1u) {
if (thread_id < stride) {
partial_sums[thread_id] += partial_sums[thread_id + stride];
}
workgroupBarrier();
}
let mean = partial_sums[0] / f32(hidden_dim);
workgroupBarrier();
// Second pass: compute variance
var thread_var = 0.0f;
for (var i = 0u; i < elements_per_thread; i++) {
let idx = thread_id + i * 256u;
if (idx < hidden_dim) {
let diff = input[offset + idx] - mean;
thread_var += diff * diff;
}
}
partial_sums[thread_id] = thread_var;
workgroupBarrier();
for (var stride = 128u; stride > 0u; stride >>= 1u) {
if (thread_id < stride) {
partial_sums[thread_id] += partial_sums[thread_id + stride];
}
workgroupBarrier();
}
let variance = partial_sums[0] / f32(hidden_dim);
let inv_std = 1.0f / sqrt(variance + eps);
workgroupBarrier();
// Third pass: normalize and apply affine transform
for (var i = 0u; i < elements_per_thread; i++) {
let idx = thread_id + i * 256u;
if (idx < hidden_dim) {
let x = input[offset + idx];
output[offset + idx] = (x - mean) * inv_std * weight[idx] + bias[idx];
}
}
}
// Fast RMSNorm for small hidden dimensions (direct reduction)
@compute @workgroup_size(128, 1, 1)
fn rms_norm_small(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) group_id: vec3<u32>,
) {
let hidden_dim = uniforms.hidden_dim;
let eps = uniforms.eps;
let batch_idx = group_id.x;
let thread_id = local_id.x;
let offset = batch_idx * hidden_dim;
// For small hidden_dim (<= 128), direct computation
if (thread_id < hidden_dim) {
// Compute sum of squares (all threads contribute)
var sum_sq = 0.0f;
for (var i = 0u; i < hidden_dim; i++) {
let x = input[offset + i];
sum_sq += x * x;
}
let rms = sqrt(sum_sq / f32(hidden_dim) + eps);
let x = input[offset + thread_id];
output[offset + thread_id] = x / rms * weight[thread_id];
}
}

View File

@@ -0,0 +1,288 @@
// Softmax Shader for WebGPU WASM
//
// Numerically stable softmax: y = exp(x - max(x)) / sum(exp(x - max(x)))
// Uses parallel reduction for finding max and computing sum.
//
// Variants:
// - Full softmax for attention scores
// - Temperature-scaled softmax for sampling
// - Top-k softmax for efficient sampling
const MAX_SEQ_LEN: u32 = 8192u;
struct SoftmaxUniforms {
dim: u32, // Dimension to reduce over
batch_size: u32, // Number of rows
temperature: f32, // Scaling factor (1.0 for standard)
top_k: u32, // 0 for full softmax, >0 for top-k
}
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
@group(0) @binding(2) var<uniform> uniforms: SoftmaxUniforms;
// Shared memory for reductions
var<workgroup> reduction_buf: array<f32, 256>;
// Standard row-wise softmax
@compute @workgroup_size(256, 1, 1)
fn softmax(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) group_id: vec3<u32>,
) {
let dim = uniforms.dim;
let temperature = uniforms.temperature;
let batch_idx = group_id.x;
let thread_id = local_id.x;
let offset = batch_idx * dim;
let elements_per_thread = (dim + 255u) / 256u;
// Phase 1: Find max value
var thread_max = -1e10f;
for (var i = 0u; i < elements_per_thread; i++) {
let idx = thread_id + i * 256u;
if (idx < dim) {
thread_max = max(thread_max, input[offset + idx] / temperature);
}
}
reduction_buf[thread_id] = thread_max;
workgroupBarrier();
// Parallel max reduction
for (var stride = 128u; stride > 0u; stride >>= 1u) {
if (thread_id < stride) {
reduction_buf[thread_id] = max(reduction_buf[thread_id], reduction_buf[thread_id + stride]);
}
workgroupBarrier();
}
let max_val = reduction_buf[0];
workgroupBarrier();
// Phase 2: Compute sum of exp(x - max)
var thread_sum = 0.0f;
for (var i = 0u; i < elements_per_thread; i++) {
let idx = thread_id + i * 256u;
if (idx < dim) {
let x = input[offset + idx] / temperature - max_val;
thread_sum += exp(x);
}
}
reduction_buf[thread_id] = thread_sum;
workgroupBarrier();
// Parallel sum reduction
for (var stride = 128u; stride > 0u; stride >>= 1u) {
if (thread_id < stride) {
reduction_buf[thread_id] += reduction_buf[thread_id + stride];
}
workgroupBarrier();
}
let sum_val = reduction_buf[0];
let inv_sum = 1.0f / sum_val;
workgroupBarrier();
// Phase 3: Compute normalized softmax
for (var i = 0u; i < elements_per_thread; i++) {
let idx = thread_id + i * 256u;
if (idx < dim) {
let x = input[offset + idx] / temperature - max_val;
output[offset + idx] = exp(x) * inv_sum;
}
}
}
// In-place softmax (input and output point to same buffer)
@compute @workgroup_size(256, 1, 1)
fn softmax_inplace(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) group_id: vec3<u32>,
) {
let dim = uniforms.dim;
let temperature = uniforms.temperature;
let batch_idx = group_id.x;
let thread_id = local_id.x;
let offset = batch_idx * dim;
let elements_per_thread = (dim + 255u) / 256u;
// Find max
var thread_max = -1e10f;
for (var i = 0u; i < elements_per_thread; i++) {
let idx = thread_id + i * 256u;
if (idx < dim) {
thread_max = max(thread_max, output[offset + idx] / temperature);
}
}
reduction_buf[thread_id] = thread_max;
workgroupBarrier();
for (var stride = 128u; stride > 0u; stride >>= 1u) {
if (thread_id < stride) {
reduction_buf[thread_id] = max(reduction_buf[thread_id], reduction_buf[thread_id + stride]);
}
workgroupBarrier();
}
let max_val = reduction_buf[0];
workgroupBarrier();
// Compute exp and sum
var thread_sum = 0.0f;
for (var i = 0u; i < elements_per_thread; i++) {
let idx = thread_id + i * 256u;
if (idx < dim) {
let x = exp(output[offset + idx] / temperature - max_val);
output[offset + idx] = x; // Store intermediate exp value
thread_sum += x;
}
}
reduction_buf[thread_id] = thread_sum;
workgroupBarrier();
for (var stride = 128u; stride > 0u; stride >>= 1u) {
if (thread_id < stride) {
reduction_buf[thread_id] += reduction_buf[thread_id + stride];
}
workgroupBarrier();
}
let inv_sum = 1.0f / reduction_buf[0];
workgroupBarrier();
// Normalize in place
for (var i = 0u; i < elements_per_thread; i++) {
let idx = thread_id + i * 256u;
if (idx < dim) {
output[offset + idx] *= inv_sum;
}
}
}
// Small dimension softmax (dim <= 256)
@compute @workgroup_size(256, 1, 1)
fn softmax_small(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) group_id: vec3<u32>,
) {
let dim = uniforms.dim;
let temperature = uniforms.temperature;
let batch_idx = group_id.x;
let thread_id = local_id.x;
let offset = batch_idx * dim;
// Load value for this thread
var x = -1e10f;
if (thread_id < dim) {
x = input[offset + thread_id] / temperature;
}
reduction_buf[thread_id] = x;
workgroupBarrier();
// Find max using warp-level operations
var max_val = x;
for (var i = 0u; i < dim; i++) {
max_val = max(max_val, reduction_buf[i]);
}
workgroupBarrier();
// Compute exp and sum
var exp_val = 0.0f;
if (thread_id < dim) {
exp_val = exp(x - max_val);
}
reduction_buf[thread_id] = exp_val;
workgroupBarrier();
var sum_val = 0.0f;
for (var i = 0u; i < dim; i++) {
sum_val += reduction_buf[i];
}
// Write normalized output
if (thread_id < dim) {
output[offset + thread_id] = exp_val / sum_val;
}
}
// Log softmax for numerical stability in loss computation
@compute @workgroup_size(256, 1, 1)
fn log_softmax(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) group_id: vec3<u32>,
) {
let dim = uniforms.dim;
let temperature = uniforms.temperature;
let batch_idx = group_id.x;
let thread_id = local_id.x;
let offset = batch_idx * dim;
let elements_per_thread = (dim + 255u) / 256u;
// Find max
var thread_max = -1e10f;
for (var i = 0u; i < elements_per_thread; i++) {
let idx = thread_id + i * 256u;
if (idx < dim) {
thread_max = max(thread_max, input[offset + idx] / temperature);
}
}
reduction_buf[thread_id] = thread_max;
workgroupBarrier();
for (var stride = 128u; stride > 0u; stride >>= 1u) {
if (thread_id < stride) {
reduction_buf[thread_id] = max(reduction_buf[thread_id], reduction_buf[thread_id + stride]);
}
workgroupBarrier();
}
let max_val = reduction_buf[0];
workgroupBarrier();
// Compute log-sum-exp
var thread_sum = 0.0f;
for (var i = 0u; i < elements_per_thread; i++) {
let idx = thread_id + i * 256u;
if (idx < dim) {
thread_sum += exp(input[offset + idx] / temperature - max_val);
}
}
reduction_buf[thread_id] = thread_sum;
workgroupBarrier();
for (var stride = 128u; stride > 0u; stride >>= 1u) {
if (thread_id < stride) {
reduction_buf[thread_id] += reduction_buf[thread_id + stride];
}
workgroupBarrier();
}
let log_sum = log(reduction_buf[0]) + max_val;
workgroupBarrier();
// Compute log softmax: log(softmax(x)) = x - log_sum_exp(x)
for (var i = 0u; i < elements_per_thread; i++) {
let idx = thread_id + i * 256u;
if (idx < dim) {
output[offset + idx] = input[offset + idx] / temperature - log_sum;
}
}
}