Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,426 @@
//! Kernel dispatch and synchronization for GPU compute operations.
//!
//! This module provides the dispatcher for executing compute kernels on the GPU,
//! including support for:
//! - Single kernel dispatch
//! - Indirect dispatch (workgroup count from GPU buffer)
//! - Chained dispatch for fused kernels
//! - Synchronization and timing
use std::sync::Arc;
use tracing::{debug, trace};
use wgpu::{CommandEncoder, Device, Queue};
use super::buffer::{GpuBuffer, GpuBufferPool};
use super::device::GpuDevice;
use super::error::{GpuError, GpuResult};
use super::pipeline::{ComputePipeline, PipelineCache};
/// Configuration for a dispatch operation
#[derive(Debug, Clone)]
pub struct DispatchConfig {
/// Label for debugging
pub label: Option<String>,
/// Whether to wait for completion
pub wait: bool,
/// Timeout in milliseconds (0 = no timeout)
pub timeout_ms: u64,
}
impl Default for DispatchConfig {
fn default() -> Self {
Self {
label: None,
wait: false,
timeout_ms: 0,
}
}
}
impl DispatchConfig {
/// Create a config that waits for completion
pub fn wait() -> Self {
Self {
wait: true,
..Default::default()
}
}
/// Create a config with a label
pub fn with_label(label: impl Into<String>) -> Self {
Self {
label: Some(label.into()),
..Default::default()
}
}
/// Set the timeout
pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
self.timeout_ms = timeout_ms;
self
}
/// Set wait flag
pub fn with_wait(mut self, wait: bool) -> Self {
self.wait = wait;
self
}
}
/// GPU dispatcher for executing compute kernels
pub struct GpuDispatcher {
device: Arc<GpuDevice>,
pipeline_cache: PipelineCache,
buffer_pool: GpuBufferPool,
}
impl GpuDispatcher {
/// Create a new dispatcher
pub fn new(device: Arc<GpuDevice>) -> Self {
let pipeline_cache = PipelineCache::new(device.device_arc());
let buffer_pool = GpuBufferPool::new(device.device_arc());
Self {
device,
pipeline_cache,
buffer_pool,
}
}
/// Get the underlying GPU device
pub fn device(&self) -> &GpuDevice {
&self.device
}
/// Get the pipeline cache
pub fn pipeline_cache(&self) -> &PipelineCache {
&self.pipeline_cache
}
/// Get the buffer pool
pub fn buffer_pool(&self) -> &GpuBufferPool {
&self.buffer_pool
}
/// Dispatch a compute kernel.
///
/// # Arguments
///
/// * `pipeline` - The compute pipeline to execute
/// * `bind_group` - The bind group with buffer bindings
/// * `workgroups` - Number of workgroups [x, y, z]
///
/// # Example
///
/// ```rust,ignore
/// dispatcher.dispatch(&pipeline, &bind_group, [4, 1, 1]).await?;
/// ```
pub async fn dispatch(
&self,
pipeline: &ComputePipeline,
bind_group: &wgpu::BindGroup,
workgroups: [u32; 3],
) -> GpuResult<()> {
self.dispatch_with_config(pipeline, bind_group, workgroups, DispatchConfig::default())
.await
}
/// Dispatch with custom configuration.
pub async fn dispatch_with_config(
&self,
pipeline: &ComputePipeline,
bind_group: &wgpu::BindGroup,
workgroups: [u32; 3],
config: DispatchConfig,
) -> GpuResult<()> {
// Validate workgroup count
let limits = &self.device.info().max_workgroups;
if workgroups[0] > limits[0] || workgroups[1] > limits[1] || workgroups[2] > limits[2] {
return Err(GpuError::InvalidWorkgroupSize {
x: workgroups[0],
y: workgroups[1],
z: workgroups[2],
});
}
let label = config.label.as_deref().unwrap_or("dispatch");
debug!(
"Dispatching '{}' with workgroups [{}, {}, {}]",
label, workgroups[0], workgroups[1], workgroups[2]
);
let mut encoder = self
.device
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some(label) });
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some(label),
timestamp_writes: None,
});
pass.set_pipeline(pipeline.pipeline());
pass.set_bind_group(0, Some(bind_group), &[]);
pass.dispatch_workgroups(workgroups[0], workgroups[1], workgroups[2]);
}
self.device.submit(encoder.finish());
if config.wait {
self.device.poll(true);
}
Ok(())
}
/// Dispatch using indirect workgroup count from a buffer.
///
/// The indirect buffer must contain [x, y, z] workgroup counts as u32.
pub async fn dispatch_indirect(
&self,
pipeline: &ComputePipeline,
bind_group: &wgpu::BindGroup,
indirect_buffer: &GpuBuffer,
) -> GpuResult<()> {
self.dispatch_indirect_with_config(
pipeline,
bind_group,
indirect_buffer,
0,
DispatchConfig::default(),
)
.await
}
/// Dispatch indirect with offset and configuration.
pub async fn dispatch_indirect_with_config(
&self,
pipeline: &ComputePipeline,
bind_group: &wgpu::BindGroup,
indirect_buffer: &GpuBuffer,
indirect_offset: u64,
config: DispatchConfig,
) -> GpuResult<()> {
let label = config.label.as_deref().unwrap_or("dispatch_indirect");
debug!("Dispatching indirect '{}'", label);
let mut encoder = self
.device
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some(label) });
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some(label),
timestamp_writes: None,
});
pass.set_pipeline(pipeline.pipeline());
pass.set_bind_group(0, Some(bind_group), &[]);
pass.dispatch_workgroups_indirect(indirect_buffer.buffer(), indirect_offset);
}
self.device.submit(encoder.finish());
if config.wait {
self.device.poll(true);
}
Ok(())
}
/// Dispatch multiple kernels in a chain (fused execution).
///
/// All dispatches are recorded into a single command buffer for
/// optimal GPU utilization.
///
/// # Arguments
///
/// * `dispatches` - List of (pipeline, bind_group, workgroups) tuples
pub async fn dispatch_chain(
&self,
dispatches: &[(&ComputePipeline, &wgpu::BindGroup, [u32; 3])],
) -> GpuResult<()> {
self.dispatch_chain_with_config(dispatches, DispatchConfig::default())
.await
}
/// Dispatch chain with custom configuration.
pub async fn dispatch_chain_with_config(
&self,
dispatches: &[(&ComputePipeline, &wgpu::BindGroup, [u32; 3])],
config: DispatchConfig,
) -> GpuResult<()> {
if dispatches.is_empty() {
return Ok(());
}
let label = config.label.as_deref().unwrap_or("dispatch_chain");
debug!(
"Dispatching chain '{}' with {} kernels",
label,
dispatches.len()
);
let mut encoder = self
.device
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some(label) });
for (i, (pipeline, bind_group, workgroups)) in dispatches.iter().enumerate() {
trace!(
"Chain dispatch {}: workgroups [{}, {}, {}]",
i,
workgroups[0],
workgroups[1],
workgroups[2]
);
let pass_label = format!("{}_pass_{}", label, i);
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some(&pass_label),
timestamp_writes: None,
});
pass.set_pipeline(pipeline.pipeline());
pass.set_bind_group(0, Some(*bind_group), &[]);
pass.dispatch_workgroups(workgroups[0], workgroups[1], workgroups[2]);
}
self.device.submit(encoder.finish());
if config.wait {
self.device.poll(true);
}
Ok(())
}
/// Record dispatches to a command encoder without submitting.
///
/// This is useful when you want to combine compute with other operations.
pub fn record_dispatch(
&self,
encoder: &mut CommandEncoder,
pipeline: &ComputePipeline,
bind_group: &wgpu::BindGroup,
workgroups: [u32; 3],
label: Option<&str>,
) {
let pass_label = label.unwrap_or("recorded_dispatch");
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some(pass_label),
timestamp_writes: None,
});
pass.set_pipeline(pipeline.pipeline());
pass.set_bind_group(0, Some(bind_group), &[]);
pass.dispatch_workgroups(workgroups[0], workgroups[1], workgroups[2]);
}
/// Wait for all pending GPU work to complete.
pub fn synchronize(&self) {
self.device.poll(true);
}
/// Poll for completed work without blocking.
pub fn poll(&self) -> bool {
self.device.poll(false)
}
}
/// Builder for constructing complex dispatch operations
pub struct DispatchBuilder<'a> {
dispatcher: &'a GpuDispatcher,
dispatches: Vec<(Arc<ComputePipeline>, wgpu::BindGroup, [u32; 3])>,
config: DispatchConfig,
}
impl<'a> DispatchBuilder<'a> {
/// Create a new dispatch builder
pub fn new(dispatcher: &'a GpuDispatcher) -> Self {
Self {
dispatcher,
dispatches: Vec::new(),
config: DispatchConfig::default(),
}
}
/// Add a dispatch to the chain
pub fn add(
mut self,
pipeline: Arc<ComputePipeline>,
bind_group: wgpu::BindGroup,
workgroups: [u32; 3],
) -> Self {
self.dispatches.push((pipeline, bind_group, workgroups));
self
}
/// Set the configuration
pub fn config(mut self, config: DispatchConfig) -> Self {
self.config = config;
self
}
/// Set the label
pub fn label(mut self, label: impl Into<String>) -> Self {
self.config.label = Some(label.into());
self
}
/// Set wait flag
pub fn wait(mut self) -> Self {
self.config.wait = true;
self
}
/// Execute all dispatches
pub async fn execute(self) -> GpuResult<()> {
if self.dispatches.is_empty() {
return Ok(());
}
let refs: Vec<(&ComputePipeline, &wgpu::BindGroup, [u32; 3])> = self
.dispatches
.iter()
.map(|(p, b, w)| (p.as_ref(), b, *w))
.collect();
self.dispatcher
.dispatch_chain_with_config(&refs, self.config)
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dispatch_config_default() {
let config = DispatchConfig::default();
assert!(!config.wait);
assert!(config.label.is_none());
assert_eq!(config.timeout_ms, 0);
}
#[test]
fn test_dispatch_config_wait() {
let config = DispatchConfig::wait();
assert!(config.wait);
}
#[test]
fn test_dispatch_config_builder() {
let config = DispatchConfig::with_label("test")
.with_timeout(1000)
.with_wait(true);
assert_eq!(config.label.as_deref(), Some("test"));
assert_eq!(config.timeout_ms, 1000);
assert!(config.wait);
}
}