Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,283 @@
//! Compute backend detection and abstraction
//!
//! Detects available compute capabilities (WebGPU, WebGL2, WebWorkers)
//! and provides a unified interface for selecting the best backend.
use wasm_bindgen::prelude::*;
/// Compute capabilities detected on the current device
#[derive(Clone, Debug)]
pub struct ComputeCapability {
/// WebGPU is available (best performance)
pub has_webgpu: bool,
/// WebGL2 is available (fallback for GPU compute)
pub has_webgl2: bool,
/// WebGL2 supports floating point textures
pub has_float_textures: bool,
/// Transform feedback is available (for GPU readback)
pub has_transform_feedback: bool,
/// WebWorkers are available
pub has_workers: bool,
/// SharedArrayBuffer is available (for shared memory)
pub has_shared_memory: bool,
/// Number of logical CPU cores
pub worker_count: usize,
/// Maximum texture size (for WebGL2)
pub max_texture_size: u32,
/// Estimated GPU memory (MB)
pub gpu_memory_mb: u32,
/// Device description
pub device_info: String,
}
impl ComputeCapability {
/// Convert to JavaScript object
pub fn to_js(&self) -> JsValue {
let obj = js_sys::Object::new();
js_sys::Reflect::set(&obj, &"hasWebGPU".into(), &self.has_webgpu.into()).ok();
js_sys::Reflect::set(&obj, &"hasWebGL2".into(), &self.has_webgl2.into()).ok();
js_sys::Reflect::set(&obj, &"hasFloatTextures".into(), &self.has_float_textures.into()).ok();
js_sys::Reflect::set(&obj, &"hasTransformFeedback".into(), &self.has_transform_feedback.into()).ok();
js_sys::Reflect::set(&obj, &"hasWorkers".into(), &self.has_workers.into()).ok();
js_sys::Reflect::set(&obj, &"hasSharedMemory".into(), &self.has_shared_memory.into()).ok();
js_sys::Reflect::set(&obj, &"workerCount".into(), &(self.worker_count as u32).into()).ok();
js_sys::Reflect::set(&obj, &"maxTextureSize".into(), &self.max_texture_size.into()).ok();
js_sys::Reflect::set(&obj, &"gpuMemoryMB".into(), &self.gpu_memory_mb.into()).ok();
js_sys::Reflect::set(&obj, &"deviceInfo".into(), &self.device_info.clone().into()).ok();
obj.into()
}
/// Get recommended backend for a given operation size
pub fn recommend_backend(&self, operation_size: usize) -> ComputeBackend {
// WebGPU is always preferred if available
if self.has_webgpu {
return ComputeBackend::WebGPU;
}
// For large operations, prefer GPU
if operation_size > 4096 && self.has_webgl2 && self.has_float_textures {
return ComputeBackend::WebGL2;
}
// For medium operations with multiple cores, use workers
if operation_size > 1024 && self.has_workers && self.worker_count > 1 {
return ComputeBackend::WebWorkers;
}
// Fall back to single-threaded CPU
ComputeBackend::CPU
}
}
/// Available compute backends
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ComputeBackend {
/// WebGPU compute shaders (best performance)
WebGPU,
/// WebGL2 texture-based compute (fallback GPU)
WebGL2,
/// WebWorker pool (CPU parallelism)
WebWorkers,
/// Single-threaded CPU (last resort)
CPU,
}
impl ComputeBackend {
/// Get backend name
pub fn name(&self) -> &'static str {
match self {
ComputeBackend::WebGPU => "WebGPU",
ComputeBackend::WebGL2 => "WebGL2",
ComputeBackend::WebWorkers => "WebWorkers",
ComputeBackend::CPU => "CPU",
}
}
/// Get relative performance (higher is better)
pub fn relative_performance(&self) -> f32 {
match self {
ComputeBackend::WebGPU => 10.0,
ComputeBackend::WebGL2 => 5.0,
ComputeBackend::WebWorkers => 2.0,
ComputeBackend::CPU => 1.0,
}
}
}
/// Detect compute capabilities on the current device
pub fn detect_capabilities() -> Result<ComputeCapability, JsValue> {
let window = web_sys::window()
.ok_or_else(|| JsValue::from_str("No window object"))?;
let navigator = window.navigator();
// Detect WebGPU
let has_webgpu = js_sys::Reflect::has(&navigator, &"gpu".into())
.unwrap_or(false);
// Detect WebWorkers
let has_workers = js_sys::Reflect::has(&window, &"Worker".into())
.unwrap_or(false);
// Detect SharedArrayBuffer
let has_shared_memory = js_sys::Reflect::has(&window, &"SharedArrayBuffer".into())
.unwrap_or(false);
// Get hardware concurrency (CPU cores)
let worker_count = navigator.hardware_concurrency() as usize;
// Detect WebGL2 capabilities
let document = window.document()
.ok_or_else(|| JsValue::from_str("No document"))?;
let (has_webgl2, has_float_textures, has_transform_feedback, max_texture_size, gpu_memory_mb, device_info) =
detect_webgl2_capabilities(&document)?;
Ok(ComputeCapability {
has_webgpu,
has_webgl2,
has_float_textures,
has_transform_feedback,
has_workers,
has_shared_memory,
worker_count: worker_count.max(1),
max_texture_size,
gpu_memory_mb,
device_info,
})
}
/// Detect WebGL2-specific capabilities
fn detect_webgl2_capabilities(document: &web_sys::Document) -> Result<(bool, bool, bool, u32, u32, String), JsValue> {
// Create a temporary canvas to probe WebGL2
let canvas = document.create_element("canvas")?;
let canvas: web_sys::HtmlCanvasElement = canvas.dyn_into()?;
// Try to get WebGL2 context
let context = match canvas.get_context("webgl2")? {
Some(ctx) => ctx,
None => return Ok((false, false, false, 0, 0, "No WebGL2".to_string())),
};
let gl: web_sys::WebGl2RenderingContext = context.dyn_into()?;
// Check for float texture support (required for compute)
let ext_color_buffer_float = gl.get_extension("EXT_color_buffer_float")?;
let has_float_textures = ext_color_buffer_float.is_some();
// Transform feedback is built into WebGL2
let has_transform_feedback = true;
// Get max texture size
let max_texture_size = gl.get_parameter(web_sys::WebGl2RenderingContext::MAX_TEXTURE_SIZE)?
.as_f64()
.unwrap_or(4096.0) as u32;
// Try to get GPU memory info (vendor-specific)
let gpu_memory_mb = get_gpu_memory_mb(&gl);
// Get renderer info
let renderer_info = gl.get_extension("WEBGL_debug_renderer_info")?;
let device_info = if renderer_info.is_some() {
// UNMASKED_RENDERER_WEBGL = 0x9246
let renderer = gl.get_parameter(0x9246)?;
renderer.as_string().unwrap_or_else(|| "Unknown GPU".to_string())
} else {
"Unknown GPU".to_string()
};
Ok((true, has_float_textures, has_transform_feedback, max_texture_size, gpu_memory_mb, device_info))
}
/// Try to get GPU memory size (vendor-specific extension)
fn get_gpu_memory_mb(gl: &web_sys::WebGl2RenderingContext) -> u32 {
// Try WEBGL_memory_info extension (available on some browsers)
if let Ok(Some(_ext)) = gl.get_extension("WEBGL_memory_info") {
// GPU_MEMORY_INFO_TOTAL_AVAILABLE_MEMORY_NVX = 0x9048
if let Ok(mem) = gl.get_parameter(0x9048) {
if let Some(kb) = mem.as_f64() {
return (kb / 1024.0) as u32;
}
}
}
// Default estimate based on typical mobile/desktop GPUs
// Most modern GPUs have at least 2GB
2048
}
/// Configuration for compute operations
#[derive(Clone, Debug)]
pub struct ComputeConfig {
/// Preferred backend (None = auto-select)
pub preferred_backend: Option<ComputeBackend>,
/// Maximum memory to use (bytes)
pub max_memory: usize,
/// Timeout for operations (ms)
pub timeout_ms: u32,
/// Enable profiling
pub profiling: bool,
}
impl Default for ComputeConfig {
fn default() -> Self {
ComputeConfig {
preferred_backend: None,
max_memory: 256 * 1024 * 1024, // 256MB
timeout_ms: 30_000, // 30 seconds
profiling: false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_backend_recommendation() {
let caps = ComputeCapability {
has_webgpu: false,
has_webgl2: true,
has_float_textures: true,
has_transform_feedback: true,
has_workers: true,
has_shared_memory: true,
worker_count: 4,
max_texture_size: 4096,
gpu_memory_mb: 2048,
device_info: "Test GPU".to_string(),
};
// Large operations should use WebGL2
assert_eq!(caps.recommend_backend(10000), ComputeBackend::WebGL2);
// Medium operations with workers should use workers
assert_eq!(caps.recommend_backend(2000), ComputeBackend::WebWorkers);
// Small operations should use CPU
assert_eq!(caps.recommend_backend(100), ComputeBackend::CPU);
}
#[test]
fn test_backend_with_webgpu() {
let caps = ComputeCapability {
has_webgpu: true,
has_webgl2: true,
has_float_textures: true,
has_transform_feedback: true,
has_workers: true,
has_shared_memory: true,
worker_count: 4,
max_texture_size: 4096,
gpu_memory_mb: 2048,
device_info: "Test GPU".to_string(),
};
// WebGPU should always be preferred
assert_eq!(caps.recommend_backend(100), ComputeBackend::WebGPU);
assert_eq!(caps.recommend_backend(10000), ComputeBackend::WebGPU);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,15 @@
//! SIMD Compute Backend for edge-net P2P AI Network
//!
//! Provides portable CPU acceleration with support for:
//! - WASM simd128 intrinsics (browser/WASM targets)
//! - x86_64 AVX2 intrinsics (native x86 targets)
//! - Scalar fallback for unsupported platforms
//!
//! Performance targets:
//! - 2,236+ ops/sec for MicroLoRA (rank-2)
//! - 150x faster HNSW search
//! - Q4 quantized inference
pub mod simd;
pub use simd::*;

View File

@@ -0,0 +1,233 @@
// Flash Attention Shader
//
// Implements memory-efficient attention using the Flash Attention algorithm.
// Target: 2ms for 4K context length.
//
// Algorithm (Flash Attention v2):
// 1. Process Q in blocks, streaming K and V
// 2. Maintain running max and sum for numerical stability
// 3. Rescale outputs on-the-fly
// 4. Avoid materializing full attention matrix (O(n) memory vs O(n^2))
//
// Memory Layout:
// - Q: (seq_len, num_heads * head_dim) - queries
// - K: (seq_len, num_heads * head_dim) - keys
// - V: (seq_len, num_heads * head_dim) - values
// - Output: (seq_len, num_heads * head_dim)
// Block size for flash attention (balance between parallelism and memory)
const BLOCK_SIZE: u32 = 64u;
const WARP_SIZE: u32 = 32u;
struct Uniforms {
seq_len: f32,
head_dim: f32,
num_heads: f32,
scale: f32, // 1/sqrt(head_dim)
causal_mask: f32, // 1.0 for causal, 0.0 for full
_pad0: f32,
_pad1: f32,
_pad2: f32,
}
@group(0) @binding(0) var<storage, read> Q: array<f32>;
@group(0) @binding(1) var<storage, read> K: array<f32>;
@group(0) @binding(2) var<storage, read> V: array<f32>;
@group(0) @binding(3) var<storage, read_write> Output: array<f32>;
@group(0) @binding(4) var<uniform> uniforms: Uniforms;
// Shared memory for Q, K, V blocks
var<workgroup> Q_block: array<f32, 4096>; // BLOCK_SIZE * 64 (max head_dim)
var<workgroup> K_block: array<f32, 4096>;
var<workgroup> V_block: array<f32, 4096>;
var<workgroup> scores: array<f32, 4096>; // BLOCK_SIZE * BLOCK_SIZE
// Thread-local accumulators
var<private> m_prev: f32; // Previous max score
var<private> l_prev: f32; // Previous sum of exp(scores - max)
var<private> acc: array<f32, 64>; // Output accumulator (head_dim)
// Compute softmax denominator using online algorithm
fn online_softmax_update(
new_max: f32,
old_max: f32,
old_sum: f32,
new_scores: ptr<function, array<f32, 64>>,
block_len: u32,
) -> f32 {
// Rescale old sum
var new_sum = old_sum * exp(old_max - new_max);
// Add new contributions
for (var i = 0u; i < block_len; i++) {
new_sum += exp((*new_scores)[i] - new_max);
}
return new_sum;
}
@compute @workgroup_size(64, 1, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) group_id: vec3<u32>,
) {
let seq_len = u32(uniforms.seq_len);
let head_dim = u32(uniforms.head_dim);
let num_heads = u32(uniforms.num_heads);
let scale = uniforms.scale;
let is_causal = uniforms.causal_mask > 0.5;
// This workgroup processes one block of Q for one head
let head_idx = group_id.y;
let q_block_idx = group_id.x;
let q_start = q_block_idx * BLOCK_SIZE;
let thread_id = local_id.x;
let hidden_dim = num_heads * head_dim;
// Initialize accumulators
m_prev = -1e10; // Very negative (will be updated)
l_prev = 0.0;
for (var i = 0u; i < 64u; i++) {
acc[i] = 0.0;
}
// Load Q block into shared memory
// Each thread loads one position's head_dim values
let q_pos = q_start + thread_id;
if (q_pos < seq_len && thread_id < BLOCK_SIZE) {
for (var d = 0u; d < head_dim; d++) {
let q_idx = q_pos * hidden_dim + head_idx * head_dim + d;
Q_block[thread_id * head_dim + d] = Q[q_idx];
}
}
workgroupBarrier();
// Iterate over K/V blocks
let num_kv_blocks = (seq_len + BLOCK_SIZE - 1u) / BLOCK_SIZE;
let max_kv_block = select(num_kv_blocks, q_block_idx + 1u, is_causal);
for (var kv_block_idx = 0u; kv_block_idx < max_kv_block; kv_block_idx++) {
let kv_start = kv_block_idx * BLOCK_SIZE;
// Load K block into shared memory
let k_pos = kv_start + thread_id;
if (k_pos < seq_len && thread_id < BLOCK_SIZE) {
for (var d = 0u; d < head_dim; d++) {
let k_idx = k_pos * hidden_dim + head_idx * head_dim + d;
K_block[thread_id * head_dim + d] = K[k_idx];
}
}
// Load V block into shared memory
let v_pos = kv_start + thread_id;
if (v_pos < seq_len && thread_id < BLOCK_SIZE) {
for (var d = 0u; d < head_dim; d++) {
let v_idx = v_pos * hidden_dim + head_idx * head_dim + d;
V_block[thread_id * head_dim + d] = V[v_idx];
}
}
workgroupBarrier();
// Compute attention scores for this Q position against all K in block
// Each thread handles one Q position
if (thread_id < BLOCK_SIZE && q_pos < seq_len) {
let kv_block_len = min(BLOCK_SIZE, seq_len - kv_start);
// Compute Q @ K^T for this thread's Q position
var local_scores: array<f32, 64>;
var block_max = -1e10f;
for (var k = 0u; k < kv_block_len; k++) {
let k_global = kv_start + k;
// Causal mask: skip future positions
if (is_causal && k_global > q_pos) {
local_scores[k] = -1e10;
continue;
}
// Dot product Q[thread] @ K[k]
var score = 0.0f;
for (var d = 0u; d < head_dim; d++) {
score += Q_block[thread_id * head_dim + d] * K_block[k * head_dim + d];
}
score *= scale;
local_scores[k] = score;
block_max = max(block_max, score);
}
// Update running max
let new_max = max(m_prev, block_max);
// Compute rescaling factors
let scale_old = exp(m_prev - new_max);
let scale_new = exp(block_max - new_max);
// Rescale previous accumulator
for (var d = 0u; d < head_dim; d++) {
acc[d] *= scale_old;
}
l_prev *= scale_old;
// Compute exp(scores - new_max) and accumulate
var block_sum = 0.0f;
for (var k = 0u; k < kv_block_len; k++) {
let k_global = kv_start + k;
if (is_causal && k_global > q_pos) {
continue;
}
let p = exp(local_scores[k] - new_max);
block_sum += p;
// Accumulate weighted V
for (var d = 0u; d < head_dim; d++) {
acc[d] += p * V_block[k * head_dim + d];
}
}
// Update running sum
l_prev += block_sum;
m_prev = new_max;
}
workgroupBarrier();
}
// Normalize and write output
if (thread_id < BLOCK_SIZE && q_pos < seq_len) {
let inv_sum = select(1.0 / l_prev, 0.0, l_prev == 0.0);
for (var d = 0u; d < head_dim; d++) {
let out_idx = q_pos * hidden_dim + head_idx * head_dim + d;
Output[out_idx] = acc[d] * inv_sum;
}
}
}
// Multi-head attention with grouped-query attention (GQA) support
@compute @workgroup_size(64, 1, 1)
fn main_gqa(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) group_id: vec3<u32>,
) {
// GQA: Multiple Q heads share same K/V heads
// kv_head = q_head / num_q_per_kv
// Left as placeholder for models like Llama 2/3
}
// Sliding window attention variant
@compute @workgroup_size(64, 1, 1)
fn main_sliding_window(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) group_id: vec3<u32>,
) {
// Only attend to positions within window_size
// Useful for very long sequences (Mistral-style)
// Left as placeholder
}

View File

@@ -0,0 +1,159 @@
// LoRA (Low-Rank Adaptation) Forward Pass Shader
//
// Computes: output = input + scaling * (input @ A @ B)
//
// Where:
// - input: (batch_size, in_dim)
// - A: (in_dim, rank) - down projection
// - B: (rank, out_dim) - up projection
// - output: (batch_size, out_dim)
//
// Performance target: <1ms for typical LoRA ranks (2-64)
//
// Optimization strategy:
// 1. Fuse both matmuls into single kernel
// 2. Use shared memory for intermediate (rank is small)
// 3. Each thread computes one output element
const WARP_SIZE: u32 = 32u;
const MAX_RANK: u32 = 64u; // Maximum supported LoRA rank
struct Uniforms {
batch_size: f32,
in_dim: f32,
rank: f32,
out_dim: f32,
scaling: f32, // alpha / rank
_pad0: f32,
_pad1: f32,
_pad2: f32,
}
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read> lora_A: array<f32>; // (in_dim, rank)
@group(0) @binding(2) var<storage, read> lora_B: array<f32>; // (rank, out_dim)
@group(0) @binding(3) var<storage, read_write> output: array<f32>;
@group(0) @binding(4) var<uniform> uniforms: Uniforms;
// Shared memory for intermediate result (input @ A)
var<workgroup> intermediate: array<f32, 2048>; // batch * rank (fits typical cases)
// Thread-local registers
var<private> input_cache: array<f32, 32>; // Cache input values
var<private> a_cache: array<f32, 64>; // Cache A column
@compute @workgroup_size(256, 1, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) group_id: vec3<u32>,
) {
let batch_size = u32(uniforms.batch_size);
let in_dim = u32(uniforms.in_dim);
let rank = u32(uniforms.rank);
let out_dim = u32(uniforms.out_dim);
let scaling = uniforms.scaling;
let thread_id = local_id.x;
let global_thread = global_id.x;
// Compute which output element this thread handles
let batch_idx = global_thread / out_dim;
let out_idx = global_thread % out_dim;
if (batch_idx >= batch_size) {
return;
}
// Phase 1: Compute input @ A for this batch element
// Store in shared memory for reuse
// Each thread contributes to computing intermediate[batch_idx, :]
// For small rank, each thread can compute entire row
if (rank <= MAX_RANK && thread_id < rank) {
var sum = 0.0f;
// Dot product: input[batch_idx, :] @ A[:, thread_id]
for (var i = 0u; i < in_dim; i++) {
let input_val = input[batch_idx * in_dim + i];
let a_val = lora_A[i * rank + thread_id];
sum += input_val * a_val;
}
// Store in shared memory
let shared_idx = (batch_idx % 32u) * rank + thread_id; // Wrap for shared memory size
if (shared_idx < 2048u) {
intermediate[shared_idx] = sum;
}
}
workgroupBarrier();
// Phase 2: Compute intermediate @ B for this output position
var lora_output = 0.0f;
// Dot product: intermediate[batch_idx, :] @ B[:, out_idx]
for (var r = 0u; r < rank; r++) {
let shared_idx = (batch_idx % 32u) * rank + r;
let inter_val = select(0.0, intermediate[shared_idx], shared_idx < 2048u);
let b_val = lora_B[r * out_dim + out_idx];
lora_output += inter_val * b_val;
}
// Apply scaling and add to output
// Note: For true residual connection, we'd add to existing output
// Here we assume output buffer is pre-filled with base model output
// or we're computing the delta only
output[batch_idx * out_dim + out_idx] = lora_output * scaling;
}
// Fused LoRA with base weight: output = (input @ W) + scaling * (input @ A @ B)
// More efficient when we have access to base weights
@compute @workgroup_size(256, 1, 1)
fn main_fused(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) group_id: vec3<u32>,
) {
// Would include base weight computation
// Placeholder for full integration
}
// Batched LoRA for multiple adapters (multi-task serving)
// Each batch element can use different LoRA weights
@compute @workgroup_size(256, 1, 1)
fn main_batched_lora(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) group_id: vec3<u32>,
) {
// Supports different LoRA for different requests in same batch
// Useful for serving multiple fine-tuned models
// Placeholder for multi-tenant serving
}
// Quantized LoRA (int4 weights)
// Significant memory savings for large rank or many adapters
@compute @workgroup_size(256, 1, 1)
fn main_quantized(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) group_id: vec3<u32>,
) {
// A and B stored as int4 with scale factors
// Dequantize on-the-fly during computation
// Placeholder for memory-constrained deployment
}
// DoRA (Weight-Decomposed Low-Rank Adaptation)
// Decomposes weight update into magnitude and direction
@compute @workgroup_size(256, 1, 1)
fn main_dora(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) group_id: vec3<u32>,
) {
// DoRA: output = m * (W + scaling * A @ B) / ||W + scaling * A @ B||
// where m is learned magnitude
// Placeholder for DoRA support
}

