Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,705 @@
//! GPU Buffer Management
//!
//! Provides efficient GPU buffer allocation, management, and data transfer
//! for the coherence engine. Implements a buffer pool for reuse and
//! minimizes CPU-GPU synchronization overhead.
use super::error::{GpuError, GpuResult};
use bytemuck::{Pod, Zeroable};
use std::collections::HashMap;
use std::sync::Arc;
use wgpu::{Buffer, BufferDescriptor, BufferUsages, Device, Queue};
/// Buffer usage flags for coherence computation
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum BufferUsage {
/// Storage buffer for node states
NodeStates,
/// Storage buffer for edge data
EdgeData,
/// Storage buffer for restriction maps
RestrictionMaps,
/// Storage buffer for residuals
Residuals,
/// Storage buffer for energy values
Energies,
/// Storage buffer for attention weights
AttentionWeights,
/// Storage buffer for routing decisions
RoutingDecisions,
/// Uniform buffer for shader parameters
Uniforms,
/// Staging buffer for CPU readback
Staging,
}
/// GPU-side node state representation
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct GpuNodeState {
/// Flattened state vector (padded to MAX_STATE_DIM)
pub state: [f32; 128], // Will be dynamically sized based on actual dim
/// Actual dimension of the state vector
pub dim: u32,
/// Node index
pub index: u32,
/// Padding for alignment
pub _padding: [u32; 2],
}
/// GPU-side edge representation
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct GpuEdge {
/// Source node index
pub source_idx: u32,
/// Target node index
pub target_idx: u32,
/// Edge weight
pub weight: f32,
/// Restriction map index for source
pub rho_source_idx: u32,
/// Restriction map index for target
pub rho_target_idx: u32,
/// Output dimension of restriction maps
pub comparison_dim: u32,
/// Padding for alignment
pub _padding: [u32; 2],
}
/// GPU-side restriction map (dense matrix stored row-major)
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct GpuRestrictionMap {
/// Matrix type: 0=identity, 1=diagonal, 2=projection, 3=dense
pub map_type: u32,
/// Input dimension
pub input_dim: u32,
/// Output dimension
pub output_dim: u32,
/// Offset into the shared data buffer
pub data_offset: u32,
/// Number of elements in data
pub data_len: u32,
/// Padding for alignment
pub _padding: [u32; 3],
}
/// GPU-side shader parameters
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct GpuParams {
/// Number of edges
pub num_edges: u32,
/// Number of nodes
pub num_nodes: u32,
/// State dimension
pub state_dim: u32,
/// Beta parameter for attention
pub beta: f32,
/// Lane 0 threshold (reflex)
pub threshold_lane0: f32,
/// Lane 1 threshold (retrieval)
pub threshold_lane1: f32,
/// Lane 2 threshold (heavy)
pub threshold_lane2: f32,
/// Flag to control residual storage (0 = skip, 1 = store)
/// When computing energy only, skip storage for better performance
pub store_residuals: u32,
}
/// Wrapper around a wgpu Buffer with metadata
pub struct GpuBuffer {
/// The underlying wgpu buffer
pub buffer: Buffer,
/// Size in bytes
pub size: usize,
/// Usage flags
pub usage: BufferUsage,
/// Label for debugging
pub label: String,
}
impl GpuBuffer {
/// Create a new GPU buffer
pub fn new(
device: &Device,
size: usize,
usage: BufferUsage,
label: impl Into<String>,
) -> GpuResult<Self> {
let label = label.into();
let wgpu_usage = Self::to_wgpu_usage(usage);
let buffer = device.create_buffer(&BufferDescriptor {
label: Some(&label),
size: size as u64,
usage: wgpu_usage,
mapped_at_creation: false,
});
Ok(Self {
buffer,
size,
usage,
label,
})
}
/// Create a new GPU buffer with initial data
pub fn new_with_data<T: Pod>(
device: &Device,
queue: &Queue,
data: &[T],
usage: BufferUsage,
label: impl Into<String>,
) -> GpuResult<Self> {
let label = label.into();
let bytes = bytemuck::cast_slice(data);
let size = bytes.len();
let wgpu_usage = Self::to_wgpu_usage(usage);
let buffer = device.create_buffer(&BufferDescriptor {
label: Some(&label),
size: size as u64,
usage: wgpu_usage,
mapped_at_creation: false,
});
queue.write_buffer(&buffer, 0, bytes);
Ok(Self {
buffer,
size,
usage,
label,
})
}
/// Write data to the buffer
pub fn write<T: Pod>(&self, queue: &Queue, data: &[T]) -> GpuResult<()> {
let bytes = bytemuck::cast_slice(data);
if bytes.len() > self.size {
return Err(GpuError::BufferSizeMismatch {
expected: self.size,
actual: bytes.len(),
});
}
queue.write_buffer(&self.buffer, 0, bytes);
Ok(())
}
/// Convert our usage to wgpu usage flags
fn to_wgpu_usage(usage: BufferUsage) -> BufferUsages {
match usage {
BufferUsage::NodeStates
| BufferUsage::EdgeData
| BufferUsage::RestrictionMaps
| BufferUsage::Residuals
| BufferUsage::Energies
| BufferUsage::AttentionWeights
| BufferUsage::RoutingDecisions => {
BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST
}
BufferUsage::Uniforms => BufferUsages::UNIFORM | BufferUsages::COPY_DST,
BufferUsage::Staging => BufferUsages::MAP_READ | BufferUsages::COPY_DST,
}
}
}
/// Buffer manager for efficient allocation and reuse
pub struct GpuBufferManager {
device: Arc<Device>,
queue: Arc<Queue>,
/// Buffer pool keyed by (usage, size_bucket)
pool: HashMap<(BufferUsage, usize), Vec<GpuBuffer>>,
/// Active buffers currently in use
active: HashMap<String, GpuBuffer>,
}
impl GpuBufferManager {
/// Create a new buffer manager
pub fn new(device: Arc<Device>, queue: Arc<Queue>) -> Self {
Self {
device,
queue,
pool: HashMap::new(),
active: HashMap::new(),
}
}
/// Allocate or reuse a buffer
pub fn allocate(
&mut self,
size: usize,
usage: BufferUsage,
label: impl Into<String>,
) -> GpuResult<&GpuBuffer> {
let label = label.into();
let bucket = Self::size_bucket(size);
// Try to reuse from pool
if let Some(buffers) = self.pool.get_mut(&(usage, bucket)) {
if let Some(buffer) = buffers.pop() {
self.active.insert(label.clone(), buffer);
return Ok(self.active.get(&label).unwrap());
}
}
// Allocate new buffer
let buffer = GpuBuffer::new(&self.device, bucket, usage, &label)?;
self.active.insert(label.clone(), buffer);
Ok(self.active.get(&label).unwrap())
}
/// Allocate or reuse a buffer with initial data
pub fn allocate_with_data<T: Pod>(
&mut self,
data: &[T],
usage: BufferUsage,
label: impl Into<String>,
) -> GpuResult<&GpuBuffer> {
let label = label.into();
let size = std::mem::size_of_val(data);
let bucket = Self::size_bucket(size);
// Try to reuse from pool
if let Some(buffers) = self.pool.get_mut(&(usage, bucket)) {
if let Some(buffer) = buffers.pop() {
buffer.write(&self.queue, data)?;
self.active.insert(label.clone(), buffer);
return Ok(self.active.get(&label).unwrap());
}
}
// Allocate new buffer with data
let buffer = GpuBuffer::new_with_data(&self.device, &self.queue, data, usage, &label)?;
self.active.insert(label.clone(), buffer);
Ok(self.active.get(&label).unwrap())
}
/// Get an active buffer by label
pub fn get(&self, label: &str) -> Option<&GpuBuffer> {
self.active.get(label)
}
/// Release a buffer back to the pool for reuse
pub fn release(&mut self, label: &str) {
if let Some(buffer) = self.active.remove(label) {
let bucket = Self::size_bucket(buffer.size);
self.pool
.entry((buffer.usage, bucket))
.or_default()
.push(buffer);
}
}
/// Release all active buffers back to the pool
pub fn release_all(&mut self) {
let labels: Vec<_> = self.active.keys().cloned().collect();
for label in labels {
self.release(&label);
}
}
/// Clear all buffers (both pool and active)
pub fn clear(&mut self) {
self.active.clear();
self.pool.clear();
}
/// Round size up to nearest power of 2 for efficient reuse
fn size_bucket(size: usize) -> usize {
const MIN_BUCKET: usize = 256;
if size <= MIN_BUCKET {
MIN_BUCKET
} else {
size.next_power_of_two()
}
}
/// Get the underlying device
pub fn device(&self) -> &Device {
&self.device
}
/// Get the underlying queue
pub fn queue(&self) -> &Queue {
&self.queue
}
}
// ============================================================================
// BUFFER USAGE FLAGS (for pipeline.rs compatibility)
// ============================================================================
/// Buffer usage flags for flexible configuration
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct BufferUsageFlags {
/// Can be read from GPU (STORAGE)
pub storage_read: bool,
/// Can be written to by GPU (STORAGE)
pub storage_write: bool,
/// Can be used as uniform buffer
pub uniform: bool,
/// Can be mapped for CPU read
pub map_read: bool,
/// Can be mapped for CPU write
pub map_write: bool,
/// Can be used as copy source
pub copy_src: bool,
/// Can be used as copy destination
pub copy_dst: bool,
/// Can be used for indirect dispatch
pub indirect: bool,
}
impl BufferUsageFlags {
/// Storage buffer (read-only)
pub const fn storage_readonly() -> Self {
Self {
storage_read: true,
storage_write: false,
uniform: false,
map_read: false,
map_write: false,
copy_src: true,
copy_dst: true,
indirect: false,
}
}
/// Storage buffer (read-write)
pub const fn storage_readwrite() -> Self {
Self {
storage_read: true,
storage_write: true,
uniform: false,
map_read: false,
map_write: false,
copy_src: true,
copy_dst: true,
indirect: false,
}
}
/// Uniform buffer
pub const fn uniform() -> Self {
Self {
storage_read: false,
storage_write: false,
uniform: true,
map_read: false,
map_write: false,
copy_src: false,
copy_dst: true,
indirect: false,
}
}
/// Staging buffer for read-back
pub const fn staging_read() -> Self {
Self {
storage_read: false,
storage_write: false,
uniform: false,
map_read: true,
map_write: false,
copy_src: false,
copy_dst: true,
indirect: false,
}
}
/// Staging buffer for upload
pub const fn staging_write() -> Self {
Self {
storage_read: false,
storage_write: false,
uniform: false,
map_read: false,
map_write: true,
copy_src: true,
copy_dst: false,
indirect: false,
}
}
/// Indirect dispatch buffer
pub const fn indirect() -> Self {
Self {
storage_read: true,
storage_write: true,
uniform: false,
map_read: false,
map_write: false,
copy_src: true,
copy_dst: true,
indirect: true,
}
}
/// Convert to wgpu buffer usages
pub fn to_wgpu(&self) -> BufferUsages {
let mut usages = BufferUsages::empty();
if self.storage_read || self.storage_write {
usages |= BufferUsages::STORAGE;
}
if self.uniform {
usages |= BufferUsages::UNIFORM;
}
if self.map_read {
usages |= BufferUsages::MAP_READ;
}
if self.map_write {
usages |= BufferUsages::MAP_WRITE;
}
if self.copy_src {
usages |= BufferUsages::COPY_SRC;
}
if self.copy_dst {
usages |= BufferUsages::COPY_DST;
}
if self.indirect {
usages |= BufferUsages::INDIRECT;
}
usages
}
}
// ============================================================================
// BUFFER KEY AND POOL (for dispatch.rs compatibility)
// ============================================================================
/// Key for buffer pool lookups
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct BufferKey {
/// Buffer size in bytes
pub size: u64,
/// Buffer usage flags
pub usage: BufferUsageFlags,
}
impl BufferKey {
/// Create a new buffer key
pub fn new(size: u64, usage: BufferUsageFlags) -> Self {
Self { size, usage }
}
}
/// Buffer pool for reusing GPU allocations with DashMap for concurrent access
pub struct GpuBufferPool {
device: Arc<Device>,
buffers: dashmap::DashMap<BufferKey, Vec<GpuBuffer>>,
max_pool_size: usize,
}
impl GpuBufferPool {
/// Create a new buffer pool
pub fn new(device: Arc<Device>) -> Self {
Self::with_capacity(device, super::DEFAULT_POOL_CAPACITY)
}
/// Create a new buffer pool with custom capacity
pub fn with_capacity(device: Arc<Device>, max_pool_size: usize) -> Self {
Self {
device,
buffers: dashmap::DashMap::new(),
max_pool_size,
}
}
/// Acquire a buffer from the pool or create a new one.
pub fn acquire(&self, size: u64, usage: BufferUsageFlags) -> GpuResult<GpuBuffer> {
if size > super::MAX_BUFFER_SIZE {
return Err(GpuError::BufferTooLarge {
size,
max: super::MAX_BUFFER_SIZE,
});
}
let key = BufferKey::new(size, usage);
// Try to get from pool
if let Some(mut buffers) = self.buffers.get_mut(&key) {
if let Some(buffer) = buffers.pop() {
return Ok(buffer);
}
}
// Create new buffer
let wgpu_buffer = self.device.create_buffer(&BufferDescriptor {
label: Some("pooled_buffer"),
size,
usage: usage.to_wgpu(),
mapped_at_creation: false,
});
Ok(GpuBuffer {
buffer: wgpu_buffer,
size: size as usize,
usage: BufferUsage::Staging, // Default usage type
label: "pooled_buffer".to_string(),
})
}
/// Return a buffer to the pool for reuse.
pub fn release(&self, buffer: GpuBuffer) {
let size = buffer.size as u64;
let usage = BufferUsageFlags::storage_readwrite(); // Default
let key = BufferKey::new(size, usage);
let mut buffers = self.buffers.entry(key).or_insert_with(Vec::new);
if buffers.len() < self.max_pool_size {
buffers.push(buffer);
}
}
/// Clear all pooled buffers
pub fn clear(&self) {
self.buffers.clear();
}
/// Get statistics about the pool
pub fn stats(&self) -> PoolStats {
let mut total_buffers = 0;
let mut total_bytes = 0u64;
for entry in self.buffers.iter() {
total_buffers += entry.value().len();
total_bytes += entry.key().size * entry.value().len() as u64;
}
PoolStats {
total_buffers,
total_bytes,
bucket_count: self.buffers.len(),
}
}
}
/// Statistics about the buffer pool
#[derive(Debug, Clone)]
pub struct PoolStats {
/// Total number of pooled buffers
pub total_buffers: usize,
/// Total bytes allocated in pool
pub total_bytes: u64,
/// Number of unique buffer configurations
pub bucket_count: usize,
}
// ============================================================================
// EXTENDED GPUBUFFER METHODS (for pipeline.rs compatibility)
// ============================================================================
impl GpuBuffer {
/// Create a binding entry for this buffer.
pub fn binding(&self, binding: u32) -> wgpu::BindGroupEntry {
wgpu::BindGroupEntry {
binding,
resource: self.buffer.as_entire_binding(),
}
}
/// Get the underlying wgpu buffer
pub fn buffer(&self) -> &Buffer {
&self.buffer
}
/// Create a new storage buffer with initial data (for dispatch compatibility)
pub fn new_storage<T: Pod>(
device: &Device,
queue: &Queue,
data: &[T],
read_write: bool,
) -> GpuResult<Self> {
let usage = if read_write {
BufferUsage::Residuals
} else {
BufferUsage::NodeStates
};
Self::new_with_data(device, queue, data, usage, "storage_buffer")
}
/// Create a new uninitialized storage buffer
pub fn new_storage_uninit<T: Pod>(
device: &Device,
count: usize,
read_write: bool,
) -> GpuResult<Self> {
let size = count * std::mem::size_of::<T>();
let usage = if read_write {
BufferUsage::Residuals
} else {
BufferUsage::NodeStates
};
Self::new(device, size, usage, "storage_buffer_uninit")
}
/// Create a new uniform buffer with data
pub fn new_uniform<T: Pod>(device: &Device, queue: &Queue, data: &T) -> GpuResult<Self> {
Self::new_with_data(
device,
queue,
std::slice::from_ref(data),
BufferUsage::Uniforms,
"uniform_buffer",
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_size_bucket() {
assert_eq!(GpuBufferManager::size_bucket(100), 256);
assert_eq!(GpuBufferManager::size_bucket(256), 256);
assert_eq!(GpuBufferManager::size_bucket(257), 512);
assert_eq!(GpuBufferManager::size_bucket(1000), 1024);
}
#[test]
fn test_gpu_params_alignment() {
// Ensure our GPU structs are properly aligned for wgpu
assert_eq!(std::mem::size_of::<GpuParams>(), 32);
assert_eq!(std::mem::align_of::<GpuParams>(), 4);
}
#[test]
fn test_gpu_edge_alignment() {
assert_eq!(std::mem::size_of::<GpuEdge>(), 32);
assert_eq!(std::mem::align_of::<GpuEdge>(), 4);
}
#[test]
fn test_gpu_restriction_map_alignment() {
assert_eq!(std::mem::size_of::<GpuRestrictionMap>(), 32);
assert_eq!(std::mem::align_of::<GpuRestrictionMap>(), 4);
}
#[test]
fn test_buffer_usage_flags() {
let readonly = BufferUsageFlags::storage_readonly();
assert!(readonly.storage_read);
assert!(!readonly.storage_write);
let readwrite = BufferUsageFlags::storage_readwrite();
assert!(readwrite.storage_read);
assert!(readwrite.storage_write);
}
#[test]
fn test_buffer_key_equality() {
let key1 = BufferKey::new(1024, BufferUsageFlags::storage_readonly());
let key2 = BufferKey::new(1024, BufferUsageFlags::storage_readonly());
let key3 = BufferKey::new(2048, BufferUsageFlags::storage_readonly());
assert_eq!(key1, key2);
assert_ne!(key1, key3);
}
}

View File

@@ -0,0 +1,290 @@
//! GPU device initialization and management.
//!
//! This module provides the core GPU device abstraction using wgpu,
//! handling adapter selection, device creation, and queue management.
use std::sync::Arc;
use tracing::{debug, info, warn};
use wgpu::{Adapter, Device, Instance, Queue};
use super::error::{GpuError, GpuResult};
/// Information about the GPU device
#[derive(Debug, Clone)]
pub struct GpuDeviceInfo {
/// Device name
pub name: String,
/// Vendor ID
pub vendor: u32,
/// Device ID
pub device_id: u32,
/// Device type (discrete, integrated, etc.)
pub device_type: String,
/// Backend API (Vulkan, Metal, DX12, etc.)
pub backend: String,
/// Maximum buffer size
pub max_buffer_size: u64,
/// Maximum compute workgroup size per dimension
pub max_workgroup_size: [u32; 3],
/// Maximum compute workgroups per dimension
pub max_workgroups: [u32; 3],
/// Maximum storage buffers per shader stage
pub max_storage_buffers: u32,
}
/// GPU device wrapper providing access to wgpu resources
pub struct GpuDevice {
instance: Instance,
adapter: Adapter,
device: Arc<Device>,
queue: Arc<Queue>,
info: GpuDeviceInfo,
}
impl GpuDevice {
/// Create a new GPU device with default configuration.
///
/// This will:
/// 1. Create a wgpu instance with all available backends
/// 2. Request a high-performance adapter
/// 3. Create the device and queue
///
/// # Errors
///
/// Returns `GpuError::NoAdapter` if no suitable GPU is found.
/// Returns `GpuError::DeviceRequestFailed` if device creation fails.
pub async fn new() -> GpuResult<Self> {
Self::with_options(GpuDeviceOptions::default()).await
}
/// Create a new GPU device with custom options.
pub async fn with_options(options: GpuDeviceOptions) -> GpuResult<Self> {
let instance = Instance::new(wgpu::InstanceDescriptor {
backends: options.backends,
flags: wgpu::InstanceFlags::default(),
dx12_shader_compiler: wgpu::Dx12Compiler::default(),
gles_minor_version: wgpu::Gles3MinorVersion::default(),
});
debug!(
"Created wgpu instance with backends: {:?}",
options.backends
);
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: options.power_preference,
compatible_surface: None,
force_fallback_adapter: options.force_fallback,
})
.await
.ok_or(GpuError::NoAdapter)?;
let adapter_info = adapter.get_info();
info!(
"Selected GPU adapter: {} ({:?})",
adapter_info.name, adapter_info.backend
);
let limits = if options.use_downlevel_limits {
wgpu::Limits::downlevel_defaults()
} else {
wgpu::Limits::default()
};
let (device, queue) = adapter
.request_device(
&wgpu::DeviceDescriptor {
label: Some("prime-radiant-gpu"),
required_features: options.required_features,
required_limits: limits.clone(),
memory_hints: wgpu::MemoryHints::Performance,
},
None,
)
.await?;
// Set up error handling
device.on_uncaptured_error(Box::new(|error| {
warn!("Uncaptured GPU error: {:?}", error);
}));
let info = GpuDeviceInfo {
name: adapter_info.name.clone(),
vendor: adapter_info.vendor,
device_id: adapter_info.device,
device_type: format!("{:?}", adapter_info.device_type),
backend: format!("{:?}", adapter_info.backend),
max_buffer_size: limits.max_buffer_size as u64,
max_workgroup_size: [
limits.max_compute_workgroup_size_x,
limits.max_compute_workgroup_size_y,
limits.max_compute_workgroup_size_z,
],
max_workgroups: [
limits.max_compute_workgroups_per_dimension,
limits.max_compute_workgroups_per_dimension,
limits.max_compute_workgroups_per_dimension,
],
max_storage_buffers: limits.max_storage_buffers_per_shader_stage,
};
debug!("GPU device info: {:?}", info);
Ok(Self {
instance,
adapter,
device: Arc::new(device),
queue: Arc::new(queue),
info,
})
}
/// Get a reference to the wgpu device
pub fn device(&self) -> &Device {
&self.device
}
/// Get a shared reference to the wgpu device
pub fn device_arc(&self) -> Arc<Device> {
Arc::clone(&self.device)
}
/// Get a reference to the command queue
pub fn queue(&self) -> &Queue {
&self.queue
}
/// Get a shared reference to the command queue
pub fn queue_arc(&self) -> Arc<Queue> {
Arc::clone(&self.queue)
}
/// Get device information
pub fn info(&self) -> &GpuDeviceInfo {
&self.info
}
/// Get the wgpu instance
pub fn instance(&self) -> &Instance {
&self.instance
}
/// Get the wgpu adapter
pub fn adapter(&self) -> &Adapter {
&self.adapter
}
/// Check if a feature is supported
pub fn supports_feature(&self, feature: wgpu::Features) -> bool {
self.adapter.features().contains(feature)
}
/// Poll the device for completed work.
///
/// This is useful when you need to ensure GPU work has completed
/// before continuing on the CPU.
pub fn poll(&self, wait: bool) -> bool {
self.device
.poll(if wait {
wgpu::Maintain::Wait
} else {
wgpu::Maintain::Poll
})
.is_queue_empty()
}
/// Submit a command buffer to the queue
pub fn submit(&self, command_buffer: wgpu::CommandBuffer) -> wgpu::SubmissionIndex {
self.queue.submit(std::iter::once(command_buffer))
}
/// Submit multiple command buffers to the queue
pub fn submit_multiple(
&self,
command_buffers: impl IntoIterator<Item = wgpu::CommandBuffer>,
) -> wgpu::SubmissionIndex {
self.queue.submit(command_buffers)
}
}
/// Options for GPU device creation
#[derive(Debug, Clone)]
pub struct GpuDeviceOptions {
/// Backends to use (default: all)
pub backends: wgpu::Backends,
/// Power preference (default: high performance)
pub power_preference: wgpu::PowerPreference,
/// Required GPU features
pub required_features: wgpu::Features,
/// Use downlevel limits for broader compatibility
pub use_downlevel_limits: bool,
/// Force fallback adapter (software rendering)
pub force_fallback: bool,
}
impl Default for GpuDeviceOptions {
fn default() -> Self {
Self {
backends: wgpu::Backends::all(),
power_preference: wgpu::PowerPreference::HighPerformance,
required_features: wgpu::Features::empty(),
use_downlevel_limits: false,
force_fallback: false,
}
}
}
impl GpuDeviceOptions {
/// Create options for low-power mode (integrated GPU preferred)
pub fn low_power() -> Self {
Self {
power_preference: wgpu::PowerPreference::LowPower,
..Default::default()
}
}
/// Create options for maximum compatibility
pub fn compatible() -> Self {
Self {
use_downlevel_limits: true,
..Default::default()
}
}
/// Create options for software fallback
pub fn software() -> Self {
Self {
force_fallback: true,
use_downlevel_limits: true,
..Default::default()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_device_options_default() {
let options = GpuDeviceOptions::default();
assert_eq!(
options.power_preference,
wgpu::PowerPreference::HighPerformance
);
assert!(!options.force_fallback);
}
#[test]
fn test_device_options_low_power() {
let options = GpuDeviceOptions::low_power();
assert_eq!(options.power_preference, wgpu::PowerPreference::LowPower);
}
#[test]
fn test_device_options_compatible() {
let options = GpuDeviceOptions::compatible();
assert!(options.use_downlevel_limits);
}
}

View File

@@ -0,0 +1,426 @@
//! Kernel dispatch and synchronization for GPU compute operations.
//!
//! This module provides the dispatcher for executing compute kernels on the GPU,
//! including support for:
//! - Single kernel dispatch
//! - Indirect dispatch (workgroup count from GPU buffer)
//! - Chained dispatch for fused kernels
//! - Synchronization and timing
use std::sync::Arc;
use tracing::{debug, trace};
use wgpu::{CommandEncoder, Device, Queue};
use super::buffer::{GpuBuffer, GpuBufferPool};
use super::device::GpuDevice;
use super::error::{GpuError, GpuResult};
use super::pipeline::{ComputePipeline, PipelineCache};
/// Configuration for a dispatch operation
#[derive(Debug, Clone)]
pub struct DispatchConfig {
/// Label for debugging
pub label: Option<String>,
/// Whether to wait for completion
pub wait: bool,
/// Timeout in milliseconds (0 = no timeout)
pub timeout_ms: u64,
}
impl Default for DispatchConfig {
fn default() -> Self {
Self {
label: None,
wait: false,
timeout_ms: 0,
}
}
}
impl DispatchConfig {
/// Create a config that waits for completion
pub fn wait() -> Self {
Self {
wait: true,
..Default::default()
}
}
/// Create a config with a label
pub fn with_label(label: impl Into<String>) -> Self {
Self {
label: Some(label.into()),
..Default::default()
}
}
/// Set the timeout
pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
self.timeout_ms = timeout_ms;
self
}
/// Set wait flag
pub fn with_wait(mut self, wait: bool) -> Self {
self.wait = wait;
self
}
}
/// GPU dispatcher for executing compute kernels
pub struct GpuDispatcher {
device: Arc<GpuDevice>,
pipeline_cache: PipelineCache,
buffer_pool: GpuBufferPool,
}
impl GpuDispatcher {
/// Create a new dispatcher
pub fn new(device: Arc<GpuDevice>) -> Self {
let pipeline_cache = PipelineCache::new(device.device_arc());
let buffer_pool = GpuBufferPool::new(device.device_arc());
Self {
device,
pipeline_cache,
buffer_pool,
}
}
/// Get the underlying GPU device
pub fn device(&self) -> &GpuDevice {
&self.device
}
/// Get the pipeline cache
pub fn pipeline_cache(&self) -> &PipelineCache {
&self.pipeline_cache
}
/// Get the buffer pool
pub fn buffer_pool(&self) -> &GpuBufferPool {
&self.buffer_pool
}
/// Dispatch a compute kernel.
///
/// # Arguments
///
/// * `pipeline` - The compute pipeline to execute
/// * `bind_group` - The bind group with buffer bindings
/// * `workgroups` - Number of workgroups [x, y, z]
///
/// # Example
///
/// ```rust,ignore
/// dispatcher.dispatch(&pipeline, &bind_group, [4, 1, 1]).await?;
/// ```
pub async fn dispatch(
&self,
pipeline: &ComputePipeline,
bind_group: &wgpu::BindGroup,
workgroups: [u32; 3],
) -> GpuResult<()> {
self.dispatch_with_config(pipeline, bind_group, workgroups, DispatchConfig::default())
.await
}
/// Dispatch with custom configuration.
pub async fn dispatch_with_config(
&self,
pipeline: &ComputePipeline,
bind_group: &wgpu::BindGroup,
workgroups: [u32; 3],
config: DispatchConfig,
) -> GpuResult<()> {
// Validate workgroup count
let limits = &self.device.info().max_workgroups;
if workgroups[0] > limits[0] || workgroups[1] > limits[1] || workgroups[2] > limits[2] {
return Err(GpuError::InvalidWorkgroupSize {
x: workgroups[0],
y: workgroups[1],
z: workgroups[2],
});
}
let label = config.label.as_deref().unwrap_or("dispatch");
debug!(
"Dispatching '{}' with workgroups [{}, {}, {}]",
label, workgroups[0], workgroups[1], workgroups[2]
);
let mut encoder = self
.device
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some(label) });
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some(label),
timestamp_writes: None,
});
pass.set_pipeline(pipeline.pipeline());
pass.set_bind_group(0, Some(bind_group), &[]);
pass.dispatch_workgroups(workgroups[0], workgroups[1], workgroups[2]);
}
self.device.submit(encoder.finish());
if config.wait {
self.device.poll(true);
}
Ok(())
}
/// Dispatch using indirect workgroup count from a buffer.
///
/// The indirect buffer must contain [x, y, z] workgroup counts as u32.
pub async fn dispatch_indirect(
&self,
pipeline: &ComputePipeline,
bind_group: &wgpu::BindGroup,
indirect_buffer: &GpuBuffer,
) -> GpuResult<()> {
self.dispatch_indirect_with_config(
pipeline,
bind_group,
indirect_buffer,
0,
DispatchConfig::default(),
)
.await
}
/// Dispatch indirect with offset and configuration.
pub async fn dispatch_indirect_with_config(
&self,
pipeline: &ComputePipeline,
bind_group: &wgpu::BindGroup,
indirect_buffer: &GpuBuffer,
indirect_offset: u64,
config: DispatchConfig,
) -> GpuResult<()> {
let label = config.label.as_deref().unwrap_or("dispatch_indirect");
debug!("Dispatching indirect '{}'", label);
let mut encoder = self
.device
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some(label) });
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some(label),
timestamp_writes: None,
});
pass.set_pipeline(pipeline.pipeline());
pass.set_bind_group(0, Some(bind_group), &[]);
pass.dispatch_workgroups_indirect(indirect_buffer.buffer(), indirect_offset);
}
self.device.submit(encoder.finish());
if config.wait {
self.device.poll(true);
}
Ok(())
}
/// Dispatch multiple kernels in a chain (fused execution).
///
/// All dispatches are recorded into a single command buffer for
/// optimal GPU utilization.
///
/// # Arguments
///
/// * `dispatches` - List of (pipeline, bind_group, workgroups) tuples
pub async fn dispatch_chain(
&self,
dispatches: &[(&ComputePipeline, &wgpu::BindGroup, [u32; 3])],
) -> GpuResult<()> {
self.dispatch_chain_with_config(dispatches, DispatchConfig::default())
.await
}
/// Dispatch chain with custom configuration.
pub async fn dispatch_chain_with_config(
&self,
dispatches: &[(&ComputePipeline, &wgpu::BindGroup, [u32; 3])],
config: DispatchConfig,
) -> GpuResult<()> {
if dispatches.is_empty() {
return Ok(());
}
let label = config.label.as_deref().unwrap_or("dispatch_chain");
debug!(
"Dispatching chain '{}' with {} kernels",
label,
dispatches.len()
);
let mut encoder = self
.device
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some(label) });
for (i, (pipeline, bind_group, workgroups)) in dispatches.iter().enumerate() {
trace!(
"Chain dispatch {}: workgroups [{}, {}, {}]",
i,
workgroups[0],
workgroups[1],
workgroups[2]
);
let pass_label = format!("{}_pass_{}", label, i);
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some(&pass_label),
timestamp_writes: None,
});
pass.set_pipeline(pipeline.pipeline());
pass.set_bind_group(0, Some(*bind_group), &[]);
pass.dispatch_workgroups(workgroups[0], workgroups[1], workgroups[2]);
}
self.device.submit(encoder.finish());
if config.wait {
self.device.poll(true);
}
Ok(())
}
/// Record dispatches to a command encoder without submitting.
///
/// This is useful when you want to combine compute with other operations.
pub fn record_dispatch(
&self,
encoder: &mut CommandEncoder,
pipeline: &ComputePipeline,
bind_group: &wgpu::BindGroup,
workgroups: [u32; 3],
label: Option<&str>,
) {
let pass_label = label.unwrap_or("recorded_dispatch");
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some(pass_label),
timestamp_writes: None,
});
pass.set_pipeline(pipeline.pipeline());
pass.set_bind_group(0, Some(bind_group), &[]);
pass.dispatch_workgroups(workgroups[0], workgroups[1], workgroups[2]);
}
/// Wait for all pending GPU work to complete.
pub fn synchronize(&self) {
self.device.poll(true);
}
/// Poll for completed work without blocking.
pub fn poll(&self) -> bool {
self.device.poll(false)
}
}
/// Builder for constructing complex dispatch operations
pub struct DispatchBuilder<'a> {
dispatcher: &'a GpuDispatcher,
dispatches: Vec<(Arc<ComputePipeline>, wgpu::BindGroup, [u32; 3])>,
config: DispatchConfig,
}
impl<'a> DispatchBuilder<'a> {
/// Create a new dispatch builder
pub fn new(dispatcher: &'a GpuDispatcher) -> Self {
Self {
dispatcher,
dispatches: Vec::new(),
config: DispatchConfig::default(),
}
}
/// Add a dispatch to the chain
pub fn add(
mut self,
pipeline: Arc<ComputePipeline>,
bind_group: wgpu::BindGroup,
workgroups: [u32; 3],
) -> Self {
self.dispatches.push((pipeline, bind_group, workgroups));
self
}
/// Set the configuration
pub fn config(mut self, config: DispatchConfig) -> Self {
self.config = config;
self
}
/// Set the label
pub fn label(mut self, label: impl Into<String>) -> Self {
self.config.label = Some(label.into());
self
}
/// Set wait flag
pub fn wait(mut self) -> Self {
self.config.wait = true;
self
}
/// Execute all dispatches
pub async fn execute(self) -> GpuResult<()> {
if self.dispatches.is_empty() {
return Ok(());
}
let refs: Vec<(&ComputePipeline, &wgpu::BindGroup, [u32; 3])> = self
.dispatches
.iter()
.map(|(p, b, w)| (p.as_ref(), b, *w))
.collect();
self.dispatcher
.dispatch_chain_with_config(&refs, self.config)
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dispatch_config_default() {
let config = DispatchConfig::default();
assert!(!config.wait);
assert!(config.label.is_none());
assert_eq!(config.timeout_ms, 0);
}
#[test]
fn test_dispatch_config_wait() {
let config = DispatchConfig::wait();
assert!(config.wait);
}
#[test]
fn test_dispatch_config_builder() {
let config = DispatchConfig::with_label("test")
.with_timeout(1000)
.with_wait(true);
assert_eq!(config.label.as_deref(), Some("test"));
assert_eq!(config.timeout_ms, 1000);
assert!(config.wait);
}
}

View File

@@ -0,0 +1,814 @@
//! GPU Coherence Engine
//!
//! Main entry point for GPU-accelerated coherence computation.
//! Provides automatic CPU fallback when GPU is unavailable.
use super::buffer::{BufferUsage, GpuBufferManager, GpuEdge, GpuParams, GpuRestrictionMap};
use super::error::{GpuError, GpuResult};
use super::kernels::{
ComputeEnergyKernel, ComputeResidualsKernel, EnergyParams, SheafAttentionKernel,
TokenRoutingKernel,
};
use crate::coherence::{CoherenceEnergy as CpuCoherenceEnergy, EdgeEnergy};
use crate::substrate::restriction::MatrixStorage;
use crate::substrate::{EdgeId, NodeId, SheafGraph};
use chrono::Utc;
use std::collections::HashMap;
use std::sync::Arc;
use tracing::{debug, info, warn};
use wgpu::{
Adapter, Device, DeviceDescriptor, Features, Instance, InstanceDescriptor, Limits,
PowerPreference, Queue, RequestAdapterOptions,
};
/// GPU configuration
#[derive(Debug, Clone)]
pub struct GpuConfig {
/// Preferred power preference (high performance vs low power)
pub power_preference: PowerPreference,
/// Enable CPU fallback when GPU is unavailable
pub enable_fallback: bool,
/// Maximum buffer size in bytes (0 = no limit)
pub max_buffer_size: usize,
/// Beta parameter for attention computation
pub beta: f32,
/// Lane 0 (reflex) threshold
pub threshold_lane0: f32,
/// Lane 1 (retrieval) threshold
pub threshold_lane1: f32,
/// Lane 2 (heavy) threshold
pub threshold_lane2: f32,
/// Timeout for GPU operations in milliseconds
pub timeout_ms: u64,
}
impl Default for GpuConfig {
fn default() -> Self {
Self {
power_preference: PowerPreference::HighPerformance,
enable_fallback: true,
max_buffer_size: 0, // No limit
beta: 1.0,
threshold_lane0: 0.01,
threshold_lane1: 0.1,
threshold_lane2: 1.0,
timeout_ms: 5000,
}
}
}
/// GPU capabilities and limits
#[derive(Debug, Clone)]
pub struct GpuCapabilities {
/// Device name
pub device_name: String,
/// Vendor
pub vendor: String,
/// Backend (Vulkan, Metal, DX12, etc.)
pub backend: String,
/// Maximum buffer size
pub max_buffer_size: u64,
/// Maximum compute workgroup size
pub max_workgroup_size: u32,
/// Maximum compute workgroups per dimension
pub max_workgroups: [u32; 3],
/// Whether the GPU supports required features
pub supported: bool,
}
/// GPU energy result
#[derive(Debug, Clone)]
pub struct GpuCoherenceEnergy {
/// Total system energy
pub total_energy: f32,
/// Per-edge energies
pub edge_energies: Vec<f32>,
/// Edge indices (matches edge_energies)
pub edge_indices: Vec<EdgeId>,
/// Computation time in microseconds
pub compute_time_us: u64,
/// Whether GPU was used (false = CPU fallback)
pub used_gpu: bool,
}
impl GpuCoherenceEnergy {
/// Convert to CPU CoherenceEnergy format
pub fn to_cpu_format(&self, graph: &SheafGraph) -> CpuCoherenceEnergy {
let mut edge_energy_map = HashMap::new();
for (i, &edge_id) in self.edge_indices.iter().enumerate() {
let energy = self.edge_energies[i];
if let Some(edge) = graph.get_edge(edge_id) {
let edge_energy = EdgeEnergy::new_lightweight(
edge_id.to_string(),
edge.source.to_string(),
edge.target.to_string(),
energy / edge.weight.max(0.001), // Remove weight to get raw norm_sq
edge.weight,
);
edge_energy_map.insert(edge_id.to_string(), edge_energy);
}
}
CpuCoherenceEnergy::new(
edge_energy_map,
&HashMap::new(),
graph.node_count(),
format!("gpu-{}", Utc::now().timestamp()),
)
}
}
/// GPU-accelerated coherence engine
pub struct GpuCoherenceEngine {
device: Arc<Device>,
queue: Arc<Queue>,
buffer_manager: GpuBufferManager,
config: GpuConfig,
capabilities: GpuCapabilities,
// Kernels
residuals_kernel: ComputeResidualsKernel,
energy_kernel: ComputeEnergyKernel,
attention_kernel: SheafAttentionKernel,
routing_kernel: TokenRoutingKernel,
// Cached graph data
graph_data: Option<GpuGraphData>,
}
/// Cached graph data on GPU
struct GpuGraphData {
num_nodes: u32,
num_edges: u32,
state_dim: u32,
node_id_map: HashMap<NodeId, u32>,
edge_id_map: HashMap<EdgeId, u32>,
edge_id_reverse: Vec<EdgeId>,
// Pre-allocated computation buffers (eliminates per-frame allocations)
params_buffer: wgpu::Buffer,
energy_params_buffer: wgpu::Buffer,
partial_sums_buffer: wgpu::Buffer,
final_params_buffer: wgpu::Buffer,
total_energy_buffer: wgpu::Buffer,
energies_staging: wgpu::Buffer,
total_staging: wgpu::Buffer,
// Pre-computed workgroup count for energy reduction
num_workgroups: u32,
}
impl GpuCoherenceEngine {
/// Create a new GPU coherence engine
pub async fn new(config: GpuConfig) -> GpuResult<Self> {
// Create wgpu instance
let instance = Instance::new(InstanceDescriptor::default());
// Request adapter
let adapter = instance
.request_adapter(&RequestAdapterOptions {
power_preference: config.power_preference,
compatible_surface: None,
force_fallback_adapter: false,
})
.await
.ok_or_else(|| GpuError::AdapterRequest("No suitable GPU adapter found".into()))?;
let capabilities = Self::get_capabilities(&adapter);
if !capabilities.supported {
return Err(GpuError::UnsupportedFeature(
"GPU does not support required features".into(),
));
}
info!(
"Using GPU: {} ({}) - {}",
capabilities.device_name, capabilities.vendor, capabilities.backend
);
// Request device
let (device, queue) = adapter
.request_device(
&DeviceDescriptor {
label: Some("prime_radiant_gpu"),
required_features: Features::empty(),
required_limits: Limits::default(),
memory_hints: Default::default(),
},
None,
)
.await
.map_err(|e| GpuError::DeviceCreation(e.to_string()))?;
let device = Arc::new(device);
let queue = Arc::new(queue);
// Create kernels
let residuals_kernel = ComputeResidualsKernel::new(&device)?;
let energy_kernel = ComputeEnergyKernel::new(&device)?;
let attention_kernel = SheafAttentionKernel::new(&device)?;
let routing_kernel = TokenRoutingKernel::new(&device)?;
// Create buffer manager
let buffer_manager = GpuBufferManager::new(device.clone(), queue.clone());
Ok(Self {
device,
queue,
buffer_manager,
config,
capabilities,
residuals_kernel,
energy_kernel,
attention_kernel,
routing_kernel,
graph_data: None,
})
}
/// Try to create a GPU engine, returning None if GPU is unavailable
pub async fn try_new(config: GpuConfig) -> Option<Self> {
match Self::new(config).await {
Ok(engine) => Some(engine),
Err(e) => {
warn!("GPU initialization failed: {}. Will use CPU fallback.", e);
None
}
}
}
/// Get GPU capabilities
fn get_capabilities(adapter: &Adapter) -> GpuCapabilities {
let info = adapter.get_info();
let limits = adapter.limits();
GpuCapabilities {
device_name: info.name,
vendor: format!("{:?}", info.vendor),
backend: format!("{:?}", info.backend),
max_buffer_size: limits.max_buffer_size as u64,
max_workgroup_size: limits.max_compute_workgroup_size_x,
max_workgroups: [
limits.max_compute_workgroups_per_dimension,
limits.max_compute_workgroups_per_dimension,
limits.max_compute_workgroups_per_dimension,
],
supported: true,
}
}
/// Upload graph data to GPU
pub fn upload_graph(&mut self, graph: &SheafGraph) -> GpuResult<()> {
if graph.edge_count() == 0 {
return Err(GpuError::EmptyGraph);
}
let num_nodes = graph.node_count() as u32;
let num_edges = graph.edge_count() as u32;
// Build node ID mapping
let mut node_id_map = HashMap::new();
let node_ids = graph.node_ids();
for (i, node_id) in node_ids.iter().enumerate() {
node_id_map.insert(*node_id, i as u32);
}
// Determine state dimension from first node
let state_dim = node_ids
.first()
.and_then(|id| graph.get_node(*id))
.map(|n| n.dim())
.unwrap_or(64) as u32;
// Flatten node states
let mut node_states: Vec<f32> = Vec::with_capacity((num_nodes * state_dim) as usize);
for node_id in &node_ids {
if let Some(state) = graph.node_state(*node_id) {
node_states.extend(state.iter().cloned());
// Pad if needed
for _ in state.len()..(state_dim as usize) {
node_states.push(0.0);
}
}
}
// Build edge data and restriction maps
let mut edges: Vec<GpuEdge> = Vec::with_capacity(num_edges as usize);
let mut restriction_maps: Vec<GpuRestrictionMap> = Vec::new();
let mut restriction_data: Vec<f32> = Vec::new();
let mut edge_id_map = HashMap::new();
let mut edge_id_reverse = Vec::new();
let edge_ids = graph.edge_ids();
for (i, edge_id) in edge_ids.iter().enumerate() {
edge_id_map.insert(*edge_id, i as u32);
edge_id_reverse.push(*edge_id);
if let Some(edge) = graph.get_edge(*edge_id) {
let source_idx = *node_id_map.get(&edge.source).unwrap_or(&0);
let target_idx = *node_id_map.get(&edge.target).unwrap_or(&0);
// Convert restriction maps
let rho_source_idx = restriction_maps.len() as u32;
let gpu_rho_source =
Self::convert_restriction_map(&edge.rho_source, &mut restriction_data);
restriction_maps.push(gpu_rho_source);
let rho_target_idx = restriction_maps.len() as u32;
let gpu_rho_target =
Self::convert_restriction_map(&edge.rho_target, &mut restriction_data);
restriction_maps.push(gpu_rho_target);
edges.push(GpuEdge {
source_idx,
target_idx,
weight: edge.weight,
rho_source_idx,
rho_target_idx,
comparison_dim: edge.comparison_dim() as u32,
_padding: [0; 2],
});
}
}
// Ensure restriction_data is not empty (GPU buffers can't be zero-sized)
if restriction_data.is_empty() {
restriction_data.push(0.0);
}
// Upload to GPU
self.buffer_manager.allocate_with_data(
&node_states,
BufferUsage::NodeStates,
"node_states",
)?;
self.buffer_manager
.allocate_with_data(&edges, BufferUsage::EdgeData, "edges")?;
self.buffer_manager.allocate_with_data(
&restriction_maps,
BufferUsage::RestrictionMaps,
"restriction_maps",
)?;
self.buffer_manager.allocate_with_data(
&restriction_data,
BufferUsage::RestrictionMaps,
"restriction_data",
)?;
// Allocate output buffers
let max_comparison_dim = edges
.iter()
.map(|e| e.comparison_dim)
.max()
.unwrap_or(state_dim);
let residuals_size = (num_edges * max_comparison_dim) as usize * std::mem::size_of::<f32>();
let energies_size = num_edges as usize * std::mem::size_of::<f32>();
self.buffer_manager
.allocate(residuals_size, BufferUsage::Residuals, "residuals")?;
self.buffer_manager
.allocate(energies_size, BufferUsage::Energies, "edge_energies")?;
// Pre-allocate computation buffers to eliminate per-frame allocations
let num_workgroups = ComputeEnergyKernel::workgroup_count(num_edges);
// Params buffer (GpuParams)
let params_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("params_preallocated"),
size: std::mem::size_of::<GpuParams>() as u64,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
// Energy params buffer (EnergyParams)
let energy_params_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("energy_params_preallocated"),
size: std::mem::size_of::<EnergyParams>() as u64,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
// Partial sums buffer (one f32 per workgroup)
let partial_sums_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("partial_sums_preallocated"),
size: ((num_workgroups as usize).max(1) * std::mem::size_of::<f32>()) as u64,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_SRC
| wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
// Final params buffer (for second reduction pass)
let final_params_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("final_params_preallocated"),
size: std::mem::size_of::<EnergyParams>() as u64,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
// Total energy buffer (single f32)
let total_energy_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("total_energy_preallocated"),
size: std::mem::size_of::<f32>() as u64,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_SRC
| wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
// Staging buffer for edge energies readback
let energies_staging = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("energies_staging_preallocated"),
size: (num_edges as usize * std::mem::size_of::<f32>()) as u64,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
// Staging buffer for total energy readback
let total_staging = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("total_staging_preallocated"),
size: std::mem::size_of::<f32>() as u64,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
// Store graph data with pre-allocated buffers
self.graph_data = Some(GpuGraphData {
num_nodes,
num_edges,
state_dim,
node_id_map,
edge_id_map,
edge_id_reverse,
params_buffer,
energy_params_buffer,
partial_sums_buffer,
final_params_buffer,
total_energy_buffer,
energies_staging,
total_staging,
num_workgroups,
});
debug!(
"Uploaded graph to GPU: {} nodes, {} edges, state_dim={}, workgroups={}",
num_nodes, num_edges, state_dim, num_workgroups
);
Ok(())
}
/// Convert a RestrictionMap to GPU format
fn convert_restriction_map(
map: &crate::substrate::RestrictionMap,
data: &mut Vec<f32>,
) -> GpuRestrictionMap {
let data_offset = data.len() as u32;
let (map_type, data_len) = match &map.matrix {
MatrixStorage::Identity => (0, 0),
MatrixStorage::Diagonal(scales) => {
data.extend(scales.iter().cloned());
(1, scales.len() as u32)
}
MatrixStorage::Projection { indices, .. } => {
data.extend(indices.iter().map(|&i| i as f32));
(2, indices.len() as u32)
}
MatrixStorage::Sparse { values, .. } => {
// Simplified: just store values (would need row/col in practice)
data.extend(values.iter().cloned());
(3, values.len() as u32)
}
MatrixStorage::Csr(csr) => {
// CSR format: store values similar to sparse
// Note: In practice, GPU would need row_ptr and col_indices too
data.extend(csr.values.iter().cloned());
(3, csr.values.len() as u32)
}
MatrixStorage::Dense {
data: matrix_data, ..
} => {
data.extend(matrix_data.iter().cloned());
(3, matrix_data.len() as u32)
}
};
GpuRestrictionMap {
map_type,
input_dim: map.input_dim() as u32,
output_dim: map.output_dim() as u32,
data_offset,
data_len,
_padding: [0; 3],
}
}
/// Compute coherence energy on GPU
/// Uses pre-allocated buffers to eliminate per-frame allocations
pub async fn compute_energy(&mut self) -> GpuResult<GpuCoherenceEnergy> {
let start = std::time::Instant::now();
let graph_data = self
.graph_data
.as_ref()
.ok_or_else(|| GpuError::Internal("Graph not uploaded".into()))?;
let num_edges = graph_data.num_edges;
let num_workgroups = graph_data.num_workgroups;
// Write params to pre-allocated buffer (no allocation)
let params = GpuParams {
num_edges,
num_nodes: graph_data.num_nodes,
state_dim: graph_data.state_dim,
beta: self.config.beta,
threshold_lane0: self.config.threshold_lane0,
threshold_lane1: self.config.threshold_lane1,
threshold_lane2: self.config.threshold_lane2,
store_residuals: 1, // Store residuals by default for gradient computation
};
self.queue
.write_buffer(&graph_data.params_buffer, 0, bytemuck::bytes_of(&params));
// Write energy params to pre-allocated buffer (no allocation)
let energy_params = EnergyParams {
num_elements: num_edges,
_padding: [0; 7],
};
self.queue.write_buffer(
&graph_data.energy_params_buffer,
0,
bytemuck::bytes_of(&energy_params),
);
// Get managed buffers for bind group creation
let node_states_buf = self
.buffer_manager
.get("node_states")
.ok_or_else(|| GpuError::Internal("Node states buffer not found".into()))?;
let edges_buf = self
.buffer_manager
.get("edges")
.ok_or_else(|| GpuError::Internal("Edges buffer not found".into()))?;
let restriction_maps_buf = self
.buffer_manager
.get("restriction_maps")
.ok_or_else(|| GpuError::Internal("Restriction maps buffer not found".into()))?;
let restriction_data_buf = self
.buffer_manager
.get("restriction_data")
.ok_or_else(|| GpuError::Internal("Restriction data buffer not found".into()))?;
let residuals_buf = self
.buffer_manager
.get("residuals")
.ok_or_else(|| GpuError::Internal("Residuals buffer not found".into()))?;
let energies_buf = self
.buffer_manager
.get("edge_energies")
.ok_or_else(|| GpuError::Internal("Edge energies buffer not found".into()))?;
// Create bind group for residuals kernel using pre-allocated params buffer
let residuals_bind_group = self.residuals_kernel.create_bind_group_raw(
&self.device,
&graph_data.params_buffer,
&node_states_buf.buffer,
&edges_buf.buffer,
&restriction_maps_buf.buffer,
&restriction_data_buf.buffer,
&residuals_buf.buffer,
&energies_buf.buffer,
);
// Create bind group for energy reduction using pre-allocated buffers
let energy_bind_group = self.energy_kernel.create_bind_group_raw(
&self.device,
&graph_data.energy_params_buffer,
&energies_buf.buffer,
&graph_data.partial_sums_buffer,
);
// Create command encoder
let mut encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("compute_energy_encoder"),
});
// Dispatch residuals computation
{
let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("compute_residuals_pass"),
timestamp_writes: None,
});
compute_pass.set_pipeline(self.residuals_kernel.pipeline());
compute_pass.set_bind_group(0, &residuals_bind_group, &[]);
compute_pass.dispatch_workgroups(
ComputeResidualsKernel::workgroup_count(num_edges),
1,
1,
);
}
// Dispatch energy reduction
{
let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("compute_energy_pass"),
timestamp_writes: None,
});
compute_pass.set_pipeline(self.energy_kernel.main_pipeline());
compute_pass.set_bind_group(0, &energy_bind_group, &[]);
compute_pass.dispatch_workgroups(num_workgroups, 1, 1);
}
// If we have multiple workgroups, do final reduction
if num_workgroups > 1 {
// Write final params to pre-allocated buffer (no allocation)
let final_params = EnergyParams {
num_elements: num_workgroups,
_padding: [0; 7],
};
self.queue.write_buffer(
&graph_data.final_params_buffer,
0,
bytemuck::bytes_of(&final_params),
);
let final_bind_group = self.energy_kernel.create_bind_group_raw(
&self.device,
&graph_data.final_params_buffer,
&graph_data.partial_sums_buffer,
&graph_data.total_energy_buffer,
);
{
let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("final_reduce_pass"),
timestamp_writes: None,
});
compute_pass.set_pipeline(self.energy_kernel.final_pipeline());
compute_pass.set_bind_group(0, &final_bind_group, &[]);
compute_pass.dispatch_workgroups(1, 1, 1);
}
}
// Copy results to pre-allocated staging buffers (no allocation)
encoder.copy_buffer_to_buffer(
&energies_buf.buffer,
0,
&graph_data.energies_staging,
0,
(num_edges as usize * std::mem::size_of::<f32>()) as u64,
);
if num_workgroups > 1 {
encoder.copy_buffer_to_buffer(
&graph_data.total_energy_buffer,
0,
&graph_data.total_staging,
0,
std::mem::size_of::<f32>() as u64,
);
} else {
encoder.copy_buffer_to_buffer(
&graph_data.partial_sums_buffer,
0,
&graph_data.total_staging,
0,
std::mem::size_of::<f32>() as u64,
);
}
// Submit commands
self.queue.submit(std::iter::once(encoder.finish()));
// Read back results from pre-allocated staging buffers
let edge_energies = Self::read_buffer_f32(
&self.device,
&graph_data.energies_staging,
num_edges as usize,
)
.await?;
let total_energy =
Self::read_buffer_f32(&self.device, &graph_data.total_staging, 1).await?[0];
let compute_time_us = start.elapsed().as_micros() as u64;
debug!(
"GPU energy computation: total={:.6}, {} edges, {}us (pre-allocated buffers)",
total_energy, num_edges, compute_time_us
);
Ok(GpuCoherenceEnergy {
total_energy,
edge_energies,
edge_indices: graph_data.edge_id_reverse.clone(),
compute_time_us,
used_gpu: true,
})
}
/// Read f32 buffer back to CPU
async fn read_buffer_f32(
device: &Device,
buffer: &wgpu::Buffer,
count: usize,
) -> GpuResult<Vec<f32>> {
let buffer_slice = buffer.slice(..);
let (sender, receiver) = futures::channel::oneshot::channel();
buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
let _ = sender.send(result);
});
device.poll(wgpu::Maintain::Wait);
receiver
.await
.map_err(|_| GpuError::BufferRead("Channel closed".into()))?
.map_err(|e| GpuError::BufferRead(e.to_string()))?;
let data = buffer_slice.get_mapped_range();
let result: Vec<f32> =
bytemuck::cast_slice(&data[..count * std::mem::size_of::<f32>()]).to_vec();
drop(data);
buffer.unmap();
Ok(result)
}
/// Get GPU capabilities
pub fn capabilities(&self) -> &GpuCapabilities {
&self.capabilities
}
/// Get configuration
pub fn config(&self) -> &GpuConfig {
&self.config
}
/// Check if GPU is available
pub fn is_available(&self) -> bool {
self.capabilities.supported
}
/// Release all GPU resources
pub fn release(&mut self) {
self.buffer_manager.clear();
self.graph_data = None;
}
}
/// Synchronous wrapper for GPU coherence engine using pollster
pub mod sync {
use super::*;
/// Synchronously create a GPU engine
pub fn create_engine(config: GpuConfig) -> GpuResult<GpuCoherenceEngine> {
pollster::block_on(GpuCoherenceEngine::new(config))
}
/// Try to create GPU engine synchronously
pub fn try_create_engine(config: GpuConfig) -> Option<GpuCoherenceEngine> {
pollster::block_on(GpuCoherenceEngine::try_new(config))
}
/// Compute energy synchronously
pub fn compute_energy(engine: &mut GpuCoherenceEngine) -> GpuResult<GpuCoherenceEnergy> {
pollster::block_on(engine.compute_energy())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gpu_config_default() {
let config = GpuConfig::default();
assert!(config.enable_fallback);
assert_eq!(config.beta, 1.0);
assert!(config.threshold_lane0 < config.threshold_lane1);
assert!(config.threshold_lane1 < config.threshold_lane2);
}
#[test]
fn test_gpu_params_size() {
assert_eq!(std::mem::size_of::<GpuParams>(), 32);
}
#[test]
fn test_energy_params_size() {
assert_eq!(std::mem::size_of::<EnergyParams>(), 32);
}
}

View File

@@ -0,0 +1,228 @@
//! GPU Error Types
//!
//! Error handling for GPU operations including device initialization,
//! buffer management, shader execution, and kernel dispatch.
use thiserror::Error;
/// Result type for GPU operations
pub type GpuResult<T> = Result<T, GpuError>;
/// Errors that can occur during GPU operations
#[derive(Debug, Error)]
pub enum GpuError {
/// No suitable GPU adapter found
#[error("No suitable GPU adapter found. Ensure a GPU with compute capabilities is available.")]
NoAdapter,
/// No compatible GPU device found
#[error("No compatible GPU device found: {0}")]
NoDevice(String),
/// GPU device creation failed
#[error("Failed to create GPU device: {0}")]
DeviceCreation(String),
/// Device request failed
#[error("Failed to request GPU device: {0}")]
DeviceRequestFailed(String),
/// Shader compilation failed
#[error("Shader compilation failed: {0}")]
ShaderCompilation(String),
/// Buffer allocation failed
#[error("Buffer allocation failed: {0}")]
BufferAllocation(String),
/// Buffer allocation failed with details
#[error("Buffer allocation failed: requested {requested_bytes} bytes, reason: {reason}")]
BufferAllocationFailed {
/// Number of bytes requested
requested_bytes: u64,
/// Reason for failure
reason: String,
},
/// Buffer size exceeds maximum allowed
#[error("Buffer size {size} exceeds maximum allowed {max}")]
BufferTooLarge {
/// Requested size
size: u64,
/// Maximum allowed size
max: u64,
},
/// Buffer size mismatch
#[error("Buffer size mismatch: expected {expected}, got {actual}")]
BufferSizeMismatch { expected: usize, actual: usize },
/// Buffer read-back failed
#[error("Buffer read-back failed: {0}")]
BufferReadFailed(String),
/// Buffer mapping failed
#[error("Buffer mapping failed: {0}")]
BufferMapFailed(String),
/// Dimension mismatch
#[error("Dimension mismatch: expected {expected}, got {actual}")]
DimensionMismatch { expected: usize, actual: usize },
/// Invalid binding configuration
#[error("Invalid binding configuration: expected {expected} bindings, got {actual}")]
InvalidBindingCount {
/// Expected number of bindings
expected: usize,
/// Actual number of bindings
actual: usize,
},
/// Invalid workgroup configuration
#[error("Invalid workgroup configuration: [{x}, {y}, {z}] exceeds device limits")]
InvalidWorkgroupSize {
/// X dimension
x: u32,
/// Y dimension
y: u32,
/// Z dimension
z: u32,
},
/// Compute pipeline creation failed
#[error("Failed to create compute pipeline: {0}")]
PipelineCreation(String),
/// Command encoding failed
#[error("Command encoding failed: {0}")]
CommandEncoding(String),
/// GPU execution failed
#[error("GPU execution failed: {0}")]
ExecutionFailed(String),
/// Buffer read failed
#[error("Failed to read buffer: {0}")]
BufferRead(String),
/// Buffer write failed
#[error("Failed to write buffer: {0}")]
BufferWrite(String),
/// Timeout waiting for GPU operation
#[error("GPU operation timed out after {0}ms")]
Timeout(u64),
/// Graph has no edges
#[error("Graph has no edges to compute")]
EmptyGraph,
/// Invalid configuration
#[error("Invalid GPU configuration: {0}")]
InvalidConfig(String),
/// Feature not supported
#[error("GPU feature not supported: {0}")]
UnsupportedFeature(String),
/// Adapter request failed
#[error("Failed to request GPU adapter: {0}")]
AdapterRequest(String),
/// Out of GPU memory
#[error("Out of GPU memory: requested {requested_bytes} bytes")]
OutOfMemory {
/// Number of bytes requested
requested_bytes: u64,
},
/// Device lost
#[error("GPU device lost: {0}")]
DeviceLost(String),
/// Internal error
#[error("Internal GPU error: {0}")]
Internal(String),
}
impl GpuError {
/// Check if this error indicates GPU is unavailable and fallback should be used
pub fn should_fallback(&self) -> bool {
matches!(
self,
GpuError::NoAdapter
| GpuError::NoDevice(_)
| GpuError::DeviceCreation(_)
| GpuError::DeviceRequestFailed(_)
| GpuError::AdapterRequest(_)
| GpuError::UnsupportedFeature(_)
)
}
/// Check if this error is recoverable
pub fn is_recoverable(&self) -> bool {
matches!(
self,
GpuError::Timeout(_)
| GpuError::BufferRead(_)
| GpuError::BufferReadFailed(_)
| GpuError::ExecutionFailed(_)
)
}
}
impl From<wgpu::RequestDeviceError> for GpuError {
fn from(e: wgpu::RequestDeviceError) -> Self {
Self::DeviceRequestFailed(e.to_string())
}
}
impl From<wgpu::BufferAsyncError> for GpuError {
fn from(e: wgpu::BufferAsyncError) -> Self {
Self::BufferMapFailed(e.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_should_fallback() {
assert!(GpuError::NoAdapter.should_fallback());
assert!(GpuError::NoDevice("test".into()).should_fallback());
assert!(GpuError::DeviceCreation("test".into()).should_fallback());
assert!(!GpuError::Timeout(100).should_fallback());
assert!(!GpuError::EmptyGraph.should_fallback());
}
#[test]
fn test_is_recoverable() {
assert!(GpuError::Timeout(100).is_recoverable());
assert!(GpuError::BufferRead("test".into()).is_recoverable());
assert!(GpuError::BufferReadFailed("test".into()).is_recoverable());
assert!(!GpuError::NoDevice("test".into()).is_recoverable());
assert!(!GpuError::NoAdapter.is_recoverable());
}
#[test]
fn test_error_display() {
let err = GpuError::BufferAllocationFailed {
requested_bytes: 1024,
reason: "out of memory".to_string(),
};
assert!(err.to_string().contains("1024"));
assert!(err.to_string().contains("out of memory"));
}
#[test]
fn test_workgroup_error() {
let err = GpuError::InvalidWorkgroupSize {
x: 1000,
y: 1,
z: 1,
};
let msg = err.to_string();
assert!(msg.contains("1000"));
}
}

View File

@@ -0,0 +1,760 @@
//! GPU Kernel Wrappers
//!
//! Provides Rust wrappers around WGSL compute shaders for coherence computation.
//! Each kernel handles pipeline creation, bind group setup, and dispatch.
use super::buffer::{
BufferUsage, GpuBuffer, GpuBufferManager, GpuEdge, GpuParams, GpuRestrictionMap,
};
use super::error::{GpuError, GpuResult};
use super::shaders;
use super::workgroup;
use bytemuck::{Pod, Zeroable};
use std::sync::Arc;
use wgpu::{
BindGroup, BindGroupDescriptor, BindGroupEntry, BindGroupLayout, BindGroupLayoutDescriptor,
BindGroupLayoutEntry, BindingResource, BindingType, BufferBindingType, ComputePipeline,
ComputePipelineDescriptor, Device, PipelineLayoutDescriptor, Queue, ShaderModule,
ShaderModuleDescriptor, ShaderSource, ShaderStages,
};
/// Compute residuals kernel
/// Computes r_e = rho_source(x_source) - rho_target(x_target) for all edges
pub struct ComputeResidualsKernel {
pipeline: ComputePipeline,
bind_group_layout: BindGroupLayout,
}
impl ComputeResidualsKernel {
/// Create a new compute residuals kernel
pub fn new(device: &Device) -> GpuResult<Self> {
let shader = device.create_shader_module(ShaderModuleDescriptor {
label: Some("compute_residuals"),
source: ShaderSource::Wgsl(shaders::COMPUTE_RESIDUALS.into()),
});
let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: Some("compute_residuals_bind_group_layout"),
entries: &[
// Params uniform
BindGroupLayoutEntry {
binding: 0,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
// Node states
BindGroupLayoutEntry {
binding: 1,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
// Edges
BindGroupLayoutEntry {
binding: 2,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
// Restriction maps
BindGroupLayoutEntry {
binding: 3,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
// Restriction data
BindGroupLayoutEntry {
binding: 4,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
// Residuals output
BindGroupLayoutEntry {
binding: 5,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
// Residual norms output
BindGroupLayoutEntry {
binding: 6,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
label: Some("compute_residuals_pipeline_layout"),
bind_group_layouts: &[&bind_group_layout],
push_constant_ranges: &[],
});
let pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
label: Some("compute_residuals_pipeline"),
layout: Some(&pipeline_layout),
module: &shader,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
});
Ok(Self {
pipeline,
bind_group_layout,
})
}
/// Create a bind group for execution
pub fn create_bind_group(
&self,
device: &Device,
params_buffer: &GpuBuffer,
node_states_buffer: &GpuBuffer,
edges_buffer: &GpuBuffer,
restriction_maps_buffer: &GpuBuffer,
restriction_data_buffer: &GpuBuffer,
residuals_buffer: &GpuBuffer,
residual_norms_buffer: &GpuBuffer,
) -> BindGroup {
device.create_bind_group(&BindGroupDescriptor {
label: Some("compute_residuals_bind_group"),
layout: &self.bind_group_layout,
entries: &[
BindGroupEntry {
binding: 0,
resource: params_buffer.buffer.as_entire_binding(),
},
BindGroupEntry {
binding: 1,
resource: node_states_buffer.buffer.as_entire_binding(),
},
BindGroupEntry {
binding: 2,
resource: edges_buffer.buffer.as_entire_binding(),
},
BindGroupEntry {
binding: 3,
resource: restriction_maps_buffer.buffer.as_entire_binding(),
},
BindGroupEntry {
binding: 4,
resource: restriction_data_buffer.buffer.as_entire_binding(),
},
BindGroupEntry {
binding: 5,
resource: residuals_buffer.buffer.as_entire_binding(),
},
BindGroupEntry {
binding: 6,
resource: residual_norms_buffer.buffer.as_entire_binding(),
},
],
})
}
/// Create a bind group using raw wgpu buffers (for pre-allocated buffer optimization)
pub fn create_bind_group_raw(
&self,
device: &Device,
params_buffer: &wgpu::Buffer,
node_states_buffer: &wgpu::Buffer,
edges_buffer: &wgpu::Buffer,
restriction_maps_buffer: &wgpu::Buffer,
restriction_data_buffer: &wgpu::Buffer,
residuals_buffer: &wgpu::Buffer,
residual_norms_buffer: &wgpu::Buffer,
) -> BindGroup {
device.create_bind_group(&BindGroupDescriptor {
label: Some("compute_residuals_bind_group_raw"),
layout: &self.bind_group_layout,
entries: &[
BindGroupEntry {
binding: 0,
resource: params_buffer.as_entire_binding(),
},
BindGroupEntry {
binding: 1,
resource: node_states_buffer.as_entire_binding(),
},
BindGroupEntry {
binding: 2,
resource: edges_buffer.as_entire_binding(),
},
BindGroupEntry {
binding: 3,
resource: restriction_maps_buffer.as_entire_binding(),
},
BindGroupEntry {
binding: 4,
resource: restriction_data_buffer.as_entire_binding(),
},
BindGroupEntry {
binding: 5,
resource: residuals_buffer.as_entire_binding(),
},
BindGroupEntry {
binding: 6,
resource: residual_norms_buffer.as_entire_binding(),
},
],
})
}
/// Get the pipeline for use in command encoder
pub fn pipeline(&self) -> &ComputePipeline {
&self.pipeline
}
/// Calculate number of workgroups needed
pub fn workgroup_count(num_edges: u32) -> u32 {
// One thread per edge, 256 threads per workgroup
(num_edges + workgroup::SIZE_1D - 1) / workgroup::SIZE_1D
}
}
/// Compute energy kernel with parallel reduction
pub struct ComputeEnergyKernel {
main_pipeline: ComputePipeline,
final_pipeline: ComputePipeline,
bind_group_layout: BindGroupLayout,
}
/// Parameters for energy reduction
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct EnergyParams {
/// Number of elements to reduce
pub num_elements: u32,
/// Padding
pub _padding: [u32; 7],
}
impl ComputeEnergyKernel {
/// Create a new compute energy kernel
pub fn new(device: &Device) -> GpuResult<Self> {
let shader = device.create_shader_module(ShaderModuleDescriptor {
label: Some("compute_energy"),
source: ShaderSource::Wgsl(shaders::COMPUTE_ENERGY.into()),
});
let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: Some("compute_energy_bind_group_layout"),
entries: &[
// Params uniform
BindGroupLayoutEntry {
binding: 0,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
// Input energies
BindGroupLayoutEntry {
binding: 1,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
// Output partial sums
BindGroupLayoutEntry {
binding: 2,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
label: Some("compute_energy_pipeline_layout"),
bind_group_layouts: &[&bind_group_layout],
push_constant_ranges: &[],
});
let main_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
label: Some("compute_energy_main_pipeline"),
layout: Some(&pipeline_layout),
module: &shader,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
});
let final_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
label: Some("compute_energy_final_pipeline"),
layout: Some(&pipeline_layout),
module: &shader,
entry_point: Some("final_reduce"),
compilation_options: Default::default(),
cache: None,
});
Ok(Self {
main_pipeline,
final_pipeline,
bind_group_layout,
})
}
/// Create a bind group for execution
pub fn create_bind_group(
&self,
device: &Device,
params_buffer: &GpuBuffer,
input_buffer: &GpuBuffer,
output_buffer: &GpuBuffer,
) -> BindGroup {
device.create_bind_group(&BindGroupDescriptor {
label: Some("compute_energy_bind_group"),
layout: &self.bind_group_layout,
entries: &[
BindGroupEntry {
binding: 0,
resource: params_buffer.buffer.as_entire_binding(),
},
BindGroupEntry {
binding: 1,
resource: input_buffer.buffer.as_entire_binding(),
},
BindGroupEntry {
binding: 2,
resource: output_buffer.buffer.as_entire_binding(),
},
],
})
}
/// Create a bind group using raw wgpu buffers (for pre-allocated buffer optimization)
pub fn create_bind_group_raw(
&self,
device: &Device,
params_buffer: &wgpu::Buffer,
input_buffer: &wgpu::Buffer,
output_buffer: &wgpu::Buffer,
) -> BindGroup {
device.create_bind_group(&BindGroupDescriptor {
label: Some("compute_energy_bind_group_raw"),
layout: &self.bind_group_layout,
entries: &[
BindGroupEntry {
binding: 0,
resource: params_buffer.as_entire_binding(),
},
BindGroupEntry {
binding: 1,
resource: input_buffer.as_entire_binding(),
},
BindGroupEntry {
binding: 2,
resource: output_buffer.as_entire_binding(),
},
],
})
}
/// Get the main reduction pipeline
pub fn main_pipeline(&self) -> &ComputePipeline {
&self.main_pipeline
}
/// Get the final reduction pipeline
pub fn final_pipeline(&self) -> &ComputePipeline {
&self.final_pipeline
}
/// Calculate number of workgroups for first pass
pub fn workgroup_count(num_elements: u32) -> u32 {
// One element per thread, 256 threads per workgroup
(num_elements + workgroup::SIZE_1D - 1) / workgroup::SIZE_1D
}
}
/// Sheaf attention kernel
pub struct SheafAttentionKernel {
single_pass_pipeline: ComputePipeline,
bind_group_layout: BindGroupLayout,
}
/// Attention weight output
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct AttentionWeight {
pub edge_idx: u32,
pub source_idx: u32,
pub target_idx: u32,
pub raw_score: f32,
pub attention: f32,
pub _padding: [u32; 3],
}
impl SheafAttentionKernel {
/// Create a new sheaf attention kernel
pub fn new(device: &Device) -> GpuResult<Self> {
let shader = device.create_shader_module(ShaderModuleDescriptor {
label: Some("sheaf_attention"),
source: ShaderSource::Wgsl(shaders::SHEAF_ATTENTION.into()),
});
let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: Some("sheaf_attention_bind_group_layout"),
entries: &[
// Params
BindGroupLayoutEntry {
binding: 0,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
// Edges
BindGroupLayoutEntry {
binding: 1,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
// Edge energies
BindGroupLayoutEntry {
binding: 2,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
// Attention weights output
BindGroupLayoutEntry {
binding: 3,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
// Node exp sums (for normalization)
BindGroupLayoutEntry {
binding: 4,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
label: Some("sheaf_attention_pipeline_layout"),
bind_group_layouts: &[&bind_group_layout],
push_constant_ranges: &[],
});
let single_pass_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
label: Some("sheaf_attention_single_pass_pipeline"),
layout: Some(&pipeline_layout),
module: &shader,
entry_point: Some("compute_attention_single_pass"),
compilation_options: Default::default(),
cache: None,
});
Ok(Self {
single_pass_pipeline,
bind_group_layout,
})
}
/// Create a bind group
pub fn create_bind_group(
&self,
device: &Device,
params_buffer: &GpuBuffer,
edges_buffer: &GpuBuffer,
edge_energies_buffer: &GpuBuffer,
attention_weights_buffer: &GpuBuffer,
node_exp_sums_buffer: &GpuBuffer,
) -> BindGroup {
device.create_bind_group(&BindGroupDescriptor {
label: Some("sheaf_attention_bind_group"),
layout: &self.bind_group_layout,
entries: &[
BindGroupEntry {
binding: 0,
resource: params_buffer.buffer.as_entire_binding(),
},
BindGroupEntry {
binding: 1,
resource: edges_buffer.buffer.as_entire_binding(),
},
BindGroupEntry {
binding: 2,
resource: edge_energies_buffer.buffer.as_entire_binding(),
},
BindGroupEntry {
binding: 3,
resource: attention_weights_buffer.buffer.as_entire_binding(),
},
BindGroupEntry {
binding: 4,
resource: node_exp_sums_buffer.buffer.as_entire_binding(),
},
],
})
}
/// Get the single-pass pipeline
pub fn pipeline(&self) -> &ComputePipeline {
&self.single_pass_pipeline
}
/// Calculate workgroup count
pub fn workgroup_count(num_edges: u32) -> u32 {
(num_edges + workgroup::SIZE_1D - 1) / workgroup::SIZE_1D
}
}
/// Token routing kernel
pub struct TokenRoutingKernel {
route_pipeline: ComputePipeline,
bind_group_layout: BindGroupLayout,
}
/// Token input
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct Token {
pub token_id: u32,
pub node_idx: u32,
pub action_type: u32,
pub priority: f32,
}
/// Routing decision output
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct RoutingDecision {
pub token_id: u32,
pub assigned_lane: u32,
pub local_energy: f32,
pub confidence: f32,
pub escalation_reason: u32,
pub num_high_energy_edges: u32,
pub max_edge_energy: f32,
pub _padding: u32,
}
/// Lane statistics
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct LaneStats {
pub lane_counts: [u32; 4],
pub total_energy_per_lane: [f32; 4],
pub _padding: [u32; 8],
}
impl TokenRoutingKernel {
/// Create a new token routing kernel
pub fn new(device: &Device) -> GpuResult<Self> {
let shader = device.create_shader_module(ShaderModuleDescriptor {
label: Some("token_routing"),
source: ShaderSource::Wgsl(shaders::TOKEN_ROUTING.into()),
});
let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: Some("token_routing_bind_group_layout"),
entries: &[
// Params
BindGroupLayoutEntry {
binding: 0,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
// Tokens
BindGroupLayoutEntry {
binding: 1,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
// Local energies
BindGroupLayoutEntry {
binding: 2,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
// Edge energies
BindGroupLayoutEntry {
binding: 3,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
// Node edge counts
BindGroupLayoutEntry {
binding: 4,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
// Node edge offsets
BindGroupLayoutEntry {
binding: 5,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
// Node edges
BindGroupLayoutEntry {
binding: 6,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
// Routing decisions output
BindGroupLayoutEntry {
binding: 7,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
// Lane stats output
BindGroupLayoutEntry {
binding: 8,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
label: Some("token_routing_pipeline_layout"),
bind_group_layouts: &[&bind_group_layout],
push_constant_ranges: &[],
});
let route_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
label: Some("token_routing_pipeline"),
layout: Some(&pipeline_layout),
module: &shader,
entry_point: Some("route_tokens"),
compilation_options: Default::default(),
cache: None,
});
Ok(Self {
route_pipeline,
bind_group_layout,
})
}
/// Get the routing pipeline
pub fn pipeline(&self) -> &ComputePipeline {
&self.route_pipeline
}
/// Get bind group layout
pub fn bind_group_layout(&self) -> &BindGroupLayout {
&self.bind_group_layout
}
/// Calculate workgroup count
pub fn workgroup_count(num_tokens: u32) -> u32 {
(num_tokens + workgroup::SIZE_1D - 1) / workgroup::SIZE_1D
}
}

View File

@@ -0,0 +1,156 @@
//! GPU acceleration module for Prime-Radiant coherence engine.
//!
//! This module provides GPU-accelerated computation using wgpu for:
//! - Parallel residual calculations across large graphs
//! - Matrix operations for restriction maps
//! - Energy aggregation with atomic operations
//! - Spectral analysis via power iteration
//!
//! # Architecture
//!
//! ```text
//! +------------------+ +------------------+ +------------------+
//! | GpuDevice |---->| GpuBuffer |---->| GpuDispatcher |
//! | (Init/Queue) | | (Alloc/Transfer)| | (Kernels/Sync) |
//! +------------------+ +------------------+ +------------------+
//! | | |
//! v v v
//! +------------------+ +------------------+ +------------------+
//! | Instance/Adapter | | BufferPool | | PipelineCache |
//! | Device/Queue | | Read/Write | | BindGroups |
//! +------------------+ +------------------+ +------------------+
//! ```
//!
//! # Feature Flag
//!
//! This module requires the `gpu` feature flag:
//! ```toml
//! [dependencies]
//! prime-radiant = { version = "0.1", features = ["gpu"] }
//! ```
//!
//! # Example
//!
//! ```rust,ignore
//! use prime_radiant::gpu::{GpuDevice, GpuBuffer, GpuDispatcher, ComputePipeline};
//!
//! #[tokio::main]
//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
//! // Initialize GPU device
//! let device = GpuDevice::new().await?;
//!
//! // Create storage buffer with data
//! let input_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
//! let input_buffer = GpuBuffer::new_storage(device.device(), &input_data, false);
//!
//! // Create output buffer
//! let output_buffer = GpuBuffer::new_storage_uninit::<f32>(
//! device.device(),
//! input_data.len(),
//! true,
//! );
//!
//! // Create compute pipeline
//! let pipeline = ComputePipeline::from_shader(
//! device.device(),
//! include_str!("shaders/compute_residuals.wgsl"),
//! "main",
//! &[BindingDesc::storage_readonly(), BindingDesc::storage_readwrite()],
//! )?;
//!
//! // Create dispatcher and execute
//! let dispatcher = GpuDispatcher::new(Arc::new(device));
//! let bind_group = pipeline.create_bind_group(
//! dispatcher.device().device(),
//! &[&input_buffer, &output_buffer],
//! )?;
//! dispatcher.dispatch(&pipeline, &bind_group, [4, 1, 1]).await?;
//!
//! Ok(())
//! }
//! ```
//!
//! # GPU Kernels
//!
//! The following WGSL compute shaders are implemented:
//!
//! 1. **compute_residuals.wgsl** - Parallel residual computation for all edges
//! 2. **compute_energy.wgsl** - Parallel energy aggregation with tree reduction
//! 3. **sheaf_attention.wgsl** - Batched attention: A_ij = exp(-beta * E_ij) / Z
//! 4. **token_routing.wgsl** - Parallel lane assignment based on energy thresholds
//!
//! # Performance Targets
//!
//! | Operation | Target | Notes |
//! |-----------|--------|-------|
//! | Buffer allocation | < 1ms | Pooled for hot paths |
//! | Kernel dispatch | < 100us | Excludes GPU execution |
//! | Residual (10K edges) | < 1ms | GPU parallel |
//! | Energy aggregation | < 500us | Atomic reduction |
mod buffer;
mod device;
mod dispatch;
mod engine;
mod error;
mod kernels;
mod pipeline;
// Core exports
pub use buffer::{
BufferKey, BufferUsage, BufferUsageFlags, GpuBuffer, GpuBufferManager, GpuBufferPool,
};
pub use device::{GpuDevice, GpuDeviceInfo, GpuDeviceOptions};
pub use dispatch::{DispatchBuilder, DispatchConfig, GpuDispatcher};
pub use error::{GpuError, GpuResult};
pub use pipeline::{BindingDesc, BindingType, ComputePipeline, PipelineCache};
// Re-export buffer types
pub use buffer::{GpuEdge, GpuNodeState, GpuParams, GpuRestrictionMap};
// Re-export engine types
pub use engine::{GpuCapabilities, GpuCoherenceEnergy, GpuCoherenceEngine, GpuConfig};
/// Synchronous API for GPU coherence engine (uses pollster)
pub mod sync {
pub use super::engine::sync::*;
}
// Re-export kernel types
pub use kernels::{
AttentionWeight, ComputeEnergyKernel, ComputeResidualsKernel, EnergyParams, LaneStats,
RoutingDecision, SheafAttentionKernel, Token, TokenRoutingKernel,
};
/// Default workgroup size for compute shaders
pub const DEFAULT_WORKGROUP_SIZE: u32 = 256;
/// Maximum buffer size for a single allocation (256MB)
pub const MAX_BUFFER_SIZE: u64 = 256 * 1024 * 1024;
/// Default pool capacity for buffer reuse
pub const DEFAULT_POOL_CAPACITY: usize = 32;
/// Shader source code embedded at compile time
pub mod shaders {
/// Compute residuals shader for parallel edge residual computation
pub const COMPUTE_RESIDUALS: &str = include_str!("shaders/compute_residuals.wgsl");
/// Compute energy shader for parallel reduction
pub const COMPUTE_ENERGY: &str = include_str!("shaders/compute_energy.wgsl");
/// Sheaf attention shader for attention weight computation
pub const SHEAF_ATTENTION: &str = include_str!("shaders/sheaf_attention.wgsl");
/// Token routing shader for lane assignment
pub const TOKEN_ROUTING: &str = include_str!("shaders/token_routing.wgsl");
}
/// GPU workgroup size constants
pub mod workgroup {
/// Default workgroup size for 1D compute
pub const SIZE_1D: u32 = 256;
/// Default workgroup size for 2D compute (x dimension)
pub const SIZE_2D_X: u32 = 16;
/// Default workgroup size for 2D compute (y dimension)
pub const SIZE_2D_Y: u32 = 16;
/// Maximum state vector dimension for GPU kernels
pub const MAX_STATE_DIM: u32 = 512;
}

View File

@@ -0,0 +1,513 @@
//! Compute pipeline management for GPU operations.
//!
//! This module handles shader compilation, pipeline creation, and bind group
//! management for GPU compute operations.
use dashmap::DashMap;
use std::sync::Arc;
use tracing::{debug, info};
use wgpu::{Device, ShaderModule};
use super::buffer::GpuBuffer;
use super::error::{GpuError, GpuResult};
use super::DEFAULT_WORKGROUP_SIZE;
/// Type of binding in a compute shader
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BindingType {
/// Storage buffer (read-only)
StorageReadonly,
/// Storage buffer (read-write)
StorageReadWrite,
/// Uniform buffer
Uniform,
}
impl BindingType {
/// Convert to wgpu binding type
fn to_wgpu(&self) -> wgpu::BindingType {
match self {
Self::StorageReadonly => wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
Self::StorageReadWrite => wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
Self::Uniform => wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
}
}
}
/// Description of a binding in a compute shader
#[derive(Debug, Clone)]
pub struct BindingDesc {
/// Binding type
pub binding_type: BindingType,
/// Optional label for debugging
pub label: Option<String>,
}
impl BindingDesc {
/// Create a storage read-only binding
pub fn storage_readonly() -> Self {
Self {
binding_type: BindingType::StorageReadonly,
label: None,
}
}
/// Create a storage read-write binding
pub fn storage_readwrite() -> Self {
Self {
binding_type: BindingType::StorageReadWrite,
label: None,
}
}
/// Create a uniform binding
pub fn uniform() -> Self {
Self {
binding_type: BindingType::Uniform,
label: None,
}
}
/// Add a label to the binding
pub fn with_label(mut self, label: impl Into<String>) -> Self {
self.label = Some(label.into());
self
}
}
/// Compute pipeline wrapper
pub struct ComputePipeline {
pipeline: wgpu::ComputePipeline,
bind_group_layout: wgpu::BindGroupLayout,
workgroup_size: [u32; 3],
entry_point: String,
binding_count: usize,
}
impl ComputePipeline {
/// Create a new compute pipeline from shader source.
///
/// # Arguments
///
/// * `device` - The wgpu device
/// * `shader_source` - WGSL shader source code
/// * `entry_point` - Entry point function name
/// * `bindings` - Binding descriptions
///
/// # Example
///
/// ```rust,ignore
/// let pipeline = ComputePipeline::from_shader(
/// &device,
/// r#"
/// @group(0) @binding(0) var<storage, read> input: array<f32>;
/// @group(0) @binding(1) var<storage, read_write> output: array<f32>;
///
/// @compute @workgroup_size(256)
/// fn main(@builtin(global_invocation_id) id: vec3<u32>) {
/// output[id.x] = input[id.x] * 2.0;
/// }
/// "#,
/// "main",
/// &[BindingDesc::storage_readonly(), BindingDesc::storage_readwrite()],
/// );
/// ```
pub fn from_shader(
device: &Device,
shader_source: &str,
entry_point: &str,
bindings: &[BindingDesc],
) -> GpuResult<Self> {
Self::from_shader_with_workgroup_size(
device,
shader_source,
entry_point,
bindings,
[DEFAULT_WORKGROUP_SIZE, 1, 1],
)
}
/// Create a pipeline with custom workgroup size.
pub fn from_shader_with_workgroup_size(
device: &Device,
shader_source: &str,
entry_point: &str,
bindings: &[BindingDesc],
workgroup_size: [u32; 3],
) -> GpuResult<Self> {
debug!(
"Creating compute pipeline with entry point '{}' and {} bindings",
entry_point,
bindings.len()
);
// Create shader module
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("compute_shader"),
source: wgpu::ShaderSource::Wgsl(shader_source.into()),
});
Self::from_module(device, &shader, entry_point, bindings, workgroup_size)
}
/// Create a pipeline from a pre-compiled shader module.
pub fn from_module(
device: &Device,
shader: &ShaderModule,
entry_point: &str,
bindings: &[BindingDesc],
workgroup_size: [u32; 3],
) -> GpuResult<Self> {
// Create bind group layout entries
let layout_entries: Vec<wgpu::BindGroupLayoutEntry> = bindings
.iter()
.enumerate()
.map(|(i, desc)| wgpu::BindGroupLayoutEntry {
binding: i as u32,
visibility: wgpu::ShaderStages::COMPUTE,
ty: desc.binding_type.to_wgpu(),
count: None,
})
.collect();
// Create bind group layout
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("compute_bind_group_layout"),
entries: &layout_entries,
});
// Create pipeline layout
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("compute_pipeline_layout"),
bind_group_layouts: &[&bind_group_layout],
push_constant_ranges: &[],
});
// Create compute pipeline
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("compute_pipeline"),
layout: Some(&pipeline_layout),
module: shader,
entry_point: Some(entry_point),
compilation_options: wgpu::PipelineCompilationOptions::default(),
cache: None,
});
Ok(Self {
pipeline,
bind_group_layout,
workgroup_size,
entry_point: entry_point.to_string(),
binding_count: bindings.len(),
})
}
/// Create a bind group for this pipeline.
///
/// # Arguments
///
/// * `device` - The wgpu device
/// * `buffers` - Buffers to bind, in order
///
/// # Panics
///
/// Panics if the number of buffers doesn't match the pipeline's binding count.
pub fn create_bind_group(
&self,
device: &Device,
buffers: &[&GpuBuffer],
) -> GpuResult<wgpu::BindGroup> {
if buffers.len() != self.binding_count {
return Err(GpuError::InvalidBindingCount {
expected: self.binding_count,
actual: buffers.len(),
});
}
let entries: Vec<wgpu::BindGroupEntry> = buffers
.iter()
.enumerate()
.map(|(i, buffer)| buffer.binding(i as u32))
.collect();
Ok(device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("compute_bind_group"),
layout: &self.bind_group_layout,
entries: &entries,
}))
}
/// Get the underlying wgpu pipeline
pub fn pipeline(&self) -> &wgpu::ComputePipeline {
&self.pipeline
}
/// Get the bind group layout
pub fn bind_group_layout(&self) -> &wgpu::BindGroupLayout {
&self.bind_group_layout
}
/// Get the workgroup size
pub fn workgroup_size(&self) -> [u32; 3] {
self.workgroup_size
}
/// Get the entry point name
pub fn entry_point(&self) -> &str {
&self.entry_point
}
/// Get the number of bindings
pub fn binding_count(&self) -> usize {
self.binding_count
}
/// Calculate workgroup count for a given data size.
pub fn calculate_workgroups(&self, data_size: u32) -> [u32; 3] {
let x = (data_size + self.workgroup_size[0] - 1) / self.workgroup_size[0];
[x, 1, 1]
}
/// Calculate workgroup count for 2D data.
pub fn calculate_workgroups_2d(&self, width: u32, height: u32) -> [u32; 3] {
let x = (width + self.workgroup_size[0] - 1) / self.workgroup_size[0];
let y = (height + self.workgroup_size[1] - 1) / self.workgroup_size[1];
[x, y, 1]
}
/// Calculate workgroup count for 3D data.
pub fn calculate_workgroups_3d(&self, width: u32, height: u32, depth: u32) -> [u32; 3] {
let x = (width + self.workgroup_size[0] - 1) / self.workgroup_size[0];
let y = (height + self.workgroup_size[1] - 1) / self.workgroup_size[1];
let z = (depth + self.workgroup_size[2] - 1) / self.workgroup_size[2];
[x, y, z]
}
}
/// Cache for compute pipelines
pub struct PipelineCache {
device: Arc<Device>,
pipelines: DashMap<String, Arc<ComputePipeline>>,
}
impl PipelineCache {
/// Create a new pipeline cache
pub fn new(device: Arc<Device>) -> Self {
Self {
device,
pipelines: DashMap::new(),
}
}
/// Get or create a pipeline.
///
/// # Arguments
///
/// * `name` - Unique name for the pipeline
/// * `shader_source` - WGSL shader source
/// * `entry_point` - Entry point function name
/// * `bindings` - Binding descriptions
pub fn get_or_create(
&self,
name: &str,
shader_source: &str,
entry_point: &str,
bindings: &[BindingDesc],
) -> GpuResult<Arc<ComputePipeline>> {
if let Some(pipeline) = self.pipelines.get(name) {
return Ok(Arc::clone(&pipeline));
}
info!("Creating and caching pipeline: {}", name);
let pipeline =
ComputePipeline::from_shader(&self.device, shader_source, entry_point, bindings)?;
let pipeline = Arc::new(pipeline);
self.pipelines
.insert(name.to_string(), Arc::clone(&pipeline));
Ok(pipeline)
}
/// Get a cached pipeline by name.
pub fn get(&self, name: &str) -> Option<Arc<ComputePipeline>> {
self.pipelines.get(name).map(|p| Arc::clone(&p))
}
/// Check if a pipeline exists in cache.
pub fn contains(&self, name: &str) -> bool {
self.pipelines.contains_key(name)
}
/// Remove a pipeline from cache.
pub fn remove(&self, name: &str) -> Option<Arc<ComputePipeline>> {
self.pipelines.remove(name).map(|(_, p)| p)
}
/// Clear all cached pipelines.
pub fn clear(&self) {
self.pipelines.clear();
}
/// Get the number of cached pipelines.
pub fn len(&self) -> usize {
self.pipelines.len()
}
/// Check if the cache is empty.
pub fn is_empty(&self) -> bool {
self.pipelines.is_empty()
}
/// List all cached pipeline names.
pub fn names(&self) -> Vec<String> {
self.pipelines.iter().map(|e| e.key().clone()).collect()
}
}
/// Pre-defined shaders for common coherence operations
pub mod shaders {
/// WGSL shader for computing residuals
pub const RESIDUAL_COMPUTE: &str = r#"
// Node states: [node_count, dim]
@group(0) @binding(0) var<storage, read> node_states: array<f32>;
// Edge info: [edge_count, 4] - source_idx, target_idx, weight, padding
@group(0) @binding(1) var<storage, read> edges: array<vec4<f32>>;
// Restriction map (identity for simplicity): [dim, dim]
@group(0) @binding(2) var<storage, read> restriction: array<f32>;
// Output residuals: [edge_count]
@group(0) @binding(3) var<storage, read_write> residuals: array<f32>;
// Params: [dim, node_count, edge_count, 0]
@group(0) @binding(4) var<uniform> params: vec4<u32>;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
let edge_idx = id.x;
let edge_count = params.z;
let dim = params.x;
if (edge_idx >= edge_count) {
return;
}
let edge = edges[edge_idx];
let source_idx = u32(edge.x);
let target_idx = u32(edge.y);
let weight = edge.z;
// Compute residual = ||rho_u(x_u) - rho_v(x_v)||^2
var residual: f32 = 0.0;
for (var d: u32 = 0u; d < dim; d = d + 1u) {
let source_val = node_states[source_idx * dim + d];
let target_val = node_states[target_idx * dim + d];
let diff = source_val - target_val;
residual = residual + diff * diff;
}
residuals[edge_idx] = weight * residual;
}
"#;
/// WGSL shader for parallel reduction (sum)
pub const REDUCE_SUM: &str = r#"
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
@group(0) @binding(2) var<uniform> count: u32;
var<workgroup> shared_data: array<f32, 256>;
@compute @workgroup_size(256)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) workgroup_id: vec3<u32>
) {
let tid = local_id.x;
let gid = global_id.x;
// Load data into shared memory
if (gid < count) {
shared_data[tid] = input[gid];
} else {
shared_data[tid] = 0.0;
}
workgroupBarrier();
// Parallel reduction
for (var s: u32 = 128u; s > 0u; s = s >> 1u) {
if (tid < s) {
shared_data[tid] = shared_data[tid] + shared_data[tid + s];
}
workgroupBarrier();
}
// Write result
if (tid == 0u) {
output[workgroup_id.x] = shared_data[0];
}
}
"#;
/// WGSL shader for matrix-vector multiplication
pub const MATVEC: &str = r#"
@group(0) @binding(0) var<storage, read> matrix: array<f32>;
@group(0) @binding(1) var<storage, read> vector: array<f32>;
@group(0) @binding(2) var<storage, read_write> result: array<f32>;
// params: [rows, cols, 0, 0]
@group(0) @binding(3) var<uniform> params: vec4<u32>;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
let row = id.x;
let rows = params.x;
let cols = params.y;
if (row >= rows) {
return;
}
var sum: f32 = 0.0;
for (var c: u32 = 0u; c < cols; c = c + 1u) {
sum = sum + matrix[row * cols + c] * vector[c];
}
result[row] = sum;
}
"#;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_binding_desc() {
let readonly = BindingDesc::storage_readonly();
assert_eq!(readonly.binding_type, BindingType::StorageReadonly);
let readwrite = BindingDesc::storage_readwrite();
assert_eq!(readwrite.binding_type, BindingType::StorageReadWrite);
let uniform = BindingDesc::uniform();
assert_eq!(uniform.binding_type, BindingType::Uniform);
}
#[test]
fn test_binding_with_label() {
let binding = BindingDesc::storage_readonly().with_label("input_buffer");
assert_eq!(binding.label.as_deref(), Some("input_buffer"));
}
}

