//! ONNX model loading and management use crate::config::{EmbedderConfig, ExecutionProvider, ModelSource}; use crate::{EmbeddingError, PretrainedModel, Result}; use indicatif::{ProgressBar, ProgressStyle}; use ort::session::{builder::GraphOptimizationLevel, Session}; use sha2::{Digest, Sha256}; use std::fs; use std::io::Write; use std::path::Path; use tracing::{debug, info, instrument, warn}; /// Information about a loaded model #[derive(Debug, Clone)] pub struct ModelInfo { /// Model name or identifier pub name: String, /// Embedding dimension pub dimension: usize, /// Maximum sequence length pub max_seq_length: usize, /// Model file size in bytes pub file_size: u64, /// Model input names pub input_names: Vec, /// Model output names pub output_names: Vec, } /// ONNX model wrapper with inference capabilities pub struct OnnxModel { session: Session, info: ModelInfo, } impl OnnxModel { /// Load model from configuration #[instrument(skip_all)] pub async fn from_config(config: &EmbedderConfig) -> Result { match &config.model_source { ModelSource::Local { model_path, tokenizer_path: _, } => Self::from_file(model_path, config).await, ModelSource::Pretrained(model) => Self::from_pretrained(*model, config).await, ModelSource::HuggingFace { model_id, revision } => { Self::from_huggingface(model_id, revision.as_deref(), config).await } ModelSource::Url { model_url, tokenizer_url: _, } => Self::from_url(model_url, config).await, } } /// Load model from a local ONNX file #[instrument(skip_all, fields(path = %path.as_ref().display()))] pub async fn from_file(path: impl AsRef, config: &EmbedderConfig) -> Result { let path = path.as_ref(); info!("Loading ONNX model from file: {}", path.display()); if !path.exists() { return Err(EmbeddingError::model_not_found(path.display().to_string())); } let file_size = fs::metadata(path)?.len(); let session = Self::create_session(path, config)?; let info = Self::extract_model_info(&session, path, file_size)?; Ok(Self { session, info }) } /// Load a pretrained model (downloads if not cached) #[instrument(skip_all, fields(model = ?model))] pub async fn from_pretrained(model: PretrainedModel, config: &EmbedderConfig) -> Result { let model_id = model.model_id(); info!("Loading pretrained model: {}", model_id); // Check cache first let cache_path = config.cache_dir.join(sanitize_model_id(model_id)); let model_path = cache_path.join("model.onnx"); if model_path.exists() { debug!("Found cached model at {}", model_path.display()); return Self::from_file(&model_path, config).await; } // Download from HuggingFace Self::from_huggingface(model_id, None, config).await } /// Load model from HuggingFace Hub #[instrument(skip_all, fields(model_id = %model_id))] pub async fn from_huggingface( model_id: &str, revision: Option<&str>, config: &EmbedderConfig, ) -> Result { let cache_path = config.cache_dir.join(sanitize_model_id(model_id)); fs::create_dir_all(&cache_path)?; let model_path = cache_path.join("model.onnx"); if !model_path.exists() { info!("Downloading model from HuggingFace: {}", model_id); download_from_huggingface(model_id, revision, &cache_path, config.show_progress) .await?; } Self::from_file(&model_path, config).await } /// Load model from a URL #[instrument(skip_all, fields(url = %url))] pub async fn from_url(url: &str, config: &EmbedderConfig) -> Result { let hash = hash_url(url); let cache_path = config.cache_dir.join(&hash); fs::create_dir_all(&cache_path)?; let model_path = cache_path.join("model.onnx"); if !model_path.exists() { info!("Downloading model from URL: {}", url); download_file(url, &model_path, config.show_progress).await?; } Self::from_file(&model_path, config).await } /// Create an ONNX session with the specified configuration fn create_session(path: &Path, config: &EmbedderConfig) -> Result { let mut builder = Session::builder()?; // Set optimization level if config.optimize_graph { builder = builder.with_optimization_level(GraphOptimizationLevel::Level3)?; } // Set number of threads builder = builder.with_intra_threads(config.num_threads)?; // Configure execution provider match config.execution_provider { ExecutionProvider::Cpu => { // Default CPU provider } #[cfg(feature = "cuda")] ExecutionProvider::Cuda { device_id } => { builder = builder.with_execution_providers([ ort::execution_providers::CUDAExecutionProvider::default() .with_device_id(device_id) .build(), ])?; } #[cfg(feature = "tensorrt")] ExecutionProvider::TensorRt { device_id } => { builder = builder.with_execution_providers([ ort::execution_providers::TensorRTExecutionProvider::default() .with_device_id(device_id) .build(), ])?; } #[cfg(feature = "coreml")] ExecutionProvider::CoreMl => { builder = builder.with_execution_providers([ ort::execution_providers::CoreMLExecutionProvider::default().build(), ])?; } _ => { warn!( "Requested execution provider not available, falling back to CPU" ); } } let session = builder.commit_from_file(path)?; Ok(session) } /// Extract model information from the session fn extract_model_info(session: &Session, path: &Path, file_size: u64) -> Result { let inputs: Vec = session.inputs.iter().map(|i| i.name.clone()).collect(); let outputs: Vec = session.outputs.iter().map(|o| o.name.clone()).collect(); // Default embedding dimension (will be determined at runtime from actual output) // Most sentence-transformers models output 384 dimensions let dimension = 384; let name = path .file_stem() .map(|s| s.to_string_lossy().to_string()) .unwrap_or_else(|| "unknown".to_string()); Ok(ModelInfo { name, dimension, max_seq_length: 512, file_size, input_names: inputs, output_names: outputs, }) } /// Run inference on encoded inputs #[instrument(skip_all, fields(batch_size, seq_length))] pub fn run( &mut self, input_ids: &[i64], attention_mask: &[i64], token_type_ids: &[i64], shape: &[usize], ) -> Result>> { use ort::value::Tensor; let batch_size = shape[0]; let seq_length = shape[1]; debug!( "Running inference: batch_size={}, seq_length={}", batch_size, seq_length ); // Create input tensors using ort's Tensor type let input_ids_tensor = Tensor::from_array(( vec![batch_size, seq_length], input_ids.to_vec().into_boxed_slice(), )) .map_err(|e| EmbeddingError::invalid_model(e.to_string()))?; let attention_mask_tensor = Tensor::from_array(( vec![batch_size, seq_length], attention_mask.to_vec().into_boxed_slice(), )) .map_err(|e| EmbeddingError::invalid_model(e.to_string()))?; let token_type_ids_tensor = Tensor::from_array(( vec![batch_size, seq_length], token_type_ids.to_vec().into_boxed_slice(), )) .map_err(|e| EmbeddingError::invalid_model(e.to_string()))?; // Build inputs vector let inputs = vec![ ("input_ids", input_ids_tensor.into_dyn()), ("attention_mask", attention_mask_tensor.into_dyn()), ("token_type_ids", token_type_ids_tensor.into_dyn()), ]; // Run inference let outputs = self.session.run(inputs) .map_err(EmbeddingError::OnnxRuntime)?; // Extract output tensor // Usually the output is [batch, seq_len, hidden_size] or [batch, hidden_size] let output_names = ["last_hidden_state", "output", "sentence_embedding"]; // Find the appropriate output by name, or use the first one let output_iter: Vec<_> = outputs.iter().collect(); let output = output_iter .iter() .find(|(name, _)| output_names.contains(name)) .or_else(|| output_iter.first()) .map(|(_, v)| v) .ok_or_else(|| EmbeddingError::invalid_model("No output tensor found"))?; // In ort 2.0, try_extract_tensor returns (&Shape, &[f32]) let (tensor_shape, tensor_data) = output .try_extract_tensor::() .map_err(|e| EmbeddingError::invalid_model(e.to_string()))?; // Convert Shape to Vec - Shape yields i64 let dims: Vec = tensor_shape.iter().map(|&d| d as usize).collect(); // Handle different output shapes let embeddings = if dims.len() == 3 { // [batch, seq_len, hidden] - need pooling let hidden_size = dims[2]; (0..batch_size) .map(|i| { let start = i * seq_length * hidden_size; let end = start + seq_length * hidden_size; tensor_data[start..end].to_vec() }) .collect() } else if dims.len() == 2 { // [batch, hidden] - already pooled let hidden_size = dims[1]; (0..batch_size) .map(|i| { let start = i * hidden_size; let end = start + hidden_size; tensor_data[start..end].to_vec() }) .collect() } else { return Err(EmbeddingError::invalid_model(format!( "Unexpected output shape: {:?}", dims ))); }; Ok(embeddings) } /// Get model info pub fn info(&self) -> &ModelInfo { &self.info } /// Get embedding dimension pub fn dimension(&self) -> usize { self.info.dimension } } /// Download model files from HuggingFace Hub async fn download_from_huggingface( model_id: &str, revision: Option<&str>, cache_path: &Path, show_progress: bool, ) -> Result<()> { let revision = revision.unwrap_or("main"); let base_url = format!( "https://huggingface.co/{}/resolve/{}", model_id, revision ); let model_path = cache_path.join("model.onnx"); // Try to download model.onnx - check multiple locations if !model_path.exists() { // Location 1: Root directory (model.onnx) let root_url = format!("{}/model.onnx", base_url); debug!("Trying to download model from root: {}", root_url); let root_result = download_file(&root_url, &model_path, show_progress).await; // Location 2: ONNX subfolder (onnx/model.onnx) - common for sentence-transformers if root_result.is_err() && !model_path.exists() { let onnx_url = format!("{}/onnx/model.onnx", base_url); debug!("Root download failed, trying onnx subfolder: {}", onnx_url); match download_file(&onnx_url, &model_path, show_progress).await { Ok(_) => debug!("Downloaded model.onnx from onnx/ subfolder"), Err(e) => { // Both locations failed return Err(EmbeddingError::download_failed(format!( "Failed to download model.onnx from {} - tried both root and onnx/ subfolder: {}", model_id, e ))); } } } else if let Err(e) = root_result { // Root failed but model exists (shouldn't happen, but handle gracefully) if !model_path.exists() { return Err(e); } } else { debug!("Downloaded model.onnx from root"); } } // Download auxiliary files (tokenizer.json, config.json) - these are optional let aux_files = ["tokenizer.json", "config.json"]; for file in aux_files { let path = cache_path.join(file); if !path.exists() { // Try root first, then onnx subfolder let root_url = format!("{}/{}", base_url, file); match download_file(&root_url, &path, show_progress).await { Ok(_) => debug!("Downloaded {}", file), Err(_) => { // Try onnx subfolder let onnx_url = format!("{}/onnx/{}", base_url, file); match download_file(&onnx_url, &path, show_progress).await { Ok(_) => debug!("Downloaded {} from onnx/ subfolder", file), Err(e) => warn!("Failed to download {} (optional): {}", file, e), } } } } } Ok(()) } /// Download a file from URL with optional progress bar async fn download_file(url: &str, path: &Path, show_progress: bool) -> Result<()> { let client = reqwest::Client::new(); let response = client.get(url).send().await?; if !response.status().is_success() { return Err(EmbeddingError::download_failed(format!( "HTTP {}: {}", response.status(), url ))); } let total_size = response.content_length().unwrap_or(0); let pb = if show_progress && total_size > 0 { let pb = ProgressBar::new(total_size); pb.set_style( ProgressStyle::default_bar() .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({eta})") .unwrap() .progress_chars("#>-"), ); Some(pb) } else { None }; let mut file = fs::File::create(path)?; let mut downloaded = 0u64; use futures_util::StreamExt; let mut stream = response.bytes_stream(); while let Some(chunk) = stream.next().await { let chunk = chunk?; file.write_all(&chunk)?; downloaded += chunk.len() as u64; if let Some(ref pb) = pb { pb.set_position(downloaded); } } if let Some(pb) = pb { pb.finish_with_message("Downloaded"); } Ok(()) } /// Sanitize model ID for use as directory name fn sanitize_model_id(model_id: &str) -> String { model_id.replace(['/', '\\', ':'], "_") } /// Create a hash of a URL for caching fn hash_url(url: &str) -> String { let mut hasher = Sha256::new(); hasher.update(url.as_bytes()); hex::encode(&hasher.finalize()[..8]) } #[cfg(test)] mod tests { use super::*; #[test] fn test_sanitize_model_id() { assert_eq!( sanitize_model_id("sentence-transformers/all-MiniLM-L6-v2"), "sentence-transformers_all-MiniLM-L6-v2" ); } #[test] fn test_hash_url() { let hash = hash_url("https://example.com/model.onnx"); assert_eq!(hash.len(), 16); // 8 bytes = 16 hex chars } }