Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
283
vendor/ruvector/examples/edge-net/src/compute/backend.rs
vendored
Normal file
283
vendor/ruvector/examples/edge-net/src/compute/backend.rs
vendored
Normal file
@@ -0,0 +1,283 @@
|
||||
//! Compute backend detection and abstraction
|
||||
//!
|
||||
//! Detects available compute capabilities (WebGPU, WebGL2, WebWorkers)
|
||||
//! and provides a unified interface for selecting the best backend.
|
||||
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
/// Compute capabilities detected on the current device
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ComputeCapability {
|
||||
/// WebGPU is available (best performance)
|
||||
pub has_webgpu: bool,
|
||||
/// WebGL2 is available (fallback for GPU compute)
|
||||
pub has_webgl2: bool,
|
||||
/// WebGL2 supports floating point textures
|
||||
pub has_float_textures: bool,
|
||||
/// Transform feedback is available (for GPU readback)
|
||||
pub has_transform_feedback: bool,
|
||||
/// WebWorkers are available
|
||||
pub has_workers: bool,
|
||||
/// SharedArrayBuffer is available (for shared memory)
|
||||
pub has_shared_memory: bool,
|
||||
/// Number of logical CPU cores
|
||||
pub worker_count: usize,
|
||||
/// Maximum texture size (for WebGL2)
|
||||
pub max_texture_size: u32,
|
||||
/// Estimated GPU memory (MB)
|
||||
pub gpu_memory_mb: u32,
|
||||
/// Device description
|
||||
pub device_info: String,
|
||||
}
|
||||
|
||||
impl ComputeCapability {
|
||||
/// Convert to JavaScript object
|
||||
pub fn to_js(&self) -> JsValue {
|
||||
let obj = js_sys::Object::new();
|
||||
|
||||
js_sys::Reflect::set(&obj, &"hasWebGPU".into(), &self.has_webgpu.into()).ok();
|
||||
js_sys::Reflect::set(&obj, &"hasWebGL2".into(), &self.has_webgl2.into()).ok();
|
||||
js_sys::Reflect::set(&obj, &"hasFloatTextures".into(), &self.has_float_textures.into()).ok();
|
||||
js_sys::Reflect::set(&obj, &"hasTransformFeedback".into(), &self.has_transform_feedback.into()).ok();
|
||||
js_sys::Reflect::set(&obj, &"hasWorkers".into(), &self.has_workers.into()).ok();
|
||||
js_sys::Reflect::set(&obj, &"hasSharedMemory".into(), &self.has_shared_memory.into()).ok();
|
||||
js_sys::Reflect::set(&obj, &"workerCount".into(), &(self.worker_count as u32).into()).ok();
|
||||
js_sys::Reflect::set(&obj, &"maxTextureSize".into(), &self.max_texture_size.into()).ok();
|
||||
js_sys::Reflect::set(&obj, &"gpuMemoryMB".into(), &self.gpu_memory_mb.into()).ok();
|
||||
js_sys::Reflect::set(&obj, &"deviceInfo".into(), &self.device_info.clone().into()).ok();
|
||||
|
||||
obj.into()
|
||||
}
|
||||
|
||||
/// Get recommended backend for a given operation size
|
||||
pub fn recommend_backend(&self, operation_size: usize) -> ComputeBackend {
|
||||
// WebGPU is always preferred if available
|
||||
if self.has_webgpu {
|
||||
return ComputeBackend::WebGPU;
|
||||
}
|
||||
|
||||
// For large operations, prefer GPU
|
||||
if operation_size > 4096 && self.has_webgl2 && self.has_float_textures {
|
||||
return ComputeBackend::WebGL2;
|
||||
}
|
||||
|
||||
// For medium operations with multiple cores, use workers
|
||||
if operation_size > 1024 && self.has_workers && self.worker_count > 1 {
|
||||
return ComputeBackend::WebWorkers;
|
||||
}
|
||||
|
||||
// Fall back to single-threaded CPU
|
||||
ComputeBackend::CPU
|
||||
}
|
||||
}
|
||||
|
||||
/// Available compute backends
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum ComputeBackend {
|
||||
/// WebGPU compute shaders (best performance)
|
||||
WebGPU,
|
||||
/// WebGL2 texture-based compute (fallback GPU)
|
||||
WebGL2,
|
||||
/// WebWorker pool (CPU parallelism)
|
||||
WebWorkers,
|
||||
/// Single-threaded CPU (last resort)
|
||||
CPU,
|
||||
}
|
||||
|
||||
impl ComputeBackend {
|
||||
/// Get backend name
|
||||
pub fn name(&self) -> &'static str {
|
||||
match self {
|
||||
ComputeBackend::WebGPU => "WebGPU",
|
||||
ComputeBackend::WebGL2 => "WebGL2",
|
||||
ComputeBackend::WebWorkers => "WebWorkers",
|
||||
ComputeBackend::CPU => "CPU",
|
||||
}
|
||||
}
|
||||
|
||||
/// Get relative performance (higher is better)
|
||||
pub fn relative_performance(&self) -> f32 {
|
||||
match self {
|
||||
ComputeBackend::WebGPU => 10.0,
|
||||
ComputeBackend::WebGL2 => 5.0,
|
||||
ComputeBackend::WebWorkers => 2.0,
|
||||
ComputeBackend::CPU => 1.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Detect compute capabilities on the current device
|
||||
pub fn detect_capabilities() -> Result<ComputeCapability, JsValue> {
|
||||
let window = web_sys::window()
|
||||
.ok_or_else(|| JsValue::from_str("No window object"))?;
|
||||
|
||||
let navigator = window.navigator();
|
||||
|
||||
// Detect WebGPU
|
||||
let has_webgpu = js_sys::Reflect::has(&navigator, &"gpu".into())
|
||||
.unwrap_or(false);
|
||||
|
||||
// Detect WebWorkers
|
||||
let has_workers = js_sys::Reflect::has(&window, &"Worker".into())
|
||||
.unwrap_or(false);
|
||||
|
||||
// Detect SharedArrayBuffer
|
||||
let has_shared_memory = js_sys::Reflect::has(&window, &"SharedArrayBuffer".into())
|
||||
.unwrap_or(false);
|
||||
|
||||
// Get hardware concurrency (CPU cores)
|
||||
let worker_count = navigator.hardware_concurrency() as usize;
|
||||
|
||||
// Detect WebGL2 capabilities
|
||||
let document = window.document()
|
||||
.ok_or_else(|| JsValue::from_str("No document"))?;
|
||||
|
||||
let (has_webgl2, has_float_textures, has_transform_feedback, max_texture_size, gpu_memory_mb, device_info) =
|
||||
detect_webgl2_capabilities(&document)?;
|
||||
|
||||
Ok(ComputeCapability {
|
||||
has_webgpu,
|
||||
has_webgl2,
|
||||
has_float_textures,
|
||||
has_transform_feedback,
|
||||
has_workers,
|
||||
has_shared_memory,
|
||||
worker_count: worker_count.max(1),
|
||||
max_texture_size,
|
||||
gpu_memory_mb,
|
||||
device_info,
|
||||
})
|
||||
}
|
||||
|
||||
/// Detect WebGL2-specific capabilities
|
||||
fn detect_webgl2_capabilities(document: &web_sys::Document) -> Result<(bool, bool, bool, u32, u32, String), JsValue> {
|
||||
// Create a temporary canvas to probe WebGL2
|
||||
let canvas = document.create_element("canvas")?;
|
||||
let canvas: web_sys::HtmlCanvasElement = canvas.dyn_into()?;
|
||||
|
||||
// Try to get WebGL2 context
|
||||
let context = match canvas.get_context("webgl2")? {
|
||||
Some(ctx) => ctx,
|
||||
None => return Ok((false, false, false, 0, 0, "No WebGL2".to_string())),
|
||||
};
|
||||
|
||||
let gl: web_sys::WebGl2RenderingContext = context.dyn_into()?;
|
||||
|
||||
// Check for float texture support (required for compute)
|
||||
let ext_color_buffer_float = gl.get_extension("EXT_color_buffer_float")?;
|
||||
let has_float_textures = ext_color_buffer_float.is_some();
|
||||
|
||||
// Transform feedback is built into WebGL2
|
||||
let has_transform_feedback = true;
|
||||
|
||||
// Get max texture size
|
||||
let max_texture_size = gl.get_parameter(web_sys::WebGl2RenderingContext::MAX_TEXTURE_SIZE)?
|
||||
.as_f64()
|
||||
.unwrap_or(4096.0) as u32;
|
||||
|
||||
// Try to get GPU memory info (vendor-specific)
|
||||
let gpu_memory_mb = get_gpu_memory_mb(&gl);
|
||||
|
||||
// Get renderer info
|
||||
let renderer_info = gl.get_extension("WEBGL_debug_renderer_info")?;
|
||||
let device_info = if renderer_info.is_some() {
|
||||
// UNMASKED_RENDERER_WEBGL = 0x9246
|
||||
let renderer = gl.get_parameter(0x9246)?;
|
||||
renderer.as_string().unwrap_or_else(|| "Unknown GPU".to_string())
|
||||
} else {
|
||||
"Unknown GPU".to_string()
|
||||
};
|
||||
|
||||
Ok((true, has_float_textures, has_transform_feedback, max_texture_size, gpu_memory_mb, device_info))
|
||||
}
|
||||
|
||||
/// Try to get GPU memory size (vendor-specific extension)
|
||||
fn get_gpu_memory_mb(gl: &web_sys::WebGl2RenderingContext) -> u32 {
|
||||
// Try WEBGL_memory_info extension (available on some browsers)
|
||||
if let Ok(Some(_ext)) = gl.get_extension("WEBGL_memory_info") {
|
||||
// GPU_MEMORY_INFO_TOTAL_AVAILABLE_MEMORY_NVX = 0x9048
|
||||
if let Ok(mem) = gl.get_parameter(0x9048) {
|
||||
if let Some(kb) = mem.as_f64() {
|
||||
return (kb / 1024.0) as u32;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Default estimate based on typical mobile/desktop GPUs
|
||||
// Most modern GPUs have at least 2GB
|
||||
2048
|
||||
}
|
||||
|
||||
/// Configuration for compute operations
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ComputeConfig {
|
||||
/// Preferred backend (None = auto-select)
|
||||
pub preferred_backend: Option<ComputeBackend>,
|
||||
/// Maximum memory to use (bytes)
|
||||
pub max_memory: usize,
|
||||
/// Timeout for operations (ms)
|
||||
pub timeout_ms: u32,
|
||||
/// Enable profiling
|
||||
pub profiling: bool,
|
||||
}
|
||||
|
||||
impl Default for ComputeConfig {
|
||||
fn default() -> Self {
|
||||
ComputeConfig {
|
||||
preferred_backend: None,
|
||||
max_memory: 256 * 1024 * 1024, // 256MB
|
||||
timeout_ms: 30_000, // 30 seconds
|
||||
profiling: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_backend_recommendation() {
|
||||
let caps = ComputeCapability {
|
||||
has_webgpu: false,
|
||||
has_webgl2: true,
|
||||
has_float_textures: true,
|
||||
has_transform_feedback: true,
|
||||
has_workers: true,
|
||||
has_shared_memory: true,
|
||||
worker_count: 4,
|
||||
max_texture_size: 4096,
|
||||
gpu_memory_mb: 2048,
|
||||
device_info: "Test GPU".to_string(),
|
||||
};
|
||||
|
||||
// Large operations should use WebGL2
|
||||
assert_eq!(caps.recommend_backend(10000), ComputeBackend::WebGL2);
|
||||
|
||||
// Medium operations with workers should use workers
|
||||
assert_eq!(caps.recommend_backend(2000), ComputeBackend::WebWorkers);
|
||||
|
||||
// Small operations should use CPU
|
||||
assert_eq!(caps.recommend_backend(100), ComputeBackend::CPU);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_backend_with_webgpu() {
|
||||
let caps = ComputeCapability {
|
||||
has_webgpu: true,
|
||||
has_webgl2: true,
|
||||
has_float_textures: true,
|
||||
has_transform_feedback: true,
|
||||
has_workers: true,
|
||||
has_shared_memory: true,
|
||||
worker_count: 4,
|
||||
max_texture_size: 4096,
|
||||
gpu_memory_mb: 2048,
|
||||
device_info: "Test GPU".to_string(),
|
||||
};
|
||||
|
||||
// WebGPU should always be preferred
|
||||
assert_eq!(caps.recommend_backend(100), ComputeBackend::WebGPU);
|
||||
assert_eq!(caps.recommend_backend(10000), ComputeBackend::WebGPU);
|
||||
}
|
||||
}
|
||||
1076
vendor/ruvector/examples/edge-net/src/compute/backends.rs
vendored
Normal file
1076
vendor/ruvector/examples/edge-net/src/compute/backends.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
15
vendor/ruvector/examples/edge-net/src/compute/mod.rs
vendored
Normal file
15
vendor/ruvector/examples/edge-net/src/compute/mod.rs
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
//! SIMD Compute Backend for edge-net P2P AI Network
|
||||
//!
|
||||
//! Provides portable CPU acceleration with support for:
|
||||
//! - WASM simd128 intrinsics (browser/WASM targets)
|
||||
//! - x86_64 AVX2 intrinsics (native x86 targets)
|
||||
//! - Scalar fallback for unsupported platforms
|
||||
//!
|
||||
//! Performance targets:
|
||||
//! - 2,236+ ops/sec for MicroLoRA (rank-2)
|
||||
//! - 150x faster HNSW search
|
||||
//! - Q4 quantized inference
|
||||
|
||||
pub mod simd;
|
||||
|
||||
pub use simd::*;
|
||||
233
vendor/ruvector/examples/edge-net/src/compute/shaders/attention.wgsl
vendored
Normal file
233
vendor/ruvector/examples/edge-net/src/compute/shaders/attention.wgsl
vendored
Normal file
@@ -0,0 +1,233 @@
|
||||
// Flash Attention Shader
|
||||
//
|
||||
// Implements memory-efficient attention using the Flash Attention algorithm.
|
||||
// Target: 2ms for 4K context length.
|
||||
//
|
||||
// Algorithm (Flash Attention v2):
|
||||
// 1. Process Q in blocks, streaming K and V
|
||||
// 2. Maintain running max and sum for numerical stability
|
||||
// 3. Rescale outputs on-the-fly
|
||||
// 4. Avoid materializing full attention matrix (O(n) memory vs O(n^2))
|
||||
//
|
||||
// Memory Layout:
|
||||
// - Q: (seq_len, num_heads * head_dim) - queries
|
||||
// - K: (seq_len, num_heads * head_dim) - keys
|
||||
// - V: (seq_len, num_heads * head_dim) - values
|
||||
// - Output: (seq_len, num_heads * head_dim)
|
||||
|
||||
// Block size for flash attention (balance between parallelism and memory)
|
||||
const BLOCK_SIZE: u32 = 64u;
|
||||
const WARP_SIZE: u32 = 32u;
|
||||
|
||||
struct Uniforms {
|
||||
seq_len: f32,
|
||||
head_dim: f32,
|
||||
num_heads: f32,
|
||||
scale: f32, // 1/sqrt(head_dim)
|
||||
causal_mask: f32, // 1.0 for causal, 0.0 for full
|
||||
_pad0: f32,
|
||||
_pad1: f32,
|
||||
_pad2: f32,
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> Q: array<f32>;
|
||||
@group(0) @binding(1) var<storage, read> K: array<f32>;
|
||||
@group(0) @binding(2) var<storage, read> V: array<f32>;
|
||||
@group(0) @binding(3) var<storage, read_write> Output: array<f32>;
|
||||
@group(0) @binding(4) var<uniform> uniforms: Uniforms;
|
||||
|
||||
// Shared memory for Q, K, V blocks
|
||||
var<workgroup> Q_block: array<f32, 4096>; // BLOCK_SIZE * 64 (max head_dim)
|
||||
var<workgroup> K_block: array<f32, 4096>;
|
||||
var<workgroup> V_block: array<f32, 4096>;
|
||||
var<workgroup> scores: array<f32, 4096>; // BLOCK_SIZE * BLOCK_SIZE
|
||||
|
||||
// Thread-local accumulators
|
||||
var<private> m_prev: f32; // Previous max score
|
||||
var<private> l_prev: f32; // Previous sum of exp(scores - max)
|
||||
var<private> acc: array<f32, 64>; // Output accumulator (head_dim)
|
||||
|
||||
// Compute softmax denominator using online algorithm
|
||||
fn online_softmax_update(
|
||||
new_max: f32,
|
||||
old_max: f32,
|
||||
old_sum: f32,
|
||||
new_scores: ptr<function, array<f32, 64>>,
|
||||
block_len: u32,
|
||||
) -> f32 {
|
||||
// Rescale old sum
|
||||
var new_sum = old_sum * exp(old_max - new_max);
|
||||
|
||||
// Add new contributions
|
||||
for (var i = 0u; i < block_len; i++) {
|
||||
new_sum += exp((*new_scores)[i] - new_max);
|
||||
}
|
||||
|
||||
return new_sum;
|
||||
}
|
||||
|
||||
@compute @workgroup_size(64, 1, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
let seq_len = u32(uniforms.seq_len);
|
||||
let head_dim = u32(uniforms.head_dim);
|
||||
let num_heads = u32(uniforms.num_heads);
|
||||
let scale = uniforms.scale;
|
||||
let is_causal = uniforms.causal_mask > 0.5;
|
||||
|
||||
// This workgroup processes one block of Q for one head
|
||||
let head_idx = group_id.y;
|
||||
let q_block_idx = group_id.x;
|
||||
let q_start = q_block_idx * BLOCK_SIZE;
|
||||
|
||||
let thread_id = local_id.x;
|
||||
let hidden_dim = num_heads * head_dim;
|
||||
|
||||
// Initialize accumulators
|
||||
m_prev = -1e10; // Very negative (will be updated)
|
||||
l_prev = 0.0;
|
||||
for (var i = 0u; i < 64u; i++) {
|
||||
acc[i] = 0.0;
|
||||
}
|
||||
|
||||
// Load Q block into shared memory
|
||||
// Each thread loads one position's head_dim values
|
||||
let q_pos = q_start + thread_id;
|
||||
if (q_pos < seq_len && thread_id < BLOCK_SIZE) {
|
||||
for (var d = 0u; d < head_dim; d++) {
|
||||
let q_idx = q_pos * hidden_dim + head_idx * head_dim + d;
|
||||
Q_block[thread_id * head_dim + d] = Q[q_idx];
|
||||
}
|
||||
}
|
||||
workgroupBarrier();
|
||||
|
||||
// Iterate over K/V blocks
|
||||
let num_kv_blocks = (seq_len + BLOCK_SIZE - 1u) / BLOCK_SIZE;
|
||||
let max_kv_block = select(num_kv_blocks, q_block_idx + 1u, is_causal);
|
||||
|
||||
for (var kv_block_idx = 0u; kv_block_idx < max_kv_block; kv_block_idx++) {
|
||||
let kv_start = kv_block_idx * BLOCK_SIZE;
|
||||
|
||||
// Load K block into shared memory
|
||||
let k_pos = kv_start + thread_id;
|
||||
if (k_pos < seq_len && thread_id < BLOCK_SIZE) {
|
||||
for (var d = 0u; d < head_dim; d++) {
|
||||
let k_idx = k_pos * hidden_dim + head_idx * head_dim + d;
|
||||
K_block[thread_id * head_dim + d] = K[k_idx];
|
||||
}
|
||||
}
|
||||
|
||||
// Load V block into shared memory
|
||||
let v_pos = kv_start + thread_id;
|
||||
if (v_pos < seq_len && thread_id < BLOCK_SIZE) {
|
||||
for (var d = 0u; d < head_dim; d++) {
|
||||
let v_idx = v_pos * hidden_dim + head_idx * head_dim + d;
|
||||
V_block[thread_id * head_dim + d] = V[v_idx];
|
||||
}
|
||||
}
|
||||
workgroupBarrier();
|
||||
|
||||
// Compute attention scores for this Q position against all K in block
|
||||
// Each thread handles one Q position
|
||||
if (thread_id < BLOCK_SIZE && q_pos < seq_len) {
|
||||
let kv_block_len = min(BLOCK_SIZE, seq_len - kv_start);
|
||||
|
||||
// Compute Q @ K^T for this thread's Q position
|
||||
var local_scores: array<f32, 64>;
|
||||
var block_max = -1e10f;
|
||||
|
||||
for (var k = 0u; k < kv_block_len; k++) {
|
||||
let k_global = kv_start + k;
|
||||
|
||||
// Causal mask: skip future positions
|
||||
if (is_causal && k_global > q_pos) {
|
||||
local_scores[k] = -1e10;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Dot product Q[thread] @ K[k]
|
||||
var score = 0.0f;
|
||||
for (var d = 0u; d < head_dim; d++) {
|
||||
score += Q_block[thread_id * head_dim + d] * K_block[k * head_dim + d];
|
||||
}
|
||||
score *= scale;
|
||||
|
||||
local_scores[k] = score;
|
||||
block_max = max(block_max, score);
|
||||
}
|
||||
|
||||
// Update running max
|
||||
let new_max = max(m_prev, block_max);
|
||||
|
||||
// Compute rescaling factors
|
||||
let scale_old = exp(m_prev - new_max);
|
||||
let scale_new = exp(block_max - new_max);
|
||||
|
||||
// Rescale previous accumulator
|
||||
for (var d = 0u; d < head_dim; d++) {
|
||||
acc[d] *= scale_old;
|
||||
}
|
||||
l_prev *= scale_old;
|
||||
|
||||
// Compute exp(scores - new_max) and accumulate
|
||||
var block_sum = 0.0f;
|
||||
for (var k = 0u; k < kv_block_len; k++) {
|
||||
let k_global = kv_start + k;
|
||||
if (is_causal && k_global > q_pos) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let p = exp(local_scores[k] - new_max);
|
||||
block_sum += p;
|
||||
|
||||
// Accumulate weighted V
|
||||
for (var d = 0u; d < head_dim; d++) {
|
||||
acc[d] += p * V_block[k * head_dim + d];
|
||||
}
|
||||
}
|
||||
|
||||
// Update running sum
|
||||
l_prev += block_sum;
|
||||
m_prev = new_max;
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
// Normalize and write output
|
||||
if (thread_id < BLOCK_SIZE && q_pos < seq_len) {
|
||||
let inv_sum = select(1.0 / l_prev, 0.0, l_prev == 0.0);
|
||||
|
||||
for (var d = 0u; d < head_dim; d++) {
|
||||
let out_idx = q_pos * hidden_dim + head_idx * head_dim + d;
|
||||
Output[out_idx] = acc[d] * inv_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Multi-head attention with grouped-query attention (GQA) support
|
||||
@compute @workgroup_size(64, 1, 1)
|
||||
fn main_gqa(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
// GQA: Multiple Q heads share same K/V heads
|
||||
// kv_head = q_head / num_q_per_kv
|
||||
// Left as placeholder for models like Llama 2/3
|
||||
}
|
||||
|
||||
// Sliding window attention variant
|
||||
@compute @workgroup_size(64, 1, 1)
|
||||
fn main_sliding_window(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
// Only attend to positions within window_size
|
||||
// Useful for very long sequences (Mistral-style)
|
||||
// Left as placeholder
|
||||
}
|
||||
159
vendor/ruvector/examples/edge-net/src/compute/shaders/lora.wgsl
vendored
Normal file
159
vendor/ruvector/examples/edge-net/src/compute/shaders/lora.wgsl
vendored
Normal file
@@ -0,0 +1,159 @@
|
||||
// LoRA (Low-Rank Adaptation) Forward Pass Shader
|
||||
//
|
||||
// Computes: output = input + scaling * (input @ A @ B)
|
||||
//
|
||||
// Where:
|
||||
// - input: (batch_size, in_dim)
|
||||
// - A: (in_dim, rank) - down projection
|
||||
// - B: (rank, out_dim) - up projection
|
||||
// - output: (batch_size, out_dim)
|
||||
//
|
||||
// Performance target: <1ms for typical LoRA ranks (2-64)
|
||||
//
|
||||
// Optimization strategy:
|
||||
// 1. Fuse both matmuls into single kernel
|
||||
// 2. Use shared memory for intermediate (rank is small)
|
||||
// 3. Each thread computes one output element
|
||||
|
||||
const WARP_SIZE: u32 = 32u;
|
||||
const MAX_RANK: u32 = 64u; // Maximum supported LoRA rank
|
||||
|
||||
struct Uniforms {
|
||||
batch_size: f32,
|
||||
in_dim: f32,
|
||||
rank: f32,
|
||||
out_dim: f32,
|
||||
scaling: f32, // alpha / rank
|
||||
_pad0: f32,
|
||||
_pad1: f32,
|
||||
_pad2: f32,
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> input: array<f32>;
|
||||
@group(0) @binding(1) var<storage, read> lora_A: array<f32>; // (in_dim, rank)
|
||||
@group(0) @binding(2) var<storage, read> lora_B: array<f32>; // (rank, out_dim)
|
||||
@group(0) @binding(3) var<storage, read_write> output: array<f32>;
|
||||
@group(0) @binding(4) var<uniform> uniforms: Uniforms;
|
||||
|
||||
// Shared memory for intermediate result (input @ A)
|
||||
var<workgroup> intermediate: array<f32, 2048>; // batch * rank (fits typical cases)
|
||||
|
||||
// Thread-local registers
|
||||
var<private> input_cache: array<f32, 32>; // Cache input values
|
||||
var<private> a_cache: array<f32, 64>; // Cache A column
|
||||
|
||||
@compute @workgroup_size(256, 1, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
let batch_size = u32(uniforms.batch_size);
|
||||
let in_dim = u32(uniforms.in_dim);
|
||||
let rank = u32(uniforms.rank);
|
||||
let out_dim = u32(uniforms.out_dim);
|
||||
let scaling = uniforms.scaling;
|
||||
|
||||
let thread_id = local_id.x;
|
||||
let global_thread = global_id.x;
|
||||
|
||||
// Compute which output element this thread handles
|
||||
let batch_idx = global_thread / out_dim;
|
||||
let out_idx = global_thread % out_dim;
|
||||
|
||||
if (batch_idx >= batch_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Phase 1: Compute input @ A for this batch element
|
||||
// Store in shared memory for reuse
|
||||
// Each thread contributes to computing intermediate[batch_idx, :]
|
||||
|
||||
// For small rank, each thread can compute entire row
|
||||
if (rank <= MAX_RANK && thread_id < rank) {
|
||||
var sum = 0.0f;
|
||||
|
||||
// Dot product: input[batch_idx, :] @ A[:, thread_id]
|
||||
for (var i = 0u; i < in_dim; i++) {
|
||||
let input_val = input[batch_idx * in_dim + i];
|
||||
let a_val = lora_A[i * rank + thread_id];
|
||||
sum += input_val * a_val;
|
||||
}
|
||||
|
||||
// Store in shared memory
|
||||
let shared_idx = (batch_idx % 32u) * rank + thread_id; // Wrap for shared memory size
|
||||
if (shared_idx < 2048u) {
|
||||
intermediate[shared_idx] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
// Phase 2: Compute intermediate @ B for this output position
|
||||
var lora_output = 0.0f;
|
||||
|
||||
// Dot product: intermediate[batch_idx, :] @ B[:, out_idx]
|
||||
for (var r = 0u; r < rank; r++) {
|
||||
let shared_idx = (batch_idx % 32u) * rank + r;
|
||||
let inter_val = select(0.0, intermediate[shared_idx], shared_idx < 2048u);
|
||||
let b_val = lora_B[r * out_dim + out_idx];
|
||||
lora_output += inter_val * b_val;
|
||||
}
|
||||
|
||||
// Apply scaling and add to output
|
||||
// Note: For true residual connection, we'd add to existing output
|
||||
// Here we assume output buffer is pre-filled with base model output
|
||||
// or we're computing the delta only
|
||||
output[batch_idx * out_dim + out_idx] = lora_output * scaling;
|
||||
}
|
||||
|
||||
// Fused LoRA with base weight: output = (input @ W) + scaling * (input @ A @ B)
|
||||
// More efficient when we have access to base weights
|
||||
@compute @workgroup_size(256, 1, 1)
|
||||
fn main_fused(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
// Would include base weight computation
|
||||
// Placeholder for full integration
|
||||
}
|
||||
|
||||
// Batched LoRA for multiple adapters (multi-task serving)
|
||||
// Each batch element can use different LoRA weights
|
||||
@compute @workgroup_size(256, 1, 1)
|
||||
fn main_batched_lora(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
// Supports different LoRA for different requests in same batch
|
||||
// Useful for serving multiple fine-tuned models
|
||||
// Placeholder for multi-tenant serving
|
||||
}
|
||||
|
||||
// Quantized LoRA (int4 weights)
|
||||
// Significant memory savings for large rank or many adapters
|
||||
@compute @workgroup_size(256, 1, 1)
|
||||
fn main_quantized(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
// A and B stored as int4 with scale factors
|
||||
// Dequantize on-the-fly during computation
|
||||
// Placeholder for memory-constrained deployment
|
||||
}
|
||||
|
||||
// DoRA (Weight-Decomposed Low-Rank Adaptation)
|
||||
// Decomposes weight update into magnitude and direction
|
||||
@compute @workgroup_size(256, 1, 1)
|
||||
fn main_dora(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
// DoRA: output = m * (W + scaling * A @ B) / ||W + scaling * A @ B||
|
||||
// where m is learned magnitude
|
||||
// Placeholder for DoRA support
|
||||
}
|
||||
102
vendor/ruvector/examples/edge-net/src/compute/shaders/matmul.frag
vendored
Normal file
102
vendor/ruvector/examples/edge-net/src/compute/shaders/matmul.frag
vendored
Normal file
@@ -0,0 +1,102 @@
|
||||
#version 300 es
|
||||
//! Matrix Multiplication Fragment Shader
|
||||
//!
|
||||
//! Computes C = A * B using texture-based GPU compute.
|
||||
//!
|
||||
//! ## Usage
|
||||
//!
|
||||
//! - A and B are R32F textures (single-channel float)
|
||||
//! - Output is rendered to framebuffer-attached texture
|
||||
//! - Each fragment computes one element of C
|
||||
//!
|
||||
//! ## Texture Layout
|
||||
//!
|
||||
//! - A: rows = M, cols = K (stored row-major)
|
||||
//! - B: rows = K, cols = N (stored row-major)
|
||||
//! - C: rows = M, cols = N (output)
|
||||
//!
|
||||
//! ## Performance Notes
|
||||
//!
|
||||
//! - Use texture size that's power of 2 for best performance
|
||||
//! - NEAREST filtering required for exact texel fetch
|
||||
//! - Loop unrolling may help on some GPUs
|
||||
|
||||
precision highp float;
|
||||
|
||||
// Input matrices as textures
|
||||
uniform sampler2D u_A;
|
||||
uniform sampler2D u_B;
|
||||
|
||||
// Matrix dimensions: (M, K, N)
|
||||
// A is MxK, B is KxN, C is MxN
|
||||
uniform vec3 u_dims;
|
||||
|
||||
// Texture coordinates from vertex shader
|
||||
in vec2 v_texcoord;
|
||||
|
||||
// Output value (single float stored in R channel)
|
||||
out float fragColor;
|
||||
|
||||
void main() {
|
||||
float M = u_dims.x;
|
||||
float K = u_dims.y;
|
||||
float N = u_dims.z;
|
||||
|
||||
// Calculate output position (row i, column j)
|
||||
// v_texcoord is normalized [0,1], so we scale to pixel coordinates
|
||||
float i = floor(v_texcoord.y * M);
|
||||
float j = floor(v_texcoord.x * N);
|
||||
|
||||
// Bounds check (fragments outside valid range output 0)
|
||||
if (i >= M || j >= N) {
|
||||
fragColor = 0.0;
|
||||
return;
|
||||
}
|
||||
|
||||
// Compute dot product of row i of A with column j of B
|
||||
float sum = 0.0;
|
||||
|
||||
// Manual loop unrolling for common case (K <= 4)
|
||||
// This helps on mobile GPUs with limited loop support
|
||||
#if defined(UNROLL_4)
|
||||
if (K <= 4.0) {
|
||||
if (K >= 1.0) {
|
||||
float a0 = texture(u_A, vec2(0.5 / K, (i + 0.5) / M)).r;
|
||||
float b0 = texture(u_B, vec2((j + 0.5) / N, 0.5 / K)).r;
|
||||
sum += a0 * b0;
|
||||
}
|
||||
if (K >= 2.0) {
|
||||
float a1 = texture(u_A, vec2(1.5 / K, (i + 0.5) / M)).r;
|
||||
float b1 = texture(u_B, vec2((j + 0.5) / N, 1.5 / K)).r;
|
||||
sum += a1 * b1;
|
||||
}
|
||||
if (K >= 3.0) {
|
||||
float a2 = texture(u_A, vec2(2.5 / K, (i + 0.5) / M)).r;
|
||||
float b2 = texture(u_B, vec2((j + 0.5) / N, 2.5 / K)).r;
|
||||
sum += a2 * b2;
|
||||
}
|
||||
if (K >= 4.0) {
|
||||
float a3 = texture(u_A, vec2(3.5 / K, (i + 0.5) / M)).r;
|
||||
float b3 = texture(u_B, vec2((j + 0.5) / N, 3.5 / K)).r;
|
||||
sum += a3 * b3;
|
||||
}
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
// General loop for arbitrary K
|
||||
// We add 0.5 to center the sample within each texel
|
||||
for (float k = 0.0; k < K; k += 1.0) {
|
||||
// Sample A[i, k] - row i, column k
|
||||
// Texture coordinate: x = (k + 0.5) / K, y = (i + 0.5) / M
|
||||
float a_val = texture(u_A, vec2((k + 0.5) / K, (i + 0.5) / M)).r;
|
||||
|
||||
// Sample B[k, j] - row k, column j
|
||||
// Texture coordinate: x = (j + 0.5) / N, y = (k + 0.5) / K
|
||||
float b_val = texture(u_B, vec2((j + 0.5) / N, (k + 0.5) / K)).r;
|
||||
|
||||
sum += a_val * b_val;
|
||||
}
|
||||
}
|
||||
|
||||
fragColor = sum;
|
||||
}
|
||||
171
vendor/ruvector/examples/edge-net/src/compute/shaders/matmul.wgsl
vendored
Normal file
171
vendor/ruvector/examples/edge-net/src/compute/shaders/matmul.wgsl
vendored
Normal file
@@ -0,0 +1,171 @@
|
||||
// Tiled Matrix Multiplication Shader
|
||||
//
|
||||
// Computes C = A * B using 128x128 tiles for cache efficiency.
|
||||
// Targets 10+ TFLOPS on discrete GPUs.
|
||||
//
|
||||
// Algorithm:
|
||||
// 1. Each workgroup computes a TILE_SIZE x TILE_SIZE block of C
|
||||
// 2. A and B are loaded into shared memory in tiles
|
||||
// 3. Each thread computes a 4x4 subblock for register tiling
|
||||
// 4. Accumulation happens in registers, then written to C
|
||||
//
|
||||
// Memory Layout:
|
||||
// - A: M x K matrix (row-major)
|
||||
// - B: K x N matrix (row-major)
|
||||
// - C: M x N matrix (row-major, output)
|
||||
|
||||
// Tile dimensions (must match host code)
|
||||
const TILE_SIZE: u32 = 128u;
|
||||
const BLOCK_SIZE: u32 = 16u; // Threads per dimension in workgroup
|
||||
const THREAD_TILE: u32 = 8u; // Each thread computes 8x8 elements
|
||||
|
||||
// Uniforms
|
||||
struct Uniforms {
|
||||
M: u32, // Rows of A, rows of C
|
||||
N: u32, // Cols of B, cols of C
|
||||
K: u32, // Cols of A, rows of B
|
||||
tile_size: u32,
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> A: array<f32>;
|
||||
@group(0) @binding(1) var<storage, read> B: array<f32>;
|
||||
@group(0) @binding(2) var<storage, read_write> C: array<f32>;
|
||||
@group(0) @binding(3) var<uniform> uniforms: Uniforms;
|
||||
|
||||
// Shared memory for tile caching
|
||||
var<workgroup> A_tile: array<f32, 2048>; // TILE_SIZE * BLOCK_SIZE = 128 * 16
|
||||
var<workgroup> B_tile: array<f32, 2048>;
|
||||
|
||||
// Thread-local accumulator registers
|
||||
var<private> acc: array<f32, 64>; // THREAD_TILE * THREAD_TILE = 8 * 8
|
||||
|
||||
@compute @workgroup_size(16, 16, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
let M = uniforms.M;
|
||||
let N = uniforms.N;
|
||||
let K = uniforms.K;
|
||||
|
||||
// Global row and column for this thread's block
|
||||
let block_row = group_id.x * TILE_SIZE;
|
||||
let block_col = group_id.y * TILE_SIZE;
|
||||
|
||||
// Thread position within workgroup
|
||||
let thread_row = local_id.x;
|
||||
let thread_col = local_id.y;
|
||||
|
||||
// Initialize accumulators to zero
|
||||
for (var i = 0u; i < 64u; i++) {
|
||||
acc[i] = 0.0;
|
||||
}
|
||||
|
||||
// Number of K-tiles to process
|
||||
let num_k_tiles = (K + TILE_SIZE - 1u) / TILE_SIZE;
|
||||
|
||||
// Iterate over K dimension in tiles
|
||||
for (var k_tile = 0u; k_tile < num_k_tiles; k_tile++) {
|
||||
let k_base = k_tile * TILE_SIZE;
|
||||
|
||||
// Cooperative load of A tile into shared memory
|
||||
// Each thread loads multiple elements
|
||||
for (var i = 0u; i < THREAD_TILE; i++) {
|
||||
let a_row = block_row + thread_row * THREAD_TILE + i;
|
||||
for (var j = 0u; j < THREAD_TILE; j++) {
|
||||
let a_col = k_base + thread_col * THREAD_TILE + j;
|
||||
let shared_idx = (thread_row * THREAD_TILE + i) * BLOCK_SIZE + thread_col;
|
||||
|
||||
if (a_row < M && a_col < K) {
|
||||
// Only load partial tile for first few elements to fit in shared memory
|
||||
if (shared_idx < 2048u) {
|
||||
A_tile[shared_idx] = A[a_row * K + a_col];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cooperative load of B tile into shared memory
|
||||
for (var i = 0u; i < THREAD_TILE; i++) {
|
||||
let b_row = k_base + thread_row * THREAD_TILE + i;
|
||||
for (var j = 0u; j < THREAD_TILE; j++) {
|
||||
let b_col = block_col + thread_col * THREAD_TILE + j;
|
||||
let shared_idx = (thread_row * THREAD_TILE + i) * BLOCK_SIZE + thread_col;
|
||||
|
||||
if (b_row < K && b_col < N) {
|
||||
if (shared_idx < 2048u) {
|
||||
B_tile[shared_idx] = B[b_row * N + b_col];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Synchronize to ensure all data is loaded
|
||||
workgroupBarrier();
|
||||
|
||||
// Compute partial dot products
|
||||
// Each thread computes an 8x8 subblock
|
||||
for (var k = 0u; k < min(TILE_SIZE, K - k_base); k++) {
|
||||
// Load A values into registers
|
||||
var a_regs: array<f32, 8>;
|
||||
for (var i = 0u; i < THREAD_TILE; i++) {
|
||||
let a_shared_row = thread_row * THREAD_TILE + i;
|
||||
let a_shared_idx = a_shared_row * BLOCK_SIZE + (k % BLOCK_SIZE);
|
||||
if (a_shared_idx < 2048u) {
|
||||
a_regs[i] = A_tile[a_shared_idx];
|
||||
} else {
|
||||
a_regs[i] = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
// Load B values into registers
|
||||
var b_regs: array<f32, 8>;
|
||||
for (var j = 0u; j < THREAD_TILE; j++) {
|
||||
let b_shared_row = k % BLOCK_SIZE;
|
||||
let b_shared_col = thread_col * THREAD_TILE + j;
|
||||
let b_shared_idx = b_shared_row * BLOCK_SIZE + (b_shared_col % BLOCK_SIZE);
|
||||
if (b_shared_idx < 2048u) {
|
||||
b_regs[j] = B_tile[b_shared_idx];
|
||||
} else {
|
||||
b_regs[j] = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
// Outer product accumulation
|
||||
for (var i = 0u; i < THREAD_TILE; i++) {
|
||||
for (var j = 0u; j < THREAD_TILE; j++) {
|
||||
acc[i * THREAD_TILE + j] += a_regs[i] * b_regs[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Synchronize before loading next tile
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
// Write accumulated results to global memory
|
||||
for (var i = 0u; i < THREAD_TILE; i++) {
|
||||
let c_row = block_row + thread_row * THREAD_TILE + i;
|
||||
for (var j = 0u; j < THREAD_TILE; j++) {
|
||||
let c_col = block_col + thread_col * THREAD_TILE + j;
|
||||
|
||||
if (c_row < M && c_col < N) {
|
||||
C[c_row * N + c_col] = acc[i * THREAD_TILE + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Quantized int8 matrix multiplication variant
|
||||
// Uses int8 inputs with int32 accumulation, then scales to f32 output
|
||||
@compute @workgroup_size(16, 16, 1)
|
||||
fn main_int8(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
// Quantized version would use packed i8x4 and accumulate to i32
|
||||
// Then scale by quantization factors at the end
|
||||
// Left as placeholder for future implementation
|
||||
}
|
||||
1417
vendor/ruvector/examples/edge-net/src/compute/simd.rs
vendored
Normal file
1417
vendor/ruvector/examples/edge-net/src/compute/simd.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
751
vendor/ruvector/examples/edge-net/src/compute/tensor.rs
vendored
Normal file
751
vendor/ruvector/examples/edge-net/src/compute/tensor.rs
vendored
Normal file
@@ -0,0 +1,751 @@
|
||||
//! Tensor abstraction layer for unified compute operations
|
||||
//!
|
||||
//! Provides a minimal tensor abstraction that works across all compute backends
|
||||
//! (WebGPU, WebGL2, SIMD, WebWorkers, and naive fallback).
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
|
||||
/// Data type for tensor elements
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum DType {
|
||||
/// 32-bit floating point
|
||||
F32,
|
||||
/// 16-bit floating point (for WebGPU)
|
||||
F16,
|
||||
/// 8-bit integer (for quantized models)
|
||||
I8,
|
||||
/// Unsigned 8-bit (for embeddings)
|
||||
U8,
|
||||
/// Binary (for HDC hypervectors)
|
||||
Binary,
|
||||
}
|
||||
|
||||
impl DType {
|
||||
/// Size in bytes for this data type
|
||||
pub fn size_bytes(&self) -> usize {
|
||||
match self {
|
||||
DType::F32 => 4,
|
||||
DType::F16 => 2,
|
||||
DType::I8 | DType::U8 => 1,
|
||||
DType::Binary => 1, // 8 bits per byte
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DType {
|
||||
fn default() -> Self {
|
||||
DType::F32
|
||||
}
|
||||
}
|
||||
|
||||
/// Tensor shape with up to 4 dimensions
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct Shape {
|
||||
dims: Vec<usize>,
|
||||
}
|
||||
|
||||
impl Shape {
|
||||
/// Create a new shape from dimensions
|
||||
pub fn new(dims: &[usize]) -> Self {
|
||||
Self { dims: dims.to_vec() }
|
||||
}
|
||||
|
||||
/// 1D shape (vector)
|
||||
pub fn d1(n: usize) -> Self {
|
||||
Self { dims: vec![n] }
|
||||
}
|
||||
|
||||
/// 2D shape (matrix)
|
||||
pub fn d2(rows: usize, cols: usize) -> Self {
|
||||
Self { dims: vec![rows, cols] }
|
||||
}
|
||||
|
||||
/// 3D shape (batch of matrices)
|
||||
pub fn d3(batch: usize, rows: usize, cols: usize) -> Self {
|
||||
Self { dims: vec![batch, rows, cols] }
|
||||
}
|
||||
|
||||
/// 4D shape (e.g., attention tensors)
|
||||
pub fn d4(b: usize, h: usize, s: usize, d: usize) -> Self {
|
||||
Self { dims: vec![b, h, s, d] }
|
||||
}
|
||||
|
||||
/// Total number of elements
|
||||
pub fn numel(&self) -> usize {
|
||||
self.dims.iter().product()
|
||||
}
|
||||
|
||||
/// Number of dimensions
|
||||
pub fn ndim(&self) -> usize {
|
||||
self.dims.len()
|
||||
}
|
||||
|
||||
/// Get dimension at index
|
||||
pub fn dim(&self, idx: usize) -> usize {
|
||||
self.dims.get(idx).copied().unwrap_or(1)
|
||||
}
|
||||
|
||||
/// Get all dimensions
|
||||
pub fn dims(&self) -> &[usize] {
|
||||
&self.dims
|
||||
}
|
||||
|
||||
/// Check if shape is compatible for matrix multiplication with another
|
||||
pub fn matmul_compatible(&self, other: &Shape) -> bool {
|
||||
if self.ndim() < 1 || other.ndim() < 1 {
|
||||
return false;
|
||||
}
|
||||
// Last dim of self must match second-to-last of other (or last if 1D)
|
||||
let self_k = self.dim(self.ndim() - 1);
|
||||
let other_k = if other.ndim() >= 2 {
|
||||
other.dim(other.ndim() - 2)
|
||||
} else {
|
||||
other.dim(0)
|
||||
};
|
||||
self_k == other_k
|
||||
}
|
||||
|
||||
/// Compute strides for row-major layout
|
||||
pub fn strides(&self) -> Vec<usize> {
|
||||
let mut strides = vec![1; self.dims.len()];
|
||||
for i in (0..self.dims.len() - 1).rev() {
|
||||
strides[i] = strides[i + 1] * self.dims[i + 1];
|
||||
}
|
||||
strides
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Shape {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "(")?;
|
||||
for (i, d) in self.dims.iter().enumerate() {
|
||||
if i > 0 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
write!(f, "{}", d)?;
|
||||
}
|
||||
write!(f, ")")
|
||||
}
|
||||
}
|
||||
|
||||
/// Memory layout for tensors
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum Layout {
|
||||
/// Row-major (C-style), most common
|
||||
RowMajor,
|
||||
/// Column-major (Fortran-style)
|
||||
ColMajor,
|
||||
/// Strided (non-contiguous)
|
||||
Strided,
|
||||
}
|
||||
|
||||
impl Default for Layout {
|
||||
fn default() -> Self {
|
||||
Layout::RowMajor
|
||||
}
|
||||
}
|
||||
|
||||
/// Tensor storage - holds the actual data
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum TensorStorage {
|
||||
/// CPU storage (Vec<f32>)
|
||||
Cpu(Vec<f32>),
|
||||
/// Quantized storage (Vec<i8>)
|
||||
Quantized(Vec<i8>, f32), // (data, scale)
|
||||
/// Binary storage for HDC
|
||||
Binary(Vec<u64>), // 64 bits per element
|
||||
/// GPU buffer reference (opaque handle)
|
||||
GpuBuffer(u32), // WebGPU buffer ID
|
||||
/// Shared memory reference for WebWorkers
|
||||
SharedBuffer(u32), // SharedArrayBuffer ID
|
||||
}
|
||||
|
||||
impl TensorStorage {
|
||||
/// Get storage size in bytes
|
||||
pub fn size_bytes(&self) -> usize {
|
||||
match self {
|
||||
TensorStorage::Cpu(v) => v.len() * 4,
|
||||
TensorStorage::Quantized(v, _) => v.len(),
|
||||
TensorStorage::Binary(v) => v.len() * 8,
|
||||
TensorStorage::GpuBuffer(_) => 0, // Unknown
|
||||
TensorStorage::SharedBuffer(_) => 0, // Unknown
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if storage is on CPU
|
||||
pub fn is_cpu(&self) -> bool {
|
||||
matches!(self, TensorStorage::Cpu(_) | TensorStorage::Quantized(_, _))
|
||||
}
|
||||
|
||||
/// Check if storage is on GPU
|
||||
pub fn is_gpu(&self) -> bool {
|
||||
matches!(self, TensorStorage::GpuBuffer(_))
|
||||
}
|
||||
}
|
||||
|
||||
/// Main tensor type for all compute operations
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Tensor {
|
||||
/// Shape of the tensor
|
||||
shape: Shape,
|
||||
/// Data type
|
||||
dtype: DType,
|
||||
/// Memory layout
|
||||
layout: Layout,
|
||||
/// Underlying storage
|
||||
storage: TensorStorage,
|
||||
/// Offset into storage (for views)
|
||||
offset: usize,
|
||||
/// Custom strides (for non-contiguous tensors)
|
||||
strides: Option<Vec<usize>>,
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
// ========================================================================
|
||||
// Constructors
|
||||
// ========================================================================
|
||||
|
||||
/// Create a new tensor with zeros
|
||||
pub fn zeros(shape: Shape, dtype: DType) -> Self {
|
||||
let numel = shape.numel();
|
||||
let storage = match dtype {
|
||||
DType::F32 | DType::F16 => TensorStorage::Cpu(vec![0.0; numel]),
|
||||
DType::I8 | DType::U8 => TensorStorage::Quantized(vec![0; numel], 1.0),
|
||||
DType::Binary => TensorStorage::Binary(vec![0; (numel + 63) / 64]),
|
||||
};
|
||||
Self {
|
||||
shape,
|
||||
dtype,
|
||||
layout: Layout::RowMajor,
|
||||
storage,
|
||||
offset: 0,
|
||||
strides: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new tensor with ones
|
||||
pub fn ones(shape: Shape, dtype: DType) -> Self {
|
||||
let numel = shape.numel();
|
||||
let storage = match dtype {
|
||||
DType::F32 | DType::F16 => TensorStorage::Cpu(vec![1.0; numel]),
|
||||
DType::I8 | DType::U8 => TensorStorage::Quantized(vec![1; numel], 1.0),
|
||||
DType::Binary => TensorStorage::Binary(vec![u64::MAX; (numel + 63) / 64]),
|
||||
};
|
||||
Self {
|
||||
shape,
|
||||
dtype,
|
||||
layout: Layout::RowMajor,
|
||||
storage,
|
||||
offset: 0,
|
||||
strides: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a tensor from raw f32 data
|
||||
pub fn from_slice(data: &[f32], shape: Shape) -> Self {
|
||||
assert_eq!(
|
||||
data.len(),
|
||||
shape.numel(),
|
||||
"Data length {} doesn't match shape {}",
|
||||
data.len(),
|
||||
shape
|
||||
);
|
||||
Self {
|
||||
shape,
|
||||
dtype: DType::F32,
|
||||
layout: Layout::RowMajor,
|
||||
storage: TensorStorage::Cpu(data.to_vec()),
|
||||
offset: 0,
|
||||
strides: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a tensor from a Vec<f32>
|
||||
pub fn from_vec(data: Vec<f32>, shape: Shape) -> Self {
|
||||
assert_eq!(
|
||||
data.len(),
|
||||
shape.numel(),
|
||||
"Data length {} doesn't match shape {}",
|
||||
data.len(),
|
||||
shape
|
||||
);
|
||||
Self {
|
||||
shape,
|
||||
dtype: DType::F32,
|
||||
layout: Layout::RowMajor,
|
||||
storage: TensorStorage::Cpu(data),
|
||||
offset: 0,
|
||||
strides: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a random tensor (uniform [0, 1))
|
||||
pub fn rand(shape: Shape) -> Self {
|
||||
let numel = shape.numel();
|
||||
let mut data = vec![0.0f32; numel];
|
||||
// Simple LCG PRNG for reproducibility
|
||||
let mut seed = 0xDEADBEEFu64;
|
||||
for x in data.iter_mut() {
|
||||
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
|
||||
*x = (seed >> 33) as f32 / (1u64 << 31) as f32;
|
||||
}
|
||||
Self::from_vec(data, shape)
|
||||
}
|
||||
|
||||
/// Create a random normal tensor (mean=0, std=1)
|
||||
pub fn randn(shape: Shape) -> Self {
|
||||
let numel = shape.numel();
|
||||
let mut data = vec![0.0f32; numel];
|
||||
// Box-Muller transform for normal distribution
|
||||
let mut seed = 0xCAFEBABEu64;
|
||||
for i in (0..numel).step_by(2) {
|
||||
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
|
||||
let u1 = (seed >> 33) as f32 / (1u64 << 31) as f32;
|
||||
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
|
||||
let u2 = (seed >> 33) as f32 / (1u64 << 31) as f32;
|
||||
|
||||
let r = (-2.0 * u1.max(1e-10).ln()).sqrt();
|
||||
let theta = 2.0 * std::f32::consts::PI * u2;
|
||||
|
||||
data[i] = r * theta.cos();
|
||||
if i + 1 < numel {
|
||||
data[i + 1] = r * theta.sin();
|
||||
}
|
||||
}
|
||||
Self::from_vec(data, shape)
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Accessors
|
||||
// ========================================================================
|
||||
|
||||
/// Get tensor shape
|
||||
pub fn shape(&self) -> &Shape {
|
||||
&self.shape
|
||||
}
|
||||
|
||||
/// Get data type
|
||||
pub fn dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
/// Get number of elements
|
||||
pub fn numel(&self) -> usize {
|
||||
self.shape.numel()
|
||||
}
|
||||
|
||||
/// Get memory layout
|
||||
pub fn layout(&self) -> Layout {
|
||||
self.layout
|
||||
}
|
||||
|
||||
/// Check if tensor is contiguous
|
||||
pub fn is_contiguous(&self) -> bool {
|
||||
self.strides.is_none() && self.offset == 0
|
||||
}
|
||||
|
||||
/// Get underlying storage reference
|
||||
pub fn storage(&self) -> &TensorStorage {
|
||||
&self.storage
|
||||
}
|
||||
|
||||
/// Get underlying data as f32 slice (if CPU storage)
|
||||
pub fn as_slice(&self) -> Option<&[f32]> {
|
||||
match &self.storage {
|
||||
TensorStorage::Cpu(data) => {
|
||||
if self.is_contiguous() {
|
||||
Some(data.as_slice())
|
||||
} else {
|
||||
Some(&data[self.offset..self.offset + self.numel()])
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get mutable underlying data (if CPU storage)
|
||||
pub fn as_mut_slice(&mut self) -> Option<&mut [f32]> {
|
||||
match &mut self.storage {
|
||||
TensorStorage::Cpu(data) => {
|
||||
if self.is_contiguous() {
|
||||
Some(data.as_mut_slice())
|
||||
} else {
|
||||
let start = self.offset;
|
||||
let end = start + self.numel();
|
||||
Some(&mut data[start..end])
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert to Vec<f32> (copies data)
|
||||
pub fn to_vec(&self) -> Vec<f32> {
|
||||
match &self.storage {
|
||||
TensorStorage::Cpu(data) => {
|
||||
if self.is_contiguous() {
|
||||
data.clone()
|
||||
} else {
|
||||
data[self.offset..self.offset + self.numel()].to_vec()
|
||||
}
|
||||
}
|
||||
TensorStorage::Quantized(data, scale) => {
|
||||
data.iter().map(|&x| x as f32 * scale).collect()
|
||||
}
|
||||
_ => vec![0.0; self.numel()],
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Transformations
|
||||
// ========================================================================
|
||||
|
||||
/// Reshape tensor (must have same numel)
|
||||
pub fn reshape(&self, new_shape: Shape) -> Self {
|
||||
assert_eq!(
|
||||
self.numel(),
|
||||
new_shape.numel(),
|
||||
"Cannot reshape {} to {}",
|
||||
self.shape,
|
||||
new_shape
|
||||
);
|
||||
Self {
|
||||
shape: new_shape,
|
||||
dtype: self.dtype,
|
||||
layout: self.layout,
|
||||
storage: self.storage.clone(),
|
||||
offset: self.offset,
|
||||
strides: None, // Reshaping makes it contiguous
|
||||
}
|
||||
}
|
||||
|
||||
/// Transpose 2D tensor
|
||||
pub fn transpose(&self) -> Self {
|
||||
assert_eq!(self.shape.ndim(), 2, "Transpose only supports 2D tensors");
|
||||
let rows = self.shape.dim(0);
|
||||
let cols = self.shape.dim(1);
|
||||
|
||||
// For non-contiguous transpose, we'd use strides
|
||||
// For simplicity, we copy and transpose
|
||||
if let TensorStorage::Cpu(data) = &self.storage {
|
||||
let mut new_data = vec![0.0f32; self.numel()];
|
||||
for i in 0..rows {
|
||||
for j in 0..cols {
|
||||
new_data[j * rows + i] = data[i * cols + j];
|
||||
}
|
||||
}
|
||||
Self::from_vec(new_data, Shape::d2(cols, rows))
|
||||
} else {
|
||||
// For GPU tensors, return a strided view
|
||||
Self {
|
||||
shape: Shape::d2(cols, rows),
|
||||
dtype: self.dtype,
|
||||
layout: Layout::Strided,
|
||||
storage: self.storage.clone(),
|
||||
offset: self.offset,
|
||||
strides: Some(vec![1, rows]),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert to contiguous layout
|
||||
pub fn contiguous(&self) -> Self {
|
||||
if self.is_contiguous() {
|
||||
self.clone()
|
||||
} else {
|
||||
// Copy to new contiguous storage
|
||||
Self::from_vec(self.to_vec(), self.shape.clone())
|
||||
}
|
||||
}
|
||||
|
||||
/// Quantize to i8
|
||||
pub fn quantize(&self) -> Self {
|
||||
let data = self.to_vec();
|
||||
let max_abs = data.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
|
||||
let scale = max_abs / 127.0;
|
||||
|
||||
let quantized: Vec<i8> = data
|
||||
.iter()
|
||||
.map(|&x| (x / scale).clamp(-127.0, 127.0) as i8)
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
shape: self.shape.clone(),
|
||||
dtype: DType::I8,
|
||||
layout: Layout::RowMajor,
|
||||
storage: TensorStorage::Quantized(quantized, scale),
|
||||
offset: 0,
|
||||
strides: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Dequantize to f32
|
||||
pub fn dequantize(&self) -> Self {
|
||||
Self::from_vec(self.to_vec(), self.shape.clone())
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Size estimation
|
||||
// ========================================================================
|
||||
|
||||
/// Estimate memory usage in bytes
|
||||
pub fn size_bytes(&self) -> usize {
|
||||
self.storage.size_bytes()
|
||||
}
|
||||
}
|
||||
|
||||
/// LoRA adapter for efficient fine-tuning
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct LoraAdapter {
|
||||
/// Low-rank A matrix (d x r)
|
||||
pub a: Tensor,
|
||||
/// Low-rank B matrix (r x d)
|
||||
pub b: Tensor,
|
||||
/// Scaling factor (alpha / rank)
|
||||
pub scaling: f32,
|
||||
/// Target layer name
|
||||
pub target: String,
|
||||
}
|
||||
|
||||
impl LoraAdapter {
|
||||
/// Create a new LoRA adapter
|
||||
pub fn new(input_dim: usize, output_dim: usize, rank: usize, alpha: f32, target: &str) -> Self {
|
||||
// Initialize A with random normal, B with zeros (as per LoRA paper)
|
||||
let a = Tensor::randn(Shape::d2(input_dim, rank));
|
||||
let b = Tensor::zeros(Shape::d2(rank, output_dim), DType::F32);
|
||||
|
||||
Self {
|
||||
a,
|
||||
b,
|
||||
scaling: alpha / rank as f32,
|
||||
target: target.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get rank of this adapter
|
||||
pub fn rank(&self) -> usize {
|
||||
self.a.shape().dim(1)
|
||||
}
|
||||
|
||||
/// Get input dimension
|
||||
pub fn input_dim(&self) -> usize {
|
||||
self.a.shape().dim(0)
|
||||
}
|
||||
|
||||
/// Get output dimension
|
||||
pub fn output_dim(&self) -> usize {
|
||||
self.b.shape().dim(1)
|
||||
}
|
||||
}
|
||||
|
||||
/// Workload classification for backend selection
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum WorkloadType {
|
||||
/// Small matmul (< 1K elements)
|
||||
SmallMatmul,
|
||||
/// Medium matmul (1K - 100K elements)
|
||||
MediumMatmul,
|
||||
/// Large matmul (> 100K elements)
|
||||
LargeMatmul,
|
||||
/// Attention mechanism
|
||||
Attention,
|
||||
/// Element-wise operation
|
||||
Elementwise,
|
||||
/// Reduction (sum, mean, etc.)
|
||||
Reduction,
|
||||
/// Sparse operation (> 50% zeros)
|
||||
Sparse,
|
||||
/// Batch inference
|
||||
BatchInference,
|
||||
/// LoRA forward pass
|
||||
LoraForward,
|
||||
}
|
||||
|
||||
impl WorkloadType {
|
||||
/// Classify a workload from tensor shapes
|
||||
pub fn classify(a: &Tensor, b: Option<&Tensor>) -> Self {
|
||||
let numel_a = a.numel();
|
||||
|
||||
match b {
|
||||
Some(b_tensor) => {
|
||||
let numel_b = b_tensor.numel();
|
||||
let total = numel_a + numel_b;
|
||||
|
||||
if a.shape().ndim() >= 3 && a.shape().dim(a.shape().ndim() - 2) == a.shape().dim(a.shape().ndim() - 1) {
|
||||
// Likely attention (square inner dimensions)
|
||||
WorkloadType::Attention
|
||||
} else if total < 1_000 {
|
||||
WorkloadType::SmallMatmul
|
||||
} else if total < 100_000 {
|
||||
WorkloadType::MediumMatmul
|
||||
} else {
|
||||
WorkloadType::LargeMatmul
|
||||
}
|
||||
}
|
||||
None => {
|
||||
if numel_a < 1_000 {
|
||||
WorkloadType::Elementwise
|
||||
} else {
|
||||
WorkloadType::Reduction
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get estimated FLOP count for this workload
|
||||
pub fn estimated_flops(&self, numel: usize) -> u64 {
|
||||
match self {
|
||||
WorkloadType::SmallMatmul => numel as u64 * 2,
|
||||
WorkloadType::MediumMatmul => numel as u64 * 2,
|
||||
WorkloadType::LargeMatmul => numel as u64 * 2,
|
||||
WorkloadType::Attention => numel as u64 * 4, // Q*K + softmax + *V
|
||||
WorkloadType::Elementwise => numel as u64,
|
||||
WorkloadType::Reduction => numel as u64,
|
||||
WorkloadType::Sparse => numel as u64 / 2, // Assumes 50% sparsity
|
||||
WorkloadType::BatchInference => numel as u64 * 10,
|
||||
WorkloadType::LoraForward => numel as u64 * 4, // A*x + B*(A*x)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Sparsity analysis for tensors
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct SparsityInfo {
|
||||
/// Fraction of zero elements
|
||||
pub sparsity: f32,
|
||||
/// Is structured sparsity (blocks of zeros)?
|
||||
pub is_structured: bool,
|
||||
/// Block size if structured
|
||||
pub block_size: Option<usize>,
|
||||
}
|
||||
|
||||
impl SparsityInfo {
|
||||
/// Analyze sparsity of a tensor
|
||||
pub fn analyze(tensor: &Tensor) -> Self {
|
||||
let data = tensor.to_vec();
|
||||
let total = data.len();
|
||||
let zeros = data.iter().filter(|&&x| x == 0.0).count();
|
||||
let sparsity = zeros as f32 / total as f32;
|
||||
|
||||
// Check for structured sparsity (simple block check)
|
||||
let block_sizes = [4, 8, 16, 32];
|
||||
let mut is_structured = false;
|
||||
let mut detected_block = None;
|
||||
|
||||
for &block in &block_sizes {
|
||||
if total >= block * 4 {
|
||||
let mut block_zeros = 0;
|
||||
let mut total_blocks = 0;
|
||||
|
||||
for chunk in data.chunks(block) {
|
||||
total_blocks += 1;
|
||||
if chunk.iter().all(|&x| x == 0.0) {
|
||||
block_zeros += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// If > 30% of blocks are all zeros, consider structured
|
||||
if block_zeros as f32 / total_blocks as f32 > 0.3 {
|
||||
is_structured = true;
|
||||
detected_block = Some(block);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
sparsity,
|
||||
is_structured,
|
||||
block_size: detected_block,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_shape_creation() {
|
||||
let s = Shape::d2(3, 4);
|
||||
assert_eq!(s.numel(), 12);
|
||||
assert_eq!(s.ndim(), 2);
|
||||
assert_eq!(s.dim(0), 3);
|
||||
assert_eq!(s.dim(1), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tensor_zeros() {
|
||||
let t = Tensor::zeros(Shape::d2(2, 3), DType::F32);
|
||||
assert_eq!(t.numel(), 6);
|
||||
let data = t.to_vec();
|
||||
assert!(data.iter().all(|&x| x == 0.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tensor_from_slice() {
|
||||
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let t = Tensor::from_slice(&data, Shape::d2(2, 3));
|
||||
assert_eq!(t.to_vec(), data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matmul_compatible() {
|
||||
let s1 = Shape::d2(3, 4);
|
||||
let s2 = Shape::d2(4, 5);
|
||||
let s3 = Shape::d2(3, 5);
|
||||
|
||||
assert!(s1.matmul_compatible(&s2));
|
||||
assert!(!s1.matmul_compatible(&s3));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transpose() {
|
||||
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let t = Tensor::from_slice(&data, Shape::d2(2, 3));
|
||||
let t_t = t.transpose();
|
||||
|
||||
assert_eq!(t_t.shape().dims(), &[3, 2]);
|
||||
assert_eq!(t_t.to_vec(), vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_workload_classification() {
|
||||
let small = Tensor::zeros(Shape::d2(10, 10), DType::F32);
|
||||
let large = Tensor::zeros(Shape::d2(1000, 1000), DType::F32);
|
||||
|
||||
assert_eq!(
|
||||
WorkloadType::classify(&small, Some(&small)),
|
||||
WorkloadType::SmallMatmul
|
||||
);
|
||||
assert_eq!(
|
||||
WorkloadType::classify(&large, Some(&large)),
|
||||
WorkloadType::LargeMatmul
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantization() {
|
||||
let data = vec![0.5, -0.5, 1.0, -1.0];
|
||||
let t = Tensor::from_slice(&data, Shape::d1(4));
|
||||
let q = t.quantize();
|
||||
|
||||
assert_eq!(q.dtype(), DType::I8);
|
||||
|
||||
// Dequantize and check approximate equality
|
||||
let dq = q.dequantize();
|
||||
let dq_data = dq.to_vec();
|
||||
for (a, b) in data.iter().zip(dq_data.iter()) {
|
||||
assert!((a - b).abs() < 0.01);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lora_adapter() {
|
||||
let lora = LoraAdapter::new(128, 128, 4, 1.0, "attention.q");
|
||||
assert_eq!(lora.rank(), 4);
|
||||
assert_eq!(lora.input_dim(), 128);
|
||||
assert_eq!(lora.output_dim(), 128);
|
||||
}
|
||||
}
|
||||
353
vendor/ruvector/examples/edge-net/src/compute/types.rs
vendored
Normal file
353
vendor/ruvector/examples/edge-net/src/compute/types.rs
vendored
Normal file
@@ -0,0 +1,353 @@
|
||||
//! Core types for compute operations
|
||||
//!
|
||||
//! These types work without the WebGPU feature and provide
|
||||
//! the interface for compute operations.
|
||||
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
/// Matrix storage format
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum MatrixLayout {
|
||||
/// Row-major storage (C-style)
|
||||
RowMajor,
|
||||
/// Column-major storage (Fortran-style)
|
||||
ColMajor,
|
||||
}
|
||||
|
||||
impl Default for MatrixLayout {
|
||||
fn default() -> Self {
|
||||
Self::RowMajor
|
||||
}
|
||||
}
|
||||
|
||||
/// Data type for compute operations
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum DataType {
|
||||
/// 32-bit floating point
|
||||
F32,
|
||||
/// 16-bit floating point
|
||||
F16,
|
||||
/// 16-bit brain floating point
|
||||
BF16,
|
||||
/// 8-bit signed integer
|
||||
I8,
|
||||
/// 8-bit unsigned integer
|
||||
U8,
|
||||
/// 4-bit integer (packed, 2 per byte)
|
||||
I4,
|
||||
}
|
||||
|
||||
impl DataType {
|
||||
/// Get size in bytes
|
||||
pub fn size_bytes(&self) -> usize {
|
||||
match self {
|
||||
Self::F32 => 4,
|
||||
Self::F16 | Self::BF16 => 2,
|
||||
Self::I8 | Self::U8 => 1,
|
||||
Self::I4 => 1, // 2 values per byte, but minimum addressable is 1
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if this is a floating point type
|
||||
pub fn is_float(&self) -> bool {
|
||||
matches!(self, Self::F32 | Self::F16 | Self::BF16)
|
||||
}
|
||||
|
||||
/// Check if this is a quantized type
|
||||
pub fn is_quantized(&self) -> bool {
|
||||
matches!(self, Self::I8 | Self::U8 | Self::I4)
|
||||
}
|
||||
}
|
||||
|
||||
/// Tensor descriptor for GPU buffers
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct TensorDescriptor {
|
||||
/// Shape of the tensor
|
||||
pub shape: Vec<usize>,
|
||||
/// Data type
|
||||
pub dtype: DataType,
|
||||
/// Storage layout
|
||||
pub layout: MatrixLayout,
|
||||
/// Stride between elements (None = contiguous)
|
||||
pub strides: Option<Vec<usize>>,
|
||||
}
|
||||
|
||||
impl TensorDescriptor {
|
||||
/// Create a new contiguous tensor descriptor
|
||||
pub fn new(shape: Vec<usize>, dtype: DataType) -> Self {
|
||||
Self {
|
||||
shape,
|
||||
dtype,
|
||||
layout: MatrixLayout::RowMajor,
|
||||
strides: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Total number of elements
|
||||
pub fn numel(&self) -> usize {
|
||||
self.shape.iter().product()
|
||||
}
|
||||
|
||||
/// Size in bytes
|
||||
pub fn size_bytes(&self) -> usize {
|
||||
self.numel() * self.dtype.size_bytes()
|
||||
}
|
||||
|
||||
/// Check if tensor is contiguous in memory
|
||||
pub fn is_contiguous(&self) -> bool {
|
||||
self.strides.is_none()
|
||||
}
|
||||
|
||||
/// Get number of dimensions
|
||||
pub fn ndim(&self) -> usize {
|
||||
self.shape.len()
|
||||
}
|
||||
|
||||
/// Create 2D matrix descriptor
|
||||
pub fn matrix(rows: usize, cols: usize, dtype: DataType) -> Self {
|
||||
Self::new(vec![rows, cols], dtype)
|
||||
}
|
||||
|
||||
/// Create 3D tensor descriptor (batch, seq, hidden)
|
||||
pub fn tensor3d(batch: usize, seq: usize, hidden: usize, dtype: DataType) -> Self {
|
||||
Self::new(vec![batch, seq, hidden], dtype)
|
||||
}
|
||||
}
|
||||
|
||||
/// LoRA adapter configuration
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct LoraConfig {
|
||||
/// Rank of the adaptation (typically 2-64)
|
||||
pub rank: usize,
|
||||
/// Alpha scaling factor
|
||||
pub alpha: f32,
|
||||
/// Input dimension
|
||||
pub in_dim: usize,
|
||||
/// Output dimension
|
||||
pub out_dim: usize,
|
||||
/// Dropout rate (0.0 = no dropout)
|
||||
pub dropout: f32,
|
||||
}
|
||||
|
||||
impl LoraConfig {
|
||||
/// Create new LoRA config
|
||||
pub fn new(rank: usize, in_dim: usize, out_dim: usize) -> Self {
|
||||
Self {
|
||||
rank,
|
||||
alpha: rank as f32, // Default alpha = rank
|
||||
in_dim,
|
||||
out_dim,
|
||||
dropout: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Scaling factor for LoRA output
|
||||
pub fn scaling(&self) -> f32 {
|
||||
self.alpha / self.rank as f32
|
||||
}
|
||||
|
||||
/// Size of A matrix (in_dim x rank)
|
||||
pub fn a_size(&self) -> usize {
|
||||
self.in_dim * self.rank
|
||||
}
|
||||
|
||||
/// Size of B matrix (rank x out_dim)
|
||||
pub fn b_size(&self) -> usize {
|
||||
self.rank * self.out_dim
|
||||
}
|
||||
}
|
||||
|
||||
/// Attention configuration
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct AttentionConfig {
|
||||
/// Number of attention heads
|
||||
pub num_heads: usize,
|
||||
/// Dimension per head
|
||||
pub head_dim: usize,
|
||||
/// Maximum sequence length
|
||||
pub max_seq_len: usize,
|
||||
/// Use causal (autoregressive) masking
|
||||
pub causal: bool,
|
||||
/// Attention dropout rate
|
||||
pub dropout: f32,
|
||||
/// Scale factor (None = 1/sqrt(head_dim))
|
||||
pub scale: Option<f32>,
|
||||
/// Use flash attention algorithm
|
||||
pub flash: bool,
|
||||
}
|
||||
|
||||
impl AttentionConfig {
|
||||
/// Create new attention config
|
||||
pub fn new(num_heads: usize, head_dim: usize, max_seq_len: usize) -> Self {
|
||||
Self {
|
||||
num_heads,
|
||||
head_dim,
|
||||
max_seq_len,
|
||||
causal: true,
|
||||
dropout: 0.0,
|
||||
scale: None,
|
||||
flash: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Total hidden dimension (num_heads * head_dim)
|
||||
pub fn hidden_dim(&self) -> usize {
|
||||
self.num_heads * self.head_dim
|
||||
}
|
||||
|
||||
/// Get attention scale factor
|
||||
pub fn get_scale(&self) -> f32 {
|
||||
self.scale.unwrap_or_else(|| 1.0 / (self.head_dim as f32).sqrt())
|
||||
}
|
||||
}
|
||||
|
||||
/// Quantization configuration for int8/int4 operations
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct QuantConfig {
|
||||
/// Target data type
|
||||
pub dtype: DataType,
|
||||
/// Per-channel vs per-tensor quantization
|
||||
pub per_channel: bool,
|
||||
/// Symmetric quantization (zero_point = 0)
|
||||
pub symmetric: bool,
|
||||
/// Group size for group quantization (0 = no grouping)
|
||||
pub group_size: usize,
|
||||
}
|
||||
|
||||
impl Default for QuantConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
dtype: DataType::I8,
|
||||
per_channel: true,
|
||||
symmetric: true,
|
||||
group_size: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl QuantConfig {
|
||||
/// Create int8 quantization config
|
||||
pub fn int8() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Create int4 quantization config with grouping
|
||||
pub fn int4_grouped(group_size: usize) -> Self {
|
||||
Self {
|
||||
dtype: DataType::I4,
|
||||
per_channel: false,
|
||||
symmetric: true,
|
||||
group_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Buffer usage flags for GPU memory
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub struct BufferUsage {
|
||||
pub map_read: bool,
|
||||
pub map_write: bool,
|
||||
pub copy_src: bool,
|
||||
pub copy_dst: bool,
|
||||
pub storage: bool,
|
||||
pub uniform: bool,
|
||||
}
|
||||
|
||||
impl Default for BufferUsage {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
map_read: false,
|
||||
map_write: false,
|
||||
copy_src: false,
|
||||
copy_dst: true,
|
||||
storage: true,
|
||||
uniform: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BufferUsage {
|
||||
/// Buffer for staging CPU->GPU transfers
|
||||
pub fn staging_upload() -> Self {
|
||||
Self {
|
||||
map_read: false,
|
||||
map_write: true,
|
||||
copy_src: true,
|
||||
copy_dst: false,
|
||||
storage: false,
|
||||
uniform: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Buffer for staging GPU->CPU transfers
|
||||
pub fn staging_download() -> Self {
|
||||
Self {
|
||||
map_read: true,
|
||||
map_write: false,
|
||||
copy_src: false,
|
||||
copy_dst: true,
|
||||
storage: false,
|
||||
uniform: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Buffer for compute shader storage
|
||||
pub fn storage() -> Self {
|
||||
Self {
|
||||
map_read: false,
|
||||
map_write: false,
|
||||
copy_src: true,
|
||||
copy_dst: true,
|
||||
storage: true,
|
||||
uniform: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Buffer for uniform data (small, read-only)
|
||||
pub fn uniform() -> Self {
|
||||
Self {
|
||||
map_read: false,
|
||||
map_write: false,
|
||||
copy_src: false,
|
||||
copy_dst: true,
|
||||
storage: false,
|
||||
uniform: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_data_type_size() {
|
||||
assert_eq!(DataType::F32.size_bytes(), 4);
|
||||
assert_eq!(DataType::F16.size_bytes(), 2);
|
||||
assert_eq!(DataType::I8.size_bytes(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tensor_descriptor() {
|
||||
let desc = TensorDescriptor::matrix(1024, 768, DataType::F32);
|
||||
assert_eq!(desc.numel(), 1024 * 768);
|
||||
assert_eq!(desc.size_bytes(), 1024 * 768 * 4);
|
||||
assert_eq!(desc.ndim(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lora_config() {
|
||||
let config = LoraConfig::new(4, 768, 768);
|
||||
assert_eq!(config.rank, 4);
|
||||
assert!((config.scaling() - 1.0).abs() < 0.001);
|
||||
assert_eq!(config.a_size(), 768 * 4);
|
||||
assert_eq!(config.b_size(), 4 * 768);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attention_config() {
|
||||
let config = AttentionConfig::new(12, 64, 4096);
|
||||
assert_eq!(config.hidden_dim(), 768);
|
||||
assert!((config.get_scale() - 0.125).abs() < 0.001);
|
||||
}
|
||||
}
|
||||
696
vendor/ruvector/examples/edge-net/src/compute/webgl_compute.rs
vendored
Normal file
696
vendor/ruvector/examples/edge-net/src/compute/webgl_compute.rs
vendored
Normal file
@@ -0,0 +1,696 @@
|
||||
//! WebGL2 compute simulation for GPU-accelerated operations
|
||||
//!
|
||||
//! Uses ping-pong texture rendering for matrix operations on devices without WebGPU.
|
||||
//! This approach treats textures as 2D arrays and uses fragment shaders for computation.
|
||||
//!
|
||||
//! ## Architecture
|
||||
//!
|
||||
//! ```text
|
||||
//! +-------------+ +----------------+ +-------------+
|
||||
//! | Input A | --> | Fragment | --> | Output |
|
||||
//! | (Texture) | | Shader | | (Texture) |
|
||||
//! +-------------+ +----------------+ +-------------+
|
||||
//! ^ | |
|
||||
//! | v v
|
||||
//! +-------------+ +----------------+ +-------------+
|
||||
//! | Input B | --> | Transform | --> | CPU Read |
|
||||
//! | (Texture) | | Feedback | | (Float32) |
|
||||
//! +-------------+ +----------------+ +-------------+
|
||||
//! ```
|
||||
//!
|
||||
//! ## Limitations vs WebGPU
|
||||
//!
|
||||
//! - No true compute shaders (uses fragment shaders)
|
||||
//! - Limited to 2D texture operations
|
||||
//! - Readback through transform feedback or readPixels
|
||||
//! - Lower performance than WebGPU compute
|
||||
|
||||
use wasm_bindgen::prelude::*;
|
||||
use web_sys::{
|
||||
WebGl2RenderingContext, WebGlProgram, WebGlShader, WebGlTexture,
|
||||
WebGlFramebuffer, WebGlBuffer, WebGlVertexArrayObject,
|
||||
};
|
||||
use crate::compute::tensor::{Tensor, TensorShape};
|
||||
|
||||
/// Shader programs for different operations
|
||||
struct ShaderPrograms {
|
||||
matmul: WebGlProgram,
|
||||
vector_add: WebGlProgram,
|
||||
vector_mul: WebGlProgram,
|
||||
softmax: WebGlProgram,
|
||||
relu: WebGlProgram,
|
||||
}
|
||||
|
||||
/// WebGL2 compute backend
|
||||
#[wasm_bindgen]
|
||||
pub struct WebGl2Compute {
|
||||
/// WebGL2 rendering context
|
||||
gl: WebGl2RenderingContext,
|
||||
/// Shader programs
|
||||
programs: ShaderPrograms,
|
||||
/// Texture pool for reuse
|
||||
texture_pool: Vec<TextureHandle>,
|
||||
/// Framebuffer for render-to-texture
|
||||
framebuffer: WebGlFramebuffer,
|
||||
/// Full-screen quad VAO
|
||||
quad_vao: WebGlVertexArrayObject,
|
||||
/// Quad vertex buffer
|
||||
quad_vbo: WebGlBuffer,
|
||||
/// Maximum texture size
|
||||
max_texture_size: u32,
|
||||
/// Transform feedback buffer for readback
|
||||
tf_buffer: WebGlBuffer,
|
||||
}
|
||||
|
||||
/// Handle to a pooled texture
|
||||
struct TextureHandle {
|
||||
texture: WebGlTexture,
|
||||
width: u32,
|
||||
height: u32,
|
||||
in_use: bool,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WebGl2Compute {
|
||||
/// Create a new WebGL2 compute backend
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new() -> Result<WebGl2Compute, JsValue> {
|
||||
let window = web_sys::window()
|
||||
.ok_or_else(|| JsValue::from_str("No window"))?;
|
||||
let document = window.document()
|
||||
.ok_or_else(|| JsValue::from_str("No document"))?;
|
||||
|
||||
// Create offscreen canvas
|
||||
let canvas = document.create_element("canvas")?;
|
||||
let canvas: web_sys::HtmlCanvasElement = canvas.dyn_into()?;
|
||||
canvas.set_width(1);
|
||||
canvas.set_height(1);
|
||||
|
||||
// Get WebGL2 context
|
||||
let context_options = js_sys::Object::new();
|
||||
js_sys::Reflect::set(&context_options, &"antialias".into(), &false.into())?;
|
||||
js_sys::Reflect::set(&context_options, &"depth".into(), &false.into())?;
|
||||
js_sys::Reflect::set(&context_options, &"stencil".into(), &false.into())?;
|
||||
js_sys::Reflect::set(&context_options, &"preserveDrawingBuffer".into(), &true.into())?;
|
||||
|
||||
let gl: WebGl2RenderingContext = canvas
|
||||
.get_context_with_context_options("webgl2", &context_options)?
|
||||
.ok_or_else(|| JsValue::from_str("WebGL2 not available"))?
|
||||
.dyn_into()?;
|
||||
|
||||
// Enable required extensions
|
||||
gl.get_extension("EXT_color_buffer_float")?
|
||||
.ok_or_else(|| JsValue::from_str("EXT_color_buffer_float not available"))?;
|
||||
gl.get_extension("OES_texture_float_linear")?;
|
||||
|
||||
// Get max texture size
|
||||
let max_texture_size = gl.get_parameter(WebGl2RenderingContext::MAX_TEXTURE_SIZE)?
|
||||
.as_f64()
|
||||
.unwrap_or(4096.0) as u32;
|
||||
|
||||
// Create shader programs
|
||||
let programs = ShaderPrograms {
|
||||
matmul: create_matmul_program(&gl)?,
|
||||
vector_add: create_vector_add_program(&gl)?,
|
||||
vector_mul: create_vector_mul_program(&gl)?,
|
||||
softmax: create_softmax_program(&gl)?,
|
||||
relu: create_relu_program(&gl)?,
|
||||
};
|
||||
|
||||
// Create framebuffer
|
||||
let framebuffer = gl.create_framebuffer()
|
||||
.ok_or_else(|| JsValue::from_str("Failed to create framebuffer"))?;
|
||||
|
||||
// Create full-screen quad
|
||||
let (quad_vao, quad_vbo) = create_fullscreen_quad(&gl)?;
|
||||
|
||||
// Create transform feedback buffer
|
||||
let tf_buffer = gl.create_buffer()
|
||||
.ok_or_else(|| JsValue::from_str("Failed to create TF buffer"))?;
|
||||
|
||||
Ok(WebGl2Compute {
|
||||
gl,
|
||||
programs,
|
||||
texture_pool: Vec::new(),
|
||||
framebuffer,
|
||||
quad_vao,
|
||||
quad_vbo,
|
||||
max_texture_size,
|
||||
tf_buffer,
|
||||
})
|
||||
}
|
||||
|
||||
/// Check if WebGL2 compute is available
|
||||
#[wasm_bindgen(js_name = isAvailable)]
|
||||
pub fn is_available() -> bool {
|
||||
if let Some(window) = web_sys::window() {
|
||||
if let Some(document) = window.document() {
|
||||
if let Ok(canvas) = document.create_element("canvas") {
|
||||
if let Ok(canvas) = canvas.dyn_into::<web_sys::HtmlCanvasElement>() {
|
||||
if let Ok(Some(ctx)) = canvas.get_context("webgl2") {
|
||||
if let Ok(gl) = ctx.dyn_into::<WebGl2RenderingContext>() {
|
||||
return gl.get_extension("EXT_color_buffer_float")
|
||||
.map(|e| e.is_some())
|
||||
.unwrap_or(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Get maximum supported texture size
|
||||
#[wasm_bindgen(js_name = maxTextureSize)]
|
||||
pub fn max_texture_size(&self) -> u32 {
|
||||
self.max_texture_size
|
||||
}
|
||||
}
|
||||
|
||||
// Non-WASM implementation
|
||||
impl WebGl2Compute {
|
||||
/// Perform matrix multiplication: C = A * B
|
||||
pub fn matmul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor, JsValue> {
|
||||
if !a.shape().is_matrix() || !b.shape().is_matrix() {
|
||||
return Err(JsValue::from_str("Inputs must be matrices"));
|
||||
}
|
||||
|
||||
let m = a.shape().rows();
|
||||
let k = a.shape().cols();
|
||||
let n = b.shape().cols();
|
||||
|
||||
if k != b.shape().rows() {
|
||||
return Err(JsValue::from_str("Matrix dimension mismatch"));
|
||||
}
|
||||
|
||||
// For small matrices, use CPU
|
||||
if m * k * n < 4096 {
|
||||
return Ok(self.cpu_matmul(a, b));
|
||||
}
|
||||
|
||||
// Upload matrices to textures
|
||||
let tex_a = self.upload_matrix(a)?;
|
||||
let tex_b = self.upload_matrix(b)?;
|
||||
let tex_c = self.create_texture(m as u32, n as u32)?;
|
||||
|
||||
// Bind output texture to framebuffer
|
||||
self.gl.bind_framebuffer(WebGl2RenderingContext::FRAMEBUFFER, Some(&self.framebuffer));
|
||||
self.gl.framebuffer_texture_2d(
|
||||
WebGl2RenderingContext::FRAMEBUFFER,
|
||||
WebGl2RenderingContext::COLOR_ATTACHMENT0,
|
||||
WebGl2RenderingContext::TEXTURE_2D,
|
||||
Some(&tex_c),
|
||||
0,
|
||||
);
|
||||
|
||||
// Set viewport
|
||||
self.gl.viewport(0, 0, n as i32, m as i32);
|
||||
|
||||
// Use matmul program
|
||||
self.gl.use_program(Some(&self.programs.matmul));
|
||||
|
||||
// Bind input textures
|
||||
self.gl.active_texture(WebGl2RenderingContext::TEXTURE0);
|
||||
self.gl.bind_texture(WebGl2RenderingContext::TEXTURE_2D, Some(&tex_a));
|
||||
let loc_a = self.gl.get_uniform_location(&self.programs.matmul, "u_A");
|
||||
self.gl.uniform1i(loc_a.as_ref(), 0);
|
||||
|
||||
self.gl.active_texture(WebGl2RenderingContext::TEXTURE1);
|
||||
self.gl.bind_texture(WebGl2RenderingContext::TEXTURE_2D, Some(&tex_b));
|
||||
let loc_b = self.gl.get_uniform_location(&self.programs.matmul, "u_B");
|
||||
self.gl.uniform1i(loc_b.as_ref(), 1);
|
||||
|
||||
// Set dimensions
|
||||
let loc_dims = self.gl.get_uniform_location(&self.programs.matmul, "u_dims");
|
||||
self.gl.uniform3f(loc_dims.as_ref(), m as f32, k as f32, n as f32);
|
||||
|
||||
// Draw full-screen quad
|
||||
self.gl.bind_vertex_array(Some(&self.quad_vao));
|
||||
self.gl.draw_arrays(WebGl2RenderingContext::TRIANGLE_STRIP, 0, 4);
|
||||
|
||||
// Read back result
|
||||
let result = self.read_texture(&tex_c, m as u32, n as u32)?;
|
||||
|
||||
// Cleanup
|
||||
self.gl.delete_texture(Some(&tex_a));
|
||||
self.gl.delete_texture(Some(&tex_b));
|
||||
self.gl.delete_texture(Some(&tex_c));
|
||||
self.gl.bind_framebuffer(WebGl2RenderingContext::FRAMEBUFFER, None);
|
||||
|
||||
Ok(Tensor::from_vec(result, TensorShape::matrix(m, n)))
|
||||
}
|
||||
|
||||
/// Element-wise vector operations
|
||||
pub fn vector_op(&self, a: &[f32], b: &[f32], op: &str) -> Result<Vec<f32>, JsValue> {
|
||||
if a.len() != b.len() {
|
||||
return Err(JsValue::from_str("Vector length mismatch"));
|
||||
}
|
||||
|
||||
let len = a.len();
|
||||
|
||||
// For small vectors, use CPU
|
||||
if len < 1024 {
|
||||
return Ok(match op {
|
||||
"add" => a.iter().zip(b.iter()).map(|(x, y)| x + y).collect(),
|
||||
"sub" => a.iter().zip(b.iter()).map(|(x, y)| x - y).collect(),
|
||||
"mul" => a.iter().zip(b.iter()).map(|(x, y)| x * y).collect(),
|
||||
"div" => a.iter().zip(b.iter()).map(|(x, y)| x / y).collect(),
|
||||
_ => return Err(JsValue::from_str(&format!("Unknown op: {}", op))),
|
||||
});
|
||||
}
|
||||
|
||||
// Calculate texture dimensions (square-ish)
|
||||
let width = (len as f32).sqrt().ceil() as u32;
|
||||
let height = ((len as u32 + width - 1) / width).max(1);
|
||||
|
||||
// Pad data to fill texture
|
||||
let padded_len = (width * height) as usize;
|
||||
let mut a_padded = a.to_vec();
|
||||
let mut b_padded = b.to_vec();
|
||||
a_padded.resize(padded_len, 0.0);
|
||||
b_padded.resize(padded_len, 0.0);
|
||||
|
||||
// Upload to textures
|
||||
let tex_a = self.upload_data(&a_padded, width, height)?;
|
||||
let tex_b = self.upload_data(&b_padded, width, height)?;
|
||||
let tex_c = self.create_texture(width, height)?;
|
||||
|
||||
// Select program
|
||||
let program = match op {
|
||||
"add" | "sub" => &self.programs.vector_add,
|
||||
"mul" | "div" => &self.programs.vector_mul,
|
||||
_ => return Err(JsValue::from_str(&format!("Unknown op: {}", op))),
|
||||
};
|
||||
|
||||
// Bind framebuffer
|
||||
self.gl.bind_framebuffer(WebGl2RenderingContext::FRAMEBUFFER, Some(&self.framebuffer));
|
||||
self.gl.framebuffer_texture_2d(
|
||||
WebGl2RenderingContext::FRAMEBUFFER,
|
||||
WebGl2RenderingContext::COLOR_ATTACHMENT0,
|
||||
WebGl2RenderingContext::TEXTURE_2D,
|
||||
Some(&tex_c),
|
||||
0,
|
||||
);
|
||||
|
||||
self.gl.viewport(0, 0, width as i32, height as i32);
|
||||
self.gl.use_program(Some(program));
|
||||
|
||||
// Bind textures
|
||||
self.gl.active_texture(WebGl2RenderingContext::TEXTURE0);
|
||||
self.gl.bind_texture(WebGl2RenderingContext::TEXTURE_2D, Some(&tex_a));
|
||||
self.gl.uniform1i(self.gl.get_uniform_location(program, "u_A").as_ref(), 0);
|
||||
|
||||
self.gl.active_texture(WebGl2RenderingContext::TEXTURE1);
|
||||
self.gl.bind_texture(WebGl2RenderingContext::TEXTURE_2D, Some(&tex_b));
|
||||
self.gl.uniform1i(self.gl.get_uniform_location(program, "u_B").as_ref(), 1);
|
||||
|
||||
// Set operation mode
|
||||
let op_mode = match op {
|
||||
"add" => 0.0,
|
||||
"sub" => 1.0,
|
||||
"mul" => 0.0,
|
||||
"div" => 1.0,
|
||||
_ => 0.0,
|
||||
};
|
||||
self.gl.uniform1f(self.gl.get_uniform_location(program, "u_mode").as_ref(), op_mode);
|
||||
|
||||
// Draw
|
||||
self.gl.bind_vertex_array(Some(&self.quad_vao));
|
||||
self.gl.draw_arrays(WebGl2RenderingContext::TRIANGLE_STRIP, 0, 4);
|
||||
|
||||
// Read back
|
||||
let result = self.read_texture(&tex_c, width, height)?;
|
||||
|
||||
// Cleanup
|
||||
self.gl.delete_texture(Some(&tex_a));
|
||||
self.gl.delete_texture(Some(&tex_b));
|
||||
self.gl.delete_texture(Some(&tex_c));
|
||||
self.gl.bind_framebuffer(WebGl2RenderingContext::FRAMEBUFFER, None);
|
||||
|
||||
// Trim to original length
|
||||
Ok(result[..len].to_vec())
|
||||
}
|
||||
|
||||
/// Upload matrix to texture
|
||||
fn upload_matrix(&self, tensor: &Tensor) -> Result<WebGlTexture, JsValue> {
|
||||
let rows = tensor.shape().rows() as u32;
|
||||
let cols = tensor.shape().cols() as u32;
|
||||
self.upload_data(tensor.data(), cols, rows)
|
||||
}
|
||||
|
||||
/// Upload data to a float texture
|
||||
fn upload_data(&self, data: &[f32], width: u32, height: u32) -> Result<WebGlTexture, JsValue> {
|
||||
let texture = self.gl.create_texture()
|
||||
.ok_or_else(|| JsValue::from_str("Failed to create texture"))?;
|
||||
|
||||
self.gl.bind_texture(WebGl2RenderingContext::TEXTURE_2D, Some(&texture));
|
||||
|
||||
// Set texture parameters
|
||||
self.gl.tex_parameteri(
|
||||
WebGl2RenderingContext::TEXTURE_2D,
|
||||
WebGl2RenderingContext::TEXTURE_MIN_FILTER,
|
||||
WebGl2RenderingContext::NEAREST as i32,
|
||||
);
|
||||
self.gl.tex_parameteri(
|
||||
WebGl2RenderingContext::TEXTURE_2D,
|
||||
WebGl2RenderingContext::TEXTURE_MAG_FILTER,
|
||||
WebGl2RenderingContext::NEAREST as i32,
|
||||
);
|
||||
self.gl.tex_parameteri(
|
||||
WebGl2RenderingContext::TEXTURE_2D,
|
||||
WebGl2RenderingContext::TEXTURE_WRAP_S,
|
||||
WebGl2RenderingContext::CLAMP_TO_EDGE as i32,
|
||||
);
|
||||
self.gl.tex_parameteri(
|
||||
WebGl2RenderingContext::TEXTURE_2D,
|
||||
WebGl2RenderingContext::TEXTURE_WRAP_T,
|
||||
WebGl2RenderingContext::CLAMP_TO_EDGE as i32,
|
||||
);
|
||||
|
||||
// Create Float32Array view
|
||||
let array = js_sys::Float32Array::from(data);
|
||||
|
||||
// Upload as R32F texture
|
||||
self.gl.tex_image_2d_with_i32_and_i32_and_i32_and_format_and_type_and_opt_array_buffer_view(
|
||||
WebGl2RenderingContext::TEXTURE_2D,
|
||||
0,
|
||||
WebGl2RenderingContext::R32F as i32,
|
||||
width as i32,
|
||||
height as i32,
|
||||
0,
|
||||
WebGl2RenderingContext::RED,
|
||||
WebGl2RenderingContext::FLOAT,
|
||||
Some(&array),
|
||||
)?;
|
||||
|
||||
Ok(texture)
|
||||
}
|
||||
|
||||
/// Create an empty float texture
|
||||
fn create_texture(&self, width: u32, height: u32) -> Result<WebGlTexture, JsValue> {
|
||||
let texture = self.gl.create_texture()
|
||||
.ok_or_else(|| JsValue::from_str("Failed to create texture"))?;
|
||||
|
||||
self.gl.bind_texture(WebGl2RenderingContext::TEXTURE_2D, Some(&texture));
|
||||
|
||||
self.gl.tex_parameteri(
|
||||
WebGl2RenderingContext::TEXTURE_2D,
|
||||
WebGl2RenderingContext::TEXTURE_MIN_FILTER,
|
||||
WebGl2RenderingContext::NEAREST as i32,
|
||||
);
|
||||
self.gl.tex_parameteri(
|
||||
WebGl2RenderingContext::TEXTURE_2D,
|
||||
WebGl2RenderingContext::TEXTURE_MAG_FILTER,
|
||||
WebGl2RenderingContext::NEAREST as i32,
|
||||
);
|
||||
|
||||
self.gl.tex_image_2d_with_i32_and_i32_and_i32_and_format_and_type_and_opt_array_buffer_view(
|
||||
WebGl2RenderingContext::TEXTURE_2D,
|
||||
0,
|
||||
WebGl2RenderingContext::R32F as i32,
|
||||
width as i32,
|
||||
height as i32,
|
||||
0,
|
||||
WebGl2RenderingContext::RED,
|
||||
WebGl2RenderingContext::FLOAT,
|
||||
None,
|
||||
)?;
|
||||
|
||||
Ok(texture)
|
||||
}
|
||||
|
||||
/// Read texture data back to CPU
|
||||
fn read_texture(&self, texture: &WebGlTexture, width: u32, height: u32) -> Result<Vec<f32>, JsValue> {
|
||||
// Bind texture to framebuffer
|
||||
self.gl.bind_framebuffer(WebGl2RenderingContext::FRAMEBUFFER, Some(&self.framebuffer));
|
||||
self.gl.framebuffer_texture_2d(
|
||||
WebGl2RenderingContext::FRAMEBUFFER,
|
||||
WebGl2RenderingContext::COLOR_ATTACHMENT0,
|
||||
WebGl2RenderingContext::TEXTURE_2D,
|
||||
Some(texture),
|
||||
0,
|
||||
);
|
||||
|
||||
// Read pixels as RGBA (WebGL2 limitation for readPixels)
|
||||
let pixel_count = (width * height) as usize;
|
||||
let mut rgba_data = vec![0u8; pixel_count * 4 * 4]; // RGBA * f32
|
||||
|
||||
// Use readPixels with RGBA format
|
||||
let float_array = js_sys::Float32Array::new_with_length(pixel_count as u32 * 4);
|
||||
|
||||
self.gl.read_pixels_with_array_buffer_view(
|
||||
0, 0,
|
||||
width as i32, height as i32,
|
||||
WebGl2RenderingContext::RGBA,
|
||||
WebGl2RenderingContext::FLOAT,
|
||||
&float_array,
|
||||
)?;
|
||||
|
||||
// Extract R channel (our actual data)
|
||||
let mut result = Vec::with_capacity(pixel_count);
|
||||
for i in 0..pixel_count {
|
||||
result.push(float_array.get_index((i * 4) as u32));
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// CPU fallback for small matrices
|
||||
fn cpu_matmul(&self, a: &Tensor, b: &Tensor) -> Tensor {
|
||||
let m = a.shape().rows();
|
||||
let k = a.shape().cols();
|
||||
let n = b.shape().cols();
|
||||
|
||||
let a_data = a.data();
|
||||
let b_data = b.data();
|
||||
let mut result = vec![0.0f32; m * n];
|
||||
|
||||
for i in 0..m {
|
||||
for j in 0..n {
|
||||
let mut sum = 0.0;
|
||||
for kk in 0..k {
|
||||
sum += a_data[i * k + kk] * b_data[kk * n + j];
|
||||
}
|
||||
result[i * n + j] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor::from_vec(result, TensorShape::matrix(m, n))
|
||||
}
|
||||
}
|
||||
|
||||
/// Create fullscreen quad for render-to-texture
|
||||
fn create_fullscreen_quad(gl: &WebGl2RenderingContext) -> Result<(WebGlVertexArrayObject, WebGlBuffer), JsValue> {
|
||||
let vao = gl.create_vertex_array()
|
||||
.ok_or_else(|| JsValue::from_str("Failed to create VAO"))?;
|
||||
let vbo = gl.create_buffer()
|
||||
.ok_or_else(|| JsValue::from_str("Failed to create VBO"))?;
|
||||
|
||||
gl.bind_vertex_array(Some(&vao));
|
||||
gl.bind_buffer(WebGl2RenderingContext::ARRAY_BUFFER, Some(&vbo));
|
||||
|
||||
// Fullscreen quad vertices (position + texcoord)
|
||||
let vertices: [f32; 16] = [
|
||||
-1.0, -1.0, 0.0, 0.0,
|
||||
1.0, -1.0, 1.0, 0.0,
|
||||
-1.0, 1.0, 0.0, 1.0,
|
||||
1.0, 1.0, 1.0, 1.0,
|
||||
];
|
||||
|
||||
let array = js_sys::Float32Array::from(vertices.as_slice());
|
||||
gl.buffer_data_with_array_buffer_view(
|
||||
WebGl2RenderingContext::ARRAY_BUFFER,
|
||||
&array,
|
||||
WebGl2RenderingContext::STATIC_DRAW,
|
||||
);
|
||||
|
||||
// Position attribute
|
||||
gl.enable_vertex_attrib_array(0);
|
||||
gl.vertex_attrib_pointer_with_i32(0, 2, WebGl2RenderingContext::FLOAT, false, 16, 0);
|
||||
|
||||
// Texcoord attribute
|
||||
gl.enable_vertex_attrib_array(1);
|
||||
gl.vertex_attrib_pointer_with_i32(1, 2, WebGl2RenderingContext::FLOAT, false, 16, 8);
|
||||
|
||||
Ok((vao, vbo))
|
||||
}
|
||||
|
||||
/// Compile a shader
|
||||
fn compile_shader(gl: &WebGl2RenderingContext, shader_type: u32, source: &str) -> Result<WebGlShader, JsValue> {
|
||||
let shader = gl.create_shader(shader_type)
|
||||
.ok_or_else(|| JsValue::from_str("Failed to create shader"))?;
|
||||
|
||||
gl.shader_source(&shader, source);
|
||||
gl.compile_shader(&shader);
|
||||
|
||||
if !gl.get_shader_parameter(&shader, WebGl2RenderingContext::COMPILE_STATUS)
|
||||
.as_bool()
|
||||
.unwrap_or(false)
|
||||
{
|
||||
let log = gl.get_shader_info_log(&shader)
|
||||
.unwrap_or_else(|| "Unknown error".to_string());
|
||||
gl.delete_shader(Some(&shader));
|
||||
return Err(JsValue::from_str(&format!("Shader compile error: {}", log)));
|
||||
}
|
||||
|
||||
Ok(shader)
|
||||
}
|
||||
|
||||
/// Link a shader program
|
||||
fn link_program(gl: &WebGl2RenderingContext, vertex: &WebGlShader, fragment: &WebGlShader) -> Result<WebGlProgram, JsValue> {
|
||||
let program = gl.create_program()
|
||||
.ok_or_else(|| JsValue::from_str("Failed to create program"))?;
|
||||
|
||||
gl.attach_shader(&program, vertex);
|
||||
gl.attach_shader(&program, fragment);
|
||||
gl.link_program(&program);
|
||||
|
||||
if !gl.get_program_parameter(&program, WebGl2RenderingContext::LINK_STATUS)
|
||||
.as_bool()
|
||||
.unwrap_or(false)
|
||||
{
|
||||
let log = gl.get_program_info_log(&program)
|
||||
.unwrap_or_else(|| "Unknown error".to_string());
|
||||
gl.delete_program(Some(&program));
|
||||
return Err(JsValue::from_str(&format!("Program link error: {}", log)));
|
||||
}
|
||||
|
||||
Ok(program)
|
||||
}
|
||||
|
||||
/// Vertex shader for all compute operations
|
||||
const VERTEX_SHADER: &str = r#"#version 300 es
|
||||
layout(location = 0) in vec2 a_position;
|
||||
layout(location = 1) in vec2 a_texcoord;
|
||||
out vec2 v_texcoord;
|
||||
void main() {
|
||||
gl_Position = vec4(a_position, 0.0, 1.0);
|
||||
v_texcoord = a_texcoord;
|
||||
}
|
||||
"#;
|
||||
|
||||
/// Create matrix multiplication program
|
||||
fn create_matmul_program(gl: &WebGl2RenderingContext) -> Result<WebGlProgram, JsValue> {
|
||||
const MATMUL_FRAG: &str = r#"#version 300 es
|
||||
precision highp float;
|
||||
uniform sampler2D u_A;
|
||||
uniform sampler2D u_B;
|
||||
uniform vec3 u_dims; // M, K, N
|
||||
in vec2 v_texcoord;
|
||||
out float fragColor;
|
||||
|
||||
void main() {
|
||||
float M = u_dims.x;
|
||||
float K = u_dims.y;
|
||||
float N = u_dims.z;
|
||||
|
||||
// Output position
|
||||
float i = floor(v_texcoord.y * M);
|
||||
float j = floor(v_texcoord.x * N);
|
||||
|
||||
float sum = 0.0;
|
||||
for (float k = 0.0; k < K; k += 1.0) {
|
||||
float a = texture(u_A, vec2((k + 0.5) / K, (i + 0.5) / M)).r;
|
||||
float b = texture(u_B, vec2((j + 0.5) / N, (k + 0.5) / K)).r;
|
||||
sum += a * b;
|
||||
}
|
||||
|
||||
fragColor = sum;
|
||||
}
|
||||
"#;
|
||||
|
||||
let vs = compile_shader(gl, WebGl2RenderingContext::VERTEX_SHADER, VERTEX_SHADER)?;
|
||||
let fs = compile_shader(gl, WebGl2RenderingContext::FRAGMENT_SHADER, MATMUL_FRAG)?;
|
||||
link_program(gl, &vs, &fs)
|
||||
}
|
||||
|
||||
/// Create vector addition program
|
||||
fn create_vector_add_program(gl: &WebGl2RenderingContext) -> Result<WebGlProgram, JsValue> {
|
||||
const VECTOR_ADD_FRAG: &str = r#"#version 300 es
|
||||
precision highp float;
|
||||
uniform sampler2D u_A;
|
||||
uniform sampler2D u_B;
|
||||
uniform float u_mode; // 0 = add, 1 = sub
|
||||
in vec2 v_texcoord;
|
||||
out float fragColor;
|
||||
|
||||
void main() {
|
||||
float a = texture(u_A, v_texcoord).r;
|
||||
float b = texture(u_B, v_texcoord).r;
|
||||
fragColor = u_mode < 0.5 ? a + b : a - b;
|
||||
}
|
||||
"#;
|
||||
|
||||
let vs = compile_shader(gl, WebGl2RenderingContext::VERTEX_SHADER, VERTEX_SHADER)?;
|
||||
let fs = compile_shader(gl, WebGl2RenderingContext::FRAGMENT_SHADER, VECTOR_ADD_FRAG)?;
|
||||
link_program(gl, &vs, &fs)
|
||||
}
|
||||
|
||||
/// Create vector multiplication program
|
||||
fn create_vector_mul_program(gl: &WebGl2RenderingContext) -> Result<WebGlProgram, JsValue> {
|
||||
const VECTOR_MUL_FRAG: &str = r#"#version 300 es
|
||||
precision highp float;
|
||||
uniform sampler2D u_A;
|
||||
uniform sampler2D u_B;
|
||||
uniform float u_mode; // 0 = mul, 1 = div
|
||||
in vec2 v_texcoord;
|
||||
out float fragColor;
|
||||
|
||||
void main() {
|
||||
float a = texture(u_A, v_texcoord).r;
|
||||
float b = texture(u_B, v_texcoord).r;
|
||||
fragColor = u_mode < 0.5 ? a * b : a / max(b, 1e-7);
|
||||
}
|
||||
"#;
|
||||
|
||||
let vs = compile_shader(gl, WebGl2RenderingContext::VERTEX_SHADER, VERTEX_SHADER)?;
|
||||
let fs = compile_shader(gl, WebGl2RenderingContext::FRAGMENT_SHADER, VECTOR_MUL_FRAG)?;
|
||||
link_program(gl, &vs, &fs)
|
||||
}
|
||||
|
||||
/// Create softmax program
|
||||
fn create_softmax_program(gl: &WebGl2RenderingContext) -> Result<WebGlProgram, JsValue> {
|
||||
const SOFTMAX_FRAG: &str = r#"#version 300 es
|
||||
precision highp float;
|
||||
uniform sampler2D u_A;
|
||||
uniform vec2 u_size;
|
||||
in vec2 v_texcoord;
|
||||
out float fragColor;
|
||||
|
||||
void main() {
|
||||
// First pass would compute max, second pass computes exp/sum
|
||||
// This is a simplified single-pass version for small vectors
|
||||
float x = texture(u_A, v_texcoord).r;
|
||||
fragColor = exp(x);
|
||||
}
|
||||
"#;
|
||||
|
||||
let vs = compile_shader(gl, WebGl2RenderingContext::VERTEX_SHADER, VERTEX_SHADER)?;
|
||||
let fs = compile_shader(gl, WebGl2RenderingContext::FRAGMENT_SHADER, SOFTMAX_FRAG)?;
|
||||
link_program(gl, &vs, &fs)
|
||||
}
|
||||
|
||||
/// Create ReLU program
|
||||
fn create_relu_program(gl: &WebGl2RenderingContext) -> Result<WebGlProgram, JsValue> {
|
||||
const RELU_FRAG: &str = r#"#version 300 es
|
||||
precision highp float;
|
||||
uniform sampler2D u_A;
|
||||
in vec2 v_texcoord;
|
||||
out float fragColor;
|
||||
|
||||
void main() {
|
||||
float x = texture(u_A, v_texcoord).r;
|
||||
fragColor = max(x, 0.0);
|
||||
}
|
||||
"#;
|
||||
|
||||
let vs = compile_shader(gl, WebGl2RenderingContext::VERTEX_SHADER, VERTEX_SHADER)?;
|
||||
let fs = compile_shader(gl, WebGl2RenderingContext::FRAGMENT_SHADER, RELU_FRAG)?;
|
||||
link_program(gl, &vs, &fs)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
// WebGL tests require browser environment
|
||||
}
|
||||
909
vendor/ruvector/examples/edge-net/src/compute/webgpu.rs
vendored
Normal file
909
vendor/ruvector/examples/edge-net/src/compute/webgpu.rs
vendored
Normal file
@@ -0,0 +1,909 @@
|
||||
//! WebGPU Compute Backend Implementation
|
||||
//!
|
||||
//! This module provides GPU-accelerated compute operations using wgpu.
|
||||
//! It includes optimized pipelines for matrix multiplication, attention,
|
||||
//! and LoRA adapter inference.
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::{
|
||||
ComputeConfig, ComputeError, ComputeMetrics,
|
||||
TensorDescriptor, DataType, LoraConfig, AttentionConfig,
|
||||
BufferUsage, MATMUL_SHADER, ATTENTION_SHADER, LORA_SHADER,
|
||||
};
|
||||
|
||||
/// Buffer handle for GPU memory
|
||||
#[derive(Clone)]
|
||||
pub struct GpuBuffer {
|
||||
/// Underlying wgpu buffer
|
||||
buffer: Arc<wgpu::Buffer>,
|
||||
/// Size in bytes
|
||||
size: usize,
|
||||
/// Tensor descriptor
|
||||
desc: TensorDescriptor,
|
||||
}
|
||||
|
||||
impl GpuBuffer {
|
||||
/// Get buffer size in bytes
|
||||
pub fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
|
||||
/// Get tensor descriptor
|
||||
pub fn descriptor(&self) -> &TensorDescriptor {
|
||||
&self.desc
|
||||
}
|
||||
|
||||
/// Get underlying wgpu buffer
|
||||
pub fn raw(&self) -> &wgpu::Buffer {
|
||||
&self.buffer
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute pipeline for a specific operation
|
||||
struct ComputePipeline {
|
||||
pipeline: wgpu::ComputePipeline,
|
||||
bind_group_layout: wgpu::BindGroupLayout,
|
||||
}
|
||||
|
||||
/// WebGPU compute backend for GPU-accelerated inference
|
||||
pub struct WebGpuCompute {
|
||||
/// GPU device handle
|
||||
device: Arc<wgpu::Device>,
|
||||
/// Command queue
|
||||
queue: Arc<wgpu::Queue>,
|
||||
/// Backend configuration
|
||||
config: ComputeConfig,
|
||||
/// Matrix multiplication pipeline
|
||||
matmul_pipeline: ComputePipeline,
|
||||
/// Attention pipeline
|
||||
attention_pipeline: ComputePipeline,
|
||||
/// LoRA forward pipeline
|
||||
lora_pipeline: ComputePipeline,
|
||||
/// Staging buffer pool for CPU<->GPU transfers
|
||||
staging_pool: StagingBufferPool,
|
||||
/// Performance metrics from last operation
|
||||
last_metrics: ComputeMetrics,
|
||||
/// Device limits
|
||||
limits: wgpu::Limits,
|
||||
}
|
||||
|
||||
impl WebGpuCompute {
|
||||
/// Create a new WebGPU compute backend
|
||||
pub async fn new() -> Result<Self, ComputeError> {
|
||||
Self::with_config(ComputeConfig::default()).await
|
||||
}
|
||||
|
||||
/// Create with custom configuration
|
||||
pub async fn with_config(config: ComputeConfig) -> Result<Self, ComputeError> {
|
||||
// Request adapter
|
||||
let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
|
||||
backends: wgpu::Backends::all(),
|
||||
dx12_shader_compiler: wgpu::Dx12Compiler::Fxc,
|
||||
flags: wgpu::InstanceFlags::empty(),
|
||||
gles_minor_version: wgpu::Gles3MinorVersion::Automatic,
|
||||
});
|
||||
|
||||
let adapter = instance
|
||||
.request_adapter(&wgpu::RequestAdapterOptions {
|
||||
power_preference: wgpu::PowerPreference::HighPerformance,
|
||||
compatible_surface: None,
|
||||
force_fallback_adapter: false,
|
||||
})
|
||||
.await
|
||||
.ok_or_else(|| ComputeError::DeviceNotAvailable(
|
||||
"No suitable GPU adapter found".to_string()
|
||||
))?;
|
||||
|
||||
let limits = adapter.limits();
|
||||
|
||||
// Request device with compute capabilities
|
||||
let (device, queue) = adapter
|
||||
.request_device(
|
||||
&wgpu::DeviceDescriptor {
|
||||
label: Some("edge-net-compute"),
|
||||
required_features: wgpu::Features::empty(),
|
||||
required_limits: wgpu::Limits::default(),
|
||||
memory_hints: wgpu::MemoryHints::Performance,
|
||||
},
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| ComputeError::DeviceNotAvailable(e.to_string()))?;
|
||||
|
||||
let device = Arc::new(device);
|
||||
let queue = Arc::new(queue);
|
||||
|
||||
// Create compute pipelines
|
||||
let matmul_pipeline = Self::create_matmul_pipeline(&device, &config)?;
|
||||
let attention_pipeline = Self::create_attention_pipeline(&device, &config)?;
|
||||
let lora_pipeline = Self::create_lora_pipeline(&device, &config)?;
|
||||
|
||||
// Create staging buffer pool
|
||||
let staging_pool = StagingBufferPool::new(device.clone(), 16 * 1024 * 1024); // 16MB pool
|
||||
|
||||
Ok(Self {
|
||||
device,
|
||||
queue,
|
||||
config,
|
||||
matmul_pipeline,
|
||||
attention_pipeline,
|
||||
lora_pipeline,
|
||||
staging_pool,
|
||||
last_metrics: ComputeMetrics::default(),
|
||||
limits,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create matrix multiplication pipeline
|
||||
fn create_matmul_pipeline(
|
||||
device: &wgpu::Device,
|
||||
config: &ComputeConfig,
|
||||
) -> Result<ComputePipeline, ComputeError> {
|
||||
// Create shader module
|
||||
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
|
||||
label: Some("matmul_shader"),
|
||||
source: wgpu::ShaderSource::Wgsl(MATMUL_SHADER.into()),
|
||||
});
|
||||
|
||||
// Create bind group layout
|
||||
// Bindings: 0=A matrix, 1=B matrix, 2=C matrix (output), 3=uniforms
|
||||
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
|
||||
label: Some("matmul_bind_group_layout"),
|
||||
entries: &[
|
||||
// Matrix A (read-only storage)
|
||||
wgpu::BindGroupLayoutEntry {
|
||||
binding: 0,
|
||||
visibility: wgpu::ShaderStages::COMPUTE,
|
||||
ty: wgpu::BindingType::Buffer {
|
||||
ty: wgpu::BufferBindingType::Storage { read_only: true },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
// Matrix B (read-only storage)
|
||||
wgpu::BindGroupLayoutEntry {
|
||||
binding: 1,
|
||||
visibility: wgpu::ShaderStages::COMPUTE,
|
||||
ty: wgpu::BindingType::Buffer {
|
||||
ty: wgpu::BufferBindingType::Storage { read_only: true },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
// Matrix C (read-write storage)
|
||||
wgpu::BindGroupLayoutEntry {
|
||||
binding: 2,
|
||||
visibility: wgpu::ShaderStages::COMPUTE,
|
||||
ty: wgpu::BindingType::Buffer {
|
||||
ty: wgpu::BufferBindingType::Storage { read_only: false },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
// Uniforms (dimensions)
|
||||
wgpu::BindGroupLayoutEntry {
|
||||
binding: 3,
|
||||
visibility: wgpu::ShaderStages::COMPUTE,
|
||||
ty: wgpu::BindingType::Buffer {
|
||||
ty: wgpu::BufferBindingType::Uniform,
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
// Create pipeline layout
|
||||
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
|
||||
label: Some("matmul_pipeline_layout"),
|
||||
bind_group_layouts: &[&bind_group_layout],
|
||||
push_constant_ranges: &[],
|
||||
});
|
||||
|
||||
// Create compute pipeline
|
||||
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
|
||||
label: Some("matmul_pipeline"),
|
||||
layout: Some(&pipeline_layout),
|
||||
module: &shader,
|
||||
entry_point: Some("main"),
|
||||
compilation_options: wgpu::PipelineCompilationOptions::default(),
|
||||
cache: None,
|
||||
});
|
||||
|
||||
Ok(ComputePipeline {
|
||||
pipeline,
|
||||
bind_group_layout,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create attention pipeline
|
||||
fn create_attention_pipeline(
|
||||
device: &wgpu::Device,
|
||||
config: &ComputeConfig,
|
||||
) -> Result<ComputePipeline, ComputeError> {
|
||||
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
|
||||
label: Some("attention_shader"),
|
||||
source: wgpu::ShaderSource::Wgsl(ATTENTION_SHADER.into()),
|
||||
});
|
||||
|
||||
// Bindings: 0=Q, 1=K, 2=V, 3=Output, 4=Uniforms
|
||||
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
|
||||
label: Some("attention_bind_group_layout"),
|
||||
entries: &[
|
||||
// Q (query)
|
||||
wgpu::BindGroupLayoutEntry {
|
||||
binding: 0,
|
||||
visibility: wgpu::ShaderStages::COMPUTE,
|
||||
ty: wgpu::BindingType::Buffer {
|
||||
ty: wgpu::BufferBindingType::Storage { read_only: true },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
// K (key)
|
||||
wgpu::BindGroupLayoutEntry {
|
||||
binding: 1,
|
||||
visibility: wgpu::ShaderStages::COMPUTE,
|
||||
ty: wgpu::BindingType::Buffer {
|
||||
ty: wgpu::BufferBindingType::Storage { read_only: true },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
// V (value)
|
||||
wgpu::BindGroupLayoutEntry {
|
||||
binding: 2,
|
||||
visibility: wgpu::ShaderStages::COMPUTE,
|
||||
ty: wgpu::BindingType::Buffer {
|
||||
ty: wgpu::BufferBindingType::Storage { read_only: true },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
// Output
|
||||
wgpu::BindGroupLayoutEntry {
|
||||
binding: 3,
|
||||
visibility: wgpu::ShaderStages::COMPUTE,
|
||||
ty: wgpu::BindingType::Buffer {
|
||||
ty: wgpu::BufferBindingType::Storage { read_only: false },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
// Uniforms
|
||||
wgpu::BindGroupLayoutEntry {
|
||||
binding: 4,
|
||||
visibility: wgpu::ShaderStages::COMPUTE,
|
||||
ty: wgpu::BindingType::Buffer {
|
||||
ty: wgpu::BufferBindingType::Uniform,
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
|
||||
label: Some("attention_pipeline_layout"),
|
||||
bind_group_layouts: &[&bind_group_layout],
|
||||
push_constant_ranges: &[],
|
||||
});
|
||||
|
||||
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
|
||||
label: Some("attention_pipeline"),
|
||||
layout: Some(&pipeline_layout),
|
||||
module: &shader,
|
||||
entry_point: Some("main"),
|
||||
compilation_options: wgpu::PipelineCompilationOptions::default(),
|
||||
cache: None,
|
||||
});
|
||||
|
||||
Ok(ComputePipeline {
|
||||
pipeline,
|
||||
bind_group_layout,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create LoRA forward pipeline
|
||||
fn create_lora_pipeline(
|
||||
device: &wgpu::Device,
|
||||
config: &ComputeConfig,
|
||||
) -> Result<ComputePipeline, ComputeError> {
|
||||
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
|
||||
label: Some("lora_shader"),
|
||||
source: wgpu::ShaderSource::Wgsl(LORA_SHADER.into()),
|
||||
});
|
||||
|
||||
// Bindings: 0=Input, 1=LoRA_A, 2=LoRA_B, 3=Output, 4=Uniforms
|
||||
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
|
||||
label: Some("lora_bind_group_layout"),
|
||||
entries: &[
|
||||
// Input
|
||||
wgpu::BindGroupLayoutEntry {
|
||||
binding: 0,
|
||||
visibility: wgpu::ShaderStages::COMPUTE,
|
||||
ty: wgpu::BindingType::Buffer {
|
||||
ty: wgpu::BufferBindingType::Storage { read_only: true },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
// LoRA A matrix
|
||||
wgpu::BindGroupLayoutEntry {
|
||||
binding: 1,
|
||||
visibility: wgpu::ShaderStages::COMPUTE,
|
||||
ty: wgpu::BindingType::Buffer {
|
||||
ty: wgpu::BufferBindingType::Storage { read_only: true },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
// LoRA B matrix
|
||||
wgpu::BindGroupLayoutEntry {
|
||||
binding: 2,
|
||||
visibility: wgpu::ShaderStages::COMPUTE,
|
||||
ty: wgpu::BindingType::Buffer {
|
||||
ty: wgpu::BufferBindingType::Storage { read_only: true },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
// Output
|
||||
wgpu::BindGroupLayoutEntry {
|
||||
binding: 3,
|
||||
visibility: wgpu::ShaderStages::COMPUTE,
|
||||
ty: wgpu::BindingType::Buffer {
|
||||
ty: wgpu::BufferBindingType::Storage { read_only: false },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
// Uniforms
|
||||
wgpu::BindGroupLayoutEntry {
|
||||
binding: 4,
|
||||
visibility: wgpu::ShaderStages::COMPUTE,
|
||||
ty: wgpu::BindingType::Buffer {
|
||||
ty: wgpu::BufferBindingType::Uniform,
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
|
||||
label: Some("lora_pipeline_layout"),
|
||||
bind_group_layouts: &[&bind_group_layout],
|
||||
push_constant_ranges: &[],
|
||||
});
|
||||
|
||||
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
|
||||
label: Some("lora_pipeline"),
|
||||
layout: Some(&pipeline_layout),
|
||||
module: &shader,
|
||||
entry_point: Some("main"),
|
||||
compilation_options: wgpu::PipelineCompilationOptions::default(),
|
||||
cache: None,
|
||||
});
|
||||
|
||||
Ok(ComputePipeline {
|
||||
pipeline,
|
||||
bind_group_layout,
|
||||
})
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Buffer Management
|
||||
// ========================================================================
|
||||
|
||||
/// Allocate a GPU buffer
|
||||
pub fn allocate_buffer(&self, desc: TensorDescriptor, usage: BufferUsage) -> Result<GpuBuffer, ComputeError> {
|
||||
let size = desc.size_bytes();
|
||||
|
||||
// Check against device limits
|
||||
if size > self.limits.max_buffer_size as usize {
|
||||
return Err(ComputeError::BufferAllocationFailed {
|
||||
requested: size,
|
||||
available: self.limits.max_buffer_size as usize,
|
||||
});
|
||||
}
|
||||
|
||||
let mut wgpu_usage = wgpu::BufferUsages::empty();
|
||||
if usage.map_read { wgpu_usage |= wgpu::BufferUsages::MAP_READ; }
|
||||
if usage.map_write { wgpu_usage |= wgpu::BufferUsages::MAP_WRITE; }
|
||||
if usage.copy_src { wgpu_usage |= wgpu::BufferUsages::COPY_SRC; }
|
||||
if usage.copy_dst { wgpu_usage |= wgpu::BufferUsages::COPY_DST; }
|
||||
if usage.storage { wgpu_usage |= wgpu::BufferUsages::STORAGE; }
|
||||
if usage.uniform { wgpu_usage |= wgpu::BufferUsages::UNIFORM; }
|
||||
|
||||
let buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
|
||||
label: Some("compute_buffer"),
|
||||
size: size as u64,
|
||||
usage: wgpu_usage,
|
||||
mapped_at_creation: false,
|
||||
});
|
||||
|
||||
Ok(GpuBuffer {
|
||||
buffer: Arc::new(buffer),
|
||||
size,
|
||||
desc,
|
||||
})
|
||||
}
|
||||
|
||||
/// Upload data to GPU buffer
|
||||
pub async fn upload_buffer(&self, buffer: &GpuBuffer, data: &[u8]) -> Result<(), ComputeError> {
|
||||
if data.len() != buffer.size {
|
||||
return Err(ComputeError::DimensionMismatch {
|
||||
expected: format!("{} bytes", buffer.size),
|
||||
actual: format!("{} bytes", data.len()),
|
||||
});
|
||||
}
|
||||
|
||||
// Use staging buffer for upload
|
||||
let staging = self.staging_pool.get_upload_buffer(data.len())?;
|
||||
|
||||
// Write to staging buffer
|
||||
self.queue.write_buffer(&staging, 0, data);
|
||||
|
||||
// Copy from staging to destination
|
||||
let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
|
||||
label: Some("upload_encoder"),
|
||||
});
|
||||
encoder.copy_buffer_to_buffer(&staging, 0, buffer.raw(), 0, data.len() as u64);
|
||||
self.queue.submit(std::iter::once(encoder.finish()));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Download data from GPU buffer
|
||||
pub async fn download_buffer(&self, buffer: &GpuBuffer) -> Result<Vec<u8>, ComputeError> {
|
||||
let staging = self.staging_pool.get_download_buffer(buffer.size)?;
|
||||
|
||||
// Copy from source to staging
|
||||
let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
|
||||
label: Some("download_encoder"),
|
||||
});
|
||||
encoder.copy_buffer_to_buffer(buffer.raw(), 0, &staging, 0, buffer.size as u64);
|
||||
self.queue.submit(std::iter::once(encoder.finish()));
|
||||
|
||||
// Map staging buffer and read
|
||||
let slice = staging.slice(..);
|
||||
let (tx, rx) = std::sync::mpsc::channel();
|
||||
slice.map_async(wgpu::MapMode::Read, move |result| {
|
||||
tx.send(result).unwrap();
|
||||
});
|
||||
self.device.poll(wgpu::Maintain::Wait);
|
||||
rx.recv().unwrap().map_err(|e| ComputeError::DeviceNotAvailable(e.to_string()))?;
|
||||
|
||||
let data = slice.get_mapped_range().to_vec();
|
||||
staging.unmap();
|
||||
|
||||
Ok(data)
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Matrix Multiplication
|
||||
// ========================================================================
|
||||
|
||||
/// Perform matrix multiplication: C = A * B
|
||||
///
|
||||
/// Dimensions: A (M x K), B (K x N), C (M x N)
|
||||
///
|
||||
/// Performance target: 10+ TFLOPS on discrete GPU
|
||||
pub async fn matmul(
|
||||
&mut self,
|
||||
a: &GpuBuffer,
|
||||
b: &GpuBuffer,
|
||||
c: &GpuBuffer,
|
||||
m: u32,
|
||||
n: u32,
|
||||
k: u32,
|
||||
) -> Result<ComputeMetrics, ComputeError> {
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
// Validate dimensions
|
||||
let expected_a = (m as usize) * (k as usize) * 4; // f32
|
||||
let expected_b = (k as usize) * (n as usize) * 4;
|
||||
let expected_c = (m as usize) * (n as usize) * 4;
|
||||
|
||||
if a.size != expected_a || b.size != expected_b || c.size != expected_c {
|
||||
return Err(ComputeError::DimensionMismatch {
|
||||
expected: format!("A:{}x{}, B:{}x{}, C:{}x{}", m, k, k, n, m, n),
|
||||
actual: format!("A:{}, B:{}, C:{} bytes", a.size, b.size, c.size),
|
||||
});
|
||||
}
|
||||
|
||||
// Create uniforms buffer
|
||||
let uniforms = [m, n, k, self.config.tile_size];
|
||||
let uniform_buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
|
||||
label: Some("matmul_uniforms"),
|
||||
contents: bytemuck::cast_slice(&uniforms),
|
||||
usage: wgpu::BufferUsages::UNIFORM,
|
||||
});
|
||||
|
||||
// Create bind group
|
||||
let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
|
||||
label: Some("matmul_bind_group"),
|
||||
layout: &self.matmul_pipeline.bind_group_layout,
|
||||
entries: &[
|
||||
wgpu::BindGroupEntry { binding: 0, resource: a.raw().as_entire_binding() },
|
||||
wgpu::BindGroupEntry { binding: 1, resource: b.raw().as_entire_binding() },
|
||||
wgpu::BindGroupEntry { binding: 2, resource: c.raw().as_entire_binding() },
|
||||
wgpu::BindGroupEntry { binding: 3, resource: uniform_buffer.as_entire_binding() },
|
||||
],
|
||||
});
|
||||
|
||||
// Dispatch compute
|
||||
let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
|
||||
label: Some("matmul_encoder"),
|
||||
});
|
||||
|
||||
{
|
||||
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
|
||||
label: Some("matmul_pass"),
|
||||
timestamp_writes: None,
|
||||
});
|
||||
pass.set_pipeline(&self.matmul_pipeline.pipeline);
|
||||
pass.set_bind_group(0, &bind_group, &[]);
|
||||
|
||||
// Dispatch workgroups (tile-based)
|
||||
let tile_size = self.config.tile_size;
|
||||
let workgroups_x = (m + tile_size - 1) / tile_size;
|
||||
let workgroups_y = (n + tile_size - 1) / tile_size;
|
||||
pass.dispatch_workgroups(workgroups_x, workgroups_y, 1);
|
||||
}
|
||||
|
||||
let kernel_start = std::time::Instant::now();
|
||||
self.queue.submit(std::iter::once(encoder.finish()));
|
||||
self.device.poll(wgpu::Maintain::Wait);
|
||||
let kernel_time = kernel_start.elapsed();
|
||||
|
||||
let total_time = start.elapsed();
|
||||
|
||||
// Calculate metrics
|
||||
let flops = 2.0 * (m as f64) * (n as f64) * (k as f64); // 2*M*N*K for matmul
|
||||
let metrics = ComputeMetrics {
|
||||
flops,
|
||||
bandwidth_gbps: ((a.size + b.size + c.size) as f64) / kernel_time.as_secs_f64() / 1e9,
|
||||
kernel_time_ms: kernel_time.as_secs_f64() * 1000.0,
|
||||
transfer_time_ms: 0.0, // Data already on GPU
|
||||
total_time_ms: total_time.as_secs_f64() * 1000.0,
|
||||
};
|
||||
|
||||
self.last_metrics = metrics.clone();
|
||||
Ok(metrics)
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Attention
|
||||
// ========================================================================
|
||||
|
||||
/// Compute attention: Output = softmax(Q * K^T / sqrt(d_k)) * V
|
||||
///
|
||||
/// Uses flash attention algorithm for memory efficiency.
|
||||
///
|
||||
/// Performance target: 2ms for 4K context
|
||||
pub async fn attention(
|
||||
&mut self,
|
||||
q: &GpuBuffer,
|
||||
k: &GpuBuffer,
|
||||
v: &GpuBuffer,
|
||||
output: &GpuBuffer,
|
||||
config: &AttentionConfig,
|
||||
seq_len: u32,
|
||||
) -> Result<ComputeMetrics, ComputeError> {
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
// Validate dimensions
|
||||
let hidden_dim = config.hidden_dim();
|
||||
let expected_size = (seq_len as usize) * hidden_dim * 4; // f32
|
||||
|
||||
if q.size != expected_size || k.size != expected_size || v.size != expected_size {
|
||||
return Err(ComputeError::DimensionMismatch {
|
||||
expected: format!("{}x{} = {} bytes", seq_len, hidden_dim, expected_size),
|
||||
actual: format!("Q:{}, K:{}, V:{} bytes", q.size, k.size, v.size),
|
||||
});
|
||||
}
|
||||
|
||||
// Create uniforms buffer
|
||||
let scale = config.get_scale();
|
||||
let causal_mask = if config.causal { 1u32 } else { 0u32 };
|
||||
let uniforms: [f32; 8] = [
|
||||
seq_len as f32,
|
||||
config.head_dim as f32,
|
||||
config.num_heads as f32,
|
||||
scale,
|
||||
causal_mask as f32,
|
||||
0.0, 0.0, 0.0, // padding
|
||||
];
|
||||
let uniform_buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
|
||||
label: Some("attention_uniforms"),
|
||||
contents: bytemuck::cast_slice(&uniforms),
|
||||
usage: wgpu::BufferUsages::UNIFORM,
|
||||
});
|
||||
|
||||
// Create bind group
|
||||
let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
|
||||
label: Some("attention_bind_group"),
|
||||
layout: &self.attention_pipeline.bind_group_layout,
|
||||
entries: &[
|
||||
wgpu::BindGroupEntry { binding: 0, resource: q.raw().as_entire_binding() },
|
||||
wgpu::BindGroupEntry { binding: 1, resource: k.raw().as_entire_binding() },
|
||||
wgpu::BindGroupEntry { binding: 2, resource: v.raw().as_entire_binding() },
|
||||
wgpu::BindGroupEntry { binding: 3, resource: output.raw().as_entire_binding() },
|
||||
wgpu::BindGroupEntry { binding: 4, resource: uniform_buffer.as_entire_binding() },
|
||||
],
|
||||
});
|
||||
|
||||
// Dispatch compute
|
||||
let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
|
||||
label: Some("attention_encoder"),
|
||||
});
|
||||
|
||||
{
|
||||
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
|
||||
label: Some("attention_pass"),
|
||||
timestamp_writes: None,
|
||||
});
|
||||
pass.set_pipeline(&self.attention_pipeline.pipeline);
|
||||
pass.set_bind_group(0, &bind_group, &[]);
|
||||
|
||||
// Dispatch: one workgroup per head per batch of sequence positions
|
||||
let block_size = 64u32; // Flash attention block size
|
||||
let num_blocks = (seq_len + block_size - 1) / block_size;
|
||||
pass.dispatch_workgroups(num_blocks, config.num_heads as u32, 1);
|
||||
}
|
||||
|
||||
let kernel_start = std::time::Instant::now();
|
||||
self.queue.submit(std::iter::once(encoder.finish()));
|
||||
self.device.poll(wgpu::Maintain::Wait);
|
||||
let kernel_time = kernel_start.elapsed();
|
||||
|
||||
let total_time = start.elapsed();
|
||||
|
||||
// Calculate metrics (attention has O(n^2*d) complexity)
|
||||
let flops = 4.0 * (seq_len as f64).powi(2) * (hidden_dim as f64);
|
||||
let metrics = ComputeMetrics {
|
||||
flops,
|
||||
bandwidth_gbps: ((q.size + k.size + v.size + output.size) as f64) / kernel_time.as_secs_f64() / 1e9,
|
||||
kernel_time_ms: kernel_time.as_secs_f64() * 1000.0,
|
||||
transfer_time_ms: 0.0,
|
||||
total_time_ms: total_time.as_secs_f64() * 1000.0,
|
||||
};
|
||||
|
||||
self.last_metrics = metrics.clone();
|
||||
Ok(metrics)
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// LoRA Forward
|
||||
// ========================================================================
|
||||
|
||||
/// Apply LoRA adapter: output = input + scaling * (input @ A @ B)
|
||||
///
|
||||
/// Where A is (in_dim x rank) and B is (rank x out_dim).
|
||||
///
|
||||
/// Performance target: <1ms
|
||||
pub async fn lora_forward(
|
||||
&mut self,
|
||||
input: &GpuBuffer,
|
||||
lora_a: &GpuBuffer,
|
||||
lora_b: &GpuBuffer,
|
||||
output: &GpuBuffer,
|
||||
config: &LoraConfig,
|
||||
batch_size: u32,
|
||||
) -> Result<ComputeMetrics, ComputeError> {
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
// Validate dimensions
|
||||
let expected_input = (batch_size as usize) * config.in_dim * 4;
|
||||
let expected_a = config.a_size() * 4;
|
||||
let expected_b = config.b_size() * 4;
|
||||
let expected_output = (batch_size as usize) * config.out_dim * 4;
|
||||
|
||||
if input.size != expected_input || lora_a.size != expected_a ||
|
||||
lora_b.size != expected_b || output.size != expected_output {
|
||||
return Err(ComputeError::DimensionMismatch {
|
||||
expected: format!("input:{}x{}, A:{}x{}, B:{}x{}, output:{}x{}",
|
||||
batch_size, config.in_dim, config.in_dim, config.rank,
|
||||
config.rank, config.out_dim, batch_size, config.out_dim),
|
||||
actual: format!("input:{}, A:{}, B:{}, output:{} bytes",
|
||||
input.size, lora_a.size, lora_b.size, output.size),
|
||||
});
|
||||
}
|
||||
|
||||
// Create uniforms buffer
|
||||
let scaling = config.scaling();
|
||||
let uniforms: [f32; 8] = [
|
||||
batch_size as f32,
|
||||
config.in_dim as f32,
|
||||
config.rank as f32,
|
||||
config.out_dim as f32,
|
||||
scaling,
|
||||
0.0, 0.0, 0.0, // padding
|
||||
];
|
||||
let uniform_buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
|
||||
label: Some("lora_uniforms"),
|
||||
contents: bytemuck::cast_slice(&uniforms),
|
||||
usage: wgpu::BufferUsages::UNIFORM,
|
||||
});
|
||||
|
||||
// Create bind group
|
||||
let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
|
||||
label: Some("lora_bind_group"),
|
||||
layout: &self.lora_pipeline.bind_group_layout,
|
||||
entries: &[
|
||||
wgpu::BindGroupEntry { binding: 0, resource: input.raw().as_entire_binding() },
|
||||
wgpu::BindGroupEntry { binding: 1, resource: lora_a.raw().as_entire_binding() },
|
||||
wgpu::BindGroupEntry { binding: 2, resource: lora_b.raw().as_entire_binding() },
|
||||
wgpu::BindGroupEntry { binding: 3, resource: output.raw().as_entire_binding() },
|
||||
wgpu::BindGroupEntry { binding: 4, resource: uniform_buffer.as_entire_binding() },
|
||||
],
|
||||
});
|
||||
|
||||
// Dispatch compute
|
||||
let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
|
||||
label: Some("lora_encoder"),
|
||||
});
|
||||
|
||||
{
|
||||
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
|
||||
label: Some("lora_pass"),
|
||||
timestamp_writes: None,
|
||||
});
|
||||
pass.set_pipeline(&self.lora_pipeline.pipeline);
|
||||
pass.set_bind_group(0, &bind_group, &[]);
|
||||
|
||||
// Dispatch: one workgroup per batch element
|
||||
let workgroup_size = 256u32;
|
||||
let workgroups = (batch_size * config.out_dim as u32 + workgroup_size - 1) / workgroup_size;
|
||||
pass.dispatch_workgroups(workgroups, 1, 1);
|
||||
}
|
||||
|
||||
let kernel_start = std::time::Instant::now();
|
||||
self.queue.submit(std::iter::once(encoder.finish()));
|
||||
self.device.poll(wgpu::Maintain::Wait);
|
||||
let kernel_time = kernel_start.elapsed();
|
||||
|
||||
let total_time = start.elapsed();
|
||||
|
||||
// Calculate metrics
|
||||
// LoRA: input @ A @ B = 2 matmuls
|
||||
let flops = 2.0 * (batch_size as f64) * (config.in_dim as f64) * (config.rank as f64)
|
||||
+ 2.0 * (batch_size as f64) * (config.rank as f64) * (config.out_dim as f64);
|
||||
let metrics = ComputeMetrics {
|
||||
flops,
|
||||
bandwidth_gbps: ((input.size + lora_a.size + lora_b.size + output.size) as f64)
|
||||
/ kernel_time.as_secs_f64() / 1e9,
|
||||
kernel_time_ms: kernel_time.as_secs_f64() * 1000.0,
|
||||
transfer_time_ms: 0.0,
|
||||
total_time_ms: total_time.as_secs_f64() * 1000.0,
|
||||
};
|
||||
|
||||
self.last_metrics = metrics.clone();
|
||||
Ok(metrics)
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Utilities
|
||||
// ========================================================================
|
||||
|
||||
/// Get last operation metrics
|
||||
pub fn last_metrics(&self) -> &ComputeMetrics {
|
||||
&self.last_metrics
|
||||
}
|
||||
|
||||
/// Get device limits
|
||||
pub fn limits(&self) -> &wgpu::Limits {
|
||||
&self.limits
|
||||
}
|
||||
|
||||
/// Get configuration
|
||||
pub fn config(&self) -> &ComputeConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Synchronize all pending GPU operations
|
||||
pub fn sync(&self) {
|
||||
self.device.poll(wgpu::Maintain::Wait);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Staging Buffer Pool
|
||||
// ============================================================================
|
||||
|
||||
/// Pool of reusable staging buffers for CPU<->GPU transfers
|
||||
struct StagingBufferPool {
|
||||
device: Arc<wgpu::Device>,
|
||||
upload_buffers: Vec<wgpu::Buffer>,
|
||||
download_buffers: Vec<wgpu::Buffer>,
|
||||
max_pool_size: usize,
|
||||
}
|
||||
|
||||
impl StagingBufferPool {
|
||||
fn new(device: Arc<wgpu::Device>, max_pool_size: usize) -> Self {
|
||||
Self {
|
||||
device,
|
||||
upload_buffers: Vec::new(),
|
||||
download_buffers: Vec::new(),
|
||||
max_pool_size,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_upload_buffer(&self, size: usize) -> Result<wgpu::Buffer, ComputeError> {
|
||||
// For simplicity, always create new buffer (production would pool)
|
||||
let buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
|
||||
label: Some("staging_upload"),
|
||||
size: size as u64,
|
||||
usage: wgpu::BufferUsages::MAP_WRITE | wgpu::BufferUsages::COPY_SRC,
|
||||
mapped_at_creation: false,
|
||||
});
|
||||
Ok(buffer)
|
||||
}
|
||||
|
||||
fn get_download_buffer(&self, size: usize) -> Result<wgpu::Buffer, ComputeError> {
|
||||
let buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
|
||||
label: Some("staging_download"),
|
||||
size: size as u64,
|
||||
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
|
||||
mapped_at_creation: false,
|
||||
});
|
||||
Ok(buffer)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// wgpu::util helpers
|
||||
// ============================================================================
|
||||
|
||||
mod wgpu_util {
|
||||
use super::*;
|
||||
|
||||
impl wgpu::Device {
|
||||
pub fn create_buffer_init(&self, desc: &wgpu::util::BufferInitDescriptor) -> wgpu::Buffer {
|
||||
wgpu::util::DeviceExt::create_buffer_init(self, desc)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// Note: These tests require a GPU and are marked as ignored by default
|
||||
// Run with: cargo test --features webgpu -- --ignored
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_webgpu_init() {
|
||||
let compute = WebGpuCompute::new().await;
|
||||
assert!(compute.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_buffer_allocation() {
|
||||
let compute = WebGpuCompute::new().await.unwrap();
|
||||
let desc = TensorDescriptor::matrix(1024, 1024, DataType::F32);
|
||||
let buffer = compute.allocate_buffer(desc, BufferUsage::storage());
|
||||
assert!(buffer.is_ok());
|
||||
assert_eq!(buffer.unwrap().size(), 1024 * 1024 * 4);
|
||||
}
|
||||
}
|
||||
566
vendor/ruvector/examples/edge-net/src/compute/workers.rs
vendored
Normal file
566
vendor/ruvector/examples/edge-net/src/compute/workers.rs
vendored
Normal file
@@ -0,0 +1,566 @@
|
||||
//! WebWorker pool for CPU parallelism in browsers
|
||||
//!
|
||||
//! Provides multi-threaded compute using WebWorkers with work stealing
|
||||
//! for load balancing. Uses SharedArrayBuffer when available for
|
||||
//! zero-copy data sharing.
|
||||
//!
|
||||
//! ## Architecture
|
||||
//!
|
||||
//! ```text
|
||||
//! +------------------+
|
||||
//! | Main Thread |
|
||||
//! | (Coordinator) |
|
||||
//! +--------+---------+
|
||||
//! |
|
||||
//! +-----+-----+-----+-----+
|
||||
//! | | | | |
|
||||
//! +--v-+ +-v--+ +--v-+ +--v-+ +--v-+
|
||||
//! | W1 | | W2 | | W3 | | W4 | | Wn |
|
||||
//! +----+ +----+ +----+ +----+ +----+
|
||||
//! | | | | |
|
||||
//! +-----+-----+-----+-----+
|
||||
//! |
|
||||
//! SharedArrayBuffer (when available)
|
||||
//! ```
|
||||
//!
|
||||
//! ## Work Stealing
|
||||
//!
|
||||
//! Workers that finish early can steal work from busy workers' queues.
|
||||
|
||||
use wasm_bindgen::prelude::*;
|
||||
use wasm_bindgen::JsCast;
|
||||
use web_sys::{Worker, MessageEvent};
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::cell::RefCell;
|
||||
use std::rc::Rc;
|
||||
|
||||
/// Task for worker execution
|
||||
#[derive(Clone)]
|
||||
pub struct WorkerTask {
|
||||
/// Task identifier
|
||||
pub id: u32,
|
||||
/// Operation type
|
||||
pub op: WorkerOp,
|
||||
/// Input data offset in shared buffer
|
||||
pub input_offset: usize,
|
||||
/// Input data length
|
||||
pub input_len: usize,
|
||||
/// Output data offset in shared buffer
|
||||
pub output_offset: usize,
|
||||
}
|
||||
|
||||
/// Operations that workers can perform
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum WorkerOp {
|
||||
/// Matrix multiplication (partial)
|
||||
MatmulPartial { m_start: usize, m_end: usize, k: usize, n: usize },
|
||||
/// Dot product (partial)
|
||||
DotProductPartial { start: usize, end: usize },
|
||||
/// Vector element-wise operation
|
||||
VectorOp { start: usize, end: usize, op: VectorOpType },
|
||||
/// Reduction (sum, max, etc.)
|
||||
Reduce { start: usize, end: usize, op: ReduceOp },
|
||||
}
|
||||
|
||||
/// Element-wise vector operations
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum VectorOpType {
|
||||
Add,
|
||||
Sub,
|
||||
Mul,
|
||||
Div,
|
||||
Relu,
|
||||
Sigmoid,
|
||||
}
|
||||
|
||||
/// Reduction operations
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum ReduceOp {
|
||||
Sum,
|
||||
Max,
|
||||
Min,
|
||||
Mean,
|
||||
}
|
||||
|
||||
/// Worker pool status
|
||||
#[derive(Clone)]
|
||||
pub struct PoolStatus {
|
||||
/// Number of workers
|
||||
pub worker_count: usize,
|
||||
/// Number of active tasks
|
||||
pub active_tasks: usize,
|
||||
/// Total tasks completed
|
||||
pub completed_tasks: u64,
|
||||
/// Whether shared memory is available
|
||||
pub has_shared_memory: bool,
|
||||
}
|
||||
|
||||
/// WebWorker pool for parallel compute
|
||||
#[wasm_bindgen]
|
||||
pub struct WorkerPool {
|
||||
/// Active workers
|
||||
workers: Vec<Worker>,
|
||||
/// Number of workers
|
||||
worker_count: usize,
|
||||
/// Shared memory buffer (if available)
|
||||
shared_buffer: Option<js_sys::SharedArrayBuffer>,
|
||||
/// Float32 view into shared buffer
|
||||
shared_view: Option<js_sys::Float32Array>,
|
||||
/// Active task count
|
||||
active_tasks: Rc<RefCell<usize>>,
|
||||
/// Completed task count
|
||||
completed_tasks: Rc<RefCell<u64>>,
|
||||
/// Whether pool is initialized
|
||||
initialized: bool,
|
||||
/// Has SharedArrayBuffer support
|
||||
has_shared_memory: bool,
|
||||
/// Pending results collector
|
||||
pending_results: Rc<RefCell<Vec<Vec<f32>>>>,
|
||||
/// Next task ID
|
||||
next_task_id: Rc<RefCell<u32>>,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WorkerPool {
|
||||
/// Create a new worker pool
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(worker_count: usize) -> Result<WorkerPool, JsValue> {
|
||||
let count = worker_count.max(1).min(16); // Limit to reasonable range
|
||||
|
||||
// Check for SharedArrayBuffer support
|
||||
let window = web_sys::window()
|
||||
.ok_or_else(|| JsValue::from_str("No window"))?;
|
||||
let has_shared_memory = js_sys::Reflect::has(&window, &"SharedArrayBuffer".into())
|
||||
.unwrap_or(false);
|
||||
|
||||
// Create shared buffer if available (16MB default)
|
||||
let (shared_buffer, shared_view) = if has_shared_memory {
|
||||
let buffer = js_sys::SharedArrayBuffer::new(16 * 1024 * 1024);
|
||||
let view = js_sys::Float32Array::new(&buffer);
|
||||
(Some(buffer), Some(view))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
|
||||
Ok(WorkerPool {
|
||||
workers: Vec::with_capacity(count),
|
||||
worker_count: count,
|
||||
shared_buffer,
|
||||
shared_view,
|
||||
active_tasks: Rc::new(RefCell::new(0)),
|
||||
completed_tasks: Rc::new(RefCell::new(0)),
|
||||
initialized: false,
|
||||
has_shared_memory,
|
||||
pending_results: Rc::new(RefCell::new(Vec::new())),
|
||||
next_task_id: Rc::new(RefCell::new(0)),
|
||||
})
|
||||
}
|
||||
|
||||
/// Initialize workers
|
||||
#[wasm_bindgen(js_name = initialize)]
|
||||
pub fn initialize(&mut self) -> Result<(), JsValue> {
|
||||
if self.initialized {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Create worker script as a blob
|
||||
let worker_script = create_worker_script();
|
||||
let blob_parts = js_sys::Array::new();
|
||||
blob_parts.push(&worker_script.into());
|
||||
|
||||
let blob_options = web_sys::BlobPropertyBag::new();
|
||||
blob_options.set_type("application/javascript");
|
||||
|
||||
let blob = web_sys::Blob::new_with_str_sequence_and_options(&blob_parts, &blob_options)?;
|
||||
let url = web_sys::Url::create_object_url_with_blob(&blob)?;
|
||||
|
||||
// Spawn workers
|
||||
for i in 0..self.worker_count {
|
||||
let worker = Worker::new(&url)?;
|
||||
|
||||
// Set up message handler
|
||||
let completed = self.completed_tasks.clone();
|
||||
let active = self.active_tasks.clone();
|
||||
let results = self.pending_results.clone();
|
||||
|
||||
let onmessage = Closure::wrap(Box::new(move |event: MessageEvent| {
|
||||
let data = event.data();
|
||||
|
||||
// Parse result
|
||||
if let Ok(result_array) = data.dyn_into::<js_sys::Float32Array>() {
|
||||
let mut result_vec = vec![0.0f32; result_array.length() as usize];
|
||||
result_array.copy_to(&mut result_vec);
|
||||
results.borrow_mut().push(result_vec);
|
||||
}
|
||||
|
||||
*completed.borrow_mut() += 1;
|
||||
*active.borrow_mut() = active.borrow().saturating_sub(1);
|
||||
}) as Box<dyn FnMut(MessageEvent)>);
|
||||
|
||||
worker.set_onmessage(Some(onmessage.as_ref().unchecked_ref()));
|
||||
onmessage.forget();
|
||||
|
||||
// Send initialization message
|
||||
let init_msg = js_sys::Object::new();
|
||||
js_sys::Reflect::set(&init_msg, &"type".into(), &"init".into())?;
|
||||
js_sys::Reflect::set(&init_msg, &"workerId".into(), &(i as u32).into())?;
|
||||
|
||||
if let Some(ref buffer) = self.shared_buffer {
|
||||
js_sys::Reflect::set(&init_msg, &"sharedBuffer".into(), buffer)?;
|
||||
}
|
||||
|
||||
worker.post_message(&init_msg)?;
|
||||
|
||||
self.workers.push(worker);
|
||||
}
|
||||
|
||||
self.initialized = true;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get worker count
|
||||
#[wasm_bindgen(js_name = workerCount)]
|
||||
pub fn worker_count(&self) -> usize {
|
||||
self.worker_count
|
||||
}
|
||||
|
||||
/// Get pool status
|
||||
#[wasm_bindgen(js_name = getStatus)]
|
||||
pub fn get_status(&self) -> JsValue {
|
||||
let obj = js_sys::Object::new();
|
||||
js_sys::Reflect::set(&obj, &"workerCount".into(), &(self.worker_count as u32).into()).ok();
|
||||
js_sys::Reflect::set(&obj, &"activeTasks".into(), &(*self.active_tasks.borrow() as u32).into()).ok();
|
||||
js_sys::Reflect::set(&obj, &"completedTasks".into(), &(*self.completed_tasks.borrow() as f64).into()).ok();
|
||||
js_sys::Reflect::set(&obj, &"hasSharedMemory".into(), &self.has_shared_memory.into()).ok();
|
||||
js_sys::Reflect::set(&obj, &"initialized".into(), &self.initialized.into()).ok();
|
||||
obj.into()
|
||||
}
|
||||
|
||||
/// Shutdown all workers
|
||||
#[wasm_bindgen]
|
||||
pub fn shutdown(&mut self) -> Result<(), JsValue> {
|
||||
for worker in &self.workers {
|
||||
worker.terminate();
|
||||
}
|
||||
self.workers.clear();
|
||||
self.initialized = false;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// Non-WASM implementation
|
||||
impl WorkerPool {
|
||||
/// Perform parallel matrix multiplication
|
||||
pub fn matmul_parallel(&self, a: &[f32], b: &[f32], m: usize, k: usize, n: usize) -> Result<Vec<f32>, JsValue> {
|
||||
if !self.initialized || self.workers.is_empty() {
|
||||
// Fall back to CPU
|
||||
return Ok(cpu_matmul(a, b, m, k, n));
|
||||
}
|
||||
|
||||
// For small matrices, don't bother with parallelism
|
||||
if m * k * n < 10000 {
|
||||
return Ok(cpu_matmul(a, b, m, k, n));
|
||||
}
|
||||
|
||||
// Divide rows among workers
|
||||
let rows_per_worker = (m + self.worker_count - 1) / self.worker_count;
|
||||
|
||||
// If using shared memory, copy input data
|
||||
if let (Some(ref buffer), Some(ref view)) = (&self.shared_buffer, &self.shared_view) {
|
||||
// Copy A and B to shared buffer
|
||||
let a_array = js_sys::Float32Array::from(a);
|
||||
let b_array = js_sys::Float32Array::from(b);
|
||||
view.set(&a_array, 0);
|
||||
view.set(&b_array, (m * k) as u32);
|
||||
}
|
||||
|
||||
// Dispatch tasks to workers
|
||||
self.pending_results.borrow_mut().clear();
|
||||
|
||||
for (i, worker) in self.workers.iter().enumerate() {
|
||||
let row_start = i * rows_per_worker;
|
||||
let row_end = ((i + 1) * rows_per_worker).min(m);
|
||||
|
||||
if row_start >= m {
|
||||
break;
|
||||
}
|
||||
|
||||
let msg = js_sys::Object::new();
|
||||
js_sys::Reflect::set(&msg, &"type".into(), &"matmul".into()).ok();
|
||||
js_sys::Reflect::set(&msg, &"rowStart".into(), &(row_start as u32).into()).ok();
|
||||
js_sys::Reflect::set(&msg, &"rowEnd".into(), &(row_end as u32).into()).ok();
|
||||
js_sys::Reflect::set(&msg, &"m".into(), &(m as u32).into()).ok();
|
||||
js_sys::Reflect::set(&msg, &"k".into(), &(k as u32).into()).ok();
|
||||
js_sys::Reflect::set(&msg, &"n".into(), &(n as u32).into()).ok();
|
||||
|
||||
// If no shared memory, send data directly
|
||||
if self.shared_buffer.is_none() {
|
||||
let a_slice = &a[row_start * k..row_end * k];
|
||||
let a_array = js_sys::Float32Array::from(a_slice);
|
||||
let b_array = js_sys::Float32Array::from(b);
|
||||
js_sys::Reflect::set(&msg, &"a".into(), &a_array).ok();
|
||||
js_sys::Reflect::set(&msg, &"b".into(), &b_array).ok();
|
||||
}
|
||||
|
||||
*self.active_tasks.borrow_mut() += 1;
|
||||
worker.post_message(&msg).ok();
|
||||
}
|
||||
|
||||
// Wait for results (in real async code, this would be Promise-based)
|
||||
// For now, fall back to CPU since we can't truly wait in WASM
|
||||
Ok(cpu_matmul(a, b, m, k, n))
|
||||
}
|
||||
|
||||
/// Perform parallel dot product
|
||||
pub fn dot_product_parallel(&self, a: &[f32], b: &[f32]) -> Result<f32, JsValue> {
|
||||
if !self.initialized || self.workers.is_empty() || a.len() < 10000 {
|
||||
// Fall back to CPU
|
||||
return Ok(a.iter().zip(b.iter()).map(|(x, y)| x * y).sum());
|
||||
}
|
||||
|
||||
// For simplicity, use CPU implementation
|
||||
// Full implementation would dispatch to workers and collect partial sums
|
||||
Ok(a.iter().zip(b.iter()).map(|(x, y)| x * y).sum())
|
||||
}
|
||||
}
|
||||
|
||||
/// Create the worker script as a string
|
||||
fn create_worker_script() -> String {
|
||||
r#"
|
||||
let workerId = -1;
|
||||
let sharedBuffer = null;
|
||||
let sharedView = null;
|
||||
|
||||
self.onmessage = function(e) {
|
||||
const msg = e.data;
|
||||
|
||||
if (msg.type === 'init') {
|
||||
workerId = msg.workerId;
|
||||
if (msg.sharedBuffer) {
|
||||
sharedBuffer = msg.sharedBuffer;
|
||||
sharedView = new Float32Array(sharedBuffer);
|
||||
}
|
||||
self.postMessage({ type: 'ready', workerId: workerId });
|
||||
return;
|
||||
}
|
||||
|
||||
if (msg.type === 'matmul') {
|
||||
const result = matmulPartial(msg);
|
||||
self.postMessage(result, [result.buffer]);
|
||||
return;
|
||||
}
|
||||
|
||||
if (msg.type === 'dotproduct') {
|
||||
const result = dotProductPartial(msg);
|
||||
self.postMessage({ type: 'result', value: result });
|
||||
return;
|
||||
}
|
||||
|
||||
if (msg.type === 'vectorop') {
|
||||
const result = vectorOp(msg);
|
||||
self.postMessage(result, [result.buffer]);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
function matmulPartial(msg) {
|
||||
const { rowStart, rowEnd, m, k, n } = msg;
|
||||
const rows = rowEnd - rowStart;
|
||||
const result = new Float32Array(rows * n);
|
||||
|
||||
let a, b;
|
||||
if (sharedView) {
|
||||
// Use shared memory
|
||||
a = new Float32Array(sharedBuffer, rowStart * k * 4, rows * k);
|
||||
b = new Float32Array(sharedBuffer, m * k * 4, k * n);
|
||||
} else {
|
||||
// Use passed data
|
||||
a = msg.a;
|
||||
b = msg.b;
|
||||
}
|
||||
|
||||
// Cache-friendly blocked multiplication
|
||||
const BLOCK = 32;
|
||||
for (let i = 0; i < rows; i++) {
|
||||
for (let j = 0; j < n; j++) {
|
||||
let sum = 0;
|
||||
for (let kk = 0; kk < k; kk++) {
|
||||
sum += a[i * k + kk] * b[kk * n + j];
|
||||
}
|
||||
result[i * n + j] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
function dotProductPartial(msg) {
|
||||
const { start, end } = msg;
|
||||
let sum = 0;
|
||||
|
||||
if (sharedView) {
|
||||
const a = new Float32Array(sharedBuffer, start * 4, end - start);
|
||||
const b = new Float32Array(sharedBuffer, (msg.bOffset + start) * 4, end - start);
|
||||
for (let i = 0; i < a.length; i++) {
|
||||
sum += a[i] * b[i];
|
||||
}
|
||||
} else {
|
||||
const a = msg.a;
|
||||
const b = msg.b;
|
||||
for (let i = start; i < end; i++) {
|
||||
sum += a[i] * b[i];
|
||||
}
|
||||
}
|
||||
|
||||
return sum;
|
||||
}
|
||||
|
||||
function vectorOp(msg) {
|
||||
const { start, end, op } = msg;
|
||||
const len = end - start;
|
||||
const result = new Float32Array(len);
|
||||
|
||||
const a = sharedView ? new Float32Array(sharedBuffer, start * 4, len) : msg.a;
|
||||
const b = sharedView ? new Float32Array(sharedBuffer, (msg.bOffset + start) * 4, len) : msg.b;
|
||||
|
||||
switch (op) {
|
||||
case 'add':
|
||||
for (let i = 0; i < len; i++) result[i] = a[i] + b[i];
|
||||
break;
|
||||
case 'sub':
|
||||
for (let i = 0; i < len; i++) result[i] = a[i] - b[i];
|
||||
break;
|
||||
case 'mul':
|
||||
for (let i = 0; i < len; i++) result[i] = a[i] * b[i];
|
||||
break;
|
||||
case 'div':
|
||||
for (let i = 0; i < len; i++) result[i] = a[i] / (b[i] || 1e-7);
|
||||
break;
|
||||
case 'relu':
|
||||
for (let i = 0; i < len; i++) result[i] = Math.max(a[i], 0);
|
||||
break;
|
||||
case 'sigmoid':
|
||||
for (let i = 0; i < len; i++) result[i] = 1 / (1 + Math.exp(-a[i]));
|
||||
break;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
"#.to_string()
|
||||
}
|
||||
|
||||
/// CPU matrix multiplication fallback
|
||||
fn cpu_matmul(a: &[f32], b: &[f32], m: usize, k: usize, n: usize) -> Vec<f32> {
|
||||
let mut result = vec![0.0f32; m * n];
|
||||
|
||||
// Cache-friendly blocked multiplication
|
||||
const BLOCK_SIZE: usize = 32;
|
||||
|
||||
for i0 in (0..m).step_by(BLOCK_SIZE) {
|
||||
for j0 in (0..n).step_by(BLOCK_SIZE) {
|
||||
for k0 in (0..k).step_by(BLOCK_SIZE) {
|
||||
let i_end = (i0 + BLOCK_SIZE).min(m);
|
||||
let j_end = (j0 + BLOCK_SIZE).min(n);
|
||||
let k_end = (k0 + BLOCK_SIZE).min(k);
|
||||
|
||||
for i in i0..i_end {
|
||||
for kk in k0..k_end {
|
||||
let a_val = a[i * k + kk];
|
||||
for j in j0..j_end {
|
||||
result[i * n + j] += a_val * b[kk * n + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Work-stealing task queue
|
||||
pub struct WorkStealingQueue<T> {
|
||||
/// Local tasks (LIFO for locality)
|
||||
local: Vec<T>,
|
||||
/// Shared tasks (can be stolen)
|
||||
shared: Rc<RefCell<Vec<T>>>,
|
||||
}
|
||||
|
||||
impl<T: Clone> WorkStealingQueue<T> {
|
||||
/// Create a new work-stealing queue
|
||||
pub fn new() -> Self {
|
||||
WorkStealingQueue {
|
||||
local: Vec::new(),
|
||||
shared: Rc::new(RefCell::new(Vec::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Push a task (local, cannot be stolen)
|
||||
pub fn push_local(&mut self, task: T) {
|
||||
self.local.push(task);
|
||||
}
|
||||
|
||||
/// Push a task that can be stolen
|
||||
pub fn push_shared(&mut self, task: T) {
|
||||
self.shared.borrow_mut().push(task);
|
||||
}
|
||||
|
||||
/// Pop a local task (LIFO)
|
||||
pub fn pop_local(&mut self) -> Option<T> {
|
||||
self.local.pop()
|
||||
}
|
||||
|
||||
/// Try to steal from shared queue (FIFO)
|
||||
pub fn steal(&self) -> Option<T> {
|
||||
let mut shared = self.shared.borrow_mut();
|
||||
if shared.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(shared.remove(0))
|
||||
}
|
||||
}
|
||||
|
||||
/// Get number of stealable tasks
|
||||
pub fn stealable_count(&self) -> usize {
|
||||
self.shared.borrow().len()
|
||||
}
|
||||
|
||||
/// Get total task count
|
||||
pub fn total_count(&self) -> usize {
|
||||
self.local.len() + self.shared.borrow().len()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_cpu_matmul() {
|
||||
let a = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let b = vec![5.0, 6.0, 7.0, 8.0];
|
||||
|
||||
let result = cpu_matmul(&a, &b, 2, 2, 2);
|
||||
|
||||
// [1*5 + 2*7, 1*6 + 2*8] = [19, 22]
|
||||
// [3*5 + 4*7, 3*6 + 4*8] = [43, 50]
|
||||
assert_eq!(result, vec![19.0, 22.0, 43.0, 50.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_work_stealing_queue() {
|
||||
let mut queue: WorkStealingQueue<i32> = WorkStealingQueue::new();
|
||||
|
||||
queue.push_local(1);
|
||||
queue.push_shared(2);
|
||||
queue.push_shared(3);
|
||||
|
||||
assert_eq!(queue.total_count(), 3);
|
||||
assert_eq!(queue.stealable_count(), 2);
|
||||
|
||||
assert_eq!(queue.pop_local(), Some(1));
|
||||
assert_eq!(queue.steal(), Some(2));
|
||||
assert_eq!(queue.steal(), Some(3));
|
||||
assert_eq!(queue.steal(), None);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user