Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
426
vendor/ruvector/crates/prime-radiant/src/gpu/dispatch.rs
vendored
Normal file
426
vendor/ruvector/crates/prime-radiant/src/gpu/dispatch.rs
vendored
Normal file
@@ -0,0 +1,426 @@
|
||||
//! Kernel dispatch and synchronization for GPU compute operations.
|
||||
//!
|
||||
//! This module provides the dispatcher for executing compute kernels on the GPU,
|
||||
//! including support for:
|
||||
//! - Single kernel dispatch
|
||||
//! - Indirect dispatch (workgroup count from GPU buffer)
|
||||
//! - Chained dispatch for fused kernels
|
||||
//! - Synchronization and timing
|
||||
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug, trace};
|
||||
use wgpu::{CommandEncoder, Device, Queue};
|
||||
|
||||
use super::buffer::{GpuBuffer, GpuBufferPool};
|
||||
use super::device::GpuDevice;
|
||||
use super::error::{GpuError, GpuResult};
|
||||
use super::pipeline::{ComputePipeline, PipelineCache};
|
||||
|
||||
/// Configuration for a dispatch operation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DispatchConfig {
|
||||
/// Label for debugging
|
||||
pub label: Option<String>,
|
||||
/// Whether to wait for completion
|
||||
pub wait: bool,
|
||||
/// Timeout in milliseconds (0 = no timeout)
|
||||
pub timeout_ms: u64,
|
||||
}
|
||||
|
||||
impl Default for DispatchConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
label: None,
|
||||
wait: false,
|
||||
timeout_ms: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DispatchConfig {
|
||||
/// Create a config that waits for completion
|
||||
pub fn wait() -> Self {
|
||||
Self {
|
||||
wait: true,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a config with a label
|
||||
pub fn with_label(label: impl Into<String>) -> Self {
|
||||
Self {
|
||||
label: Some(label.into()),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the timeout
|
||||
pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
|
||||
self.timeout_ms = timeout_ms;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set wait flag
|
||||
pub fn with_wait(mut self, wait: bool) -> Self {
|
||||
self.wait = wait;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// GPU dispatcher for executing compute kernels
|
||||
pub struct GpuDispatcher {
|
||||
device: Arc<GpuDevice>,
|
||||
pipeline_cache: PipelineCache,
|
||||
buffer_pool: GpuBufferPool,
|
||||
}
|
||||
|
||||
impl GpuDispatcher {
|
||||
/// Create a new dispatcher
|
||||
pub fn new(device: Arc<GpuDevice>) -> Self {
|
||||
let pipeline_cache = PipelineCache::new(device.device_arc());
|
||||
let buffer_pool = GpuBufferPool::new(device.device_arc());
|
||||
|
||||
Self {
|
||||
device,
|
||||
pipeline_cache,
|
||||
buffer_pool,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the underlying GPU device
|
||||
pub fn device(&self) -> &GpuDevice {
|
||||
&self.device
|
||||
}
|
||||
|
||||
/// Get the pipeline cache
|
||||
pub fn pipeline_cache(&self) -> &PipelineCache {
|
||||
&self.pipeline_cache
|
||||
}
|
||||
|
||||
/// Get the buffer pool
|
||||
pub fn buffer_pool(&self) -> &GpuBufferPool {
|
||||
&self.buffer_pool
|
||||
}
|
||||
|
||||
/// Dispatch a compute kernel.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `pipeline` - The compute pipeline to execute
|
||||
/// * `bind_group` - The bind group with buffer bindings
|
||||
/// * `workgroups` - Number of workgroups [x, y, z]
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// dispatcher.dispatch(&pipeline, &bind_group, [4, 1, 1]).await?;
|
||||
/// ```
|
||||
pub async fn dispatch(
|
||||
&self,
|
||||
pipeline: &ComputePipeline,
|
||||
bind_group: &wgpu::BindGroup,
|
||||
workgroups: [u32; 3],
|
||||
) -> GpuResult<()> {
|
||||
self.dispatch_with_config(pipeline, bind_group, workgroups, DispatchConfig::default())
|
||||
.await
|
||||
}
|
||||
|
||||
/// Dispatch with custom configuration.
|
||||
pub async fn dispatch_with_config(
|
||||
&self,
|
||||
pipeline: &ComputePipeline,
|
||||
bind_group: &wgpu::BindGroup,
|
||||
workgroups: [u32; 3],
|
||||
config: DispatchConfig,
|
||||
) -> GpuResult<()> {
|
||||
// Validate workgroup count
|
||||
let limits = &self.device.info().max_workgroups;
|
||||
if workgroups[0] > limits[0] || workgroups[1] > limits[1] || workgroups[2] > limits[2] {
|
||||
return Err(GpuError::InvalidWorkgroupSize {
|
||||
x: workgroups[0],
|
||||
y: workgroups[1],
|
||||
z: workgroups[2],
|
||||
});
|
||||
}
|
||||
|
||||
let label = config.label.as_deref().unwrap_or("dispatch");
|
||||
debug!(
|
||||
"Dispatching '{}' with workgroups [{}, {}, {}]",
|
||||
label, workgroups[0], workgroups[1], workgroups[2]
|
||||
);
|
||||
|
||||
let mut encoder = self
|
||||
.device
|
||||
.device()
|
||||
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some(label) });
|
||||
|
||||
{
|
||||
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
|
||||
label: Some(label),
|
||||
timestamp_writes: None,
|
||||
});
|
||||
|
||||
pass.set_pipeline(pipeline.pipeline());
|
||||
pass.set_bind_group(0, Some(bind_group), &[]);
|
||||
pass.dispatch_workgroups(workgroups[0], workgroups[1], workgroups[2]);
|
||||
}
|
||||
|
||||
self.device.submit(encoder.finish());
|
||||
|
||||
if config.wait {
|
||||
self.device.poll(true);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Dispatch using indirect workgroup count from a buffer.
|
||||
///
|
||||
/// The indirect buffer must contain [x, y, z] workgroup counts as u32.
|
||||
pub async fn dispatch_indirect(
|
||||
&self,
|
||||
pipeline: &ComputePipeline,
|
||||
bind_group: &wgpu::BindGroup,
|
||||
indirect_buffer: &GpuBuffer,
|
||||
) -> GpuResult<()> {
|
||||
self.dispatch_indirect_with_config(
|
||||
pipeline,
|
||||
bind_group,
|
||||
indirect_buffer,
|
||||
0,
|
||||
DispatchConfig::default(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Dispatch indirect with offset and configuration.
|
||||
pub async fn dispatch_indirect_with_config(
|
||||
&self,
|
||||
pipeline: &ComputePipeline,
|
||||
bind_group: &wgpu::BindGroup,
|
||||
indirect_buffer: &GpuBuffer,
|
||||
indirect_offset: u64,
|
||||
config: DispatchConfig,
|
||||
) -> GpuResult<()> {
|
||||
let label = config.label.as_deref().unwrap_or("dispatch_indirect");
|
||||
debug!("Dispatching indirect '{}'", label);
|
||||
|
||||
let mut encoder = self
|
||||
.device
|
||||
.device()
|
||||
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some(label) });
|
||||
|
||||
{
|
||||
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
|
||||
label: Some(label),
|
||||
timestamp_writes: None,
|
||||
});
|
||||
|
||||
pass.set_pipeline(pipeline.pipeline());
|
||||
pass.set_bind_group(0, Some(bind_group), &[]);
|
||||
pass.dispatch_workgroups_indirect(indirect_buffer.buffer(), indirect_offset);
|
||||
}
|
||||
|
||||
self.device.submit(encoder.finish());
|
||||
|
||||
if config.wait {
|
||||
self.device.poll(true);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Dispatch multiple kernels in a chain (fused execution).
|
||||
///
|
||||
/// All dispatches are recorded into a single command buffer for
|
||||
/// optimal GPU utilization.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `dispatches` - List of (pipeline, bind_group, workgroups) tuples
|
||||
pub async fn dispatch_chain(
|
||||
&self,
|
||||
dispatches: &[(&ComputePipeline, &wgpu::BindGroup, [u32; 3])],
|
||||
) -> GpuResult<()> {
|
||||
self.dispatch_chain_with_config(dispatches, DispatchConfig::default())
|
||||
.await
|
||||
}
|
||||
|
||||
/// Dispatch chain with custom configuration.
|
||||
pub async fn dispatch_chain_with_config(
|
||||
&self,
|
||||
dispatches: &[(&ComputePipeline, &wgpu::BindGroup, [u32; 3])],
|
||||
config: DispatchConfig,
|
||||
) -> GpuResult<()> {
|
||||
if dispatches.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let label = config.label.as_deref().unwrap_or("dispatch_chain");
|
||||
debug!(
|
||||
"Dispatching chain '{}' with {} kernels",
|
||||
label,
|
||||
dispatches.len()
|
||||
);
|
||||
|
||||
let mut encoder = self
|
||||
.device
|
||||
.device()
|
||||
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some(label) });
|
||||
|
||||
for (i, (pipeline, bind_group, workgroups)) in dispatches.iter().enumerate() {
|
||||
trace!(
|
||||
"Chain dispatch {}: workgroups [{}, {}, {}]",
|
||||
i,
|
||||
workgroups[0],
|
||||
workgroups[1],
|
||||
workgroups[2]
|
||||
);
|
||||
|
||||
let pass_label = format!("{}_pass_{}", label, i);
|
||||
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
|
||||
label: Some(&pass_label),
|
||||
timestamp_writes: None,
|
||||
});
|
||||
|
||||
pass.set_pipeline(pipeline.pipeline());
|
||||
pass.set_bind_group(0, Some(*bind_group), &[]);
|
||||
pass.dispatch_workgroups(workgroups[0], workgroups[1], workgroups[2]);
|
||||
}
|
||||
|
||||
self.device.submit(encoder.finish());
|
||||
|
||||
if config.wait {
|
||||
self.device.poll(true);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Record dispatches to a command encoder without submitting.
|
||||
///
|
||||
/// This is useful when you want to combine compute with other operations.
|
||||
pub fn record_dispatch(
|
||||
&self,
|
||||
encoder: &mut CommandEncoder,
|
||||
pipeline: &ComputePipeline,
|
||||
bind_group: &wgpu::BindGroup,
|
||||
workgroups: [u32; 3],
|
||||
label: Option<&str>,
|
||||
) {
|
||||
let pass_label = label.unwrap_or("recorded_dispatch");
|
||||
|
||||
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
|
||||
label: Some(pass_label),
|
||||
timestamp_writes: None,
|
||||
});
|
||||
|
||||
pass.set_pipeline(pipeline.pipeline());
|
||||
pass.set_bind_group(0, Some(bind_group), &[]);
|
||||
pass.dispatch_workgroups(workgroups[0], workgroups[1], workgroups[2]);
|
||||
}
|
||||
|
||||
/// Wait for all pending GPU work to complete.
|
||||
pub fn synchronize(&self) {
|
||||
self.device.poll(true);
|
||||
}
|
||||
|
||||
/// Poll for completed work without blocking.
|
||||
pub fn poll(&self) -> bool {
|
||||
self.device.poll(false)
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for constructing complex dispatch operations
|
||||
pub struct DispatchBuilder<'a> {
|
||||
dispatcher: &'a GpuDispatcher,
|
||||
dispatches: Vec<(Arc<ComputePipeline>, wgpu::BindGroup, [u32; 3])>,
|
||||
config: DispatchConfig,
|
||||
}
|
||||
|
||||
impl<'a> DispatchBuilder<'a> {
|
||||
/// Create a new dispatch builder
|
||||
pub fn new(dispatcher: &'a GpuDispatcher) -> Self {
|
||||
Self {
|
||||
dispatcher,
|
||||
dispatches: Vec::new(),
|
||||
config: DispatchConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a dispatch to the chain
|
||||
pub fn add(
|
||||
mut self,
|
||||
pipeline: Arc<ComputePipeline>,
|
||||
bind_group: wgpu::BindGroup,
|
||||
workgroups: [u32; 3],
|
||||
) -> Self {
|
||||
self.dispatches.push((pipeline, bind_group, workgroups));
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the configuration
|
||||
pub fn config(mut self, config: DispatchConfig) -> Self {
|
||||
self.config = config;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the label
|
||||
pub fn label(mut self, label: impl Into<String>) -> Self {
|
||||
self.config.label = Some(label.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Set wait flag
|
||||
pub fn wait(mut self) -> Self {
|
||||
self.config.wait = true;
|
||||
self
|
||||
}
|
||||
|
||||
/// Execute all dispatches
|
||||
pub async fn execute(self) -> GpuResult<()> {
|
||||
if self.dispatches.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let refs: Vec<(&ComputePipeline, &wgpu::BindGroup, [u32; 3])> = self
|
||||
.dispatches
|
||||
.iter()
|
||||
.map(|(p, b, w)| (p.as_ref(), b, *w))
|
||||
.collect();
|
||||
|
||||
self.dispatcher
|
||||
.dispatch_chain_with_config(&refs, self.config)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_dispatch_config_default() {
|
||||
let config = DispatchConfig::default();
|
||||
assert!(!config.wait);
|
||||
assert!(config.label.is_none());
|
||||
assert_eq!(config.timeout_ms, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dispatch_config_wait() {
|
||||
let config = DispatchConfig::wait();
|
||||
assert!(config.wait);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dispatch_config_builder() {
|
||||
let config = DispatchConfig::with_label("test")
|
||||
.with_timeout(1000)
|
||||
.with_wait(true);
|
||||
|
||||
assert_eq!(config.label.as_deref(), Some("test"));
|
||||
assert_eq!(config.timeout_ms, 1000);
|
||||
assert!(config.wait);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user