Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
This commit is contained in:
309
crates/ruvector-wasm/kernels/rmsnorm.rs
Normal file
309
crates/ruvector-wasm/kernels/rmsnorm.rs
Normal file
@@ -0,0 +1,309 @@
|
||||
//! RMSNorm (Root Mean Square Layer Normalization) Kernel
|
||||
//!
|
||||
//! This kernel implements RMS normalization as used in models like LLaMA.
|
||||
//! Unlike LayerNorm, RMSNorm only uses the root mean square, without
|
||||
//! centering the distribution.
|
||||
//!
|
||||
//! Formula: y = (x / rms(x)) * weight
|
||||
//! where rms(x) = sqrt(mean(x^2) + eps)
|
||||
//!
|
||||
//! # Compilation
|
||||
//!
|
||||
//! To compile this kernel to WASM:
|
||||
//! ```bash
|
||||
//! rustc --target wasm32-unknown-unknown \
|
||||
//! --crate-type cdylib \
|
||||
//! -C opt-level=3 \
|
||||
//! -C lto=fat \
|
||||
//! kernels/rmsnorm.rs \
|
||||
//! -o kernels/rmsnorm_f32.wasm
|
||||
//! ```
|
||||
|
||||
#![no_std]
|
||||
#![no_main]
|
||||
|
||||
// Panic handler for no_std
|
||||
#[panic_handler]
|
||||
fn panic(_info: &core::panic::PanicInfo) -> ! {
|
||||
loop {}
|
||||
}
|
||||
|
||||
/// Kernel descriptor structure
|
||||
#[repr(C)]
|
||||
pub struct KernelDescriptor {
|
||||
pub input_a_offset: u32, // x tensor
|
||||
pub input_a_size: u32,
|
||||
pub input_b_offset: u32, // weight tensor (gamma)
|
||||
pub input_b_size: u32,
|
||||
pub output_offset: u32,
|
||||
pub output_size: u32,
|
||||
pub scratch_offset: u32, // For storing intermediate RMS values
|
||||
pub scratch_size: u32,
|
||||
pub params_offset: u32,
|
||||
pub params_size: u32,
|
||||
}
|
||||
|
||||
/// RMSNorm parameters
|
||||
#[repr(C)]
|
||||
pub struct RmsNormParams {
|
||||
/// Epsilon for numerical stability (typically 1e-5 or 1e-6)
|
||||
pub eps: f32,
|
||||
/// Hidden dimension (normalizing dimension)
|
||||
pub hidden_dim: u32,
|
||||
/// Number of elements to normalize (batch * seq)
|
||||
pub num_elements: u32,
|
||||
}
|
||||
|
||||
/// Error codes
|
||||
const OK: i32 = 0;
|
||||
const INVALID_INPUT: i32 = 1;
|
||||
const INVALID_OUTPUT: i32 = 2;
|
||||
const INVALID_PARAMS: i32 = 3;
|
||||
|
||||
/// Initialize kernel
|
||||
#[no_mangle]
|
||||
pub extern "C" fn kernel_init(_params_ptr: *const u8, _params_len: u32) -> i32 {
|
||||
OK
|
||||
}
|
||||
|
||||
/// Execute RMSNorm forward pass
|
||||
///
|
||||
/// # Memory Layout
|
||||
///
|
||||
/// Input A (x): [num_elements, hidden_dim] as f32
|
||||
/// Input B (weight): [hidden_dim] as f32 (gamma scaling factors)
|
||||
/// Output (y): [num_elements, hidden_dim] as f32
|
||||
/// Scratch: [num_elements] as f32 (RMS values for backward pass)
|
||||
///
|
||||
/// For each row i:
|
||||
/// rms[i] = sqrt(mean(x[i]^2) + eps)
|
||||
/// y[i] = (x[i] / rms[i]) * weight
|
||||
#[no_mangle]
|
||||
pub extern "C" fn kernel_forward(desc_ptr: *const KernelDescriptor) -> i32 {
|
||||
let desc = unsafe { &*desc_ptr };
|
||||
|
||||
// Validate inputs
|
||||
if desc.input_a_size == 0 {
|
||||
return INVALID_INPUT;
|
||||
}
|
||||
if desc.output_size == 0 {
|
||||
return INVALID_OUTPUT;
|
||||
}
|
||||
if desc.params_size < core::mem::size_of::<RmsNormParams>() as u32 {
|
||||
return INVALID_PARAMS;
|
||||
}
|
||||
|
||||
let memory_base = 0usize as *mut u8;
|
||||
|
||||
let params = unsafe {
|
||||
&*(memory_base.add(desc.params_offset as usize) as *const RmsNormParams)
|
||||
};
|
||||
|
||||
let hidden_dim = params.hidden_dim as usize;
|
||||
let num_elements = params.num_elements as usize;
|
||||
let eps = params.eps;
|
||||
|
||||
// Get tensor pointers
|
||||
let x_ptr = unsafe { memory_base.add(desc.input_a_offset as usize) as *const f32 };
|
||||
let weight_ptr = unsafe { memory_base.add(desc.input_b_offset as usize) as *const f32 };
|
||||
let y_ptr = unsafe { memory_base.add(desc.output_offset as usize) as *mut f32 };
|
||||
|
||||
// Optional: Store RMS values in scratch for backward pass
|
||||
let rms_ptr = if desc.scratch_size >= (num_elements * 4) as u32 {
|
||||
Some(unsafe { memory_base.add(desc.scratch_offset as usize) as *mut f32 })
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Process each element (row)
|
||||
for i in 0..num_elements {
|
||||
let row_offset = i * hidden_dim;
|
||||
|
||||
// Compute sum of squares
|
||||
let mut sum_sq: f32 = 0.0;
|
||||
for j in 0..hidden_dim {
|
||||
unsafe {
|
||||
let val = *x_ptr.add(row_offset + j);
|
||||
sum_sq += val * val;
|
||||
}
|
||||
}
|
||||
|
||||
// Compute RMS
|
||||
let mean_sq = sum_sq / (hidden_dim as f32);
|
||||
let rms = sqrtf(mean_sq + eps);
|
||||
let inv_rms = 1.0 / rms;
|
||||
|
||||
// Store RMS for backward pass if scratch is available
|
||||
if let Some(rms_store) = rms_ptr {
|
||||
unsafe {
|
||||
*rms_store.add(i) = rms;
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize and scale
|
||||
for j in 0..hidden_dim {
|
||||
unsafe {
|
||||
let x_val = *x_ptr.add(row_offset + j);
|
||||
let w_val = *weight_ptr.add(j);
|
||||
*y_ptr.add(row_offset + j) = (x_val * inv_rms) * w_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
OK
|
||||
}
|
||||
|
||||
/// Execute RMSNorm backward pass
|
||||
///
|
||||
/// Computes gradients for x and weight given gradient of output.
|
||||
///
|
||||
/// # Memory Layout (for backward)
|
||||
///
|
||||
/// Input A (grad_y): [num_elements, hidden_dim] as f32
|
||||
/// Input B (x): Original input (needed for gradient)
|
||||
/// Output (grad_x): [num_elements, hidden_dim] as f32
|
||||
/// Scratch: [hidden_dim] as f32 (for grad_weight accumulation)
|
||||
/// Params: Contains weight pointer separately
|
||||
#[no_mangle]
|
||||
pub extern "C" fn kernel_backward(desc_ptr: *const KernelDescriptor) -> i32 {
|
||||
let desc = unsafe { &*desc_ptr };
|
||||
|
||||
if desc.input_a_size == 0 {
|
||||
return INVALID_INPUT;
|
||||
}
|
||||
if desc.output_size == 0 {
|
||||
return INVALID_OUTPUT;
|
||||
}
|
||||
if desc.params_size < core::mem::size_of::<RmsNormParams>() as u32 {
|
||||
return INVALID_PARAMS;
|
||||
}
|
||||
|
||||
let memory_base = 0usize as *mut u8;
|
||||
|
||||
let params = unsafe {
|
||||
&*(memory_base.add(desc.params_offset as usize) as *const RmsNormParams)
|
||||
};
|
||||
|
||||
let hidden_dim = params.hidden_dim as usize;
|
||||
let num_elements = params.num_elements as usize;
|
||||
let eps = params.eps;
|
||||
|
||||
// Note: For a complete backward pass, we would need:
|
||||
// - grad_y: gradient from upstream
|
||||
// - x: original input
|
||||
// - weight: scale parameters
|
||||
// - Output: grad_x
|
||||
// - Accumulate: grad_weight
|
||||
|
||||
// This is a simplified implementation showing the structure
|
||||
let grad_y_ptr = unsafe { memory_base.add(desc.input_a_offset as usize) as *const f32 };
|
||||
let x_ptr = unsafe { memory_base.add(desc.input_b_offset as usize) as *const f32 };
|
||||
let grad_x_ptr = unsafe { memory_base.add(desc.output_offset as usize) as *mut f32 };
|
||||
|
||||
// For each element
|
||||
for i in 0..num_elements {
|
||||
let row_offset = i * hidden_dim;
|
||||
|
||||
// Recompute RMS (or load from scratch if saved during forward)
|
||||
let mut sum_sq: f32 = 0.0;
|
||||
for j in 0..hidden_dim {
|
||||
unsafe {
|
||||
let val = *x_ptr.add(row_offset + j);
|
||||
sum_sq += val * val;
|
||||
}
|
||||
}
|
||||
let mean_sq = sum_sq / (hidden_dim as f32);
|
||||
let rms = sqrtf(mean_sq + eps);
|
||||
let inv_rms = 1.0 / rms;
|
||||
let inv_rms_cubed = inv_rms * inv_rms * inv_rms;
|
||||
|
||||
// Compute grad_norm_x = grad_y * weight
|
||||
// Then grad_x = inv_rms * grad_norm_x - inv_rms^3 * x * mean(x * grad_norm_x)
|
||||
// This is the chain rule applied to RMSNorm
|
||||
|
||||
// First pass: compute sum(x * grad_y) for this row
|
||||
let mut sum_x_grad: f32 = 0.0;
|
||||
for j in 0..hidden_dim {
|
||||
unsafe {
|
||||
let x_val = *x_ptr.add(row_offset + j);
|
||||
let gy_val = *grad_y_ptr.add(row_offset + j);
|
||||
sum_x_grad += x_val * gy_val;
|
||||
}
|
||||
}
|
||||
let mean_x_grad = sum_x_grad / (hidden_dim as f32);
|
||||
|
||||
// Second pass: compute grad_x
|
||||
for j in 0..hidden_dim {
|
||||
unsafe {
|
||||
let x_val = *x_ptr.add(row_offset + j);
|
||||
let gy_val = *grad_y_ptr.add(row_offset + j);
|
||||
|
||||
// Simplified gradient (without weight consideration for this demo)
|
||||
let grad = inv_rms * gy_val - inv_rms_cubed * x_val * mean_x_grad;
|
||||
*grad_x_ptr.add(row_offset + j) = grad;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
OK
|
||||
}
|
||||
|
||||
/// Kernel info structure
|
||||
#[repr(C)]
|
||||
pub struct KernelInfo {
|
||||
pub name_ptr: *const u8,
|
||||
pub name_len: u32,
|
||||
pub version_major: u16,
|
||||
pub version_minor: u16,
|
||||
pub version_patch: u16,
|
||||
pub supports_backward: bool,
|
||||
}
|
||||
|
||||
static KERNEL_NAME: &[u8] = b"rmsnorm_f32\0";
|
||||
|
||||
/// Get kernel metadata
|
||||
#[no_mangle]
|
||||
pub extern "C" fn kernel_info(info_ptr: *mut KernelInfo) -> i32 {
|
||||
if info_ptr.is_null() {
|
||||
return INVALID_PARAMS;
|
||||
}
|
||||
|
||||
unsafe {
|
||||
(*info_ptr).name_ptr = KERNEL_NAME.as_ptr();
|
||||
(*info_ptr).name_len = KERNEL_NAME.len() as u32 - 1;
|
||||
(*info_ptr).version_major = 1;
|
||||
(*info_ptr).version_minor = 0;
|
||||
(*info_ptr).version_patch = 0;
|
||||
(*info_ptr).supports_backward = true;
|
||||
}
|
||||
|
||||
OK
|
||||
}
|
||||
|
||||
/// Cleanup kernel resources
|
||||
#[no_mangle]
|
||||
pub extern "C" fn kernel_cleanup() -> i32 {
|
||||
OK
|
||||
}
|
||||
|
||||
// Minimal sqrt implementation for no_std
|
||||
fn sqrtf(x: f32) -> f32 {
|
||||
if x <= 0.0 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Newton-Raphson method
|
||||
let mut guess = x;
|
||||
|
||||
// Initial guess using bit manipulation
|
||||
let i = x.to_bits();
|
||||
let i = 0x1fbd1df5 + (i >> 1);
|
||||
guess = f32::from_bits(i);
|
||||
|
||||
// Newton-Raphson iterations
|
||||
for _ in 0..3 {
|
||||
guess = 0.5 * (guess + x / guess);
|
||||
}
|
||||
|
||||
guess
|
||||
}
|
||||
304
crates/ruvector-wasm/kernels/rope.rs
Normal file
304
crates/ruvector-wasm/kernels/rope.rs
Normal file
@@ -0,0 +1,304 @@
|
||||
//! RoPE (Rotary Position Embedding) Kernel
|
||||
//!
|
||||
//! This kernel implements rotary position embeddings as described in the
|
||||
//! RoFormer paper (https://arxiv.org/abs/2104.09864).
|
||||
//!
|
||||
//! RoPE applies rotation to the query and key vectors in attention,
|
||||
//! encoding relative positional information.
|
||||
//!
|
||||
//! # Compilation
|
||||
//!
|
||||
//! To compile this kernel to WASM:
|
||||
//! ```bash
|
||||
//! rustc --target wasm32-unknown-unknown \
|
||||
//! --crate-type cdylib \
|
||||
//! -C opt-level=3 \
|
||||
//! -C lto=fat \
|
||||
//! kernels/rope.rs \
|
||||
//! -o kernels/rope_f32.wasm
|
||||
//! ```
|
||||
//!
|
||||
//! Or use the provided build script in the kernels directory.
|
||||
|
||||
#![no_std]
|
||||
#![no_main]
|
||||
|
||||
// Panic handler for no_std
|
||||
#[panic_handler]
|
||||
fn panic(_info: &core::panic::PanicInfo) -> ! {
|
||||
loop {}
|
||||
}
|
||||
|
||||
/// Kernel descriptor structure (must match host definition)
|
||||
#[repr(C)]
|
||||
pub struct KernelDescriptor {
|
||||
pub input_a_offset: u32, // x tensor
|
||||
pub input_a_size: u32,
|
||||
pub input_b_offset: u32, // freqs tensor
|
||||
pub input_b_size: u32,
|
||||
pub output_offset: u32,
|
||||
pub output_size: u32,
|
||||
pub scratch_offset: u32,
|
||||
pub scratch_size: u32,
|
||||
pub params_offset: u32,
|
||||
pub params_size: u32,
|
||||
}
|
||||
|
||||
/// RoPE parameters
|
||||
#[repr(C)]
|
||||
pub struct RopeParams {
|
||||
/// Base frequency (typically 10000.0)
|
||||
pub theta: f32,
|
||||
/// Sequence length
|
||||
pub seq_len: u32,
|
||||
/// Head dimension (must be even)
|
||||
pub head_dim: u32,
|
||||
/// Number of heads
|
||||
pub num_heads: u32,
|
||||
/// Batch size
|
||||
pub batch_size: u32,
|
||||
}
|
||||
|
||||
/// Error codes
|
||||
const OK: i32 = 0;
|
||||
const INVALID_INPUT: i32 = 1;
|
||||
const INVALID_OUTPUT: i32 = 2;
|
||||
const INVALID_PARAMS: i32 = 3;
|
||||
|
||||
/// Initialize kernel (optional, for stateful kernels)
|
||||
#[no_mangle]
|
||||
pub extern "C" fn kernel_init(_params_ptr: *const u8, _params_len: u32) -> i32 {
|
||||
OK
|
||||
}
|
||||
|
||||
/// Execute RoPE forward pass
|
||||
///
|
||||
/// # Memory Layout
|
||||
///
|
||||
/// Input A (x): [batch, seq, heads, dim] as f32
|
||||
/// Input B (freqs): [seq, dim/2] as f32 (precomputed frequencies)
|
||||
/// Output (y): [batch, seq, heads, dim] as f32
|
||||
///
|
||||
/// The kernel applies rotation to pairs of elements:
|
||||
/// y[..., 2i] = x[..., 2i] * cos(freq) - x[..., 2i+1] * sin(freq)
|
||||
/// y[..., 2i+1] = x[..., 2i] * sin(freq) + x[..., 2i+1] * cos(freq)
|
||||
#[no_mangle]
|
||||
pub extern "C" fn kernel_forward(desc_ptr: *const KernelDescriptor) -> i32 {
|
||||
// Safety: We trust the host to provide valid pointers
|
||||
let desc = unsafe { &*desc_ptr };
|
||||
|
||||
// Validate inputs
|
||||
if desc.input_a_size == 0 {
|
||||
return INVALID_INPUT;
|
||||
}
|
||||
if desc.output_size == 0 || desc.output_size != desc.input_a_size {
|
||||
return INVALID_OUTPUT;
|
||||
}
|
||||
if desc.params_size < core::mem::size_of::<RopeParams>() as u32 {
|
||||
return INVALID_PARAMS;
|
||||
}
|
||||
|
||||
// Get memory base pointer (WASM linear memory starts at 0)
|
||||
let memory_base = 0usize as *mut u8;
|
||||
|
||||
// Get params
|
||||
let params = unsafe {
|
||||
&*(memory_base.add(desc.params_offset as usize) as *const RopeParams)
|
||||
};
|
||||
|
||||
// Validate head_dim is even
|
||||
if params.head_dim % 2 != 0 {
|
||||
return INVALID_PARAMS;
|
||||
}
|
||||
|
||||
let half_dim = params.head_dim / 2;
|
||||
|
||||
// Get tensor pointers
|
||||
let x_ptr = unsafe { memory_base.add(desc.input_a_offset as usize) as *const f32 };
|
||||
let freqs_ptr = unsafe { memory_base.add(desc.input_b_offset as usize) as *const f32 };
|
||||
let y_ptr = unsafe { memory_base.add(desc.output_offset as usize) as *mut f32 };
|
||||
|
||||
// Apply RoPE
|
||||
// Loop order: batch -> seq -> head -> dim_pair
|
||||
for b in 0..params.batch_size {
|
||||
for s in 0..params.seq_len {
|
||||
for h in 0..params.num_heads {
|
||||
for d in 0..half_dim {
|
||||
// Calculate indices
|
||||
let idx = ((b * params.seq_len + s) * params.num_heads + h) * params.head_dim + d * 2;
|
||||
let freq_idx = s * half_dim + d;
|
||||
|
||||
unsafe {
|
||||
// Get input values
|
||||
let x0 = *x_ptr.add(idx as usize);
|
||||
let x1 = *x_ptr.add(idx as usize + 1);
|
||||
|
||||
// Get frequency (precomputed cos and sin are interleaved)
|
||||
let freq = *freqs_ptr.add(freq_idx as usize);
|
||||
let cos_f = libm::cosf(freq);
|
||||
let sin_f = libm::sinf(freq);
|
||||
|
||||
// Apply rotation
|
||||
let y0 = x0 * cos_f - x1 * sin_f;
|
||||
let y1 = x0 * sin_f + x1 * cos_f;
|
||||
|
||||
// Write output
|
||||
*y_ptr.add(idx as usize) = y0;
|
||||
*y_ptr.add(idx as usize + 1) = y1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
OK
|
||||
}
|
||||
|
||||
/// Execute RoPE backward pass (gradient computation)
|
||||
///
|
||||
/// The backward pass is the same rotation with negated sin,
|
||||
/// since the Jacobian of rotation is another rotation.
|
||||
#[no_mangle]
|
||||
pub extern "C" fn kernel_backward(desc_ptr: *const KernelDescriptor) -> i32 {
|
||||
// For RoPE, backward is essentially the same operation with transposed rotation
|
||||
// (negated sin terms), but the structure is identical
|
||||
let desc = unsafe { &*desc_ptr };
|
||||
|
||||
if desc.input_a_size == 0 {
|
||||
return INVALID_INPUT;
|
||||
}
|
||||
if desc.output_size == 0 || desc.output_size != desc.input_a_size {
|
||||
return INVALID_OUTPUT;
|
||||
}
|
||||
if desc.params_size < core::mem::size_of::<RopeParams>() as u32 {
|
||||
return INVALID_PARAMS;
|
||||
}
|
||||
|
||||
let memory_base = 0usize as *mut u8;
|
||||
|
||||
let params = unsafe {
|
||||
&*(memory_base.add(desc.params_offset as usize) as *const RopeParams)
|
||||
};
|
||||
|
||||
if params.head_dim % 2 != 0 {
|
||||
return INVALID_PARAMS;
|
||||
}
|
||||
|
||||
let half_dim = params.head_dim / 2;
|
||||
|
||||
let grad_y_ptr = unsafe { memory_base.add(desc.input_a_offset as usize) as *const f32 };
|
||||
let freqs_ptr = unsafe { memory_base.add(desc.input_b_offset as usize) as *const f32 };
|
||||
let grad_x_ptr = unsafe { memory_base.add(desc.output_offset as usize) as *mut f32 };
|
||||
|
||||
// Backward RoPE: apply inverse rotation (transpose = negate sin)
|
||||
for b in 0..params.batch_size {
|
||||
for s in 0..params.seq_len {
|
||||
for h in 0..params.num_heads {
|
||||
for d in 0..half_dim {
|
||||
let idx = ((b * params.seq_len + s) * params.num_heads + h) * params.head_dim + d * 2;
|
||||
let freq_idx = s * half_dim + d;
|
||||
|
||||
unsafe {
|
||||
let gy0 = *grad_y_ptr.add(idx as usize);
|
||||
let gy1 = *grad_y_ptr.add(idx as usize + 1);
|
||||
|
||||
let freq = *freqs_ptr.add(freq_idx as usize);
|
||||
let cos_f = libm::cosf(freq);
|
||||
let sin_f = libm::sinf(freq);
|
||||
|
||||
// Inverse rotation (transpose)
|
||||
let gx0 = gy0 * cos_f + gy1 * sin_f;
|
||||
let gx1 = -gy0 * sin_f + gy1 * cos_f;
|
||||
|
||||
*grad_x_ptr.add(idx as usize) = gx0;
|
||||
*grad_x_ptr.add(idx as usize + 1) = gx1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
OK
|
||||
}
|
||||
|
||||
/// Kernel info structure
|
||||
#[repr(C)]
|
||||
pub struct KernelInfo {
|
||||
pub name_ptr: *const u8,
|
||||
pub name_len: u32,
|
||||
pub version_major: u16,
|
||||
pub version_minor: u16,
|
||||
pub version_patch: u16,
|
||||
pub supports_backward: bool,
|
||||
}
|
||||
|
||||
static KERNEL_NAME: &[u8] = b"rope_f32\0";
|
||||
|
||||
/// Get kernel metadata
|
||||
#[no_mangle]
|
||||
pub extern "C" fn kernel_info(info_ptr: *mut KernelInfo) -> i32 {
|
||||
if info_ptr.is_null() {
|
||||
return INVALID_PARAMS;
|
||||
}
|
||||
|
||||
unsafe {
|
||||
(*info_ptr).name_ptr = KERNEL_NAME.as_ptr();
|
||||
(*info_ptr).name_len = KERNEL_NAME.len() as u32 - 1; // Exclude null terminator
|
||||
(*info_ptr).version_major = 1;
|
||||
(*info_ptr).version_minor = 0;
|
||||
(*info_ptr).version_patch = 0;
|
||||
(*info_ptr).supports_backward = true;
|
||||
}
|
||||
|
||||
OK
|
||||
}
|
||||
|
||||
/// Cleanup kernel resources
|
||||
#[no_mangle]
|
||||
pub extern "C" fn kernel_cleanup() -> i32 {
|
||||
// No resources to cleanup for this stateless kernel
|
||||
OK
|
||||
}
|
||||
|
||||
// Minimal libm implementations for no_std
|
||||
mod libm {
|
||||
// Simple Taylor series approximations for sin and cos
|
||||
// In production, use more accurate implementations or link to libm
|
||||
|
||||
const PI: f32 = 3.14159265358979323846;
|
||||
const TWO_PI: f32 = 2.0 * PI;
|
||||
|
||||
fn normalize_angle(mut x: f32) -> f32 {
|
||||
// Reduce to [-PI, PI]
|
||||
while x > PI {
|
||||
x -= TWO_PI;
|
||||
}
|
||||
while x < -PI {
|
||||
x += TWO_PI;
|
||||
}
|
||||
x
|
||||
}
|
||||
|
||||
pub fn sinf(x: f32) -> f32 {
|
||||
let x = normalize_angle(x);
|
||||
// Taylor series: sin(x) = x - x^3/3! + x^5/5! - x^7/7! + ...
|
||||
let x2 = x * x;
|
||||
let x3 = x2 * x;
|
||||
let x5 = x3 * x2;
|
||||
let x7 = x5 * x2;
|
||||
let x9 = x7 * x2;
|
||||
|
||||
x - x3 / 6.0 + x5 / 120.0 - x7 / 5040.0 + x9 / 362880.0
|
||||
}
|
||||
|
||||
pub fn cosf(x: f32) -> f32 {
|
||||
let x = normalize_angle(x);
|
||||
// Taylor series: cos(x) = 1 - x^2/2! + x^4/4! - x^6/6! + ...
|
||||
let x2 = x * x;
|
||||
let x4 = x2 * x2;
|
||||
let x6 = x4 * x2;
|
||||
let x8 = x6 * x2;
|
||||
|
||||
1.0 - x2 / 2.0 + x4 / 24.0 - x6 / 720.0 + x8 / 40320.0
|
||||
}
|
||||
}
|
||||
299
crates/ruvector-wasm/kernels/swiglu.rs
Normal file
299
crates/ruvector-wasm/kernels/swiglu.rs
Normal file
@@ -0,0 +1,299 @@
|
||||
//! SwiGLU (Swish-Gated Linear Unit) Activation Kernel
|
||||
//!
|
||||
//! This kernel implements the SwiGLU activation function used in models
|
||||
//! like LLaMA and PaLM. It combines the Swish activation with a gating
|
||||
//! mechanism.
|
||||
//!
|
||||
//! Formula: SwiGLU(x, gate) = swish(gate) * x
|
||||
//! where swish(x) = x * sigmoid(x)
|
||||
//!
|
||||
//! In practice, this is often used in the FFN:
|
||||
//! FFN(x) = (swish(x * W_gate) * (x * W_up)) * W_down
|
||||
//!
|
||||
//! This kernel computes: swish(gate) * x
|
||||
//!
|
||||
//! # Compilation
|
||||
//!
|
||||
//! To compile this kernel to WASM:
|
||||
//! ```bash
|
||||
//! rustc --target wasm32-unknown-unknown \
|
||||
//! --crate-type cdylib \
|
||||
//! -C opt-level=3 \
|
||||
//! -C lto=fat \
|
||||
//! kernels/swiglu.rs \
|
||||
//! -o kernels/swiglu_f32.wasm
|
||||
//! ```
|
||||
|
||||
#![no_std]
|
||||
#![no_main]
|
||||
|
||||
// Panic handler for no_std
|
||||
#[panic_handler]
|
||||
fn panic(_info: &core::panic::PanicInfo) -> ! {
|
||||
loop {}
|
||||
}
|
||||
|
||||
/// Kernel descriptor structure
|
||||
#[repr(C)]
|
||||
pub struct KernelDescriptor {
|
||||
pub input_a_offset: u32, // x tensor (to be gated)
|
||||
pub input_a_size: u32,
|
||||
pub input_b_offset: u32, // gate tensor
|
||||
pub input_b_size: u32,
|
||||
pub output_offset: u32,
|
||||
pub output_size: u32,
|
||||
pub scratch_offset: u32,
|
||||
pub scratch_size: u32,
|
||||
pub params_offset: u32,
|
||||
pub params_size: u32,
|
||||
}
|
||||
|
||||
/// SwiGLU parameters
|
||||
#[repr(C)]
|
||||
pub struct SwiGluParams {
|
||||
/// Number of elements (total size = num_elements * hidden_dim)
|
||||
pub num_elements: u32,
|
||||
/// Hidden dimension
|
||||
pub hidden_dim: u32,
|
||||
/// Beta parameter for SiLU/Swish (typically 1.0)
|
||||
pub beta: f32,
|
||||
}
|
||||
|
||||
/// Error codes
|
||||
const OK: i32 = 0;
|
||||
const INVALID_INPUT: i32 = 1;
|
||||
const INVALID_OUTPUT: i32 = 2;
|
||||
const INVALID_PARAMS: i32 = 3;
|
||||
|
||||
/// Initialize kernel
|
||||
#[no_mangle]
|
||||
pub extern "C" fn kernel_init(_params_ptr: *const u8, _params_len: u32) -> i32 {
|
||||
OK
|
||||
}
|
||||
|
||||
/// Compute swish activation: x * sigmoid(beta * x)
|
||||
#[inline]
|
||||
fn swish(x: f32, beta: f32) -> f32 {
|
||||
x * sigmoid(beta * x)
|
||||
}
|
||||
|
||||
/// Sigmoid function: 1 / (1 + exp(-x))
|
||||
#[inline]
|
||||
fn sigmoid(x: f32) -> f32 {
|
||||
1.0 / (1.0 + expf(-x))
|
||||
}
|
||||
|
||||
/// Execute SwiGLU forward pass
|
||||
///
|
||||
/// # Memory Layout
|
||||
///
|
||||
/// Input A (x): [num_elements, hidden_dim] as f32 (value to gate)
|
||||
/// Input B (gate): [num_elements, hidden_dim] as f32 (gate values)
|
||||
/// Output (y): [num_elements, hidden_dim] as f32
|
||||
///
|
||||
/// y = swish(gate) * x
|
||||
#[no_mangle]
|
||||
pub extern "C" fn kernel_forward(desc_ptr: *const KernelDescriptor) -> i32 {
|
||||
let desc = unsafe { &*desc_ptr };
|
||||
|
||||
// Validate inputs
|
||||
if desc.input_a_size == 0 || desc.input_b_size == 0 {
|
||||
return INVALID_INPUT;
|
||||
}
|
||||
if desc.input_a_size != desc.input_b_size {
|
||||
return INVALID_INPUT; // x and gate must have same size
|
||||
}
|
||||
if desc.output_size == 0 || desc.output_size != desc.input_a_size {
|
||||
return INVALID_OUTPUT;
|
||||
}
|
||||
if desc.params_size < core::mem::size_of::<SwiGluParams>() as u32 {
|
||||
return INVALID_PARAMS;
|
||||
}
|
||||
|
||||
let memory_base = 0usize as *mut u8;
|
||||
|
||||
let params = unsafe {
|
||||
&*(memory_base.add(desc.params_offset as usize) as *const SwiGluParams)
|
||||
};
|
||||
|
||||
let total_elements = (params.num_elements * params.hidden_dim) as usize;
|
||||
let beta = params.beta;
|
||||
|
||||
// Get tensor pointers
|
||||
let x_ptr = unsafe { memory_base.add(desc.input_a_offset as usize) as *const f32 };
|
||||
let gate_ptr = unsafe { memory_base.add(desc.input_b_offset as usize) as *const f32 };
|
||||
let y_ptr = unsafe { memory_base.add(desc.output_offset as usize) as *mut f32 };
|
||||
|
||||
// Apply SwiGLU: y = swish(gate) * x
|
||||
for i in 0..total_elements {
|
||||
unsafe {
|
||||
let x_val = *x_ptr.add(i);
|
||||
let gate_val = *gate_ptr.add(i);
|
||||
let swish_gate = swish(gate_val, beta);
|
||||
*y_ptr.add(i) = swish_gate * x_val;
|
||||
}
|
||||
}
|
||||
|
||||
OK
|
||||
}
|
||||
|
||||
/// Execute SwiGLU backward pass
|
||||
///
|
||||
/// Given grad_y, compute grad_x and grad_gate.
|
||||
///
|
||||
/// grad_x = swish(gate) * grad_y
|
||||
/// grad_gate = x * grad_y * (sigmoid(gate) + gate * sigmoid(gate) * (1 - sigmoid(gate)))
|
||||
/// = x * grad_y * sigmoid(gate) * (1 + gate * (1 - sigmoid(gate)))
|
||||
///
|
||||
/// For this simplified kernel:
|
||||
/// Input A (grad_y): gradient from upstream
|
||||
/// Input B contains both (x, gate) - simplified layout
|
||||
/// Output (grad_x): gradient w.r.t. x
|
||||
/// Scratch: gradient w.r.t. gate (if space available)
|
||||
#[no_mangle]
|
||||
pub extern "C" fn kernel_backward(desc_ptr: *const KernelDescriptor) -> i32 {
|
||||
let desc = unsafe { &*desc_ptr };
|
||||
|
||||
if desc.input_a_size == 0 {
|
||||
return INVALID_INPUT;
|
||||
}
|
||||
if desc.output_size == 0 {
|
||||
return INVALID_OUTPUT;
|
||||
}
|
||||
if desc.params_size < core::mem::size_of::<SwiGluParams>() as u32 {
|
||||
return INVALID_PARAMS;
|
||||
}
|
||||
|
||||
let memory_base = 0usize as *mut u8;
|
||||
|
||||
let params = unsafe {
|
||||
&*(memory_base.add(desc.params_offset as usize) as *const SwiGluParams)
|
||||
};
|
||||
|
||||
let total_elements = (params.num_elements * params.hidden_dim) as usize;
|
||||
let beta = params.beta;
|
||||
|
||||
// For backward, input_b should contain original gate values
|
||||
// This is a simplified layout - real implementation would use separate descriptors
|
||||
let grad_y_ptr = unsafe { memory_base.add(desc.input_a_offset as usize) as *const f32 };
|
||||
let gate_ptr = unsafe { memory_base.add(desc.input_b_offset as usize) as *const f32 };
|
||||
let grad_x_ptr = unsafe { memory_base.add(desc.output_offset as usize) as *mut f32 };
|
||||
|
||||
// Compute grad_x = swish(gate) * grad_y
|
||||
// (simplified: we would also need original x to compute grad_gate)
|
||||
for i in 0..total_elements {
|
||||
unsafe {
|
||||
let grad_y_val = *grad_y_ptr.add(i);
|
||||
let gate_val = *gate_ptr.add(i);
|
||||
let swish_gate = swish(gate_val, beta);
|
||||
*grad_x_ptr.add(i) = swish_gate * grad_y_val;
|
||||
}
|
||||
}
|
||||
|
||||
OK
|
||||
}
|
||||
|
||||
/// Kernel info structure
|
||||
#[repr(C)]
|
||||
pub struct KernelInfo {
|
||||
pub name_ptr: *const u8,
|
||||
pub name_len: u32,
|
||||
pub version_major: u16,
|
||||
pub version_minor: u16,
|
||||
pub version_patch: u16,
|
||||
pub supports_backward: bool,
|
||||
}
|
||||
|
||||
static KERNEL_NAME: &[u8] = b"swiglu_f32\0";
|
||||
|
||||
/// Get kernel metadata
|
||||
#[no_mangle]
|
||||
pub extern "C" fn kernel_info(info_ptr: *mut KernelInfo) -> i32 {
|
||||
if info_ptr.is_null() {
|
||||
return INVALID_PARAMS;
|
||||
}
|
||||
|
||||
unsafe {
|
||||
(*info_ptr).name_ptr = KERNEL_NAME.as_ptr();
|
||||
(*info_ptr).name_len = KERNEL_NAME.len() as u32 - 1;
|
||||
(*info_ptr).version_major = 1;
|
||||
(*info_ptr).version_minor = 0;
|
||||
(*info_ptr).version_patch = 0;
|
||||
(*info_ptr).supports_backward = true;
|
||||
}
|
||||
|
||||
OK
|
||||
}
|
||||
|
||||
/// Cleanup kernel resources
|
||||
#[no_mangle]
|
||||
pub extern "C" fn kernel_cleanup() -> i32 {
|
||||
OK
|
||||
}
|
||||
|
||||
// Minimal exp implementation for no_std
|
||||
fn expf(x: f32) -> f32 {
|
||||
// Handle edge cases
|
||||
if x > 88.0 {
|
||||
return f32::INFINITY;
|
||||
}
|
||||
if x < -88.0 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Use range reduction: exp(x) = 2^k * exp(r)
|
||||
// where k = round(x / ln(2)) and r = x - k * ln(2)
|
||||
const LN2: f32 = 0.693147180559945;
|
||||
const LN2_INV: f32 = 1.442695040888963;
|
||||
|
||||
let k = (x * LN2_INV + 0.5).floor();
|
||||
let r = x - k * LN2;
|
||||
|
||||
// Taylor series for exp(r) where |r| <= ln(2)/2
|
||||
// exp(r) ≈ 1 + r + r^2/2! + r^3/3! + r^4/4! + r^5/5! + r^6/6!
|
||||
let r2 = r * r;
|
||||
let r3 = r2 * r;
|
||||
let r4 = r2 * r2;
|
||||
let r5 = r4 * r;
|
||||
let r6 = r3 * r3;
|
||||
|
||||
let exp_r = 1.0 + r + r2 * 0.5 + r3 * 0.166666667 + r4 * 0.041666667 + r5 * 0.008333333 + r6 * 0.001388889;
|
||||
|
||||
// Combine: exp(x) = 2^k * exp(r)
|
||||
// 2^k can be computed via bit manipulation
|
||||
let k_int = k as i32;
|
||||
let scale_bits = ((127 + k_int) as u32) << 23;
|
||||
let scale = f32::from_bits(scale_bits);
|
||||
|
||||
exp_r * scale
|
||||
}
|
||||
|
||||
/// Compute GeGLU variant (alternative activation)
|
||||
/// GeGLU(x, gate) = gelu(gate) * x
|
||||
/// This is provided as an alternative, not used in default forward
|
||||
#[allow(dead_code)]
|
||||
fn gelu(x: f32) -> f32 {
|
||||
// Approximate GELU: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
||||
const SQRT_2_OVER_PI: f32 = 0.7978845608028654;
|
||||
const COEFF: f32 = 0.044715;
|
||||
|
||||
let x3 = x * x * x;
|
||||
let inner = SQRT_2_OVER_PI * (x + COEFF * x3);
|
||||
0.5 * x * (1.0 + tanhf(inner))
|
||||
}
|
||||
|
||||
/// Minimal tanh implementation
|
||||
#[allow(dead_code)]
|
||||
fn tanhf(x: f32) -> f32 {
|
||||
// tanh(x) = (exp(2x) - 1) / (exp(2x) + 1)
|
||||
// For numerical stability with large |x|
|
||||
if x > 10.0 {
|
||||
return 1.0;
|
||||
}
|
||||
if x < -10.0 {
|
||||
return -1.0;
|
||||
}
|
||||
|
||||
let exp_2x = expf(2.0 * x);
|
||||
(exp_2x - 1.0) / (exp_2x + 1.0)
|
||||
}
|
||||
Reference in New Issue
Block a user