Files
wifi-densepose/vendor/ruvector/examples/onnx-embeddings/src/embedder.rs

472 lines
14 KiB
Rust

//! Main embedder implementation combining model, tokenizer, and pooling
use crate::config::{EmbedderConfig, ModelSource, PoolingStrategy};
use crate::model::OnnxModel;
use crate::pooling::Pooler;
use crate::tokenizer::Tokenizer;
use crate::{EmbeddingError, PretrainedModel, Result};
use std::path::Path;
use tracing::{debug, info, instrument};
#[cfg(feature = "gpu")]
use crate::gpu::{GpuAccelerator, GpuConfig};
/// High-level embedder combining tokenizer, model, and pooling
pub struct Embedder {
/// ONNX model for inference
model: OnnxModel,
/// Tokenizer for text processing
tokenizer: Tokenizer,
/// Pooler for combining token embeddings
pooler: Pooler,
/// Configuration
config: EmbedderConfig,
/// Optional GPU accelerator for similarity operations
#[cfg(feature = "gpu")]
gpu: Option<GpuAccelerator>,
}
/// Embedding output with metadata
#[derive(Debug, Clone)]
pub struct EmbeddingOutput {
/// The embedding vectors
pub embeddings: Vec<Vec<f32>>,
/// Original input texts
pub texts: Vec<String>,
/// Number of tokens per input
pub token_counts: Vec<usize>,
/// Embedding dimension
pub dimension: usize,
}
impl EmbeddingOutput {
/// Get the number of embeddings
pub fn len(&self) -> usize {
self.embeddings.len()
}
/// Check if empty
pub fn is_empty(&self) -> bool {
self.embeddings.is_empty()
}
/// Get a single embedding by index
pub fn get(&self, index: usize) -> Option<&Vec<f32>> {
self.embeddings.get(index)
}
/// Iterate over embeddings
pub fn iter(&self) -> impl Iterator<Item = &Vec<f32>> {
self.embeddings.iter()
}
/// Convert to owned vectors
pub fn into_vecs(self) -> Vec<Vec<f32>> {
self.embeddings
}
}
impl Embedder {
/// Create a new embedder from configuration
#[instrument(skip_all)]
pub async fn new(config: EmbedderConfig) -> Result<Self> {
info!("Initializing embedder");
// Load model
let model = OnnxModel::from_config(&config).await?;
// Load tokenizer based on model source
let tokenizer = match &config.model_source {
ModelSource::Local {
tokenizer_path, ..
} => Tokenizer::from_file(tokenizer_path, config.max_length)?,
ModelSource::Pretrained(pretrained) => {
Tokenizer::from_pretrained(pretrained.model_id(), config.max_length)?
}
ModelSource::HuggingFace { model_id, .. } => {
Tokenizer::from_pretrained(model_id, config.max_length)?
}
ModelSource::Url { tokenizer_url, .. } => {
// Download tokenizer
let cache_path = config.cache_dir.join("tokenizer.json");
if !cache_path.exists() {
download_tokenizer(tokenizer_url, &cache_path).await?;
}
Tokenizer::from_file(&cache_path, config.max_length)?
}
};
let pooler = Pooler::new(config.pooling, config.normalize);
// Initialize GPU accelerator if available
#[cfg(feature = "gpu")]
let gpu = {
match GpuAccelerator::new(GpuConfig::auto()).await {
Ok(accel) => {
info!("GPU accelerator initialized: {}", accel.device_info().name);
Some(accel)
}
Err(e) => {
debug!("GPU not available, using CPU: {}", e);
None
}
}
};
Ok(Self {
model,
tokenizer,
pooler,
config,
#[cfg(feature = "gpu")]
gpu,
})
}
/// Create embedder with default model (all-MiniLM-L6-v2)
pub async fn default_model() -> Result<Self> {
Self::new(EmbedderConfig::default()).await
}
/// Create embedder for a specific pretrained model
pub async fn pretrained(model: PretrainedModel) -> Result<Self> {
Self::new(EmbedderConfig::pretrained(model)).await
}
/// Embed a single text
#[instrument(skip(self, text), fields(text_len = text.len()))]
pub fn embed_one(&mut self, text: &str) -> Result<Vec<f32>> {
let output = self.embed(&[text])?;
output
.embeddings
.into_iter()
.next()
.ok_or(EmbeddingError::EmptyInput)
}
/// Embed multiple texts
#[instrument(skip(self, texts), fields(batch_size = texts.len()))]
pub fn embed<S: AsRef<str>>(&mut self, texts: &[S]) -> Result<EmbeddingOutput> {
if texts.is_empty() {
return Err(EmbeddingError::EmptyInput);
}
let texts_owned: Vec<String> = texts.iter().map(|t| t.as_ref().to_string()).collect();
// Process in batches
let batch_size = self.config.batch_size;
let mut all_embeddings = Vec::with_capacity(texts.len());
let mut all_token_counts = Vec::with_capacity(texts.len());
for chunk in texts.chunks(batch_size) {
let (embeddings, token_counts) = self.embed_batch(chunk)?;
all_embeddings.extend(embeddings);
all_token_counts.extend(token_counts);
}
Ok(EmbeddingOutput {
embeddings: all_embeddings,
texts: texts_owned,
token_counts: all_token_counts,
dimension: self.model.dimension(),
})
}
/// Embed a batch of texts (internal)
fn embed_batch<S: AsRef<str>>(&mut self, texts: &[S]) -> Result<(Vec<Vec<f32>>, Vec<usize>)> {
debug!("Embedding batch of {} texts", texts.len());
// Tokenize
let encoded = self.tokenizer.encode_batch(texts)?;
let (input_ids, attention_mask, token_type_ids, shape) = encoded.to_onnx_inputs();
// Run model
let token_embeddings = self.model.run(
&input_ids,
&attention_mask,
&token_type_ids,
&shape,
)?;
let seq_length = shape[1];
let hidden_size = self.model.dimension();
// Pool embeddings
let attention_masks: Vec<Vec<i64>> = encoded.attention_mask;
let embeddings = self.pooler.pool(
&token_embeddings,
&attention_masks,
seq_length,
hidden_size,
);
let token_counts = encoded.original_lengths;
Ok((embeddings, token_counts))
}
/// Embed texts (sequential processing)
/// Note: For parallel processing, consider using tokio::spawn with multiple Embedder instances
#[instrument(skip(self, texts), fields(total_texts = texts.len()))]
pub fn embed_parallel<S: AsRef<str> + Sync>(&mut self, texts: &[S]) -> Result<EmbeddingOutput> {
// Use sequential processing since ONNX session requires mutable access
self.embed(texts)
}
/// Process texts one at a time (use embed for batch processing)
pub fn embed_each<S: AsRef<str>>(&mut self, texts: &[S]) -> Vec<Result<Vec<f32>>> {
texts.iter().map(|text| self.embed_one(text.as_ref())).collect()
}
/// Get the embedding dimension
pub fn dimension(&self) -> usize {
self.model.dimension()
}
/// Get model info
pub fn model_info(&self) -> &crate::model::ModelInfo {
self.model.info()
}
/// Get the pooling strategy
pub fn pooling_strategy(&self) -> PoolingStrategy {
self.config.pooling
}
/// Get max sequence length
pub fn max_length(&self) -> usize {
self.config.max_length
}
/// Compute similarity between two texts
pub fn similarity(&mut self, text1: &str, text2: &str) -> Result<f32> {
let emb1 = self.embed_one(text1)?;
let emb2 = self.embed_one(text2)?;
Ok(Pooler::cosine_similarity(&emb1, &emb2))
}
/// Find most similar texts from a corpus
/// Uses GPU acceleration when available and corpus is large enough
#[instrument(skip(self, query, corpus), fields(corpus_size = corpus.len()))]
pub fn most_similar<S: AsRef<str>>(
&mut self,
query: &str,
corpus: &[S],
top_k: usize,
) -> Result<Vec<(usize, f32, String)>> {
let query_emb = self.embed_one(query)?;
let corpus_embs = self.embed(corpus)?;
// Try GPU-accelerated similarity if available
#[cfg(feature = "gpu")]
if let Some(ref gpu) = self.gpu {
if corpus.len() >= 64 {
let candidates: Vec<&[f32]> = corpus_embs.embeddings.iter().map(|v| v.as_slice()).collect();
if let Ok(results) = gpu.top_k_similar(&query_emb, &candidates, top_k) {
return Ok(results
.into_iter()
.map(|(idx, score)| (idx, score, corpus[idx].as_ref().to_string()))
.collect());
}
}
}
// CPU fallback
let mut similarities: Vec<(usize, f32, String)> = corpus_embs
.embeddings
.iter()
.enumerate()
.map(|(i, emb)| {
let sim = Pooler::cosine_similarity(&query_emb, emb);
(i, sim, corpus[i].as_ref().to_string())
})
.collect();
similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
similarities.truncate(top_k);
Ok(similarities)
}
/// Check if GPU acceleration is available
pub fn has_gpu(&self) -> bool {
#[cfg(feature = "gpu")]
{
self.gpu.is_some()
}
#[cfg(not(feature = "gpu"))]
{
false
}
}
/// Get GPU device info if available
#[cfg(feature = "gpu")]
pub fn gpu_info(&self) -> Option<crate::gpu::GpuInfo> {
self.gpu.as_ref().map(|g| g.device_info())
}
/// Cluster texts by similarity (simple k-means-like approach)
#[instrument(skip(self, texts), fields(n_texts = texts.len(), n_clusters))]
pub fn cluster<S: AsRef<str>>(
&mut self,
texts: &[S],
n_clusters: usize,
) -> Result<Vec<usize>> {
let embeddings = self.embed(texts)?;
let dim = self.dimension();
// Initialize centroids with first k embeddings
let mut centroids: Vec<Vec<f32>> = embeddings
.embeddings
.iter()
.take(n_clusters)
.cloned()
.collect();
let mut assignments = vec![0usize; texts.len()];
let max_iterations = 100;
for _ in 0..max_iterations {
let old_assignments = assignments.clone();
// Assign to nearest centroid
for (i, emb) in embeddings.embeddings.iter().enumerate() {
let mut min_dist = f32::MAX;
let mut min_idx = 0;
for (j, centroid) in centroids.iter().enumerate() {
let dist = Pooler::euclidean_distance(emb, centroid);
if dist < min_dist {
min_dist = dist;
min_idx = j;
}
}
assignments[i] = min_idx;
}
// Check convergence
if assignments == old_assignments {
break;
}
// Update centroids
for (j, centroid) in centroids.iter_mut().enumerate() {
let cluster_points: Vec<&Vec<f32>> = embeddings
.embeddings
.iter()
.zip(assignments.iter())
.filter(|(_, &a)| a == j)
.map(|(e, _)| e)
.collect();
if !cluster_points.is_empty() {
*centroid = vec![0.0; dim];
for point in &cluster_points {
for (k, &val) in point.iter().enumerate() {
centroid[k] += val;
}
}
let count = cluster_points.len() as f32;
for val in centroid.iter_mut() {
*val /= count;
}
}
}
}
Ok(assignments)
}
}
/// Download tokenizer from URL
async fn download_tokenizer(url: &str, path: &Path) -> Result<()> {
use std::io::Write;
let response = reqwest::get(url).await?;
if !response.status().is_success() {
return Err(EmbeddingError::download_failed(format!(
"Failed to download tokenizer: HTTP {}",
response.status()
)));
}
let bytes = response.bytes().await?;
let mut file = std::fs::File::create(path)?;
file.write_all(&bytes)?;
Ok(())
}
/// Builder for creating embedders with custom configurations
pub struct EmbedderBuilder {
config: EmbedderConfig,
}
impl EmbedderBuilder {
/// Start building an embedder
pub fn new() -> Self {
Self {
config: EmbedderConfig::default(),
}
}
/// Use a pretrained model
pub fn pretrained(mut self, model: PretrainedModel) -> Self {
self.config = EmbedderConfig::pretrained(model);
self
}
/// Set pooling strategy
pub fn pooling(mut self, strategy: PoolingStrategy) -> Self {
self.config.pooling = strategy;
self
}
/// Set normalization
pub fn normalize(mut self, normalize: bool) -> Self {
self.config.normalize = normalize;
self
}
/// Set batch size
pub fn batch_size(mut self, size: usize) -> Self {
self.config.batch_size = size;
self
}
/// Set max sequence length
pub fn max_length(mut self, length: usize) -> Self {
self.config.max_length = length;
self
}
/// Build the embedder
pub async fn build(self) -> Result<Embedder> {
Embedder::new(self.config).await
}
}
impl Default for EmbedderBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = EmbedderConfig::default();
assert_eq!(config.pooling, PoolingStrategy::Mean);
assert!(config.normalize);
}
}