View File

@@ -0,0 +1,102 @@
#version 300 es
//! Matrix Multiplication Fragment Shader
//!
//! Computes C = A * B using texture-based GPU compute.
//!
//! ## Usage
//!
//! - A and B are R32F textures (single-channel float)
//! - Output is rendered to framebuffer-attached texture
//! - Each fragment computes one element of C
//!
//! ## Texture Layout
//!
//! - A: rows = M, cols = K (stored row-major)
//! - B: rows = K, cols = N (stored row-major)
//! - C: rows = M, cols = N (output)
//!
//! ## Performance Notes
//!
//! - Use texture size that's power of 2 for best performance
//! - NEAREST filtering required for exact texel fetch
//! - Loop unrolling may help on some GPUs
precision highp float;
// Input matrices as textures
uniform sampler2D u_A;
uniform sampler2D u_B;
// Matrix dimensions: (M, K, N)
// A is MxK, B is KxN, C is MxN
uniform vec3 u_dims;
// Texture coordinates from vertex shader
in vec2 v_texcoord;
// Output value (single float stored in R channel)
out float fragColor;
void main() {
float M = u_dims.x;
float K = u_dims.y;
float N = u_dims.z;
// Calculate output position (row i, column j)
// v_texcoord is normalized [0,1], so we scale to pixel coordinates
float i = floor(v_texcoord.y * M);
float j = floor(v_texcoord.x * N);
// Bounds check (fragments outside valid range output 0)
if (i >= M || j >= N) {
fragColor = 0.0;
return;
}
// Compute dot product of row i of A with column j of B
float sum = 0.0;
// Manual loop unrolling for common case (K <= 4)
// This helps on mobile GPUs with limited loop support
#if defined(UNROLL_4)
if (K <= 4.0) {
if (K >= 1.0) {
float a0 = texture(u_A, vec2(0.5 / K, (i + 0.5) / M)).r;
float b0 = texture(u_B, vec2((j + 0.5) / N, 0.5 / K)).r;
sum += a0 * b0;
}
if (K >= 2.0) {
float a1 = texture(u_A, vec2(1.5 / K, (i + 0.5) / M)).r;
float b1 = texture(u_B, vec2((j + 0.5) / N, 1.5 / K)).r;
sum += a1 * b1;
}
if (K >= 3.0) {
float a2 = texture(u_A, vec2(2.5 / K, (i + 0.5) / M)).r;
float b2 = texture(u_B, vec2((j + 0.5) / N, 2.5 / K)).r;
sum += a2 * b2;
}
if (K >= 4.0) {
float a3 = texture(u_A, vec2(3.5 / K, (i + 0.5) / M)).r;
float b3 = texture(u_B, vec2((j + 0.5) / N, 3.5 / K)).r;
sum += a3 * b3;
}
} else
#endif
{
// General loop for arbitrary K
// We add 0.5 to center the sample within each texel
for (float k = 0.0; k < K; k += 1.0) {
// Sample A[i, k] - row i, column k
// Texture coordinate: x = (k + 0.5) / K, y = (i + 0.5) / M
float a_val = texture(u_A, vec2((k + 0.5) / K, (i + 0.5) / M)).r;
// Sample B[k, j] - row k, column j
// Texture coordinate: x = (j + 0.5) / N, y = (k + 0.5) / K
float b_val = texture(u_B, vec2((j + 0.5) / N, (k + 0.5) / K)).r;
sum += a_val * b_val;
}
}
fragColor = sum;
}

View File