View File

@@ -0,0 +1,134 @@
// =============================================================================
// Prime-Radiant GPU Compute Shaders - Energy Computation
// =============================================================================
//
// Parallel reduction to compute total coherence energy:
// E(S) = sum(w_e * |r_e|^2)
//
// Uses a two-phase reduction strategy:
// 1. Local reduction within workgroups using shared memory
// 2. Global reduction across workgroup partial sums
// =============================================================================
// TYPE DEFINITIONS
// =============================================================================
struct EnergyParams {
num_elements: u32,
_padding0: u32,
_padding1: u32,
_padding2: u32,
_padding3: u32,
_padding4: u32,
_padding5: u32,
_padding6: u32,
}
const WORKGROUP_SIZE: u32 = 256u;
// =============================================================================
// BUFFER BINDINGS
// =============================================================================
// Layout matches Rust kernel bind group:
// binding 0: params (uniform)
// binding 1: input (storage, read) - edge energies or partial sums
// binding 2: output (storage, read_write) - partial sums or final result
/// Energy computation parameters
@group(0) @binding(0) var<uniform> params: EnergyParams;
/// Input values to reduce
@group(0) @binding(1) var<storage, read> input_values: array<f32>;
/// Output partial sums or final result
@group(0) @binding(2) var<storage, read_write> output_values: array<f32>;
// =============================================================================
// SHARED MEMORY
// =============================================================================
/// Shared memory for parallel reduction
var<workgroup> shared_data: array<f32, 256>;
// =============================================================================
// MAIN REDUCTION KERNEL
// =============================================================================
/// Phase 1: Reduce input values within workgroup
@compute @workgroup_size(256)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) workgroup_id: vec3<u32>
) {
let tid = local_id.x;
let gid = global_id.x;
let element_count = params.num_elements;
// Load element (or 0 if out of bounds)
var val: f32 = 0.0;
if (gid < element_count) {
val = input_values[gid];
}
// Store in shared memory
shared_data[tid] = val;
workgroupBarrier();
// Tree reduction with sequential addressing
for (var stride = WORKGROUP_SIZE / 2u; stride > 0u; stride >>= 1u) {
if (tid < stride) {
shared_data[tid] += shared_data[tid + stride];
}
workgroupBarrier();
}
// Thread 0 writes the partial sum
if (tid == 0u) {
output_values[workgroup_id.x] = shared_data[0];
}
}
// =============================================================================
// FINAL REDUCTION PASS
// =============================================================================
/// Phase 2: Reduce partial sums to final total
/// Reads from input_values (the partial sums from phase 1)
/// Writes result to output_values[0]
@compute @workgroup_size(256)
fn final_reduce(
@builtin(local_invocation_id) local_id: vec3<u32>
) {
let tid = local_id.x;
let element_count = params.num_elements;
// Load partial sum from input (or 0 if out of bounds)
var sum: f32 = 0.0;
if (tid < element_count) {
sum = input_values[tid];
}
// Handle case where we have more partial sums than workgroup size
var idx = tid + WORKGROUP_SIZE;
while (idx < element_count) {
sum += input_values[idx];
idx += WORKGROUP_SIZE;
}
shared_data[tid] = sum;
workgroupBarrier();
// Tree reduction
for (var stride = WORKGROUP_SIZE / 2u; stride > 0u; stride >>= 1u) {
if (tid < stride) {
shared_data[tid] += shared_data[tid + stride];
}
workgroupBarrier();
}
// Write final result to output[0]
if (tid == 0u) {
output_values[0] = shared_data[0];
}
}

