git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
310 lines
8.7 KiB
Rust
310 lines
8.7 KiB
Rust
//! RMSNorm (Root Mean Square Layer Normalization) Kernel
|
|
//!
|
|
//! This kernel implements RMS normalization as used in models like LLaMA.
|
|
//! Unlike LayerNorm, RMSNorm only uses the root mean square, without
|
|
//! centering the distribution.
|
|
//!
|
|
//! Formula: y = (x / rms(x)) * weight
|
|
//! where rms(x) = sqrt(mean(x^2) + eps)
|
|
//!
|
|
//! # Compilation
|
|
//!
|
|
//! To compile this kernel to WASM:
|
|
//! ```bash
|
|
//! rustc --target wasm32-unknown-unknown \
|
|
//! --crate-type cdylib \
|
|
//! -C opt-level=3 \
|
|
//! -C lto=fat \
|
|
//! kernels/rmsnorm.rs \
|
|
//! -o kernels/rmsnorm_f32.wasm
|
|
//! ```
|
|
|
|
#![no_std]
|
|
#![no_main]
|
|
|
|
// Panic handler for no_std
|
|
#[panic_handler]
|
|
fn panic(_info: &core::panic::PanicInfo) -> ! {
|
|
loop {}
|
|
}
|
|
|
|
/// Kernel descriptor structure
|
|
#[repr(C)]
|
|
pub struct KernelDescriptor {
|
|
pub input_a_offset: u32, // x tensor
|
|
pub input_a_size: u32,
|
|
pub input_b_offset: u32, // weight tensor (gamma)
|
|
pub input_b_size: u32,
|
|
pub output_offset: u32,
|
|
pub output_size: u32,
|
|
pub scratch_offset: u32, // For storing intermediate RMS values
|
|
pub scratch_size: u32,
|
|
pub params_offset: u32,
|
|
pub params_size: u32,
|
|
}
|
|
|
|
/// RMSNorm parameters
|
|
#[repr(C)]
|
|
pub struct RmsNormParams {
|
|
/// Epsilon for numerical stability (typically 1e-5 or 1e-6)
|
|
pub eps: f32,
|
|
/// Hidden dimension (normalizing dimension)
|
|
pub hidden_dim: u32,
|
|
/// Number of elements to normalize (batch * seq)
|
|
pub num_elements: u32,
|
|
}
|
|
|
|
/// Error codes
|
|
const OK: i32 = 0;
|
|
const INVALID_INPUT: i32 = 1;
|
|
const INVALID_OUTPUT: i32 = 2;
|
|
const INVALID_PARAMS: i32 = 3;
|
|
|
|
/// Initialize kernel
|
|
#[no_mangle]
|
|
pub extern "C" fn kernel_init(_params_ptr: *const u8, _params_len: u32) -> i32 {
|
|
OK
|
|
}
|
|
|
|
/// Execute RMSNorm forward pass
|
|
///
|
|
/// # Memory Layout
|
|
///
|
|
/// Input A (x): [num_elements, hidden_dim] as f32
|
|
/// Input B (weight): [hidden_dim] as f32 (gamma scaling factors)
|
|
/// Output (y): [num_elements, hidden_dim] as f32
|
|
/// Scratch: [num_elements] as f32 (RMS values for backward pass)
|
|
///
|
|
/// For each row i:
|
|
/// rms[i] = sqrt(mean(x[i]^2) + eps)
|
|
/// y[i] = (x[i] / rms[i]) * weight
|
|
#[no_mangle]
|
|
pub extern "C" fn kernel_forward(desc_ptr: *const KernelDescriptor) -> i32 {
|
|
let desc = unsafe { &*desc_ptr };
|
|
|
|
// Validate inputs
|
|
if desc.input_a_size == 0 {
|
|
return INVALID_INPUT;
|
|
}
|
|
if desc.output_size == 0 {
|
|
return INVALID_OUTPUT;
|
|
}
|
|
if desc.params_size < core::mem::size_of::<RmsNormParams>() as u32 {
|
|
return INVALID_PARAMS;
|
|
}
|
|
|
|
let memory_base = 0usize as *mut u8;
|
|
|
|
let params = unsafe {
|
|
&*(memory_base.add(desc.params_offset as usize) as *const RmsNormParams)
|
|
};
|
|
|
|
let hidden_dim = params.hidden_dim as usize;
|
|
let num_elements = params.num_elements as usize;
|
|
let eps = params.eps;
|
|
|
|
// Get tensor pointers
|
|
let x_ptr = unsafe { memory_base.add(desc.input_a_offset as usize) as *const f32 };
|
|
let weight_ptr = unsafe { memory_base.add(desc.input_b_offset as usize) as *const f32 };
|
|
let y_ptr = unsafe { memory_base.add(desc.output_offset as usize) as *mut f32 };
|
|
|
|
// Optional: Store RMS values in scratch for backward pass
|
|
let rms_ptr = if desc.scratch_size >= (num_elements * 4) as u32 {
|
|
Some(unsafe { memory_base.add(desc.scratch_offset as usize) as *mut f32 })
|
|
} else {
|
|
None
|
|
};
|
|
|
|
// Process each element (row)
|
|
for i in 0..num_elements {
|
|
let row_offset = i * hidden_dim;
|
|
|
|
// Compute sum of squares
|
|
let mut sum_sq: f32 = 0.0;
|
|
for j in 0..hidden_dim {
|
|
unsafe {
|
|
let val = *x_ptr.add(row_offset + j);
|
|
sum_sq += val * val;
|
|
}
|
|
}
|
|
|
|
// Compute RMS
|
|
let mean_sq = sum_sq / (hidden_dim as f32);
|
|
let rms = sqrtf(mean_sq + eps);
|
|
let inv_rms = 1.0 / rms;
|
|
|
|
// Store RMS for backward pass if scratch is available
|
|
if let Some(rms_store) = rms_ptr {
|
|
unsafe {
|
|
*rms_store.add(i) = rms;
|
|
}
|
|
}
|
|
|
|
// Normalize and scale
|
|
for j in 0..hidden_dim {
|
|
unsafe {
|
|
let x_val = *x_ptr.add(row_offset + j);
|
|
let w_val = *weight_ptr.add(j);
|
|
*y_ptr.add(row_offset + j) = (x_val * inv_rms) * w_val;
|
|
}
|
|
}
|
|
}
|
|
|
|
OK
|
|
}
|
|
|
|
/// Execute RMSNorm backward pass
|
|
///
|
|
/// Computes gradients for x and weight given gradient of output.
|
|
///
|
|
/// # Memory Layout (for backward)
|
|
///
|
|
/// Input A (grad_y): [num_elements, hidden_dim] as f32
|
|
/// Input B (x): Original input (needed for gradient)
|
|
/// Output (grad_x): [num_elements, hidden_dim] as f32
|
|
/// Scratch: [hidden_dim] as f32 (for grad_weight accumulation)
|
|
/// Params: Contains weight pointer separately
|
|
#[no_mangle]
|
|
pub extern "C" fn kernel_backward(desc_ptr: *const KernelDescriptor) -> i32 {
|
|
let desc = unsafe { &*desc_ptr };
|
|
|
|
if desc.input_a_size == 0 {
|
|
return INVALID_INPUT;
|
|
}
|
|
if desc.output_size == 0 {
|
|
return INVALID_OUTPUT;
|
|
}
|
|
if desc.params_size < core::mem::size_of::<RmsNormParams>() as u32 {
|
|
return INVALID_PARAMS;
|
|
}
|
|
|
|
let memory_base = 0usize as *mut u8;
|
|
|
|
let params = unsafe {
|
|
&*(memory_base.add(desc.params_offset as usize) as *const RmsNormParams)
|
|
};
|
|
|
|
let hidden_dim = params.hidden_dim as usize;
|
|
let num_elements = params.num_elements as usize;
|
|
let eps = params.eps;
|
|
|
|
// Note: For a complete backward pass, we would need:
|
|
// - grad_y: gradient from upstream
|
|
// - x: original input
|
|
// - weight: scale parameters
|
|
// - Output: grad_x
|
|
// - Accumulate: grad_weight
|
|
|
|
// This is a simplified implementation showing the structure
|
|
let grad_y_ptr = unsafe { memory_base.add(desc.input_a_offset as usize) as *const f32 };
|
|
let x_ptr = unsafe { memory_base.add(desc.input_b_offset as usize) as *const f32 };
|
|
let grad_x_ptr = unsafe { memory_base.add(desc.output_offset as usize) as *mut f32 };
|
|
|
|
// For each element
|
|
for i in 0..num_elements {
|
|
let row_offset = i * hidden_dim;
|
|
|
|
// Recompute RMS (or load from scratch if saved during forward)
|
|
let mut sum_sq: f32 = 0.0;
|
|
for j in 0..hidden_dim {
|
|
unsafe {
|
|
let val = *x_ptr.add(row_offset + j);
|
|
sum_sq += val * val;
|
|
}
|
|
}
|
|
let mean_sq = sum_sq / (hidden_dim as f32);
|
|
let rms = sqrtf(mean_sq + eps);
|
|
let inv_rms = 1.0 / rms;
|
|
let inv_rms_cubed = inv_rms * inv_rms * inv_rms;
|
|
|
|
// Compute grad_norm_x = grad_y * weight
|
|
// Then grad_x = inv_rms * grad_norm_x - inv_rms^3 * x * mean(x * grad_norm_x)
|
|
// This is the chain rule applied to RMSNorm
|
|
|
|
// First pass: compute sum(x * grad_y) for this row
|
|
let mut sum_x_grad: f32 = 0.0;
|
|
for j in 0..hidden_dim {
|
|
unsafe {
|
|
let x_val = *x_ptr.add(row_offset + j);
|
|
let gy_val = *grad_y_ptr.add(row_offset + j);
|
|
sum_x_grad += x_val * gy_val;
|
|
}
|
|
}
|
|
let mean_x_grad = sum_x_grad / (hidden_dim as f32);
|
|
|
|
// Second pass: compute grad_x
|
|
for j in 0..hidden_dim {
|
|
unsafe {
|
|
let x_val = *x_ptr.add(row_offset + j);
|
|
let gy_val = *grad_y_ptr.add(row_offset + j);
|
|
|
|
// Simplified gradient (without weight consideration for this demo)
|
|
let grad = inv_rms * gy_val - inv_rms_cubed * x_val * mean_x_grad;
|
|
*grad_x_ptr.add(row_offset + j) = grad;
|
|
}
|
|
}
|
|
}
|
|
|
|
OK
|
|
}
|
|
|
|
/// Kernel info structure
|
|
#[repr(C)]
|
|
pub struct KernelInfo {
|
|
pub name_ptr: *const u8,
|
|
pub name_len: u32,
|
|
pub version_major: u16,
|
|
pub version_minor: u16,
|
|
pub version_patch: u16,
|
|
pub supports_backward: bool,
|
|
}
|
|
|
|
static KERNEL_NAME: &[u8] = b"rmsnorm_f32\0";
|
|
|
|
/// Get kernel metadata
|
|
#[no_mangle]
|
|
pub extern "C" fn kernel_info(info_ptr: *mut KernelInfo) -> i32 {
|
|
if info_ptr.is_null() {
|
|
return INVALID_PARAMS;
|
|
}
|
|
|
|
unsafe {
|
|
(*info_ptr).name_ptr = KERNEL_NAME.as_ptr();
|
|
(*info_ptr).name_len = KERNEL_NAME.len() as u32 - 1;
|
|
(*info_ptr).version_major = 1;
|
|
(*info_ptr).version_minor = 0;
|
|
(*info_ptr).version_patch = 0;
|
|
(*info_ptr).supports_backward = true;
|
|
}
|
|
|
|
OK
|
|
}
|
|
|
|
/// Cleanup kernel resources
|
|
#[no_mangle]
|
|
pub extern "C" fn kernel_cleanup() -> i32 {
|
|
OK
|
|
}
|
|
|
|
// Minimal sqrt implementation for no_std
|
|
fn sqrtf(x: f32) -> f32 {
|
|
if x <= 0.0 {
|
|
return 0.0;
|
|
}
|
|
|
|
// Newton-Raphson method
|
|
let mut guess = x;
|
|
|
|
// Initial guess using bit manipulation
|
|
let i = x.to_bits();
|
|
let i = 0x1fbd1df5 + (i >> 1);
|
|
guess = f32::from_bits(i);
|
|
|
|
// Newton-Raphson iterations
|
|
for _ in 0..3 {
|
|
guess = 0.5 * (guess + x / guess);
|
|
}
|
|
|
|
guess
|
|
}
|