@@ -0,0 +1,171 @@
// Tiled Matrix Multiplication Shader
//
// Computes C = A * B using 128x128 tiles for cache efficiency.
// Targets 10+ TFLOPS on discrete GPUs.
//
// Algorithm:
// 1. Each workgroup computes a TILE_SIZE x TILE_SIZE block of C
// 2. A and B are loaded into shared memory in tiles
// 3. Each thread computes a 4x4 subblock for register tiling
// 4. Accumulation happens in registers, then written to C
//
// Memory Layout:
// - A: M x K matrix (row-major)
// - B: K x N matrix (row-major)
// - C: M x N matrix (row-major, output)
// Tile dimensions (must match host code)
const TILE_SIZE: u32 = 128u;
const BLOCK_SIZE: u32 = 16u; // Threads per dimension in workgroup
const THREAD_TILE: u32 = 8u; // Each thread computes 8x8 elements
// Uniforms
struct Uniforms {
M: u32, // Rows of A, rows of C
N: u32, // Cols of B, cols of C
K: u32, // Cols of A, rows of B
tile_size: u32,
}
@group(0) @binding(0) var<storage, read> A: array<f32>;
@group(0) @binding(1) var<storage, read> B: array<f32>;
@group(0) @binding(2) var<storage, read_write> C: array<f32>;
@group(0) @binding(3) var<uniform> uniforms: Uniforms;
// Shared memory for tile caching
var<workgroup> A_tile: array<f32, 2048>; // TILE_SIZE * BLOCK_SIZE = 128 * 16
var<workgroup> B_tile: array<f32, 2048>;
// Thread-local accumulator registers
var<private> acc: array<f32, 64>; // THREAD_TILE * THREAD_TILE = 8 * 8
@compute @workgroup_size(16, 16, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) group_id: vec3<u32>,
) {
let M = uniforms.M;
let N = uniforms.N;
let K = uniforms.K;
// Global row and column for this thread's block
let block_row = group_id.x * TILE_SIZE;
let block_col = group_id.y * TILE_SIZE;
// Thread position within workgroup
let thread_row = local_id.x;
let thread_col = local_id.y;
// Initialize accumulators to zero
for (var i = 0u; i < 64u; i++) {
acc[i] = 0.0;
}
// Number of K-tiles to process
let num_k_tiles = (K + TILE_SIZE - 1u) / TILE_SIZE;
// Iterate over K dimension in tiles
for (var k_tile = 0u; k_tile < num_k_tiles; k_tile++) {
let k_base = k_tile * TILE_SIZE;
// Cooperative load of A tile into shared memory
// Each thread loads multiple elements
for (var i = 0u; i < THREAD_TILE; i++) {
let a_row = block_row + thread_row * THREAD_TILE + i;
for (var j = 0u; j < THREAD_TILE; j++) {
let a_col = k_base + thread_col * THREAD_TILE + j;
let shared_idx = (thread_row * THREAD_TILE + i) * BLOCK_SIZE + thread_col;
if (a_row < M && a_col < K) {
// Only load partial tile for first few elements to fit in shared memory
if (shared_idx < 2048u) {
A_tile[shared_idx] = A[a_row * K + a_col];
}
}
}
}
// Cooperative load of B tile into shared memory
for (var i = 0u; i < THREAD_TILE; i++) {
let b_row = k_base + thread_row * THREAD_TILE + i;
for (var j = 0u; j < THREAD_TILE; j++) {
let b_col = block_col + thread_col * THREAD_TILE + j;
let shared_idx = (thread_row * THREAD_TILE + i) * BLOCK_SIZE + thread_col;
if (b_row < K && b_col < N) {
if (shared_idx < 2048u) {
B_tile[shared_idx] = B[b_row * N + b_col];
}
}
}
}
// Synchronize to ensure all data is loaded
workgroupBarrier();
// Compute partial dot products
// Each thread computes an 8x8 subblock
for (var k = 0u; k < min(TILE_SIZE, K - k_base); k++) {
// Load A values into registers
var a_regs: array<f32, 8>;
for (var i = 0u; i < THREAD_TILE; i++) {
let a_shared_row = thread_row * THREAD_TILE + i;
let a_shared_idx = a_shared_row * BLOCK_SIZE + (k % BLOCK_SIZE);
if (a_shared_idx < 2048u) {
a_regs[i] = A_tile[a_shared_idx];
} else {
a_regs[i] = 0.0;
}
}
// Load B values into registers
var b_regs: array<f32, 8>;
for (var j = 0u; j < THREAD_TILE; j++) {
let b_shared_row = k % BLOCK_SIZE;
let b_shared_col = thread_col * THREAD_TILE + j;
let b_shared_idx = b_shared_row * BLOCK_SIZE + (b_shared_col % BLOCK_SIZE);
if (b_shared_idx < 2048u) {
b_regs[j] = B_tile[b_shared_idx];
} else {
b_regs[j] = 0.0;
}
}
// Outer product accumulation
for (var i = 0u; i < THREAD_TILE; i++) {
for (var j = 0u; j < THREAD_TILE; j++) {
acc[i * THREAD_TILE + j] += a_regs[i] * b_regs[j];
}
}
}
// Synchronize before loading next tile
workgroupBarrier();
}
// Write accumulated results to global memory
for (var i = 0u; i < THREAD_TILE; i++) {
let c_row = block_row + thread_row * THREAD_TILE + i;
for (var j = 0u; j < THREAD_TILE; j++) {
let c_col = block_col + thread_col * THREAD_TILE + j;
if (c_row < M && c_col < N) {
C[c_row * N + c_col] = acc[i * THREAD_TILE + j];
}
}
}
}
// Quantized int8 matrix multiplication variant
// Uses int8 inputs with int32 accumulation, then scales to f32 output
@compute @workgroup_size(16, 16, 1)
fn main_int8(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) group_id: vec3<u32>,
) {
// Quantized version would use packed i8x4 and accumulate to i32
// Then scale by quantization factors at the end
// Left as placeholder for future implementation
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,751 @@
//! Tensor abstraction layer for unified compute operations
//!
//! Provides a minimal tensor abstraction that works across all compute backends
//! (WebGPU, WebGL2, SIMD, WebWorkers, and naive fallback).
use serde::{Deserialize, Serialize};
use std::fmt;
/// Data type for tensor elements
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum DType {
/// 32-bit floating point
F32,
/// 16-bit floating point (for WebGPU)
F16,
/// 8-bit integer (for quantized models)
I8,
/// Unsigned 8-bit (for embeddings)
U8,
/// Binary (for HDC hypervectors)
Binary,
}
impl DType {
/// Size in bytes for this data type
pub fn size_bytes(&self) -> usize {
match self {
DType::F32 => 4,
DType::F16 => 2,
DType::I8 | DType::U8 => 1,
DType::Binary => 1, // 8 bits per byte
}
}
}
impl Default for DType {
fn default() -> Self {
DType::F32
}
}
/// Tensor shape with up to 4 dimensions
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct Shape {
dims: Vec<usize>,
}
impl Shape {
/// Create a new shape from dimensions
pub fn new(dims: &[usize]) -> Self {
Self { dims: dims.to_vec() }
}
/// 1D shape (vector)
pub fn d1(n: usize) -> Self {
Self { dims: vec![n] }
}
/// 2D shape (matrix)
pub fn d2(rows: usize, cols: usize) -> Self {
Self { dims: vec![rows, cols] }
}
/// 3D shape (batch of matrices)
pub fn d3(batch: usize, rows: usize, cols: usize) -> Self {
Self { dims: vec![batch, rows, cols] }
}
/// 4D shape (e.g., attention tensors)
pub fn d4(b: usize, h: usize, s: usize, d: usize) -> Self {
Self { dims: vec![b, h, s, d] }
}
/// Total number of elements
pub fn numel(&self) -> usize {
self.dims.iter().product()
}
/// Number of dimensions
pub fn ndim(&self) -> usize {
self.dims.len()
}
/// Get dimension at index
pub fn dim(&self, idx: usize) -> usize {
self.dims.get(idx).copied().unwrap_or(1)
}
/// Get all dimensions
pub fn dims(&self) -> &[usize] {
&self.dims
}
/// Check if shape is compatible for matrix multiplication with another
pub fn matmul_compatible(&self, other: &Shape) -> bool {
if self.ndim() < 1 || other.ndim() < 1 {
return false;
}
// Last dim of self must match second-to-last of other (or last if 1D)
let self_k = self.dim(self.ndim() - 1);
let other_k = if other.ndim() >= 2 {
other.dim(other.ndim() - 2)
} else {
other.dim(0)
};
self_k == other_k
}
/// Compute strides for row-major layout
pub fn strides(&self) -> Vec<usize> {
let mut strides = vec![1; self.dims.len()];
for i in (0..self.dims.len() - 1).rev() {
strides[i] = strides[i + 1] * self.dims[i + 1];
}
strides
}
}
impl fmt::Display for Shape {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "(")?;
for (i, d) in self.dims.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", d)?;
}
write!(f, ")")
}
}
/// Memory layout for tensors
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum Layout {
/// Row-major (C-style), most common
RowMajor,
/// Column-major (Fortran-style)
ColMajor,
/// Strided (non-contiguous)
Strided,
}
impl Default for Layout {
fn default() -> Self {
Layout::RowMajor
}
}
/// Tensor storage - holds the actual data
#[derive(Clone, Debug)]
pub enum TensorStorage {
/// CPU storage (Vec<f32>)
Cpu(Vec<f32>),
/// Quantized storage (Vec<i8>)
Quantized(Vec<i8>, f32), // (data, scale)
/// Binary storage for HDC
Binary(Vec<u64>), // 64 bits per element
/// GPU buffer reference (opaque handle)
GpuBuffer(u32), // WebGPU buffer ID
/// Shared memory reference for WebWorkers
SharedBuffer(u32), // SharedArrayBuffer ID
}
impl TensorStorage {
/// Get storage size in bytes
pub fn size_bytes(&self) -> usize {
match self {
TensorStorage::Cpu(v) => v.len() * 4,
TensorStorage::Quantized(v, _) => v.len(),
TensorStorage::Binary(v) => v.len() * 8,
TensorStorage::GpuBuffer(_) => 0, // Unknown
TensorStorage::SharedBuffer(_) => 0, // Unknown
}
}
/// Check if storage is on CPU
pub fn is_cpu(&self) -> bool {
matches!(self, TensorStorage::Cpu(_) | TensorStorage::Quantized(_, _))
}
/// Check if storage is on GPU
pub fn is_gpu(&self) -> bool {
matches!(self, TensorStorage::GpuBuffer(_))
}
}
/// Main tensor type for all compute operations
#[derive(Clone, Debug)]
pub struct Tensor {
/// Shape of the tensor
shape: Shape,
/// Data type
dtype: DType,
/// Memory layout
layout: Layout,
/// Underlying storage
storage: TensorStorage,
/// Offset into storage (for views)
offset: usize,
/// Custom strides (for non-contiguous tensors)
strides: Option<Vec<usize>>,
}
impl Tensor {
// ========================================================================
// Constructors
// ========================================================================
/// Create a new tensor with zeros
pub fn zeros(shape: Shape, dtype: DType) -> Self {
let numel = shape.numel();
let storage = match dtype {
DType::F32 | DType::F16 => TensorStorage::Cpu(vec![0.0; numel]),
DType::I8 | DType::U8 => TensorStorage::Quantized(vec![0; numel], 1.0),
DType::Binary => TensorStorage::Binary(vec![0; (numel + 63) / 64]),
};
Self {
shape,
dtype,
layout: Layout::RowMajor,
storage,
offset: 0,
strides: None,
}
}
/// Create a new tensor with ones
pub fn ones(shape: Shape, dtype: DType) -> Self {
let numel = shape.numel();
let storage = match dtype {
DType::F32 | DType::F16 => TensorStorage::Cpu(vec![1.0; numel]),
DType::I8 | DType::U8 => TensorStorage::Quantized(vec![1; numel], 1.0),
DType::Binary => TensorStorage::Binary(vec![u64::MAX; (numel + 63) / 64]),
};
Self {
shape,
dtype,
layout: Layout::RowMajor,
storage,
offset: 0,
strides: None,
}
}
/// Create a tensor from raw f32 data
pub fn from_slice(data: &[f32], shape: Shape) -> Self {
assert_eq!(
data.len(),
shape.numel(),
"Data length {} doesn't match shape {}",
data.len(),
shape
);
Self {
shape,
dtype: DType::F32,
layout: Layout::RowMajor,
storage: TensorStorage::Cpu(data.to_vec()),
offset: 0,
strides: None,
}
}
/// Create a tensor from a Vec<f32>
pub fn from_vec(data: Vec<f32>, shape: Shape) -> Self {
assert_eq!(
data.len(),
shape.numel(),
"Data length {} doesn't match shape {}",
data.len(),
shape
);
Self {
shape,
dtype: DType::F32,
layout: Layout::RowMajor,
storage: TensorStorage::Cpu(data),
offset: 0,
strides: None,
}
}
/// Create a random tensor (uniform [0, 1))
pub fn rand(shape: Shape) -> Self {
let numel = shape.numel();
let mut data = vec![0.0f32; numel];
// Simple LCG PRNG for reproducibility
let mut seed = 0xDEADBEEFu64;
for x in data.iter_mut() {
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
*x = (seed >> 33) as f32 / (1u64 << 31) as f32;
}
Self::from_vec(data, shape)
}
/// Create a random normal tensor (mean=0, std=1)
pub fn randn(shape: Shape) -> Self {
let numel = shape.numel();
let mut data = vec![0.0f32; numel];
// Box-Muller transform for normal distribution
let mut seed = 0xCAFEBABEu64;
for i in (0..numel).step_by(2) {
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
let u1 = (seed >> 33) as f32 / (1u64 << 31) as f32;
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
let u2 = (seed >> 33) as f32 / (1u64 << 31) as f32;
let r = (-2.0 * u1.max(1e-10).ln()).sqrt();
let theta = 2.0 * std::f32::consts::PI * u2;
data[i] = r * theta.cos();
if i + 1 < numel {
data[i + 1] = r * theta.sin();
}
}
Self::from_vec(data, shape)
}
// ========================================================================
// Accessors
// ========================================================================
/// Get tensor shape
pub fn shape(&self) -> &Shape {
&self.shape
}
/// Get data type
pub fn dtype(&self) -> DType {
self.dtype
}
/// Get number of elements
pub fn numel(&self) -> usize {
self.shape.numel()
}
/// Get memory layout
pub fn layout(&self) -> Layout {
self.layout
}
/// Check if tensor is contiguous
pub fn is_contiguous(&self) -> bool {
self.strides.is_none() && self.offset == 0
}
/// Get underlying storage reference
pub fn storage(&self) -> &TensorStorage {
&self.storage
}
/// Get underlying data as f32 slice (if CPU storage)
pub fn as_slice(&self) -> Option<&[f32]> {
match &self.storage {
TensorStorage::Cpu(data) => {
if self.is_contiguous() {
Some(data.as_slice())
} else {
Some(&data[self.offset..self.offset + self.numel()])
}
}
_ => None,
}
}
/// Get mutable underlying data (if CPU storage)
pub fn as_mut_slice(&mut self) -> Option<&mut [f32]> {
match &mut self.storage {
TensorStorage::Cpu(data) => {
if self.is_contiguous() {
Some(data.as_mut_slice())
} else {
let start = self.offset;
let end = start + self.numel();
Some(&mut data[start..end])
}
}
_ => None,
}
}
/// Convert to Vec<f32> (copies data)
pub fn to_vec(&self) -> Vec<f32> {
match &self.storage {
TensorStorage::Cpu(data) => {
if self.is_contiguous() {
data.clone()
} else {
data[self.offset..self.offset + self.numel()].to_vec()
}
}
TensorStorage::Quantized(data, scale) => {
data.iter().map(|&x| x as f32 * scale).collect()
}
_ => vec![0.0; self.numel()],
}
}
// ========================================================================
// Transformations
// ========================================================================
/// Reshape tensor (must have same numel)
pub fn reshape(&self, new_shape: Shape) -> Self {
assert_eq!(
self.numel(),
new_shape.numel(),
"Cannot reshape {} to {}",
self.shape,
new_shape
);
Self {
shape: new_shape,
dtype: self.dtype,
layout: self.layout,
storage: self.storage.clone(),
offset: self.offset,
strides: None, // Reshaping makes it contiguous
}
}
/// Transpose 2D tensor
pub fn transpose(&self) -> Self {
assert_eq!(self.shape.ndim(), 2, "Transpose only supports 2D tensors");
let rows = self.shape.dim(0);
let cols = self.shape.dim(1);
// For non-contiguous transpose, we'd use strides
// For simplicity, we copy and transpose
if let TensorStorage::Cpu(data) = &self.storage {
let mut new_data = vec![0.0f32; self.numel()];
for i in 0..rows {
for j in 0..cols {
new_data[j * rows + i] = data[i * cols + j];
}
}
Self::from_vec(new_data, Shape::d2(cols, rows))
} else {
// For GPU tensors, return a strided view
Self {
shape: Shape::d2(cols, rows),
dtype: self.dtype,
layout: Layout::Strided,
storage: self.storage.clone(),
offset: self.offset,
strides: Some(vec![1, rows]),
}
}
}
/// Convert to contiguous layout
pub fn contiguous(&self) -> Self {
if self.is_contiguous() {
self.clone()
} else {
// Copy to new contiguous storage
Self::from_vec(self.to_vec(), self.shape.clone())
}
}
/// Quantize to i8
pub fn quantize(&self) -> Self {
let data = self.to_vec();
let max_abs = data.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
let scale = max_abs / 127.0;
let quantized: Vec<i8> = data
.iter()
.map(|&x| (x / scale).clamp(-127.0, 127.0) as i8)
.collect();
Self {
shape: self.shape.clone(),
dtype: DType::I8,
layout: Layout::RowMajor,
storage: TensorStorage::Quantized(quantized, scale),
offset: 0,
strides: None,
}
}
/// Dequantize to f32
pub fn dequantize(&self) -> Self {
Self::from_vec(self.to_vec(), self.shape.clone())
}
// ========================================================================
// Size estimation
// ========================================================================
/// Estimate memory usage in bytes
pub fn size_bytes(&self) -> usize {
self.storage.size_bytes()
}
}
/// LoRA adapter for efficient fine-tuning
#[derive(Clone, Debug)]
pub struct LoraAdapter {
/// Low-rank A matrix (d x r)
pub a: Tensor,
/// Low-rank B matrix (r x d)
pub b: Tensor,
/// Scaling factor (alpha / rank)
pub scaling: f32,
/// Target layer name
pub target: String,
}
impl LoraAdapter {
/// Create a new LoRA adapter
pub fn new(input_dim: usize, output_dim: usize, rank: usize, alpha: f32, target: &str) -> Self {
// Initialize A with random normal, B with zeros (as per LoRA paper)
let a = Tensor::randn(Shape::d2(input_dim, rank));
let b = Tensor::zeros(Shape::d2(rank, output_dim), DType::F32);
Self {
a,
b,
scaling: alpha / rank as f32,
target: target.to_string(),
}
}
/// Get rank of this adapter
pub fn rank(&self) -> usize {
self.a.shape().dim(1)
}
/// Get input dimension
pub fn input_dim(&self) -> usize {
self.a.shape().dim(0)
}
/// Get output dimension
pub fn output_dim(&self) -> usize {
self.b.shape().dim(1)
}
}
/// Workload classification for backend selection
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum WorkloadType {
/// Small matmul (< 1K elements)
SmallMatmul,
/// Medium matmul (1K - 100K elements)
MediumMatmul,
/// Large matmul (> 100K elements)
LargeMatmul,
/// Attention mechanism
Attention,
/// Element-wise operation
Elementwise,
/// Reduction (sum, mean, etc.)
Reduction,
/// Sparse operation (> 50% zeros)
Sparse,
/// Batch inference
BatchInference,
/// LoRA forward pass
LoraForward,
}
impl WorkloadType {
/// Classify a workload from tensor shapes
pub fn classify(a: &Tensor, b: Option<&Tensor>) -> Self {
let numel_a = a.numel();
match b {
Some(b_tensor) => {
let numel_b = b_tensor.numel();
let total = numel_a + numel_b;
if a.shape().ndim() >= 3 && a.shape().dim(a.shape().ndim() - 2) == a.shape().dim(a.shape().ndim() - 1) {
// Likely attention (square inner dimensions)
WorkloadType::Attention
} else if total < 1_000 {
WorkloadType::SmallMatmul
} else if total < 100_000 {
WorkloadType::MediumMatmul
} else {
WorkloadType::LargeMatmul
}
}
None => {
if numel_a < 1_000 {
WorkloadType::Elementwise
} else {
WorkloadType::Reduction
}
}
}
}
/// Get estimated FLOP count for this workload
pub fn estimated_flops(&self, numel: usize) -> u64 {
match self {
WorkloadType::SmallMatmul => numel as u64 * 2,
WorkloadType::MediumMatmul => numel as u64 * 2,
WorkloadType::LargeMatmul => numel as u64 * 2,
WorkloadType::Attention => numel as u64 * 4, // Q*K + softmax + *V
WorkloadType::Elementwise => numel as u64,
WorkloadType::Reduction => numel as u64,
WorkloadType::Sparse => numel as u64 / 2, // Assumes 50% sparsity
WorkloadType::BatchInference => numel as u64 * 10,
WorkloadType::LoraForward => numel as u64 * 4, // A*x + B*(A*x)
}
}
}
/// Sparsity analysis for tensors
#[derive(Clone, Debug)]
pub struct SparsityInfo {
/// Fraction of zero elements
pub sparsity: f32,
/// Is structured sparsity (blocks of zeros)?
pub is_structured: bool,
/// Block size if structured
pub block_size: Option<usize>,
}
impl SparsityInfo {
/// Analyze sparsity of a tensor
pub fn analyze(tensor: &Tensor) -> Self {
let data = tensor.to_vec();
let total = data.len();
let zeros = data.iter().filter(|&&x| x == 0.0).count();
let sparsity = zeros as f32 / total as f32;
// Check for structured sparsity (simple block check)
let block_sizes = [4, 8, 16, 32];
let mut is_structured = false;
let mut detected_block = None;
for &block in &block_sizes {
if total >= block * 4 {
let mut block_zeros = 0;
let mut total_blocks = 0;
for chunk in data.chunks(block) {
total_blocks += 1;
if chunk.iter().all(|&x| x == 0.0) {
block_zeros += 1;
}
}
// If > 30% of blocks are all zeros, consider structured
if block_zeros as f32 / total_blocks as f32 > 0.3 {
is_structured = true;
detected_block = Some(block);
break;
}
}
}
Self {
sparsity,
is_structured,
block_size: detected_block,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_shape_creation() {
let s = Shape::d2(3, 4);
assert_eq!(s.numel(), 12);
assert_eq!(s.ndim(), 2);
assert_eq!(s.dim(0), 3);
assert_eq!(s.dim(1), 4);
}
#[test]
fn test_tensor_zeros() {
let t = Tensor::zeros(Shape::d2(2, 3), DType::F32);
assert_eq!(t.numel(), 6);
let data = t.to_vec();
assert!(data.iter().all(|&x| x == 0.0));
}
#[test]
fn test_tensor_from_slice() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let t = Tensor::from_slice(&data, Shape::d2(2, 3));
assert_eq!(t.to_vec(), data);
}
#[test]
fn test_matmul_compatible() {
let s1 = Shape::d2(3, 4);
let s2 = Shape::d2(4, 5);
let s3 = Shape::d2(3, 5);
assert!(s1.matmul_compatible(&s2));
assert!(!s1.matmul_compatible(&s3));
}
#[test]
fn test_transpose() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let t = Tensor::from_slice(&data, Shape::d2(2, 3));
let t_t = t.transpose();
assert_eq!(t_t.shape().dims(), &[3, 2]);
assert_eq!(t_t.to_vec(), vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
}
#[test]
fn test_workload_classification() {
let small = Tensor::zeros(Shape::d2(10, 10), DType::F32);
let large = Tensor::zeros(Shape::d2(1000, 1000), DType::F32);
assert_eq!(
WorkloadType::classify(&small, Some(&small)),
WorkloadType::SmallMatmul
);
assert_eq!(
WorkloadType::classify(&large, Some(&large)),
WorkloadType::LargeMatmul
);
}
#[test]
fn test_quantization() {
let data = vec![0.5, -0.5, 1.0, -1.0];
let t = Tensor::from_slice(&data, Shape::d1(4));
let q = t.quantize();
assert_eq!(q.dtype(), DType::I8);
// Dequantize and check approximate equality
let dq = q.dequantize();
let dq_data = dq.to_vec();
for (a, b) in data.iter().zip(dq_data.iter()) {
assert!((a - b).abs() < 0.01);
}
}
#[test]
fn test_lora_adapter() {
let lora = LoraAdapter::new(128, 128, 4, 1.0, "attention.q");
assert_eq!(lora.rank(), 4);
assert_eq!(lora.input_dim(), 128);
assert_eq!(lora.output_dim(), 128);
}
}

View File

@@ -0,0 +1,353 @@
//! Core types for compute operations
//!
//! These types work without the WebGPU feature and provide
//! the interface for compute operations.
use serde::{Serialize, Deserialize};
/// Matrix storage format
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum MatrixLayout {
/// Row-major storage (C-style)
RowMajor,
/// Column-major storage (Fortran-style)
ColMajor,
}
impl Default for MatrixLayout {
fn default() -> Self {
Self::RowMajor
}
}
/// Data type for compute operations
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum DataType {
/// 32-bit floating point
F32,
/// 16-bit floating point
F16,
/// 16-bit brain floating point
BF16,
/// 8-bit signed integer
I8,
/// 8-bit unsigned integer
U8,
/// 4-bit integer (packed, 2 per byte)
I4,
}
impl DataType {
/// Get size in bytes
pub fn size_bytes(&self) -> usize {
match self {
Self::F32 => 4,
Self::F16 | Self::BF16 => 2,
Self::I8 | Self::U8 => 1,
Self::I4 => 1, // 2 values per byte, but minimum addressable is 1
}
}
/// Check if this is a floating point type
pub fn is_float(&self) -> bool {
matches!(self, Self::F32 | Self::F16 | Self::BF16)
}
/// Check if this is a quantized type
pub fn is_quantized(&self) -> bool {
matches!(self, Self::I8 | Self::U8 | Self::I4)
}
}
/// Tensor descriptor for GPU buffers
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TensorDescriptor {
/// Shape of the tensor
pub shape: Vec<usize>,
/// Data type
pub dtype: DataType,
/// Storage layout
pub layout: MatrixLayout,
/// Stride between elements (None = contiguous)
pub strides: Option<Vec<usize>>,
}
impl TensorDescriptor {
/// Create a new contiguous tensor descriptor
pub fn new(shape: Vec<usize>, dtype: DataType) -> Self {
Self {
shape,
dtype,
layout: MatrixLayout::RowMajor,
strides: None,
}
}
/// Total number of elements
pub fn numel(&self) -> usize {
self.shape.iter().product()
}
/// Size in bytes
pub fn size_bytes(&self) -> usize {
self.numel() * self.dtype.size_bytes()
}
/// Check if tensor is contiguous in memory
pub fn is_contiguous(&self) -> bool {
self.strides.is_none()
}
/// Get number of dimensions
pub fn ndim(&self) -> usize {
self.shape.len()
}
/// Create 2D matrix descriptor
pub fn matrix(rows: usize, cols: usize, dtype: DataType) -> Self {
Self::new(vec![rows, cols], dtype)
}
/// Create 3D tensor descriptor (batch, seq, hidden)
pub fn tensor3d(batch: usize, seq: usize, hidden: usize, dtype: DataType) -> Self {
Self::new(vec![batch, seq, hidden], dtype)
}
}
/// LoRA adapter configuration
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct LoraConfig {
/// Rank of the adaptation (typically 2-64)
pub rank: usize,
/// Alpha scaling factor
pub alpha: f32,
/// Input dimension
pub in_dim: usize,
/// Output dimension
pub out_dim: usize,
/// Dropout rate (0.0 = no dropout)
pub dropout: f32,
}
impl LoraConfig {
/// Create new LoRA config
pub fn new(rank: usize, in_dim: usize, out_dim: usize) -> Self {
Self {
rank,
alpha: rank as f32, // Default alpha = rank
in_dim,
out_dim,
dropout: 0.0,
}
}
/// Scaling factor for LoRA output
pub fn scaling(&self) -> f32 {
self.alpha / self.rank as f32
}
/// Size of A matrix (in_dim x rank)
pub fn a_size(&self) -> usize {
self.in_dim * self.rank
}
/// Size of B matrix (rank x out_dim)
pub fn b_size(&self) -> usize {
self.rank * self.out_dim
}
}
/// Attention configuration
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AttentionConfig {
/// Number of attention heads
pub num_heads: usize,
/// Dimension per head
pub head_dim: usize,
/// Maximum sequence length
pub max_seq_len: usize,
/// Use causal (autoregressive) masking
pub causal: bool,
/// Attention dropout rate
pub dropout: f32,
/// Scale factor (None = 1/sqrt(head_dim))
pub scale: Option<f32>,
/// Use flash attention algorithm
pub flash: bool,
}
impl AttentionConfig {
/// Create new attention config
pub fn new(num_heads: usize, head_dim: usize, max_seq_len: usize) -> Self {
Self {
num_heads,
head_dim,
max_seq_len,
causal: true,
dropout: 0.0,
scale: None,
flash: true,
}
}
/// Total hidden dimension (num_heads * head_dim)
pub fn hidden_dim(&self) -> usize {
self.num_heads * self.head_dim
}
/// Get attention scale factor
pub fn get_scale(&self) -> f32 {
self.scale.unwrap_or_else(|| 1.0 / (self.head_dim as f32).sqrt())
}
}
/// Quantization configuration for int8/int4 operations
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct QuantConfig {
/// Target data type
pub dtype: DataType,
/// Per-channel vs per-tensor quantization
pub per_channel: bool,
/// Symmetric quantization (zero_point = 0)
pub symmetric: bool,
/// Group size for group quantization (0 = no grouping)
pub group_size: usize,
}
impl Default for QuantConfig {
fn default() -> Self {
Self {
dtype: DataType::I8,
per_channel: true,
symmetric: true,
group_size: 0,
}
}
}
impl QuantConfig {
/// Create int8 quantization config
pub fn int8() -> Self {
Self::default()
}
/// Create int4 quantization config with grouping
pub fn int4_grouped(group_size: usize) -> Self {
Self {
dtype: DataType::I4,
per_channel: false,
symmetric: true,
group_size,
}
}
}
/// Buffer usage flags for GPU memory
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct BufferUsage {
pub map_read: bool,
pub map_write: bool,
pub copy_src: bool,
pub copy_dst: bool,
pub storage: bool,
pub uniform: bool,
}
impl Default for BufferUsage {
fn default() -> Self {
Self {
map_read: false,
map_write: false,
copy_src: false,
copy_dst: true,
storage: true,
uniform: false,
}
}
}
impl BufferUsage {
/// Buffer for staging CPU->GPU transfers
pub fn staging_upload() -> Self {
Self {
map_read: false,
map_write: true,
copy_src: true,
copy_dst: false,
storage: false,
uniform: false,
}
}
/// Buffer for staging GPU->CPU transfers
pub fn staging_download() -> Self {
Self {
map_read: true,
map_write: false,
copy_src: false,
copy_dst: true,
storage: false,
uniform: false,
}
}
/// Buffer for compute shader storage
pub fn storage() -> Self {
Self {
map_read: false,
map_write: false,
copy_src: true,
copy_dst: true,
storage: true,
uniform: false,
}
}
/// Buffer for uniform data (small, read-only)
pub fn uniform() -> Self {
Self {
map_read: false,
map_write: false,
copy_src: false,
copy_dst: true,
storage: false,
uniform: true,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_data_type_size() {
assert_eq!(DataType::F32.size_bytes(), 4);
assert_eq!(DataType::F16.size_bytes(), 2);
assert_eq!(DataType::I8.size_bytes(), 1);
}
#[test]
fn test_tensor_descriptor() {
let desc = TensorDescriptor::matrix(1024, 768, DataType::F32);
assert_eq!(desc.numel(), 1024 * 768);
assert_eq!(desc.size_bytes(), 1024 * 768 * 4);
assert_eq!(desc.ndim(), 2);
}
#[test]
fn test_lora_config() {
let config = LoraConfig::new(4, 768, 768);
assert_eq!(config.rank, 4);
assert!((config.scaling() - 1.0).abs() < 0.001);
assert_eq!(config.a_size(), 768 * 4);
assert_eq!(config.b_size(), 4 * 768);
}
#[test]
fn test_attention_config() {
let config = AttentionConfig::new(12, 64, 4096);
assert_eq!(config.hidden_dim(), 768);
assert!((config.get_scale() - 0.125).abs() < 0.001);
}
}

View File

@@ -0,0 +1,696 @@
//! WebGL2 compute simulation for GPU-accelerated operations
//!
//! Uses ping-pong texture rendering for matrix operations on devices without WebGPU.
//! This approach treats textures as 2D arrays and uses fragment shaders for computation.
//!
//! ## Architecture
//!
//! ```text
//! +-------------+ +----------------+ +-------------+
//! | Input A | --> | Fragment | --> | Output |
//! | (Texture) | | Shader | | (Texture) |
//! +-------------+ +----------------+ +-------------+
//! ^ | |
//! | v v
//! +-------------+ +----------------+ +-------------+
//! | Input B | --> | Transform | --> | CPU Read |
//! | (Texture) | | Feedback | | (Float32) |
//! +-------------+ +----------------+ +-------------+
//! ```
//!
//! ## Limitations vs WebGPU
//!
//! - No true compute shaders (uses fragment shaders)
//! - Limited to 2D texture operations
//! - Readback through transform feedback or readPixels
//! - Lower performance than WebGPU compute
use wasm_bindgen::prelude::*;
use web_sys::{
WebGl2RenderingContext, WebGlProgram, WebGlShader, WebGlTexture,
WebGlFramebuffer, WebGlBuffer, WebGlVertexArrayObject,
};
use crate::compute::tensor::{Tensor, TensorShape};
/// Shader programs for different operations
struct ShaderPrograms {
matmul: WebGlProgram,
vector_add: WebGlProgram,
vector_mul: WebGlProgram,
softmax: WebGlProgram,
relu: WebGlProgram,
}
/// WebGL2 compute backend
#[wasm_bindgen]
pub struct WebGl2Compute {
/// WebGL2 rendering context
gl: WebGl2RenderingContext,
/// Shader programs
programs: ShaderPrograms,
/// Texture pool for reuse
texture_pool: Vec<TextureHandle>,
/// Framebuffer for render-to-texture
framebuffer: WebGlFramebuffer,
/// Full-screen quad VAO
quad_vao: WebGlVertexArrayObject,
/// Quad vertex buffer
quad_vbo: WebGlBuffer,
/// Maximum texture size
max_texture_size: u32,
/// Transform feedback buffer for readback
tf_buffer: WebGlBuffer,
}
/// Handle to a pooled texture
struct TextureHandle {
texture: WebGlTexture,
width: u32,
height: u32,
in_use: bool,
}
#[wasm_bindgen]
impl WebGl2Compute {
/// Create a new WebGL2 compute backend
#[wasm_bindgen(constructor)]
pub fn new() -> Result<WebGl2Compute, JsValue> {
let window = web_sys::window()
.ok_or_else(|| JsValue::from_str("No window"))?;
let document = window.document()
.ok_or_else(|| JsValue::from_str("No document"))?;
// Create offscreen canvas
let canvas = document.create_element("canvas")?;
let canvas: web_sys::HtmlCanvasElement = canvas.dyn_into()?;
canvas.set_width(1);
canvas.set_height(1);
// Get WebGL2 context
let context_options = js_sys::Object::new();
js_sys::Reflect::set(&context_options, &"antialias".into(), &false.into())?;
js_sys::Reflect::set(&context_options, &"depth".into(), &false.into())?;
js_sys::Reflect::set(&context_options, &"stencil".into(), &false.into())?;
js_sys::Reflect::set(&context_options, &"preserveDrawingBuffer".into(), &true.into())?;
let gl: WebGl2RenderingContext = canvas
.get_context_with_context_options("webgl2", &context_options)?
.ok_or_else(|| JsValue::from_str("WebGL2 not available"))?
.dyn_into()?;
// Enable required extensions
gl.get_extension("EXT_color_buffer_float")?
.ok_or_else(|| JsValue::from_str("EXT_color_buffer_float not available"))?;
gl.get_extension("OES_texture_float_linear")?;
// Get max texture size
let max_texture_size = gl.get_parameter(WebGl2RenderingContext::MAX_TEXTURE_SIZE)?
.as_f64()
.unwrap_or(4096.0) as u32;
// Create shader programs
let programs = ShaderPrograms {
matmul: create_matmul_program(&gl)?,
vector_add: create_vector_add_program(&gl)?,
vector_mul: create_vector_mul_program(&gl)?,
softmax: create_softmax_program(&gl)?,
relu: create_relu_program(&gl)?,
};
// Create framebuffer
let framebuffer = gl.create_framebuffer()
.ok_or_else(|| JsValue::from_str("Failed to create framebuffer"))?;
// Create full-screen quad
let (quad_vao, quad_vbo) = create_fullscreen_quad(&gl)?;
// Create transform feedback buffer
let tf_buffer = gl.create_buffer()
.ok_or_else(|| JsValue::from_str("Failed to create TF buffer"))?;
Ok(WebGl2Compute {
gl,
programs,
texture_pool: Vec::new(),
framebuffer,
quad_vao,
quad_vbo,
max_texture_size,
tf_buffer,
})
}
/// Check if WebGL2 compute is available
#[wasm_bindgen(js_name = isAvailable)]
pub fn is_available() -> bool {
if let Some(window) = web_sys::window() {
if let Some(document) = window.document() {
if let Ok(canvas) = document.create_element("canvas") {
if let Ok(canvas) = canvas.dyn_into::<web_sys::HtmlCanvasElement>() {
if let Ok(Some(ctx)) = canvas.get_context("webgl2") {
if let Ok(gl) = ctx.dyn_into::<WebGl2RenderingContext>() {
return gl.get_extension("EXT_color_buffer_float")
.map(|e| e.is_some())
.unwrap_or(false);
}
}
}
}
}
}
false
}
/// Get maximum supported texture size
#[wasm_bindgen(js_name = maxTextureSize)]
pub fn max_texture_size(&self) -> u32 {
self.max_texture_size
}
}
// Non-WASM implementation
impl WebGl2Compute {
/// Perform matrix multiplication: C = A * B
pub fn matmul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor, JsValue> {
if !a.shape().is_matrix() || !b.shape().is_matrix() {
return Err(JsValue::from_str("Inputs must be matrices"));
}
let m = a.shape().rows();
let k = a.shape().cols();
let n = b.shape().cols();
if k != b.shape().rows() {
return Err(JsValue::from_str("Matrix dimension mismatch"));
}
// For small matrices, use CPU
if m * k * n < 4096 {
return Ok(self.cpu_matmul(a, b));
}
// Upload matrices to textures
let tex_a = self.upload_matrix(a)?;
let tex_b = self.upload_matrix(b)?;
let tex_c = self.create_texture(m as u32, n as u32)?;
// Bind output texture to framebuffer
self.gl.bind_framebuffer(WebGl2RenderingContext::FRAMEBUFFER, Some(&self.framebuffer));
self.gl.framebuffer_texture_2d(
WebGl2RenderingContext::FRAMEBUFFER,
WebGl2RenderingContext::COLOR_ATTACHMENT0,
WebGl2RenderingContext::TEXTURE_2D,
Some(&tex_c),
0,
);
// Set viewport
self.gl.viewport(0, 0, n as i32, m as i32);
// Use matmul program
self.gl.use_program(Some(&self.programs.matmul));
// Bind input textures
self.gl.active_texture(WebGl2RenderingContext::TEXTURE0);
self.gl.bind_texture(WebGl2RenderingContext::TEXTURE_2D, Some(&tex_a));
let loc_a = self.gl.get_uniform_location(&self.programs.matmul, "u_A");
self.gl.uniform1i(loc_a.as_ref(), 0);
self.gl.active_texture(WebGl2RenderingContext::TEXTURE1);
self.gl.bind_texture(WebGl2RenderingContext::TEXTURE_2D, Some(&tex_b));
let loc_b = self.gl.get_uniform_location(&self.programs.matmul, "u_B");
self.gl.uniform1i(loc_b.as_ref(), 1);
// Set dimensions
let loc_dims = self.gl.get_uniform_location(&self.programs.matmul, "u_dims");
self.gl.uniform3f(loc_dims.as_ref(), m as f32, k as f32, n as f32);
// Draw full-screen quad
self.gl.bind_vertex_array(Some(&self.quad_vao));
self.gl.draw_arrays(WebGl2RenderingContext::TRIANGLE_STRIP, 0, 4);
// Read back result
let result = self.read_texture(&tex_c, m as u32, n as u32)?;
// Cleanup
self.gl.delete_texture(Some(&tex_a));
self.gl.delete_texture(Some(&tex_b));
self.gl.delete_texture(Some(&tex_c));
self.gl.bind_framebuffer(WebGl2RenderingContext::FRAMEBUFFER, None);
Ok(Tensor::from_vec(result, TensorShape::matrix(m, n)))
}
/// Element-wise vector operations
pub fn vector_op(&self, a: &[f32], b: &[f32], op: &str) -> Result<Vec<f32>, JsValue> {
if a.len() != b.len() {
return Err(JsValue::from_str("Vector length mismatch"));
}
let len = a.len();
// For small vectors, use CPU
if len < 1024 {
return Ok(match op {
"add" => a.iter().zip(b.iter()).map(|(x, y)| x + y).collect(),
"sub" => a.iter().zip(b.iter()).map(|(x, y)| x - y).collect(),
"mul" => a.iter().zip(b.iter()).map(|(x, y)| x * y).collect(),
"div" => a.iter().zip(b.iter()).map(|(x, y)| x / y).collect(),
_ => return Err(JsValue::from_str(&format!("Unknown op: {}", op))),
});
}
// Calculate texture dimensions (square-ish)
let width = (len as f32).sqrt().ceil() as u32;
let height = ((len as u32 + width - 1) / width).max(1);
// Pad data to fill texture
let padded_len = (width * height) as usize;
let mut a_padded = a.to_vec();
let mut b_padded = b.to_vec();
a_padded.resize(padded_len, 0.0);
b_padded.resize(padded_len, 0.0);
// Upload to textures
let tex_a = self.upload_data(&a_padded, width, height)?;
let tex_b = self.upload_data(&b_padded, width, height)?;
let tex_c = self.create_texture(width, height)?;
// Select program
let program = match op {
"add" | "sub" => &self.programs.vector_add,
"mul" | "div" => &self.programs.vector_mul,
_ => return Err(JsValue::from_str(&format!("Unknown op: {}", op))),
};
// Bind framebuffer
self.gl.bind_framebuffer(WebGl2RenderingContext::FRAMEBUFFER, Some(&self.framebuffer));
self.gl.framebuffer_texture_2d(
WebGl2RenderingContext::FRAMEBUFFER,
WebGl2RenderingContext::COLOR_ATTACHMENT0,
WebGl2RenderingContext::TEXTURE_2D,
Some(&tex_c),
0,
);
self.gl.viewport(0, 0, width as i32, height as i32);
self.gl.use_program(Some(program));
// Bind textures
self.gl.active_texture(WebGl2RenderingContext::TEXTURE0);
self.gl.bind_texture(WebGl2RenderingContext::TEXTURE_2D, Some(&tex_a));
self.gl.uniform1i(self.gl.get_uniform_location(program, "u_A").as_ref(), 0);
self.gl.active_texture(WebGl2RenderingContext::TEXTURE1);
self.gl.bind_texture(WebGl2RenderingContext::TEXTURE_2D, Some(&tex_b));
self.gl.uniform1i(self.gl.get_uniform_location(program, "u_B").as_ref(), 1);
// Set operation mode
let op_mode = match op {
"add" => 0.0,
"sub" => 1.0,
"mul" => 0.0,
"div" => 1.0,
_ => 0.0,
};
self.gl.uniform1f(self.gl.get_uniform_location(program, "u_mode").as_ref(), op_mode);
// Draw
self.gl.bind_vertex_array(Some(&self.quad_vao));
self.gl.draw_arrays(WebGl2RenderingContext::TRIANGLE_STRIP, 0, 4);
// Read back
let result = self.read_texture(&tex_c, width, height)?;
// Cleanup
self.gl.delete_texture(Some(&tex_a));
self.gl.delete_texture(Some(&tex_b));
self.gl.delete_texture(Some(&tex_c));
self.gl.bind_framebuffer(WebGl2RenderingContext::FRAMEBUFFER, None);
// Trim to original length
Ok(result[..len].to_vec())
}
/// Upload matrix to texture
fn upload_matrix(&self, tensor: &Tensor) -> Result<WebGlTexture, JsValue> {
let rows = tensor.shape().rows() as u32;
let cols = tensor.shape().cols() as u32;
self.upload_data(tensor.data(), cols, rows)
}
/// Upload data to a float texture
fn upload_data(&self, data: &[f32], width: u32, height: u32) -> Result<WebGlTexture, JsValue> {
let texture = self.gl.create_texture()
.ok_or_else(|| JsValue::from_str("Failed to create texture"))?;
self.gl.bind_texture(WebGl2RenderingContext::TEXTURE_2D, Some(&texture));
// Set texture parameters
self.gl.tex_parameteri(
WebGl2RenderingContext::TEXTURE_2D,
WebGl2RenderingContext::TEXTURE_MIN_FILTER,
WebGl2RenderingContext::NEAREST as i32,
);
self.gl.tex_parameteri(
WebGl2RenderingContext::TEXTURE_2D,
WebGl2RenderingContext::TEXTURE_MAG_FILTER,
WebGl2RenderingContext::NEAREST as i32,
);
self.gl.tex_parameteri(
WebGl2RenderingContext::TEXTURE_2D,
WebGl2RenderingContext::TEXTURE_WRAP_S,
WebGl2RenderingContext::CLAMP_TO_EDGE as i32,
);
self.gl.tex_parameteri(
WebGl2RenderingContext::TEXTURE_2D,
WebGl2RenderingContext::TEXTURE_WRAP_T,
WebGl2RenderingContext::CLAMP_TO_EDGE as i32,
);
// Create Float32Array view
let array = js_sys::Float32Array::from(data);
// Upload as R32F texture
self.gl.tex_image_2d_with_i32_and_i32_and_i32_and_format_and_type_and_opt_array_buffer_view(
WebGl2RenderingContext::TEXTURE_2D,
0,
WebGl2RenderingContext::R32F as i32,
width as i32,
height as i32,
0,
WebGl2RenderingContext::RED,
WebGl2RenderingContext::FLOAT,
Some(&array),
)?;
Ok(texture)
}
/// Create an empty float texture
fn create_texture(&self, width: u32, height: u32) -> Result<WebGlTexture, JsValue> {
let texture = self.gl.create_texture()
.ok_or_else(|| JsValue::from_str("Failed to create texture"))?;
self.gl.bind_texture(WebGl2RenderingContext::TEXTURE_2D, Some(&texture));
self.gl.tex_parameteri(
WebGl2RenderingContext::TEXTURE_2D,
WebGl2RenderingContext::TEXTURE_MIN_FILTER,
WebGl2RenderingContext::NEAREST as i32,
);
self.gl.tex_parameteri(
WebGl2RenderingContext::TEXTURE_2D,
WebGl2RenderingContext::TEXTURE_MAG_FILTER,
WebGl2RenderingContext::NEAREST as i32,
);
self.gl.tex_image_2d_with_i32_and_i32_and_i32_and_format_and_type_and_opt_array_buffer_view(
WebGl2RenderingContext::TEXTURE_2D,
0,
WebGl2RenderingContext::R32F as i32,
width as i32,
height as i32,
0,
WebGl2RenderingContext::RED,
WebGl2RenderingContext::FLOAT,
None,
)?;
Ok(texture)
}
/// Read texture data back to CPU
fn read_texture(&self, texture: &WebGlTexture, width: u32, height: u32) -> Result<Vec<f32>, JsValue> {
// Bind texture to framebuffer
self.gl.bind_framebuffer(WebGl2RenderingContext::FRAMEBUFFER, Some(&self.framebuffer));
self.gl.framebuffer_texture_2d(
WebGl2RenderingContext::FRAMEBUFFER,
WebGl2RenderingContext::COLOR_ATTACHMENT0,
WebGl2RenderingContext::TEXTURE_2D,
Some(texture),
0,
);
// Read pixels as RGBA (WebGL2 limitation for readPixels)
let pixel_count = (width * height) as usize;
let mut rgba_data = vec![0u8; pixel_count * 4 * 4]; // RGBA * f32
// Use readPixels with RGBA format
let float_array = js_sys::Float32Array::new_with_length(pixel_count as u32 * 4);
self.gl.read_pixels_with_array_buffer_view(
0, 0,
width as i32, height as i32,
WebGl2RenderingContext::RGBA,
WebGl2RenderingContext::FLOAT,
&float_array,
)?;
// Extract R channel (our actual data)
let mut result = Vec::with_capacity(pixel_count);
for i in 0..pixel_count {
result.push(float_array.get_index((i * 4) as u32));
}
Ok(result)
}
/// CPU fallback for small matrices
fn cpu_matmul(&self, a: &Tensor, b: &Tensor) -> Tensor {
let m = a.shape().rows();
let k = a.shape().cols();
let n = b.shape().cols();
let a_data = a.data();
let b_data = b.data();
let mut result = vec![0.0f32; m * n];
for i in 0..m {
for j in 0..n {
let mut sum = 0.0;
for kk in 0..k {
sum += a_data[i * k + kk] * b_data[kk * n + j];
}
result[i * n + j] = sum;
}
}
Tensor::from_vec(result, TensorShape::matrix(m, n))
}
}
/// Create fullscreen quad for render-to-texture
fn create_fullscreen_quad(gl: &WebGl2RenderingContext) -> Result<(WebGlVertexArrayObject, WebGlBuffer), JsValue> {
let vao = gl.create_vertex_array()
.ok_or_else(|| JsValue::from_str("Failed to create VAO"))?;
let vbo = gl.create_buffer()
.ok_or_else(|| JsValue::from_str("Failed to create VBO"))?;
gl.bind_vertex_array(Some(&vao));
gl.bind_buffer(WebGl2RenderingContext::ARRAY_BUFFER, Some(&vbo));
// Fullscreen quad vertices (position + texcoord)
let vertices: [f32; 16] = [
-1.0, -1.0, 0.0, 0.0,
1.0, -1.0, 1.0, 0.0,
-1.0, 1.0, 0.0, 1.0,
1.0, 1.0, 1.0, 1.0,
];
let array = js_sys::Float32Array::from(vertices.as_slice());
gl.buffer_data_with_array_buffer_view(
WebGl2RenderingContext::ARRAY_BUFFER,
&array,
WebGl2RenderingContext::STATIC_DRAW,
);
// Position attribute
gl.enable_vertex_attrib_array(0);
gl.vertex_attrib_pointer_with_i32(0, 2, WebGl2RenderingContext::FLOAT, false, 16, 0);
// Texcoord attribute
gl.enable_vertex_attrib_array(1);
gl.vertex_attrib_pointer_with_i32(1, 2, WebGl2RenderingContext::FLOAT, false, 16, 8);
Ok((vao, vbo))
}
/// Compile a shader
fn compile_shader(gl: &WebGl2RenderingContext, shader_type: u32, source: &str) -> Result<WebGlShader, JsValue> {
let shader = gl.create_shader(shader_type)
.ok_or_else(|| JsValue::from_str("Failed to create shader"))?;
gl.shader_source(&shader, source);
gl.compile_shader(&shader);
if !gl.get_shader_parameter(&shader, WebGl2RenderingContext::COMPILE_STATUS)
.as_bool()
.unwrap_or(false)
{
let log = gl.get_shader_info_log(&shader)
.unwrap_or_else(|| "Unknown error".to_string());
gl.delete_shader(Some(&shader));
return Err(JsValue::from_str(&format!("Shader compile error: {}", log)));
}
Ok(shader)
}
/// Link a shader program
fn link_program(gl: &WebGl2RenderingContext, vertex: &WebGlShader, fragment: &WebGlShader) -> Result<WebGlProgram, JsValue> {
let program = gl.create_program()
.ok_or_else(|| JsValue::from_str("Failed to create program"))?;
gl.attach_shader(&program, vertex);
gl.attach_shader(&program, fragment);
gl.link_program(&program);
if !gl.get_program_parameter(&program, WebGl2RenderingContext::LINK_STATUS)
.as_bool()
.unwrap_or(false)
{
let log = gl.get_program_info_log(&program)
.unwrap_or_else(|| "Unknown error".to_string());
gl.delete_program(Some(&program));
return Err(JsValue::from_str(&format!("Program link error: {}", log)));
}
Ok(program)
}
/// Vertex shader for all compute operations
const VERTEX_SHADER: &str = r#"#version 300 es
layout(location = 0) in vec2 a_position;
layout(location = 1) in vec2 a_texcoord;
out vec2 v_texcoord;
void main() {
gl_Position = vec4(a_position, 0.0, 1.0);
v_texcoord = a_texcoord;
}
"#;
/// Create matrix multiplication program
fn create_matmul_program(gl: &WebGl2RenderingContext) -> Result<WebGlProgram, JsValue> {
const MATMUL_FRAG: &str = r#"#version 300 es
precision highp float;
uniform sampler2D u_A;
uniform sampler2D u_B;
uniform vec3 u_dims; // M, K, N
in vec2 v_texcoord;
out float fragColor;
void main() {
float M = u_dims.x;
float K = u_dims.y;
float N = u_dims.z;
// Output position
float i = floor(v_texcoord.y * M);
float j = floor(v_texcoord.x * N);
float sum = 0.0;
for (float k = 0.0; k < K; k += 1.0) {
float a = texture(u_A, vec2((k + 0.5) / K, (i + 0.5) / M)).r;
float b = texture(u_B, vec2((j + 0.5) / N, (k + 0.5) / K)).r;
sum += a * b;
}
fragColor = sum;
}
"#;
let vs = compile_shader(gl, WebGl2RenderingContext::VERTEX_SHADER, VERTEX_SHADER)?;
let fs = compile_shader(gl, WebGl2RenderingContext::FRAGMENT_SHADER, MATMUL_FRAG)?;
link_program(gl, &vs, &fs)
}
/// Create vector addition program
fn create_vector_add_program(gl: &WebGl2RenderingContext) -> Result<WebGlProgram, JsValue> {
const VECTOR_ADD_FRAG: &str = r#"#version 300 es
precision highp float;
uniform sampler2D u_A;
uniform sampler2D u_B;
uniform float u_mode; // 0 = add, 1 = sub
in vec2 v_texcoord;
out float fragColor;
void main() {
float a = texture(u_A, v_texcoord).r;
float b = texture(u_B, v_texcoord).r;
fragColor = u_mode < 0.5 ? a + b : a - b;
}
"#;
let vs = compile_shader(gl, WebGl2RenderingContext::VERTEX_SHADER, VERTEX_SHADER)?;
let fs = compile_shader(gl, WebGl2RenderingContext::FRAGMENT_SHADER, VECTOR_ADD_FRAG)?;
link_program(gl, &vs, &fs)
}
/// Create vector multiplication program
fn create_vector_mul_program(gl: &WebGl2RenderingContext) -> Result<WebGlProgram, JsValue> {
const VECTOR_MUL_FRAG: &str = r#"#version 300 es
precision highp float;
uniform sampler2D u_A;
uniform sampler2D u_B;
uniform float u_mode; // 0 = mul, 1 = div
in vec2 v_texcoord;
out float fragColor;
void main() {
float a = texture(u_A, v_texcoord).r;
float b = texture(u_B, v_texcoord).r;
fragColor = u_mode < 0.5 ? a * b : a / max(b, 1e-7);
}
"#;
let vs = compile_shader(gl, WebGl2RenderingContext::VERTEX_SHADER, VERTEX_SHADER)?;
let fs = compile_shader(gl, WebGl2RenderingContext::FRAGMENT_SHADER, VECTOR_MUL_FRAG)?;
link_program(gl, &vs, &fs)
}
/// Create softmax program
fn create_softmax_program(gl: &WebGl2RenderingContext) -> Result<WebGlProgram, JsValue> {
const SOFTMAX_FRAG: &str = r#"#version 300 es
precision highp float;
uniform sampler2D u_A;
uniform vec2 u_size;
in vec2 v_texcoord;
out float fragColor;
void main() {
// First pass would compute max, second pass computes exp/sum
// This is a simplified single-pass version for small vectors
float x = texture(u_A, v_texcoord).r;
fragColor = exp(x);
}
"#;
let vs = compile_shader(gl, WebGl2RenderingContext::VERTEX_SHADER, VERTEX_SHADER)?;
let fs = compile_shader(gl, WebGl2RenderingContext::FRAGMENT_SHADER, SOFTMAX_FRAG)?;
link_program(gl, &vs, &fs)
}
/// Create ReLU program
fn create_relu_program(gl: &WebGl2RenderingContext) -> Result<WebGlProgram, JsValue> {
const RELU_FRAG: &str = r#"#version 300 es
precision highp float;
uniform sampler2D u_A;
in vec2 v_texcoord;
out float fragColor;
void main() {
float x = texture(u_A, v_texcoord).r;
fragColor = max(x, 0.0);
}
"#;
let vs = compile_shader(gl, WebGl2RenderingContext::VERTEX_SHADER, VERTEX_SHADER)?;
let fs = compile_shader(gl, WebGl2RenderingContext::FRAGMENT_SHADER, RELU_FRAG)?;
link_program(gl, &vs, &fs)
}
#[cfg(test)]
mod tests {
// WebGL tests require browser environment
}

View File

@@ -0,0 +1,909 @@
//! WebGPU Compute Backend Implementation
//!
//! This module provides GPU-accelerated compute operations using wgpu.
//! It includes optimized pipelines for matrix multiplication, attention,
//! and LoRA adapter inference.
use std::sync::Arc;
use std::collections::HashMap;
use super::{
ComputeConfig, ComputeError, ComputeMetrics,
TensorDescriptor, DataType, LoraConfig, AttentionConfig,
BufferUsage, MATMUL_SHADER, ATTENTION_SHADER, LORA_SHADER,
};
/// Buffer handle for GPU memory
#[derive(Clone)]
pub struct GpuBuffer {
/// Underlying wgpu buffer
buffer: Arc<wgpu::Buffer>,
/// Size in bytes
size: usize,
/// Tensor descriptor
desc: TensorDescriptor,
}
impl GpuBuffer {
/// Get buffer size in bytes
pub fn size(&self) -> usize {
self.size
}
/// Get tensor descriptor
pub fn descriptor(&self) -> &TensorDescriptor {
&self.desc
}
/// Get underlying wgpu buffer
pub fn raw(&self) -> &wgpu::Buffer {
&self.buffer
}
}
/// Compute pipeline for a specific operation
struct ComputePipeline {
pipeline: wgpu::ComputePipeline,
bind_group_layout: wgpu::BindGroupLayout,
}
/// WebGPU compute backend for GPU-accelerated inference
pub struct WebGpuCompute {
/// GPU device handle
device: Arc<wgpu::Device>,
/// Command queue
queue: Arc<wgpu::Queue>,
/// Backend configuration
config: ComputeConfig,
/// Matrix multiplication pipeline
matmul_pipeline: ComputePipeline,
/// Attention pipeline
attention_pipeline: ComputePipeline,
/// LoRA forward pipeline
lora_pipeline: ComputePipeline,
/// Staging buffer pool for CPU<->GPU transfers
staging_pool: StagingBufferPool,
/// Performance metrics from last operation
last_metrics: ComputeMetrics,
/// Device limits
limits: wgpu::Limits,
}
impl WebGpuCompute {
/// Create a new WebGPU compute backend
pub async fn new() -> Result<Self, ComputeError> {
Self::with_config(ComputeConfig::default()).await
}
/// Create with custom configuration
pub async fn with_config(config: ComputeConfig) -> Result<Self, ComputeError> {
// Request adapter
let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
backends: wgpu::Backends::all(),
dx12_shader_compiler: wgpu::Dx12Compiler::Fxc,
flags: wgpu::InstanceFlags::empty(),
gles_minor_version: wgpu::Gles3MinorVersion::Automatic,
});
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
compatible_surface: None,
force_fallback_adapter: false,
})
.await
.ok_or_else(|| ComputeError::DeviceNotAvailable(
"No suitable GPU adapter found".to_string()
))?;
let limits = adapter.limits();
// Request device with compute capabilities
let (device, queue) = adapter
.request_device(
&wgpu::DeviceDescriptor {
label: Some("edge-net-compute"),
required_features: wgpu::Features::empty(),
required_limits: wgpu::Limits::default(),
memory_hints: wgpu::MemoryHints::Performance,
},
None,
)
.await
.map_err(|e| ComputeError::DeviceNotAvailable(e.to_string()))?;
let device = Arc::new(device);
let queue = Arc::new(queue);
// Create compute pipelines
let matmul_pipeline = Self::create_matmul_pipeline(&device, &config)?;
let attention_pipeline = Self::create_attention_pipeline(&device, &config)?;
let lora_pipeline = Self::create_lora_pipeline(&device, &config)?;
// Create staging buffer pool
let staging_pool = StagingBufferPool::new(device.clone(), 16 * 1024 * 1024); // 16MB pool
Ok(Self {
device,
queue,
config,
matmul_pipeline,
attention_pipeline,
lora_pipeline,
staging_pool,
last_metrics: ComputeMetrics::default(),
limits,
})
}
/// Create matrix multiplication pipeline
fn create_matmul_pipeline(
device: &wgpu::Device,
config: &ComputeConfig,
) -> Result<ComputePipeline, ComputeError> {
// Create shader module
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("matmul_shader"),
source: wgpu::ShaderSource::Wgsl(MATMUL_SHADER.into()),
});
// Create bind group layout
// Bindings: 0=A matrix, 1=B matrix, 2=C matrix (output), 3=uniforms
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("matmul_bind_group_layout"),
entries: &[
// Matrix A (read-only storage)
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
// Matrix B (read-only storage)
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
// Matrix C (read-write storage)
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
// Uniforms (dimensions)
wgpu::BindGroupLayoutEntry {
binding: 3,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
// Create pipeline layout
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("matmul_pipeline_layout"),
bind_group_layouts: &[&bind_group_layout],
push_constant_ranges: &[],
});
// Create compute pipeline
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("matmul_pipeline"),
layout: Some(&pipeline_layout),
module: &shader,
entry_point: Some("main"),
compilation_options: wgpu::PipelineCompilationOptions::default(),
cache: None,
});
Ok(ComputePipeline {
pipeline,
bind_group_layout,
})
}
/// Create attention pipeline
fn create_attention_pipeline(
device: &wgpu::Device,
config: &ComputeConfig,
) -> Result<ComputePipeline, ComputeError> {
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("attention_shader"),
source: wgpu::ShaderSource::Wgsl(ATTENTION_SHADER.into()),
});
// Bindings: 0=Q, 1=K, 2=V, 3=Output, 4=Uniforms
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("attention_bind_group_layout"),
entries: &[
// Q (query)
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
// K (key)
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
// V (value)
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
// Output
wgpu::BindGroupLayoutEntry {
binding: 3,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
// Uniforms
wgpu::BindGroupLayoutEntry {
binding: 4,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("attention_pipeline_layout"),
bind_group_layouts: &[&bind_group_layout],
push_constant_ranges: &[],
});
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("attention_pipeline"),
layout: Some(&pipeline_layout),
module: &shader,
entry_point: Some("main"),
compilation_options: wgpu::PipelineCompilationOptions::default(),
cache: None,
});
Ok(ComputePipeline {
pipeline,
bind_group_layout,
})
}
/// Create LoRA forward pipeline
fn create_lora_pipeline(
device: &wgpu::Device,
config: &ComputeConfig,
) -> Result<ComputePipeline, ComputeError> {
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("lora_shader"),
source: wgpu::ShaderSource::Wgsl(LORA_SHADER.into()),
});
// Bindings: 0=Input, 1=LoRA_A, 2=LoRA_B, 3=Output, 4=Uniforms
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("lora_bind_group_layout"),
entries: &[
// Input
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
// LoRA A matrix
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
// LoRA B matrix
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
// Output
wgpu::BindGroupLayoutEntry {
binding: 3,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
// Uniforms
wgpu::BindGroupLayoutEntry {
binding: 4,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("lora_pipeline_layout"),
bind_group_layouts: &[&bind_group_layout],
push_constant_ranges: &[],
});
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("lora_pipeline"),
layout: Some(&pipeline_layout),
module: &shader,
entry_point: Some("main"),
compilation_options: wgpu::PipelineCompilationOptions::default(),
cache: None,
});
Ok(ComputePipeline {
pipeline,
bind_group_layout,
})
}
// ========================================================================
// Buffer Management
// ========================================================================
/// Allocate a GPU buffer
pub fn allocate_buffer(&self, desc: TensorDescriptor, usage: BufferUsage) -> Result<GpuBuffer, ComputeError> {
let size = desc.size_bytes();
// Check against device limits
if size > self.limits.max_buffer_size as usize {
return Err(ComputeError::BufferAllocationFailed {
requested: size,
available: self.limits.max_buffer_size as usize,
});
}
let mut wgpu_usage = wgpu::BufferUsages::empty();
if usage.map_read { wgpu_usage |= wgpu::BufferUsages::MAP_READ; }
if usage.map_write { wgpu_usage |= wgpu::BufferUsages::MAP_WRITE; }
if usage.copy_src { wgpu_usage |= wgpu::BufferUsages::COPY_SRC; }
if usage.copy_dst { wgpu_usage |= wgpu::BufferUsages::COPY_DST; }
if usage.storage { wgpu_usage |= wgpu::BufferUsages::STORAGE; }
if usage.uniform { wgpu_usage |= wgpu::BufferUsages::UNIFORM; }
let buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("compute_buffer"),
size: size as u64,
usage: wgpu_usage,
mapped_at_creation: false,
});
Ok(GpuBuffer {
buffer: Arc::new(buffer),
size,
desc,
})
}
/// Upload data to GPU buffer
pub async fn upload_buffer(&self, buffer: &GpuBuffer, data: &[u8]) -> Result<(), ComputeError> {
if data.len() != buffer.size {
return Err(ComputeError::DimensionMismatch {
expected: format!("{} bytes", buffer.size),
actual: format!("{} bytes", data.len()),
});
}
// Use staging buffer for upload
let staging = self.staging_pool.get_upload_buffer(data.len())?;
// Write to staging buffer
self.queue.write_buffer(&staging, 0, data);
// Copy from staging to destination
let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("upload_encoder"),
});
encoder.copy_buffer_to_buffer(&staging, 0, buffer.raw(), 0, data.len() as u64);
self.queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
/// Download data from GPU buffer
pub async fn download_buffer(&self, buffer: &GpuBuffer) -> Result<Vec<u8>, ComputeError> {
let staging = self.staging_pool.get_download_buffer(buffer.size)?;
// Copy from source to staging
let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("download_encoder"),
});
encoder.copy_buffer_to_buffer(buffer.raw(), 0, &staging, 0, buffer.size as u64);
self.queue.submit(std::iter::once(encoder.finish()));
// Map staging buffer and read
let slice = staging.slice(..);
let (tx, rx) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |result| {
tx.send(result).unwrap();
});
self.device.poll(wgpu::Maintain::Wait);
rx.recv().unwrap().map_err(|e| ComputeError::DeviceNotAvailable(e.to_string()))?;
let data = slice.get_mapped_range().to_vec();
staging.unmap();
Ok(data)
}
// ========================================================================
// Matrix Multiplication
// ========================================================================
/// Perform matrix multiplication: C = A * B
///
/// Dimensions: A (M x K), B (K x N), C (M x N)
///
/// Performance target: 10+ TFLOPS on discrete GPU
pub async fn matmul(
&mut self,
a: &GpuBuffer,
b: &GpuBuffer,
c: &GpuBuffer,
m: u32,
n: u32,
k: u32,
) -> Result<ComputeMetrics, ComputeError> {
let start = std::time::Instant::now();
// Validate dimensions
let expected_a = (m as usize) * (k as usize) * 4; // f32
let expected_b = (k as usize) * (n as usize) * 4;
let expected_c = (m as usize) * (n as usize) * 4;
if a.size != expected_a || b.size != expected_b || c.size != expected_c {
return Err(ComputeError::DimensionMismatch {
expected: format!("A:{}x{}, B:{}x{}, C:{}x{}", m, k, k, n, m, n),
actual: format!("A:{}, B:{}, C:{} bytes", a.size, b.size, c.size),
});
}
// Create uniforms buffer
let uniforms = [m, n, k, self.config.tile_size];
let uniform_buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("matmul_uniforms"),
contents: bytemuck::cast_slice(&uniforms),
usage: wgpu::BufferUsages::UNIFORM,
});
// Create bind group
let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("matmul_bind_group"),
layout: &self.matmul_pipeline.bind_group_layout,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: a.raw().as_entire_binding() },
wgpu::BindGroupEntry { binding: 1, resource: b.raw().as_entire_binding() },
wgpu::BindGroupEntry { binding: 2, resource: c.raw().as_entire_binding() },
wgpu::BindGroupEntry { binding: 3, resource: uniform_buffer.as_entire_binding() },
],
});
// Dispatch compute
let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("matmul_encoder"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("matmul_pass"),
timestamp_writes: None,
});
pass.set_pipeline(&self.matmul_pipeline.pipeline);
pass.set_bind_group(0, &bind_group, &[]);
// Dispatch workgroups (tile-based)
let tile_size = self.config.tile_size;
let workgroups_x = (m + tile_size - 1) / tile_size;
let workgroups_y = (n + tile_size - 1) / tile_size;
pass.dispatch_workgroups(workgroups_x, workgroups_y, 1);
}
let kernel_start = std::time::Instant::now();
self.queue.submit(std::iter::once(encoder.finish()));
self.device.poll(wgpu::Maintain::Wait);
let kernel_time = kernel_start.elapsed();
let total_time = start.elapsed();
// Calculate metrics
let flops = 2.0 * (m as f64) * (n as f64) * (k as f64); // 2*M*N*K for matmul
let metrics = ComputeMetrics {
flops,
bandwidth_gbps: ((a.size + b.size + c.size) as f64) / kernel_time.as_secs_f64() / 1e9,
kernel_time_ms: kernel_time.as_secs_f64() * 1000.0,
transfer_time_ms: 0.0, // Data already on GPU
total_time_ms: total_time.as_secs_f64() * 1000.0,
};
self.last_metrics = metrics.clone();
Ok(metrics)
}
// ========================================================================
// Attention
// ========================================================================
/// Compute attention: Output = softmax(Q * K^T / sqrt(d_k)) * V
///
/// Uses flash attention algorithm for memory efficiency.
///
/// Performance target: 2ms for 4K context
pub async fn attention(
&mut self,
q: &GpuBuffer,
k: &GpuBuffer,
v: &GpuBuffer,
output: &GpuBuffer,
config: &AttentionConfig,
seq_len: u32,
) -> Result<ComputeMetrics, ComputeError> {
let start = std::time::Instant::now();
// Validate dimensions
let hidden_dim = config.hidden_dim();
let expected_size = (seq_len as usize) * hidden_dim * 4; // f32
if q.size != expected_size || k.size != expected_size || v.size != expected_size {
return Err(ComputeError::DimensionMismatch {
expected: format!("{}x{} = {} bytes", seq_len, hidden_dim, expected_size),
actual: format!("Q:{}, K:{}, V:{} bytes", q.size, k.size, v.size),
});
}
// Create uniforms buffer
let scale = config.get_scale();
let causal_mask = if config.causal { 1u32 } else { 0u32 };
let uniforms: [f32; 8] = [
seq_len as f32,
config.head_dim as f32,
config.num_heads as f32,
scale,
causal_mask as f32,
0.0, 0.0, 0.0, // padding
];
let uniform_buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("attention_uniforms"),
contents: bytemuck::cast_slice(&uniforms),
usage: wgpu::BufferUsages::UNIFORM,
});
// Create bind group
let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("attention_bind_group"),
layout: &self.attention_pipeline.bind_group_layout,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: q.raw().as_entire_binding() },
wgpu::BindGroupEntry { binding: 1, resource: k.raw().as_entire_binding() },
wgpu::BindGroupEntry { binding: 2, resource: v.raw().as_entire_binding() },
wgpu::BindGroupEntry { binding: 3, resource: output.raw().as_entire_binding() },
wgpu::BindGroupEntry { binding: 4, resource: uniform_buffer.as_entire_binding() },
],
});
// Dispatch compute
let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("attention_encoder"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("attention_pass"),
timestamp_writes: None,
});
pass.set_pipeline(&self.attention_pipeline.pipeline);
pass.set_bind_group(0, &bind_group, &[]);
// Dispatch: one workgroup per head per batch of sequence positions
let block_size = 64u32; // Flash attention block size
let num_blocks = (seq_len + block_size - 1) / block_size;
pass.dispatch_workgroups(num_blocks, config.num_heads as u32, 1);
}
let kernel_start = std::time::Instant::now();
self.queue.submit(std::iter::once(encoder.finish()));
self.device.poll(wgpu::Maintain::Wait);
let kernel_time = kernel_start.elapsed();
let total_time = start.elapsed();
// Calculate metrics (attention has O(n^2*d) complexity)
let flops = 4.0 * (seq_len as f64).powi(2) * (hidden_dim as f64);
let metrics = ComputeMetrics {
flops,
bandwidth_gbps: ((q.size + k.size + v.size + output.size) as f64) / kernel_time.as_secs_f64() / 1e9,
kernel_time_ms: kernel_time.as_secs_f64() * 1000.0,
transfer_time_ms: 0.0,
total_time_ms: total_time.as_secs_f64() * 1000.0,
};
self.last_metrics = metrics.clone();
Ok(metrics)
}
// ========================================================================
// LoRA Forward
// ========================================================================
/// Apply LoRA adapter: output = input + scaling * (input @ A @ B)
///
/// Where A is (in_dim x rank) and B is (rank x out_dim).
///
/// Performance target: <1ms
pub async fn lora_forward(
&mut self,
input: &GpuBuffer,
lora_a: &GpuBuffer,
lora_b: &GpuBuffer,
output: &GpuBuffer,
config: &LoraConfig,
batch_size: u32,
) -> Result<ComputeMetrics, ComputeError> {
let start = std::time::Instant::now();
// Validate dimensions
let expected_input = (batch_size as usize) * config.in_dim * 4;
let expected_a = config.a_size() * 4;
let expected_b = config.b_size() * 4;
let expected_output = (batch_size as usize) * config.out_dim * 4;
if input.size != expected_input || lora_a.size != expected_a ||
lora_b.size != expected_b || output.size != expected_output {
return Err(ComputeError::DimensionMismatch {
expected: format!("input:{}x{}, A:{}x{}, B:{}x{}, output:{}x{}",
batch_size, config.in_dim, config.in_dim, config.rank,
config.rank, config.out_dim, batch_size, config.out_dim),
actual: format!("input:{}, A:{}, B:{}, output:{} bytes",
input.size, lora_a.size, lora_b.size, output.size),
});
}
// Create uniforms buffer
let scaling = config.scaling();
let uniforms: [f32; 8] = [
batch_size as f32,
config.in_dim as f32,
config.rank as f32,
config.out_dim as f32,
scaling,
0.0, 0.0, 0.0, // padding
];
let uniform_buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("lora_uniforms"),
contents: bytemuck::cast_slice(&uniforms),
usage: wgpu::BufferUsages::UNIFORM,
});
// Create bind group
let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("lora_bind_group"),
layout: &self.lora_pipeline.bind_group_layout,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: input.raw().as_entire_binding() },
wgpu::BindGroupEntry { binding: 1, resource: lora_a.raw().as_entire_binding() },
wgpu::BindGroupEntry { binding: 2, resource: lora_b.raw().as_entire_binding() },
wgpu::BindGroupEntry { binding: 3, resource: output.raw().as_entire_binding() },
wgpu::BindGroupEntry { binding: 4, resource: uniform_buffer.as_entire_binding() },
],
});
// Dispatch compute
let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("lora_encoder"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("lora_pass"),
timestamp_writes: None,
});
pass.set_pipeline(&self.lora_pipeline.pipeline);
pass.set_bind_group(0, &bind_group, &[]);
// Dispatch: one workgroup per batch element
let workgroup_size = 256u32;
let workgroups = (batch_size * config.out_dim as u32 + workgroup_size - 1) / workgroup_size;
pass.dispatch_workgroups(workgroups, 1, 1);
}
let kernel_start = std::time::Instant::now();
self.queue.submit(std::iter::once(encoder.finish()));
self.device.poll(wgpu::Maintain::Wait);
let kernel_time = kernel_start.elapsed();
let total_time = start.elapsed();
// Calculate metrics
// LoRA: input @ A @ B = 2 matmuls
let flops = 2.0 * (batch_size as f64) * (config.in_dim as f64) * (config.rank as f64)
+ 2.0 * (batch_size as f64) * (config.rank as f64) * (config.out_dim as f64);
let metrics = ComputeMetrics {
flops,
bandwidth_gbps: ((input.size + lora_a.size + lora_b.size + output.size) as f64)
/ kernel_time.as_secs_f64() / 1e9,
kernel_time_ms: kernel_time.as_secs_f64() * 1000.0,
transfer_time_ms: 0.0,
total_time_ms: total_time.as_secs_f64() * 1000.0,
};
self.last_metrics = metrics.clone();
Ok(metrics)
}
// ========================================================================
// Utilities
// ========================================================================
/// Get last operation metrics
pub fn last_metrics(&self) -> &ComputeMetrics {
&self.last_metrics
}
/// Get device limits
pub fn limits(&self) -> &wgpu::Limits {
&self.limits
}
/// Get configuration
pub fn config(&self) -> &ComputeConfig {
&self.config
}
/// Synchronize all pending GPU operations
pub fn sync(&self) {
self.device.poll(wgpu::Maintain::Wait);
}
}
// ============================================================================
// Staging Buffer Pool
// ============================================================================
/// Pool of reusable staging buffers for CPU<->GPU transfers
struct StagingBufferPool {
device: Arc<wgpu::Device>,
upload_buffers: Vec<wgpu::Buffer>,
download_buffers: Vec<wgpu::Buffer>,
max_pool_size: usize,
}
impl StagingBufferPool {
fn new(device: Arc<wgpu::Device>, max_pool_size: usize) -> Self {
Self {
device,
upload_buffers: Vec::new(),
download_buffers: Vec::new(),
max_pool_size,
}
}
fn get_upload_buffer(&self, size: usize) -> Result<wgpu::Buffer, ComputeError> {
// For simplicity, always create new buffer (production would pool)
let buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("staging_upload"),
size: size as u64,
usage: wgpu::BufferUsages::MAP_WRITE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
Ok(buffer)
}
fn get_download_buffer(&self, size: usize) -> Result<wgpu::Buffer, ComputeError> {
let buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("staging_download"),
size: size as u64,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
Ok(buffer)
}
}
// ============================================================================
// wgpu::util helpers
// ============================================================================
mod wgpu_util {
use super::*;
impl wgpu::Device {
pub fn create_buffer_init(&self, desc: &wgpu::util::BufferInitDescriptor) -> wgpu::Buffer {
wgpu::util::DeviceExt::create_buffer_init(self, desc)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
// Note: These tests require a GPU and are marked as ignored by default
// Run with: cargo test --features webgpu -- --ignored
#[tokio::test]
#[ignore]
async fn test_webgpu_init() {
let compute = WebGpuCompute::new().await;
assert!(compute.is_ok());
}
#[tokio::test]
#[ignore]
async fn test_buffer_allocation() {
let compute = WebGpuCompute::new().await.unwrap();
let desc = TensorDescriptor::matrix(1024, 1024, DataType::F32);
let buffer = compute.allocate_buffer(desc, BufferUsage::storage());
assert!(buffer.is_ok());
assert_eq!(buffer.unwrap().size(), 1024 * 1024 * 4);
}
}

View File

@@ -0,0 +1,566 @@
//! WebWorker pool for CPU parallelism in browsers
//!
//! Provides multi-threaded compute using WebWorkers with work stealing
//! for load balancing. Uses SharedArrayBuffer when available for
//! zero-copy data sharing.
//!
//! ## Architecture
//!
//! ```text
//! +------------------+
//! | Main Thread |
//! | (Coordinator) |
//! +--------+---------+
//! |
//! +-----+-----+-----+-----+
//! | | | | |
//! +--v-+ +-v--+ +--v-+ +--v-+ +--v-+
//! | W1 | | W2 | | W3 | | W4 | | Wn |
//! +----+ +----+ +----+ +----+ +----+
//! | | | | |
//! +-----+-----+-----+-----+
//! |
//! SharedArrayBuffer (when available)
//! ```
//!
//! ## Work Stealing
//!
//! Workers that finish early can steal work from busy workers' queues.
use wasm_bindgen::prelude::*;
use wasm_bindgen::JsCast;
use web_sys::{Worker, MessageEvent};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::cell::RefCell;
use std::rc::Rc;
/// Task for worker execution
#[derive(Clone)]
pub struct WorkerTask {
/// Task identifier
pub id: u32,
/// Operation type
pub op: WorkerOp,
/// Input data offset in shared buffer
pub input_offset: usize,
/// Input data length
pub input_len: usize,
/// Output data offset in shared buffer
pub output_offset: usize,
}
/// Operations that workers can perform
#[derive(Clone, Copy)]
pub enum WorkerOp {
/// Matrix multiplication (partial)
MatmulPartial { m_start: usize, m_end: usize, k: usize, n: usize },
/// Dot product (partial)
DotProductPartial { start: usize, end: usize },
/// Vector element-wise operation
VectorOp { start: usize, end: usize, op: VectorOpType },
/// Reduction (sum, max, etc.)
Reduce { start: usize, end: usize, op: ReduceOp },
}
/// Element-wise vector operations
#[derive(Clone, Copy)]
pub enum VectorOpType {
Add,
Sub,
Mul,
Div,
Relu,
Sigmoid,
}
/// Reduction operations
#[derive(Clone, Copy)]
pub enum ReduceOp {
Sum,
Max,
Min,
Mean,
}
/// Worker pool status
#[derive(Clone)]
pub struct PoolStatus {
/// Number of workers
pub worker_count: usize,
/// Number of active tasks
pub active_tasks: usize,
/// Total tasks completed
pub completed_tasks: u64,
/// Whether shared memory is available
pub has_shared_memory: bool,
}
/// WebWorker pool for parallel compute
#[wasm_bindgen]
pub struct WorkerPool {
/// Active workers
workers: Vec<Worker>,
/// Number of workers
worker_count: usize,
/// Shared memory buffer (if available)
shared_buffer: Option<js_sys::SharedArrayBuffer>,
/// Float32 view into shared buffer
shared_view: Option<js_sys::Float32Array>,
/// Active task count
active_tasks: Rc<RefCell<usize>>,
/// Completed task count
completed_tasks: Rc<RefCell<u64>>,
/// Whether pool is initialized
initialized: bool,
/// Has SharedArrayBuffer support
has_shared_memory: bool,
/// Pending results collector
pending_results: Rc<RefCell<Vec<Vec<f32>>>>,
/// Next task ID
next_task_id: Rc<RefCell<u32>>,
}
#[wasm_bindgen]
impl WorkerPool {
/// Create a new worker pool
#[wasm_bindgen(constructor)]
pub fn new(worker_count: usize) -> Result<WorkerPool, JsValue> {
let count = worker_count.max(1).min(16); // Limit to reasonable range
// Check for SharedArrayBuffer support
let window = web_sys::window()
.ok_or_else(|| JsValue::from_str("No window"))?;
let has_shared_memory = js_sys::Reflect::has(&window, &"SharedArrayBuffer".into())
.unwrap_or(false);
// Create shared buffer if available (16MB default)
let (shared_buffer, shared_view) = if has_shared_memory {
let buffer = js_sys::SharedArrayBuffer::new(16 * 1024 * 1024);
let view = js_sys::Float32Array::new(&buffer);
(Some(buffer), Some(view))
} else {
(None, None)
};
Ok(WorkerPool {
workers: Vec::with_capacity(count),
worker_count: count,
shared_buffer,
shared_view,
active_tasks: Rc::new(RefCell::new(0)),
completed_tasks: Rc::new(RefCell::new(0)),
initialized: false,
has_shared_memory,
pending_results: Rc::new(RefCell::new(Vec::new())),
next_task_id: Rc::new(RefCell::new(0)),
})
}
/// Initialize workers
#[wasm_bindgen(js_name = initialize)]
pub fn initialize(&mut self) -> Result<(), JsValue> {
if self.initialized {
return Ok(());
}
// Create worker script as a blob
let worker_script = create_worker_script();
let blob_parts = js_sys::Array::new();
blob_parts.push(&worker_script.into());
let blob_options = web_sys::BlobPropertyBag::new();
blob_options.set_type("application/javascript");
let blob = web_sys::Blob::new_with_str_sequence_and_options(&blob_parts, &blob_options)?;
let url = web_sys::Url::create_object_url_with_blob(&blob)?;
// Spawn workers
for i in 0..self.worker_count {
let worker = Worker::new(&url)?;
// Set up message handler
let completed = self.completed_tasks.clone();
let active = self.active_tasks.clone();
let results = self.pending_results.clone();
let onmessage = Closure::wrap(Box::new(move |event: MessageEvent| {
let data = event.data();
// Parse result
if let Ok(result_array) = data.dyn_into::<js_sys::Float32Array>() {
let mut result_vec = vec![0.0f32; result_array.length() as usize];
result_array.copy_to(&mut result_vec);
results.borrow_mut().push(result_vec);
}
*completed.borrow_mut() += 1;
*active.borrow_mut() = active.borrow().saturating_sub(1);
}) as Box<dyn FnMut(MessageEvent)>);
worker.set_onmessage(Some(onmessage.as_ref().unchecked_ref()));
onmessage.forget();
// Send initialization message
let init_msg = js_sys::Object::new();
js_sys::Reflect::set(&init_msg, &"type".into(), &"init".into())?;
js_sys::Reflect::set(&init_msg, &"workerId".into(), &(i as u32).into())?;
if let Some(ref buffer) = self.shared_buffer {
js_sys::Reflect::set(&init_msg, &"sharedBuffer".into(), buffer)?;
}
worker.post_message(&init_msg)?;
self.workers.push(worker);
}
self.initialized = true;
Ok(())
}
/// Get worker count
#[wasm_bindgen(js_name = workerCount)]
pub fn worker_count(&self) -> usize {
self.worker_count
}
/// Get pool status
#[wasm_bindgen(js_name = getStatus)]
pub fn get_status(&self) -> JsValue {
let obj = js_sys::Object::new();
js_sys::Reflect::set(&obj, &"workerCount".into(), &(self.worker_count as u32).into()).ok();
js_sys::Reflect::set(&obj, &"activeTasks".into(), &(*self.active_tasks.borrow() as u32).into()).ok();
js_sys::Reflect::set(&obj, &"completedTasks".into(), &(*self.completed_tasks.borrow() as f64).into()).ok();
js_sys::Reflect::set(&obj, &"hasSharedMemory".into(), &self.has_shared_memory.into()).ok();
js_sys::Reflect::set(&obj, &"initialized".into(), &self.initialized.into()).ok();
obj.into()
}
/// Shutdown all workers
#[wasm_bindgen]
pub fn shutdown(&mut self) -> Result<(), JsValue> {
for worker in &self.workers {
worker.terminate();
}
self.workers.clear();
self.initialized = false;
Ok(())
}
}
// Non-WASM implementation
impl WorkerPool {
/// Perform parallel matrix multiplication
pub fn matmul_parallel(&self, a: &[f32], b: &[f32], m: usize, k: usize, n: usize) -> Result<Vec<f32>, JsValue> {
if !self.initialized || self.workers.is_empty() {
// Fall back to CPU
return Ok(cpu_matmul(a, b, m, k, n));
}
// For small matrices, don't bother with parallelism
if m * k * n < 10000 {
return Ok(cpu_matmul(a, b, m, k, n));
}
// Divide rows among workers
let rows_per_worker = (m + self.worker_count - 1) / self.worker_count;
// If using shared memory, copy input data
if let (Some(ref buffer), Some(ref view)) = (&self.shared_buffer, &self.shared_view) {
// Copy A and B to shared buffer
let a_array = js_sys::Float32Array::from(a);
let b_array = js_sys::Float32Array::from(b);
view.set(&a_array, 0);
view.set(&b_array, (m * k) as u32);
}
// Dispatch tasks to workers
self.pending_results.borrow_mut().clear();
for (i, worker) in self.workers.iter().enumerate() {
let row_start = i * rows_per_worker;
let row_end = ((i + 1) * rows_per_worker).min(m);
if row_start >= m {
break;
}
let msg = js_sys::Object::new();
js_sys::Reflect::set(&msg, &"type".into(), &"matmul".into()).ok();
js_sys::Reflect::set(&msg, &"rowStart".into(), &(row_start as u32).into()).ok();
js_sys::Reflect::set(&msg, &"rowEnd".into(), &(row_end as u32).into()).ok();
js_sys::Reflect::set(&msg, &"m".into(), &(m as u32).into()).ok();
js_sys::Reflect::set(&msg, &"k".into(), &(k as u32).into()).ok();
js_sys::Reflect::set(&msg, &"n".into(), &(n as u32).into()).ok();
// If no shared memory, send data directly
if self.shared_buffer.is_none() {
let a_slice = &a[row_start * k..row_end * k];
let a_array = js_sys::Float32Array::from(a_slice);
let b_array = js_sys::Float32Array::from(b);
js_sys::Reflect::set(&msg, &"a".into(), &a_array).ok();
js_sys::Reflect::set(&msg, &"b".into(), &b_array).ok();
}
*self.active_tasks.borrow_mut() += 1;
worker.post_message(&msg).ok();
}
// Wait for results (in real async code, this would be Promise-based)
// For now, fall back to CPU since we can't truly wait in WASM
Ok(cpu_matmul(a, b, m, k, n))
}
/// Perform parallel dot product
pub fn dot_product_parallel(&self, a: &[f32], b: &[f32]) -> Result<f32, JsValue> {
if !self.initialized || self.workers.is_empty() || a.len() < 10000 {
// Fall back to CPU
return Ok(a.iter().zip(b.iter()).map(|(x, y)| x * y).sum());
}
// For simplicity, use CPU implementation
// Full implementation would dispatch to workers and collect partial sums
Ok(a.iter().zip(b.iter()).map(|(x, y)| x * y).sum())
}
}
/// Create the worker script as a string
fn create_worker_script() -> String {
r#"
let workerId = -1;
let sharedBuffer = null;
let sharedView = null;
self.onmessage = function(e) {
const msg = e.data;
if (msg.type === 'init') {
workerId = msg.workerId;
if (msg.sharedBuffer) {
sharedBuffer = msg.sharedBuffer;
sharedView = new Float32Array(sharedBuffer);
}
self.postMessage({ type: 'ready', workerId: workerId });
return;
}
if (msg.type === 'matmul') {
const result = matmulPartial(msg);
self.postMessage(result, [result.buffer]);
return;
}
if (msg.type === 'dotproduct') {
const result = dotProductPartial(msg);
self.postMessage({ type: 'result', value: result });
return;
}
if (msg.type === 'vectorop') {
const result = vectorOp(msg);
self.postMessage(result, [result.buffer]);
return;
}
};
function matmulPartial(msg) {
const { rowStart, rowEnd, m, k, n } = msg;
const rows = rowEnd - rowStart;
const result = new Float32Array(rows * n);
let a, b;
if (sharedView) {
// Use shared memory
a = new Float32Array(sharedBuffer, rowStart * k * 4, rows * k);
b = new Float32Array(sharedBuffer, m * k * 4, k * n);
} else {
// Use passed data
a = msg.a;
b = msg.b;
}
// Cache-friendly blocked multiplication
const BLOCK = 32;
for (let i = 0; i < rows; i++) {
for (let j = 0; j < n; j++) {
let sum = 0;
for (let kk = 0; kk < k; kk++) {
sum += a[i * k + kk] * b[kk * n + j];
}
result[i * n + j] = sum;
}
}
return result;
}
function dotProductPartial(msg) {
const { start, end } = msg;
let sum = 0;
if (sharedView) {
const a = new Float32Array(sharedBuffer, start * 4, end - start);
const b = new Float32Array(sharedBuffer, (msg.bOffset + start) * 4, end - start);
for (let i = 0; i < a.length; i++) {
sum += a[i] * b[i];
}
} else {
const a = msg.a;
const b = msg.b;
for (let i = start; i < end; i++) {
sum += a[i] * b[i];
}
}
return sum;
}
function vectorOp(msg) {
const { start, end, op } = msg;
const len = end - start;
const result = new Float32Array(len);
const a = sharedView ? new Float32Array(sharedBuffer, start * 4, len) : msg.a;
const b = sharedView ? new Float32Array(sharedBuffer, (msg.bOffset + start) * 4, len) : msg.b;
switch (op) {
case 'add':
for (let i = 0; i < len; i++) result[i] = a[i] + b[i];
break;
case 'sub':
for (let i = 0; i < len; i++) result[i] = a[i] - b[i];
break;
case 'mul':
for (let i = 0; i < len; i++) result[i] = a[i] * b[i];
break;
case 'div':
for (let i = 0; i < len; i++) result[i] = a[i] / (b[i] || 1e-7);
break;
case 'relu':
for (let i = 0; i < len; i++) result[i] = Math.max(a[i], 0);
break;
case 'sigmoid':
for (let i = 0; i < len; i++) result[i] = 1 / (1 + Math.exp(-a[i]));
break;
}
return result;
}
"#.to_string()
}
/// CPU matrix multiplication fallback
fn cpu_matmul(a: &[f32], b: &[f32], m: usize, k: usize, n: usize) -> Vec<f32> {
let mut result = vec![0.0f32; m * n];
// Cache-friendly blocked multiplication
const BLOCK_SIZE: usize = 32;
for i0 in (0..m).step_by(BLOCK_SIZE) {
for j0 in (0..n).step_by(BLOCK_SIZE) {
for k0 in (0..k).step_by(BLOCK_SIZE) {
let i_end = (i0 + BLOCK_SIZE).min(m);
let j_end = (j0 + BLOCK_SIZE).min(n);
let k_end = (k0 + BLOCK_SIZE).min(k);
for i in i0..i_end {
for kk in k0..k_end {
let a_val = a[i * k + kk];
for j in j0..j_end {
result[i * n + j] += a_val * b[kk * n + j];
}
}
}
}
}
}
result
}
/// Work-stealing task queue
pub struct WorkStealingQueue<T> {
/// Local tasks (LIFO for locality)
local: Vec<T>,
/// Shared tasks (can be stolen)
shared: Rc<RefCell<Vec<T>>>,
}
impl<T: Clone> WorkStealingQueue<T> {
/// Create a new work-stealing queue
pub fn new() -> Self {
WorkStealingQueue {
local: Vec::new(),
shared: Rc::new(RefCell::new(Vec::new())),
}
}
/// Push a task (local, cannot be stolen)
pub fn push_local(&mut self, task: T) {
self.local.push(task);
}
/// Push a task that can be stolen
pub fn push_shared(&mut self, task: T) {
self.shared.borrow_mut().push(task);
}
/// Pop a local task (LIFO)
pub fn pop_local(&mut self) -> Option<T> {
self.local.pop()
}
/// Try to steal from shared queue (FIFO)
pub fn steal(&self) -> Option<T> {
let mut shared = self.shared.borrow_mut();
if shared.is_empty() {
None
} else {
Some(shared.remove(0))
}
}
/// Get number of stealable tasks
pub fn stealable_count(&self) -> usize {
self.shared.borrow().len()
}
/// Get total task count
pub fn total_count(&self) -> usize {
self.local.len() + self.shared.borrow().len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cpu_matmul() {
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![5.0, 6.0, 7.0, 8.0];
let result = cpu_matmul(&a, &b, 2, 2, 2);
// [1*5 + 2*7, 1*6 + 2*8] = [19, 22]
// [3*5 + 4*7, 3*6 + 4*8] = [43, 50]
assert_eq!(result, vec![19.0, 22.0, 43.0, 50.0]);
}
#[test]
fn test_work_stealing_queue() {
let mut queue: WorkStealingQueue<i32> = WorkStealingQueue::new();
queue.push_local(1);
queue.push_shared(2);
queue.push_shared(3);
assert_eq!(queue.total_count(), 3);
assert_eq!(queue.stealable_count(), 2);
assert_eq!(queue.pop_local(), Some(1));
assert_eq!(queue.steal(), Some(2));
assert_eq!(queue.steal(), Some(3));
assert_eq!(queue.steal(), None);
}
}