View File

@@ -0,0 +1,223 @@
// =============================================================================
// Prime-Radiant GPU Compute Shaders - Residual Computation
// =============================================================================
//
// Computes sheaf Laplacian residuals: r_e = rho_source(x_source) - rho_target(x_target)
// and per-edge energy: E_e = w_e * ||r_e||^2
//
// Each thread processes one edge, computing the residual and squared norm.
// =============================================================================
// TYPE DEFINITIONS (must match Rust structs exactly)
// =============================================================================
struct GpuParams {
num_edges: u32,
num_nodes: u32,
state_dim: u32,
beta: f32,
threshold_lane0: f32,
threshold_lane1: f32,
threshold_lane2: f32,
store_residuals: u32, // 0 = skip storage (energy only), 1 = store residuals
}
struct GpuEdge {
source_idx: u32,
target_idx: u32,
weight: f32,
rho_source_idx: u32,
rho_target_idx: u32,
comparison_dim: u32,
_padding0: u32,
_padding1: u32,
}
struct GpuRestrictionMap {
map_type: u32, // 0=identity, 1=diagonal, 2=projection, 3=dense
input_dim: u32,
output_dim: u32,
data_offset: u32,
data_len: u32,
_padding0: u32,
_padding1: u32,
_padding2: u32,
}
const WORKGROUP_SIZE: u32 = 256u;
const MAP_IDENTITY: u32 = 0u;
const MAP_DIAGONAL: u32 = 1u;
const MAP_PROJECTION: u32 = 2u;
const MAP_DENSE: u32 = 3u;
// =============================================================================
// BUFFER BINDINGS (matches Rust kernel bind group layout)
// =============================================================================
// binding 0: params (uniform)
// binding 1: node_states (storage, read)
// binding 2: edges (storage, read)
// binding 3: restriction_maps (storage, read)
// binding 4: restriction_data (storage, read)
// binding 5: residuals (storage, read_write)
// binding 6: energies (storage, read_write)
@group(0) @binding(0) var<uniform> params: GpuParams;
@group(0) @binding(1) var<storage, read> node_states: array<f32>;
@group(0) @binding(2) var<storage, read> edges: array<GpuEdge>;
@group(0) @binding(3) var<storage, read> restriction_maps: array<GpuRestrictionMap>;
@group(0) @binding(4) var<storage, read> restriction_data: array<f32>;
@group(0) @binding(5) var<storage, read_write> residuals: array<f32>;
@group(0) @binding(6) var<storage, read_write> energies: array<f32>;
// =============================================================================
// HELPER FUNCTIONS
// =============================================================================
/// Apply restriction map to a state vector at the given offset
/// Returns the projected value at output dimension d
fn apply_restriction(
rho: GpuRestrictionMap,
state_base: u32,
output_dim: u32
) -> f32 {
switch(rho.map_type) {
case MAP_IDENTITY: {
// Identity: just return the corresponding element
if (output_dim < rho.output_dim && output_dim < params.state_dim) {
return node_states[state_base + output_dim];
}
return 0.0;
}
case MAP_DIAGONAL: {
// Diagonal: scale by diagonal element
if (output_dim < rho.data_len) {
let scale = restriction_data[rho.data_offset + output_dim];
return node_states[state_base + output_dim] * scale;
}
return 0.0;
}
case MAP_PROJECTION: {
// Projection: select specific indices
if (output_dim < rho.data_len) {
let idx = u32(restriction_data[rho.data_offset + output_dim]);
if (idx < params.state_dim) {
return node_states[state_base + idx];
}
}
return 0.0;
}
case MAP_DENSE, default: {
// Dense: matrix-vector multiply for row output_dim
var result: f32 = 0.0;
let row_offset = rho.data_offset + output_dim * rho.input_dim;
for (var i = 0u; i < rho.input_dim && i < params.state_dim; i++) {
result += restriction_data[row_offset + i] * node_states[state_base + i];
}
return result;
}
}
return 0.0;
}
// =============================================================================
// MAIN ENTRY POINT
// =============================================================================
@compute @workgroup_size(256)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>
) {
let edge_idx = global_id.x;
// Bounds check
if (edge_idx >= params.num_edges) {
return;
}
// Get edge data
let edge = edges[edge_idx];
// Compute base offsets for source and target node states
let source_base = edge.source_idx * params.state_dim;
let target_base = edge.target_idx * params.state_dim;
// Get restriction maps
let rho_source = restriction_maps[edge.rho_source_idx];
let rho_target = restriction_maps[edge.rho_target_idx];
// Compute residual: r = rho_source(x_source) - rho_target(x_target)
// and accumulate squared norm
//
// OPTIMIZATION: Process 4 dimensions at a time using vec4 operations.
// This leverages GPU SIMD capabilities for ~4x throughput on high-dimensional
// state vectors. The dot(v, v) operation is particularly efficient on GPU.
var norm_sq: f32 = 0.0;
let comparison_dim = edge.comparison_dim;
let residual_base = edge_idx * comparison_dim;
// Calculate how many full vec4 iterations and remainder
let vec4_count = comparison_dim / 4u;
let remainder = comparison_dim % 4u;
// Process 4 dimensions at a time
var d = 0u;
for (var i = 0u; i < vec4_count; i++) {
// Load 4 source values via restriction maps
let source_vec = vec4<f32>(
apply_restriction(rho_source, source_base, d),
apply_restriction(rho_source, source_base, d + 1u),
apply_restriction(rho_source, source_base, d + 2u),
apply_restriction(rho_source, source_base, d + 3u)
);
// Load 4 target values via restriction maps
let target_vec = vec4<f32>(
apply_restriction(rho_target, target_base, d),
apply_restriction(rho_target, target_base, d + 1u),
apply_restriction(rho_target, target_base, d + 2u),
apply_restriction(rho_target, target_base, d + 3u)
);
// Compute residual vector (4 components at once)
let r_vec = source_vec - target_vec;
// Accumulate norm using dot product (very efficient on GPU - single instruction)
norm_sq += dot(r_vec, r_vec);
// Store residuals if requested (optional for energy-only computation)
if (params.store_residuals != 0u) {
let base_offset = residual_base + d;
if (base_offset + 3u < arrayLength(&residuals)) {
residuals[base_offset] = r_vec.x;
residuals[base_offset + 1u] = r_vec.y;
residuals[base_offset + 2u] = r_vec.z;
residuals[base_offset + 3u] = r_vec.w;
}
}
d += 4u;
}
// Handle remainder dimensions (0-3 elements)
for (var j = 0u; j < remainder; j++) {
let dim_idx = d + j;
let projected_source = apply_restriction(rho_source, source_base, dim_idx);
let projected_target = apply_restriction(rho_target, target_base, dim_idx);
let r = projected_source - projected_target;
norm_sq += r * r;
if (params.store_residuals != 0u) {
let offset = residual_base + dim_idx;
if (offset < arrayLength(&residuals)) {
residuals[offset] = r;
}
}
}
// Compute weighted energy: E_e = w_e * ||r_e||^2
let energy = edge.weight * norm_sq;
// Store per-edge energy
energies[edge_idx] = energy;
}

