Files
wifi-densepose/vendor/ruvector/examples/onnx-embeddings/src/model.rs

471 lines
16 KiB
Rust

//! ONNX model loading and management
use crate::config::{EmbedderConfig, ExecutionProvider, ModelSource};
use crate::{EmbeddingError, PretrainedModel, Result};
use indicatif::{ProgressBar, ProgressStyle};
use ort::session::{builder::GraphOptimizationLevel, Session};
use sha2::{Digest, Sha256};
use std::fs;
use std::io::Write;
use std::path::Path;
use tracing::{debug, info, instrument, warn};
/// Information about a loaded model
#[derive(Debug, Clone)]
pub struct ModelInfo {
/// Model name or identifier
pub name: String,
/// Embedding dimension
pub dimension: usize,
/// Maximum sequence length
pub max_seq_length: usize,
/// Model file size in bytes
pub file_size: u64,
/// Model input names
pub input_names: Vec<String>,
/// Model output names
pub output_names: Vec<String>,
}
/// ONNX model wrapper with inference capabilities
pub struct OnnxModel {
session: Session,
info: ModelInfo,
}
impl OnnxModel {
/// Load model from configuration
#[instrument(skip_all)]
pub async fn from_config(config: &EmbedderConfig) -> Result<Self> {
match &config.model_source {
ModelSource::Local {
model_path,
tokenizer_path: _,
} => Self::from_file(model_path, config).await,
ModelSource::Pretrained(model) => Self::from_pretrained(*model, config).await,
ModelSource::HuggingFace { model_id, revision } => {
Self::from_huggingface(model_id, revision.as_deref(), config).await
}
ModelSource::Url {
model_url,
tokenizer_url: _,
} => Self::from_url(model_url, config).await,
}
}
/// Load model from a local ONNX file
#[instrument(skip_all, fields(path = %path.as_ref().display()))]
pub async fn from_file(path: impl AsRef<Path>, config: &EmbedderConfig) -> Result<Self> {
let path = path.as_ref();
info!("Loading ONNX model from file: {}", path.display());
if !path.exists() {
return Err(EmbeddingError::model_not_found(path.display().to_string()));
}
let file_size = fs::metadata(path)?.len();
let session = Self::create_session(path, config)?;
let info = Self::extract_model_info(&session, path, file_size)?;
Ok(Self { session, info })
}
/// Load a pretrained model (downloads if not cached)
#[instrument(skip_all, fields(model = ?model))]
pub async fn from_pretrained(model: PretrainedModel, config: &EmbedderConfig) -> Result<Self> {
let model_id = model.model_id();
info!("Loading pretrained model: {}", model_id);
// Check cache first
let cache_path = config.cache_dir.join(sanitize_model_id(model_id));
let model_path = cache_path.join("model.onnx");
if model_path.exists() {
debug!("Found cached model at {}", model_path.display());
return Self::from_file(&model_path, config).await;
}
// Download from HuggingFace
Self::from_huggingface(model_id, None, config).await
}
/// Load model from HuggingFace Hub
#[instrument(skip_all, fields(model_id = %model_id))]
pub async fn from_huggingface(
model_id: &str,
revision: Option<&str>,
config: &EmbedderConfig,
) -> Result<Self> {
let cache_path = config.cache_dir.join(sanitize_model_id(model_id));
fs::create_dir_all(&cache_path)?;
let model_path = cache_path.join("model.onnx");
if !model_path.exists() {
info!("Downloading model from HuggingFace: {}", model_id);
download_from_huggingface(model_id, revision, &cache_path, config.show_progress)
.await?;
}
Self::from_file(&model_path, config).await
}
/// Load model from a URL
#[instrument(skip_all, fields(url = %url))]
pub async fn from_url(url: &str, config: &EmbedderConfig) -> Result<Self> {
let hash = hash_url(url);
let cache_path = config.cache_dir.join(&hash);
fs::create_dir_all(&cache_path)?;
let model_path = cache_path.join("model.onnx");
if !model_path.exists() {
info!("Downloading model from URL: {}", url);
download_file(url, &model_path, config.show_progress).await?;
}
Self::from_file(&model_path, config).await
}
/// Create an ONNX session with the specified configuration
fn create_session(path: &Path, config: &EmbedderConfig) -> Result<Session> {
let mut builder = Session::builder()?;
// Set optimization level
if config.optimize_graph {
builder = builder.with_optimization_level(GraphOptimizationLevel::Level3)?;
}
// Set number of threads
builder = builder.with_intra_threads(config.num_threads)?;
// Configure execution provider
match config.execution_provider {
ExecutionProvider::Cpu => {
// Default CPU provider
}
#[cfg(feature = "cuda")]
ExecutionProvider::Cuda { device_id } => {
builder = builder.with_execution_providers([
ort::execution_providers::CUDAExecutionProvider::default()
.with_device_id(device_id)
.build(),
])?;
}
#[cfg(feature = "tensorrt")]
ExecutionProvider::TensorRt { device_id } => {
builder = builder.with_execution_providers([
ort::execution_providers::TensorRTExecutionProvider::default()
.with_device_id(device_id)
.build(),
])?;
}
#[cfg(feature = "coreml")]
ExecutionProvider::CoreMl => {
builder = builder.with_execution_providers([
ort::execution_providers::CoreMLExecutionProvider::default().build(),
])?;
}
_ => {
warn!(
"Requested execution provider not available, falling back to CPU"
);
}
}
let session = builder.commit_from_file(path)?;
Ok(session)
}
/// Extract model information from the session
fn extract_model_info(session: &Session, path: &Path, file_size: u64) -> Result<ModelInfo> {
let inputs: Vec<String> = session.inputs.iter().map(|i| i.name.clone()).collect();
let outputs: Vec<String> = session.outputs.iter().map(|o| o.name.clone()).collect();
// Default embedding dimension (will be determined at runtime from actual output)
// Most sentence-transformers models output 384 dimensions
let dimension = 384;
let name = path
.file_stem()
.map(|s| s.to_string_lossy().to_string())
.unwrap_or_else(|| "unknown".to_string());
Ok(ModelInfo {
name,
dimension,
max_seq_length: 512,
file_size,
input_names: inputs,
output_names: outputs,
})
}
/// Run inference on encoded inputs
#[instrument(skip_all, fields(batch_size, seq_length))]
pub fn run(
&mut self,
input_ids: &[i64],
attention_mask: &[i64],
token_type_ids: &[i64],
shape: &[usize],
) -> Result<Vec<Vec<f32>>> {
use ort::value::Tensor;
let batch_size = shape[0];
let seq_length = shape[1];
debug!(
"Running inference: batch_size={}, seq_length={}",
batch_size, seq_length
);
// Create input tensors using ort's Tensor type
let input_ids_tensor = Tensor::from_array((
vec![batch_size, seq_length],
input_ids.to_vec().into_boxed_slice(),
))
.map_err(|e| EmbeddingError::invalid_model(e.to_string()))?;
let attention_mask_tensor = Tensor::from_array((
vec![batch_size, seq_length],
attention_mask.to_vec().into_boxed_slice(),
))
.map_err(|e| EmbeddingError::invalid_model(e.to_string()))?;
let token_type_ids_tensor = Tensor::from_array((
vec![batch_size, seq_length],
token_type_ids.to_vec().into_boxed_slice(),
))
.map_err(|e| EmbeddingError::invalid_model(e.to_string()))?;
// Build inputs vector
let inputs = vec![
("input_ids", input_ids_tensor.into_dyn()),
("attention_mask", attention_mask_tensor.into_dyn()),
("token_type_ids", token_type_ids_tensor.into_dyn()),
];
// Run inference
let outputs = self.session.run(inputs)
.map_err(EmbeddingError::OnnxRuntime)?;
// Extract output tensor
// Usually the output is [batch, seq_len, hidden_size] or [batch, hidden_size]
let output_names = ["last_hidden_state", "output", "sentence_embedding"];
// Find the appropriate output by name, or use the first one
let output_iter: Vec<_> = outputs.iter().collect();
let output = output_iter
.iter()
.find(|(name, _)| output_names.contains(name))
.or_else(|| output_iter.first())
.map(|(_, v)| v)
.ok_or_else(|| EmbeddingError::invalid_model("No output tensor found"))?;
// In ort 2.0, try_extract_tensor returns (&Shape, &[f32])
let (tensor_shape, tensor_data) = output
.try_extract_tensor::<f32>()
.map_err(|e| EmbeddingError::invalid_model(e.to_string()))?;
// Convert Shape to Vec<usize> - Shape yields i64
let dims: Vec<usize> = tensor_shape.iter().map(|&d| d as usize).collect();
// Handle different output shapes
let embeddings = if dims.len() == 3 {
// [batch, seq_len, hidden] - need pooling
let hidden_size = dims[2];
(0..batch_size)
.map(|i| {
let start = i * seq_length * hidden_size;
let end = start + seq_length * hidden_size;
tensor_data[start..end].to_vec()
})
.collect()
} else if dims.len() == 2 {
// [batch, hidden] - already pooled
let hidden_size = dims[1];
(0..batch_size)
.map(|i| {
let start = i * hidden_size;
let end = start + hidden_size;
tensor_data[start..end].to_vec()
})
.collect()
} else {
return Err(EmbeddingError::invalid_model(format!(
"Unexpected output shape: {:?}",
dims
)));
};
Ok(embeddings)
}
/// Get model info
pub fn info(&self) -> &ModelInfo {
&self.info
}
/// Get embedding dimension
pub fn dimension(&self) -> usize {
self.info.dimension
}
}
/// Download model files from HuggingFace Hub
async fn download_from_huggingface(
model_id: &str,
revision: Option<&str>,
cache_path: &Path,
show_progress: bool,
) -> Result<()> {
let revision = revision.unwrap_or("main");
let base_url = format!(
"https://huggingface.co/{}/resolve/{}",
model_id, revision
);
let model_path = cache_path.join("model.onnx");
// Try to download model.onnx - check multiple locations
if !model_path.exists() {
// Location 1: Root directory (model.onnx)
let root_url = format!("{}/model.onnx", base_url);
debug!("Trying to download model from root: {}", root_url);
let root_result = download_file(&root_url, &model_path, show_progress).await;
// Location 2: ONNX subfolder (onnx/model.onnx) - common for sentence-transformers
if root_result.is_err() && !model_path.exists() {
let onnx_url = format!("{}/onnx/model.onnx", base_url);
debug!("Root download failed, trying onnx subfolder: {}", onnx_url);
match download_file(&onnx_url, &model_path, show_progress).await {
Ok(_) => debug!("Downloaded model.onnx from onnx/ subfolder"),
Err(e) => {
// Both locations failed
return Err(EmbeddingError::download_failed(format!(
"Failed to download model.onnx from {} - tried both root and onnx/ subfolder: {}",
model_id, e
)));
}
}
} else if let Err(e) = root_result {
// Root failed but model exists (shouldn't happen, but handle gracefully)
if !model_path.exists() {
return Err(e);
}
} else {
debug!("Downloaded model.onnx from root");
}
}
// Download auxiliary files (tokenizer.json, config.json) - these are optional
let aux_files = ["tokenizer.json", "config.json"];
for file in aux_files {
let path = cache_path.join(file);
if !path.exists() {
// Try root first, then onnx subfolder
let root_url = format!("{}/{}", base_url, file);
match download_file(&root_url, &path, show_progress).await {
Ok(_) => debug!("Downloaded {}", file),
Err(_) => {
// Try onnx subfolder
let onnx_url = format!("{}/onnx/{}", base_url, file);
match download_file(&onnx_url, &path, show_progress).await {
Ok(_) => debug!("Downloaded {} from onnx/ subfolder", file),
Err(e) => warn!("Failed to download {} (optional): {}", file, e),
}
}
}
}
}
Ok(())
}
/// Download a file from URL with optional progress bar
async fn download_file(url: &str, path: &Path, show_progress: bool) -> Result<()> {
let client = reqwest::Client::new();
let response = client.get(url).send().await?;
if !response.status().is_success() {
return Err(EmbeddingError::download_failed(format!(
"HTTP {}: {}",
response.status(),
url
)));
}
let total_size = response.content_length().unwrap_or(0);
let pb = if show_progress && total_size > 0 {
let pb = ProgressBar::new(total_size);
pb.set_style(
ProgressStyle::default_bar()
.template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({eta})")
.unwrap()
.progress_chars("#>-"),
);
Some(pb)
} else {
None
};
let mut file = fs::File::create(path)?;
let mut downloaded = 0u64;
use futures_util::StreamExt;
let mut stream = response.bytes_stream();
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
file.write_all(&chunk)?;
downloaded += chunk.len() as u64;
if let Some(ref pb) = pb {
pb.set_position(downloaded);
}
}
if let Some(pb) = pb {
pb.finish_with_message("Downloaded");
}
Ok(())
}
/// Sanitize model ID for use as directory name
fn sanitize_model_id(model_id: &str) -> String {
model_id.replace(['/', '\\', ':'], "_")
}
/// Create a hash of a URL for caching
fn hash_url(url: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(url.as_bytes());
hex::encode(&hasher.finalize()[..8])
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sanitize_model_id() {
assert_eq!(
sanitize_model_id("sentence-transformers/all-MiniLM-L6-v2"),
"sentence-transformers_all-MiniLM-L6-v2"
);
}
#[test]
fn test_hash_url() {
let hash = hash_url("https://example.com/model.onnx");
assert_eq!(hash.len(), 16); // 8 bytes = 16 hex chars
}
}