Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
471
vendor/ruvector/examples/onnx-embeddings/src/embedder.rs
vendored
Normal file
471
vendor/ruvector/examples/onnx-embeddings/src/embedder.rs
vendored
Normal file
@@ -0,0 +1,471 @@
|
||||
//! Main embedder implementation combining model, tokenizer, and pooling
|
||||
|
||||
use crate::config::{EmbedderConfig, ModelSource, PoolingStrategy};
|
||||
use crate::model::OnnxModel;
|
||||
use crate::pooling::Pooler;
|
||||
use crate::tokenizer::Tokenizer;
|
||||
use crate::{EmbeddingError, PretrainedModel, Result};
|
||||
use std::path::Path;
|
||||
use tracing::{debug, info, instrument};
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::gpu::{GpuAccelerator, GpuConfig};
|
||||
|
||||
/// High-level embedder combining tokenizer, model, and pooling
|
||||
pub struct Embedder {
|
||||
/// ONNX model for inference
|
||||
model: OnnxModel,
|
||||
/// Tokenizer for text processing
|
||||
tokenizer: Tokenizer,
|
||||
/// Pooler for combining token embeddings
|
||||
pooler: Pooler,
|
||||
/// Configuration
|
||||
config: EmbedderConfig,
|
||||
/// Optional GPU accelerator for similarity operations
|
||||
#[cfg(feature = "gpu")]
|
||||
gpu: Option<GpuAccelerator>,
|
||||
}
|
||||
|
||||
/// Embedding output with metadata
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EmbeddingOutput {
|
||||
/// The embedding vectors
|
||||
pub embeddings: Vec<Vec<f32>>,
|
||||
/// Original input texts
|
||||
pub texts: Vec<String>,
|
||||
/// Number of tokens per input
|
||||
pub token_counts: Vec<usize>,
|
||||
/// Embedding dimension
|
||||
pub dimension: usize,
|
||||
}
|
||||
|
||||
impl EmbeddingOutput {
|
||||
/// Get the number of embeddings
|
||||
pub fn len(&self) -> usize {
|
||||
self.embeddings.len()
|
||||
}
|
||||
|
||||
/// Check if empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.embeddings.is_empty()
|
||||
}
|
||||
|
||||
/// Get a single embedding by index
|
||||
pub fn get(&self, index: usize) -> Option<&Vec<f32>> {
|
||||
self.embeddings.get(index)
|
||||
}
|
||||
|
||||
/// Iterate over embeddings
|
||||
pub fn iter(&self) -> impl Iterator<Item = &Vec<f32>> {
|
||||
self.embeddings.iter()
|
||||
}
|
||||
|
||||
/// Convert to owned vectors
|
||||
pub fn into_vecs(self) -> Vec<Vec<f32>> {
|
||||
self.embeddings
|
||||
}
|
||||
}
|
||||
|
||||
impl Embedder {
|
||||
/// Create a new embedder from configuration
|
||||
#[instrument(skip_all)]
|
||||
pub async fn new(config: EmbedderConfig) -> Result<Self> {
|
||||
info!("Initializing embedder");
|
||||
|
||||
// Load model
|
||||
let model = OnnxModel::from_config(&config).await?;
|
||||
|
||||
// Load tokenizer based on model source
|
||||
let tokenizer = match &config.model_source {
|
||||
ModelSource::Local {
|
||||
tokenizer_path, ..
|
||||
} => Tokenizer::from_file(tokenizer_path, config.max_length)?,
|
||||
|
||||
ModelSource::Pretrained(pretrained) => {
|
||||
Tokenizer::from_pretrained(pretrained.model_id(), config.max_length)?
|
||||
}
|
||||
|
||||
ModelSource::HuggingFace { model_id, .. } => {
|
||||
Tokenizer::from_pretrained(model_id, config.max_length)?
|
||||
}
|
||||
|
||||
ModelSource::Url { tokenizer_url, .. } => {
|
||||
// Download tokenizer
|
||||
let cache_path = config.cache_dir.join("tokenizer.json");
|
||||
if !cache_path.exists() {
|
||||
download_tokenizer(tokenizer_url, &cache_path).await?;
|
||||
}
|
||||
Tokenizer::from_file(&cache_path, config.max_length)?
|
||||
}
|
||||
};
|
||||
|
||||
let pooler = Pooler::new(config.pooling, config.normalize);
|
||||
|
||||
// Initialize GPU accelerator if available
|
||||
#[cfg(feature = "gpu")]
|
||||
let gpu = {
|
||||
match GpuAccelerator::new(GpuConfig::auto()).await {
|
||||
Ok(accel) => {
|
||||
info!("GPU accelerator initialized: {}", accel.device_info().name);
|
||||
Some(accel)
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("GPU not available, using CPU: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
model,
|
||||
tokenizer,
|
||||
pooler,
|
||||
config,
|
||||
#[cfg(feature = "gpu")]
|
||||
gpu,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create embedder with default model (all-MiniLM-L6-v2)
|
||||
pub async fn default_model() -> Result<Self> {
|
||||
Self::new(EmbedderConfig::default()).await
|
||||
}
|
||||
|
||||
/// Create embedder for a specific pretrained model
|
||||
pub async fn pretrained(model: PretrainedModel) -> Result<Self> {
|
||||
Self::new(EmbedderConfig::pretrained(model)).await
|
||||
}
|
||||
|
||||
/// Embed a single text
|
||||
#[instrument(skip(self, text), fields(text_len = text.len()))]
|
||||
pub fn embed_one(&mut self, text: &str) -> Result<Vec<f32>> {
|
||||
let output = self.embed(&[text])?;
|
||||
output
|
||||
.embeddings
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or(EmbeddingError::EmptyInput)
|
||||
}
|
||||
|
||||
/// Embed multiple texts
|
||||
#[instrument(skip(self, texts), fields(batch_size = texts.len()))]
|
||||
pub fn embed<S: AsRef<str>>(&mut self, texts: &[S]) -> Result<EmbeddingOutput> {
|
||||
if texts.is_empty() {
|
||||
return Err(EmbeddingError::EmptyInput);
|
||||
}
|
||||
|
||||
let texts_owned: Vec<String> = texts.iter().map(|t| t.as_ref().to_string()).collect();
|
||||
|
||||
// Process in batches
|
||||
let batch_size = self.config.batch_size;
|
||||
let mut all_embeddings = Vec::with_capacity(texts.len());
|
||||
let mut all_token_counts = Vec::with_capacity(texts.len());
|
||||
|
||||
for chunk in texts.chunks(batch_size) {
|
||||
let (embeddings, token_counts) = self.embed_batch(chunk)?;
|
||||
all_embeddings.extend(embeddings);
|
||||
all_token_counts.extend(token_counts);
|
||||
}
|
||||
|
||||
Ok(EmbeddingOutput {
|
||||
embeddings: all_embeddings,
|
||||
texts: texts_owned,
|
||||
token_counts: all_token_counts,
|
||||
dimension: self.model.dimension(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Embed a batch of texts (internal)
|
||||
fn embed_batch<S: AsRef<str>>(&mut self, texts: &[S]) -> Result<(Vec<Vec<f32>>, Vec<usize>)> {
|
||||
debug!("Embedding batch of {} texts", texts.len());
|
||||
|
||||
// Tokenize
|
||||
let encoded = self.tokenizer.encode_batch(texts)?;
|
||||
let (input_ids, attention_mask, token_type_ids, shape) = encoded.to_onnx_inputs();
|
||||
|
||||
// Run model
|
||||
let token_embeddings = self.model.run(
|
||||
&input_ids,
|
||||
&attention_mask,
|
||||
&token_type_ids,
|
||||
&shape,
|
||||
)?;
|
||||
|
||||
let seq_length = shape[1];
|
||||
let hidden_size = self.model.dimension();
|
||||
|
||||
// Pool embeddings
|
||||
let attention_masks: Vec<Vec<i64>> = encoded.attention_mask;
|
||||
let embeddings = self.pooler.pool(
|
||||
&token_embeddings,
|
||||
&attention_masks,
|
||||
seq_length,
|
||||
hidden_size,
|
||||
);
|
||||
|
||||
let token_counts = encoded.original_lengths;
|
||||
|
||||
Ok((embeddings, token_counts))
|
||||
}
|
||||
|
||||
/// Embed texts (sequential processing)
|
||||
/// Note: For parallel processing, consider using tokio::spawn with multiple Embedder instances
|
||||
#[instrument(skip(self, texts), fields(total_texts = texts.len()))]
|
||||
pub fn embed_parallel<S: AsRef<str> + Sync>(&mut self, texts: &[S]) -> Result<EmbeddingOutput> {
|
||||
// Use sequential processing since ONNX session requires mutable access
|
||||
self.embed(texts)
|
||||
}
|
||||
|
||||
/// Process texts one at a time (use embed for batch processing)
|
||||
pub fn embed_each<S: AsRef<str>>(&mut self, texts: &[S]) -> Vec<Result<Vec<f32>>> {
|
||||
texts.iter().map(|text| self.embed_one(text.as_ref())).collect()
|
||||
}
|
||||
|
||||
/// Get the embedding dimension
|
||||
pub fn dimension(&self) -> usize {
|
||||
self.model.dimension()
|
||||
}
|
||||
|
||||
/// Get model info
|
||||
pub fn model_info(&self) -> &crate::model::ModelInfo {
|
||||
self.model.info()
|
||||
}
|
||||
|
||||
/// Get the pooling strategy
|
||||
pub fn pooling_strategy(&self) -> PoolingStrategy {
|
||||
self.config.pooling
|
||||
}
|
||||
|
||||
/// Get max sequence length
|
||||
pub fn max_length(&self) -> usize {
|
||||
self.config.max_length
|
||||
}
|
||||
|
||||
/// Compute similarity between two texts
|
||||
pub fn similarity(&mut self, text1: &str, text2: &str) -> Result<f32> {
|
||||
let emb1 = self.embed_one(text1)?;
|
||||
let emb2 = self.embed_one(text2)?;
|
||||
Ok(Pooler::cosine_similarity(&emb1, &emb2))
|
||||
}
|
||||
|
||||
/// Find most similar texts from a corpus
|
||||
/// Uses GPU acceleration when available and corpus is large enough
|
||||
#[instrument(skip(self, query, corpus), fields(corpus_size = corpus.len()))]
|
||||
pub fn most_similar<S: AsRef<str>>(
|
||||
&mut self,
|
||||
query: &str,
|
||||
corpus: &[S],
|
||||
top_k: usize,
|
||||
) -> Result<Vec<(usize, f32, String)>> {
|
||||
let query_emb = self.embed_one(query)?;
|
||||
let corpus_embs = self.embed(corpus)?;
|
||||
|
||||
// Try GPU-accelerated similarity if available
|
||||
#[cfg(feature = "gpu")]
|
||||
if let Some(ref gpu) = self.gpu {
|
||||
if corpus.len() >= 64 {
|
||||
let candidates: Vec<&[f32]> = corpus_embs.embeddings.iter().map(|v| v.as_slice()).collect();
|
||||
if let Ok(results) = gpu.top_k_similar(&query_emb, &candidates, top_k) {
|
||||
return Ok(results
|
||||
.into_iter()
|
||||
.map(|(idx, score)| (idx, score, corpus[idx].as_ref().to_string()))
|
||||
.collect());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CPU fallback
|
||||
let mut similarities: Vec<(usize, f32, String)> = corpus_embs
|
||||
.embeddings
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, emb)| {
|
||||
let sim = Pooler::cosine_similarity(&query_emb, emb);
|
||||
(i, sim, corpus[i].as_ref().to_string())
|
||||
})
|
||||
.collect();
|
||||
|
||||
similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
|
||||
similarities.truncate(top_k);
|
||||
|
||||
Ok(similarities)
|
||||
}
|
||||
|
||||
/// Check if GPU acceleration is available
|
||||
pub fn has_gpu(&self) -> bool {
|
||||
#[cfg(feature = "gpu")]
|
||||
{
|
||||
self.gpu.is_some()
|
||||
}
|
||||
#[cfg(not(feature = "gpu"))]
|
||||
{
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Get GPU device info if available
|
||||
#[cfg(feature = "gpu")]
|
||||
pub fn gpu_info(&self) -> Option<crate::gpu::GpuInfo> {
|
||||
self.gpu.as_ref().map(|g| g.device_info())
|
||||
}
|
||||
|
||||
/// Cluster texts by similarity (simple k-means-like approach)
|
||||
#[instrument(skip(self, texts), fields(n_texts = texts.len(), n_clusters))]
|
||||
pub fn cluster<S: AsRef<str>>(
|
||||
&mut self,
|
||||
texts: &[S],
|
||||
n_clusters: usize,
|
||||
) -> Result<Vec<usize>> {
|
||||
let embeddings = self.embed(texts)?;
|
||||
let dim = self.dimension();
|
||||
|
||||
// Initialize centroids with first k embeddings
|
||||
let mut centroids: Vec<Vec<f32>> = embeddings
|
||||
.embeddings
|
||||
.iter()
|
||||
.take(n_clusters)
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
let mut assignments = vec![0usize; texts.len()];
|
||||
let max_iterations = 100;
|
||||
|
||||
for _ in 0..max_iterations {
|
||||
let old_assignments = assignments.clone();
|
||||
|
||||
// Assign to nearest centroid
|
||||
for (i, emb) in embeddings.embeddings.iter().enumerate() {
|
||||
let mut min_dist = f32::MAX;
|
||||
let mut min_idx = 0;
|
||||
|
||||
for (j, centroid) in centroids.iter().enumerate() {
|
||||
let dist = Pooler::euclidean_distance(emb, centroid);
|
||||
if dist < min_dist {
|
||||
min_dist = dist;
|
||||
min_idx = j;
|
||||
}
|
||||
}
|
||||
|
||||
assignments[i] = min_idx;
|
||||
}
|
||||
|
||||
// Check convergence
|
||||
if assignments == old_assignments {
|
||||
break;
|
||||
}
|
||||
|
||||
// Update centroids
|
||||
for (j, centroid) in centroids.iter_mut().enumerate() {
|
||||
let cluster_points: Vec<&Vec<f32>> = embeddings
|
||||
.embeddings
|
||||
.iter()
|
||||
.zip(assignments.iter())
|
||||
.filter(|(_, &a)| a == j)
|
||||
.map(|(e, _)| e)
|
||||
.collect();
|
||||
|
||||
if !cluster_points.is_empty() {
|
||||
*centroid = vec![0.0; dim];
|
||||
for point in &cluster_points {
|
||||
for (k, &val) in point.iter().enumerate() {
|
||||
centroid[k] += val;
|
||||
}
|
||||
}
|
||||
let count = cluster_points.len() as f32;
|
||||
for val in centroid.iter_mut() {
|
||||
*val /= count;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(assignments)
|
||||
}
|
||||
}
|
||||
|
||||
/// Download tokenizer from URL
|
||||
async fn download_tokenizer(url: &str, path: &Path) -> Result<()> {
|
||||
use std::io::Write;
|
||||
|
||||
let response = reqwest::get(url).await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(EmbeddingError::download_failed(format!(
|
||||
"Failed to download tokenizer: HTTP {}",
|
||||
response.status()
|
||||
)));
|
||||
}
|
||||
|
||||
let bytes = response.bytes().await?;
|
||||
let mut file = std::fs::File::create(path)?;
|
||||
file.write_all(&bytes)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Builder for creating embedders with custom configurations
|
||||
pub struct EmbedderBuilder {
|
||||
config: EmbedderConfig,
|
||||
}
|
||||
|
||||
impl EmbedderBuilder {
|
||||
/// Start building an embedder
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
config: EmbedderConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Use a pretrained model
|
||||
pub fn pretrained(mut self, model: PretrainedModel) -> Self {
|
||||
self.config = EmbedderConfig::pretrained(model);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set pooling strategy
|
||||
pub fn pooling(mut self, strategy: PoolingStrategy) -> Self {
|
||||
self.config.pooling = strategy;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set normalization
|
||||
pub fn normalize(mut self, normalize: bool) -> Self {
|
||||
self.config.normalize = normalize;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set batch size
|
||||
pub fn batch_size(mut self, size: usize) -> Self {
|
||||
self.config.batch_size = size;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set max sequence length
|
||||
pub fn max_length(mut self, length: usize) -> Self {
|
||||
self.config.max_length = length;
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the embedder
|
||||
pub async fn build(self) -> Result<Embedder> {
|
||||
Embedder::new(self.config).await
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for EmbedderBuilder {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_default_config() {
|
||||
let config = EmbedderConfig::default();
|
||||
assert_eq!(config.pooling, PoolingStrategy::Mean);
|
||||
assert!(config.normalize);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user