View File

@@ -0,0 +1,144 @@
// =============================================================================
// Prime-Radiant GPU Compute Shaders - Sheaf Attention
// =============================================================================
//
// Energy-based sheaf attention: A_ij = softmax(-beta * E_ij)
//
// Attention weights are computed from coherence energy:
// - Low energy (coherent) edges get high attention
// - High energy (incoherent) edges get low attention
// =============================================================================
// TYPE DEFINITIONS
// =============================================================================
struct AttentionParams {
num_edges: u32,
num_nodes: u32,
beta: f32,
energy_threshold: f32,
use_sparse: u32,
_padding0: u32,
_padding1: u32,
_padding2: u32,
}
struct EdgeDescriptor {
source_idx: u32,
target_idx: u32,
weight: f32,
_padding: u32,
}
const WORKGROUP_SIZE: u32 = 256u;
const NEG_INF: f32 = -3.402823e+38;
const EPSILON: f32 = 1e-8;
// =============================================================================
// BUFFER BINDINGS
// =============================================================================
// Layout matches Rust kernel bind group:
// binding 0: params (uniform)
// binding 1: edges (storage, read)
// binding 2: edge_energies (storage, read)
// binding 3: attention_weights (storage, read_write)
// binding 4: node_exp_sums (storage, read_write)
/// Attention parameters
@group(0) @binding(0) var<uniform> params: AttentionParams;
/// Edge descriptors
@group(0) @binding(1) var<storage, read> edges: array<EdgeDescriptor>;
/// Edge energies from residual computation
@group(0) @binding(2) var<storage, read> edge_energies: array<f32>;
/// Output attention weights (one per edge)
@group(0) @binding(3) var<storage, read_write> attention_weights: array<f32>;
/// Per-node exponential sums for normalization
@group(0) @binding(4) var<storage, read_write> node_exp_sums: array<f32>;
// =============================================================================
// SHARED MEMORY
// =============================================================================
/// Shared memory for parallel reduction
var<workgroup> shared_data: array<f32, 256>;
// =============================================================================
// SINGLE-PASS ATTENTION COMPUTATION
// =============================================================================
/// Compute attention weights from edge energies
/// A_e = exp(-beta * E_e) (unnormalized)
/// Each workgroup processes multiple edges
@compute @workgroup_size(256)
fn compute_attention_single_pass(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>
) {
let edge_idx = global_id.x;
let num_edges = params.num_edges;
let beta = params.beta;
if (edge_idx >= num_edges) {
return;
}
// Get edge energy
let energy = edge_energies[edge_idx];
// Compute unnormalized attention weight
// For energy-based attention: A = exp(-beta * E)
// High energy (incoherent) -> low attention
// Low energy (coherent) -> high attention
var score = -beta * energy;
// Apply energy threshold masking for sparse attention
if (params.use_sparse == 1u && energy > params.energy_threshold) {
score = NEG_INF;
}
// Compute exp(score) - clamp to avoid overflow
let clamped_score = clamp(score, -80.0, 80.0);
let exp_score = exp(clamped_score);
// Store unnormalized attention weight
attention_weights[edge_idx] = exp_score;
// Accumulate exp sum for source node (for later normalization)
// Note: This requires atomic operations for correctness in parallel
// For now, we store unnormalized weights; normalization done in separate pass
let edge = edges[edge_idx];
// atomicAdd(&node_exp_sums[edge.source_idx], exp_score);
// Note: WGSL doesn't have atomicAdd for f32, so we store for CPU normalization
}
// =============================================================================
// NORMALIZATION PASS
// =============================================================================
/// Normalize attention weights by node (outgoing edges sum to 1)
/// Second pass after exp sums are computed
@compute @workgroup_size(256)
fn normalize_attention(
@builtin(global_invocation_id) global_id: vec3<u32>
) {
let edge_idx = global_id.x;
let num_edges = params.num_edges;
if (edge_idx >= num_edges) {
return;
}
let edge = edges[edge_idx];
let source_idx = edge.source_idx;
// Get the sum of exp scores for this source node
let exp_sum = node_exp_sums[source_idx];
// Normalize
let normalized = attention_weights[edge_idx] / max(exp_sum, EPSILON);
attention_weights[edge_idx] = normalized;
}

