Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
705
vendor/ruvector/crates/prime-radiant/src/gpu/buffer.rs
vendored
Normal file
705
vendor/ruvector/crates/prime-radiant/src/gpu/buffer.rs
vendored
Normal file
@@ -0,0 +1,705 @@
|
||||
//! GPU Buffer Management
|
||||
//!
|
||||
//! Provides efficient GPU buffer allocation, management, and data transfer
|
||||
//! for the coherence engine. Implements a buffer pool for reuse and
|
||||
//! minimizes CPU-GPU synchronization overhead.
|
||||
|
||||
use super::error::{GpuError, GpuResult};
|
||||
use bytemuck::{Pod, Zeroable};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use wgpu::{Buffer, BufferDescriptor, BufferUsages, Device, Queue};
|
||||
|
||||
/// Buffer usage flags for coherence computation
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum BufferUsage {
|
||||
/// Storage buffer for node states
|
||||
NodeStates,
|
||||
/// Storage buffer for edge data
|
||||
EdgeData,
|
||||
/// Storage buffer for restriction maps
|
||||
RestrictionMaps,
|
||||
/// Storage buffer for residuals
|
||||
Residuals,
|
||||
/// Storage buffer for energy values
|
||||
Energies,
|
||||
/// Storage buffer for attention weights
|
||||
AttentionWeights,
|
||||
/// Storage buffer for routing decisions
|
||||
RoutingDecisions,
|
||||
/// Uniform buffer for shader parameters
|
||||
Uniforms,
|
||||
/// Staging buffer for CPU readback
|
||||
Staging,
|
||||
}
|
||||
|
||||
/// GPU-side node state representation
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
|
||||
pub struct GpuNodeState {
|
||||
/// Flattened state vector (padded to MAX_STATE_DIM)
|
||||
pub state: [f32; 128], // Will be dynamically sized based on actual dim
|
||||
/// Actual dimension of the state vector
|
||||
pub dim: u32,
|
||||
/// Node index
|
||||
pub index: u32,
|
||||
/// Padding for alignment
|
||||
pub _padding: [u32; 2],
|
||||
}
|
||||
|
||||
/// GPU-side edge representation
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
|
||||
pub struct GpuEdge {
|
||||
/// Source node index
|
||||
pub source_idx: u32,
|
||||
/// Target node index
|
||||
pub target_idx: u32,
|
||||
/// Edge weight
|
||||
pub weight: f32,
|
||||
/// Restriction map index for source
|
||||
pub rho_source_idx: u32,
|
||||
/// Restriction map index for target
|
||||
pub rho_target_idx: u32,
|
||||
/// Output dimension of restriction maps
|
||||
pub comparison_dim: u32,
|
||||
/// Padding for alignment
|
||||
pub _padding: [u32; 2],
|
||||
}
|
||||
|
||||
/// GPU-side restriction map (dense matrix stored row-major)
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
|
||||
pub struct GpuRestrictionMap {
|
||||
/// Matrix type: 0=identity, 1=diagonal, 2=projection, 3=dense
|
||||
pub map_type: u32,
|
||||
/// Input dimension
|
||||
pub input_dim: u32,
|
||||
/// Output dimension
|
||||
pub output_dim: u32,
|
||||
/// Offset into the shared data buffer
|
||||
pub data_offset: u32,
|
||||
/// Number of elements in data
|
||||
pub data_len: u32,
|
||||
/// Padding for alignment
|
||||
pub _padding: [u32; 3],
|
||||
}
|
||||
|
||||
/// GPU-side shader parameters
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
|
||||
pub struct GpuParams {
|
||||
/// Number of edges
|
||||
pub num_edges: u32,
|
||||
/// Number of nodes
|
||||
pub num_nodes: u32,
|
||||
/// State dimension
|
||||
pub state_dim: u32,
|
||||
/// Beta parameter for attention
|
||||
pub beta: f32,
|
||||
/// Lane 0 threshold (reflex)
|
||||
pub threshold_lane0: f32,
|
||||
/// Lane 1 threshold (retrieval)
|
||||
pub threshold_lane1: f32,
|
||||
/// Lane 2 threshold (heavy)
|
||||
pub threshold_lane2: f32,
|
||||
/// Flag to control residual storage (0 = skip, 1 = store)
|
||||
/// When computing energy only, skip storage for better performance
|
||||
pub store_residuals: u32,
|
||||
}
|
||||
|
||||
/// Wrapper around a wgpu Buffer with metadata
|
||||
pub struct GpuBuffer {
|
||||
/// The underlying wgpu buffer
|
||||
pub buffer: Buffer,
|
||||
/// Size in bytes
|
||||
pub size: usize,
|
||||
/// Usage flags
|
||||
pub usage: BufferUsage,
|
||||
/// Label for debugging
|
||||
pub label: String,
|
||||
}
|
||||
|
||||
impl GpuBuffer {
|
||||
/// Create a new GPU buffer
|
||||
pub fn new(
|
||||
device: &Device,
|
||||
size: usize,
|
||||
usage: BufferUsage,
|
||||
label: impl Into<String>,
|
||||
) -> GpuResult<Self> {
|
||||
let label = label.into();
|
||||
let wgpu_usage = Self::to_wgpu_usage(usage);
|
||||
|
||||
let buffer = device.create_buffer(&BufferDescriptor {
|
||||
label: Some(&label),
|
||||
size: size as u64,
|
||||
usage: wgpu_usage,
|
||||
mapped_at_creation: false,
|
||||
});
|
||||
|
||||
Ok(Self {
|
||||
buffer,
|
||||
size,
|
||||
usage,
|
||||
label,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a new GPU buffer with initial data
|
||||
pub fn new_with_data<T: Pod>(
|
||||
device: &Device,
|
||||
queue: &Queue,
|
||||
data: &[T],
|
||||
usage: BufferUsage,
|
||||
label: impl Into<String>,
|
||||
) -> GpuResult<Self> {
|
||||
let label = label.into();
|
||||
let bytes = bytemuck::cast_slice(data);
|
||||
let size = bytes.len();
|
||||
let wgpu_usage = Self::to_wgpu_usage(usage);
|
||||
|
||||
let buffer = device.create_buffer(&BufferDescriptor {
|
||||
label: Some(&label),
|
||||
size: size as u64,
|
||||
usage: wgpu_usage,
|
||||
mapped_at_creation: false,
|
||||
});
|
||||
|
||||
queue.write_buffer(&buffer, 0, bytes);
|
||||
|
||||
Ok(Self {
|
||||
buffer,
|
||||
size,
|
||||
usage,
|
||||
label,
|
||||
})
|
||||
}
|
||||
|
||||
/// Write data to the buffer
|
||||
pub fn write<T: Pod>(&self, queue: &Queue, data: &[T]) -> GpuResult<()> {
|
||||
let bytes = bytemuck::cast_slice(data);
|
||||
if bytes.len() > self.size {
|
||||
return Err(GpuError::BufferSizeMismatch {
|
||||
expected: self.size,
|
||||
actual: bytes.len(),
|
||||
});
|
||||
}
|
||||
queue.write_buffer(&self.buffer, 0, bytes);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Convert our usage to wgpu usage flags
|
||||
fn to_wgpu_usage(usage: BufferUsage) -> BufferUsages {
|
||||
match usage {
|
||||
BufferUsage::NodeStates
|
||||
| BufferUsage::EdgeData
|
||||
| BufferUsage::RestrictionMaps
|
||||
| BufferUsage::Residuals
|
||||
| BufferUsage::Energies
|
||||
| BufferUsage::AttentionWeights
|
||||
| BufferUsage::RoutingDecisions => {
|
||||
BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST
|
||||
}
|
||||
BufferUsage::Uniforms => BufferUsages::UNIFORM | BufferUsages::COPY_DST,
|
||||
BufferUsage::Staging => BufferUsages::MAP_READ | BufferUsages::COPY_DST,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Buffer manager for efficient allocation and reuse
|
||||
pub struct GpuBufferManager {
|
||||
device: Arc<Device>,
|
||||
queue: Arc<Queue>,
|
||||
/// Buffer pool keyed by (usage, size_bucket)
|
||||
pool: HashMap<(BufferUsage, usize), Vec<GpuBuffer>>,
|
||||
/// Active buffers currently in use
|
||||
active: HashMap<String, GpuBuffer>,
|
||||
}
|
||||
|
||||
impl GpuBufferManager {
|
||||
/// Create a new buffer manager
|
||||
pub fn new(device: Arc<Device>, queue: Arc<Queue>) -> Self {
|
||||
Self {
|
||||
device,
|
||||
queue,
|
||||
pool: HashMap::new(),
|
||||
active: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Allocate or reuse a buffer
|
||||
pub fn allocate(
|
||||
&mut self,
|
||||
size: usize,
|
||||
usage: BufferUsage,
|
||||
label: impl Into<String>,
|
||||
) -> GpuResult<&GpuBuffer> {
|
||||
let label = label.into();
|
||||
let bucket = Self::size_bucket(size);
|
||||
|
||||
// Try to reuse from pool
|
||||
if let Some(buffers) = self.pool.get_mut(&(usage, bucket)) {
|
||||
if let Some(buffer) = buffers.pop() {
|
||||
self.active.insert(label.clone(), buffer);
|
||||
return Ok(self.active.get(&label).unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
// Allocate new buffer
|
||||
let buffer = GpuBuffer::new(&self.device, bucket, usage, &label)?;
|
||||
self.active.insert(label.clone(), buffer);
|
||||
Ok(self.active.get(&label).unwrap())
|
||||
}
|
||||
|
||||
/// Allocate or reuse a buffer with initial data
|
||||
pub fn allocate_with_data<T: Pod>(
|
||||
&mut self,
|
||||
data: &[T],
|
||||
usage: BufferUsage,
|
||||
label: impl Into<String>,
|
||||
) -> GpuResult<&GpuBuffer> {
|
||||
let label = label.into();
|
||||
let size = std::mem::size_of_val(data);
|
||||
let bucket = Self::size_bucket(size);
|
||||
|
||||
// Try to reuse from pool
|
||||
if let Some(buffers) = self.pool.get_mut(&(usage, bucket)) {
|
||||
if let Some(buffer) = buffers.pop() {
|
||||
buffer.write(&self.queue, data)?;
|
||||
self.active.insert(label.clone(), buffer);
|
||||
return Ok(self.active.get(&label).unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
// Allocate new buffer with data
|
||||
let buffer = GpuBuffer::new_with_data(&self.device, &self.queue, data, usage, &label)?;
|
||||
self.active.insert(label.clone(), buffer);
|
||||
Ok(self.active.get(&label).unwrap())
|
||||
}
|
||||
|
||||
/// Get an active buffer by label
|
||||
pub fn get(&self, label: &str) -> Option<&GpuBuffer> {
|
||||
self.active.get(label)
|
||||
}
|
||||
|
||||
/// Release a buffer back to the pool for reuse
|
||||
pub fn release(&mut self, label: &str) {
|
||||
if let Some(buffer) = self.active.remove(label) {
|
||||
let bucket = Self::size_bucket(buffer.size);
|
||||
self.pool
|
||||
.entry((buffer.usage, bucket))
|
||||
.or_default()
|
||||
.push(buffer);
|
||||
}
|
||||
}
|
||||
|
||||
/// Release all active buffers back to the pool
|
||||
pub fn release_all(&mut self) {
|
||||
let labels: Vec<_> = self.active.keys().cloned().collect();
|
||||
for label in labels {
|
||||
self.release(&label);
|
||||
}
|
||||
}
|
||||
|
||||
/// Clear all buffers (both pool and active)
|
||||
pub fn clear(&mut self) {
|
||||
self.active.clear();
|
||||
self.pool.clear();
|
||||
}
|
||||
|
||||
/// Round size up to nearest power of 2 for efficient reuse
|
||||
fn size_bucket(size: usize) -> usize {
|
||||
const MIN_BUCKET: usize = 256;
|
||||
if size <= MIN_BUCKET {
|
||||
MIN_BUCKET
|
||||
} else {
|
||||
size.next_power_of_two()
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the underlying device
|
||||
pub fn device(&self) -> &Device {
|
||||
&self.device
|
||||
}
|
||||
|
||||
/// Get the underlying queue
|
||||
pub fn queue(&self) -> &Queue {
|
||||
&self.queue
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// BUFFER USAGE FLAGS (for pipeline.rs compatibility)
|
||||
// ============================================================================
|
||||
|
||||
/// Buffer usage flags for flexible configuration
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub struct BufferUsageFlags {
|
||||
/// Can be read from GPU (STORAGE)
|
||||
pub storage_read: bool,
|
||||
/// Can be written to by GPU (STORAGE)
|
||||
pub storage_write: bool,
|
||||
/// Can be used as uniform buffer
|
||||
pub uniform: bool,
|
||||
/// Can be mapped for CPU read
|
||||
pub map_read: bool,
|
||||
/// Can be mapped for CPU write
|
||||
pub map_write: bool,
|
||||
/// Can be used as copy source
|
||||
pub copy_src: bool,
|
||||
/// Can be used as copy destination
|
||||
pub copy_dst: bool,
|
||||
/// Can be used for indirect dispatch
|
||||
pub indirect: bool,
|
||||
}
|
||||
|
||||
impl BufferUsageFlags {
|
||||
/// Storage buffer (read-only)
|
||||
pub const fn storage_readonly() -> Self {
|
||||
Self {
|
||||
storage_read: true,
|
||||
storage_write: false,
|
||||
uniform: false,
|
||||
map_read: false,
|
||||
map_write: false,
|
||||
copy_src: true,
|
||||
copy_dst: true,
|
||||
indirect: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Storage buffer (read-write)
|
||||
pub const fn storage_readwrite() -> Self {
|
||||
Self {
|
||||
storage_read: true,
|
||||
storage_write: true,
|
||||
uniform: false,
|
||||
map_read: false,
|
||||
map_write: false,
|
||||
copy_src: true,
|
||||
copy_dst: true,
|
||||
indirect: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Uniform buffer
|
||||
pub const fn uniform() -> Self {
|
||||
Self {
|
||||
storage_read: false,
|
||||
storage_write: false,
|
||||
uniform: true,
|
||||
map_read: false,
|
||||
map_write: false,
|
||||
copy_src: false,
|
||||
copy_dst: true,
|
||||
indirect: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Staging buffer for read-back
|
||||
pub const fn staging_read() -> Self {
|
||||
Self {
|
||||
storage_read: false,
|
||||
storage_write: false,
|
||||
uniform: false,
|
||||
map_read: true,
|
||||
map_write: false,
|
||||
copy_src: false,
|
||||
copy_dst: true,
|
||||
indirect: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Staging buffer for upload
|
||||
pub const fn staging_write() -> Self {
|
||||
Self {
|
||||
storage_read: false,
|
||||
storage_write: false,
|
||||
uniform: false,
|
||||
map_read: false,
|
||||
map_write: true,
|
||||
copy_src: true,
|
||||
copy_dst: false,
|
||||
indirect: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Indirect dispatch buffer
|
||||
pub const fn indirect() -> Self {
|
||||
Self {
|
||||
storage_read: true,
|
||||
storage_write: true,
|
||||
uniform: false,
|
||||
map_read: false,
|
||||
map_write: false,
|
||||
copy_src: true,
|
||||
copy_dst: true,
|
||||
indirect: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert to wgpu buffer usages
|
||||
pub fn to_wgpu(&self) -> BufferUsages {
|
||||
let mut usages = BufferUsages::empty();
|
||||
|
||||
if self.storage_read || self.storage_write {
|
||||
usages |= BufferUsages::STORAGE;
|
||||
}
|
||||
if self.uniform {
|
||||
usages |= BufferUsages::UNIFORM;
|
||||
}
|
||||
if self.map_read {
|
||||
usages |= BufferUsages::MAP_READ;
|
||||
}
|
||||
if self.map_write {
|
||||
usages |= BufferUsages::MAP_WRITE;
|
||||
}
|
||||
if self.copy_src {
|
||||
usages |= BufferUsages::COPY_SRC;
|
||||
}
|
||||
if self.copy_dst {
|
||||
usages |= BufferUsages::COPY_DST;
|
||||
}
|
||||
if self.indirect {
|
||||
usages |= BufferUsages::INDIRECT;
|
||||
}
|
||||
|
||||
usages
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// BUFFER KEY AND POOL (for dispatch.rs compatibility)
|
||||
// ============================================================================
|
||||
|
||||
/// Key for buffer pool lookups
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub struct BufferKey {
|
||||
/// Buffer size in bytes
|
||||
pub size: u64,
|
||||
/// Buffer usage flags
|
||||
pub usage: BufferUsageFlags,
|
||||
}
|
||||
|
||||
impl BufferKey {
|
||||
/// Create a new buffer key
|
||||
pub fn new(size: u64, usage: BufferUsageFlags) -> Self {
|
||||
Self { size, usage }
|
||||
}
|
||||
}
|
||||
|
||||
/// Buffer pool for reusing GPU allocations with DashMap for concurrent access
|
||||
pub struct GpuBufferPool {
|
||||
device: Arc<Device>,
|
||||
buffers: dashmap::DashMap<BufferKey, Vec<GpuBuffer>>,
|
||||
max_pool_size: usize,
|
||||
}
|
||||
|
||||
impl GpuBufferPool {
|
||||
/// Create a new buffer pool
|
||||
pub fn new(device: Arc<Device>) -> Self {
|
||||
Self::with_capacity(device, super::DEFAULT_POOL_CAPACITY)
|
||||
}
|
||||
|
||||
/// Create a new buffer pool with custom capacity
|
||||
pub fn with_capacity(device: Arc<Device>, max_pool_size: usize) -> Self {
|
||||
Self {
|
||||
device,
|
||||
buffers: dashmap::DashMap::new(),
|
||||
max_pool_size,
|
||||
}
|
||||
}
|
||||
|
||||
/// Acquire a buffer from the pool or create a new one.
|
||||
pub fn acquire(&self, size: u64, usage: BufferUsageFlags) -> GpuResult<GpuBuffer> {
|
||||
if size > super::MAX_BUFFER_SIZE {
|
||||
return Err(GpuError::BufferTooLarge {
|
||||
size,
|
||||
max: super::MAX_BUFFER_SIZE,
|
||||
});
|
||||
}
|
||||
|
||||
let key = BufferKey::new(size, usage);
|
||||
|
||||
// Try to get from pool
|
||||
if let Some(mut buffers) = self.buffers.get_mut(&key) {
|
||||
if let Some(buffer) = buffers.pop() {
|
||||
return Ok(buffer);
|
||||
}
|
||||
}
|
||||
|
||||
// Create new buffer
|
||||
let wgpu_buffer = self.device.create_buffer(&BufferDescriptor {
|
||||
label: Some("pooled_buffer"),
|
||||
size,
|
||||
usage: usage.to_wgpu(),
|
||||
mapped_at_creation: false,
|
||||
});
|
||||
|
||||
Ok(GpuBuffer {
|
||||
buffer: wgpu_buffer,
|
||||
size: size as usize,
|
||||
usage: BufferUsage::Staging, // Default usage type
|
||||
label: "pooled_buffer".to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Return a buffer to the pool for reuse.
|
||||
pub fn release(&self, buffer: GpuBuffer) {
|
||||
let size = buffer.size as u64;
|
||||
let usage = BufferUsageFlags::storage_readwrite(); // Default
|
||||
let key = BufferKey::new(size, usage);
|
||||
|
||||
let mut buffers = self.buffers.entry(key).or_insert_with(Vec::new);
|
||||
if buffers.len() < self.max_pool_size {
|
||||
buffers.push(buffer);
|
||||
}
|
||||
}
|
||||
|
||||
/// Clear all pooled buffers
|
||||
pub fn clear(&self) {
|
||||
self.buffers.clear();
|
||||
}
|
||||
|
||||
/// Get statistics about the pool
|
||||
pub fn stats(&self) -> PoolStats {
|
||||
let mut total_buffers = 0;
|
||||
let mut total_bytes = 0u64;
|
||||
|
||||
for entry in self.buffers.iter() {
|
||||
total_buffers += entry.value().len();
|
||||
total_bytes += entry.key().size * entry.value().len() as u64;
|
||||
}
|
||||
|
||||
PoolStats {
|
||||
total_buffers,
|
||||
total_bytes,
|
||||
bucket_count: self.buffers.len(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Statistics about the buffer pool
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PoolStats {
|
||||
/// Total number of pooled buffers
|
||||
pub total_buffers: usize,
|
||||
/// Total bytes allocated in pool
|
||||
pub total_bytes: u64,
|
||||
/// Number of unique buffer configurations
|
||||
pub bucket_count: usize,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// EXTENDED GPUBUFFER METHODS (for pipeline.rs compatibility)
|
||||
// ============================================================================
|
||||
|
||||
impl GpuBuffer {
|
||||
/// Create a binding entry for this buffer.
|
||||
pub fn binding(&self, binding: u32) -> wgpu::BindGroupEntry {
|
||||
wgpu::BindGroupEntry {
|
||||
binding,
|
||||
resource: self.buffer.as_entire_binding(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the underlying wgpu buffer
|
||||
pub fn buffer(&self) -> &Buffer {
|
||||
&self.buffer
|
||||
}
|
||||
|
||||
/// Create a new storage buffer with initial data (for dispatch compatibility)
|
||||
pub fn new_storage<T: Pod>(
|
||||
device: &Device,
|
||||
queue: &Queue,
|
||||
data: &[T],
|
||||
read_write: bool,
|
||||
) -> GpuResult<Self> {
|
||||
let usage = if read_write {
|
||||
BufferUsage::Residuals
|
||||
} else {
|
||||
BufferUsage::NodeStates
|
||||
};
|
||||
Self::new_with_data(device, queue, data, usage, "storage_buffer")
|
||||
}
|
||||
|
||||
/// Create a new uninitialized storage buffer
|
||||
pub fn new_storage_uninit<T: Pod>(
|
||||
device: &Device,
|
||||
count: usize,
|
||||
read_write: bool,
|
||||
) -> GpuResult<Self> {
|
||||
let size = count * std::mem::size_of::<T>();
|
||||
let usage = if read_write {
|
||||
BufferUsage::Residuals
|
||||
} else {
|
||||
BufferUsage::NodeStates
|
||||
};
|
||||
Self::new(device, size, usage, "storage_buffer_uninit")
|
||||
}
|
||||
|
||||
/// Create a new uniform buffer with data
|
||||
pub fn new_uniform<T: Pod>(device: &Device, queue: &Queue, data: &T) -> GpuResult<Self> {
|
||||
Self::new_with_data(
|
||||
device,
|
||||
queue,
|
||||
std::slice::from_ref(data),
|
||||
BufferUsage::Uniforms,
|
||||
"uniform_buffer",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_size_bucket() {
|
||||
assert_eq!(GpuBufferManager::size_bucket(100), 256);
|
||||
assert_eq!(GpuBufferManager::size_bucket(256), 256);
|
||||
assert_eq!(GpuBufferManager::size_bucket(257), 512);
|
||||
assert_eq!(GpuBufferManager::size_bucket(1000), 1024);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gpu_params_alignment() {
|
||||
// Ensure our GPU structs are properly aligned for wgpu
|
||||
assert_eq!(std::mem::size_of::<GpuParams>(), 32);
|
||||
assert_eq!(std::mem::align_of::<GpuParams>(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gpu_edge_alignment() {
|
||||
assert_eq!(std::mem::size_of::<GpuEdge>(), 32);
|
||||
assert_eq!(std::mem::align_of::<GpuEdge>(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gpu_restriction_map_alignment() {
|
||||
assert_eq!(std::mem::size_of::<GpuRestrictionMap>(), 32);
|
||||
assert_eq!(std::mem::align_of::<GpuRestrictionMap>(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_buffer_usage_flags() {
|
||||
let readonly = BufferUsageFlags::storage_readonly();
|
||||
assert!(readonly.storage_read);
|
||||
assert!(!readonly.storage_write);
|
||||
|
||||
let readwrite = BufferUsageFlags::storage_readwrite();
|
||||
assert!(readwrite.storage_read);
|
||||
assert!(readwrite.storage_write);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_buffer_key_equality() {
|
||||
let key1 = BufferKey::new(1024, BufferUsageFlags::storage_readonly());
|
||||
let key2 = BufferKey::new(1024, BufferUsageFlags::storage_readonly());
|
||||
let key3 = BufferKey::new(2048, BufferUsageFlags::storage_readonly());
|
||||
|
||||
assert_eq!(key1, key2);
|
||||
assert_ne!(key1, key3);
|
||||
}
|
||||
}
|
||||
290
vendor/ruvector/crates/prime-radiant/src/gpu/device.rs
vendored
Normal file
290
vendor/ruvector/crates/prime-radiant/src/gpu/device.rs
vendored
Normal file
@@ -0,0 +1,290 @@
|
||||
//! GPU device initialization and management.
|
||||
//!
|
||||
//! This module provides the core GPU device abstraction using wgpu,
|
||||
//! handling adapter selection, device creation, and queue management.
|
||||
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug, info, warn};
|
||||
use wgpu::{Adapter, Device, Instance, Queue};
|
||||
|
||||
use super::error::{GpuError, GpuResult};
|
||||
|
||||
/// Information about the GPU device
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GpuDeviceInfo {
|
||||
/// Device name
|
||||
pub name: String,
|
||||
/// Vendor ID
|
||||
pub vendor: u32,
|
||||
/// Device ID
|
||||
pub device_id: u32,
|
||||
/// Device type (discrete, integrated, etc.)
|
||||
pub device_type: String,
|
||||
/// Backend API (Vulkan, Metal, DX12, etc.)
|
||||
pub backend: String,
|
||||
/// Maximum buffer size
|
||||
pub max_buffer_size: u64,
|
||||
/// Maximum compute workgroup size per dimension
|
||||
pub max_workgroup_size: [u32; 3],
|
||||
/// Maximum compute workgroups per dimension
|
||||
pub max_workgroups: [u32; 3],
|
||||
/// Maximum storage buffers per shader stage
|
||||
pub max_storage_buffers: u32,
|
||||
}
|
||||
|
||||
/// GPU device wrapper providing access to wgpu resources
|
||||
pub struct GpuDevice {
|
||||
instance: Instance,
|
||||
adapter: Adapter,
|
||||
device: Arc<Device>,
|
||||
queue: Arc<Queue>,
|
||||
info: GpuDeviceInfo,
|
||||
}
|
||||
|
||||
impl GpuDevice {
|
||||
/// Create a new GPU device with default configuration.
|
||||
///
|
||||
/// This will:
|
||||
/// 1. Create a wgpu instance with all available backends
|
||||
/// 2. Request a high-performance adapter
|
||||
/// 3. Create the device and queue
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `GpuError::NoAdapter` if no suitable GPU is found.
|
||||
/// Returns `GpuError::DeviceRequestFailed` if device creation fails.
|
||||
pub async fn new() -> GpuResult<Self> {
|
||||
Self::with_options(GpuDeviceOptions::default()).await
|
||||
}
|
||||
|
||||
/// Create a new GPU device with custom options.
|
||||
pub async fn with_options(options: GpuDeviceOptions) -> GpuResult<Self> {
|
||||
let instance = Instance::new(wgpu::InstanceDescriptor {
|
||||
backends: options.backends,
|
||||
flags: wgpu::InstanceFlags::default(),
|
||||
dx12_shader_compiler: wgpu::Dx12Compiler::default(),
|
||||
gles_minor_version: wgpu::Gles3MinorVersion::default(),
|
||||
});
|
||||
|
||||
debug!(
|
||||
"Created wgpu instance with backends: {:?}",
|
||||
options.backends
|
||||
);
|
||||
|
||||
let adapter = instance
|
||||
.request_adapter(&wgpu::RequestAdapterOptions {
|
||||
power_preference: options.power_preference,
|
||||
compatible_surface: None,
|
||||
force_fallback_adapter: options.force_fallback,
|
||||
})
|
||||
.await
|
||||
.ok_or(GpuError::NoAdapter)?;
|
||||
|
||||
let adapter_info = adapter.get_info();
|
||||
info!(
|
||||
"Selected GPU adapter: {} ({:?})",
|
||||
adapter_info.name, adapter_info.backend
|
||||
);
|
||||
|
||||
let limits = if options.use_downlevel_limits {
|
||||
wgpu::Limits::downlevel_defaults()
|
||||
} else {
|
||||
wgpu::Limits::default()
|
||||
};
|
||||
|
||||
let (device, queue) = adapter
|
||||
.request_device(
|
||||
&wgpu::DeviceDescriptor {
|
||||
label: Some("prime-radiant-gpu"),
|
||||
required_features: options.required_features,
|
||||
required_limits: limits.clone(),
|
||||
memory_hints: wgpu::MemoryHints::Performance,
|
||||
},
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Set up error handling
|
||||
device.on_uncaptured_error(Box::new(|error| {
|
||||
warn!("Uncaptured GPU error: {:?}", error);
|
||||
}));
|
||||
|
||||
let info = GpuDeviceInfo {
|
||||
name: adapter_info.name.clone(),
|
||||
vendor: adapter_info.vendor,
|
||||
device_id: adapter_info.device,
|
||||
device_type: format!("{:?}", adapter_info.device_type),
|
||||
backend: format!("{:?}", adapter_info.backend),
|
||||
max_buffer_size: limits.max_buffer_size as u64,
|
||||
max_workgroup_size: [
|
||||
limits.max_compute_workgroup_size_x,
|
||||
limits.max_compute_workgroup_size_y,
|
||||
limits.max_compute_workgroup_size_z,
|
||||
],
|
||||
max_workgroups: [
|
||||
limits.max_compute_workgroups_per_dimension,
|
||||
limits.max_compute_workgroups_per_dimension,
|
||||
limits.max_compute_workgroups_per_dimension,
|
||||
],
|
||||
max_storage_buffers: limits.max_storage_buffers_per_shader_stage,
|
||||
};
|
||||
|
||||
debug!("GPU device info: {:?}", info);
|
||||
|
||||
Ok(Self {
|
||||
instance,
|
||||
adapter,
|
||||
device: Arc::new(device),
|
||||
queue: Arc::new(queue),
|
||||
info,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get a reference to the wgpu device
|
||||
pub fn device(&self) -> &Device {
|
||||
&self.device
|
||||
}
|
||||
|
||||
/// Get a shared reference to the wgpu device
|
||||
pub fn device_arc(&self) -> Arc<Device> {
|
||||
Arc::clone(&self.device)
|
||||
}
|
||||
|
||||
/// Get a reference to the command queue
|
||||
pub fn queue(&self) -> &Queue {
|
||||
&self.queue
|
||||
}
|
||||
|
||||
/// Get a shared reference to the command queue
|
||||
pub fn queue_arc(&self) -> Arc<Queue> {
|
||||
Arc::clone(&self.queue)
|
||||
}
|
||||
|
||||
/// Get device information
|
||||
pub fn info(&self) -> &GpuDeviceInfo {
|
||||
&self.info
|
||||
}
|
||||
|
||||
/// Get the wgpu instance
|
||||
pub fn instance(&self) -> &Instance {
|
||||
&self.instance
|
||||
}
|
||||
|
||||
/// Get the wgpu adapter
|
||||
pub fn adapter(&self) -> &Adapter {
|
||||
&self.adapter
|
||||
}
|
||||
|
||||
/// Check if a feature is supported
|
||||
pub fn supports_feature(&self, feature: wgpu::Features) -> bool {
|
||||
self.adapter.features().contains(feature)
|
||||
}
|
||||
|
||||
/// Poll the device for completed work.
|
||||
///
|
||||
/// This is useful when you need to ensure GPU work has completed
|
||||
/// before continuing on the CPU.
|
||||
pub fn poll(&self, wait: bool) -> bool {
|
||||
self.device
|
||||
.poll(if wait {
|
||||
wgpu::Maintain::Wait
|
||||
} else {
|
||||
wgpu::Maintain::Poll
|
||||
})
|
||||
.is_queue_empty()
|
||||
}
|
||||
|
||||
/// Submit a command buffer to the queue
|
||||
pub fn submit(&self, command_buffer: wgpu::CommandBuffer) -> wgpu::SubmissionIndex {
|
||||
self.queue.submit(std::iter::once(command_buffer))
|
||||
}
|
||||
|
||||
/// Submit multiple command buffers to the queue
|
||||
pub fn submit_multiple(
|
||||
&self,
|
||||
command_buffers: impl IntoIterator<Item = wgpu::CommandBuffer>,
|
||||
) -> wgpu::SubmissionIndex {
|
||||
self.queue.submit(command_buffers)
|
||||
}
|
||||
}
|
||||
|
||||
/// Options for GPU device creation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GpuDeviceOptions {
|
||||
/// Backends to use (default: all)
|
||||
pub backends: wgpu::Backends,
|
||||
/// Power preference (default: high performance)
|
||||
pub power_preference: wgpu::PowerPreference,
|
||||
/// Required GPU features
|
||||
pub required_features: wgpu::Features,
|
||||
/// Use downlevel limits for broader compatibility
|
||||
pub use_downlevel_limits: bool,
|
||||
/// Force fallback adapter (software rendering)
|
||||
pub force_fallback: bool,
|
||||
}
|
||||
|
||||
impl Default for GpuDeviceOptions {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
backends: wgpu::Backends::all(),
|
||||
power_preference: wgpu::PowerPreference::HighPerformance,
|
||||
required_features: wgpu::Features::empty(),
|
||||
use_downlevel_limits: false,
|
||||
force_fallback: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl GpuDeviceOptions {
|
||||
/// Create options for low-power mode (integrated GPU preferred)
|
||||
pub fn low_power() -> Self {
|
||||
Self {
|
||||
power_preference: wgpu::PowerPreference::LowPower,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create options for maximum compatibility
|
||||
pub fn compatible() -> Self {
|
||||
Self {
|
||||
use_downlevel_limits: true,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create options for software fallback
|
||||
pub fn software() -> Self {
|
||||
Self {
|
||||
force_fallback: true,
|
||||
use_downlevel_limits: true,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_device_options_default() {
|
||||
let options = GpuDeviceOptions::default();
|
||||
assert_eq!(
|
||||
options.power_preference,
|
||||
wgpu::PowerPreference::HighPerformance
|
||||
);
|
||||
assert!(!options.force_fallback);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_device_options_low_power() {
|
||||
let options = GpuDeviceOptions::low_power();
|
||||
assert_eq!(options.power_preference, wgpu::PowerPreference::LowPower);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_device_options_compatible() {
|
||||
let options = GpuDeviceOptions::compatible();
|
||||
assert!(options.use_downlevel_limits);
|
||||
}
|
||||
}
|
||||
426
vendor/ruvector/crates/prime-radiant/src/gpu/dispatch.rs
vendored
Normal file
426
vendor/ruvector/crates/prime-radiant/src/gpu/dispatch.rs
vendored
Normal file
@@ -0,0 +1,426 @@
|
||||
//! Kernel dispatch and synchronization for GPU compute operations.
|
||||
//!
|
||||
//! This module provides the dispatcher for executing compute kernels on the GPU,
|
||||
//! including support for:
|
||||
//! - Single kernel dispatch
|
||||
//! - Indirect dispatch (workgroup count from GPU buffer)
|
||||
//! - Chained dispatch for fused kernels
|
||||
//! - Synchronization and timing
|
||||
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug, trace};
|
||||
use wgpu::{CommandEncoder, Device, Queue};
|
||||
|
||||
use super::buffer::{GpuBuffer, GpuBufferPool};
|
||||
use super::device::GpuDevice;
|
||||
use super::error::{GpuError, GpuResult};
|
||||
use super::pipeline::{ComputePipeline, PipelineCache};
|
||||
|
||||
/// Configuration for a dispatch operation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DispatchConfig {
|
||||
/// Label for debugging
|
||||
pub label: Option<String>,
|
||||
/// Whether to wait for completion
|
||||
pub wait: bool,
|
||||
/// Timeout in milliseconds (0 = no timeout)
|
||||
pub timeout_ms: u64,
|
||||
}
|
||||
|
||||
impl Default for DispatchConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
label: None,
|
||||
wait: false,
|
||||
timeout_ms: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DispatchConfig {
|
||||
/// Create a config that waits for completion
|
||||
pub fn wait() -> Self {
|
||||
Self {
|
||||
wait: true,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a config with a label
|
||||
pub fn with_label(label: impl Into<String>) -> Self {
|
||||
Self {
|
||||
label: Some(label.into()),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the timeout
|
||||
pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
|
||||
self.timeout_ms = timeout_ms;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set wait flag
|
||||
pub fn with_wait(mut self, wait: bool) -> Self {
|
||||
self.wait = wait;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// GPU dispatcher for executing compute kernels
|
||||
pub struct GpuDispatcher {
|
||||
device: Arc<GpuDevice>,
|
||||
pipeline_cache: PipelineCache,
|
||||
buffer_pool: GpuBufferPool,
|
||||
}
|
||||
|
||||
impl GpuDispatcher {
|
||||
/// Create a new dispatcher
|
||||
pub fn new(device: Arc<GpuDevice>) -> Self {
|
||||
let pipeline_cache = PipelineCache::new(device.device_arc());
|
||||
let buffer_pool = GpuBufferPool::new(device.device_arc());
|
||||
|
||||
Self {
|
||||
device,
|
||||
pipeline_cache,
|
||||
buffer_pool,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the underlying GPU device
|
||||
pub fn device(&self) -> &GpuDevice {
|
||||
&self.device
|
||||
}
|
||||
|
||||
/// Get the pipeline cache
|
||||
pub fn pipeline_cache(&self) -> &PipelineCache {
|
||||
&self.pipeline_cache
|
||||
}
|
||||
|
||||
/// Get the buffer pool
|
||||
pub fn buffer_pool(&self) -> &GpuBufferPool {
|
||||
&self.buffer_pool
|
||||
}
|
||||
|
||||
/// Dispatch a compute kernel.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `pipeline` - The compute pipeline to execute
|
||||
/// * `bind_group` - The bind group with buffer bindings
|
||||
/// * `workgroups` - Number of workgroups [x, y, z]
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// dispatcher.dispatch(&pipeline, &bind_group, [4, 1, 1]).await?;
|
||||
/// ```
|
||||
pub async fn dispatch(
|
||||
&self,
|
||||
pipeline: &ComputePipeline,
|
||||
bind_group: &wgpu::BindGroup,
|
||||
workgroups: [u32; 3],
|
||||
) -> GpuResult<()> {
|
||||
self.dispatch_with_config(pipeline, bind_group, workgroups, DispatchConfig::default())
|
||||
.await
|
||||
}
|
||||
|
||||
/// Dispatch with custom configuration.
|
||||
pub async fn dispatch_with_config(
|
||||
&self,
|
||||
pipeline: &ComputePipeline,
|
||||
bind_group: &wgpu::BindGroup,
|
||||
workgroups: [u32; 3],
|
||||
config: DispatchConfig,
|
||||
) -> GpuResult<()> {
|
||||
// Validate workgroup count
|
||||
let limits = &self.device.info().max_workgroups;
|
||||
if workgroups[0] > limits[0] || workgroups[1] > limits[1] || workgroups[2] > limits[2] {
|
||||
return Err(GpuError::InvalidWorkgroupSize {
|
||||
x: workgroups[0],
|
||||
y: workgroups[1],
|
||||
z: workgroups[2],
|
||||
});
|
||||
}
|
||||
|
||||
let label = config.label.as_deref().unwrap_or("dispatch");
|
||||
debug!(
|
||||
"Dispatching '{}' with workgroups [{}, {}, {}]",
|
||||
label, workgroups[0], workgroups[1], workgroups[2]
|
||||
);
|
||||
|
||||
let mut encoder = self
|
||||
.device
|
||||
.device()
|
||||
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some(label) });
|
||||
|
||||
{
|
||||
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
|
||||
label: Some(label),
|
||||
timestamp_writes: None,
|
||||
});
|
||||
|
||||
pass.set_pipeline(pipeline.pipeline());
|
||||
pass.set_bind_group(0, Some(bind_group), &[]);
|
||||
pass.dispatch_workgroups(workgroups[0], workgroups[1], workgroups[2]);
|
||||
}
|
||||
|
||||
self.device.submit(encoder.finish());
|
||||
|
||||
if config.wait {
|
||||
self.device.poll(true);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Dispatch using indirect workgroup count from a buffer.
|
||||
///
|
||||
/// The indirect buffer must contain [x, y, z] workgroup counts as u32.
|
||||
pub async fn dispatch_indirect(
|
||||
&self,
|
||||
pipeline: &ComputePipeline,
|
||||
bind_group: &wgpu::BindGroup,
|
||||
indirect_buffer: &GpuBuffer,
|
||||
) -> GpuResult<()> {
|
||||
self.dispatch_indirect_with_config(
|
||||
pipeline,
|
||||
bind_group,
|
||||
indirect_buffer,
|
||||
0,
|
||||
DispatchConfig::default(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Dispatch indirect with offset and configuration.
|
||||
pub async fn dispatch_indirect_with_config(
|
||||
&self,
|
||||
pipeline: &ComputePipeline,
|
||||
bind_group: &wgpu::BindGroup,
|
||||
indirect_buffer: &GpuBuffer,
|
||||
indirect_offset: u64,
|
||||
config: DispatchConfig,
|
||||
) -> GpuResult<()> {
|
||||
let label = config.label.as_deref().unwrap_or("dispatch_indirect");
|
||||
debug!("Dispatching indirect '{}'", label);
|
||||
|
||||
let mut encoder = self
|
||||
.device
|
||||
.device()
|
||||
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some(label) });
|
||||
|
||||
{
|
||||
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
|
||||
label: Some(label),
|
||||
timestamp_writes: None,
|
||||
});
|
||||
|
||||
pass.set_pipeline(pipeline.pipeline());
|
||||
pass.set_bind_group(0, Some(bind_group), &[]);
|
||||
pass.dispatch_workgroups_indirect(indirect_buffer.buffer(), indirect_offset);
|
||||
}
|
||||
|
||||
self.device.submit(encoder.finish());
|
||||
|
||||
if config.wait {
|
||||
self.device.poll(true);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Dispatch multiple kernels in a chain (fused execution).
|
||||
///
|
||||
/// All dispatches are recorded into a single command buffer for
|
||||
/// optimal GPU utilization.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `dispatches` - List of (pipeline, bind_group, workgroups) tuples
|
||||
pub async fn dispatch_chain(
|
||||
&self,
|
||||
dispatches: &[(&ComputePipeline, &wgpu::BindGroup, [u32; 3])],
|
||||
) -> GpuResult<()> {
|
||||
self.dispatch_chain_with_config(dispatches, DispatchConfig::default())
|
||||
.await
|
||||
}
|
||||
|
||||
/// Dispatch chain with custom configuration.
|
||||
pub async fn dispatch_chain_with_config(
|
||||
&self,
|
||||
dispatches: &[(&ComputePipeline, &wgpu::BindGroup, [u32; 3])],
|
||||
config: DispatchConfig,
|
||||
) -> GpuResult<()> {
|
||||
if dispatches.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let label = config.label.as_deref().unwrap_or("dispatch_chain");
|
||||
debug!(
|
||||
"Dispatching chain '{}' with {} kernels",
|
||||
label,
|
||||
dispatches.len()
|
||||
);
|
||||
|
||||
let mut encoder = self
|
||||
.device
|
||||
.device()
|
||||
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some(label) });
|
||||
|
||||
for (i, (pipeline, bind_group, workgroups)) in dispatches.iter().enumerate() {
|
||||
trace!(
|
||||
"Chain dispatch {}: workgroups [{}, {}, {}]",
|
||||
i,
|
||||
workgroups[0],
|
||||
workgroups[1],
|
||||
workgroups[2]
|
||||
);
|
||||
|
||||
let pass_label = format!("{}_pass_{}", label, i);
|
||||
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
|
||||
label: Some(&pass_label),
|
||||
timestamp_writes: None,
|
||||
});
|
||||
|
||||
pass.set_pipeline(pipeline.pipeline());
|
||||
pass.set_bind_group(0, Some(*bind_group), &[]);
|
||||
pass.dispatch_workgroups(workgroups[0], workgroups[1], workgroups[2]);
|
||||
}
|
||||
|
||||
self.device.submit(encoder.finish());
|
||||
|
||||
if config.wait {
|
||||
self.device.poll(true);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Record dispatches to a command encoder without submitting.
|
||||
///
|
||||
/// This is useful when you want to combine compute with other operations.
|
||||
pub fn record_dispatch(
|
||||
&self,
|
||||
encoder: &mut CommandEncoder,
|
||||
pipeline: &ComputePipeline,
|
||||
bind_group: &wgpu::BindGroup,
|
||||
workgroups: [u32; 3],
|
||||
label: Option<&str>,
|
||||
) {
|
||||
let pass_label = label.unwrap_or("recorded_dispatch");
|
||||
|
||||
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
|
||||
label: Some(pass_label),
|
||||
timestamp_writes: None,
|
||||
});
|
||||
|
||||
pass.set_pipeline(pipeline.pipeline());
|
||||
pass.set_bind_group(0, Some(bind_group), &[]);
|
||||
pass.dispatch_workgroups(workgroups[0], workgroups[1], workgroups[2]);
|
||||
}
|
||||
|
||||
/// Wait for all pending GPU work to complete.
|
||||
pub fn synchronize(&self) {
|
||||
self.device.poll(true);
|
||||
}
|
||||
|
||||
/// Poll for completed work without blocking.
|
||||
pub fn poll(&self) -> bool {
|
||||
self.device.poll(false)
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for constructing complex dispatch operations
|
||||
pub struct DispatchBuilder<'a> {
|
||||
dispatcher: &'a GpuDispatcher,
|
||||
dispatches: Vec<(Arc<ComputePipeline>, wgpu::BindGroup, [u32; 3])>,
|
||||
config: DispatchConfig,
|
||||
}
|
||||
|
||||
impl<'a> DispatchBuilder<'a> {
|
||||
/// Create a new dispatch builder
|
||||
pub fn new(dispatcher: &'a GpuDispatcher) -> Self {
|
||||
Self {
|
||||
dispatcher,
|
||||
dispatches: Vec::new(),
|
||||
config: DispatchConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a dispatch to the chain
|
||||
pub fn add(
|
||||
mut self,
|
||||
pipeline: Arc<ComputePipeline>,
|
||||
bind_group: wgpu::BindGroup,
|
||||
workgroups: [u32; 3],
|
||||
) -> Self {
|
||||
self.dispatches.push((pipeline, bind_group, workgroups));
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the configuration
|
||||
pub fn config(mut self, config: DispatchConfig) -> Self {
|
||||
self.config = config;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the label
|
||||
pub fn label(mut self, label: impl Into<String>) -> Self {
|
||||
self.config.label = Some(label.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Set wait flag
|
||||
pub fn wait(mut self) -> Self {
|
||||
self.config.wait = true;
|
||||
self
|
||||
}
|
||||
|
||||
/// Execute all dispatches
|
||||
pub async fn execute(self) -> GpuResult<()> {
|
||||
if self.dispatches.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let refs: Vec<(&ComputePipeline, &wgpu::BindGroup, [u32; 3])> = self
|
||||
.dispatches
|
||||
.iter()
|
||||
.map(|(p, b, w)| (p.as_ref(), b, *w))
|
||||
.collect();
|
||||
|
||||
self.dispatcher
|
||||
.dispatch_chain_with_config(&refs, self.config)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_dispatch_config_default() {
|
||||
let config = DispatchConfig::default();
|
||||
assert!(!config.wait);
|
||||
assert!(config.label.is_none());
|
||||
assert_eq!(config.timeout_ms, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dispatch_config_wait() {
|
||||
let config = DispatchConfig::wait();
|
||||
assert!(config.wait);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dispatch_config_builder() {
|
||||
let config = DispatchConfig::with_label("test")
|
||||
.with_timeout(1000)
|
||||
.with_wait(true);
|
||||
|
||||
assert_eq!(config.label.as_deref(), Some("test"));
|
||||
assert_eq!(config.timeout_ms, 1000);
|
||||
assert!(config.wait);
|
||||
}
|
||||
}
|
||||
814
vendor/ruvector/crates/prime-radiant/src/gpu/engine.rs
vendored
Normal file
814
vendor/ruvector/crates/prime-radiant/src/gpu/engine.rs
vendored
Normal file
@@ -0,0 +1,814 @@
|
||||
//! GPU Coherence Engine
|
||||
//!
|
||||
//! Main entry point for GPU-accelerated coherence computation.
|
||||
//! Provides automatic CPU fallback when GPU is unavailable.
|
||||
|
||||
use super::buffer::{BufferUsage, GpuBufferManager, GpuEdge, GpuParams, GpuRestrictionMap};
|
||||
use super::error::{GpuError, GpuResult};
|
||||
use super::kernels::{
|
||||
ComputeEnergyKernel, ComputeResidualsKernel, EnergyParams, SheafAttentionKernel,
|
||||
TokenRoutingKernel,
|
||||
};
|
||||
use crate::coherence::{CoherenceEnergy as CpuCoherenceEnergy, EdgeEnergy};
|
||||
use crate::substrate::restriction::MatrixStorage;
|
||||
use crate::substrate::{EdgeId, NodeId, SheafGraph};
|
||||
|
||||
use chrono::Utc;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug, info, warn};
|
||||
use wgpu::{
|
||||
Adapter, Device, DeviceDescriptor, Features, Instance, InstanceDescriptor, Limits,
|
||||
PowerPreference, Queue, RequestAdapterOptions,
|
||||
};
|
||||
|
||||
/// GPU configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GpuConfig {
|
||||
/// Preferred power preference (high performance vs low power)
|
||||
pub power_preference: PowerPreference,
|
||||
/// Enable CPU fallback when GPU is unavailable
|
||||
pub enable_fallback: bool,
|
||||
/// Maximum buffer size in bytes (0 = no limit)
|
||||
pub max_buffer_size: usize,
|
||||
/// Beta parameter for attention computation
|
||||
pub beta: f32,
|
||||
/// Lane 0 (reflex) threshold
|
||||
pub threshold_lane0: f32,
|
||||
/// Lane 1 (retrieval) threshold
|
||||
pub threshold_lane1: f32,
|
||||
/// Lane 2 (heavy) threshold
|
||||
pub threshold_lane2: f32,
|
||||
/// Timeout for GPU operations in milliseconds
|
||||
pub timeout_ms: u64,
|
||||
}
|
||||
|
||||
impl Default for GpuConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
power_preference: PowerPreference::HighPerformance,
|
||||
enable_fallback: true,
|
||||
max_buffer_size: 0, // No limit
|
||||
beta: 1.0,
|
||||
threshold_lane0: 0.01,
|
||||
threshold_lane1: 0.1,
|
||||
threshold_lane2: 1.0,
|
||||
timeout_ms: 5000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// GPU capabilities and limits
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GpuCapabilities {
|
||||
/// Device name
|
||||
pub device_name: String,
|
||||
/// Vendor
|
||||
pub vendor: String,
|
||||
/// Backend (Vulkan, Metal, DX12, etc.)
|
||||
pub backend: String,
|
||||
/// Maximum buffer size
|
||||
pub max_buffer_size: u64,
|
||||
/// Maximum compute workgroup size
|
||||
pub max_workgroup_size: u32,
|
||||
/// Maximum compute workgroups per dimension
|
||||
pub max_workgroups: [u32; 3],
|
||||
/// Whether the GPU supports required features
|
||||
pub supported: bool,
|
||||
}
|
||||
|
||||
/// GPU energy result
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GpuCoherenceEnergy {
|
||||
/// Total system energy
|
||||
pub total_energy: f32,
|
||||
/// Per-edge energies
|
||||
pub edge_energies: Vec<f32>,
|
||||
/// Edge indices (matches edge_energies)
|
||||
pub edge_indices: Vec<EdgeId>,
|
||||
/// Computation time in microseconds
|
||||
pub compute_time_us: u64,
|
||||
/// Whether GPU was used (false = CPU fallback)
|
||||
pub used_gpu: bool,
|
||||
}
|
||||
|
||||
impl GpuCoherenceEnergy {
|
||||
/// Convert to CPU CoherenceEnergy format
|
||||
pub fn to_cpu_format(&self, graph: &SheafGraph) -> CpuCoherenceEnergy {
|
||||
let mut edge_energy_map = HashMap::new();
|
||||
|
||||
for (i, &edge_id) in self.edge_indices.iter().enumerate() {
|
||||
let energy = self.edge_energies[i];
|
||||
if let Some(edge) = graph.get_edge(edge_id) {
|
||||
let edge_energy = EdgeEnergy::new_lightweight(
|
||||
edge_id.to_string(),
|
||||
edge.source.to_string(),
|
||||
edge.target.to_string(),
|
||||
energy / edge.weight.max(0.001), // Remove weight to get raw norm_sq
|
||||
edge.weight,
|
||||
);
|
||||
edge_energy_map.insert(edge_id.to_string(), edge_energy);
|
||||
}
|
||||
}
|
||||
|
||||
CpuCoherenceEnergy::new(
|
||||
edge_energy_map,
|
||||
&HashMap::new(),
|
||||
graph.node_count(),
|
||||
format!("gpu-{}", Utc::now().timestamp()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// GPU-accelerated coherence engine
|
||||
pub struct GpuCoherenceEngine {
|
||||
device: Arc<Device>,
|
||||
queue: Arc<Queue>,
|
||||
buffer_manager: GpuBufferManager,
|
||||
config: GpuConfig,
|
||||
capabilities: GpuCapabilities,
|
||||
|
||||
// Kernels
|
||||
residuals_kernel: ComputeResidualsKernel,
|
||||
energy_kernel: ComputeEnergyKernel,
|
||||
attention_kernel: SheafAttentionKernel,
|
||||
routing_kernel: TokenRoutingKernel,
|
||||
|
||||
// Cached graph data
|
||||
graph_data: Option<GpuGraphData>,
|
||||
}
|
||||
|
||||
/// Cached graph data on GPU
|
||||
struct GpuGraphData {
|
||||
num_nodes: u32,
|
||||
num_edges: u32,
|
||||
state_dim: u32,
|
||||
node_id_map: HashMap<NodeId, u32>,
|
||||
edge_id_map: HashMap<EdgeId, u32>,
|
||||
edge_id_reverse: Vec<EdgeId>,
|
||||
|
||||
// Pre-allocated computation buffers (eliminates per-frame allocations)
|
||||
params_buffer: wgpu::Buffer,
|
||||
energy_params_buffer: wgpu::Buffer,
|
||||
partial_sums_buffer: wgpu::Buffer,
|
||||
final_params_buffer: wgpu::Buffer,
|
||||
total_energy_buffer: wgpu::Buffer,
|
||||
energies_staging: wgpu::Buffer,
|
||||
total_staging: wgpu::Buffer,
|
||||
|
||||
// Pre-computed workgroup count for energy reduction
|
||||
num_workgroups: u32,
|
||||
}
|
||||
|
||||
impl GpuCoherenceEngine {
|
||||
/// Create a new GPU coherence engine
|
||||
pub async fn new(config: GpuConfig) -> GpuResult<Self> {
|
||||
// Create wgpu instance
|
||||
let instance = Instance::new(InstanceDescriptor::default());
|
||||
|
||||
// Request adapter
|
||||
let adapter = instance
|
||||
.request_adapter(&RequestAdapterOptions {
|
||||
power_preference: config.power_preference,
|
||||
compatible_surface: None,
|
||||
force_fallback_adapter: false,
|
||||
})
|
||||
.await
|
||||
.ok_or_else(|| GpuError::AdapterRequest("No suitable GPU adapter found".into()))?;
|
||||
|
||||
let capabilities = Self::get_capabilities(&adapter);
|
||||
if !capabilities.supported {
|
||||
return Err(GpuError::UnsupportedFeature(
|
||||
"GPU does not support required features".into(),
|
||||
));
|
||||
}
|
||||
|
||||
info!(
|
||||
"Using GPU: {} ({}) - {}",
|
||||
capabilities.device_name, capabilities.vendor, capabilities.backend
|
||||
);
|
||||
|
||||
// Request device
|
||||
let (device, queue) = adapter
|
||||
.request_device(
|
||||
&DeviceDescriptor {
|
||||
label: Some("prime_radiant_gpu"),
|
||||
required_features: Features::empty(),
|
||||
required_limits: Limits::default(),
|
||||
memory_hints: Default::default(),
|
||||
},
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| GpuError::DeviceCreation(e.to_string()))?;
|
||||
|
||||
let device = Arc::new(device);
|
||||
let queue = Arc::new(queue);
|
||||
|
||||
// Create kernels
|
||||
let residuals_kernel = ComputeResidualsKernel::new(&device)?;
|
||||
let energy_kernel = ComputeEnergyKernel::new(&device)?;
|
||||
let attention_kernel = SheafAttentionKernel::new(&device)?;
|
||||
let routing_kernel = TokenRoutingKernel::new(&device)?;
|
||||
|
||||
// Create buffer manager
|
||||
let buffer_manager = GpuBufferManager::new(device.clone(), queue.clone());
|
||||
|
||||
Ok(Self {
|
||||
device,
|
||||
queue,
|
||||
buffer_manager,
|
||||
config,
|
||||
capabilities,
|
||||
residuals_kernel,
|
||||
energy_kernel,
|
||||
attention_kernel,
|
||||
routing_kernel,
|
||||
graph_data: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Try to create a GPU engine, returning None if GPU is unavailable
|
||||
pub async fn try_new(config: GpuConfig) -> Option<Self> {
|
||||
match Self::new(config).await {
|
||||
Ok(engine) => Some(engine),
|
||||
Err(e) => {
|
||||
warn!("GPU initialization failed: {}. Will use CPU fallback.", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get GPU capabilities
|
||||
fn get_capabilities(adapter: &Adapter) -> GpuCapabilities {
|
||||
let info = adapter.get_info();
|
||||
let limits = adapter.limits();
|
||||
|
||||
GpuCapabilities {
|
||||
device_name: info.name,
|
||||
vendor: format!("{:?}", info.vendor),
|
||||
backend: format!("{:?}", info.backend),
|
||||
max_buffer_size: limits.max_buffer_size as u64,
|
||||
max_workgroup_size: limits.max_compute_workgroup_size_x,
|
||||
max_workgroups: [
|
||||
limits.max_compute_workgroups_per_dimension,
|
||||
limits.max_compute_workgroups_per_dimension,
|
||||
limits.max_compute_workgroups_per_dimension,
|
||||
],
|
||||
supported: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Upload graph data to GPU
|
||||
pub fn upload_graph(&mut self, graph: &SheafGraph) -> GpuResult<()> {
|
||||
if graph.edge_count() == 0 {
|
||||
return Err(GpuError::EmptyGraph);
|
||||
}
|
||||
|
||||
let num_nodes = graph.node_count() as u32;
|
||||
let num_edges = graph.edge_count() as u32;
|
||||
|
||||
// Build node ID mapping
|
||||
let mut node_id_map = HashMap::new();
|
||||
let node_ids = graph.node_ids();
|
||||
for (i, node_id) in node_ids.iter().enumerate() {
|
||||
node_id_map.insert(*node_id, i as u32);
|
||||
}
|
||||
|
||||
// Determine state dimension from first node
|
||||
let state_dim = node_ids
|
||||
.first()
|
||||
.and_then(|id| graph.get_node(*id))
|
||||
.map(|n| n.dim())
|
||||
.unwrap_or(64) as u32;
|
||||
|
||||
// Flatten node states
|
||||
let mut node_states: Vec<f32> = Vec::with_capacity((num_nodes * state_dim) as usize);
|
||||
for node_id in &node_ids {
|
||||
if let Some(state) = graph.node_state(*node_id) {
|
||||
node_states.extend(state.iter().cloned());
|
||||
// Pad if needed
|
||||
for _ in state.len()..(state_dim as usize) {
|
||||
node_states.push(0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build edge data and restriction maps
|
||||
let mut edges: Vec<GpuEdge> = Vec::with_capacity(num_edges as usize);
|
||||
let mut restriction_maps: Vec<GpuRestrictionMap> = Vec::new();
|
||||
let mut restriction_data: Vec<f32> = Vec::new();
|
||||
let mut edge_id_map = HashMap::new();
|
||||
let mut edge_id_reverse = Vec::new();
|
||||
|
||||
let edge_ids = graph.edge_ids();
|
||||
for (i, edge_id) in edge_ids.iter().enumerate() {
|
||||
edge_id_map.insert(*edge_id, i as u32);
|
||||
edge_id_reverse.push(*edge_id);
|
||||
|
||||
if let Some(edge) = graph.get_edge(*edge_id) {
|
||||
let source_idx = *node_id_map.get(&edge.source).unwrap_or(&0);
|
||||
let target_idx = *node_id_map.get(&edge.target).unwrap_or(&0);
|
||||
|
||||
// Convert restriction maps
|
||||
let rho_source_idx = restriction_maps.len() as u32;
|
||||
let gpu_rho_source =
|
||||
Self::convert_restriction_map(&edge.rho_source, &mut restriction_data);
|
||||
restriction_maps.push(gpu_rho_source);
|
||||
|
||||
let rho_target_idx = restriction_maps.len() as u32;
|
||||
let gpu_rho_target =
|
||||
Self::convert_restriction_map(&edge.rho_target, &mut restriction_data);
|
||||
restriction_maps.push(gpu_rho_target);
|
||||
|
||||
edges.push(GpuEdge {
|
||||
source_idx,
|
||||
target_idx,
|
||||
weight: edge.weight,
|
||||
rho_source_idx,
|
||||
rho_target_idx,
|
||||
comparison_dim: edge.comparison_dim() as u32,
|
||||
_padding: [0; 2],
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure restriction_data is not empty (GPU buffers can't be zero-sized)
|
||||
if restriction_data.is_empty() {
|
||||
restriction_data.push(0.0);
|
||||
}
|
||||
|
||||
// Upload to GPU
|
||||
self.buffer_manager.allocate_with_data(
|
||||
&node_states,
|
||||
BufferUsage::NodeStates,
|
||||
"node_states",
|
||||
)?;
|
||||
|
||||
self.buffer_manager
|
||||
.allocate_with_data(&edges, BufferUsage::EdgeData, "edges")?;
|
||||
|
||||
self.buffer_manager.allocate_with_data(
|
||||
&restriction_maps,
|
||||
BufferUsage::RestrictionMaps,
|
||||
"restriction_maps",
|
||||
)?;
|
||||
|
||||
self.buffer_manager.allocate_with_data(
|
||||
&restriction_data,
|
||||
BufferUsage::RestrictionMaps,
|
||||
"restriction_data",
|
||||
)?;
|
||||
|
||||
// Allocate output buffers
|
||||
let max_comparison_dim = edges
|
||||
.iter()
|
||||
.map(|e| e.comparison_dim)
|
||||
.max()
|
||||
.unwrap_or(state_dim);
|
||||
let residuals_size = (num_edges * max_comparison_dim) as usize * std::mem::size_of::<f32>();
|
||||
let energies_size = num_edges as usize * std::mem::size_of::<f32>();
|
||||
|
||||
self.buffer_manager
|
||||
.allocate(residuals_size, BufferUsage::Residuals, "residuals")?;
|
||||
|
||||
self.buffer_manager
|
||||
.allocate(energies_size, BufferUsage::Energies, "edge_energies")?;
|
||||
|
||||
// Pre-allocate computation buffers to eliminate per-frame allocations
|
||||
let num_workgroups = ComputeEnergyKernel::workgroup_count(num_edges);
|
||||
|
||||
// Params buffer (GpuParams)
|
||||
let params_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
|
||||
label: Some("params_preallocated"),
|
||||
size: std::mem::size_of::<GpuParams>() as u64,
|
||||
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
|
||||
mapped_at_creation: false,
|
||||
});
|
||||
|
||||
// Energy params buffer (EnergyParams)
|
||||
let energy_params_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
|
||||
label: Some("energy_params_preallocated"),
|
||||
size: std::mem::size_of::<EnergyParams>() as u64,
|
||||
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
|
||||
mapped_at_creation: false,
|
||||
});
|
||||
|
||||
// Partial sums buffer (one f32 per workgroup)
|
||||
let partial_sums_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
|
||||
label: Some("partial_sums_preallocated"),
|
||||
size: ((num_workgroups as usize).max(1) * std::mem::size_of::<f32>()) as u64,
|
||||
usage: wgpu::BufferUsages::STORAGE
|
||||
| wgpu::BufferUsages::COPY_SRC
|
||||
| wgpu::BufferUsages::COPY_DST,
|
||||
mapped_at_creation: false,
|
||||
});
|
||||
|
||||
// Final params buffer (for second reduction pass)
|
||||
let final_params_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
|
||||
label: Some("final_params_preallocated"),
|
||||
size: std::mem::size_of::<EnergyParams>() as u64,
|
||||
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
|
||||
mapped_at_creation: false,
|
||||
});
|
||||
|
||||
// Total energy buffer (single f32)
|
||||
let total_energy_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
|
||||
label: Some("total_energy_preallocated"),
|
||||
size: std::mem::size_of::<f32>() as u64,
|
||||
usage: wgpu::BufferUsages::STORAGE
|
||||
| wgpu::BufferUsages::COPY_SRC
|
||||
| wgpu::BufferUsages::COPY_DST,
|
||||
mapped_at_creation: false,
|
||||
});
|
||||
|
||||
// Staging buffer for edge energies readback
|
||||
let energies_staging = self.device.create_buffer(&wgpu::BufferDescriptor {
|
||||
label: Some("energies_staging_preallocated"),
|
||||
size: (num_edges as usize * std::mem::size_of::<f32>()) as u64,
|
||||
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
|
||||
mapped_at_creation: false,
|
||||
});
|
||||
|
||||
// Staging buffer for total energy readback
|
||||
let total_staging = self.device.create_buffer(&wgpu::BufferDescriptor {
|
||||
label: Some("total_staging_preallocated"),
|
||||
size: std::mem::size_of::<f32>() as u64,
|
||||
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
|
||||
mapped_at_creation: false,
|
||||
});
|
||||
|
||||
// Store graph data with pre-allocated buffers
|
||||
self.graph_data = Some(GpuGraphData {
|
||||
num_nodes,
|
||||
num_edges,
|
||||
state_dim,
|
||||
node_id_map,
|
||||
edge_id_map,
|
||||
edge_id_reverse,
|
||||
params_buffer,
|
||||
energy_params_buffer,
|
||||
partial_sums_buffer,
|
||||
final_params_buffer,
|
||||
total_energy_buffer,
|
||||
energies_staging,
|
||||
total_staging,
|
||||
num_workgroups,
|
||||
});
|
||||
|
||||
debug!(
|
||||
"Uploaded graph to GPU: {} nodes, {} edges, state_dim={}, workgroups={}",
|
||||
num_nodes, num_edges, state_dim, num_workgroups
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Convert a RestrictionMap to GPU format
|
||||
fn convert_restriction_map(
|
||||
map: &crate::substrate::RestrictionMap,
|
||||
data: &mut Vec<f32>,
|
||||
) -> GpuRestrictionMap {
|
||||
let data_offset = data.len() as u32;
|
||||
|
||||
let (map_type, data_len) = match &map.matrix {
|
||||
MatrixStorage::Identity => (0, 0),
|
||||
MatrixStorage::Diagonal(scales) => {
|
||||
data.extend(scales.iter().cloned());
|
||||
(1, scales.len() as u32)
|
||||
}
|
||||
MatrixStorage::Projection { indices, .. } => {
|
||||
data.extend(indices.iter().map(|&i| i as f32));
|
||||
(2, indices.len() as u32)
|
||||
}
|
||||
MatrixStorage::Sparse { values, .. } => {
|
||||
// Simplified: just store values (would need row/col in practice)
|
||||
data.extend(values.iter().cloned());
|
||||
(3, values.len() as u32)
|
||||
}
|
||||
MatrixStorage::Csr(csr) => {
|
||||
// CSR format: store values similar to sparse
|
||||
// Note: In practice, GPU would need row_ptr and col_indices too
|
||||
data.extend(csr.values.iter().cloned());
|
||||
(3, csr.values.len() as u32)
|
||||
}
|
||||
MatrixStorage::Dense {
|
||||
data: matrix_data, ..
|
||||
} => {
|
||||
data.extend(matrix_data.iter().cloned());
|
||||
(3, matrix_data.len() as u32)
|
||||
}
|
||||
};
|
||||
|
||||
GpuRestrictionMap {
|
||||
map_type,
|
||||
input_dim: map.input_dim() as u32,
|
||||
output_dim: map.output_dim() as u32,
|
||||
data_offset,
|
||||
data_len,
|
||||
_padding: [0; 3],
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute coherence energy on GPU
|
||||
/// Uses pre-allocated buffers to eliminate per-frame allocations
|
||||
pub async fn compute_energy(&mut self) -> GpuResult<GpuCoherenceEnergy> {
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let graph_data = self
|
||||
.graph_data
|
||||
.as_ref()
|
||||
.ok_or_else(|| GpuError::Internal("Graph not uploaded".into()))?;
|
||||
|
||||
let num_edges = graph_data.num_edges;
|
||||
let num_workgroups = graph_data.num_workgroups;
|
||||
|
||||
// Write params to pre-allocated buffer (no allocation)
|
||||
let params = GpuParams {
|
||||
num_edges,
|
||||
num_nodes: graph_data.num_nodes,
|
||||
state_dim: graph_data.state_dim,
|
||||
beta: self.config.beta,
|
||||
threshold_lane0: self.config.threshold_lane0,
|
||||
threshold_lane1: self.config.threshold_lane1,
|
||||
threshold_lane2: self.config.threshold_lane2,
|
||||
store_residuals: 1, // Store residuals by default for gradient computation
|
||||
};
|
||||
self.queue
|
||||
.write_buffer(&graph_data.params_buffer, 0, bytemuck::bytes_of(¶ms));
|
||||
|
||||
// Write energy params to pre-allocated buffer (no allocation)
|
||||
let energy_params = EnergyParams {
|
||||
num_elements: num_edges,
|
||||
_padding: [0; 7],
|
||||
};
|
||||
self.queue.write_buffer(
|
||||
&graph_data.energy_params_buffer,
|
||||
0,
|
||||
bytemuck::bytes_of(&energy_params),
|
||||
);
|
||||
|
||||
// Get managed buffers for bind group creation
|
||||
let node_states_buf = self
|
||||
.buffer_manager
|
||||
.get("node_states")
|
||||
.ok_or_else(|| GpuError::Internal("Node states buffer not found".into()))?;
|
||||
let edges_buf = self
|
||||
.buffer_manager
|
||||
.get("edges")
|
||||
.ok_or_else(|| GpuError::Internal("Edges buffer not found".into()))?;
|
||||
let restriction_maps_buf = self
|
||||
.buffer_manager
|
||||
.get("restriction_maps")
|
||||
.ok_or_else(|| GpuError::Internal("Restriction maps buffer not found".into()))?;
|
||||
let restriction_data_buf = self
|
||||
.buffer_manager
|
||||
.get("restriction_data")
|
||||
.ok_or_else(|| GpuError::Internal("Restriction data buffer not found".into()))?;
|
||||
let residuals_buf = self
|
||||
.buffer_manager
|
||||
.get("residuals")
|
||||
.ok_or_else(|| GpuError::Internal("Residuals buffer not found".into()))?;
|
||||
let energies_buf = self
|
||||
.buffer_manager
|
||||
.get("edge_energies")
|
||||
.ok_or_else(|| GpuError::Internal("Edge energies buffer not found".into()))?;
|
||||
|
||||
// Create bind group for residuals kernel using pre-allocated params buffer
|
||||
let residuals_bind_group = self.residuals_kernel.create_bind_group_raw(
|
||||
&self.device,
|
||||
&graph_data.params_buffer,
|
||||
&node_states_buf.buffer,
|
||||
&edges_buf.buffer,
|
||||
&restriction_maps_buf.buffer,
|
||||
&restriction_data_buf.buffer,
|
||||
&residuals_buf.buffer,
|
||||
&energies_buf.buffer,
|
||||
);
|
||||
|
||||
// Create bind group for energy reduction using pre-allocated buffers
|
||||
let energy_bind_group = self.energy_kernel.create_bind_group_raw(
|
||||
&self.device,
|
||||
&graph_data.energy_params_buffer,
|
||||
&energies_buf.buffer,
|
||||
&graph_data.partial_sums_buffer,
|
||||
);
|
||||
|
||||
// Create command encoder
|
||||
let mut encoder = self
|
||||
.device
|
||||
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
|
||||
label: Some("compute_energy_encoder"),
|
||||
});
|
||||
|
||||
// Dispatch residuals computation
|
||||
{
|
||||
let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
|
||||
label: Some("compute_residuals_pass"),
|
||||
timestamp_writes: None,
|
||||
});
|
||||
|
||||
compute_pass.set_pipeline(self.residuals_kernel.pipeline());
|
||||
compute_pass.set_bind_group(0, &residuals_bind_group, &[]);
|
||||
compute_pass.dispatch_workgroups(
|
||||
ComputeResidualsKernel::workgroup_count(num_edges),
|
||||
1,
|
||||
1,
|
||||
);
|
||||
}
|
||||
|
||||
// Dispatch energy reduction
|
||||
{
|
||||
let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
|
||||
label: Some("compute_energy_pass"),
|
||||
timestamp_writes: None,
|
||||
});
|
||||
|
||||
compute_pass.set_pipeline(self.energy_kernel.main_pipeline());
|
||||
compute_pass.set_bind_group(0, &energy_bind_group, &[]);
|
||||
compute_pass.dispatch_workgroups(num_workgroups, 1, 1);
|
||||
}
|
||||
|
||||
// If we have multiple workgroups, do final reduction
|
||||
if num_workgroups > 1 {
|
||||
// Write final params to pre-allocated buffer (no allocation)
|
||||
let final_params = EnergyParams {
|
||||
num_elements: num_workgroups,
|
||||
_padding: [0; 7],
|
||||
};
|
||||
self.queue.write_buffer(
|
||||
&graph_data.final_params_buffer,
|
||||
0,
|
||||
bytemuck::bytes_of(&final_params),
|
||||
);
|
||||
|
||||
let final_bind_group = self.energy_kernel.create_bind_group_raw(
|
||||
&self.device,
|
||||
&graph_data.final_params_buffer,
|
||||
&graph_data.partial_sums_buffer,
|
||||
&graph_data.total_energy_buffer,
|
||||
);
|
||||
|
||||
{
|
||||
let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
|
||||
label: Some("final_reduce_pass"),
|
||||
timestamp_writes: None,
|
||||
});
|
||||
|
||||
compute_pass.set_pipeline(self.energy_kernel.final_pipeline());
|
||||
compute_pass.set_bind_group(0, &final_bind_group, &[]);
|
||||
compute_pass.dispatch_workgroups(1, 1, 1);
|
||||
}
|
||||
}
|
||||
|
||||
// Copy results to pre-allocated staging buffers (no allocation)
|
||||
encoder.copy_buffer_to_buffer(
|
||||
&energies_buf.buffer,
|
||||
0,
|
||||
&graph_data.energies_staging,
|
||||
0,
|
||||
(num_edges as usize * std::mem::size_of::<f32>()) as u64,
|
||||
);
|
||||
|
||||
if num_workgroups > 1 {
|
||||
encoder.copy_buffer_to_buffer(
|
||||
&graph_data.total_energy_buffer,
|
||||
0,
|
||||
&graph_data.total_staging,
|
||||
0,
|
||||
std::mem::size_of::<f32>() as u64,
|
||||
);
|
||||
} else {
|
||||
encoder.copy_buffer_to_buffer(
|
||||
&graph_data.partial_sums_buffer,
|
||||
0,
|
||||
&graph_data.total_staging,
|
||||
0,
|
||||
std::mem::size_of::<f32>() as u64,
|
||||
);
|
||||
}
|
||||
|
||||
// Submit commands
|
||||
self.queue.submit(std::iter::once(encoder.finish()));
|
||||
|
||||
// Read back results from pre-allocated staging buffers
|
||||
let edge_energies = Self::read_buffer_f32(
|
||||
&self.device,
|
||||
&graph_data.energies_staging,
|
||||
num_edges as usize,
|
||||
)
|
||||
.await?;
|
||||
let total_energy =
|
||||
Self::read_buffer_f32(&self.device, &graph_data.total_staging, 1).await?[0];
|
||||
|
||||
let compute_time_us = start.elapsed().as_micros() as u64;
|
||||
|
||||
debug!(
|
||||
"GPU energy computation: total={:.6}, {} edges, {}us (pre-allocated buffers)",
|
||||
total_energy, num_edges, compute_time_us
|
||||
);
|
||||
|
||||
Ok(GpuCoherenceEnergy {
|
||||
total_energy,
|
||||
edge_energies,
|
||||
edge_indices: graph_data.edge_id_reverse.clone(),
|
||||
compute_time_us,
|
||||
used_gpu: true,
|
||||
})
|
||||
}
|
||||
|
||||
/// Read f32 buffer back to CPU
|
||||
async fn read_buffer_f32(
|
||||
device: &Device,
|
||||
buffer: &wgpu::Buffer,
|
||||
count: usize,
|
||||
) -> GpuResult<Vec<f32>> {
|
||||
let buffer_slice = buffer.slice(..);
|
||||
|
||||
let (sender, receiver) = futures::channel::oneshot::channel();
|
||||
buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
|
||||
let _ = sender.send(result);
|
||||
});
|
||||
|
||||
device.poll(wgpu::Maintain::Wait);
|
||||
|
||||
receiver
|
||||
.await
|
||||
.map_err(|_| GpuError::BufferRead("Channel closed".into()))?
|
||||
.map_err(|e| GpuError::BufferRead(e.to_string()))?;
|
||||
|
||||
let data = buffer_slice.get_mapped_range();
|
||||
let result: Vec<f32> =
|
||||
bytemuck::cast_slice(&data[..count * std::mem::size_of::<f32>()]).to_vec();
|
||||
|
||||
drop(data);
|
||||
buffer.unmap();
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Get GPU capabilities
|
||||
pub fn capabilities(&self) -> &GpuCapabilities {
|
||||
&self.capabilities
|
||||
}
|
||||
|
||||
/// Get configuration
|
||||
pub fn config(&self) -> &GpuConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Check if GPU is available
|
||||
pub fn is_available(&self) -> bool {
|
||||
self.capabilities.supported
|
||||
}
|
||||
|
||||
/// Release all GPU resources
|
||||
pub fn release(&mut self) {
|
||||
self.buffer_manager.clear();
|
||||
self.graph_data = None;
|
||||
}
|
||||
}
|
||||
|
||||
/// Synchronous wrapper for GPU coherence engine using pollster
|
||||
pub mod sync {
|
||||
use super::*;
|
||||
|
||||
/// Synchronously create a GPU engine
|
||||
pub fn create_engine(config: GpuConfig) -> GpuResult<GpuCoherenceEngine> {
|
||||
pollster::block_on(GpuCoherenceEngine::new(config))
|
||||
}
|
||||
|
||||
/// Try to create GPU engine synchronously
|
||||
pub fn try_create_engine(config: GpuConfig) -> Option<GpuCoherenceEngine> {
|
||||
pollster::block_on(GpuCoherenceEngine::try_new(config))
|
||||
}
|
||||
|
||||
/// Compute energy synchronously
|
||||
pub fn compute_energy(engine: &mut GpuCoherenceEngine) -> GpuResult<GpuCoherenceEnergy> {
|
||||
pollster::block_on(engine.compute_energy())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_gpu_config_default() {
|
||||
let config = GpuConfig::default();
|
||||
assert!(config.enable_fallback);
|
||||
assert_eq!(config.beta, 1.0);
|
||||
assert!(config.threshold_lane0 < config.threshold_lane1);
|
||||
assert!(config.threshold_lane1 < config.threshold_lane2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gpu_params_size() {
|
||||
assert_eq!(std::mem::size_of::<GpuParams>(), 32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_energy_params_size() {
|
||||
assert_eq!(std::mem::size_of::<EnergyParams>(), 32);
|
||||
}
|
||||
}
|
||||
228
vendor/ruvector/crates/prime-radiant/src/gpu/error.rs
vendored
Normal file
228
vendor/ruvector/crates/prime-radiant/src/gpu/error.rs
vendored
Normal file
@@ -0,0 +1,228 @@
|
||||
//! GPU Error Types
|
||||
//!
|
||||
//! Error handling for GPU operations including device initialization,
|
||||
//! buffer management, shader execution, and kernel dispatch.
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
/// Result type for GPU operations
|
||||
pub type GpuResult<T> = Result<T, GpuError>;
|
||||
|
||||
/// Errors that can occur during GPU operations
|
||||
#[derive(Debug, Error)]
|
||||
pub enum GpuError {
|
||||
/// No suitable GPU adapter found
|
||||
#[error("No suitable GPU adapter found. Ensure a GPU with compute capabilities is available.")]
|
||||
NoAdapter,
|
||||
|
||||
/// No compatible GPU device found
|
||||
#[error("No compatible GPU device found: {0}")]
|
||||
NoDevice(String),
|
||||
|
||||
/// GPU device creation failed
|
||||
#[error("Failed to create GPU device: {0}")]
|
||||
DeviceCreation(String),
|
||||
|
||||
/// Device request failed
|
||||
#[error("Failed to request GPU device: {0}")]
|
||||
DeviceRequestFailed(String),
|
||||
|
||||
/// Shader compilation failed
|
||||
#[error("Shader compilation failed: {0}")]
|
||||
ShaderCompilation(String),
|
||||
|
||||
/// Buffer allocation failed
|
||||
#[error("Buffer allocation failed: {0}")]
|
||||
BufferAllocation(String),
|
||||
|
||||
/// Buffer allocation failed with details
|
||||
#[error("Buffer allocation failed: requested {requested_bytes} bytes, reason: {reason}")]
|
||||
BufferAllocationFailed {
|
||||
/// Number of bytes requested
|
||||
requested_bytes: u64,
|
||||
/// Reason for failure
|
||||
reason: String,
|
||||
},
|
||||
|
||||
/// Buffer size exceeds maximum allowed
|
||||
#[error("Buffer size {size} exceeds maximum allowed {max}")]
|
||||
BufferTooLarge {
|
||||
/// Requested size
|
||||
size: u64,
|
||||
/// Maximum allowed size
|
||||
max: u64,
|
||||
},
|
||||
|
||||
/// Buffer size mismatch
|
||||
#[error("Buffer size mismatch: expected {expected}, got {actual}")]
|
||||
BufferSizeMismatch { expected: usize, actual: usize },
|
||||
|
||||
/// Buffer read-back failed
|
||||
#[error("Buffer read-back failed: {0}")]
|
||||
BufferReadFailed(String),
|
||||
|
||||
/// Buffer mapping failed
|
||||
#[error("Buffer mapping failed: {0}")]
|
||||
BufferMapFailed(String),
|
||||
|
||||
/// Dimension mismatch
|
||||
#[error("Dimension mismatch: expected {expected}, got {actual}")]
|
||||
DimensionMismatch { expected: usize, actual: usize },
|
||||
|
||||
/// Invalid binding configuration
|
||||
#[error("Invalid binding configuration: expected {expected} bindings, got {actual}")]
|
||||
InvalidBindingCount {
|
||||
/// Expected number of bindings
|
||||
expected: usize,
|
||||
/// Actual number of bindings
|
||||
actual: usize,
|
||||
},
|
||||
|
||||
/// Invalid workgroup configuration
|
||||
#[error("Invalid workgroup configuration: [{x}, {y}, {z}] exceeds device limits")]
|
||||
InvalidWorkgroupSize {
|
||||
/// X dimension
|
||||
x: u32,
|
||||
/// Y dimension
|
||||
y: u32,
|
||||
/// Z dimension
|
||||
z: u32,
|
||||
},
|
||||
|
||||
/// Compute pipeline creation failed
|
||||
#[error("Failed to create compute pipeline: {0}")]
|
||||
PipelineCreation(String),
|
||||
|
||||
/// Command encoding failed
|
||||
#[error("Command encoding failed: {0}")]
|
||||
CommandEncoding(String),
|
||||
|
||||
/// GPU execution failed
|
||||
#[error("GPU execution failed: {0}")]
|
||||
ExecutionFailed(String),
|
||||
|
||||
/// Buffer read failed
|
||||
#[error("Failed to read buffer: {0}")]
|
||||
BufferRead(String),
|
||||
|
||||
/// Buffer write failed
|
||||
#[error("Failed to write buffer: {0}")]
|
||||
BufferWrite(String),
|
||||
|
||||
/// Timeout waiting for GPU operation
|
||||
#[error("GPU operation timed out after {0}ms")]
|
||||
Timeout(u64),
|
||||
|
||||
/// Graph has no edges
|
||||
#[error("Graph has no edges to compute")]
|
||||
EmptyGraph,
|
||||
|
||||
/// Invalid configuration
|
||||
#[error("Invalid GPU configuration: {0}")]
|
||||
InvalidConfig(String),
|
||||
|
||||
/// Feature not supported
|
||||
#[error("GPU feature not supported: {0}")]
|
||||
UnsupportedFeature(String),
|
||||
|
||||
/// Adapter request failed
|
||||
#[error("Failed to request GPU adapter: {0}")]
|
||||
AdapterRequest(String),
|
||||
|
||||
/// Out of GPU memory
|
||||
#[error("Out of GPU memory: requested {requested_bytes} bytes")]
|
||||
OutOfMemory {
|
||||
/// Number of bytes requested
|
||||
requested_bytes: u64,
|
||||
},
|
||||
|
||||
/// Device lost
|
||||
#[error("GPU device lost: {0}")]
|
||||
DeviceLost(String),
|
||||
|
||||
/// Internal error
|
||||
#[error("Internal GPU error: {0}")]
|
||||
Internal(String),
|
||||
}
|
||||
|
||||
impl GpuError {
|
||||
/// Check if this error indicates GPU is unavailable and fallback should be used
|
||||
pub fn should_fallback(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
GpuError::NoAdapter
|
||||
| GpuError::NoDevice(_)
|
||||
| GpuError::DeviceCreation(_)
|
||||
| GpuError::DeviceRequestFailed(_)
|
||||
| GpuError::AdapterRequest(_)
|
||||
| GpuError::UnsupportedFeature(_)
|
||||
)
|
||||
}
|
||||
|
||||
/// Check if this error is recoverable
|
||||
pub fn is_recoverable(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
GpuError::Timeout(_)
|
||||
| GpuError::BufferRead(_)
|
||||
| GpuError::BufferReadFailed(_)
|
||||
| GpuError::ExecutionFailed(_)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<wgpu::RequestDeviceError> for GpuError {
|
||||
fn from(e: wgpu::RequestDeviceError) -> Self {
|
||||
Self::DeviceRequestFailed(e.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<wgpu::BufferAsyncError> for GpuError {
|
||||
fn from(e: wgpu::BufferAsyncError) -> Self {
|
||||
Self::BufferMapFailed(e.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_should_fallback() {
|
||||
assert!(GpuError::NoAdapter.should_fallback());
|
||||
assert!(GpuError::NoDevice("test".into()).should_fallback());
|
||||
assert!(GpuError::DeviceCreation("test".into()).should_fallback());
|
||||
assert!(!GpuError::Timeout(100).should_fallback());
|
||||
assert!(!GpuError::EmptyGraph.should_fallback());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_recoverable() {
|
||||
assert!(GpuError::Timeout(100).is_recoverable());
|
||||
assert!(GpuError::BufferRead("test".into()).is_recoverable());
|
||||
assert!(GpuError::BufferReadFailed("test".into()).is_recoverable());
|
||||
assert!(!GpuError::NoDevice("test".into()).is_recoverable());
|
||||
assert!(!GpuError::NoAdapter.is_recoverable());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_display() {
|
||||
let err = GpuError::BufferAllocationFailed {
|
||||
requested_bytes: 1024,
|
||||
reason: "out of memory".to_string(),
|
||||
};
|
||||
assert!(err.to_string().contains("1024"));
|
||||
assert!(err.to_string().contains("out of memory"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_workgroup_error() {
|
||||
let err = GpuError::InvalidWorkgroupSize {
|
||||
x: 1000,
|
||||
y: 1,
|
||||
z: 1,
|
||||
};
|
||||
let msg = err.to_string();
|
||||
assert!(msg.contains("1000"));
|
||||
}
|
||||
}
|
||||
760
vendor/ruvector/crates/prime-radiant/src/gpu/kernels.rs
vendored
Normal file
760
vendor/ruvector/crates/prime-radiant/src/gpu/kernels.rs
vendored
Normal file
@@ -0,0 +1,760 @@
|
||||
//! GPU Kernel Wrappers
|
||||
//!
|
||||
//! Provides Rust wrappers around WGSL compute shaders for coherence computation.
|
||||
//! Each kernel handles pipeline creation, bind group setup, and dispatch.
|
||||
|
||||
use super::buffer::{
|
||||
BufferUsage, GpuBuffer, GpuBufferManager, GpuEdge, GpuParams, GpuRestrictionMap,
|
||||
};
|
||||
use super::error::{GpuError, GpuResult};
|
||||
use super::shaders;
|
||||
use super::workgroup;
|
||||
use bytemuck::{Pod, Zeroable};
|
||||
use std::sync::Arc;
|
||||
use wgpu::{
|
||||
BindGroup, BindGroupDescriptor, BindGroupEntry, BindGroupLayout, BindGroupLayoutDescriptor,
|
||||
BindGroupLayoutEntry, BindingResource, BindingType, BufferBindingType, ComputePipeline,
|
||||
ComputePipelineDescriptor, Device, PipelineLayoutDescriptor, Queue, ShaderModule,
|
||||
ShaderModuleDescriptor, ShaderSource, ShaderStages,
|
||||
};
|
||||
|
||||
/// Compute residuals kernel
|
||||
/// Computes r_e = rho_source(x_source) - rho_target(x_target) for all edges
|
||||
pub struct ComputeResidualsKernel {
|
||||
pipeline: ComputePipeline,
|
||||
bind_group_layout: BindGroupLayout,
|
||||
}
|
||||
|
||||
impl ComputeResidualsKernel {
|
||||
/// Create a new compute residuals kernel
|
||||
pub fn new(device: &Device) -> GpuResult<Self> {
|
||||
let shader = device.create_shader_module(ShaderModuleDescriptor {
|
||||
label: Some("compute_residuals"),
|
||||
source: ShaderSource::Wgsl(shaders::COMPUTE_RESIDUALS.into()),
|
||||
});
|
||||
|
||||
let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
|
||||
label: Some("compute_residuals_bind_group_layout"),
|
||||
entries: &[
|
||||
// Params uniform
|
||||
BindGroupLayoutEntry {
|
||||
binding: 0,
|
||||
visibility: ShaderStages::COMPUTE,
|
||||
ty: BindingType::Buffer {
|
||||
ty: BufferBindingType::Uniform,
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
// Node states
|
||||
BindGroupLayoutEntry {
|
||||
binding: 1,
|
||||
visibility: ShaderStages::COMPUTE,
|
||||
ty: BindingType::Buffer {
|
||||
ty: BufferBindingType::Storage { read_only: true },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
// Edges
|
||||
BindGroupLayoutEntry {
|
||||
binding: 2,
|
||||
visibility: ShaderStages::COMPUTE,
|
||||
ty: BindingType::Buffer {
|
||||
ty: BufferBindingType::Storage { read_only: true },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
// Restriction maps
|
||||
BindGroupLayoutEntry {
|
||||
binding: 3,
|
||||
visibility: ShaderStages::COMPUTE,
|
||||
ty: BindingType::Buffer {
|
||||
ty: BufferBindingType::Storage { read_only: true },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
// Restriction data
|
||||
BindGroupLayoutEntry {
|
||||
binding: 4,
|
||||
visibility: ShaderStages::COMPUTE,
|
||||
ty: BindingType::Buffer {
|
||||
ty: BufferBindingType::Storage { read_only: true },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
// Residuals output
|
||||
BindGroupLayoutEntry {
|
||||
binding: 5,
|
||||
visibility: ShaderStages::COMPUTE,
|
||||
ty: BindingType::Buffer {
|
||||
ty: BufferBindingType::Storage { read_only: false },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
// Residual norms output
|
||||
BindGroupLayoutEntry {
|
||||
binding: 6,
|
||||
visibility: ShaderStages::COMPUTE,
|
||||
ty: BindingType::Buffer {
|
||||
ty: BufferBindingType::Storage { read_only: false },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
|
||||
label: Some("compute_residuals_pipeline_layout"),
|
||||
bind_group_layouts: &[&bind_group_layout],
|
||||
push_constant_ranges: &[],
|
||||
});
|
||||
|
||||
let pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
|
||||
label: Some("compute_residuals_pipeline"),
|
||||
layout: Some(&pipeline_layout),
|
||||
module: &shader,
|
||||
entry_point: Some("main"),
|
||||
compilation_options: Default::default(),
|
||||
cache: None,
|
||||
});
|
||||
|
||||
Ok(Self {
|
||||
pipeline,
|
||||
bind_group_layout,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a bind group for execution
|
||||
pub fn create_bind_group(
|
||||
&self,
|
||||
device: &Device,
|
||||
params_buffer: &GpuBuffer,
|
||||
node_states_buffer: &GpuBuffer,
|
||||
edges_buffer: &GpuBuffer,
|
||||
restriction_maps_buffer: &GpuBuffer,
|
||||
restriction_data_buffer: &GpuBuffer,
|
||||
residuals_buffer: &GpuBuffer,
|
||||
residual_norms_buffer: &GpuBuffer,
|
||||
) -> BindGroup {
|
||||
device.create_bind_group(&BindGroupDescriptor {
|
||||
label: Some("compute_residuals_bind_group"),
|
||||
layout: &self.bind_group_layout,
|
||||
entries: &[
|
||||
BindGroupEntry {
|
||||
binding: 0,
|
||||
resource: params_buffer.buffer.as_entire_binding(),
|
||||
},
|
||||
BindGroupEntry {
|
||||
binding: 1,
|
||||
resource: node_states_buffer.buffer.as_entire_binding(),
|
||||
},
|
||||
BindGroupEntry {
|
||||
binding: 2,
|
||||
resource: edges_buffer.buffer.as_entire_binding(),
|
||||
},
|
||||
BindGroupEntry {
|
||||
binding: 3,
|
||||
resource: restriction_maps_buffer.buffer.as_entire_binding(),
|
||||
},
|
||||
BindGroupEntry {
|
||||
binding: 4,
|
||||
resource: restriction_data_buffer.buffer.as_entire_binding(),
|
||||
},
|
||||
BindGroupEntry {
|
||||
binding: 5,
|
||||
resource: residuals_buffer.buffer.as_entire_binding(),
|
||||
},
|
||||
BindGroupEntry {
|
||||
binding: 6,
|
||||
resource: residual_norms_buffer.buffer.as_entire_binding(),
|
||||
},
|
||||
],
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a bind group using raw wgpu buffers (for pre-allocated buffer optimization)
|
||||
pub fn create_bind_group_raw(
|
||||
&self,
|
||||
device: &Device,
|
||||
params_buffer: &wgpu::Buffer,
|
||||
node_states_buffer: &wgpu::Buffer,
|
||||
edges_buffer: &wgpu::Buffer,
|
||||
restriction_maps_buffer: &wgpu::Buffer,
|
||||
restriction_data_buffer: &wgpu::Buffer,
|
||||
residuals_buffer: &wgpu::Buffer,
|
||||
residual_norms_buffer: &wgpu::Buffer,
|
||||
) -> BindGroup {
|
||||
device.create_bind_group(&BindGroupDescriptor {
|
||||
label: Some("compute_residuals_bind_group_raw"),
|
||||
layout: &self.bind_group_layout,
|
||||
entries: &[
|
||||
BindGroupEntry {
|
||||
binding: 0,
|
||||
resource: params_buffer.as_entire_binding(),
|
||||
},
|
||||
BindGroupEntry {
|
||||
binding: 1,
|
||||
resource: node_states_buffer.as_entire_binding(),
|
||||
},
|
||||
BindGroupEntry {
|
||||
binding: 2,
|
||||
resource: edges_buffer.as_entire_binding(),
|
||||
},
|
||||
BindGroupEntry {
|
||||
binding: 3,
|
||||
resource: restriction_maps_buffer.as_entire_binding(),
|
||||
},
|
||||
BindGroupEntry {
|
||||
binding: 4,
|
||||
resource: restriction_data_buffer.as_entire_binding(),
|
||||
},
|
||||
BindGroupEntry {
|
||||
binding: 5,
|
||||
resource: residuals_buffer.as_entire_binding(),
|
||||
},
|
||||
BindGroupEntry {
|
||||
binding: 6,
|
||||
resource: residual_norms_buffer.as_entire_binding(),
|
||||
},
|
||||
],
|
||||
})
|
||||
}
|
||||
|
||||
/// Get the pipeline for use in command encoder
|
||||
pub fn pipeline(&self) -> &ComputePipeline {
|
||||
&self.pipeline
|
||||
}
|
||||
|
||||
/// Calculate number of workgroups needed
|
||||
pub fn workgroup_count(num_edges: u32) -> u32 {
|
||||
// One thread per edge, 256 threads per workgroup
|
||||
(num_edges + workgroup::SIZE_1D - 1) / workgroup::SIZE_1D
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute energy kernel with parallel reduction
|
||||
pub struct ComputeEnergyKernel {
|
||||
main_pipeline: ComputePipeline,
|
||||
final_pipeline: ComputePipeline,
|
||||
bind_group_layout: BindGroupLayout,
|
||||
}
|
||||
|
||||
/// Parameters for energy reduction
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
|
||||
pub struct EnergyParams {
|
||||
/// Number of elements to reduce
|
||||
pub num_elements: u32,
|
||||
/// Padding
|
||||
pub _padding: [u32; 7],
|
||||
}
|
||||
|
||||
impl ComputeEnergyKernel {
|
||||
/// Create a new compute energy kernel
|
||||
pub fn new(device: &Device) -> GpuResult<Self> {
|
||||
let shader = device.create_shader_module(ShaderModuleDescriptor {
|
||||
label: Some("compute_energy"),
|
||||
source: ShaderSource::Wgsl(shaders::COMPUTE_ENERGY.into()),
|
||||
});
|
||||
|
||||
let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
|
||||
label: Some("compute_energy_bind_group_layout"),
|
||||
entries: &[
|
||||
// Params uniform
|
||||
BindGroupLayoutEntry {
|
||||
binding: 0,
|
||||
visibility: ShaderStages::COMPUTE,
|
||||
ty: BindingType::Buffer {
|
||||
ty: BufferBindingType::Uniform,
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
// Input energies
|
||||
BindGroupLayoutEntry {
|
||||
binding: 1,
|
||||
visibility: ShaderStages::COMPUTE,
|
||||
ty: BindingType::Buffer {
|
||||
ty: BufferBindingType::Storage { read_only: true },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
// Output partial sums
|
||||
BindGroupLayoutEntry {
|
||||
binding: 2,
|
||||
visibility: ShaderStages::COMPUTE,
|
||||
ty: BindingType::Buffer {
|
||||
ty: BufferBindingType::Storage { read_only: false },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
|
||||
label: Some("compute_energy_pipeline_layout"),
|
||||
bind_group_layouts: &[&bind_group_layout],
|
||||
push_constant_ranges: &[],
|
||||
});
|
||||
|
||||
let main_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
|
||||
label: Some("compute_energy_main_pipeline"),
|
||||
layout: Some(&pipeline_layout),
|
||||
module: &shader,
|
||||
entry_point: Some("main"),
|
||||
compilation_options: Default::default(),
|
||||
cache: None,
|
||||
});
|
||||
|
||||
let final_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
|
||||
label: Some("compute_energy_final_pipeline"),
|
||||
layout: Some(&pipeline_layout),
|
||||
module: &shader,
|
||||
entry_point: Some("final_reduce"),
|
||||
compilation_options: Default::default(),
|
||||
cache: None,
|
||||
});
|
||||
|
||||
Ok(Self {
|
||||
main_pipeline,
|
||||
final_pipeline,
|
||||
bind_group_layout,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a bind group for execution
|
||||
pub fn create_bind_group(
|
||||
&self,
|
||||
device: &Device,
|
||||
params_buffer: &GpuBuffer,
|
||||
input_buffer: &GpuBuffer,
|
||||
output_buffer: &GpuBuffer,
|
||||
) -> BindGroup {
|
||||
device.create_bind_group(&BindGroupDescriptor {
|
||||
label: Some("compute_energy_bind_group"),
|
||||
layout: &self.bind_group_layout,
|
||||
entries: &[
|
||||
BindGroupEntry {
|
||||
binding: 0,
|
||||
resource: params_buffer.buffer.as_entire_binding(),
|
||||
},
|
||||
BindGroupEntry {
|
||||
binding: 1,
|
||||
resource: input_buffer.buffer.as_entire_binding(),
|
||||
},
|
||||
BindGroupEntry {
|
||||
binding: 2,
|
||||
resource: output_buffer.buffer.as_entire_binding(),
|
||||
},
|
||||
],
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a bind group using raw wgpu buffers (for pre-allocated buffer optimization)
|
||||
pub fn create_bind_group_raw(
|
||||
&self,
|
||||
device: &Device,
|
||||
params_buffer: &wgpu::Buffer,
|
||||
input_buffer: &wgpu::Buffer,
|
||||
output_buffer: &wgpu::Buffer,
|
||||
) -> BindGroup {
|
||||
device.create_bind_group(&BindGroupDescriptor {
|
||||
label: Some("compute_energy_bind_group_raw"),
|
||||
layout: &self.bind_group_layout,
|
||||
entries: &[
|
||||
BindGroupEntry {
|
||||
binding: 0,
|
||||
resource: params_buffer.as_entire_binding(),
|
||||
},
|
||||
BindGroupEntry {
|
||||
binding: 1,
|
||||
resource: input_buffer.as_entire_binding(),
|
||||
},
|
||||
BindGroupEntry {
|
||||
binding: 2,
|
||||
resource: output_buffer.as_entire_binding(),
|
||||
},
|
||||
],
|
||||
})
|
||||
}
|
||||
|
||||
/// Get the main reduction pipeline
|
||||
pub fn main_pipeline(&self) -> &ComputePipeline {
|
||||
&self.main_pipeline
|
||||
}
|
||||
|
||||
/// Get the final reduction pipeline
|
||||
pub fn final_pipeline(&self) -> &ComputePipeline {
|
||||
&self.final_pipeline
|
||||
}
|
||||
|
||||
/// Calculate number of workgroups for first pass
|
||||
pub fn workgroup_count(num_elements: u32) -> u32 {
|
||||
// One element per thread, 256 threads per workgroup
|
||||
(num_elements + workgroup::SIZE_1D - 1) / workgroup::SIZE_1D
|
||||
}
|
||||
}
|
||||
|
||||
/// Sheaf attention kernel
|
||||
pub struct SheafAttentionKernel {
|
||||
single_pass_pipeline: ComputePipeline,
|
||||
bind_group_layout: BindGroupLayout,
|
||||
}
|
||||
|
||||
/// Attention weight output
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
|
||||
pub struct AttentionWeight {
|
||||
pub edge_idx: u32,
|
||||
pub source_idx: u32,
|
||||
pub target_idx: u32,
|
||||
pub raw_score: f32,
|
||||
pub attention: f32,
|
||||
pub _padding: [u32; 3],
|
||||
}
|
||||
|
||||
impl SheafAttentionKernel {
|
||||
/// Create a new sheaf attention kernel
|
||||
pub fn new(device: &Device) -> GpuResult<Self> {
|
||||
let shader = device.create_shader_module(ShaderModuleDescriptor {
|
||||
label: Some("sheaf_attention"),
|
||||
source: ShaderSource::Wgsl(shaders::SHEAF_ATTENTION.into()),
|
||||
});
|
||||
|
||||
let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
|
||||
label: Some("sheaf_attention_bind_group_layout"),
|
||||
entries: &[
|
||||
// Params
|
||||
BindGroupLayoutEntry {
|
||||
binding: 0,
|
||||
visibility: ShaderStages::COMPUTE,
|
||||
ty: BindingType::Buffer {
|
||||
ty: BufferBindingType::Uniform,
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
// Edges
|
||||
BindGroupLayoutEntry {
|
||||
binding: 1,
|
||||
visibility: ShaderStages::COMPUTE,
|
||||
ty: BindingType::Buffer {
|
||||
ty: BufferBindingType::Storage { read_only: true },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
// Edge energies
|
||||
BindGroupLayoutEntry {
|
||||
binding: 2,
|
||||
visibility: ShaderStages::COMPUTE,
|
||||
ty: BindingType::Buffer {
|
||||
ty: BufferBindingType::Storage { read_only: true },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
// Attention weights output
|
||||
BindGroupLayoutEntry {
|
||||
binding: 3,
|
||||
visibility: ShaderStages::COMPUTE,
|
||||
ty: BindingType::Buffer {
|
||||
ty: BufferBindingType::Storage { read_only: false },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
// Node exp sums (for normalization)
|
||||
BindGroupLayoutEntry {
|
||||
binding: 4,
|
||||
visibility: ShaderStages::COMPUTE,
|
||||
ty: BindingType::Buffer {
|
||||
ty: BufferBindingType::Storage { read_only: false },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
|
||||
label: Some("sheaf_attention_pipeline_layout"),
|
||||
bind_group_layouts: &[&bind_group_layout],
|
||||
push_constant_ranges: &[],
|
||||
});
|
||||
|
||||
let single_pass_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
|
||||
label: Some("sheaf_attention_single_pass_pipeline"),
|
||||
layout: Some(&pipeline_layout),
|
||||
module: &shader,
|
||||
entry_point: Some("compute_attention_single_pass"),
|
||||
compilation_options: Default::default(),
|
||||
cache: None,
|
||||
});
|
||||
|
||||
Ok(Self {
|
||||
single_pass_pipeline,
|
||||
bind_group_layout,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a bind group
|
||||
pub fn create_bind_group(
|
||||
&self,
|
||||
device: &Device,
|
||||
params_buffer: &GpuBuffer,
|
||||
edges_buffer: &GpuBuffer,
|
||||
edge_energies_buffer: &GpuBuffer,
|
||||
attention_weights_buffer: &GpuBuffer,
|
||||
node_exp_sums_buffer: &GpuBuffer,
|
||||
) -> BindGroup {
|
||||
device.create_bind_group(&BindGroupDescriptor {
|
||||
label: Some("sheaf_attention_bind_group"),
|
||||
layout: &self.bind_group_layout,
|
||||
entries: &[
|
||||
BindGroupEntry {
|
||||
binding: 0,
|
||||
resource: params_buffer.buffer.as_entire_binding(),
|
||||
},
|
||||
BindGroupEntry {
|
||||
binding: 1,
|
||||
resource: edges_buffer.buffer.as_entire_binding(),
|
||||
},
|
||||
BindGroupEntry {
|
||||
binding: 2,
|
||||
resource: edge_energies_buffer.buffer.as_entire_binding(),
|
||||
},
|
||||
BindGroupEntry {
|
||||
binding: 3,
|
||||
resource: attention_weights_buffer.buffer.as_entire_binding(),
|
||||
},
|
||||
BindGroupEntry {
|
||||
binding: 4,
|
||||
resource: node_exp_sums_buffer.buffer.as_entire_binding(),
|
||||
},
|
||||
],
|
||||
})
|
||||
}
|
||||
|
||||
/// Get the single-pass pipeline
|
||||
pub fn pipeline(&self) -> &ComputePipeline {
|
||||
&self.single_pass_pipeline
|
||||
}
|
||||
|
||||
/// Calculate workgroup count
|
||||
pub fn workgroup_count(num_edges: u32) -> u32 {
|
||||
(num_edges + workgroup::SIZE_1D - 1) / workgroup::SIZE_1D
|
||||
}
|
||||
}
|
||||
|
||||
/// Token routing kernel
|
||||
pub struct TokenRoutingKernel {
|
||||
route_pipeline: ComputePipeline,
|
||||
bind_group_layout: BindGroupLayout,
|
||||
}
|
||||
|
||||
/// Token input
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
|
||||
pub struct Token {
|
||||
pub token_id: u32,
|
||||
pub node_idx: u32,
|
||||
pub action_type: u32,
|
||||
pub priority: f32,
|
||||
}
|
||||
|
||||
/// Routing decision output
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
|
||||
pub struct RoutingDecision {
|
||||
pub token_id: u32,
|
||||
pub assigned_lane: u32,
|
||||
pub local_energy: f32,
|
||||
pub confidence: f32,
|
||||
pub escalation_reason: u32,
|
||||
pub num_high_energy_edges: u32,
|
||||
pub max_edge_energy: f32,
|
||||
pub _padding: u32,
|
||||
}
|
||||
|
||||
/// Lane statistics
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
|
||||
pub struct LaneStats {
|
||||
pub lane_counts: [u32; 4],
|
||||
pub total_energy_per_lane: [f32; 4],
|
||||
pub _padding: [u32; 8],
|
||||
}
|
||||
|
||||
impl TokenRoutingKernel {
|
||||
/// Create a new token routing kernel
|
||||
pub fn new(device: &Device) -> GpuResult<Self> {
|
||||
let shader = device.create_shader_module(ShaderModuleDescriptor {
|
||||
label: Some("token_routing"),
|
||||
source: ShaderSource::Wgsl(shaders::TOKEN_ROUTING.into()),
|
||||
});
|
||||
|
||||
let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
|
||||
label: Some("token_routing_bind_group_layout"),
|
||||
entries: &[
|
||||
// Params
|
||||
BindGroupLayoutEntry {
|
||||
binding: 0,
|
||||
visibility: ShaderStages::COMPUTE,
|
||||
ty: BindingType::Buffer {
|
||||
ty: BufferBindingType::Uniform,
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
// Tokens
|
||||
BindGroupLayoutEntry {
|
||||
binding: 1,
|
||||
visibility: ShaderStages::COMPUTE,
|
||||
ty: BindingType::Buffer {
|
||||
ty: BufferBindingType::Storage { read_only: true },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
// Local energies
|
||||
BindGroupLayoutEntry {
|
||||
binding: 2,
|
||||
visibility: ShaderStages::COMPUTE,
|
||||
ty: BindingType::Buffer {
|
||||
ty: BufferBindingType::Storage { read_only: true },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
// Edge energies
|
||||
BindGroupLayoutEntry {
|
||||
binding: 3,
|
||||
visibility: ShaderStages::COMPUTE,
|
||||
ty: BindingType::Buffer {
|
||||
ty: BufferBindingType::Storage { read_only: true },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
// Node edge counts
|
||||
BindGroupLayoutEntry {
|
||||
binding: 4,
|
||||
visibility: ShaderStages::COMPUTE,
|
||||
ty: BindingType::Buffer {
|
||||
ty: BufferBindingType::Storage { read_only: true },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
// Node edge offsets
|
||||
BindGroupLayoutEntry {
|
||||
binding: 5,
|
||||
visibility: ShaderStages::COMPUTE,
|
||||
ty: BindingType::Buffer {
|
||||
ty: BufferBindingType::Storage { read_only: true },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
// Node edges
|
||||
BindGroupLayoutEntry {
|
||||
binding: 6,
|
||||
visibility: ShaderStages::COMPUTE,
|
||||
ty: BindingType::Buffer {
|
||||
ty: BufferBindingType::Storage { read_only: true },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
// Routing decisions output
|
||||
BindGroupLayoutEntry {
|
||||
binding: 7,
|
||||
visibility: ShaderStages::COMPUTE,
|
||||
ty: BindingType::Buffer {
|
||||
ty: BufferBindingType::Storage { read_only: false },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
// Lane stats output
|
||||
BindGroupLayoutEntry {
|
||||
binding: 8,
|
||||
visibility: ShaderStages::COMPUTE,
|
||||
ty: BindingType::Buffer {
|
||||
ty: BufferBindingType::Storage { read_only: false },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
|
||||
label: Some("token_routing_pipeline_layout"),
|
||||
bind_group_layouts: &[&bind_group_layout],
|
||||
push_constant_ranges: &[],
|
||||
});
|
||||
|
||||
let route_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
|
||||
label: Some("token_routing_pipeline"),
|
||||
layout: Some(&pipeline_layout),
|
||||
module: &shader,
|
||||
entry_point: Some("route_tokens"),
|
||||
compilation_options: Default::default(),
|
||||
cache: None,
|
||||
});
|
||||
|
||||
Ok(Self {
|
||||
route_pipeline,
|
||||
bind_group_layout,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get the routing pipeline
|
||||
pub fn pipeline(&self) -> &ComputePipeline {
|
||||
&self.route_pipeline
|
||||
}
|
||||
|
||||
/// Get bind group layout
|
||||
pub fn bind_group_layout(&self) -> &BindGroupLayout {
|
||||
&self.bind_group_layout
|
||||
}
|
||||
|
||||
/// Calculate workgroup count
|
||||
pub fn workgroup_count(num_tokens: u32) -> u32 {
|
||||
(num_tokens + workgroup::SIZE_1D - 1) / workgroup::SIZE_1D
|
||||
}
|
||||
}
|
||||
156
vendor/ruvector/crates/prime-radiant/src/gpu/mod.rs
vendored
Normal file
156
vendor/ruvector/crates/prime-radiant/src/gpu/mod.rs
vendored
Normal file
@@ -0,0 +1,156 @@
|
||||
//! GPU acceleration module for Prime-Radiant coherence engine.
|
||||
//!
|
||||
//! This module provides GPU-accelerated computation using wgpu for:
|
||||
//! - Parallel residual calculations across large graphs
|
||||
//! - Matrix operations for restriction maps
|
||||
//! - Energy aggregation with atomic operations
|
||||
//! - Spectral analysis via power iteration
|
||||
//!
|
||||
//! # Architecture
|
||||
//!
|
||||
//! ```text
|
||||
//! +------------------+ +------------------+ +------------------+
|
||||
//! | GpuDevice |---->| GpuBuffer |---->| GpuDispatcher |
|
||||
//! | (Init/Queue) | | (Alloc/Transfer)| | (Kernels/Sync) |
|
||||
//! +------------------+ +------------------+ +------------------+
|
||||
//! | | |
|
||||
//! v v v
|
||||
//! +------------------+ +------------------+ +------------------+
|
||||
//! | Instance/Adapter | | BufferPool | | PipelineCache |
|
||||
//! | Device/Queue | | Read/Write | | BindGroups |
|
||||
//! +------------------+ +------------------+ +------------------+
|
||||
//! ```
|
||||
//!
|
||||
//! # Feature Flag
|
||||
//!
|
||||
//! This module requires the `gpu` feature flag:
|
||||
//! ```toml
|
||||
//! [dependencies]
|
||||
//! prime-radiant = { version = "0.1", features = ["gpu"] }
|
||||
//! ```
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! ```rust,ignore
|
||||
//! use prime_radiant::gpu::{GpuDevice, GpuBuffer, GpuDispatcher, ComputePipeline};
|
||||
//!
|
||||
//! #[tokio::main]
|
||||
//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
//! // Initialize GPU device
|
||||
//! let device = GpuDevice::new().await?;
|
||||
//!
|
||||
//! // Create storage buffer with data
|
||||
//! let input_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
|
||||
//! let input_buffer = GpuBuffer::new_storage(device.device(), &input_data, false);
|
||||
//!
|
||||
//! // Create output buffer
|
||||
//! let output_buffer = GpuBuffer::new_storage_uninit::<f32>(
|
||||
//! device.device(),
|
||||
//! input_data.len(),
|
||||
//! true,
|
||||
//! );
|
||||
//!
|
||||
//! // Create compute pipeline
|
||||
//! let pipeline = ComputePipeline::from_shader(
|
||||
//! device.device(),
|
||||
//! include_str!("shaders/compute_residuals.wgsl"),
|
||||
//! "main",
|
||||
//! &[BindingDesc::storage_readonly(), BindingDesc::storage_readwrite()],
|
||||
//! )?;
|
||||
//!
|
||||
//! // Create dispatcher and execute
|
||||
//! let dispatcher = GpuDispatcher::new(Arc::new(device));
|
||||
//! let bind_group = pipeline.create_bind_group(
|
||||
//! dispatcher.device().device(),
|
||||
//! &[&input_buffer, &output_buffer],
|
||||
//! )?;
|
||||
//! dispatcher.dispatch(&pipeline, &bind_group, [4, 1, 1]).await?;
|
||||
//!
|
||||
//! Ok(())
|
||||
//! }
|
||||
//! ```
|
||||
//!
|
||||
//! # GPU Kernels
|
||||
//!
|
||||
//! The following WGSL compute shaders are implemented:
|
||||
//!
|
||||
//! 1. **compute_residuals.wgsl** - Parallel residual computation for all edges
|
||||
//! 2. **compute_energy.wgsl** - Parallel energy aggregation with tree reduction
|
||||
//! 3. **sheaf_attention.wgsl** - Batched attention: A_ij = exp(-beta * E_ij) / Z
|
||||
//! 4. **token_routing.wgsl** - Parallel lane assignment based on energy thresholds
|
||||
//!
|
||||
//! # Performance Targets
|
||||
//!
|
||||
//! | Operation | Target | Notes |
|
||||
//! |-----------|--------|-------|
|
||||
//! | Buffer allocation | < 1ms | Pooled for hot paths |
|
||||
//! | Kernel dispatch | < 100us | Excludes GPU execution |
|
||||
//! | Residual (10K edges) | < 1ms | GPU parallel |
|
||||
//! | Energy aggregation | < 500us | Atomic reduction |
|
||||
|
||||
mod buffer;
|
||||
mod device;
|
||||
mod dispatch;
|
||||
mod engine;
|
||||
mod error;
|
||||
mod kernels;
|
||||
mod pipeline;
|
||||
|
||||
// Core exports
|
||||
pub use buffer::{
|
||||
BufferKey, BufferUsage, BufferUsageFlags, GpuBuffer, GpuBufferManager, GpuBufferPool,
|
||||
};
|
||||
pub use device::{GpuDevice, GpuDeviceInfo, GpuDeviceOptions};
|
||||
pub use dispatch::{DispatchBuilder, DispatchConfig, GpuDispatcher};
|
||||
pub use error::{GpuError, GpuResult};
|
||||
pub use pipeline::{BindingDesc, BindingType, ComputePipeline, PipelineCache};
|
||||
|
||||
// Re-export buffer types
|
||||
pub use buffer::{GpuEdge, GpuNodeState, GpuParams, GpuRestrictionMap};
|
||||
|
||||
// Re-export engine types
|
||||
pub use engine::{GpuCapabilities, GpuCoherenceEnergy, GpuCoherenceEngine, GpuConfig};
|
||||
|
||||
/// Synchronous API for GPU coherence engine (uses pollster)
|
||||
pub mod sync {
|
||||
pub use super::engine::sync::*;
|
||||
}
|
||||
|
||||
// Re-export kernel types
|
||||
pub use kernels::{
|
||||
AttentionWeight, ComputeEnergyKernel, ComputeResidualsKernel, EnergyParams, LaneStats,
|
||||
RoutingDecision, SheafAttentionKernel, Token, TokenRoutingKernel,
|
||||
};
|
||||
|
||||
/// Default workgroup size for compute shaders
|
||||
pub const DEFAULT_WORKGROUP_SIZE: u32 = 256;
|
||||
|
||||
/// Maximum buffer size for a single allocation (256MB)
|
||||
pub const MAX_BUFFER_SIZE: u64 = 256 * 1024 * 1024;
|
||||
|
||||
/// Default pool capacity for buffer reuse
|
||||
pub const DEFAULT_POOL_CAPACITY: usize = 32;
|
||||
|
||||
/// Shader source code embedded at compile time
|
||||
pub mod shaders {
|
||||
/// Compute residuals shader for parallel edge residual computation
|
||||
pub const COMPUTE_RESIDUALS: &str = include_str!("shaders/compute_residuals.wgsl");
|
||||
/// Compute energy shader for parallel reduction
|
||||
pub const COMPUTE_ENERGY: &str = include_str!("shaders/compute_energy.wgsl");
|
||||
/// Sheaf attention shader for attention weight computation
|
||||
pub const SHEAF_ATTENTION: &str = include_str!("shaders/sheaf_attention.wgsl");
|
||||
/// Token routing shader for lane assignment
|
||||
pub const TOKEN_ROUTING: &str = include_str!("shaders/token_routing.wgsl");
|
||||
}
|
||||
|
||||
/// GPU workgroup size constants
|
||||
pub mod workgroup {
|
||||
/// Default workgroup size for 1D compute
|
||||
pub const SIZE_1D: u32 = 256;
|
||||
/// Default workgroup size for 2D compute (x dimension)
|
||||
pub const SIZE_2D_X: u32 = 16;
|
||||
/// Default workgroup size for 2D compute (y dimension)
|
||||
pub const SIZE_2D_Y: u32 = 16;
|
||||
/// Maximum state vector dimension for GPU kernels
|
||||
pub const MAX_STATE_DIM: u32 = 512;
|
||||
}
|
||||
513
vendor/ruvector/crates/prime-radiant/src/gpu/pipeline.rs
vendored
Normal file
513
vendor/ruvector/crates/prime-radiant/src/gpu/pipeline.rs
vendored
Normal file
@@ -0,0 +1,513 @@
|
||||
//! Compute pipeline management for GPU operations.
|
||||
//!
|
||||
//! This module handles shader compilation, pipeline creation, and bind group
|
||||
//! management for GPU compute operations.
|
||||
|
||||
use dashmap::DashMap;
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug, info};
|
||||
use wgpu::{Device, ShaderModule};
|
||||
|
||||
use super::buffer::GpuBuffer;
|
||||
use super::error::{GpuError, GpuResult};
|
||||
use super::DEFAULT_WORKGROUP_SIZE;
|
||||
|
||||
/// Type of binding in a compute shader
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum BindingType {
|
||||
/// Storage buffer (read-only)
|
||||
StorageReadonly,
|
||||
/// Storage buffer (read-write)
|
||||
StorageReadWrite,
|
||||
/// Uniform buffer
|
||||
Uniform,
|
||||
}
|
||||
|
||||
impl BindingType {
|
||||
/// Convert to wgpu binding type
|
||||
fn to_wgpu(&self) -> wgpu::BindingType {
|
||||
match self {
|
||||
Self::StorageReadonly => wgpu::BindingType::Buffer {
|
||||
ty: wgpu::BufferBindingType::Storage { read_only: true },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
Self::StorageReadWrite => wgpu::BindingType::Buffer {
|
||||
ty: wgpu::BufferBindingType::Storage { read_only: false },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
Self::Uniform => wgpu::BindingType::Buffer {
|
||||
ty: wgpu::BufferBindingType::Uniform,
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Description of a binding in a compute shader
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BindingDesc {
|
||||
/// Binding type
|
||||
pub binding_type: BindingType,
|
||||
/// Optional label for debugging
|
||||
pub label: Option<String>,
|
||||
}
|
||||
|
||||
impl BindingDesc {
|
||||
/// Create a storage read-only binding
|
||||
pub fn storage_readonly() -> Self {
|
||||
Self {
|
||||
binding_type: BindingType::StorageReadonly,
|
||||
label: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a storage read-write binding
|
||||
pub fn storage_readwrite() -> Self {
|
||||
Self {
|
||||
binding_type: BindingType::StorageReadWrite,
|
||||
label: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a uniform binding
|
||||
pub fn uniform() -> Self {
|
||||
Self {
|
||||
binding_type: BindingType::Uniform,
|
||||
label: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a label to the binding
|
||||
pub fn with_label(mut self, label: impl Into<String>) -> Self {
|
||||
self.label = Some(label.into());
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute pipeline wrapper
|
||||
pub struct ComputePipeline {
|
||||
pipeline: wgpu::ComputePipeline,
|
||||
bind_group_layout: wgpu::BindGroupLayout,
|
||||
workgroup_size: [u32; 3],
|
||||
entry_point: String,
|
||||
binding_count: usize,
|
||||
}
|
||||
|
||||
impl ComputePipeline {
|
||||
/// Create a new compute pipeline from shader source.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `device` - The wgpu device
|
||||
/// * `shader_source` - WGSL shader source code
|
||||
/// * `entry_point` - Entry point function name
|
||||
/// * `bindings` - Binding descriptions
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// let pipeline = ComputePipeline::from_shader(
|
||||
/// &device,
|
||||
/// r#"
|
||||
/// @group(0) @binding(0) var<storage, read> input: array<f32>;
|
||||
/// @group(0) @binding(1) var<storage, read_write> output: array<f32>;
|
||||
///
|
||||
/// @compute @workgroup_size(256)
|
||||
/// fn main(@builtin(global_invocation_id) id: vec3<u32>) {
|
||||
/// output[id.x] = input[id.x] * 2.0;
|
||||
/// }
|
||||
/// "#,
|
||||
/// "main",
|
||||
/// &[BindingDesc::storage_readonly(), BindingDesc::storage_readwrite()],
|
||||
/// );
|
||||
/// ```
|
||||
pub fn from_shader(
|
||||
device: &Device,
|
||||
shader_source: &str,
|
||||
entry_point: &str,
|
||||
bindings: &[BindingDesc],
|
||||
) -> GpuResult<Self> {
|
||||
Self::from_shader_with_workgroup_size(
|
||||
device,
|
||||
shader_source,
|
||||
entry_point,
|
||||
bindings,
|
||||
[DEFAULT_WORKGROUP_SIZE, 1, 1],
|
||||
)
|
||||
}
|
||||
|
||||
/// Create a pipeline with custom workgroup size.
|
||||
pub fn from_shader_with_workgroup_size(
|
||||
device: &Device,
|
||||
shader_source: &str,
|
||||
entry_point: &str,
|
||||
bindings: &[BindingDesc],
|
||||
workgroup_size: [u32; 3],
|
||||
) -> GpuResult<Self> {
|
||||
debug!(
|
||||
"Creating compute pipeline with entry point '{}' and {} bindings",
|
||||
entry_point,
|
||||
bindings.len()
|
||||
);
|
||||
|
||||
// Create shader module
|
||||
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
|
||||
label: Some("compute_shader"),
|
||||
source: wgpu::ShaderSource::Wgsl(shader_source.into()),
|
||||
});
|
||||
|
||||
Self::from_module(device, &shader, entry_point, bindings, workgroup_size)
|
||||
}
|
||||
|
||||
/// Create a pipeline from a pre-compiled shader module.
|
||||
pub fn from_module(
|
||||
device: &Device,
|
||||
shader: &ShaderModule,
|
||||
entry_point: &str,
|
||||
bindings: &[BindingDesc],
|
||||
workgroup_size: [u32; 3],
|
||||
) -> GpuResult<Self> {
|
||||
// Create bind group layout entries
|
||||
let layout_entries: Vec<wgpu::BindGroupLayoutEntry> = bindings
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, desc)| wgpu::BindGroupLayoutEntry {
|
||||
binding: i as u32,
|
||||
visibility: wgpu::ShaderStages::COMPUTE,
|
||||
ty: desc.binding_type.to_wgpu(),
|
||||
count: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Create bind group layout
|
||||
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
|
||||
label: Some("compute_bind_group_layout"),
|
||||
entries: &layout_entries,
|
||||
});
|
||||
|
||||
// Create pipeline layout
|
||||
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
|
||||
label: Some("compute_pipeline_layout"),
|
||||
bind_group_layouts: &[&bind_group_layout],
|
||||
push_constant_ranges: &[],
|
||||
});
|
||||
|
||||
// Create compute pipeline
|
||||
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
|
||||
label: Some("compute_pipeline"),
|
||||
layout: Some(&pipeline_layout),
|
||||
module: shader,
|
||||
entry_point: Some(entry_point),
|
||||
compilation_options: wgpu::PipelineCompilationOptions::default(),
|
||||
cache: None,
|
||||
});
|
||||
|
||||
Ok(Self {
|
||||
pipeline,
|
||||
bind_group_layout,
|
||||
workgroup_size,
|
||||
entry_point: entry_point.to_string(),
|
||||
binding_count: bindings.len(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a bind group for this pipeline.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `device` - The wgpu device
|
||||
/// * `buffers` - Buffers to bind, in order
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the number of buffers doesn't match the pipeline's binding count.
|
||||
pub fn create_bind_group(
|
||||
&self,
|
||||
device: &Device,
|
||||
buffers: &[&GpuBuffer],
|
||||
) -> GpuResult<wgpu::BindGroup> {
|
||||
if buffers.len() != self.binding_count {
|
||||
return Err(GpuError::InvalidBindingCount {
|
||||
expected: self.binding_count,
|
||||
actual: buffers.len(),
|
||||
});
|
||||
}
|
||||
|
||||
let entries: Vec<wgpu::BindGroupEntry> = buffers
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, buffer)| buffer.binding(i as u32))
|
||||
.collect();
|
||||
|
||||
Ok(device.create_bind_group(&wgpu::BindGroupDescriptor {
|
||||
label: Some("compute_bind_group"),
|
||||
layout: &self.bind_group_layout,
|
||||
entries: &entries,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Get the underlying wgpu pipeline
|
||||
pub fn pipeline(&self) -> &wgpu::ComputePipeline {
|
||||
&self.pipeline
|
||||
}
|
||||
|
||||
/// Get the bind group layout
|
||||
pub fn bind_group_layout(&self) -> &wgpu::BindGroupLayout {
|
||||
&self.bind_group_layout
|
||||
}
|
||||
|
||||
/// Get the workgroup size
|
||||
pub fn workgroup_size(&self) -> [u32; 3] {
|
||||
self.workgroup_size
|
||||
}
|
||||
|
||||
/// Get the entry point name
|
||||
pub fn entry_point(&self) -> &str {
|
||||
&self.entry_point
|
||||
}
|
||||
|
||||
/// Get the number of bindings
|
||||
pub fn binding_count(&self) -> usize {
|
||||
self.binding_count
|
||||
}
|
||||
|
||||
/// Calculate workgroup count for a given data size.
|
||||
pub fn calculate_workgroups(&self, data_size: u32) -> [u32; 3] {
|
||||
let x = (data_size + self.workgroup_size[0] - 1) / self.workgroup_size[0];
|
||||
[x, 1, 1]
|
||||
}
|
||||
|
||||
/// Calculate workgroup count for 2D data.
|
||||
pub fn calculate_workgroups_2d(&self, width: u32, height: u32) -> [u32; 3] {
|
||||
let x = (width + self.workgroup_size[0] - 1) / self.workgroup_size[0];
|
||||
let y = (height + self.workgroup_size[1] - 1) / self.workgroup_size[1];
|
||||
[x, y, 1]
|
||||
}
|
||||
|
||||
/// Calculate workgroup count for 3D data.
|
||||
pub fn calculate_workgroups_3d(&self, width: u32, height: u32, depth: u32) -> [u32; 3] {
|
||||
let x = (width + self.workgroup_size[0] - 1) / self.workgroup_size[0];
|
||||
let y = (height + self.workgroup_size[1] - 1) / self.workgroup_size[1];
|
||||
let z = (depth + self.workgroup_size[2] - 1) / self.workgroup_size[2];
|
||||
[x, y, z]
|
||||
}
|
||||
}
|
||||
|
||||
/// Cache for compute pipelines
|
||||
pub struct PipelineCache {
|
||||
device: Arc<Device>,
|
||||
pipelines: DashMap<String, Arc<ComputePipeline>>,
|
||||
}
|
||||
|
||||
impl PipelineCache {
|
||||
/// Create a new pipeline cache
|
||||
pub fn new(device: Arc<Device>) -> Self {
|
||||
Self {
|
||||
device,
|
||||
pipelines: DashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get or create a pipeline.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `name` - Unique name for the pipeline
|
||||
/// * `shader_source` - WGSL shader source
|
||||
/// * `entry_point` - Entry point function name
|
||||
/// * `bindings` - Binding descriptions
|
||||
pub fn get_or_create(
|
||||
&self,
|
||||
name: &str,
|
||||
shader_source: &str,
|
||||
entry_point: &str,
|
||||
bindings: &[BindingDesc],
|
||||
) -> GpuResult<Arc<ComputePipeline>> {
|
||||
if let Some(pipeline) = self.pipelines.get(name) {
|
||||
return Ok(Arc::clone(&pipeline));
|
||||
}
|
||||
|
||||
info!("Creating and caching pipeline: {}", name);
|
||||
|
||||
let pipeline =
|
||||
ComputePipeline::from_shader(&self.device, shader_source, entry_point, bindings)?;
|
||||
let pipeline = Arc::new(pipeline);
|
||||
|
||||
self.pipelines
|
||||
.insert(name.to_string(), Arc::clone(&pipeline));
|
||||
|
||||
Ok(pipeline)
|
||||
}
|
||||
|
||||
/// Get a cached pipeline by name.
|
||||
pub fn get(&self, name: &str) -> Option<Arc<ComputePipeline>> {
|
||||
self.pipelines.get(name).map(|p| Arc::clone(&p))
|
||||
}
|
||||
|
||||
/// Check if a pipeline exists in cache.
|
||||
pub fn contains(&self, name: &str) -> bool {
|
||||
self.pipelines.contains_key(name)
|
||||
}
|
||||
|
||||
/// Remove a pipeline from cache.
|
||||
pub fn remove(&self, name: &str) -> Option<Arc<ComputePipeline>> {
|
||||
self.pipelines.remove(name).map(|(_, p)| p)
|
||||
}
|
||||
|
||||
/// Clear all cached pipelines.
|
||||
pub fn clear(&self) {
|
||||
self.pipelines.clear();
|
||||
}
|
||||
|
||||
/// Get the number of cached pipelines.
|
||||
pub fn len(&self) -> usize {
|
||||
self.pipelines.len()
|
||||
}
|
||||
|
||||
/// Check if the cache is empty.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.pipelines.is_empty()
|
||||
}
|
||||
|
||||
/// List all cached pipeline names.
|
||||
pub fn names(&self) -> Vec<String> {
|
||||
self.pipelines.iter().map(|e| e.key().clone()).collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Pre-defined shaders for common coherence operations
|
||||
pub mod shaders {
|
||||
/// WGSL shader for computing residuals
|
||||
pub const RESIDUAL_COMPUTE: &str = r#"
|
||||
// Node states: [node_count, dim]
|
||||
@group(0) @binding(0) var<storage, read> node_states: array<f32>;
|
||||
// Edge info: [edge_count, 4] - source_idx, target_idx, weight, padding
|
||||
@group(0) @binding(1) var<storage, read> edges: array<vec4<f32>>;
|
||||
// Restriction map (identity for simplicity): [dim, dim]
|
||||
@group(0) @binding(2) var<storage, read> restriction: array<f32>;
|
||||
// Output residuals: [edge_count]
|
||||
@group(0) @binding(3) var<storage, read_write> residuals: array<f32>;
|
||||
// Params: [dim, node_count, edge_count, 0]
|
||||
@group(0) @binding(4) var<uniform> params: vec4<u32>;
|
||||
|
||||
@compute @workgroup_size(256)
|
||||
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
|
||||
let edge_idx = id.x;
|
||||
let edge_count = params.z;
|
||||
let dim = params.x;
|
||||
|
||||
if (edge_idx >= edge_count) {
|
||||
return;
|
||||
}
|
||||
|
||||
let edge = edges[edge_idx];
|
||||
let source_idx = u32(edge.x);
|
||||
let target_idx = u32(edge.y);
|
||||
let weight = edge.z;
|
||||
|
||||
// Compute residual = ||rho_u(x_u) - rho_v(x_v)||^2
|
||||
var residual: f32 = 0.0;
|
||||
for (var d: u32 = 0u; d < dim; d = d + 1u) {
|
||||
let source_val = node_states[source_idx * dim + d];
|
||||
let target_val = node_states[target_idx * dim + d];
|
||||
let diff = source_val - target_val;
|
||||
residual = residual + diff * diff;
|
||||
}
|
||||
|
||||
residuals[edge_idx] = weight * residual;
|
||||
}
|
||||
"#;
|
||||
|
||||
/// WGSL shader for parallel reduction (sum)
|
||||
pub const REDUCE_SUM: &str = r#"
|
||||
@group(0) @binding(0) var<storage, read> input: array<f32>;
|
||||
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
|
||||
@group(0) @binding(2) var<uniform> count: u32;
|
||||
|
||||
var<workgroup> shared_data: array<f32, 256>;
|
||||
|
||||
@compute @workgroup_size(256)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) workgroup_id: vec3<u32>
|
||||
) {
|
||||
let tid = local_id.x;
|
||||
let gid = global_id.x;
|
||||
|
||||
// Load data into shared memory
|
||||
if (gid < count) {
|
||||
shared_data[tid] = input[gid];
|
||||
} else {
|
||||
shared_data[tid] = 0.0;
|
||||
}
|
||||
workgroupBarrier();
|
||||
|
||||
// Parallel reduction
|
||||
for (var s: u32 = 128u; s > 0u; s = s >> 1u) {
|
||||
if (tid < s) {
|
||||
shared_data[tid] = shared_data[tid] + shared_data[tid + s];
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
// Write result
|
||||
if (tid == 0u) {
|
||||
output[workgroup_id.x] = shared_data[0];
|
||||
}
|
||||
}
|
||||
"#;
|
||||
|
||||
/// WGSL shader for matrix-vector multiplication
|
||||
pub const MATVEC: &str = r#"
|
||||
@group(0) @binding(0) var<storage, read> matrix: array<f32>;
|
||||
@group(0) @binding(1) var<storage, read> vector: array<f32>;
|
||||
@group(0) @binding(2) var<storage, read_write> result: array<f32>;
|
||||
// params: [rows, cols, 0, 0]
|
||||
@group(0) @binding(3) var<uniform> params: vec4<u32>;
|
||||
|
||||
@compute @workgroup_size(256)
|
||||
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
|
||||
let row = id.x;
|
||||
let rows = params.x;
|
||||
let cols = params.y;
|
||||
|
||||
if (row >= rows) {
|
||||
return;
|
||||
}
|
||||
|
||||
var sum: f32 = 0.0;
|
||||
for (var c: u32 = 0u; c < cols; c = c + 1u) {
|
||||
sum = sum + matrix[row * cols + c] * vector[c];
|
||||
}
|
||||
|
||||
result[row] = sum;
|
||||
}
|
||||
"#;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_binding_desc() {
|
||||
let readonly = BindingDesc::storage_readonly();
|
||||
assert_eq!(readonly.binding_type, BindingType::StorageReadonly);
|
||||
|
||||
let readwrite = BindingDesc::storage_readwrite();
|
||||
assert_eq!(readwrite.binding_type, BindingType::StorageReadWrite);
|
||||
|
||||
let uniform = BindingDesc::uniform();
|
||||
assert_eq!(uniform.binding_type, BindingType::Uniform);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binding_with_label() {
|
||||
let binding = BindingDesc::storage_readonly().with_label("input_buffer");
|
||||
assert_eq!(binding.label.as_deref(), Some("input_buffer"));
|
||||
}
|
||||
}
|
||||
134
vendor/ruvector/crates/prime-radiant/src/gpu/shaders/compute_energy.wgsl
vendored
Normal file
134
vendor/ruvector/crates/prime-radiant/src/gpu/shaders/compute_energy.wgsl
vendored
Normal file
@@ -0,0 +1,134 @@
|
||||
// =============================================================================
|
||||
// Prime-Radiant GPU Compute Shaders - Energy Computation
|
||||
// =============================================================================
|
||||
//
|
||||
// Parallel reduction to compute total coherence energy:
|
||||
// E(S) = sum(w_e * |r_e|^2)
|
||||
//
|
||||
// Uses a two-phase reduction strategy:
|
||||
// 1. Local reduction within workgroups using shared memory
|
||||
// 2. Global reduction across workgroup partial sums
|
||||
|
||||
// =============================================================================
|
||||
// TYPE DEFINITIONS
|
||||
// =============================================================================
|
||||
|
||||
struct EnergyParams {
|
||||
num_elements: u32,
|
||||
_padding0: u32,
|
||||
_padding1: u32,
|
||||
_padding2: u32,
|
||||
_padding3: u32,
|
||||
_padding4: u32,
|
||||
_padding5: u32,
|
||||
_padding6: u32,
|
||||
}
|
||||
|
||||
const WORKGROUP_SIZE: u32 = 256u;
|
||||
|
||||
// =============================================================================
|
||||
// BUFFER BINDINGS
|
||||
// =============================================================================
|
||||
// Layout matches Rust kernel bind group:
|
||||
// binding 0: params (uniform)
|
||||
// binding 1: input (storage, read) - edge energies or partial sums
|
||||
// binding 2: output (storage, read_write) - partial sums or final result
|
||||
|
||||
/// Energy computation parameters
|
||||
@group(0) @binding(0) var<uniform> params: EnergyParams;
|
||||
|
||||
/// Input values to reduce
|
||||
@group(0) @binding(1) var<storage, read> input_values: array<f32>;
|
||||
|
||||
/// Output partial sums or final result
|
||||
@group(0) @binding(2) var<storage, read_write> output_values: array<f32>;
|
||||
|
||||
// =============================================================================
|
||||
// SHARED MEMORY
|
||||
// =============================================================================
|
||||
|
||||
/// Shared memory for parallel reduction
|
||||
var<workgroup> shared_data: array<f32, 256>;
|
||||
|
||||
// =============================================================================
|
||||
// MAIN REDUCTION KERNEL
|
||||
// =============================================================================
|
||||
|
||||
/// Phase 1: Reduce input values within workgroup
|
||||
@compute @workgroup_size(256)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) workgroup_id: vec3<u32>
|
||||
) {
|
||||
let tid = local_id.x;
|
||||
let gid = global_id.x;
|
||||
let element_count = params.num_elements;
|
||||
|
||||
// Load element (or 0 if out of bounds)
|
||||
var val: f32 = 0.0;
|
||||
if (gid < element_count) {
|
||||
val = input_values[gid];
|
||||
}
|
||||
|
||||
// Store in shared memory
|
||||
shared_data[tid] = val;
|
||||
workgroupBarrier();
|
||||
|
||||
// Tree reduction with sequential addressing
|
||||
for (var stride = WORKGROUP_SIZE / 2u; stride > 0u; stride >>= 1u) {
|
||||
if (tid < stride) {
|
||||
shared_data[tid] += shared_data[tid + stride];
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
// Thread 0 writes the partial sum
|
||||
if (tid == 0u) {
|
||||
output_values[workgroup_id.x] = shared_data[0];
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// FINAL REDUCTION PASS
|
||||
// =============================================================================
|
||||
|
||||
/// Phase 2: Reduce partial sums to final total
|
||||
/// Reads from input_values (the partial sums from phase 1)
|
||||
/// Writes result to output_values[0]
|
||||
@compute @workgroup_size(256)
|
||||
fn final_reduce(
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>
|
||||
) {
|
||||
let tid = local_id.x;
|
||||
let element_count = params.num_elements;
|
||||
|
||||
// Load partial sum from input (or 0 if out of bounds)
|
||||
var sum: f32 = 0.0;
|
||||
if (tid < element_count) {
|
||||
sum = input_values[tid];
|
||||
}
|
||||
|
||||
// Handle case where we have more partial sums than workgroup size
|
||||
var idx = tid + WORKGROUP_SIZE;
|
||||
while (idx < element_count) {
|
||||
sum += input_values[idx];
|
||||
idx += WORKGROUP_SIZE;
|
||||
}
|
||||
|
||||
shared_data[tid] = sum;
|
||||
workgroupBarrier();
|
||||
|
||||
// Tree reduction
|
||||
for (var stride = WORKGROUP_SIZE / 2u; stride > 0u; stride >>= 1u) {
|
||||
if (tid < stride) {
|
||||
shared_data[tid] += shared_data[tid + stride];
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
// Write final result to output[0]
|
||||
if (tid == 0u) {
|
||||
output_values[0] = shared_data[0];
|
||||
}
|
||||
}
|
||||
223
vendor/ruvector/crates/prime-radiant/src/gpu/shaders/compute_residuals.wgsl
vendored
Normal file
223
vendor/ruvector/crates/prime-radiant/src/gpu/shaders/compute_residuals.wgsl
vendored
Normal file
@@ -0,0 +1,223 @@
|
||||
// =============================================================================
|
||||
// Prime-Radiant GPU Compute Shaders - Residual Computation
|
||||
// =============================================================================
|
||||
//
|
||||
// Computes sheaf Laplacian residuals: r_e = rho_source(x_source) - rho_target(x_target)
|
||||
// and per-edge energy: E_e = w_e * ||r_e||^2
|
||||
//
|
||||
// Each thread processes one edge, computing the residual and squared norm.
|
||||
|
||||
// =============================================================================
|
||||
// TYPE DEFINITIONS (must match Rust structs exactly)
|
||||
// =============================================================================
|
||||
|
||||
struct GpuParams {
|
||||
num_edges: u32,
|
||||
num_nodes: u32,
|
||||
state_dim: u32,
|
||||
beta: f32,
|
||||
threshold_lane0: f32,
|
||||
threshold_lane1: f32,
|
||||
threshold_lane2: f32,
|
||||
store_residuals: u32, // 0 = skip storage (energy only), 1 = store residuals
|
||||
}
|
||||
|
||||
struct GpuEdge {
|
||||
source_idx: u32,
|
||||
target_idx: u32,
|
||||
weight: f32,
|
||||
rho_source_idx: u32,
|
||||
rho_target_idx: u32,
|
||||
comparison_dim: u32,
|
||||
_padding0: u32,
|
||||
_padding1: u32,
|
||||
}
|
||||
|
||||
struct GpuRestrictionMap {
|
||||
map_type: u32, // 0=identity, 1=diagonal, 2=projection, 3=dense
|
||||
input_dim: u32,
|
||||
output_dim: u32,
|
||||
data_offset: u32,
|
||||
data_len: u32,
|
||||
_padding0: u32,
|
||||
_padding1: u32,
|
||||
_padding2: u32,
|
||||
}
|
||||
|
||||
const WORKGROUP_SIZE: u32 = 256u;
|
||||
const MAP_IDENTITY: u32 = 0u;
|
||||
const MAP_DIAGONAL: u32 = 1u;
|
||||
const MAP_PROJECTION: u32 = 2u;
|
||||
const MAP_DENSE: u32 = 3u;
|
||||
|
||||
// =============================================================================
|
||||
// BUFFER BINDINGS (matches Rust kernel bind group layout)
|
||||
// =============================================================================
|
||||
// binding 0: params (uniform)
|
||||
// binding 1: node_states (storage, read)
|
||||
// binding 2: edges (storage, read)
|
||||
// binding 3: restriction_maps (storage, read)
|
||||
// binding 4: restriction_data (storage, read)
|
||||
// binding 5: residuals (storage, read_write)
|
||||
// binding 6: energies (storage, read_write)
|
||||
|
||||
@group(0) @binding(0) var<uniform> params: GpuParams;
|
||||
@group(0) @binding(1) var<storage, read> node_states: array<f32>;
|
||||
@group(0) @binding(2) var<storage, read> edges: array<GpuEdge>;
|
||||
@group(0) @binding(3) var<storage, read> restriction_maps: array<GpuRestrictionMap>;
|
||||
@group(0) @binding(4) var<storage, read> restriction_data: array<f32>;
|
||||
@group(0) @binding(5) var<storage, read_write> residuals: array<f32>;
|
||||
@group(0) @binding(6) var<storage, read_write> energies: array<f32>;
|
||||
|
||||
// =============================================================================
|
||||
// HELPER FUNCTIONS
|
||||
// =============================================================================
|
||||
|
||||
/// Apply restriction map to a state vector at the given offset
|
||||
/// Returns the projected value at output dimension d
|
||||
fn apply_restriction(
|
||||
rho: GpuRestrictionMap,
|
||||
state_base: u32,
|
||||
output_dim: u32
|
||||
) -> f32 {
|
||||
switch(rho.map_type) {
|
||||
case MAP_IDENTITY: {
|
||||
// Identity: just return the corresponding element
|
||||
if (output_dim < rho.output_dim && output_dim < params.state_dim) {
|
||||
return node_states[state_base + output_dim];
|
||||
}
|
||||
return 0.0;
|
||||
}
|
||||
case MAP_DIAGONAL: {
|
||||
// Diagonal: scale by diagonal element
|
||||
if (output_dim < rho.data_len) {
|
||||
let scale = restriction_data[rho.data_offset + output_dim];
|
||||
return node_states[state_base + output_dim] * scale;
|
||||
}
|
||||
return 0.0;
|
||||
}
|
||||
case MAP_PROJECTION: {
|
||||
// Projection: select specific indices
|
||||
if (output_dim < rho.data_len) {
|
||||
let idx = u32(restriction_data[rho.data_offset + output_dim]);
|
||||
if (idx < params.state_dim) {
|
||||
return node_states[state_base + idx];
|
||||
}
|
||||
}
|
||||
return 0.0;
|
||||
}
|
||||
case MAP_DENSE, default: {
|
||||
// Dense: matrix-vector multiply for row output_dim
|
||||
var result: f32 = 0.0;
|
||||
let row_offset = rho.data_offset + output_dim * rho.input_dim;
|
||||
for (var i = 0u; i < rho.input_dim && i < params.state_dim; i++) {
|
||||
result += restriction_data[row_offset + i] * node_states[state_base + i];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// MAIN ENTRY POINT
|
||||
// =============================================================================
|
||||
|
||||
@compute @workgroup_size(256)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>
|
||||
) {
|
||||
let edge_idx = global_id.x;
|
||||
|
||||
// Bounds check
|
||||
if (edge_idx >= params.num_edges) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Get edge data
|
||||
let edge = edges[edge_idx];
|
||||
|
||||
// Compute base offsets for source and target node states
|
||||
let source_base = edge.source_idx * params.state_dim;
|
||||
let target_base = edge.target_idx * params.state_dim;
|
||||
|
||||
// Get restriction maps
|
||||
let rho_source = restriction_maps[edge.rho_source_idx];
|
||||
let rho_target = restriction_maps[edge.rho_target_idx];
|
||||
|
||||
// Compute residual: r = rho_source(x_source) - rho_target(x_target)
|
||||
// and accumulate squared norm
|
||||
//
|
||||
// OPTIMIZATION: Process 4 dimensions at a time using vec4 operations.
|
||||
// This leverages GPU SIMD capabilities for ~4x throughput on high-dimensional
|
||||
// state vectors. The dot(v, v) operation is particularly efficient on GPU.
|
||||
var norm_sq: f32 = 0.0;
|
||||
let comparison_dim = edge.comparison_dim;
|
||||
let residual_base = edge_idx * comparison_dim;
|
||||
|
||||
// Calculate how many full vec4 iterations and remainder
|
||||
let vec4_count = comparison_dim / 4u;
|
||||
let remainder = comparison_dim % 4u;
|
||||
|
||||
// Process 4 dimensions at a time
|
||||
var d = 0u;
|
||||
for (var i = 0u; i < vec4_count; i++) {
|
||||
// Load 4 source values via restriction maps
|
||||
let source_vec = vec4<f32>(
|
||||
apply_restriction(rho_source, source_base, d),
|
||||
apply_restriction(rho_source, source_base, d + 1u),
|
||||
apply_restriction(rho_source, source_base, d + 2u),
|
||||
apply_restriction(rho_source, source_base, d + 3u)
|
||||
);
|
||||
|
||||
// Load 4 target values via restriction maps
|
||||
let target_vec = vec4<f32>(
|
||||
apply_restriction(rho_target, target_base, d),
|
||||
apply_restriction(rho_target, target_base, d + 1u),
|
||||
apply_restriction(rho_target, target_base, d + 2u),
|
||||
apply_restriction(rho_target, target_base, d + 3u)
|
||||
);
|
||||
|
||||
// Compute residual vector (4 components at once)
|
||||
let r_vec = source_vec - target_vec;
|
||||
|
||||
// Accumulate norm using dot product (very efficient on GPU - single instruction)
|
||||
norm_sq += dot(r_vec, r_vec);
|
||||
|
||||
// Store residuals if requested (optional for energy-only computation)
|
||||
if (params.store_residuals != 0u) {
|
||||
let base_offset = residual_base + d;
|
||||
if (base_offset + 3u < arrayLength(&residuals)) {
|
||||
residuals[base_offset] = r_vec.x;
|
||||
residuals[base_offset + 1u] = r_vec.y;
|
||||
residuals[base_offset + 2u] = r_vec.z;
|
||||
residuals[base_offset + 3u] = r_vec.w;
|
||||
}
|
||||
}
|
||||
|
||||
d += 4u;
|
||||
}
|
||||
|
||||
// Handle remainder dimensions (0-3 elements)
|
||||
for (var j = 0u; j < remainder; j++) {
|
||||
let dim_idx = d + j;
|
||||
let projected_source = apply_restriction(rho_source, source_base, dim_idx);
|
||||
let projected_target = apply_restriction(rho_target, target_base, dim_idx);
|
||||
let r = projected_source - projected_target;
|
||||
|
||||
norm_sq += r * r;
|
||||
|
||||
if (params.store_residuals != 0u) {
|
||||
let offset = residual_base + dim_idx;
|
||||
if (offset < arrayLength(&residuals)) {
|
||||
residuals[offset] = r;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Compute weighted energy: E_e = w_e * ||r_e||^2
|
||||
let energy = edge.weight * norm_sq;
|
||||
|
||||
// Store per-edge energy
|
||||
energies[edge_idx] = energy;
|
||||
}
|
||||
144
vendor/ruvector/crates/prime-radiant/src/gpu/shaders/sheaf_attention.wgsl
vendored
Normal file
144
vendor/ruvector/crates/prime-radiant/src/gpu/shaders/sheaf_attention.wgsl
vendored
Normal file
@@ -0,0 +1,144 @@
|
||||
// =============================================================================
|
||||
// Prime-Radiant GPU Compute Shaders - Sheaf Attention
|
||||
// =============================================================================
|
||||
//
|
||||
// Energy-based sheaf attention: A_ij = softmax(-beta * E_ij)
|
||||
//
|
||||
// Attention weights are computed from coherence energy:
|
||||
// - Low energy (coherent) edges get high attention
|
||||
// - High energy (incoherent) edges get low attention
|
||||
|
||||
// =============================================================================
|
||||
// TYPE DEFINITIONS
|
||||
// =============================================================================
|
||||
|
||||
struct AttentionParams {
|
||||
num_edges: u32,
|
||||
num_nodes: u32,
|
||||
beta: f32,
|
||||
energy_threshold: f32,
|
||||
use_sparse: u32,
|
||||
_padding0: u32,
|
||||
_padding1: u32,
|
||||
_padding2: u32,
|
||||
}
|
||||
|
||||
struct EdgeDescriptor {
|
||||
source_idx: u32,
|
||||
target_idx: u32,
|
||||
weight: f32,
|
||||
_padding: u32,
|
||||
}
|
||||
|
||||
const WORKGROUP_SIZE: u32 = 256u;
|
||||
const NEG_INF: f32 = -3.402823e+38;
|
||||
const EPSILON: f32 = 1e-8;
|
||||
|
||||
// =============================================================================
|
||||
// BUFFER BINDINGS
|
||||
// =============================================================================
|
||||
// Layout matches Rust kernel bind group:
|
||||
// binding 0: params (uniform)
|
||||
// binding 1: edges (storage, read)
|
||||
// binding 2: edge_energies (storage, read)
|
||||
// binding 3: attention_weights (storage, read_write)
|
||||
// binding 4: node_exp_sums (storage, read_write)
|
||||
|
||||
/// Attention parameters
|
||||
@group(0) @binding(0) var<uniform> params: AttentionParams;
|
||||
|
||||
/// Edge descriptors
|
||||
@group(0) @binding(1) var<storage, read> edges: array<EdgeDescriptor>;
|
||||
|
||||
/// Edge energies from residual computation
|
||||
@group(0) @binding(2) var<storage, read> edge_energies: array<f32>;
|
||||
|
||||
/// Output attention weights (one per edge)
|
||||
@group(0) @binding(3) var<storage, read_write> attention_weights: array<f32>;
|
||||
|
||||
/// Per-node exponential sums for normalization
|
||||
@group(0) @binding(4) var<storage, read_write> node_exp_sums: array<f32>;
|
||||
|
||||
// =============================================================================
|
||||
// SHARED MEMORY
|
||||
// =============================================================================
|
||||
|
||||
/// Shared memory for parallel reduction
|
||||
var<workgroup> shared_data: array<f32, 256>;
|
||||
|
||||
// =============================================================================
|
||||
// SINGLE-PASS ATTENTION COMPUTATION
|
||||
// =============================================================================
|
||||
|
||||
/// Compute attention weights from edge energies
|
||||
/// A_e = exp(-beta * E_e) (unnormalized)
|
||||
/// Each workgroup processes multiple edges
|
||||
@compute @workgroup_size(256)
|
||||
fn compute_attention_single_pass(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>
|
||||
) {
|
||||
let edge_idx = global_id.x;
|
||||
let num_edges = params.num_edges;
|
||||
let beta = params.beta;
|
||||
|
||||
if (edge_idx >= num_edges) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Get edge energy
|
||||
let energy = edge_energies[edge_idx];
|
||||
|
||||
// Compute unnormalized attention weight
|
||||
// For energy-based attention: A = exp(-beta * E)
|
||||
// High energy (incoherent) -> low attention
|
||||
// Low energy (coherent) -> high attention
|
||||
var score = -beta * energy;
|
||||
|
||||
// Apply energy threshold masking for sparse attention
|
||||
if (params.use_sparse == 1u && energy > params.energy_threshold) {
|
||||
score = NEG_INF;
|
||||
}
|
||||
|
||||
// Compute exp(score) - clamp to avoid overflow
|
||||
let clamped_score = clamp(score, -80.0, 80.0);
|
||||
let exp_score = exp(clamped_score);
|
||||
|
||||
// Store unnormalized attention weight
|
||||
attention_weights[edge_idx] = exp_score;
|
||||
|
||||
// Accumulate exp sum for source node (for later normalization)
|
||||
// Note: This requires atomic operations for correctness in parallel
|
||||
// For now, we store unnormalized weights; normalization done in separate pass
|
||||
let edge = edges[edge_idx];
|
||||
// atomicAdd(&node_exp_sums[edge.source_idx], exp_score);
|
||||
// Note: WGSL doesn't have atomicAdd for f32, so we store for CPU normalization
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// NORMALIZATION PASS
|
||||
// =============================================================================
|
||||
|
||||
/// Normalize attention weights by node (outgoing edges sum to 1)
|
||||
/// Second pass after exp sums are computed
|
||||
@compute @workgroup_size(256)
|
||||
fn normalize_attention(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>
|
||||
) {
|
||||
let edge_idx = global_id.x;
|
||||
let num_edges = params.num_edges;
|
||||
|
||||
if (edge_idx >= num_edges) {
|
||||
return;
|
||||
}
|
||||
|
||||
let edge = edges[edge_idx];
|
||||
let source_idx = edge.source_idx;
|
||||
|
||||
// Get the sum of exp scores for this source node
|
||||
let exp_sum = node_exp_sums[source_idx];
|
||||
|
||||
// Normalize
|
||||
let normalized = attention_weights[edge_idx] / max(exp_sum, EPSILON);
|
||||
attention_weights[edge_idx] = normalized;
|
||||
}
|
||||
471
vendor/ruvector/crates/prime-radiant/src/gpu/shaders/sparse_mask.wgsl
vendored
Normal file
471
vendor/ruvector/crates/prime-radiant/src/gpu/shaders/sparse_mask.wgsl
vendored
Normal file
@@ -0,0 +1,471 @@
|
||||
// =============================================================================
|
||||
// Prime-Radiant GPU Compute Shaders - Sparse Attention Mask
|
||||
// =============================================================================
|
||||
//
|
||||
// Generate sparse attention masks from energy thresholds.
|
||||
// Only edges with energy below threshold (coherent) are included.
|
||||
//
|
||||
// This enables efficient sparse attention where only meaningful
|
||||
// (low-energy, coherent) connections are computed, dramatically
|
||||
// reducing computation for large graphs.
|
||||
//
|
||||
// Output Formats:
|
||||
// 1. Index list: Compact list of (row, col) pairs for valid edges
|
||||
// 2. Dense mask: Full NxN boolean matrix (for small N)
|
||||
// 3. CSR format: Compressed sparse row for efficient sparse matmul
|
||||
//
|
||||
// Optimizations:
|
||||
// - Stream compaction for index list generation
|
||||
// - Warp-level voting for efficient counting
|
||||
// - Coalesced writes using shared memory staging
|
||||
|
||||
// =============================================================================
|
||||
// TYPE DEFINITIONS
|
||||
// =============================================================================
|
||||
|
||||
struct SparseMaskParams {
|
||||
total_edges: u32,
|
||||
coherence_threshold: f32,
|
||||
max_edges: u32,
|
||||
output_format: u32, // 0=indices, 1=dense, 2=csr
|
||||
seq_len: u32,
|
||||
batch_size: u32,
|
||||
padding: array<u32, 2>,
|
||||
}
|
||||
|
||||
struct EdgeIndex {
|
||||
row: u32,
|
||||
col: u32,
|
||||
}
|
||||
|
||||
struct CSRPointers {
|
||||
row_ptr: u32,
|
||||
nnz: u32,
|
||||
}
|
||||
|
||||
const WORKGROUP_SIZE: u32 = 256u;
|
||||
const OUTPUT_INDICES: u32 = 0u;
|
||||
const OUTPUT_DENSE: u32 = 1u;
|
||||
const OUTPUT_CSR: u32 = 2u;
|
||||
|
||||
// =============================================================================
|
||||
// BUFFER BINDINGS
|
||||
// =============================================================================
|
||||
|
||||
/// Input edge energies (seq_len * seq_len per batch, or sparse)
|
||||
@group(0) @binding(0) var<storage, read> edge_energies: array<f32>;
|
||||
|
||||
/// Output: sparse edge indices (for index format)
|
||||
@group(0) @binding(1) var<storage, read_write> sparse_indices: array<EdgeIndex>;
|
||||
|
||||
/// Output: dense mask (for dense format)
|
||||
@group(0) @binding(2) var<storage, read_write> dense_mask: array<u32>;
|
||||
|
||||
/// Output: number of valid edges (atomic counter)
|
||||
@group(0) @binding(3) var<storage, read_write> edge_count: atomic<u32>;
|
||||
|
||||
/// Mask parameters
|
||||
@group(0) @binding(4) var<uniform> params: SparseMaskParams;
|
||||
|
||||
// =============================================================================
|
||||
// SHARED MEMORY
|
||||
// =============================================================================
|
||||
|
||||
/// Shared memory for stream compaction
|
||||
var<workgroup> shared_valid: array<u32, 256>;
|
||||
|
||||
/// Prefix sum for compaction offsets
|
||||
var<workgroup> shared_prefix: array<u32, 256>;
|
||||
|
||||
/// Staging buffer for coalesced writes
|
||||
var<workgroup> shared_indices: array<EdgeIndex, 256>;
|
||||
|
||||
/// Workgroup-level count of valid edges
|
||||
var<workgroup> workgroup_count: atomic<u32>;
|
||||
|
||||
// =============================================================================
|
||||
// BASIC SPARSE MASK GENERATION
|
||||
// =============================================================================
|
||||
|
||||
/// Generate sparse mask as index list
|
||||
@compute @workgroup_size(256)
|
||||
fn generate_sparse_indices(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) workgroup_id: vec3<u32>
|
||||
) {
|
||||
let idx = global_id.x;
|
||||
let tid = local_id.x;
|
||||
let total_edges = params.total_edges;
|
||||
let threshold = params.coherence_threshold;
|
||||
let seq_len = params.seq_len;
|
||||
|
||||
// Initialize workgroup counter
|
||||
if (tid == 0u) {
|
||||
atomicStore(&workgroup_count, 0u);
|
||||
}
|
||||
workgroupBarrier();
|
||||
|
||||
// Check if this edge is valid (below threshold)
|
||||
var is_valid: u32 = 0u;
|
||||
var row: u32 = 0u;
|
||||
var col: u32 = 0u;
|
||||
|
||||
if (idx < total_edges) {
|
||||
let energy = edge_energies[idx];
|
||||
is_valid = select(0u, 1u, energy < threshold);
|
||||
|
||||
// Compute row and column from linear index
|
||||
row = idx / seq_len;
|
||||
col = idx % seq_len;
|
||||
}
|
||||
|
||||
shared_valid[tid] = is_valid;
|
||||
workgroupBarrier();
|
||||
|
||||
// Compute prefix sum for compaction
|
||||
// Hillis-Steele parallel scan
|
||||
shared_prefix[tid] = is_valid;
|
||||
workgroupBarrier();
|
||||
|
||||
for (var offset = 1u; offset < WORKGROUP_SIZE; offset <<= 1u) {
|
||||
var val: u32 = 0u;
|
||||
if (tid >= offset) {
|
||||
val = shared_prefix[tid - offset];
|
||||
}
|
||||
workgroupBarrier();
|
||||
shared_prefix[tid] += val;
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
// Total valid in this workgroup
|
||||
let total_valid = shared_prefix[WORKGROUP_SIZE - 1u];
|
||||
|
||||
// Get global offset for this workgroup
|
||||
var global_offset: u32 = 0u;
|
||||
if (tid == 0u && total_valid > 0u) {
|
||||
global_offset = atomicAdd(&edge_count, total_valid);
|
||||
atomicStore(&workgroup_count, global_offset);
|
||||
}
|
||||
workgroupBarrier();
|
||||
global_offset = atomicLoad(&workgroup_count);
|
||||
|
||||
// Write valid edges to output using compacted indices
|
||||
if (is_valid == 1u && idx < total_edges) {
|
||||
// Exclusive prefix sum gives position
|
||||
let local_pos = select(0u, shared_prefix[tid - 1u], tid > 0u);
|
||||
let global_pos = global_offset + local_pos;
|
||||
|
||||
if (global_pos < params.max_edges) {
|
||||
sparse_indices[global_pos] = EdgeIndex(row, col);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// DENSE MASK GENERATION
|
||||
// =============================================================================
|
||||
|
||||
/// Generate dense boolean mask (packed as u32 bits)
|
||||
@compute @workgroup_size(256)
|
||||
fn generate_dense_mask(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>
|
||||
) {
|
||||
let idx = global_id.x;
|
||||
let total_edges = params.total_edges;
|
||||
let threshold = params.coherence_threshold;
|
||||
|
||||
if (idx >= total_edges) {
|
||||
return;
|
||||
}
|
||||
|
||||
let energy = edge_energies[idx];
|
||||
let is_valid = energy < threshold;
|
||||
|
||||
// Pack 32 boolean values per u32
|
||||
let word_idx = idx / 32u;
|
||||
let bit_idx = idx % 32u;
|
||||
|
||||
if (is_valid) {
|
||||
// Atomic OR to set the bit
|
||||
atomicOr(&dense_mask[word_idx], 1u << bit_idx);
|
||||
}
|
||||
}
|
||||
|
||||
/// Unpack dense mask bit
|
||||
fn is_edge_valid(dense_mask_ptr: ptr<storage, array<u32>, read>, idx: u32) -> bool {
|
||||
let word_idx = idx / 32u;
|
||||
let bit_idx = idx % 32u;
|
||||
return ((*dense_mask_ptr)[word_idx] & (1u << bit_idx)) != 0u;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// CSR FORMAT GENERATION
|
||||
// =============================================================================
|
||||
|
||||
/// CSR row pointers
|
||||
@group(1) @binding(0) var<storage, read_write> csr_row_ptr: array<u32>;
|
||||
|
||||
/// CSR column indices
|
||||
@group(1) @binding(1) var<storage, read_write> csr_col_idx: array<u32>;
|
||||
|
||||
/// CSR values (attention weights or energies)
|
||||
@group(1) @binding(2) var<storage, read_write> csr_values: array<f32>;
|
||||
|
||||
/// Per-row counters for CSR construction
|
||||
@group(1) @binding(3) var<storage, read_write> row_counts: array<atomic<u32>>;
|
||||
|
||||
/// Phase 1: Count valid edges per row
|
||||
@compute @workgroup_size(256)
|
||||
fn count_edges_per_row(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>
|
||||
) {
|
||||
let idx = global_id.x;
|
||||
let total_edges = params.total_edges;
|
||||
let threshold = params.coherence_threshold;
|
||||
let seq_len = params.seq_len;
|
||||
|
||||
if (idx >= total_edges) {
|
||||
return;
|
||||
}
|
||||
|
||||
let energy = edge_energies[idx];
|
||||
|
||||
if (energy < threshold) {
|
||||
let row = idx / seq_len;
|
||||
atomicAdd(&row_counts[row], 1u);
|
||||
}
|
||||
}
|
||||
|
||||
/// Phase 2: Compute row pointers via prefix sum
|
||||
@compute @workgroup_size(256)
|
||||
fn compute_row_pointers(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>
|
||||
) {
|
||||
let row = global_id.x;
|
||||
let tid = local_id.x;
|
||||
let seq_len = params.seq_len;
|
||||
|
||||
if (row >= seq_len) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Load count into shared memory
|
||||
shared_prefix[tid] = atomicLoad(&row_counts[row]);
|
||||
workgroupBarrier();
|
||||
|
||||
// Inclusive prefix sum
|
||||
for (var offset = 1u; offset < WORKGROUP_SIZE; offset <<= 1u) {
|
||||
var val: u32 = 0u;
|
||||
if (tid >= offset) {
|
||||
val = shared_prefix[tid - offset];
|
||||
}
|
||||
workgroupBarrier();
|
||||
shared_prefix[tid] += val;
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
// Convert to exclusive prefix sum for row pointers
|
||||
// row_ptr[i] = sum of counts for rows 0..i-1
|
||||
let inclusive_sum = shared_prefix[tid];
|
||||
let count = atomicLoad(&row_counts[row]);
|
||||
let exclusive_sum = inclusive_sum - count;
|
||||
|
||||
csr_row_ptr[row] = exclusive_sum;
|
||||
|
||||
// Reset counter to be used as write position
|
||||
atomicStore(&row_counts[row], exclusive_sum);
|
||||
|
||||
// Last row sets the final pointer (total nnz)
|
||||
if (row == seq_len - 1u) {
|
||||
csr_row_ptr[seq_len] = inclusive_sum;
|
||||
}
|
||||
}
|
||||
|
||||
/// Phase 3: Populate CSR column indices and values
|
||||
@compute @workgroup_size(256)
|
||||
fn populate_csr_data(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>
|
||||
) {
|
||||
let idx = global_id.x;
|
||||
let total_edges = params.total_edges;
|
||||
let threshold = params.coherence_threshold;
|
||||
let seq_len = params.seq_len;
|
||||
|
||||
if (idx >= total_edges) {
|
||||
return;
|
||||
}
|
||||
|
||||
let energy = edge_energies[idx];
|
||||
|
||||
if (energy < threshold) {
|
||||
let row = idx / seq_len;
|
||||
let col = idx % seq_len;
|
||||
|
||||
// Get write position using atomic increment
|
||||
let pos = atomicAdd(&row_counts[row], 1u);
|
||||
|
||||
csr_col_idx[pos] = col;
|
||||
csr_values[pos] = energy;
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// BATCHED SPARSE MASK
|
||||
// =============================================================================
|
||||
|
||||
/// Batch offsets for multi-batch processing
|
||||
@group(2) @binding(0) var<storage, read> batch_offsets: array<u32>;
|
||||
|
||||
/// Per-batch edge counts
|
||||
@group(2) @binding(1) var<storage, read_write> batch_edge_counts: array<atomic<u32>>;
|
||||
|
||||
/// Generate sparse mask for multiple batches
|
||||
@compute @workgroup_size(256)
|
||||
fn generate_batched_sparse_mask(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) workgroup_id: vec3<u32>
|
||||
) {
|
||||
let batch_idx = workgroup_id.z;
|
||||
let local_idx = global_id.x;
|
||||
let tid = local_id.x;
|
||||
|
||||
let seq_len = params.seq_len;
|
||||
let edges_per_batch = seq_len * seq_len;
|
||||
let threshold = params.coherence_threshold;
|
||||
|
||||
if (local_idx >= edges_per_batch) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Global index in energy array
|
||||
let global_idx = batch_idx * edges_per_batch + local_idx;
|
||||
|
||||
let energy = edge_energies[global_idx];
|
||||
let is_valid = select(0u, 1u, energy < threshold);
|
||||
|
||||
// Stream compaction within batch
|
||||
shared_valid[tid] = is_valid;
|
||||
workgroupBarrier();
|
||||
|
||||
// Prefix sum
|
||||
shared_prefix[tid] = is_valid;
|
||||
workgroupBarrier();
|
||||
|
||||
for (var offset = 1u; offset < WORKGROUP_SIZE; offset <<= 1u) {
|
||||
var val: u32 = 0u;
|
||||
if (tid >= offset) {
|
||||
val = shared_prefix[tid - offset];
|
||||
}
|
||||
workgroupBarrier();
|
||||
shared_prefix[tid] += val;
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
// Get batch-local offset
|
||||
if (tid == 0u) {
|
||||
let total_valid = shared_prefix[WORKGROUP_SIZE - 1u];
|
||||
let offset = atomicAdd(&batch_edge_counts[batch_idx], total_valid);
|
||||
atomicStore(&workgroup_count, offset);
|
||||
}
|
||||
workgroupBarrier();
|
||||
|
||||
let batch_offset = batch_offsets[batch_idx];
|
||||
let workgroup_offset = atomicLoad(&workgroup_count);
|
||||
|
||||
// Write valid edges
|
||||
if (is_valid == 1u) {
|
||||
let local_pos = select(0u, shared_prefix[tid - 1u], tid > 0u);
|
||||
let global_pos = batch_offset + workgroup_offset + local_pos;
|
||||
|
||||
let row = local_idx / seq_len;
|
||||
let col = local_idx % seq_len;
|
||||
|
||||
if (global_pos < params.max_edges) {
|
||||
sparse_indices[global_pos] = EdgeIndex(row, col);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// DYNAMIC THRESHOLD ADJUSTMENT
|
||||
// =============================================================================
|
||||
|
||||
/// Statistics for adaptive threshold
|
||||
@group(3) @binding(0) var<storage, read_write> mask_stats: array<f32>;
|
||||
|
||||
/// Compute mask statistics for adaptive thresholding
|
||||
@compute @workgroup_size(256)
|
||||
fn compute_mask_statistics(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>
|
||||
) {
|
||||
let idx = global_id.x;
|
||||
let tid = local_id.x;
|
||||
let total_edges = params.total_edges;
|
||||
let threshold = params.coherence_threshold;
|
||||
|
||||
// Count valid and total, compute sparsity ratio
|
||||
var valid_count: u32 = 0u;
|
||||
|
||||
if (idx < total_edges) {
|
||||
let energy = edge_energies[idx];
|
||||
valid_count = select(0u, 1u, energy < threshold);
|
||||
}
|
||||
|
||||
shared_prefix[tid] = valid_count;
|
||||
workgroupBarrier();
|
||||
|
||||
// Reduce to get total valid
|
||||
for (var stride = WORKGROUP_SIZE / 2u; stride > 0u; stride >>= 1u) {
|
||||
if (tid < stride) {
|
||||
shared_prefix[tid] += shared_prefix[tid + stride];
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
// Thread 0 updates global statistics
|
||||
if (tid == 0u) {
|
||||
// Atomic add to global counter
|
||||
// mask_stats[0] = total valid edges
|
||||
// mask_stats[1] = sparsity ratio (computed after all workgroups)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// CAUSAL MASK COMBINATION
|
||||
// =============================================================================
|
||||
|
||||
/// Combine energy-based sparse mask with causal mask
|
||||
@compute @workgroup_size(16, 16)
|
||||
fn combine_with_causal_mask(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>
|
||||
) {
|
||||
let row = global_id.y;
|
||||
let col = global_id.x;
|
||||
let seq_len = params.seq_len;
|
||||
let threshold = params.coherence_threshold;
|
||||
|
||||
if (row >= seq_len || col >= seq_len) {
|
||||
return;
|
||||
}
|
||||
|
||||
let idx = row * seq_len + col;
|
||||
let energy = edge_energies[idx];
|
||||
|
||||
// Valid if: (1) below energy threshold AND (2) satisfies causal constraint
|
||||
let energy_valid = energy < threshold;
|
||||
let causal_valid = col <= row; // Can only attend to past
|
||||
|
||||
let is_valid = energy_valid && causal_valid;
|
||||
|
||||
// Write to dense mask
|
||||
let word_idx = idx / 32u;
|
||||
let bit_idx = idx % 32u;
|
||||
|
||||
if (is_valid) {
|
||||
atomicOr(&dense_mask[word_idx], 1u << bit_idx);
|
||||
}
|
||||
}
|
||||
253
vendor/ruvector/crates/prime-radiant/src/gpu/shaders/token_routing.wgsl
vendored
Normal file
253
vendor/ruvector/crates/prime-radiant/src/gpu/shaders/token_routing.wgsl
vendored
Normal file
@@ -0,0 +1,253 @@
|
||||
// =============================================================================
|
||||
// Prime-Radiant GPU Compute Shaders - Token Routing
|
||||
// =============================================================================
|
||||
//
|
||||
// Parallel lane assignment for tokens based on coherence energy thresholds.
|
||||
// Routes tokens to different processing lanes (experts) based on their
|
||||
// local coherence energy, enabling adaptive computation.
|
||||
//
|
||||
// Lane Semantics:
|
||||
// - Lane 0: Coherent (energy < tau_0) - Fast path, minimal processing
|
||||
// - Lane 1: Semi-coherent (tau_0 <= energy < tau_1) - Normal processing
|
||||
// - Lane 2: Incoherent (tau_1 <= energy < tau_2) - Enhanced processing
|
||||
// - Lane 3: Critical (energy >= tau_2) - Special handling required
|
||||
|
||||
// =============================================================================
|
||||
// TYPE DEFINITIONS
|
||||
// =============================================================================
|
||||
|
||||
struct RoutingParams {
|
||||
num_tokens: u32,
|
||||
num_nodes: u32,
|
||||
threshold_0: f32,
|
||||
threshold_1: f32,
|
||||
threshold_2: f32,
|
||||
high_energy_threshold: f32,
|
||||
_padding0: u32,
|
||||
_padding1: u32,
|
||||
}
|
||||
|
||||
struct Token {
|
||||
token_id: u32,
|
||||
node_idx: u32,
|
||||
action_type: u32,
|
||||
priority: f32,
|
||||
}
|
||||
|
||||
struct RoutingDecision {
|
||||
token_id: u32,
|
||||
assigned_lane: u32,
|
||||
local_energy: f32,
|
||||
confidence: f32,
|
||||
escalation_reason: u32,
|
||||
num_high_energy_edges: u32,
|
||||
max_edge_energy: f32,
|
||||
_padding: u32,
|
||||
}
|
||||
|
||||
struct LaneStats {
|
||||
lane_counts: vec4<u32>,
|
||||
total_energy_per_lane: vec4<f32>,
|
||||
_padding: array<u32, 8>,
|
||||
}
|
||||
|
||||
const WORKGROUP_SIZE: u32 = 256u;
|
||||
const NUM_LANES: u32 = 4u;
|
||||
|
||||
// =============================================================================
|
||||
// BUFFER BINDINGS
|
||||
// =============================================================================
|
||||
// Layout matches Rust kernel bind group:
|
||||
// binding 0: params (uniform)
|
||||
// binding 1: tokens (storage, read)
|
||||
// binding 2: local_energies (storage, read)
|
||||
// binding 3: edge_energies (storage, read)
|
||||
// binding 4: node_edge_counts (storage, read)
|
||||
// binding 5: node_edge_offsets (storage, read)
|
||||
// binding 6: node_edges (storage, read)
|
||||
// binding 7: routing_decisions (storage, read_write)
|
||||
// binding 8: lane_stats (storage, read_write)
|
||||
|
||||
/// Routing parameters
|
||||
@group(0) @binding(0) var<uniform> params: RoutingParams;
|
||||
|
||||
/// Input tokens
|
||||
@group(0) @binding(1) var<storage, read> tokens: array<Token>;
|
||||
|
||||
/// Pre-computed local energies per node
|
||||
@group(0) @binding(2) var<storage, read> local_energies: array<f32>;
|
||||
|
||||
/// All edge energies
|
||||
@group(0) @binding(3) var<storage, read> edge_energies: array<f32>;
|
||||
|
||||
/// Number of edges per node (CSR format)
|
||||
@group(0) @binding(4) var<storage, read> node_edge_counts: array<u32>;
|
||||
|
||||
/// Edge start offsets per node (CSR format)
|
||||
@group(0) @binding(5) var<storage, read> node_edge_offsets: array<u32>;
|
||||
|
||||
/// Edge indices per node (CSR format)
|
||||
@group(0) @binding(6) var<storage, read> node_edges: array<u32>;
|
||||
|
||||
/// Output routing decisions
|
||||
@group(0) @binding(7) var<storage, read_write> routing_decisions: array<RoutingDecision>;
|
||||
|
||||
/// Output lane statistics
|
||||
@group(0) @binding(8) var<storage, read_write> lane_stats: LaneStats;
|
||||
|
||||
// =============================================================================
|
||||
// SHARED MEMORY
|
||||
// =============================================================================
|
||||
|
||||
/// Lane counts for workgroup-level reduction
|
||||
var<workgroup> shared_lane_counts: array<atomic<u32>, 4>;
|
||||
|
||||
/// Lane energy sums for workgroup-level reduction
|
||||
var<workgroup> shared_lane_energies: array<f32, 4>;
|
||||
|
||||
// =============================================================================
|
||||
// HELPER FUNCTIONS
|
||||
// =============================================================================
|
||||
|
||||
/// Branchless lane computation using step functions
|
||||
fn compute_lane_branchless(energy: f32, t0: f32, t1: f32, t2: f32) -> u32 {
|
||||
let s0 = select(0u, 1u, energy >= t0);
|
||||
let s1 = select(0u, 1u, energy >= t1);
|
||||
let s2 = select(0u, 1u, energy >= t2);
|
||||
return s0 + s1 + s2;
|
||||
}
|
||||
|
||||
/// Compute routing confidence based on how close energy is to threshold boundaries
|
||||
fn compute_confidence(energy: f32, lane: u32, t0: f32, t1: f32, t2: f32) -> f32 {
|
||||
// Confidence is based on distance from nearest threshold
|
||||
var dist_to_threshold: f32;
|
||||
|
||||
switch(lane) {
|
||||
case 0u: {
|
||||
dist_to_threshold = t0 - energy;
|
||||
}
|
||||
case 1u: {
|
||||
dist_to_threshold = min(energy - t0, t1 - energy);
|
||||
}
|
||||
case 2u: {
|
||||
dist_to_threshold = min(energy - t1, t2 - energy);
|
||||
}
|
||||
case 3u, default: {
|
||||
dist_to_threshold = energy - t2;
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize to [0, 1] - higher means further from boundary
|
||||
return clamp(dist_to_threshold * 10.0, 0.0, 1.0);
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// MAIN ROUTING KERNEL
|
||||
// =============================================================================
|
||||
|
||||
/// Route tokens to processing lanes based on local coherence energy
|
||||
@compute @workgroup_size(256)
|
||||
fn route_tokens(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) workgroup_id: vec3<u32>
|
||||
) {
|
||||
let token_idx = global_id.x;
|
||||
let local_idx = local_id.x;
|
||||
let num_tokens = params.num_tokens;
|
||||
|
||||
// Initialize shared counters (first thread only)
|
||||
if (local_idx == 0u) {
|
||||
atomicStore(&shared_lane_counts[0], 0u);
|
||||
atomicStore(&shared_lane_counts[1], 0u);
|
||||
atomicStore(&shared_lane_counts[2], 0u);
|
||||
atomicStore(&shared_lane_counts[3], 0u);
|
||||
shared_lane_energies[0] = 0.0;
|
||||
shared_lane_energies[1] = 0.0;
|
||||
shared_lane_energies[2] = 0.0;
|
||||
shared_lane_energies[3] = 0.0;
|
||||
}
|
||||
workgroupBarrier();
|
||||
|
||||
if (token_idx >= num_tokens) {
|
||||
return;
|
||||
}
|
||||
|
||||
let token = tokens[token_idx];
|
||||
let node_idx = token.node_idx;
|
||||
|
||||
// Get local energy for this node
|
||||
let local_energy = local_energies[node_idx];
|
||||
|
||||
// Compute lane assignment
|
||||
let lane = compute_lane_branchless(
|
||||
local_energy,
|
||||
params.threshold_0,
|
||||
params.threshold_1,
|
||||
params.threshold_2
|
||||
);
|
||||
|
||||
// Compute confidence
|
||||
let confidence = compute_confidence(
|
||||
local_energy,
|
||||
lane,
|
||||
params.threshold_0,
|
||||
params.threshold_1,
|
||||
params.threshold_2
|
||||
);
|
||||
|
||||
// Analyze edges for this node
|
||||
let edge_count = node_edge_counts[node_idx];
|
||||
let edge_offset = node_edge_offsets[node_idx];
|
||||
|
||||
var num_high_energy_edges: u32 = 0u;
|
||||
var max_edge_energy: f32 = 0.0;
|
||||
var escalation_reason: u32 = 0u;
|
||||
|
||||
for (var i = 0u; i < edge_count; i++) {
|
||||
let edge_idx = node_edges[edge_offset + i];
|
||||
let edge_energy = edge_energies[edge_idx];
|
||||
|
||||
if (edge_energy > params.high_energy_threshold) {
|
||||
num_high_energy_edges += 1u;
|
||||
}
|
||||
max_edge_energy = max(max_edge_energy, edge_energy);
|
||||
}
|
||||
|
||||
// Determine if escalation is needed
|
||||
if (num_high_energy_edges > 2u) {
|
||||
escalation_reason = 1u; // Multiple high-energy edges
|
||||
} else if (max_edge_energy > params.threshold_2) {
|
||||
escalation_reason = 2u; // Single very high energy edge
|
||||
}
|
||||
|
||||
// Write routing decision
|
||||
var decision: RoutingDecision;
|
||||
decision.token_id = token.token_id;
|
||||
decision.assigned_lane = lane;
|
||||
decision.local_energy = local_energy;
|
||||
decision.confidence = confidence;
|
||||
decision.escalation_reason = escalation_reason;
|
||||
decision.num_high_energy_edges = num_high_energy_edges;
|
||||
decision.max_edge_energy = max_edge_energy;
|
||||
decision._padding = 0u;
|
||||
|
||||
routing_decisions[token_idx] = decision;
|
||||
|
||||
// Update lane statistics
|
||||
atomicAdd(&shared_lane_counts[lane], 1u);
|
||||
// Note: No atomic f32 add in WGSL, would need separate reduction pass
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
// First thread writes workgroup stats to global buffer
|
||||
// (In production, would do proper atomic accumulation)
|
||||
if (local_idx == 0u && workgroup_id.x == 0u) {
|
||||
lane_stats.lane_counts = vec4<u32>(
|
||||
atomicLoad(&shared_lane_counts[0]),
|
||||
atomicLoad(&shared_lane_counts[1]),
|
||||
atomicLoad(&shared_lane_counts[2]),
|
||||
atomicLoad(&shared_lane_counts[3])
|
||||
);
|
||||
}
|
||||
}
|
||||
234
vendor/ruvector/crates/prime-radiant/src/gpu/shaders/types.wgsl
vendored
Normal file
234
vendor/ruvector/crates/prime-radiant/src/gpu/shaders/types.wgsl
vendored
Normal file
@@ -0,0 +1,234 @@
|
||||
// =============================================================================
|
||||
// Prime-Radiant GPU Compute Shaders - Shared Types
|
||||
// =============================================================================
|
||||
//
|
||||
// This file contains shared struct definitions and constants used across
|
||||
// all compute shaders in the Prime-Radiant coherence engine.
|
||||
//
|
||||
// Memory Layout:
|
||||
// - All structs are aligned to 16 bytes for optimal GPU memory access
|
||||
// - vec4<f32> is used where possible for coalesced memory operations
|
||||
// - Padding fields ensure proper alignment
|
||||
|
||||
// =============================================================================
|
||||
// COMPUTE PARAMETERS
|
||||
// =============================================================================
|
||||
|
||||
/// Parameters for residual computation
|
||||
struct ComputeParams {
|
||||
/// Total number of edges to process
|
||||
edge_count: u32,
|
||||
/// Dimension of state vectors
|
||||
state_dim: u32,
|
||||
/// Restriction map type: 0=identity, 1=diagonal, 2=dense, 3=projection, 4=sparse
|
||||
restriction_type: u32,
|
||||
/// Padding for 16-byte alignment
|
||||
padding: u32,
|
||||
}
|
||||
|
||||
/// Parameters for parallel reduction operations
|
||||
struct ReductionParams {
|
||||
/// Number of elements to reduce
|
||||
element_count: u32,
|
||||
/// Stride between elements (for strided access patterns)
|
||||
stride: u32,
|
||||
/// Whether this is the final reduction pass
|
||||
is_final_pass: u32,
|
||||
/// Output offset for multi-pass reductions
|
||||
output_offset: u32,
|
||||
}
|
||||
|
||||
/// Parameters for attention computation
|
||||
struct AttentionParams {
|
||||
/// Batch size (number of independent attention operations)
|
||||
batch_size: u32,
|
||||
/// Sequence length (number of tokens/nodes)
|
||||
seq_len: u32,
|
||||
/// Dimension per attention head
|
||||
head_dim: u32,
|
||||
/// Inverse temperature parameter: A_ij = softmax(-beta * E_ij)
|
||||
beta: f32,
|
||||
/// Number of attention heads (for multi-head attention)
|
||||
num_heads: u32,
|
||||
/// Whether to use causal masking
|
||||
use_causal_mask: u32,
|
||||
/// Energy threshold for sparse attention (skip if E > threshold)
|
||||
energy_threshold: f32,
|
||||
/// Padding for 16-byte alignment
|
||||
padding: u32,
|
||||
}
|
||||
|
||||
/// Parameters for token routing
|
||||
struct RoutingParams {
|
||||
/// Number of tokens to route
|
||||
token_count: u32,
|
||||
/// Number of lanes/experts
|
||||
num_lanes: u32,
|
||||
/// Whether to use load balancing
|
||||
use_load_balance: u32,
|
||||
/// Top-k selection for MoE
|
||||
top_k: u32,
|
||||
}
|
||||
|
||||
/// Parameters for sparse mask generation
|
||||
struct SparseMaskParams {
|
||||
/// Total number of potential edges
|
||||
total_edges: u32,
|
||||
/// Energy threshold for coherence (keep edges below this)
|
||||
coherence_threshold: f32,
|
||||
/// Maximum edges to keep (for memory bounds)
|
||||
max_edges: u32,
|
||||
/// Output format: 0=indices, 1=dense mask
|
||||
output_format: u32,
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// EDGE AND NODE DATA STRUCTURES
|
||||
// =============================================================================
|
||||
|
||||
/// Edge descriptor for graph connectivity (16-byte aligned)
|
||||
struct EdgeDescriptor {
|
||||
/// Index of source node
|
||||
source_idx: u32,
|
||||
/// Index of target node
|
||||
target_idx: u32,
|
||||
/// Offset into restriction data for this edge
|
||||
restriction_offset: u32,
|
||||
/// Weight for this edge
|
||||
weight: f32,
|
||||
}
|
||||
|
||||
/// Node state with metadata (16-byte aligned)
|
||||
struct NodeState {
|
||||
/// Offset into state buffer where this node's state begins
|
||||
state_offset: u32,
|
||||
/// Dimension of this node's state
|
||||
state_dim: u32,
|
||||
/// Scope ID for hierarchical energy aggregation
|
||||
scope_id: u32,
|
||||
/// Flags (bit 0: is_boundary, bit 1: is_fixed, etc.)
|
||||
flags: u32,
|
||||
}
|
||||
|
||||
/// Per-edge energy result (16-byte aligned)
|
||||
struct EdgeEnergy {
|
||||
/// Weighted energy: w_e * |r_e|^2
|
||||
energy: f32,
|
||||
/// Raw residual norm squared: |r_e|^2
|
||||
residual_norm_sq: f32,
|
||||
/// Edge weight that was applied
|
||||
weight: f32,
|
||||
/// Padding for alignment
|
||||
padding: f32,
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// ATTENTION STRUCTURES
|
||||
// =============================================================================
|
||||
|
||||
/// Attention score for a single edge (16-byte aligned)
|
||||
struct AttentionScore {
|
||||
/// Source node index
|
||||
source: u32,
|
||||
/// Target node index
|
||||
target: u32,
|
||||
/// Attention weight (after softmax)
|
||||
weight: f32,
|
||||
/// Raw score (before softmax)
|
||||
raw_score: f32,
|
||||
}
|
||||
|
||||
/// Lane assignment result for token routing (16-byte aligned)
|
||||
struct LaneAssignment {
|
||||
/// Token index
|
||||
token_idx: u32,
|
||||
/// Assigned lane (0-3 typically)
|
||||
lane: u32,
|
||||
/// Confidence score for this assignment
|
||||
confidence: f32,
|
||||
/// Energy value that determined routing
|
||||
energy: f32,
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// CONSTANTS
|
||||
// =============================================================================
|
||||
|
||||
/// Workgroup size for 1D dispatches
|
||||
const WORKGROUP_SIZE_1D: u32 = 256u;
|
||||
|
||||
/// Workgroup dimensions for 2D dispatches (attention)
|
||||
const WORKGROUP_SIZE_2D_X: u32 = 16u;
|
||||
const WORKGROUP_SIZE_2D_Y: u32 = 16u;
|
||||
|
||||
/// Maximum supported state dimension (for stack allocation)
|
||||
const MAX_STATE_DIM: u32 = 512u;
|
||||
|
||||
/// Epsilon for numerical stability
|
||||
const EPSILON: f32 = 1e-8;
|
||||
|
||||
/// Negative infinity for softmax initialization
|
||||
const NEG_INF: f32 = -3.402823e+38;
|
||||
|
||||
/// Restriction map type constants
|
||||
const RESTRICTION_IDENTITY: u32 = 0u;
|
||||
const RESTRICTION_DIAGONAL: u32 = 1u;
|
||||
const RESTRICTION_DENSE: u32 = 2u;
|
||||
const RESTRICTION_PROJECTION: u32 = 3u;
|
||||
const RESTRICTION_SPARSE: u32 = 4u;
|
||||
|
||||
/// Lane thresholds for token routing (default values)
|
||||
/// Lane 0: energy < 0.1 (coherent, fast path)
|
||||
/// Lane 1: 0.1 <= energy < 0.5 (semi-coherent, normal path)
|
||||
/// Lane 2: 0.5 <= energy < 1.0 (incoherent, slow path)
|
||||
/// Lane 3: energy >= 1.0 (critical, special handling)
|
||||
const DEFAULT_LANE_THRESHOLDS: vec4<f32> = vec4<f32>(0.1, 0.5, 1.0, 10.0);
|
||||
|
||||
// =============================================================================
|
||||
// UTILITY FUNCTIONS
|
||||
// =============================================================================
|
||||
|
||||
/// Compute squared L2 norm of a vec4
|
||||
fn norm_sq_vec4(v: vec4<f32>) -> f32 {
|
||||
return dot(v, v);
|
||||
}
|
||||
|
||||
/// Safe division with epsilon
|
||||
fn safe_div(a: f32, b: f32) -> f32 {
|
||||
return a / max(b, EPSILON);
|
||||
}
|
||||
|
||||
/// Branchless step function
|
||||
fn step_branchless(threshold: f32, value: f32) -> f32 {
|
||||
return select(0.0, 1.0, value >= threshold);
|
||||
}
|
||||
|
||||
/// Compute lane index from energy using branchless comparison
|
||||
fn compute_lane(energy: f32, thresholds: vec4<f32>) -> u32 {
|
||||
return u32(step_branchless(thresholds.x, energy))
|
||||
+ u32(step_branchless(thresholds.y, energy))
|
||||
+ u32(step_branchless(thresholds.z, energy));
|
||||
}
|
||||
|
||||
/// Online softmax helper - update max and sum
|
||||
fn online_softmax_update(
|
||||
old_max: f32,
|
||||
old_sum: f32,
|
||||
new_val: f32
|
||||
) -> vec2<f32> {
|
||||
let new_max = max(old_max, new_val);
|
||||
let correction = exp(old_max - new_max);
|
||||
let new_sum = old_sum * correction + exp(new_val - new_max);
|
||||
return vec2<f32>(new_max, new_sum);
|
||||
}
|
||||
|
||||
/// Fast approximate exp for softmax (when precision is less critical)
|
||||
fn fast_exp(x: f32) -> f32 {
|
||||
// Use native exp for now; can be replaced with polynomial approximation
|
||||
return exp(x);
|
||||
}
|
||||
|
||||
/// Clamp value to valid range
|
||||
fn clamp_f32(val: f32, min_val: f32, max_val: f32) -> f32 {
|
||||
return max(min_val, min(max_val, val));
|
||||
}
|
||||
Reference in New Issue
Block a user