//! GPU Coherence Engine //! //! Main entry point for GPU-accelerated coherence computation. //! Provides automatic CPU fallback when GPU is unavailable. use super::buffer::{BufferUsage, GpuBufferManager, GpuEdge, GpuParams, GpuRestrictionMap}; use super::error::{GpuError, GpuResult}; use super::kernels::{ ComputeEnergyKernel, ComputeResidualsKernel, EnergyParams, SheafAttentionKernel, TokenRoutingKernel, }; use crate::coherence::{CoherenceEnergy as CpuCoherenceEnergy, EdgeEnergy}; use crate::substrate::restriction::MatrixStorage; use crate::substrate::{EdgeId, NodeId, SheafGraph}; use chrono::Utc; use std::collections::HashMap; use std::sync::Arc; use tracing::{debug, info, warn}; use wgpu::{ Adapter, Device, DeviceDescriptor, Features, Instance, InstanceDescriptor, Limits, PowerPreference, Queue, RequestAdapterOptions, }; /// GPU configuration #[derive(Debug, Clone)] pub struct GpuConfig { /// Preferred power preference (high performance vs low power) pub power_preference: PowerPreference, /// Enable CPU fallback when GPU is unavailable pub enable_fallback: bool, /// Maximum buffer size in bytes (0 = no limit) pub max_buffer_size: usize, /// Beta parameter for attention computation pub beta: f32, /// Lane 0 (reflex) threshold pub threshold_lane0: f32, /// Lane 1 (retrieval) threshold pub threshold_lane1: f32, /// Lane 2 (heavy) threshold pub threshold_lane2: f32, /// Timeout for GPU operations in milliseconds pub timeout_ms: u64, } impl Default for GpuConfig { fn default() -> Self { Self { power_preference: PowerPreference::HighPerformance, enable_fallback: true, max_buffer_size: 0, // No limit beta: 1.0, threshold_lane0: 0.01, threshold_lane1: 0.1, threshold_lane2: 1.0, timeout_ms: 5000, } } } /// GPU capabilities and limits #[derive(Debug, Clone)] pub struct GpuCapabilities { /// Device name pub device_name: String, /// Vendor pub vendor: String, /// Backend (Vulkan, Metal, DX12, etc.) pub backend: String, /// Maximum buffer size pub max_buffer_size: u64, /// Maximum compute workgroup size pub max_workgroup_size: u32, /// Maximum compute workgroups per dimension pub max_workgroups: [u32; 3], /// Whether the GPU supports required features pub supported: bool, } /// GPU energy result #[derive(Debug, Clone)] pub struct GpuCoherenceEnergy { /// Total system energy pub total_energy: f32, /// Per-edge energies pub edge_energies: Vec, /// Edge indices (matches edge_energies) pub edge_indices: Vec, /// Computation time in microseconds pub compute_time_us: u64, /// Whether GPU was used (false = CPU fallback) pub used_gpu: bool, } impl GpuCoherenceEnergy { /// Convert to CPU CoherenceEnergy format pub fn to_cpu_format(&self, graph: &SheafGraph) -> CpuCoherenceEnergy { let mut edge_energy_map = HashMap::new(); for (i, &edge_id) in self.edge_indices.iter().enumerate() { let energy = self.edge_energies[i]; if let Some(edge) = graph.get_edge(edge_id) { let edge_energy = EdgeEnergy::new_lightweight( edge_id.to_string(), edge.source.to_string(), edge.target.to_string(), energy / edge.weight.max(0.001), // Remove weight to get raw norm_sq edge.weight, ); edge_energy_map.insert(edge_id.to_string(), edge_energy); } } CpuCoherenceEnergy::new( edge_energy_map, &HashMap::new(), graph.node_count(), format!("gpu-{}", Utc::now().timestamp()), ) } } /// GPU-accelerated coherence engine pub struct GpuCoherenceEngine { device: Arc, queue: Arc, buffer_manager: GpuBufferManager, config: GpuConfig, capabilities: GpuCapabilities, // Kernels residuals_kernel: ComputeResidualsKernel, energy_kernel: ComputeEnergyKernel, attention_kernel: SheafAttentionKernel, routing_kernel: TokenRoutingKernel, // Cached graph data graph_data: Option, } /// Cached graph data on GPU struct GpuGraphData { num_nodes: u32, num_edges: u32, state_dim: u32, node_id_map: HashMap, edge_id_map: HashMap, edge_id_reverse: Vec, // Pre-allocated computation buffers (eliminates per-frame allocations) params_buffer: wgpu::Buffer, energy_params_buffer: wgpu::Buffer, partial_sums_buffer: wgpu::Buffer, final_params_buffer: wgpu::Buffer, total_energy_buffer: wgpu::Buffer, energies_staging: wgpu::Buffer, total_staging: wgpu::Buffer, // Pre-computed workgroup count for energy reduction num_workgroups: u32, } impl GpuCoherenceEngine { /// Create a new GPU coherence engine pub async fn new(config: GpuConfig) -> GpuResult { // Create wgpu instance let instance = Instance::new(InstanceDescriptor::default()); // Request adapter let adapter = instance .request_adapter(&RequestAdapterOptions { power_preference: config.power_preference, compatible_surface: None, force_fallback_adapter: false, }) .await .ok_or_else(|| GpuError::AdapterRequest("No suitable GPU adapter found".into()))?; let capabilities = Self::get_capabilities(&adapter); if !capabilities.supported { return Err(GpuError::UnsupportedFeature( "GPU does not support required features".into(), )); } info!( "Using GPU: {} ({}) - {}", capabilities.device_name, capabilities.vendor, capabilities.backend ); // Request device let (device, queue) = adapter .request_device( &DeviceDescriptor { label: Some("prime_radiant_gpu"), required_features: Features::empty(), required_limits: Limits::default(), memory_hints: Default::default(), }, None, ) .await .map_err(|e| GpuError::DeviceCreation(e.to_string()))?; let device = Arc::new(device); let queue = Arc::new(queue); // Create kernels let residuals_kernel = ComputeResidualsKernel::new(&device)?; let energy_kernel = ComputeEnergyKernel::new(&device)?; let attention_kernel = SheafAttentionKernel::new(&device)?; let routing_kernel = TokenRoutingKernel::new(&device)?; // Create buffer manager let buffer_manager = GpuBufferManager::new(device.clone(), queue.clone()); Ok(Self { device, queue, buffer_manager, config, capabilities, residuals_kernel, energy_kernel, attention_kernel, routing_kernel, graph_data: None, }) } /// Try to create a GPU engine, returning None if GPU is unavailable pub async fn try_new(config: GpuConfig) -> Option { match Self::new(config).await { Ok(engine) => Some(engine), Err(e) => { warn!("GPU initialization failed: {}. Will use CPU fallback.", e); None } } } /// Get GPU capabilities fn get_capabilities(adapter: &Adapter) -> GpuCapabilities { let info = adapter.get_info(); let limits = adapter.limits(); GpuCapabilities { device_name: info.name, vendor: format!("{:?}", info.vendor), backend: format!("{:?}", info.backend), max_buffer_size: limits.max_buffer_size as u64, max_workgroup_size: limits.max_compute_workgroup_size_x, max_workgroups: [ limits.max_compute_workgroups_per_dimension, limits.max_compute_workgroups_per_dimension, limits.max_compute_workgroups_per_dimension, ], supported: true, } } /// Upload graph data to GPU pub fn upload_graph(&mut self, graph: &SheafGraph) -> GpuResult<()> { if graph.edge_count() == 0 { return Err(GpuError::EmptyGraph); } let num_nodes = graph.node_count() as u32; let num_edges = graph.edge_count() as u32; // Build node ID mapping let mut node_id_map = HashMap::new(); let node_ids = graph.node_ids(); for (i, node_id) in node_ids.iter().enumerate() { node_id_map.insert(*node_id, i as u32); } // Determine state dimension from first node let state_dim = node_ids .first() .and_then(|id| graph.get_node(*id)) .map(|n| n.dim()) .unwrap_or(64) as u32; // Flatten node states let mut node_states: Vec = Vec::with_capacity((num_nodes * state_dim) as usize); for node_id in &node_ids { if let Some(state) = graph.node_state(*node_id) { node_states.extend(state.iter().cloned()); // Pad if needed for _ in state.len()..(state_dim as usize) { node_states.push(0.0); } } } // Build edge data and restriction maps let mut edges: Vec = Vec::with_capacity(num_edges as usize); let mut restriction_maps: Vec = Vec::new(); let mut restriction_data: Vec = Vec::new(); let mut edge_id_map = HashMap::new(); let mut edge_id_reverse = Vec::new(); let edge_ids = graph.edge_ids(); for (i, edge_id) in edge_ids.iter().enumerate() { edge_id_map.insert(*edge_id, i as u32); edge_id_reverse.push(*edge_id); if let Some(edge) = graph.get_edge(*edge_id) { let source_idx = *node_id_map.get(&edge.source).unwrap_or(&0); let target_idx = *node_id_map.get(&edge.target).unwrap_or(&0); // Convert restriction maps let rho_source_idx = restriction_maps.len() as u32; let gpu_rho_source = Self::convert_restriction_map(&edge.rho_source, &mut restriction_data); restriction_maps.push(gpu_rho_source); let rho_target_idx = restriction_maps.len() as u32; let gpu_rho_target = Self::convert_restriction_map(&edge.rho_target, &mut restriction_data); restriction_maps.push(gpu_rho_target); edges.push(GpuEdge { source_idx, target_idx, weight: edge.weight, rho_source_idx, rho_target_idx, comparison_dim: edge.comparison_dim() as u32, _padding: [0; 2], }); } } // Ensure restriction_data is not empty (GPU buffers can't be zero-sized) if restriction_data.is_empty() { restriction_data.push(0.0); } // Upload to GPU self.buffer_manager.allocate_with_data( &node_states, BufferUsage::NodeStates, "node_states", )?; self.buffer_manager .allocate_with_data(&edges, BufferUsage::EdgeData, "edges")?; self.buffer_manager.allocate_with_data( &restriction_maps, BufferUsage::RestrictionMaps, "restriction_maps", )?; self.buffer_manager.allocate_with_data( &restriction_data, BufferUsage::RestrictionMaps, "restriction_data", )?; // Allocate output buffers let max_comparison_dim = edges .iter() .map(|e| e.comparison_dim) .max() .unwrap_or(state_dim); let residuals_size = (num_edges * max_comparison_dim) as usize * std::mem::size_of::(); let energies_size = num_edges as usize * std::mem::size_of::(); self.buffer_manager .allocate(residuals_size, BufferUsage::Residuals, "residuals")?; self.buffer_manager .allocate(energies_size, BufferUsage::Energies, "edge_energies")?; // Pre-allocate computation buffers to eliminate per-frame allocations let num_workgroups = ComputeEnergyKernel::workgroup_count(num_edges); // Params buffer (GpuParams) let params_buffer = self.device.create_buffer(&wgpu::BufferDescriptor { label: Some("params_preallocated"), size: std::mem::size_of::() as u64, usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST, mapped_at_creation: false, }); // Energy params buffer (EnergyParams) let energy_params_buffer = self.device.create_buffer(&wgpu::BufferDescriptor { label: Some("energy_params_preallocated"), size: std::mem::size_of::() as u64, usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST, mapped_at_creation: false, }); // Partial sums buffer (one f32 per workgroup) let partial_sums_buffer = self.device.create_buffer(&wgpu::BufferDescriptor { label: Some("partial_sums_preallocated"), size: ((num_workgroups as usize).max(1) * std::mem::size_of::()) as u64, usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC | wgpu::BufferUsages::COPY_DST, mapped_at_creation: false, }); // Final params buffer (for second reduction pass) let final_params_buffer = self.device.create_buffer(&wgpu::BufferDescriptor { label: Some("final_params_preallocated"), size: std::mem::size_of::() as u64, usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST, mapped_at_creation: false, }); // Total energy buffer (single f32) let total_energy_buffer = self.device.create_buffer(&wgpu::BufferDescriptor { label: Some("total_energy_preallocated"), size: std::mem::size_of::() as u64, usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC | wgpu::BufferUsages::COPY_DST, mapped_at_creation: false, }); // Staging buffer for edge energies readback let energies_staging = self.device.create_buffer(&wgpu::BufferDescriptor { label: Some("energies_staging_preallocated"), size: (num_edges as usize * std::mem::size_of::()) as u64, usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, mapped_at_creation: false, }); // Staging buffer for total energy readback let total_staging = self.device.create_buffer(&wgpu::BufferDescriptor { label: Some("total_staging_preallocated"), size: std::mem::size_of::() as u64, usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, mapped_at_creation: false, }); // Store graph data with pre-allocated buffers self.graph_data = Some(GpuGraphData { num_nodes, num_edges, state_dim, node_id_map, edge_id_map, edge_id_reverse, params_buffer, energy_params_buffer, partial_sums_buffer, final_params_buffer, total_energy_buffer, energies_staging, total_staging, num_workgroups, }); debug!( "Uploaded graph to GPU: {} nodes, {} edges, state_dim={}, workgroups={}", num_nodes, num_edges, state_dim, num_workgroups ); Ok(()) } /// Convert a RestrictionMap to GPU format fn convert_restriction_map( map: &crate::substrate::RestrictionMap, data: &mut Vec, ) -> GpuRestrictionMap { let data_offset = data.len() as u32; let (map_type, data_len) = match &map.matrix { MatrixStorage::Identity => (0, 0), MatrixStorage::Diagonal(scales) => { data.extend(scales.iter().cloned()); (1, scales.len() as u32) } MatrixStorage::Projection { indices, .. } => { data.extend(indices.iter().map(|&i| i as f32)); (2, indices.len() as u32) } MatrixStorage::Sparse { values, .. } => { // Simplified: just store values (would need row/col in practice) data.extend(values.iter().cloned()); (3, values.len() as u32) } MatrixStorage::Csr(csr) => { // CSR format: store values similar to sparse // Note: In practice, GPU would need row_ptr and col_indices too data.extend(csr.values.iter().cloned()); (3, csr.values.len() as u32) } MatrixStorage::Dense { data: matrix_data, .. } => { data.extend(matrix_data.iter().cloned()); (3, matrix_data.len() as u32) } }; GpuRestrictionMap { map_type, input_dim: map.input_dim() as u32, output_dim: map.output_dim() as u32, data_offset, data_len, _padding: [0; 3], } } /// Compute coherence energy on GPU /// Uses pre-allocated buffers to eliminate per-frame allocations pub async fn compute_energy(&mut self) -> GpuResult { let start = std::time::Instant::now(); let graph_data = self .graph_data .as_ref() .ok_or_else(|| GpuError::Internal("Graph not uploaded".into()))?; let num_edges = graph_data.num_edges; let num_workgroups = graph_data.num_workgroups; // Write params to pre-allocated buffer (no allocation) let params = GpuParams { num_edges, num_nodes: graph_data.num_nodes, state_dim: graph_data.state_dim, beta: self.config.beta, threshold_lane0: self.config.threshold_lane0, threshold_lane1: self.config.threshold_lane1, threshold_lane2: self.config.threshold_lane2, store_residuals: 1, // Store residuals by default for gradient computation }; self.queue .write_buffer(&graph_data.params_buffer, 0, bytemuck::bytes_of(¶ms)); // Write energy params to pre-allocated buffer (no allocation) let energy_params = EnergyParams { num_elements: num_edges, _padding: [0; 7], }; self.queue.write_buffer( &graph_data.energy_params_buffer, 0, bytemuck::bytes_of(&energy_params), ); // Get managed buffers for bind group creation let node_states_buf = self .buffer_manager .get("node_states") .ok_or_else(|| GpuError::Internal("Node states buffer not found".into()))?; let edges_buf = self .buffer_manager .get("edges") .ok_or_else(|| GpuError::Internal("Edges buffer not found".into()))?; let restriction_maps_buf = self .buffer_manager .get("restriction_maps") .ok_or_else(|| GpuError::Internal("Restriction maps buffer not found".into()))?; let restriction_data_buf = self .buffer_manager .get("restriction_data") .ok_or_else(|| GpuError::Internal("Restriction data buffer not found".into()))?; let residuals_buf = self .buffer_manager .get("residuals") .ok_or_else(|| GpuError::Internal("Residuals buffer not found".into()))?; let energies_buf = self .buffer_manager .get("edge_energies") .ok_or_else(|| GpuError::Internal("Edge energies buffer not found".into()))?; // Create bind group for residuals kernel using pre-allocated params buffer let residuals_bind_group = self.residuals_kernel.create_bind_group_raw( &self.device, &graph_data.params_buffer, &node_states_buf.buffer, &edges_buf.buffer, &restriction_maps_buf.buffer, &restriction_data_buf.buffer, &residuals_buf.buffer, &energies_buf.buffer, ); // Create bind group for energy reduction using pre-allocated buffers let energy_bind_group = self.energy_kernel.create_bind_group_raw( &self.device, &graph_data.energy_params_buffer, &energies_buf.buffer, &graph_data.partial_sums_buffer, ); // Create command encoder let mut encoder = self .device .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("compute_energy_encoder"), }); // Dispatch residuals computation { let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some("compute_residuals_pass"), timestamp_writes: None, }); compute_pass.set_pipeline(self.residuals_kernel.pipeline()); compute_pass.set_bind_group(0, &residuals_bind_group, &[]); compute_pass.dispatch_workgroups( ComputeResidualsKernel::workgroup_count(num_edges), 1, 1, ); } // Dispatch energy reduction { let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some("compute_energy_pass"), timestamp_writes: None, }); compute_pass.set_pipeline(self.energy_kernel.main_pipeline()); compute_pass.set_bind_group(0, &energy_bind_group, &[]); compute_pass.dispatch_workgroups(num_workgroups, 1, 1); } // If we have multiple workgroups, do final reduction if num_workgroups > 1 { // Write final params to pre-allocated buffer (no allocation) let final_params = EnergyParams { num_elements: num_workgroups, _padding: [0; 7], }; self.queue.write_buffer( &graph_data.final_params_buffer, 0, bytemuck::bytes_of(&final_params), ); let final_bind_group = self.energy_kernel.create_bind_group_raw( &self.device, &graph_data.final_params_buffer, &graph_data.partial_sums_buffer, &graph_data.total_energy_buffer, ); { let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some("final_reduce_pass"), timestamp_writes: None, }); compute_pass.set_pipeline(self.energy_kernel.final_pipeline()); compute_pass.set_bind_group(0, &final_bind_group, &[]); compute_pass.dispatch_workgroups(1, 1, 1); } } // Copy results to pre-allocated staging buffers (no allocation) encoder.copy_buffer_to_buffer( &energies_buf.buffer, 0, &graph_data.energies_staging, 0, (num_edges as usize * std::mem::size_of::()) as u64, ); if num_workgroups > 1 { encoder.copy_buffer_to_buffer( &graph_data.total_energy_buffer, 0, &graph_data.total_staging, 0, std::mem::size_of::() as u64, ); } else { encoder.copy_buffer_to_buffer( &graph_data.partial_sums_buffer, 0, &graph_data.total_staging, 0, std::mem::size_of::() as u64, ); } // Submit commands self.queue.submit(std::iter::once(encoder.finish())); // Read back results from pre-allocated staging buffers let edge_energies = Self::read_buffer_f32( &self.device, &graph_data.energies_staging, num_edges as usize, ) .await?; let total_energy = Self::read_buffer_f32(&self.device, &graph_data.total_staging, 1).await?[0]; let compute_time_us = start.elapsed().as_micros() as u64; debug!( "GPU energy computation: total={:.6}, {} edges, {}us (pre-allocated buffers)", total_energy, num_edges, compute_time_us ); Ok(GpuCoherenceEnergy { total_energy, edge_energies, edge_indices: graph_data.edge_id_reverse.clone(), compute_time_us, used_gpu: true, }) } /// Read f32 buffer back to CPU async fn read_buffer_f32( device: &Device, buffer: &wgpu::Buffer, count: usize, ) -> GpuResult> { let buffer_slice = buffer.slice(..); let (sender, receiver) = futures::channel::oneshot::channel(); buffer_slice.map_async(wgpu::MapMode::Read, move |result| { let _ = sender.send(result); }); device.poll(wgpu::Maintain::Wait); receiver .await .map_err(|_| GpuError::BufferRead("Channel closed".into()))? .map_err(|e| GpuError::BufferRead(e.to_string()))?; let data = buffer_slice.get_mapped_range(); let result: Vec = bytemuck::cast_slice(&data[..count * std::mem::size_of::()]).to_vec(); drop(data); buffer.unmap(); Ok(result) } /// Get GPU capabilities pub fn capabilities(&self) -> &GpuCapabilities { &self.capabilities } /// Get configuration pub fn config(&self) -> &GpuConfig { &self.config } /// Check if GPU is available pub fn is_available(&self) -> bool { self.capabilities.supported } /// Release all GPU resources pub fn release(&mut self) { self.buffer_manager.clear(); self.graph_data = None; } } /// Synchronous wrapper for GPU coherence engine using pollster pub mod sync { use super::*; /// Synchronously create a GPU engine pub fn create_engine(config: GpuConfig) -> GpuResult { pollster::block_on(GpuCoherenceEngine::new(config)) } /// Try to create GPU engine synchronously pub fn try_create_engine(config: GpuConfig) -> Option { pollster::block_on(GpuCoherenceEngine::try_new(config)) } /// Compute energy synchronously pub fn compute_energy(engine: &mut GpuCoherenceEngine) -> GpuResult { pollster::block_on(engine.compute_energy()) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_gpu_config_default() { let config = GpuConfig::default(); assert!(config.enable_fallback); assert_eq!(config.beta, 1.0); assert!(config.threshold_lane0 < config.threshold_lane1); assert!(config.threshold_lane1 < config.threshold_lane2); } #[test] fn test_gpu_params_size() { assert_eq!(std::mem::size_of::(), 32); } #[test] fn test_energy_params_size() { assert_eq!(std::mem::size_of::(), 32); } }