View File

@@ -0,0 +1,471 @@
// =============================================================================
// Prime-Radiant GPU Compute Shaders - Sparse Attention Mask
// =============================================================================
//
// Generate sparse attention masks from energy thresholds.
// Only edges with energy below threshold (coherent) are included.
//
// This enables efficient sparse attention where only meaningful
// (low-energy, coherent) connections are computed, dramatically
// reducing computation for large graphs.
//
// Output Formats:
// 1. Index list: Compact list of (row, col) pairs for valid edges
// 2. Dense mask: Full NxN boolean matrix (for small N)
// 3. CSR format: Compressed sparse row for efficient sparse matmul
//
// Optimizations:
// - Stream compaction for index list generation
// - Warp-level voting for efficient counting
// - Coalesced writes using shared memory staging
// =============================================================================
// TYPE DEFINITIONS
// =============================================================================
struct SparseMaskParams {
total_edges: u32,
coherence_threshold: f32,
max_edges: u32,
output_format: u32, // 0=indices, 1=dense, 2=csr
seq_len: u32,
batch_size: u32,
padding: array<u32, 2>,
}
struct EdgeIndex {
row: u32,
col: u32,
}
struct CSRPointers {
row_ptr: u32,
nnz: u32,
}
const WORKGROUP_SIZE: u32 = 256u;
const OUTPUT_INDICES: u32 = 0u;
const OUTPUT_DENSE: u32 = 1u;
const OUTPUT_CSR: u32 = 2u;
// =============================================================================
// BUFFER BINDINGS
// =============================================================================
/// Input edge energies (seq_len * seq_len per batch, or sparse)
@group(0) @binding(0) var<storage, read> edge_energies: array<f32>;
/// Output: sparse edge indices (for index format)
@group(0) @binding(1) var<storage, read_write> sparse_indices: array<EdgeIndex>;
/// Output: dense mask (for dense format)
@group(0) @binding(2) var<storage, read_write> dense_mask: array<u32>;
/// Output: number of valid edges (atomic counter)
@group(0) @binding(3) var<storage, read_write> edge_count: atomic<u32>;
/// Mask parameters
@group(0) @binding(4) var<uniform> params: SparseMaskParams;
// =============================================================================
// SHARED MEMORY
// =============================================================================
/// Shared memory for stream compaction
var<workgroup> shared_valid: array<u32, 256>;
/// Prefix sum for compaction offsets
var<workgroup> shared_prefix: array<u32, 256>;
/// Staging buffer for coalesced writes
var<workgroup> shared_indices: array<EdgeIndex, 256>;
/// Workgroup-level count of valid edges
var<workgroup> workgroup_count: atomic<u32>;
// =============================================================================
// BASIC SPARSE MASK GENERATION
// =============================================================================
/// Generate sparse mask as index list
@compute @workgroup_size(256)
fn generate_sparse_indices(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) workgroup_id: vec3<u32>
) {
let idx = global_id.x;
let tid = local_id.x;
let total_edges = params.total_edges;
let threshold = params.coherence_threshold;
let seq_len = params.seq_len;
// Initialize workgroup counter
if (tid == 0u) {
atomicStore(&workgroup_count, 0u);
}
workgroupBarrier();
// Check if this edge is valid (below threshold)
var is_valid: u32 = 0u;
var row: u32 = 0u;
var col: u32 = 0u;
if (idx < total_edges) {
let energy = edge_energies[idx];
is_valid = select(0u, 1u, energy < threshold);
// Compute row and column from linear index
row = idx / seq_len;
col = idx % seq_len;
}
shared_valid[tid] = is_valid;
workgroupBarrier();
// Compute prefix sum for compaction
// Hillis-Steele parallel scan
shared_prefix[tid] = is_valid;
workgroupBarrier();
for (var offset = 1u; offset < WORKGROUP_SIZE; offset <<= 1u) {
var val: u32 = 0u;
if (tid >= offset) {
val = shared_prefix[tid - offset];
}
workgroupBarrier();
shared_prefix[tid] += val;
workgroupBarrier();
}
// Total valid in this workgroup
let total_valid = shared_prefix[WORKGROUP_SIZE - 1u];
// Get global offset for this workgroup
var global_offset: u32 = 0u;
if (tid == 0u && total_valid > 0u) {
global_offset = atomicAdd(&edge_count, total_valid);
atomicStore(&workgroup_count, global_offset);
}
workgroupBarrier();
global_offset = atomicLoad(&workgroup_count);
// Write valid edges to output using compacted indices
if (is_valid == 1u && idx < total_edges) {
// Exclusive prefix sum gives position
let local_pos = select(0u, shared_prefix[tid - 1u], tid > 0u);
let global_pos = global_offset + local_pos;
if (global_pos < params.max_edges) {
sparse_indices[global_pos] = EdgeIndex(row, col);
}
}
}
// =============================================================================
// DENSE MASK GENERATION
// =============================================================================
/// Generate dense boolean mask (packed as u32 bits)
@compute @workgroup_size(256)
fn generate_dense_mask(
@builtin(global_invocation_id) global_id: vec3<u32>
) {
let idx = global_id.x;
let total_edges = params.total_edges;
let threshold = params.coherence_threshold;
if (idx >= total_edges) {
return;
}
let energy = edge_energies[idx];
let is_valid = energy < threshold;
// Pack 32 boolean values per u32
let word_idx = idx / 32u;
let bit_idx = idx % 32u;
if (is_valid) {
// Atomic OR to set the bit
atomicOr(&dense_mask[word_idx], 1u << bit_idx);
}
}
/// Unpack dense mask bit
fn is_edge_valid(dense_mask_ptr: ptr<storage, array<u32>, read>, idx: u32) -> bool {
let word_idx = idx / 32u;
let bit_idx = idx % 32u;
return ((*dense_mask_ptr)[word_idx] & (1u << bit_idx)) != 0u;
}
// =============================================================================
// CSR FORMAT GENERATION
// =============================================================================
/// CSR row pointers
@group(1) @binding(0) var<storage, read_write> csr_row_ptr: array<u32>;
/// CSR column indices
@group(1) @binding(1) var<storage, read_write> csr_col_idx: array<u32>;
/// CSR values (attention weights or energies)
@group(1) @binding(2) var<storage, read_write> csr_values: array<f32>;
/// Per-row counters for CSR construction
@group(1) @binding(3) var<storage, read_write> row_counts: array<atomic<u32>>;
/// Phase 1: Count valid edges per row
@compute @workgroup_size(256)
fn count_edges_per_row(
@builtin(global_invocation_id) global_id: vec3<u32>
) {
let idx = global_id.x;
let total_edges = params.total_edges;
let threshold = params.coherence_threshold;
let seq_len = params.seq_len;
if (idx >= total_edges) {
return;
}
let energy = edge_energies[idx];
if (energy < threshold) {
let row = idx / seq_len;
atomicAdd(&row_counts[row], 1u);
}
}
/// Phase 2: Compute row pointers via prefix sum
@compute @workgroup_size(256)
fn compute_row_pointers(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>
) {
let row = global_id.x;
let tid = local_id.x;
let seq_len = params.seq_len;
if (row >= seq_len) {
return;
}
// Load count into shared memory
shared_prefix[tid] = atomicLoad(&row_counts[row]);
workgroupBarrier();
// Inclusive prefix sum
for (var offset = 1u; offset < WORKGROUP_SIZE; offset <<= 1u) {
var val: u32 = 0u;
if (tid >= offset) {
val = shared_prefix[tid - offset];
}
workgroupBarrier();
shared_prefix[tid] += val;
workgroupBarrier();
}
// Convert to exclusive prefix sum for row pointers
// row_ptr[i] = sum of counts for rows 0..i-1
let inclusive_sum = shared_prefix[tid];
let count = atomicLoad(&row_counts[row]);
let exclusive_sum = inclusive_sum - count;
csr_row_ptr[row] = exclusive_sum;
// Reset counter to be used as write position
atomicStore(&row_counts[row], exclusive_sum);
// Last row sets the final pointer (total nnz)
if (row == seq_len - 1u) {
csr_row_ptr[seq_len] = inclusive_sum;
}
}
/// Phase 3: Populate CSR column indices and values
@compute @workgroup_size(256)
fn populate_csr_data(
@builtin(global_invocation_id) global_id: vec3<u32>
) {
let idx = global_id.x;
let total_edges = params.total_edges;
let threshold = params.coherence_threshold;
let seq_len = params.seq_len;
if (idx >= total_edges) {
return;
}
let energy = edge_energies[idx];
if (energy < threshold) {
let row = idx / seq_len;
let col = idx % seq_len;
// Get write position using atomic increment
let pos = atomicAdd(&row_counts[row], 1u);
csr_col_idx[pos] = col;
csr_values[pos] = energy;
}
}
// =============================================================================
// BATCHED SPARSE MASK
// =============================================================================
/// Batch offsets for multi-batch processing
@group(2) @binding(0) var<storage, read> batch_offsets: array<u32>;
/// Per-batch edge counts
@group(2) @binding(1) var<storage, read_write> batch_edge_counts: array<atomic<u32>>;
/// Generate sparse mask for multiple batches
@compute @workgroup_size(256)
fn generate_batched_sparse_mask(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) workgroup_id: vec3<u32>
) {
let batch_idx = workgroup_id.z;
let local_idx = global_id.x;
let tid = local_id.x;
let seq_len = params.seq_len;
let edges_per_batch = seq_len * seq_len;
let threshold = params.coherence_threshold;
if (local_idx >= edges_per_batch) {
return;
}
// Global index in energy array
let global_idx = batch_idx * edges_per_batch + local_idx;
let energy = edge_energies[global_idx];
let is_valid = select(0u, 1u, energy < threshold);
// Stream compaction within batch
shared_valid[tid] = is_valid;
workgroupBarrier();
// Prefix sum
shared_prefix[tid] = is_valid;
workgroupBarrier();
for (var offset = 1u; offset < WORKGROUP_SIZE; offset <<= 1u) {
var val: u32 = 0u;
if (tid >= offset) {
val = shared_prefix[tid - offset];
}
workgroupBarrier();
shared_prefix[tid] += val;
workgroupBarrier();
}
// Get batch-local offset
if (tid == 0u) {
let total_valid = shared_prefix[WORKGROUP_SIZE - 1u];
let offset = atomicAdd(&batch_edge_counts[batch_idx], total_valid);
atomicStore(&workgroup_count, offset);
}
workgroupBarrier();
let batch_offset = batch_offsets[batch_idx];
let workgroup_offset = atomicLoad(&workgroup_count);
// Write valid edges
if (is_valid == 1u) {
let local_pos = select(0u, shared_prefix[tid - 1u], tid > 0u);
let global_pos = batch_offset + workgroup_offset + local_pos;
let row = local_idx / seq_len;
let col = local_idx % seq_len;
if (global_pos < params.max_edges) {
sparse_indices[global_pos] = EdgeIndex(row, col);
}
}
}
// =============================================================================
// DYNAMIC THRESHOLD ADJUSTMENT
// =============================================================================
/// Statistics for adaptive threshold
@group(3) @binding(0) var<storage, read_write> mask_stats: array<f32>;
/// Compute mask statistics for adaptive thresholding
@compute @workgroup_size(256)
fn compute_mask_statistics(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>
) {
let idx = global_id.x;
let tid = local_id.x;
let total_edges = params.total_edges;
let threshold = params.coherence_threshold;
// Count valid and total, compute sparsity ratio
var valid_count: u32 = 0u;
if (idx < total_edges) {
let energy = edge_energies[idx];
valid_count = select(0u, 1u, energy < threshold);
}
shared_prefix[tid] = valid_count;
workgroupBarrier();
// Reduce to get total valid
for (var stride = WORKGROUP_SIZE / 2u; stride > 0u; stride >>= 1u) {
if (tid < stride) {
shared_prefix[tid] += shared_prefix[tid + stride];
}
workgroupBarrier();
}
// Thread 0 updates global statistics
if (tid == 0u) {
// Atomic add to global counter
// mask_stats[0] = total valid edges
// mask_stats[1] = sparsity ratio (computed after all workgroups)
}
}
// =============================================================================
// CAUSAL MASK COMBINATION
// =============================================================================
/// Combine energy-based sparse mask with causal mask
@compute @workgroup_size(16, 16)
fn combine_with_causal_mask(
@builtin(global_invocation_id) global_id: vec3<u32>
) {
let row = global_id.y;
let col = global_id.x;
let seq_len = params.seq_len;
let threshold = params.coherence_threshold;
if (row >= seq_len || col >= seq_len) {
return;
}
let idx = row * seq_len + col;
let energy = edge_energies[idx];
// Valid if: (1) below energy threshold AND (2) satisfies causal constraint
let energy_valid = energy < threshold;
let causal_valid = col <= row; // Can only attend to past
let is_valid = energy_valid && causal_valid;
// Write to dense mask
let word_idx = idx / 32u;
let bit_idx = idx % 32u;
if (is_valid) {
atomicOr(&dense_mask[word_idx], 1u << bit_idx);
}
}

