//! ONNX Runtime backend for neural network inference. //! //! This module provides ONNX model loading and execution using the `ort` crate. //! It supports CPU and GPU (CUDA/TensorRT) execution providers. use crate::error::{NnError, NnResult}; use crate::inference::{Backend, InferenceOptions}; use crate::tensor::{Tensor, TensorShape}; use ort::session::Session; use std::collections::HashMap; use std::path::Path; use std::sync::Arc; use tracing::info; /// ONNX Runtime session wrapper pub struct OnnxSession { session: Session, input_names: Vec, output_names: Vec, input_shapes: HashMap, output_shapes: HashMap, } impl std::fmt::Debug for OnnxSession { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("OnnxSession") .field("input_names", &self.input_names) .field("output_names", &self.output_names) .field("input_shapes", &self.input_shapes) .field("output_shapes", &self.output_shapes) .finish() } } impl OnnxSession { /// Create a new ONNX session from a file pub fn from_file>(path: P, _options: &InferenceOptions) -> NnResult { let path = path.as_ref(); info!(?path, "Loading ONNX model"); // Build session using ort 2.0 API let session = Session::builder() .map_err(|e| NnError::model_load(format!("Failed to create session builder: {}", e)))? .commit_from_file(path) .map_err(|e| NnError::model_load(format!("Failed to load model: {}", e)))?; // Extract metadata using ort 2.0 API let input_names: Vec = session .inputs() .iter() .map(|input| input.name().to_string()) .collect(); let output_names: Vec = session .outputs() .iter() .map(|output| output.name().to_string()) .collect(); // For now, leave shapes empty - they can be populated when needed let input_shapes = HashMap::new(); let output_shapes = HashMap::new(); info!( inputs = ?input_names, outputs = ?output_names, "ONNX model loaded successfully" ); Ok(Self { session, input_names, output_names, input_shapes, output_shapes, }) } /// Create from in-memory bytes pub fn from_bytes(bytes: &[u8], _options: &InferenceOptions) -> NnResult { info!("Loading ONNX model from bytes"); let session = Session::builder() .map_err(|e| NnError::model_load(format!("Failed to create session builder: {}", e)))? .commit_from_memory(bytes) .map_err(|e| NnError::model_load(format!("Failed to load model from bytes: {}", e)))?; let input_names: Vec = session .inputs() .iter() .map(|input| input.name().to_string()) .collect(); let output_names: Vec = session .outputs() .iter() .map(|output| output.name().to_string()) .collect(); let input_shapes = HashMap::new(); let output_shapes = HashMap::new(); Ok(Self { session, input_names, output_names, input_shapes, output_shapes, }) } /// Get input names pub fn input_names(&self) -> &[String] { &self.input_names } /// Get output names pub fn output_names(&self) -> &[String] { &self.output_names } /// Run inference pub fn run(&mut self, inputs: HashMap) -> NnResult> { // Get the first input tensor let first_input_name = self.input_names.first() .ok_or_else(|| NnError::inference("No input names defined"))?; let tensor = inputs .get(first_input_name) .ok_or_else(|| NnError::invalid_input(format!("Missing input: {}", first_input_name)))?; let arr = tensor.as_array4()?; // Get shape and data for ort tensor creation let shape: Vec = arr.shape().iter().map(|&d| d as i64).collect(); let data: Vec = arr.iter().cloned().collect(); // Create ORT tensor from shape and data let ort_tensor = ort::value::Tensor::from_array((shape, data)) .map_err(|e| NnError::tensor_op(format!("Failed to create ORT tensor: {}", e)))?; // Build input map - inputs! macro returns Vec directly let session_inputs = ort::inputs![first_input_name.as_str() => ort_tensor]; // Run session let session_outputs = self.session .run(session_inputs) .map_err(|e| NnError::inference(format!("Inference failed: {}", e)))?; // Extract outputs let mut result = HashMap::new(); for name in self.output_names.iter() { if let Some(output) = session_outputs.get(name.as_str()) { // Try to extract tensor - returns (shape, data) tuple in ort 2.0 if let Ok((shape, data)) = output.try_extract_tensor::() { let dims: Vec = shape.iter().map(|&d| d as usize).collect(); if dims.len() == 4 { // Convert to 4D array let arr4 = ndarray::Array4::from_shape_vec( (dims[0], dims[1], dims[2], dims[3]), data.to_vec(), ).map_err(|e| NnError::tensor_op(format!("Shape error: {}", e)))?; result.insert(name.clone(), Tensor::Float4D(arr4)); } else { // Handle other dimensionalities let arr_dyn = ndarray::ArrayD::from_shape_vec( ndarray::IxDyn(&dims), data.to_vec(), ).map_err(|e| NnError::tensor_op(format!("Shape error: {}", e)))?; result.insert(name.clone(), Tensor::FloatND(arr_dyn)); } } } } Ok(result) } } /// ONNX Runtime backend implementation pub struct OnnxBackend { session: Arc>, options: InferenceOptions, } impl std::fmt::Debug for OnnxBackend { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("OnnxBackend") .field("options", &self.options) .finish() } } impl OnnxBackend { /// Create backend from file pub fn from_file>(path: P) -> NnResult { let options = InferenceOptions::default(); let session = OnnxSession::from_file(path, &options)?; Ok(Self { session: Arc::new(parking_lot::RwLock::new(session)), options, }) } /// Create backend from file with options pub fn from_file_with_options>(path: P, options: InferenceOptions) -> NnResult { let session = OnnxSession::from_file(path, &options)?; Ok(Self { session: Arc::new(parking_lot::RwLock::new(session)), options, }) } /// Create backend from bytes pub fn from_bytes(bytes: &[u8]) -> NnResult { let options = InferenceOptions::default(); let session = OnnxSession::from_bytes(bytes, &options)?; Ok(Self { session: Arc::new(parking_lot::RwLock::new(session)), options, }) } /// Create backend from bytes with options pub fn from_bytes_with_options(bytes: &[u8], options: InferenceOptions) -> NnResult { let session = OnnxSession::from_bytes(bytes, &options)?; Ok(Self { session: Arc::new(parking_lot::RwLock::new(session)), options, }) } /// Get options pub fn options(&self) -> &InferenceOptions { &self.options } } impl Backend for OnnxBackend { fn name(&self) -> &str { "onnxruntime" } fn is_available(&self) -> bool { true } fn input_names(&self) -> Vec { self.session.read().input_names.clone() } fn output_names(&self) -> Vec { self.session.read().output_names.clone() } fn input_shape(&self, name: &str) -> Option { self.session.read().input_shapes.get(name).cloned() } fn output_shape(&self, name: &str) -> Option { self.session.read().output_shapes.get(name).cloned() } fn run(&self, inputs: HashMap) -> NnResult> { self.session.write().run(inputs) } fn warmup(&self) -> NnResult<()> { let session = self.session.read(); let mut dummy_inputs = HashMap::new(); for name in &session.input_names { if let Some(shape) = session.input_shapes.get(name) { let dims = shape.dims(); if dims.len() == 4 { dummy_inputs.insert( name.clone(), Tensor::zeros_4d([dims[0], dims[1], dims[2], dims[3]]), ); } } } drop(session); // Release read lock before running if !dummy_inputs.is_empty() { let _ = self.run(dummy_inputs)?; info!("ONNX warmup completed"); } Ok(()) } } /// Model metadata from ONNX file #[derive(Debug, Clone)] pub struct OnnxModelInfo { /// Model producer name pub producer_name: Option, /// Model version pub model_version: Option, /// Domain pub domain: Option, /// Description pub description: Option, /// Input specifications pub inputs: Vec, /// Output specifications pub outputs: Vec, } /// Tensor specification #[derive(Debug, Clone)] pub struct TensorSpec { /// Name of the tensor pub name: String, /// Shape (may contain dynamic dimensions as -1) pub shape: Vec, /// Data type pub dtype: String, } /// Load model info without creating a full session pub fn load_model_info>(path: P) -> NnResult { let session = Session::builder() .map_err(|e| NnError::model_load(format!("Failed to create session builder: {}", e)))? .commit_from_file(path.as_ref()) .map_err(|e| NnError::model_load(format!("Failed to load model: {}", e)))?; let inputs: Vec = session .inputs() .iter() .map(|input| { TensorSpec { name: input.name().to_string(), shape: vec![], dtype: "float32".to_string(), } }) .collect(); let outputs: Vec = session .outputs() .iter() .map(|output| { TensorSpec { name: output.name().to_string(), shape: vec![], dtype: "float32".to_string(), } }) .collect(); Ok(OnnxModelInfo { producer_name: None, model_version: None, domain: None, description: None, inputs, outputs, }) } /// Builder for ONNX backend pub struct OnnxBackendBuilder { model_path: Option, model_bytes: Option>, options: InferenceOptions, } impl OnnxBackendBuilder { /// Create a new builder pub fn new() -> Self { Self { model_path: None, model_bytes: None, options: InferenceOptions::default(), } } /// Set model path pub fn model_path>(mut self, path: P) -> Self { self.model_path = Some(path.into()); self } /// Set model bytes pub fn model_bytes(mut self, bytes: Vec) -> Self { self.model_bytes = Some(bytes); self } /// Use GPU pub fn gpu(mut self, device_id: usize) -> Self { self.options.use_gpu = true; self.options.gpu_device_id = device_id; self } /// Use CPU pub fn cpu(mut self) -> Self { self.options.use_gpu = false; self } /// Set number of threads pub fn threads(mut self, n: usize) -> Self { self.options.num_threads = n; self } /// Enable optimization pub fn optimize(mut self, enabled: bool) -> Self { self.options.optimize = enabled; self } /// Build the backend pub fn build(self) -> NnResult { if let Some(path) = self.model_path { OnnxBackend::from_file_with_options(path, self.options) } else if let Some(bytes) = self.model_bytes { OnnxBackend::from_bytes_with_options(&bytes, self.options) } else { Err(NnError::config("No model path or bytes provided")) } } } impl Default for OnnxBackendBuilder { fn default() -> Self { Self::new() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_onnx_backend_builder() { let builder = OnnxBackendBuilder::new() .cpu() .threads(4) .optimize(true); // Can't test build without a real model assert!(builder.model_path.is_none()); } #[test] fn test_tensor_spec() { let spec = TensorSpec { name: "input".to_string(), shape: vec![1, 3, 224, 224], dtype: "float32".to_string(), }; assert_eq!(spec.name, "input"); assert_eq!(spec.shape.len(), 4); } }