Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
This commit is contained in:
235
crates/ruvllm-wasm/src/webgpu/shaders/norm.wgsl
Normal file
235
crates/ruvllm-wasm/src/webgpu/shaders/norm.wgsl
Normal file
@@ -0,0 +1,235 @@
|
||||
// RMSNorm and LayerNorm Shaders for WebGPU WASM
|
||||
//
|
||||
// Implements normalization layers used in transformer architectures:
|
||||
// - RMSNorm: Used in Llama, Mistral (no mean subtraction)
|
||||
// - LayerNorm: Standard transformer normalization
|
||||
//
|
||||
// RMSNorm: y = x / sqrt(mean(x^2) + eps) * weight
|
||||
// LayerNorm: y = (x - mean) / sqrt(var + eps) * weight + bias
|
||||
|
||||
const WARP_SIZE: u32 = 32u;
|
||||
const MAX_DIM: u32 = 8192u;
|
||||
|
||||
struct NormUniforms {
|
||||
hidden_dim: u32,
|
||||
batch_size: u32,
|
||||
eps: f32,
|
||||
_pad: u32,
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> input: array<f32>;
|
||||
@group(0) @binding(1) var<storage, read> weight: array<f32>;
|
||||
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
|
||||
@group(0) @binding(3) var<uniform> uniforms: NormUniforms;
|
||||
|
||||
// Shared memory for parallel reduction
|
||||
var<workgroup> partial_sums: array<f32, 256>;
|
||||
|
||||
// RMSNorm: y = x * rsqrt(mean(x^2) + eps) * weight
|
||||
@compute @workgroup_size(256, 1, 1)
|
||||
fn rms_norm(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
let hidden_dim = uniforms.hidden_dim;
|
||||
let eps = uniforms.eps;
|
||||
|
||||
let batch_idx = group_id.x;
|
||||
let thread_id = local_id.x;
|
||||
let offset = batch_idx * hidden_dim;
|
||||
|
||||
// Each thread computes partial sum of squares
|
||||
var thread_sum = 0.0f;
|
||||
let elements_per_thread = (hidden_dim + 255u) / 256u;
|
||||
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < hidden_dim) {
|
||||
let x = input[offset + idx];
|
||||
thread_sum += x * x;
|
||||
}
|
||||
}
|
||||
|
||||
// Store partial sum
|
||||
partial_sums[thread_id] = thread_sum;
|
||||
workgroupBarrier();
|
||||
|
||||
// Parallel reduction for sum of squares
|
||||
for (var stride = 128u; stride > 0u; stride >>= 1u) {
|
||||
if (thread_id < stride) {
|
||||
partial_sums[thread_id] += partial_sums[thread_id + stride];
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
// Compute RMS scale factor
|
||||
let mean_sq = partial_sums[0] / f32(hidden_dim);
|
||||
let rms_scale = 1.0f / sqrt(mean_sq + eps);
|
||||
workgroupBarrier();
|
||||
|
||||
// Apply normalization and weight
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < hidden_dim) {
|
||||
let x = input[offset + idx];
|
||||
output[offset + idx] = x * rms_scale * weight[idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fused RMSNorm + Residual: y = (x + residual) * rsqrt(mean((x+res)^2) + eps) * weight
|
||||
@compute @workgroup_size(256, 1, 1)
|
||||
fn rms_norm_residual(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
let hidden_dim = uniforms.hidden_dim;
|
||||
let eps = uniforms.eps;
|
||||
|
||||
let batch_idx = group_id.x;
|
||||
let thread_id = local_id.x;
|
||||
let offset = batch_idx * hidden_dim;
|
||||
|
||||
// Compute partial sum of (x + residual)^2
|
||||
var thread_sum = 0.0f;
|
||||
let elements_per_thread = (hidden_dim + 255u) / 256u;
|
||||
|
||||
// First pass: compute residual sum and store in shared for reduction
|
||||
// Note: residual is passed in output buffer for in-place update
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < hidden_dim) {
|
||||
let x = input[offset + idx] + output[offset + idx]; // x + residual
|
||||
thread_sum += x * x;
|
||||
}
|
||||
}
|
||||
|
||||
partial_sums[thread_id] = thread_sum;
|
||||
workgroupBarrier();
|
||||
|
||||
// Parallel reduction
|
||||
for (var stride = 128u; stride > 0u; stride >>= 1u) {
|
||||
if (thread_id < stride) {
|
||||
partial_sums[thread_id] += partial_sums[thread_id + stride];
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
let mean_sq = partial_sums[0] / f32(hidden_dim);
|
||||
let rms_scale = 1.0f / sqrt(mean_sq + eps);
|
||||
workgroupBarrier();
|
||||
|
||||
// Apply normalization
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < hidden_dim) {
|
||||
let x = input[offset + idx] + output[offset + idx];
|
||||
output[offset + idx] = x * rms_scale * weight[idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Standard LayerNorm with bias
|
||||
@group(0) @binding(4) var<storage, read> bias: array<f32>;
|
||||
|
||||
@compute @workgroup_size(256, 1, 1)
|
||||
fn layer_norm(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
let hidden_dim = uniforms.hidden_dim;
|
||||
let eps = uniforms.eps;
|
||||
|
||||
let batch_idx = group_id.x;
|
||||
let thread_id = local_id.x;
|
||||
let offset = batch_idx * hidden_dim;
|
||||
|
||||
let elements_per_thread = (hidden_dim + 255u) / 256u;
|
||||
|
||||
// First pass: compute mean
|
||||
var thread_sum = 0.0f;
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < hidden_dim) {
|
||||
thread_sum += input[offset + idx];
|
||||
}
|
||||
}
|
||||
|
||||
partial_sums[thread_id] = thread_sum;
|
||||
workgroupBarrier();
|
||||
|
||||
for (var stride = 128u; stride > 0u; stride >>= 1u) {
|
||||
if (thread_id < stride) {
|
||||
partial_sums[thread_id] += partial_sums[thread_id + stride];
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
let mean = partial_sums[0] / f32(hidden_dim);
|
||||
workgroupBarrier();
|
||||
|
||||
// Second pass: compute variance
|
||||
var thread_var = 0.0f;
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < hidden_dim) {
|
||||
let diff = input[offset + idx] - mean;
|
||||
thread_var += diff * diff;
|
||||
}
|
||||
}
|
||||
|
||||
partial_sums[thread_id] = thread_var;
|
||||
workgroupBarrier();
|
||||
|
||||
for (var stride = 128u; stride > 0u; stride >>= 1u) {
|
||||
if (thread_id < stride) {
|
||||
partial_sums[thread_id] += partial_sums[thread_id + stride];
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
let variance = partial_sums[0] / f32(hidden_dim);
|
||||
let inv_std = 1.0f / sqrt(variance + eps);
|
||||
workgroupBarrier();
|
||||
|
||||
// Third pass: normalize and apply affine transform
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < hidden_dim) {
|
||||
let x = input[offset + idx];
|
||||
output[offset + idx] = (x - mean) * inv_std * weight[idx] + bias[idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fast RMSNorm for small hidden dimensions (direct reduction)
|
||||
@compute @workgroup_size(128, 1, 1)
|
||||
fn rms_norm_small(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
let hidden_dim = uniforms.hidden_dim;
|
||||
let eps = uniforms.eps;
|
||||
|
||||
let batch_idx = group_id.x;
|
||||
let thread_id = local_id.x;
|
||||
let offset = batch_idx * hidden_dim;
|
||||
|
||||
// For small hidden_dim (<= 128), direct computation
|
||||
if (thread_id < hidden_dim) {
|
||||
// Compute sum of squares (all threads contribute)
|
||||
var sum_sq = 0.0f;
|
||||
for (var i = 0u; i < hidden_dim; i++) {
|
||||
let x = input[offset + i];
|
||||
sum_sq += x * x;
|
||||
}
|
||||
|
||||
let rms = sqrt(sum_sq / f32(hidden_dim) + eps);
|
||||
let x = input[offset + thread_id];
|
||||
output[offset + thread_id] = x / rms * weight[thread_id];
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user