View File

@@ -0,0 +1,253 @@
// =============================================================================
// Prime-Radiant GPU Compute Shaders - Token Routing
// =============================================================================
//
// Parallel lane assignment for tokens based on coherence energy thresholds.
// Routes tokens to different processing lanes (experts) based on their
// local coherence energy, enabling adaptive computation.
//
// Lane Semantics:
// - Lane 0: Coherent (energy < tau_0) - Fast path, minimal processing
// - Lane 1: Semi-coherent (tau_0 <= energy < tau_1) - Normal processing
// - Lane 2: Incoherent (tau_1 <= energy < tau_2) - Enhanced processing
// - Lane 3: Critical (energy >= tau_2) - Special handling required
// =============================================================================
// TYPE DEFINITIONS
// =============================================================================
struct RoutingParams {
num_tokens: u32,
num_nodes: u32,
threshold_0: f32,
threshold_1: f32,
threshold_2: f32,
high_energy_threshold: f32,
_padding0: u32,
_padding1: u32,
}
struct Token {
token_id: u32,
node_idx: u32,
action_type: u32,
priority: f32,
}
struct RoutingDecision {
token_id: u32,
assigned_lane: u32,
local_energy: f32,
confidence: f32,
escalation_reason: u32,
num_high_energy_edges: u32,
max_edge_energy: f32,
_padding: u32,
}
struct LaneStats {
lane_counts: vec4<u32>,
total_energy_per_lane: vec4<f32>,
_padding: array<u32, 8>,
}
const WORKGROUP_SIZE: u32 = 256u;
const NUM_LANES: u32 = 4u;
// =============================================================================
// BUFFER BINDINGS
// =============================================================================
// Layout matches Rust kernel bind group:
// binding 0: params (uniform)
// binding 1: tokens (storage, read)
// binding 2: local_energies (storage, read)
// binding 3: edge_energies (storage, read)
// binding 4: node_edge_counts (storage, read)
// binding 5: node_edge_offsets (storage, read)
// binding 6: node_edges (storage, read)
// binding 7: routing_decisions (storage, read_write)
// binding 8: lane_stats (storage, read_write)
/// Routing parameters
@group(0) @binding(0) var<uniform> params: RoutingParams;
/// Input tokens
@group(0) @binding(1) var<storage, read> tokens: array<Token>;
/// Pre-computed local energies per node
@group(0) @binding(2) var<storage, read> local_energies: array<f32>;
/// All edge energies
@group(0) @binding(3) var<storage, read> edge_energies: array<f32>;
/// Number of edges per node (CSR format)
@group(0) @binding(4) var<storage, read> node_edge_counts: array<u32>;
/// Edge start offsets per node (CSR format)
@group(0) @binding(5) var<storage, read> node_edge_offsets: array<u32>;
/// Edge indices per node (CSR format)
@group(0) @binding(6) var<storage, read> node_edges: array<u32>;
/// Output routing decisions
@group(0) @binding(7) var<storage, read_write> routing_decisions: array<RoutingDecision>;
/// Output lane statistics
@group(0) @binding(8) var<storage, read_write> lane_stats: LaneStats;
// =============================================================================
// SHARED MEMORY
// =============================================================================
/// Lane counts for workgroup-level reduction
var<workgroup> shared_lane_counts: array<atomic<u32>, 4>;
/// Lane energy sums for workgroup-level reduction
var<workgroup> shared_lane_energies: array<f32, 4>;
// =============================================================================
// HELPER FUNCTIONS
// =============================================================================
/// Branchless lane computation using step functions
fn compute_lane_branchless(energy: f32, t0: f32, t1: f32, t2: f32) -> u32 {
let s0 = select(0u, 1u, energy >= t0);
let s1 = select(0u, 1u, energy >= t1);
let s2 = select(0u, 1u, energy >= t2);
return s0 + s1 + s2;
}
/// Compute routing confidence based on how close energy is to threshold boundaries
fn compute_confidence(energy: f32, lane: u32, t0: f32, t1: f32, t2: f32) -> f32 {
// Confidence is based on distance from nearest threshold
var dist_to_threshold: f32;
switch(lane) {
case 0u: {
dist_to_threshold = t0 - energy;
}
case 1u: {
dist_to_threshold = min(energy - t0, t1 - energy);
}
case 2u: {
dist_to_threshold = min(energy - t1, t2 - energy);
}
case 3u, default: {
dist_to_threshold = energy - t2;
}
}
// Normalize to [0, 1] - higher means further from boundary
return clamp(dist_to_threshold * 10.0, 0.0, 1.0);
}
// =============================================================================
// MAIN ROUTING KERNEL
// =============================================================================
/// Route tokens to processing lanes based on local coherence energy
@compute @workgroup_size(256)
fn route_tokens(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) workgroup_id: vec3<u32>
) {
let token_idx = global_id.x;
let local_idx = local_id.x;
let num_tokens = params.num_tokens;
// Initialize shared counters (first thread only)
if (local_idx == 0u) {
atomicStore(&shared_lane_counts[0], 0u);
atomicStore(&shared_lane_counts[1], 0u);
atomicStore(&shared_lane_counts[2], 0u);
atomicStore(&shared_lane_counts[3], 0u);
shared_lane_energies[0] = 0.0;
shared_lane_energies[1] = 0.0;
shared_lane_energies[2] = 0.0;
shared_lane_energies[3] = 0.0;
}
workgroupBarrier();
if (token_idx >= num_tokens) {
return;
}
let token = tokens[token_idx];
let node_idx = token.node_idx;
// Get local energy for this node
let local_energy = local_energies[node_idx];
// Compute lane assignment
let lane = compute_lane_branchless(
local_energy,
params.threshold_0,
params.threshold_1,
params.threshold_2
);
// Compute confidence
let confidence = compute_confidence(
local_energy,
lane,
params.threshold_0,
params.threshold_1,
params.threshold_2
);
// Analyze edges for this node
let edge_count = node_edge_counts[node_idx];
let edge_offset = node_edge_offsets[node_idx];
var num_high_energy_edges: u32 = 0u;
var max_edge_energy: f32 = 0.0;
var escalation_reason: u32 = 0u;
for (var i = 0u; i < edge_count; i++) {
let edge_idx = node_edges[edge_offset + i];
let edge_energy = edge_energies[edge_idx];
if (edge_energy > params.high_energy_threshold) {
num_high_energy_edges += 1u;
}
max_edge_energy = max(max_edge_energy, edge_energy);
}
// Determine if escalation is needed
if (num_high_energy_edges > 2u) {
escalation_reason = 1u; // Multiple high-energy edges
} else if (max_edge_energy > params.threshold_2) {
escalation_reason = 2u; // Single very high energy edge
}
// Write routing decision
var decision: RoutingDecision;
decision.token_id = token.token_id;
decision.assigned_lane = lane;
decision.local_energy = local_energy;
decision.confidence = confidence;
decision.escalation_reason = escalation_reason;
decision.num_high_energy_edges = num_high_energy_edges;
decision.max_edge_energy = max_edge_energy;
decision._padding = 0u;
routing_decisions[token_idx] = decision;
// Update lane statistics
atomicAdd(&shared_lane_counts[lane], 1u);
// Note: No atomic f32 add in WGSL, would need separate reduction pass
workgroupBarrier();
// First thread writes workgroup stats to global buffer
// (In production, would do proper atomic accumulation)
if (local_idx == 0u && workgroup_id.x == 0u) {
lane_stats.lane_counts = vec4<u32>(
atomicLoad(&shared_lane_counts[0]),
atomicLoad(&shared_lane_counts[1]),
atomicLoad(&shared_lane_counts[2]),
atomicLoad(&shared_lane_counts[3])
);
}
}

