Files
wifi-densepose/crates/ruvllm-wasm/src/webgpu/shaders/norm.wgsl
ruv d803bfe2b1 Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector
git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
2026-02-28 14:39:40 -05:00

236 lines
7.1 KiB
WebGPU Shading Language

// RMSNorm and LayerNorm Shaders for WebGPU WASM
//
// Implements normalization layers used in transformer architectures:
// - RMSNorm: Used in Llama, Mistral (no mean subtraction)
// - LayerNorm: Standard transformer normalization
//
// RMSNorm: y = x / sqrt(mean(x^2) + eps) * weight
// LayerNorm: y = (x - mean) / sqrt(var + eps) * weight + bias
const WARP_SIZE: u32 = 32u;
const MAX_DIM: u32 = 8192u;
struct NormUniforms {
hidden_dim: u32,
batch_size: u32,
eps: f32,
_pad: u32,
}
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read> weight: array<f32>;
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
@group(0) @binding(3) var<uniform> uniforms: NormUniforms;
// Shared memory for parallel reduction
var<workgroup> partial_sums: array<f32, 256>;
// RMSNorm: y = x * rsqrt(mean(x^2) + eps) * weight
@compute @workgroup_size(256, 1, 1)
fn rms_norm(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) group_id: vec3<u32>,
) {
let hidden_dim = uniforms.hidden_dim;
let eps = uniforms.eps;
let batch_idx = group_id.x;
let thread_id = local_id.x;
let offset = batch_idx * hidden_dim;
// Each thread computes partial sum of squares
var thread_sum = 0.0f;
let elements_per_thread = (hidden_dim + 255u) / 256u;
for (var i = 0u; i < elements_per_thread; i++) {
let idx = thread_id + i * 256u;
if (idx < hidden_dim) {
let x = input[offset + idx];
thread_sum += x * x;
}
}
// Store partial sum
partial_sums[thread_id] = thread_sum;
workgroupBarrier();
// Parallel reduction for sum of squares
for (var stride = 128u; stride > 0u; stride >>= 1u) {
if (thread_id < stride) {
partial_sums[thread_id] += partial_sums[thread_id + stride];
}
workgroupBarrier();
}
// Compute RMS scale factor
let mean_sq = partial_sums[0] / f32(hidden_dim);
let rms_scale = 1.0f / sqrt(mean_sq + eps);
workgroupBarrier();
// Apply normalization and weight
for (var i = 0u; i < elements_per_thread; i++) {
let idx = thread_id + i * 256u;
if (idx < hidden_dim) {
let x = input[offset + idx];
output[offset + idx] = x * rms_scale * weight[idx];
}
}
}
// Fused RMSNorm + Residual: y = (x + residual) * rsqrt(mean((x+res)^2) + eps) * weight
@compute @workgroup_size(256, 1, 1)
fn rms_norm_residual(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) group_id: vec3<u32>,
) {
let hidden_dim = uniforms.hidden_dim;
let eps = uniforms.eps;
let batch_idx = group_id.x;
let thread_id = local_id.x;
let offset = batch_idx * hidden_dim;
// Compute partial sum of (x + residual)^2
var thread_sum = 0.0f;
let elements_per_thread = (hidden_dim + 255u) / 256u;
// First pass: compute residual sum and store in shared for reduction
// Note: residual is passed in output buffer for in-place update
for (var i = 0u; i < elements_per_thread; i++) {
let idx = thread_id + i * 256u;
if (idx < hidden_dim) {
let x = input[offset + idx] + output[offset + idx]; // x + residual
thread_sum += x * x;
}
}
partial_sums[thread_id] = thread_sum;
workgroupBarrier();
// Parallel reduction
for (var stride = 128u; stride > 0u; stride >>= 1u) {
if (thread_id < stride) {
partial_sums[thread_id] += partial_sums[thread_id + stride];
}
workgroupBarrier();
}
let mean_sq = partial_sums[0] / f32(hidden_dim);
let rms_scale = 1.0f / sqrt(mean_sq + eps);
workgroupBarrier();
// Apply normalization
for (var i = 0u; i < elements_per_thread; i++) {
let idx = thread_id + i * 256u;
if (idx < hidden_dim) {
let x = input[offset + idx] + output[offset + idx];
output[offset + idx] = x * rms_scale * weight[idx];
}
}
}
// Standard LayerNorm with bias
@group(0) @binding(4) var<storage, read> bias: array<f32>;
@compute @workgroup_size(256, 1, 1)
fn layer_norm(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) group_id: vec3<u32>,
) {
let hidden_dim = uniforms.hidden_dim;
let eps = uniforms.eps;
let batch_idx = group_id.x;
let thread_id = local_id.x;
let offset = batch_idx * hidden_dim;
let elements_per_thread = (hidden_dim + 255u) / 256u;
// First pass: compute mean
var thread_sum = 0.0f;
for (var i = 0u; i < elements_per_thread; i++) {
let idx = thread_id + i * 256u;
if (idx < hidden_dim) {
thread_sum += input[offset + idx];
}
}
partial_sums[thread_id] = thread_sum;
workgroupBarrier();
for (var stride = 128u; stride > 0u; stride >>= 1u) {
if (thread_id < stride) {
partial_sums[thread_id] += partial_sums[thread_id + stride];
}
workgroupBarrier();
}
let mean = partial_sums[0] / f32(hidden_dim);
workgroupBarrier();
// Second pass: compute variance
var thread_var = 0.0f;
for (var i = 0u; i < elements_per_thread; i++) {
let idx = thread_id + i * 256u;
if (idx < hidden_dim) {
let diff = input[offset + idx] - mean;
thread_var += diff * diff;
}
}
partial_sums[thread_id] = thread_var;
workgroupBarrier();
for (var stride = 128u; stride > 0u; stride >>= 1u) {
if (thread_id < stride) {
partial_sums[thread_id] += partial_sums[thread_id + stride];
}
workgroupBarrier();
}
let variance = partial_sums[0] / f32(hidden_dim);
let inv_std = 1.0f / sqrt(variance + eps);
workgroupBarrier();
// Third pass: normalize and apply affine transform
for (var i = 0u; i < elements_per_thread; i++) {
let idx = thread_id + i * 256u;
if (idx < hidden_dim) {
let x = input[offset + idx];
output[offset + idx] = (x - mean) * inv_std * weight[idx] + bias[idx];
}
}
}
// Fast RMSNorm for small hidden dimensions (direct reduction)
@compute @workgroup_size(128, 1, 1)
fn rms_norm_small(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) group_id: vec3<u32>,
) {
let hidden_dim = uniforms.hidden_dim;
let eps = uniforms.eps;
let batch_idx = group_id.x;
let thread_id = local_id.x;
let offset = batch_idx * hidden_dim;
// For small hidden_dim (<= 128), direct computation
if (thread_id < hidden_dim) {
// Compute sum of squares (all threads contribute)
var sum_sq = 0.0f;
for (var i = 0u; i < hidden_dim; i++) {
let x = input[offset + i];
sum_sq += x * x;
}
let rms = sqrt(sum_sq / f32(hidden_dim) + eps);
let x = input[offset + thread_id];
output[offset + thread_id] = x / rms * weight[thread_id];
}
}