View File

@@ -0,0 +1,234 @@
// =============================================================================
// Prime-Radiant GPU Compute Shaders - Shared Types
// =============================================================================
//
// This file contains shared struct definitions and constants used across
// all compute shaders in the Prime-Radiant coherence engine.
//
// Memory Layout:
// - All structs are aligned to 16 bytes for optimal GPU memory access
// - vec4<f32> is used where possible for coalesced memory operations
// - Padding fields ensure proper alignment
// =============================================================================
// COMPUTE PARAMETERS
// =============================================================================
/// Parameters for residual computation
struct ComputeParams {
/// Total number of edges to process
edge_count: u32,
/// Dimension of state vectors
state_dim: u32,
/// Restriction map type: 0=identity, 1=diagonal, 2=dense, 3=projection, 4=sparse
restriction_type: u32,
/// Padding for 16-byte alignment
padding: u32,
}
/// Parameters for parallel reduction operations
struct ReductionParams {
/// Number of elements to reduce
element_count: u32,
/// Stride between elements (for strided access patterns)
stride: u32,
/// Whether this is the final reduction pass
is_final_pass: u32,
/// Output offset for multi-pass reductions
output_offset: u32,
}
/// Parameters for attention computation
struct AttentionParams {
/// Batch size (number of independent attention operations)
batch_size: u32,
/// Sequence length (number of tokens/nodes)
seq_len: u32,
/// Dimension per attention head
head_dim: u32,
/// Inverse temperature parameter: A_ij = softmax(-beta * E_ij)
beta: f32,
/// Number of attention heads (for multi-head attention)
num_heads: u32,
/// Whether to use causal masking
use_causal_mask: u32,
/// Energy threshold for sparse attention (skip if E > threshold)
energy_threshold: f32,
/// Padding for 16-byte alignment
padding: u32,
}
/// Parameters for token routing
struct RoutingParams {
/// Number of tokens to route
token_count: u32,
/// Number of lanes/experts
num_lanes: u32,
/// Whether to use load balancing
use_load_balance: u32,
/// Top-k selection for MoE
top_k: u32,
}
/// Parameters for sparse mask generation
struct SparseMaskParams {
/// Total number of potential edges
total_edges: u32,
/// Energy threshold for coherence (keep edges below this)
coherence_threshold: f32,
/// Maximum edges to keep (for memory bounds)
max_edges: u32,
/// Output format: 0=indices, 1=dense mask
output_format: u32,
}
// =============================================================================
// EDGE AND NODE DATA STRUCTURES
// =============================================================================
/// Edge descriptor for graph connectivity (16-byte aligned)
struct EdgeDescriptor {
/// Index of source node
source_idx: u32,
/// Index of target node
target_idx: u32,
/// Offset into restriction data for this edge
restriction_offset: u32,
/// Weight for this edge
weight: f32,
}
/// Node state with metadata (16-byte aligned)
struct NodeState {
/// Offset into state buffer where this node's state begins
state_offset: u32,
/// Dimension of this node's state
state_dim: u32,
/// Scope ID for hierarchical energy aggregation
scope_id: u32,
/// Flags (bit 0: is_boundary, bit 1: is_fixed, etc.)
flags: u32,
}
/// Per-edge energy result (16-byte aligned)
struct EdgeEnergy {
/// Weighted energy: w_e * |r_e|^2
energy: f32,
/// Raw residual norm squared: |r_e|^2
residual_norm_sq: f32,
/// Edge weight that was applied
weight: f32,
/// Padding for alignment
padding: f32,
}
// =============================================================================
// ATTENTION STRUCTURES
// =============================================================================
/// Attention score for a single edge (16-byte aligned)
struct AttentionScore {
/// Source node index
source: u32,
/// Target node index
target: u32,
/// Attention weight (after softmax)
weight: f32,
/// Raw score (before softmax)
raw_score: f32,
}
/// Lane assignment result for token routing (16-byte aligned)
struct LaneAssignment {
/// Token index
token_idx: u32,
/// Assigned lane (0-3 typically)
lane: u32,
/// Confidence score for this assignment
confidence: f32,
/// Energy value that determined routing
energy: f32,
}
// =============================================================================
// CONSTANTS
// =============================================================================
/// Workgroup size for 1D dispatches
const WORKGROUP_SIZE_1D: u32 = 256u;
/// Workgroup dimensions for 2D dispatches (attention)
const WORKGROUP_SIZE_2D_X: u32 = 16u;
const WORKGROUP_SIZE_2D_Y: u32 = 16u;
/// Maximum supported state dimension (for stack allocation)
const MAX_STATE_DIM: u32 = 512u;
/// Epsilon for numerical stability
const EPSILON: f32 = 1e-8;
/// Negative infinity for softmax initialization
const NEG_INF: f32 = -3.402823e+38;
/// Restriction map type constants
const RESTRICTION_IDENTITY: u32 = 0u;
const RESTRICTION_DIAGONAL: u32 = 1u;
const RESTRICTION_DENSE: u32 = 2u;
const RESTRICTION_PROJECTION: u32 = 3u;
const RESTRICTION_SPARSE: u32 = 4u;
/// Lane thresholds for token routing (default values)
/// Lane 0: energy < 0.1 (coherent, fast path)
/// Lane 1: 0.1 <= energy < 0.5 (semi-coherent, normal path)
/// Lane 2: 0.5 <= energy < 1.0 (incoherent, slow path)
/// Lane 3: energy >= 1.0 (critical, special handling)
const DEFAULT_LANE_THRESHOLDS: vec4<f32> = vec4<f32>(0.1, 0.5, 1.0, 10.0);
// =============================================================================
// UTILITY FUNCTIONS
// =============================================================================
/// Compute squared L2 norm of a vec4
fn norm_sq_vec4(v: vec4<f32>) -> f32 {
return dot(v, v);
}
/// Safe division with epsilon
fn safe_div(a: f32, b: f32) -> f32 {
return a / max(b, EPSILON);
}
/// Branchless step function
fn step_branchless(threshold: f32, value: f32) -> f32 {
return select(0.0, 1.0, value >= threshold);
}
/// Compute lane index from energy using branchless comparison
fn compute_lane(energy: f32, thresholds: vec4<f32>) -> u32 {
return u32(step_branchless(thresholds.x, energy))
+ u32(step_branchless(thresholds.y, energy))
+ u32(step_branchless(thresholds.z, energy));
}
/// Online softmax helper - update max and sum
fn online_softmax_update(
old_max: f32,
old_sum: f32,
new_val: f32
) -> vec2<f32> {
let new_max = max(old_max, new_val);
let correction = exp(old_max - new_max);
let new_sum = old_sum * correction + exp(new_val - new_max);
return vec2<f32>(new_max, new_sum);
}
/// Fast approximate exp for softmax (when precision is less critical)
fn fast_exp(x: f32) -> f32 {
// Use native exp for now; can be replaced with polynomial approximation
return exp(x);
}
/// Clamp value to valid range
fn clamp_f32(val: f32, min_val: f32, max_val: f32) -> f32 {
return max(min_val, min(max_val, val));
}