Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
252
vendor/ruvector/examples/onnx-embeddings/src/config.rs
vendored
Normal file
252
vendor/ruvector/examples/onnx-embeddings/src/config.rs
vendored
Normal file
@@ -0,0 +1,252 @@
|
||||
//! Configuration for the ONNX embedder
|
||||
|
||||
use crate::PretrainedModel;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// Source of the ONNX model
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum ModelSource {
|
||||
/// Load from HuggingFace Hub (downloads if not cached)
|
||||
HuggingFace {
|
||||
model_id: String,
|
||||
revision: Option<String>,
|
||||
},
|
||||
/// Load from a local ONNX file
|
||||
Local {
|
||||
model_path: PathBuf,
|
||||
tokenizer_path: PathBuf,
|
||||
},
|
||||
/// Use a pre-configured model
|
||||
Pretrained(PretrainedModel),
|
||||
/// Custom URL for model download
|
||||
Url {
|
||||
model_url: String,
|
||||
tokenizer_url: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl Default for ModelSource {
|
||||
fn default() -> Self {
|
||||
Self::Pretrained(PretrainedModel::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PretrainedModel> for ModelSource {
|
||||
fn from(model: PretrainedModel) -> Self {
|
||||
Self::Pretrained(model)
|
||||
}
|
||||
}
|
||||
|
||||
/// Pooling strategy for combining token embeddings
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||
pub enum PoolingStrategy {
|
||||
/// Mean pooling over all tokens (most common)
|
||||
#[default]
|
||||
Mean,
|
||||
/// Use [CLS] token embedding
|
||||
Cls,
|
||||
/// Max pooling over all tokens
|
||||
Max,
|
||||
/// Mean pooling with sqrt(length) scaling
|
||||
MeanSqrtLen,
|
||||
/// Last token pooling (for decoder models)
|
||||
LastToken,
|
||||
/// Weighted mean based on attention mask
|
||||
WeightedMean,
|
||||
}
|
||||
|
||||
/// Execution provider for ONNX Runtime
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)]
|
||||
pub enum ExecutionProvider {
|
||||
/// CPU inference (default, always available)
|
||||
#[default]
|
||||
Cpu,
|
||||
/// CUDA GPU acceleration
|
||||
Cuda { device_id: i32 },
|
||||
/// TensorRT optimization
|
||||
TensorRt { device_id: i32 },
|
||||
/// CoreML on macOS
|
||||
CoreMl,
|
||||
/// DirectML on Windows
|
||||
DirectMl,
|
||||
/// ROCm for AMD GPUs
|
||||
Rocm { device_id: i32 },
|
||||
}
|
||||
|
||||
/// Configuration for the embedder
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EmbedderConfig {
|
||||
/// Model source
|
||||
pub model_source: ModelSource,
|
||||
/// Pooling strategy
|
||||
pub pooling: PoolingStrategy,
|
||||
/// Whether to normalize embeddings to unit length
|
||||
pub normalize: bool,
|
||||
/// Maximum sequence length (truncation)
|
||||
pub max_length: usize,
|
||||
/// Batch size for inference
|
||||
pub batch_size: usize,
|
||||
/// Number of threads for CPU inference
|
||||
pub num_threads: usize,
|
||||
/// Execution provider
|
||||
pub execution_provider: ExecutionProvider,
|
||||
/// Cache directory for downloaded models
|
||||
pub cache_dir: PathBuf,
|
||||
/// Whether to show progress during downloads
|
||||
pub show_progress: bool,
|
||||
/// Use fp16 inference if available
|
||||
pub use_fp16: bool,
|
||||
/// Enable graph optimization
|
||||
pub optimize_graph: bool,
|
||||
}
|
||||
|
||||
impl Default for EmbedderConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
model_source: ModelSource::default(),
|
||||
pooling: PoolingStrategy::default(),
|
||||
normalize: true,
|
||||
max_length: 256,
|
||||
batch_size: 32,
|
||||
num_threads: num_cpus::get(),
|
||||
execution_provider: ExecutionProvider::default(),
|
||||
cache_dir: default_cache_dir(),
|
||||
show_progress: true,
|
||||
use_fp16: false,
|
||||
optimize_graph: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EmbedderConfig {
|
||||
/// Create a new config builder
|
||||
pub fn builder() -> EmbedderConfigBuilder {
|
||||
EmbedderConfigBuilder::default()
|
||||
}
|
||||
|
||||
/// Create config for a pretrained model
|
||||
pub fn pretrained(model: PretrainedModel) -> Self {
|
||||
Self {
|
||||
model_source: ModelSource::Pretrained(model),
|
||||
max_length: model.max_seq_length(),
|
||||
normalize: model.normalize_output(),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create config for a local model
|
||||
pub fn local(model_path: impl Into<PathBuf>, tokenizer_path: impl Into<PathBuf>) -> Self {
|
||||
Self {
|
||||
model_source: ModelSource::Local {
|
||||
model_path: model_path.into(),
|
||||
tokenizer_path: tokenizer_path.into(),
|
||||
},
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create config for a HuggingFace model
|
||||
pub fn huggingface(model_id: impl Into<String>) -> Self {
|
||||
Self {
|
||||
model_source: ModelSource::HuggingFace {
|
||||
model_id: model_id.into(),
|
||||
revision: None,
|
||||
},
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for EmbedderConfig
|
||||
#[derive(Debug, Default)]
|
||||
pub struct EmbedderConfigBuilder {
|
||||
config: EmbedderConfig,
|
||||
}
|
||||
|
||||
impl EmbedderConfigBuilder {
|
||||
pub fn model_source(mut self, source: ModelSource) -> Self {
|
||||
self.config.model_source = source;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn pretrained(mut self, model: PretrainedModel) -> Self {
|
||||
self.config.model_source = ModelSource::Pretrained(model);
|
||||
self.config.max_length = model.max_seq_length();
|
||||
self.config.normalize = model.normalize_output();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn pooling(mut self, strategy: PoolingStrategy) -> Self {
|
||||
self.config.pooling = strategy;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn normalize(mut self, normalize: bool) -> Self {
|
||||
self.config.normalize = normalize;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn max_length(mut self, length: usize) -> Self {
|
||||
self.config.max_length = length;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn batch_size(mut self, size: usize) -> Self {
|
||||
self.config.batch_size = size;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn num_threads(mut self, threads: usize) -> Self {
|
||||
self.config.num_threads = threads;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn execution_provider(mut self, provider: ExecutionProvider) -> Self {
|
||||
self.config.execution_provider = provider;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn cache_dir(mut self, dir: impl Into<PathBuf>) -> Self {
|
||||
self.config.cache_dir = dir.into();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn show_progress(mut self, show: bool) -> Self {
|
||||
self.config.show_progress = show;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn use_fp16(mut self, use_fp16: bool) -> Self {
|
||||
self.config.use_fp16 = use_fp16;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn optimize_graph(mut self, optimize: bool) -> Self {
|
||||
self.config.optimize_graph = optimize;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> EmbedderConfig {
|
||||
self.config
|
||||
}
|
||||
}
|
||||
|
||||
fn default_cache_dir() -> PathBuf {
|
||||
dirs::cache_dir()
|
||||
.unwrap_or_else(|| PathBuf::from("."))
|
||||
.join("ruvector")
|
||||
.join("onnx-models")
|
||||
}
|
||||
|
||||
fn num_cpus_get() -> usize {
|
||||
std::thread::available_parallelism()
|
||||
.map(|p| p.get())
|
||||
.unwrap_or(4)
|
||||
}
|
||||
|
||||
mod num_cpus {
|
||||
pub fn get() -> usize {
|
||||
super::num_cpus_get()
|
||||
}
|
||||
}
|
||||
471
vendor/ruvector/examples/onnx-embeddings/src/embedder.rs
vendored
Normal file
471
vendor/ruvector/examples/onnx-embeddings/src/embedder.rs
vendored
Normal file
@@ -0,0 +1,471 @@
|
||||
//! Main embedder implementation combining model, tokenizer, and pooling
|
||||
|
||||
use crate::config::{EmbedderConfig, ModelSource, PoolingStrategy};
|
||||
use crate::model::OnnxModel;
|
||||
use crate::pooling::Pooler;
|
||||
use crate::tokenizer::Tokenizer;
|
||||
use crate::{EmbeddingError, PretrainedModel, Result};
|
||||
use std::path::Path;
|
||||
use tracing::{debug, info, instrument};
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::gpu::{GpuAccelerator, GpuConfig};
|
||||
|
||||
/// High-level embedder combining tokenizer, model, and pooling
|
||||
pub struct Embedder {
|
||||
/// ONNX model for inference
|
||||
model: OnnxModel,
|
||||
/// Tokenizer for text processing
|
||||
tokenizer: Tokenizer,
|
||||
/// Pooler for combining token embeddings
|
||||
pooler: Pooler,
|
||||
/// Configuration
|
||||
config: EmbedderConfig,
|
||||
/// Optional GPU accelerator for similarity operations
|
||||
#[cfg(feature = "gpu")]
|
||||
gpu: Option<GpuAccelerator>,
|
||||
}
|
||||
|
||||
/// Embedding output with metadata
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EmbeddingOutput {
|
||||
/// The embedding vectors
|
||||
pub embeddings: Vec<Vec<f32>>,
|
||||
/// Original input texts
|
||||
pub texts: Vec<String>,
|
||||
/// Number of tokens per input
|
||||
pub token_counts: Vec<usize>,
|
||||
/// Embedding dimension
|
||||
pub dimension: usize,
|
||||
}
|
||||
|
||||
impl EmbeddingOutput {
|
||||
/// Get the number of embeddings
|
||||
pub fn len(&self) -> usize {
|
||||
self.embeddings.len()
|
||||
}
|
||||
|
||||
/// Check if empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.embeddings.is_empty()
|
||||
}
|
||||
|
||||
/// Get a single embedding by index
|
||||
pub fn get(&self, index: usize) -> Option<&Vec<f32>> {
|
||||
self.embeddings.get(index)
|
||||
}
|
||||
|
||||
/// Iterate over embeddings
|
||||
pub fn iter(&self) -> impl Iterator<Item = &Vec<f32>> {
|
||||
self.embeddings.iter()
|
||||
}
|
||||
|
||||
/// Convert to owned vectors
|
||||
pub fn into_vecs(self) -> Vec<Vec<f32>> {
|
||||
self.embeddings
|
||||
}
|
||||
}
|
||||
|
||||
impl Embedder {
|
||||
/// Create a new embedder from configuration
|
||||
#[instrument(skip_all)]
|
||||
pub async fn new(config: EmbedderConfig) -> Result<Self> {
|
||||
info!("Initializing embedder");
|
||||
|
||||
// Load model
|
||||
let model = OnnxModel::from_config(&config).await?;
|
||||
|
||||
// Load tokenizer based on model source
|
||||
let tokenizer = match &config.model_source {
|
||||
ModelSource::Local {
|
||||
tokenizer_path, ..
|
||||
} => Tokenizer::from_file(tokenizer_path, config.max_length)?,
|
||||
|
||||
ModelSource::Pretrained(pretrained) => {
|
||||
Tokenizer::from_pretrained(pretrained.model_id(), config.max_length)?
|
||||
}
|
||||
|
||||
ModelSource::HuggingFace { model_id, .. } => {
|
||||
Tokenizer::from_pretrained(model_id, config.max_length)?
|
||||
}
|
||||
|
||||
ModelSource::Url { tokenizer_url, .. } => {
|
||||
// Download tokenizer
|
||||
let cache_path = config.cache_dir.join("tokenizer.json");
|
||||
if !cache_path.exists() {
|
||||
download_tokenizer(tokenizer_url, &cache_path).await?;
|
||||
}
|
||||
Tokenizer::from_file(&cache_path, config.max_length)?
|
||||
}
|
||||
};
|
||||
|
||||
let pooler = Pooler::new(config.pooling, config.normalize);
|
||||
|
||||
// Initialize GPU accelerator if available
|
||||
#[cfg(feature = "gpu")]
|
||||
let gpu = {
|
||||
match GpuAccelerator::new(GpuConfig::auto()).await {
|
||||
Ok(accel) => {
|
||||
info!("GPU accelerator initialized: {}", accel.device_info().name);
|
||||
Some(accel)
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("GPU not available, using CPU: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
model,
|
||||
tokenizer,
|
||||
pooler,
|
||||
config,
|
||||
#[cfg(feature = "gpu")]
|
||||
gpu,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create embedder with default model (all-MiniLM-L6-v2)
|
||||
pub async fn default_model() -> Result<Self> {
|
||||
Self::new(EmbedderConfig::default()).await
|
||||
}
|
||||
|
||||
/// Create embedder for a specific pretrained model
|
||||
pub async fn pretrained(model: PretrainedModel) -> Result<Self> {
|
||||
Self::new(EmbedderConfig::pretrained(model)).await
|
||||
}
|
||||
|
||||
/// Embed a single text
|
||||
#[instrument(skip(self, text), fields(text_len = text.len()))]
|
||||
pub fn embed_one(&mut self, text: &str) -> Result<Vec<f32>> {
|
||||
let output = self.embed(&[text])?;
|
||||
output
|
||||
.embeddings
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or(EmbeddingError::EmptyInput)
|
||||
}
|
||||
|
||||
/// Embed multiple texts
|
||||
#[instrument(skip(self, texts), fields(batch_size = texts.len()))]
|
||||
pub fn embed<S: AsRef<str>>(&mut self, texts: &[S]) -> Result<EmbeddingOutput> {
|
||||
if texts.is_empty() {
|
||||
return Err(EmbeddingError::EmptyInput);
|
||||
}
|
||||
|
||||
let texts_owned: Vec<String> = texts.iter().map(|t| t.as_ref().to_string()).collect();
|
||||
|
||||
// Process in batches
|
||||
let batch_size = self.config.batch_size;
|
||||
let mut all_embeddings = Vec::with_capacity(texts.len());
|
||||
let mut all_token_counts = Vec::with_capacity(texts.len());
|
||||
|
||||
for chunk in texts.chunks(batch_size) {
|
||||
let (embeddings, token_counts) = self.embed_batch(chunk)?;
|
||||
all_embeddings.extend(embeddings);
|
||||
all_token_counts.extend(token_counts);
|
||||
}
|
||||
|
||||
Ok(EmbeddingOutput {
|
||||
embeddings: all_embeddings,
|
||||
texts: texts_owned,
|
||||
token_counts: all_token_counts,
|
||||
dimension: self.model.dimension(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Embed a batch of texts (internal)
|
||||
fn embed_batch<S: AsRef<str>>(&mut self, texts: &[S]) -> Result<(Vec<Vec<f32>>, Vec<usize>)> {
|
||||
debug!("Embedding batch of {} texts", texts.len());
|
||||
|
||||
// Tokenize
|
||||
let encoded = self.tokenizer.encode_batch(texts)?;
|
||||
let (input_ids, attention_mask, token_type_ids, shape) = encoded.to_onnx_inputs();
|
||||
|
||||
// Run model
|
||||
let token_embeddings = self.model.run(
|
||||
&input_ids,
|
||||
&attention_mask,
|
||||
&token_type_ids,
|
||||
&shape,
|
||||
)?;
|
||||
|
||||
let seq_length = shape[1];
|
||||
let hidden_size = self.model.dimension();
|
||||
|
||||
// Pool embeddings
|
||||
let attention_masks: Vec<Vec<i64>> = encoded.attention_mask;
|
||||
let embeddings = self.pooler.pool(
|
||||
&token_embeddings,
|
||||
&attention_masks,
|
||||
seq_length,
|
||||
hidden_size,
|
||||
);
|
||||
|
||||
let token_counts = encoded.original_lengths;
|
||||
|
||||
Ok((embeddings, token_counts))
|
||||
}
|
||||
|
||||
/// Embed texts (sequential processing)
|
||||
/// Note: For parallel processing, consider using tokio::spawn with multiple Embedder instances
|
||||
#[instrument(skip(self, texts), fields(total_texts = texts.len()))]
|
||||
pub fn embed_parallel<S: AsRef<str> + Sync>(&mut self, texts: &[S]) -> Result<EmbeddingOutput> {
|
||||
// Use sequential processing since ONNX session requires mutable access
|
||||
self.embed(texts)
|
||||
}
|
||||
|
||||
/// Process texts one at a time (use embed for batch processing)
|
||||
pub fn embed_each<S: AsRef<str>>(&mut self, texts: &[S]) -> Vec<Result<Vec<f32>>> {
|
||||
texts.iter().map(|text| self.embed_one(text.as_ref())).collect()
|
||||
}
|
||||
|
||||
/// Get the embedding dimension
|
||||
pub fn dimension(&self) -> usize {
|
||||
self.model.dimension()
|
||||
}
|
||||
|
||||
/// Get model info
|
||||
pub fn model_info(&self) -> &crate::model::ModelInfo {
|
||||
self.model.info()
|
||||
}
|
||||
|
||||
/// Get the pooling strategy
|
||||
pub fn pooling_strategy(&self) -> PoolingStrategy {
|
||||
self.config.pooling
|
||||
}
|
||||
|
||||
/// Get max sequence length
|
||||
pub fn max_length(&self) -> usize {
|
||||
self.config.max_length
|
||||
}
|
||||
|
||||
/// Compute similarity between two texts
|
||||
pub fn similarity(&mut self, text1: &str, text2: &str) -> Result<f32> {
|
||||
let emb1 = self.embed_one(text1)?;
|
||||
let emb2 = self.embed_one(text2)?;
|
||||
Ok(Pooler::cosine_similarity(&emb1, &emb2))
|
||||
}
|
||||
|
||||
/// Find most similar texts from a corpus
|
||||
/// Uses GPU acceleration when available and corpus is large enough
|
||||
#[instrument(skip(self, query, corpus), fields(corpus_size = corpus.len()))]
|
||||
pub fn most_similar<S: AsRef<str>>(
|
||||
&mut self,
|
||||
query: &str,
|
||||
corpus: &[S],
|
||||
top_k: usize,
|
||||
) -> Result<Vec<(usize, f32, String)>> {
|
||||
let query_emb = self.embed_one(query)?;
|
||||
let corpus_embs = self.embed(corpus)?;
|
||||
|
||||
// Try GPU-accelerated similarity if available
|
||||
#[cfg(feature = "gpu")]
|
||||
if let Some(ref gpu) = self.gpu {
|
||||
if corpus.len() >= 64 {
|
||||
let candidates: Vec<&[f32]> = corpus_embs.embeddings.iter().map(|v| v.as_slice()).collect();
|
||||
if let Ok(results) = gpu.top_k_similar(&query_emb, &candidates, top_k) {
|
||||
return Ok(results
|
||||
.into_iter()
|
||||
.map(|(idx, score)| (idx, score, corpus[idx].as_ref().to_string()))
|
||||
.collect());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CPU fallback
|
||||
let mut similarities: Vec<(usize, f32, String)> = corpus_embs
|
||||
.embeddings
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, emb)| {
|
||||
let sim = Pooler::cosine_similarity(&query_emb, emb);
|
||||
(i, sim, corpus[i].as_ref().to_string())
|
||||
})
|
||||
.collect();
|
||||
|
||||
similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
|
||||
similarities.truncate(top_k);
|
||||
|
||||
Ok(similarities)
|
||||
}
|
||||
|
||||
/// Check if GPU acceleration is available
|
||||
pub fn has_gpu(&self) -> bool {
|
||||
#[cfg(feature = "gpu")]
|
||||
{
|
||||
self.gpu.is_some()
|
||||
}
|
||||
#[cfg(not(feature = "gpu"))]
|
||||
{
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Get GPU device info if available
|
||||
#[cfg(feature = "gpu")]
|
||||
pub fn gpu_info(&self) -> Option<crate::gpu::GpuInfo> {
|
||||
self.gpu.as_ref().map(|g| g.device_info())
|
||||
}
|
||||
|
||||
/// Cluster texts by similarity (simple k-means-like approach)
|
||||
#[instrument(skip(self, texts), fields(n_texts = texts.len(), n_clusters))]
|
||||
pub fn cluster<S: AsRef<str>>(
|
||||
&mut self,
|
||||
texts: &[S],
|
||||
n_clusters: usize,
|
||||
) -> Result<Vec<usize>> {
|
||||
let embeddings = self.embed(texts)?;
|
||||
let dim = self.dimension();
|
||||
|
||||
// Initialize centroids with first k embeddings
|
||||
let mut centroids: Vec<Vec<f32>> = embeddings
|
||||
.embeddings
|
||||
.iter()
|
||||
.take(n_clusters)
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
let mut assignments = vec![0usize; texts.len()];
|
||||
let max_iterations = 100;
|
||||
|
||||
for _ in 0..max_iterations {
|
||||
let old_assignments = assignments.clone();
|
||||
|
||||
// Assign to nearest centroid
|
||||
for (i, emb) in embeddings.embeddings.iter().enumerate() {
|
||||
let mut min_dist = f32::MAX;
|
||||
let mut min_idx = 0;
|
||||
|
||||
for (j, centroid) in centroids.iter().enumerate() {
|
||||
let dist = Pooler::euclidean_distance(emb, centroid);
|
||||
if dist < min_dist {
|
||||
min_dist = dist;
|
||||
min_idx = j;
|
||||
}
|
||||
}
|
||||
|
||||
assignments[i] = min_idx;
|
||||
}
|
||||
|
||||
// Check convergence
|
||||
if assignments == old_assignments {
|
||||
break;
|
||||
}
|
||||
|
||||
// Update centroids
|
||||
for (j, centroid) in centroids.iter_mut().enumerate() {
|
||||
let cluster_points: Vec<&Vec<f32>> = embeddings
|
||||
.embeddings
|
||||
.iter()
|
||||
.zip(assignments.iter())
|
||||
.filter(|(_, &a)| a == j)
|
||||
.map(|(e, _)| e)
|
||||
.collect();
|
||||
|
||||
if !cluster_points.is_empty() {
|
||||
*centroid = vec![0.0; dim];
|
||||
for point in &cluster_points {
|
||||
for (k, &val) in point.iter().enumerate() {
|
||||
centroid[k] += val;
|
||||
}
|
||||
}
|
||||
let count = cluster_points.len() as f32;
|
||||
for val in centroid.iter_mut() {
|
||||
*val /= count;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(assignments)
|
||||
}
|
||||
}
|
||||
|
||||
/// Download tokenizer from URL
|
||||
async fn download_tokenizer(url: &str, path: &Path) -> Result<()> {
|
||||
use std::io::Write;
|
||||
|
||||
let response = reqwest::get(url).await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(EmbeddingError::download_failed(format!(
|
||||
"Failed to download tokenizer: HTTP {}",
|
||||
response.status()
|
||||
)));
|
||||
}
|
||||
|
||||
let bytes = response.bytes().await?;
|
||||
let mut file = std::fs::File::create(path)?;
|
||||
file.write_all(&bytes)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Builder for creating embedders with custom configurations
|
||||
pub struct EmbedderBuilder {
|
||||
config: EmbedderConfig,
|
||||
}
|
||||
|
||||
impl EmbedderBuilder {
|
||||
/// Start building an embedder
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
config: EmbedderConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Use a pretrained model
|
||||
pub fn pretrained(mut self, model: PretrainedModel) -> Self {
|
||||
self.config = EmbedderConfig::pretrained(model);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set pooling strategy
|
||||
pub fn pooling(mut self, strategy: PoolingStrategy) -> Self {
|
||||
self.config.pooling = strategy;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set normalization
|
||||
pub fn normalize(mut self, normalize: bool) -> Self {
|
||||
self.config.normalize = normalize;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set batch size
|
||||
pub fn batch_size(mut self, size: usize) -> Self {
|
||||
self.config.batch_size = size;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set max sequence length
|
||||
pub fn max_length(mut self, length: usize) -> Self {
|
||||
self.config.max_length = length;
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the embedder
|
||||
pub async fn build(self) -> Result<Embedder> {
|
||||
Embedder::new(self.config).await
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for EmbedderBuilder {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_default_config() {
|
||||
let config = EmbedderConfig::default();
|
||||
assert_eq!(config.pooling, PoolingStrategy::Mean);
|
||||
assert!(config.normalize);
|
||||
}
|
||||
}
|
||||
233
vendor/ruvector/examples/onnx-embeddings/src/error.rs
vendored
Normal file
233
vendor/ruvector/examples/onnx-embeddings/src/error.rs
vendored
Normal file
@@ -0,0 +1,233 @@
|
||||
//! Error types for ONNX embeddings
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
/// Result type alias for embedding operations
|
||||
pub type Result<T> = std::result::Result<T, EmbeddingError>;
|
||||
|
||||
/// Errors that can occur during embedding operations
|
||||
#[derive(Error, Debug)]
|
||||
pub enum EmbeddingError {
|
||||
/// ONNX Runtime error
|
||||
#[error("ONNX Runtime error: {0}")]
|
||||
OnnxRuntime(#[from] ort::Error),
|
||||
|
||||
/// Tokenizer error
|
||||
#[error("Tokenizer error: {0}")]
|
||||
Tokenizer(#[from] tokenizers::tokenizer::Error),
|
||||
|
||||
/// IO error
|
||||
#[error("IO error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
/// HTTP request error
|
||||
#[error("HTTP error: {0}")]
|
||||
Http(#[from] reqwest::Error),
|
||||
|
||||
/// Model not found
|
||||
#[error("Model not found: {path}")]
|
||||
ModelNotFound { path: String },
|
||||
|
||||
/// Tokenizer not found
|
||||
#[error("Tokenizer not found: {path}")]
|
||||
TokenizerNotFound { path: String },
|
||||
|
||||
/// Invalid model format
|
||||
#[error("Invalid model format: {reason}")]
|
||||
InvalidModel { reason: String },
|
||||
|
||||
/// Dimension mismatch
|
||||
#[error("Dimension mismatch: expected {expected}, got {actual}")]
|
||||
DimensionMismatch { expected: usize, actual: usize },
|
||||
|
||||
/// Empty input
|
||||
#[error("Empty input provided")]
|
||||
EmptyInput,
|
||||
|
||||
/// Batch size exceeded
|
||||
#[error("Batch size {size} exceeds maximum {max}")]
|
||||
BatchSizeExceeded { size: usize, max: usize },
|
||||
|
||||
/// Sequence too long
|
||||
#[error("Sequence length {length} exceeds maximum {max}")]
|
||||
SequenceTooLong { length: usize, max: usize },
|
||||
|
||||
/// Download failed
|
||||
#[error("Failed to download model: {reason}")]
|
||||
DownloadFailed { reason: String },
|
||||
|
||||
/// Cache error
|
||||
#[error("Cache error: {reason}")]
|
||||
CacheError { reason: String },
|
||||
|
||||
/// Checksum mismatch
|
||||
#[error("Checksum mismatch: expected {expected}, got {actual}")]
|
||||
ChecksumMismatch { expected: String, actual: String },
|
||||
|
||||
/// Invalid configuration
|
||||
#[error("Invalid configuration: {reason}")]
|
||||
InvalidConfig { reason: String },
|
||||
|
||||
/// Execution provider not available
|
||||
#[error("Execution provider not available: {provider}")]
|
||||
ExecutionProviderNotAvailable { provider: String },
|
||||
|
||||
/// RuVector integration error
|
||||
#[error("RuVector error: {0}")]
|
||||
RuVector(String),
|
||||
|
||||
/// Serialization error
|
||||
#[error("Serialization error: {0}")]
|
||||
Serialization(#[from] serde_json::Error),
|
||||
|
||||
/// Shape error from ndarray
|
||||
#[error("Shape error: {0}")]
|
||||
Shape(#[from] ndarray::ShapeError),
|
||||
|
||||
/// Generic error
|
||||
#[error("{0}")]
|
||||
Other(String),
|
||||
|
||||
/// GPU initialization error
|
||||
#[error("GPU initialization failed: {reason}")]
|
||||
GpuInitFailed { reason: String },
|
||||
|
||||
/// GPU operation error
|
||||
#[error("GPU operation failed: {operation} - {reason}")]
|
||||
GpuOperationFailed { operation: String, reason: String },
|
||||
|
||||
/// Shader compilation error
|
||||
#[error("Shader compilation failed: {shader} - {reason}")]
|
||||
ShaderCompilationFailed { shader: String, reason: String },
|
||||
|
||||
/// GPU buffer error
|
||||
#[error("GPU buffer error: {reason}")]
|
||||
GpuBufferError { reason: String },
|
||||
|
||||
/// GPU not available
|
||||
#[error("GPU not available: {reason}")]
|
||||
GpuNotAvailable { reason: String },
|
||||
}
|
||||
|
||||
impl EmbeddingError {
|
||||
/// Create a model not found error
|
||||
pub fn model_not_found(path: impl Into<String>) -> Self {
|
||||
Self::ModelNotFound { path: path.into() }
|
||||
}
|
||||
|
||||
/// Create a tokenizer not found error
|
||||
pub fn tokenizer_not_found(path: impl Into<String>) -> Self {
|
||||
Self::TokenizerNotFound { path: path.into() }
|
||||
}
|
||||
|
||||
/// Create an invalid model error
|
||||
pub fn invalid_model(reason: impl Into<String>) -> Self {
|
||||
Self::InvalidModel {
|
||||
reason: reason.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a dimension mismatch error
|
||||
pub fn dimension_mismatch(expected: usize, actual: usize) -> Self {
|
||||
Self::DimensionMismatch { expected, actual }
|
||||
}
|
||||
|
||||
/// Create a download failed error
|
||||
pub fn download_failed(reason: impl Into<String>) -> Self {
|
||||
Self::DownloadFailed {
|
||||
reason: reason.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a cache error
|
||||
pub fn cache_error(reason: impl Into<String>) -> Self {
|
||||
Self::CacheError {
|
||||
reason: reason.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an invalid config error
|
||||
pub fn invalid_config(reason: impl Into<String>) -> Self {
|
||||
Self::InvalidConfig {
|
||||
reason: reason.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an execution provider error
|
||||
pub fn execution_provider_not_available(provider: impl Into<String>) -> Self {
|
||||
Self::ExecutionProviderNotAvailable {
|
||||
provider: provider.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a RuVector error
|
||||
pub fn ruvector(msg: impl Into<String>) -> Self {
|
||||
Self::RuVector(msg.into())
|
||||
}
|
||||
|
||||
/// Create a generic error
|
||||
pub fn other(msg: impl Into<String>) -> Self {
|
||||
Self::Other(msg.into())
|
||||
}
|
||||
|
||||
/// Create a GPU initialization error
|
||||
pub fn gpu_init_failed(reason: impl Into<String>) -> Self {
|
||||
Self::GpuInitFailed { reason: reason.into() }
|
||||
}
|
||||
|
||||
/// Create a GPU operation error
|
||||
pub fn gpu_operation_failed(operation: impl Into<String>, reason: impl Into<String>) -> Self {
|
||||
Self::GpuOperationFailed {
|
||||
operation: operation.into(),
|
||||
reason: reason.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a shader compilation error
|
||||
pub fn shader_compilation_failed(shader: impl Into<String>, reason: impl Into<String>) -> Self {
|
||||
Self::ShaderCompilationFailed {
|
||||
shader: shader.into(),
|
||||
reason: reason.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a GPU buffer error
|
||||
pub fn gpu_buffer_error(reason: impl Into<String>) -> Self {
|
||||
Self::GpuBufferError { reason: reason.into() }
|
||||
}
|
||||
|
||||
/// Create a GPU not available error
|
||||
pub fn gpu_not_available(reason: impl Into<String>) -> Self {
|
||||
Self::GpuNotAvailable { reason: reason.into() }
|
||||
}
|
||||
|
||||
/// Check if this error is a GPU error
|
||||
pub fn is_gpu_error(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
Self::GpuInitFailed { .. }
|
||||
| Self::GpuOperationFailed { .. }
|
||||
| Self::ShaderCompilationFailed { .. }
|
||||
| Self::GpuBufferError { .. }
|
||||
| Self::GpuNotAvailable { .. }
|
||||
)
|
||||
}
|
||||
|
||||
/// Check if this error is recoverable
|
||||
pub fn is_recoverable(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
Self::Http(_) | Self::DownloadFailed { .. } | Self::CacheError { .. }
|
||||
)
|
||||
}
|
||||
|
||||
/// Check if this error is a configuration error
|
||||
pub fn is_config_error(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
Self::InvalidConfig { .. }
|
||||
| Self::InvalidModel { .. }
|
||||
| Self::DimensionMismatch { .. }
|
||||
)
|
||||
}
|
||||
}
|
||||
1323
vendor/ruvector/examples/onnx-embeddings/src/gpu/backend.rs
vendored
Normal file
1323
vendor/ruvector/examples/onnx-embeddings/src/gpu/backend.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
293
vendor/ruvector/examples/onnx-embeddings/src/gpu/config.rs
vendored
Normal file
293
vendor/ruvector/examples/onnx-embeddings/src/gpu/config.rs
vendored
Normal file
@@ -0,0 +1,293 @@
|
||||
//! GPU Configuration for RuVector ONNX Embeddings
|
||||
//!
|
||||
//! Provides configuration options for GPU acceleration including
|
||||
//! device selection, memory limits, and performance tuning.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// GPU execution mode
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum GpuMode {
|
||||
/// Automatically select best available backend
|
||||
#[default]
|
||||
Auto,
|
||||
/// Force WebGPU backend
|
||||
WebGpu,
|
||||
/// Force CUDA-WASM transpiled backend
|
||||
CudaWasm,
|
||||
/// CPU-only (disable GPU)
|
||||
CpuOnly,
|
||||
}
|
||||
|
||||
/// Power preference for GPU device selection
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum PowerPreference {
|
||||
/// Prefer low power consumption (integrated GPU)
|
||||
LowPower,
|
||||
/// Prefer high performance (discrete GPU)
|
||||
#[default]
|
||||
HighPerformance,
|
||||
/// No preference
|
||||
None,
|
||||
}
|
||||
|
||||
/// GPU acceleration configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GpuConfig {
|
||||
/// GPU execution mode
|
||||
pub mode: GpuMode,
|
||||
|
||||
/// Power preference for device selection
|
||||
pub power_preference: PowerPreference,
|
||||
|
||||
/// Maximum GPU memory usage (bytes, 0 = unlimited)
|
||||
pub max_memory: u64,
|
||||
|
||||
/// Workgroup size for compute shaders (0 = auto)
|
||||
pub workgroup_size: u32,
|
||||
|
||||
/// Enable async GPU operations
|
||||
pub async_compute: bool,
|
||||
|
||||
/// Minimum batch size to use GPU (smaller batches use CPU)
|
||||
pub min_batch_size: usize,
|
||||
|
||||
/// Minimum vector dimension to use GPU
|
||||
pub min_dimension: usize,
|
||||
|
||||
/// Enable shader caching
|
||||
pub cache_shaders: bool,
|
||||
|
||||
/// Enable profiling and timing
|
||||
pub enable_profiling: bool,
|
||||
|
||||
/// Fallback to CPU on GPU error
|
||||
pub fallback_to_cpu: bool,
|
||||
|
||||
/// Device index (for multi-GPU systems)
|
||||
pub device_index: u32,
|
||||
}
|
||||
|
||||
impl Default for GpuConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
mode: GpuMode::Auto,
|
||||
power_preference: PowerPreference::HighPerformance,
|
||||
max_memory: 0, // unlimited
|
||||
workgroup_size: 256,
|
||||
async_compute: true,
|
||||
min_batch_size: 16,
|
||||
min_dimension: 128,
|
||||
cache_shaders: true,
|
||||
enable_profiling: false,
|
||||
fallback_to_cpu: true,
|
||||
device_index: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl GpuConfig {
|
||||
/// Create configuration with automatic settings
|
||||
pub fn auto() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Create configuration for high performance
|
||||
pub fn high_performance() -> Self {
|
||||
Self {
|
||||
mode: GpuMode::Auto,
|
||||
power_preference: PowerPreference::HighPerformance,
|
||||
workgroup_size: 512,
|
||||
async_compute: true,
|
||||
min_batch_size: 8,
|
||||
min_dimension: 64,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create configuration for low power usage
|
||||
pub fn low_power() -> Self {
|
||||
Self {
|
||||
mode: GpuMode::Auto,
|
||||
power_preference: PowerPreference::LowPower,
|
||||
workgroup_size: 128,
|
||||
async_compute: false,
|
||||
min_batch_size: 32,
|
||||
min_dimension: 256,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create CPU-only configuration
|
||||
pub fn cpu_only() -> Self {
|
||||
Self {
|
||||
mode: GpuMode::CpuOnly,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create WebGPU-specific configuration
|
||||
pub fn webgpu() -> Self {
|
||||
Self {
|
||||
mode: GpuMode::WebGpu,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create CUDA-WASM specific configuration
|
||||
#[cfg(feature = "cuda-wasm")]
|
||||
pub fn cuda_wasm() -> Self {
|
||||
Self {
|
||||
mode: GpuMode::CudaWasm,
|
||||
workgroup_size: 256,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
// Builder methods
|
||||
|
||||
/// Set GPU mode
|
||||
pub fn with_mode(mut self, mode: GpuMode) -> Self {
|
||||
self.mode = mode;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set power preference
|
||||
pub fn with_power_preference(mut self, pref: PowerPreference) -> Self {
|
||||
self.power_preference = pref;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set maximum memory
|
||||
pub fn with_max_memory(mut self, bytes: u64) -> Self {
|
||||
self.max_memory = bytes;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set workgroup size
|
||||
pub fn with_workgroup_size(mut self, size: u32) -> Self {
|
||||
self.workgroup_size = size;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set minimum batch size for GPU usage
|
||||
pub fn with_min_batch_size(mut self, size: usize) -> Self {
|
||||
self.min_batch_size = size;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set minimum dimension for GPU usage
|
||||
pub fn with_min_dimension(mut self, dim: usize) -> Self {
|
||||
self.min_dimension = dim;
|
||||
self
|
||||
}
|
||||
|
||||
/// Enable or disable profiling
|
||||
pub fn with_profiling(mut self, enable: bool) -> Self {
|
||||
self.enable_profiling = enable;
|
||||
self
|
||||
}
|
||||
|
||||
/// Enable or disable CPU fallback
|
||||
pub fn with_fallback(mut self, enable: bool) -> Self {
|
||||
self.fallback_to_cpu = enable;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set device index
|
||||
pub fn with_device(mut self, index: u32) -> Self {
|
||||
self.device_index = index;
|
||||
self
|
||||
}
|
||||
|
||||
/// Check if GPU should be used for given workload
|
||||
pub fn should_use_gpu(&self, batch_size: usize, dimension: usize) -> bool {
|
||||
self.mode != GpuMode::CpuOnly
|
||||
&& batch_size >= self.min_batch_size
|
||||
&& dimension >= self.min_dimension
|
||||
}
|
||||
}
|
||||
|
||||
/// GPU memory statistics
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct GpuMemoryStats {
|
||||
/// Total GPU memory (bytes)
|
||||
pub total: u64,
|
||||
/// Used GPU memory (bytes)
|
||||
pub used: u64,
|
||||
/// Free GPU memory (bytes)
|
||||
pub free: u64,
|
||||
/// Peak usage (bytes)
|
||||
pub peak: u64,
|
||||
}
|
||||
|
||||
impl GpuMemoryStats {
|
||||
/// Get usage percentage
|
||||
pub fn usage_percent(&self) -> f32 {
|
||||
if self.total > 0 {
|
||||
(self.used as f32 / self.total as f32) * 100.0
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// GPU profiling data
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct GpuProfilingData {
|
||||
/// Total operations executed
|
||||
pub operations: u64,
|
||||
/// Total GPU time (microseconds)
|
||||
pub gpu_time_us: u64,
|
||||
/// Total CPU time (microseconds)
|
||||
pub cpu_time_us: u64,
|
||||
/// GPU speedup over CPU
|
||||
pub speedup: f32,
|
||||
/// Memory transfers (bytes)
|
||||
pub memory_transferred: u64,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_default_config() {
|
||||
let config = GpuConfig::default();
|
||||
assert_eq!(config.mode, GpuMode::Auto);
|
||||
assert_eq!(config.power_preference, PowerPreference::HighPerformance);
|
||||
assert!(config.fallback_to_cpu);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_should_use_gpu() {
|
||||
let config = GpuConfig::default()
|
||||
.with_min_batch_size(16)
|
||||
.with_min_dimension(128);
|
||||
|
||||
assert!(!config.should_use_gpu(8, 384)); // batch too small
|
||||
assert!(!config.should_use_gpu(32, 64)); // dimension too small
|
||||
assert!(config.should_use_gpu(32, 384)); // both ok
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cpu_only() {
|
||||
let config = GpuConfig::cpu_only();
|
||||
assert!(!config.should_use_gpu(1000, 1000));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_builder() {
|
||||
let config = GpuConfig::auto()
|
||||
.with_mode(GpuMode::WebGpu)
|
||||
.with_max_memory(1024 * 1024 * 1024)
|
||||
.with_workgroup_size(512)
|
||||
.with_profiling(true);
|
||||
|
||||
assert_eq!(config.mode, GpuMode::WebGpu);
|
||||
assert_eq!(config.max_memory, 1024 * 1024 * 1024);
|
||||
assert_eq!(config.workgroup_size, 512);
|
||||
assert!(config.enable_profiling);
|
||||
}
|
||||
}
|
||||
298
vendor/ruvector/examples/onnx-embeddings/src/gpu/mod.rs
vendored
Normal file
298
vendor/ruvector/examples/onnx-embeddings/src/gpu/mod.rs
vendored
Normal file
@@ -0,0 +1,298 @@
|
||||
//! GPU Acceleration Module for RuVector ONNX Embeddings
|
||||
//!
|
||||
//! This module provides optional GPU acceleration using cuda-wasm for:
|
||||
//! - Pooling operations
|
||||
//! - Similarity computations
|
||||
//! - Batch vector operations
|
||||
//!
|
||||
//! ## Architecture
|
||||
//!
|
||||
//! ```text
|
||||
//! ┌─────────────────────────────────────────────────────────────────┐
|
||||
//! │ GPU Acceleration Layer │
|
||||
//! ├─────────────────────────────────────────────────────────────────┤
|
||||
//! │ ┌─────────────┐ ┌─────────────┐ ┌─────────────────────┐ │
|
||||
//! │ │ GpuBackend │ -> │ Shaders │ -> │ WebGPU Runtime │ │
|
||||
//! │ │ (Trait) │ │ (WGSL) │ │ (wgpu) │ │
|
||||
//! │ └─────────────┘ └─────────────┘ └─────────────────────┘ │
|
||||
//! │ │ │ │
|
||||
//! │ v v │
|
||||
//! │ ┌─────────────┐ ┌─────────────┐ ┌─────────────────────┐ │
|
||||
//! │ │ GpuPooler │ │ GpuSimilar │ │ GpuVectorOps │ │
|
||||
//! │ │ │ │ │ │ │ │
|
||||
//! │ └─────────────┘ └─────────────┘ └─────────────────────┘ │
|
||||
//! └─────────────────────────────────────────────────────────────────┘
|
||||
//! ```
|
||||
//!
|
||||
//! ## Feature Flags
|
||||
//!
|
||||
//! - `gpu`: Enable GPU acceleration (WebGPU backend)
|
||||
//! - `cuda-wasm`: Enable CUDA-WASM transpilation support
|
||||
//!
|
||||
//! ## Usage
|
||||
//!
|
||||
//! ```rust,ignore
|
||||
//! use ruvector_onnx_embeddings::gpu::{GpuAccelerator, GpuConfig};
|
||||
//!
|
||||
//! // Create GPU accelerator with auto-detection
|
||||
//! let gpu = GpuAccelerator::new(GpuConfig::auto()).await?;
|
||||
//!
|
||||
//! // GPU-accelerated similarity search
|
||||
//! let similarities = gpu.batch_cosine_similarity(&query, &candidates)?;
|
||||
//!
|
||||
//! // GPU-accelerated pooling
|
||||
//! let pooled = gpu.mean_pool(&token_embeddings, &attention_mask)?;
|
||||
//! ```
|
||||
|
||||
mod backend;
|
||||
mod config;
|
||||
mod operations;
|
||||
mod shaders;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
pub use backend::{GpuBackend, GpuDevice, GpuInfo};
|
||||
pub use config::{GpuConfig, GpuMode, PowerPreference};
|
||||
pub use operations::{
|
||||
GpuPooler, GpuSimilarity, GpuVectorOps,
|
||||
batch_cosine_similarity_gpu, batch_dot_product_gpu, batch_euclidean_gpu,
|
||||
};
|
||||
pub use shaders::ShaderRegistry;
|
||||
|
||||
use crate::Result;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// GPU Accelerator - Main entry point for GPU operations
|
||||
///
|
||||
/// Provides unified access to GPU-accelerated operations with automatic
|
||||
/// fallback to CPU when GPU is unavailable.
|
||||
pub struct GpuAccelerator {
|
||||
backend: Arc<dyn GpuBackend>,
|
||||
config: GpuConfig,
|
||||
pooler: GpuPooler,
|
||||
similarity: GpuSimilarity,
|
||||
vector_ops: GpuVectorOps,
|
||||
}
|
||||
|
||||
impl GpuAccelerator {
|
||||
/// Create a new GPU accelerator with the given configuration
|
||||
pub async fn new(config: GpuConfig) -> Result<Self> {
|
||||
let backend: Arc<dyn GpuBackend> = Arc::from(backend::create_backend(&config).await?);
|
||||
let shader_registry = ShaderRegistry::new();
|
||||
|
||||
let mut pooler = GpuPooler::new(backend.as_ref(), &shader_registry)?;
|
||||
let mut similarity = GpuSimilarity::new(backend.as_ref(), &shader_registry)?;
|
||||
let mut vector_ops = GpuVectorOps::new(backend.as_ref(), &shader_registry)?;
|
||||
|
||||
// Wire up the backend to all components for GPU dispatch
|
||||
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
|
||||
{
|
||||
pooler.set_backend(Arc::clone(&backend));
|
||||
similarity.set_backend(Arc::clone(&backend));
|
||||
vector_ops.set_backend(Arc::clone(&backend));
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
backend,
|
||||
config,
|
||||
pooler,
|
||||
similarity,
|
||||
vector_ops,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create with automatic configuration
|
||||
pub async fn auto() -> Result<Self> {
|
||||
Self::new(GpuConfig::auto()).await
|
||||
}
|
||||
|
||||
/// Check if GPU acceleration is available
|
||||
pub fn is_available(&self) -> bool {
|
||||
self.backend.is_available()
|
||||
}
|
||||
|
||||
/// Get GPU device information
|
||||
pub fn device_info(&self) -> GpuInfo {
|
||||
self.backend.device_info()
|
||||
}
|
||||
|
||||
/// Get the current configuration
|
||||
pub fn config(&self) -> &GpuConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
// ==================== Pooling Operations ====================
|
||||
|
||||
/// Mean pooling over token embeddings (GPU-accelerated)
|
||||
pub fn mean_pool(
|
||||
&self,
|
||||
token_embeddings: &[f32],
|
||||
attention_mask: &[i64],
|
||||
batch_size: usize,
|
||||
seq_length: usize,
|
||||
hidden_size: usize,
|
||||
) -> Result<Vec<f32>> {
|
||||
self.pooler.mean_pool(
|
||||
token_embeddings,
|
||||
attention_mask,
|
||||
batch_size,
|
||||
seq_length,
|
||||
hidden_size,
|
||||
)
|
||||
}
|
||||
|
||||
/// CLS token pooling (GPU-accelerated)
|
||||
pub fn cls_pool(
|
||||
&self,
|
||||
token_embeddings: &[f32],
|
||||
batch_size: usize,
|
||||
hidden_size: usize,
|
||||
) -> Result<Vec<f32>> {
|
||||
self.pooler.cls_pool(token_embeddings, batch_size, hidden_size)
|
||||
}
|
||||
|
||||
/// Max pooling over token embeddings (GPU-accelerated)
|
||||
pub fn max_pool(
|
||||
&self,
|
||||
token_embeddings: &[f32],
|
||||
attention_mask: &[i64],
|
||||
batch_size: usize,
|
||||
seq_length: usize,
|
||||
hidden_size: usize,
|
||||
) -> Result<Vec<f32>> {
|
||||
self.pooler.max_pool(
|
||||
token_embeddings,
|
||||
attention_mask,
|
||||
batch_size,
|
||||
seq_length,
|
||||
hidden_size,
|
||||
)
|
||||
}
|
||||
|
||||
// ==================== Similarity Operations ====================
|
||||
|
||||
/// Batch cosine similarity (GPU-accelerated)
|
||||
pub fn batch_cosine_similarity(
|
||||
&self,
|
||||
query: &[f32],
|
||||
candidates: &[&[f32]],
|
||||
) -> Result<Vec<f32>> {
|
||||
self.similarity.batch_cosine(query, candidates)
|
||||
}
|
||||
|
||||
/// Batch dot product (GPU-accelerated)
|
||||
pub fn batch_dot_product(
|
||||
&self,
|
||||
query: &[f32],
|
||||
candidates: &[&[f32]],
|
||||
) -> Result<Vec<f32>> {
|
||||
self.similarity.batch_dot_product(query, candidates)
|
||||
}
|
||||
|
||||
/// Batch Euclidean distance (GPU-accelerated)
|
||||
pub fn batch_euclidean_distance(
|
||||
&self,
|
||||
query: &[f32],
|
||||
candidates: &[&[f32]],
|
||||
) -> Result<Vec<f32>> {
|
||||
self.similarity.batch_euclidean(query, candidates)
|
||||
}
|
||||
|
||||
/// Find top-k most similar vectors (GPU-accelerated)
|
||||
pub fn top_k_similar(
|
||||
&self,
|
||||
query: &[f32],
|
||||
candidates: &[&[f32]],
|
||||
k: usize,
|
||||
) -> Result<Vec<(usize, f32)>> {
|
||||
self.similarity.top_k(query, candidates, k)
|
||||
}
|
||||
|
||||
// ==================== Vector Operations ====================
|
||||
|
||||
/// L2 normalize vectors (GPU-accelerated)
|
||||
pub fn normalize_batch(&self, vectors: &mut [f32], dimension: usize) -> Result<()> {
|
||||
self.vector_ops.normalize_batch(vectors, dimension)
|
||||
}
|
||||
|
||||
/// Matrix-vector multiplication (GPU-accelerated)
|
||||
pub fn matmul(
|
||||
&self,
|
||||
matrix: &[f32],
|
||||
vector: &[f32],
|
||||
rows: usize,
|
||||
cols: usize,
|
||||
) -> Result<Vec<f32>> {
|
||||
self.vector_ops.matmul(matrix, vector, rows, cols)
|
||||
}
|
||||
|
||||
/// Batch vector addition (GPU-accelerated)
|
||||
pub fn batch_add(&self, a: &[f32], b: &[f32]) -> Result<Vec<f32>> {
|
||||
self.vector_ops.batch_add(a, b)
|
||||
}
|
||||
|
||||
/// Batch vector scaling (GPU-accelerated)
|
||||
pub fn batch_scale(&self, vectors: &mut [f32], scale: f32) -> Result<()> {
|
||||
self.vector_ops.batch_scale(vectors, scale)
|
||||
}
|
||||
}
|
||||
|
||||
/// Convenience function to check GPU availability without creating accelerator
|
||||
pub async fn is_gpu_available() -> bool {
|
||||
backend::probe_gpu().await
|
||||
}
|
||||
|
||||
/// Get GPU device info without full initialization
|
||||
pub async fn get_gpu_info() -> Option<GpuInfo> {
|
||||
backend::get_device_info().await
|
||||
}
|
||||
|
||||
/// Fallback wrapper that tries GPU first, then CPU
|
||||
pub struct HybridAccelerator {
|
||||
gpu: Option<GpuAccelerator>,
|
||||
use_gpu: bool,
|
||||
}
|
||||
|
||||
impl HybridAccelerator {
|
||||
/// Create hybrid accelerator with GPU if available
|
||||
pub async fn new() -> Self {
|
||||
let gpu = GpuAccelerator::auto().await.ok();
|
||||
let use_gpu = gpu.is_some();
|
||||
Self { gpu, use_gpu }
|
||||
}
|
||||
|
||||
/// Check if GPU is being used
|
||||
pub fn using_gpu(&self) -> bool {
|
||||
self.use_gpu && self.gpu.is_some()
|
||||
}
|
||||
|
||||
/// Disable GPU (use CPU only)
|
||||
pub fn disable_gpu(&mut self) {
|
||||
self.use_gpu = false;
|
||||
}
|
||||
|
||||
/// Enable GPU if available
|
||||
pub fn enable_gpu(&mut self) {
|
||||
self.use_gpu = self.gpu.is_some();
|
||||
}
|
||||
|
||||
/// Batch cosine similarity with automatic backend selection
|
||||
pub fn batch_cosine_similarity(
|
||||
&self,
|
||||
query: &[f32],
|
||||
candidates: &[Vec<f32>],
|
||||
) -> Vec<f32> {
|
||||
if self.use_gpu {
|
||||
if let Some(ref gpu) = self.gpu {
|
||||
let refs: Vec<&[f32]> = candidates.iter().map(|v| v.as_slice()).collect();
|
||||
if let Ok(result) = gpu.batch_cosine_similarity(query, &refs) {
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CPU fallback
|
||||
crate::pooling::batch_cosine_similarity(query, candidates)
|
||||
}
|
||||
}
|
||||
934
vendor/ruvector/examples/onnx-embeddings/src/gpu/operations.rs
vendored
Normal file
934
vendor/ruvector/examples/onnx-embeddings/src/gpu/operations.rs
vendored
Normal file
@@ -0,0 +1,934 @@
|
||||
//! GPU-Accelerated Operations
|
||||
//!
|
||||
//! High-level GPU operations for embeddings with automatic fallback to CPU.
|
||||
|
||||
use crate::{EmbeddingError, Result};
|
||||
use super::backend::{GpuBackend, BufferUsage};
|
||||
use super::shaders::ShaderRegistry;
|
||||
use rayon::prelude::*;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
|
||||
use bytemuck;
|
||||
|
||||
// ==================== GPU Pooler ====================
|
||||
|
||||
/// GPU-accelerated pooling operations
|
||||
pub struct GpuPooler {
|
||||
use_gpu: bool,
|
||||
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
|
||||
backend: Option<Arc<dyn GpuBackend>>,
|
||||
}
|
||||
|
||||
impl GpuPooler {
|
||||
/// Create new GPU pooler
|
||||
pub fn new(backend: &dyn GpuBackend, _shaders: &ShaderRegistry) -> Result<Self> {
|
||||
let use_gpu = backend.is_available() && backend.device_info().supports_compute;
|
||||
|
||||
Ok(Self {
|
||||
use_gpu,
|
||||
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
|
||||
backend: None, // Will be set by GpuAccelerator
|
||||
})
|
||||
}
|
||||
|
||||
/// Set the backend for GPU operations
|
||||
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
|
||||
pub fn set_backend(&mut self, backend: Arc<dyn GpuBackend>) {
|
||||
self.backend = Some(backend);
|
||||
}
|
||||
|
||||
/// Mean pooling (GPU or CPU fallback)
|
||||
pub fn mean_pool(
|
||||
&self,
|
||||
token_embeddings: &[f32],
|
||||
attention_mask: &[i64],
|
||||
batch_size: usize,
|
||||
seq_length: usize,
|
||||
hidden_size: usize,
|
||||
) -> Result<Vec<f32>> {
|
||||
// GPU implementation requires minimum batch size for efficiency
|
||||
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
|
||||
if self.use_gpu && batch_size >= 8 && self.backend.is_some() {
|
||||
return self.mean_pool_gpu(token_embeddings, attention_mask, batch_size, seq_length, hidden_size);
|
||||
}
|
||||
|
||||
Ok(self.mean_pool_cpu(token_embeddings, attention_mask, batch_size, seq_length, hidden_size))
|
||||
}
|
||||
|
||||
/// CLS pooling (GPU or CPU fallback)
|
||||
pub fn cls_pool(
|
||||
&self,
|
||||
token_embeddings: &[f32],
|
||||
batch_size: usize,
|
||||
hidden_size: usize,
|
||||
) -> Result<Vec<f32>> {
|
||||
// CLS pooling is simple copy, CPU is often faster
|
||||
Ok(self.cls_pool_cpu(token_embeddings, batch_size, hidden_size))
|
||||
}
|
||||
|
||||
/// Max pooling (GPU or CPU fallback)
|
||||
pub fn max_pool(
|
||||
&self,
|
||||
token_embeddings: &[f32],
|
||||
attention_mask: &[i64],
|
||||
batch_size: usize,
|
||||
seq_length: usize,
|
||||
hidden_size: usize,
|
||||
) -> Result<Vec<f32>> {
|
||||
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
|
||||
if self.use_gpu && batch_size >= 8 && self.backend.is_some() {
|
||||
return self.max_pool_gpu(token_embeddings, attention_mask, batch_size, seq_length, hidden_size);
|
||||
}
|
||||
|
||||
Ok(self.max_pool_cpu(token_embeddings, attention_mask, batch_size, seq_length, hidden_size))
|
||||
}
|
||||
|
||||
// GPU implementations
|
||||
|
||||
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
|
||||
fn mean_pool_gpu(
|
||||
&self,
|
||||
token_embeddings: &[f32],
|
||||
attention_mask: &[i64],
|
||||
batch_size: usize,
|
||||
seq_length: usize,
|
||||
hidden_size: usize,
|
||||
) -> Result<Vec<f32>> {
|
||||
let backend = self.backend.as_ref().ok_or_else(|| {
|
||||
EmbeddingError::GpuOperationFailed {
|
||||
operation: "mean_pool".to_string(),
|
||||
reason: "Backend not initialized".to_string(),
|
||||
}
|
||||
})?;
|
||||
|
||||
// Create buffers
|
||||
let token_buf = backend.create_buffer(
|
||||
(token_embeddings.len() * 4) as u64,
|
||||
BufferUsage::Storage,
|
||||
)?;
|
||||
let mask_buf = backend.create_buffer(
|
||||
(attention_mask.len() * 8) as u64,
|
||||
BufferUsage::Storage,
|
||||
)?;
|
||||
let output_buf = backend.create_buffer(
|
||||
(batch_size * hidden_size * 4) as u64,
|
||||
BufferUsage::Storage,
|
||||
)?;
|
||||
|
||||
// Create params buffer (batch_size, seq_length, hidden_size)
|
||||
let params: [u32; 3] = [batch_size as u32, seq_length as u32, hidden_size as u32];
|
||||
let params_buf = backend.create_buffer(16, BufferUsage::Uniform)?; // 16 bytes aligned
|
||||
backend.write_buffer(¶ms_buf, bytemuck::cast_slice(¶ms))?;
|
||||
|
||||
// Write input data
|
||||
backend.write_buffer(&token_buf, bytemuck::cast_slice(token_embeddings))?;
|
||||
backend.write_buffer(&mask_buf, bytemuck::cast_slice(attention_mask))?;
|
||||
|
||||
// Create pipeline with mean pool shader
|
||||
let shader = super::shaders::MEAN_POOL_SHADER;
|
||||
let pipeline = backend.create_pipeline(shader, "mean_pool", [64, 1, 1])?;
|
||||
|
||||
// Dispatch with params buffer as 4th binding
|
||||
let total_outputs = batch_size * hidden_size;
|
||||
let workgroups = [total_outputs.div_ceil(64) as u32, 1, 1];
|
||||
backend.dispatch(&pipeline, &[&token_buf, &mask_buf, &output_buf, ¶ms_buf], workgroups)?;
|
||||
backend.sync()?;
|
||||
|
||||
// Read output
|
||||
let output_bytes = backend.read_buffer(&output_buf, (batch_size * hidden_size * 4) as u64)?;
|
||||
let output: Vec<f32> = bytemuck::cast_slice(&output_bytes).to_vec();
|
||||
|
||||
// Cleanup
|
||||
backend.release_buffer(token_buf)?;
|
||||
backend.release_buffer(mask_buf)?;
|
||||
backend.release_buffer(output_buf)?;
|
||||
backend.release_buffer(params_buf)?;
|
||||
backend.release_pipeline(pipeline)?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
|
||||
fn max_pool_gpu(
|
||||
&self,
|
||||
token_embeddings: &[f32],
|
||||
attention_mask: &[i64],
|
||||
batch_size: usize,
|
||||
seq_length: usize,
|
||||
hidden_size: usize,
|
||||
) -> Result<Vec<f32>> {
|
||||
let backend = self.backend.as_ref().ok_or_else(|| {
|
||||
EmbeddingError::GpuOperationFailed {
|
||||
operation: "max_pool".to_string(),
|
||||
reason: "Backend not initialized".to_string(),
|
||||
}
|
||||
})?;
|
||||
|
||||
// Create buffers
|
||||
let token_buf = backend.create_buffer(
|
||||
(token_embeddings.len() * 4) as u64,
|
||||
BufferUsage::Storage,
|
||||
)?;
|
||||
let mask_buf = backend.create_buffer(
|
||||
(attention_mask.len() * 8) as u64,
|
||||
BufferUsage::Storage,
|
||||
)?;
|
||||
let output_buf = backend.create_buffer(
|
||||
(batch_size * hidden_size * 4) as u64,
|
||||
BufferUsage::Storage,
|
||||
)?;
|
||||
|
||||
// Create params buffer (batch_size, seq_length, hidden_size)
|
||||
let params: [u32; 3] = [batch_size as u32, seq_length as u32, hidden_size as u32];
|
||||
let params_buf = backend.create_buffer(16, BufferUsage::Uniform)?;
|
||||
backend.write_buffer(¶ms_buf, bytemuck::cast_slice(¶ms))?;
|
||||
|
||||
// Write input data
|
||||
backend.write_buffer(&token_buf, bytemuck::cast_slice(token_embeddings))?;
|
||||
backend.write_buffer(&mask_buf, bytemuck::cast_slice(attention_mask))?;
|
||||
|
||||
// Create pipeline with max pool shader
|
||||
let shader = super::shaders::MAX_POOL_SHADER;
|
||||
let pipeline = backend.create_pipeline(shader, "max_pool", [64, 1, 1])?;
|
||||
|
||||
// Dispatch with params buffer as 4th binding
|
||||
let total_outputs = batch_size * hidden_size;
|
||||
let workgroups = [total_outputs.div_ceil(64) as u32, 1, 1];
|
||||
backend.dispatch(&pipeline, &[&token_buf, &mask_buf, &output_buf, ¶ms_buf], workgroups)?;
|
||||
backend.sync()?;
|
||||
|
||||
// Read output
|
||||
let output_bytes = backend.read_buffer(&output_buf, (batch_size * hidden_size * 4) as u64)?;
|
||||
let output: Vec<f32> = bytemuck::cast_slice(&output_bytes).to_vec();
|
||||
|
||||
// Cleanup
|
||||
backend.release_buffer(token_buf)?;
|
||||
backend.release_buffer(mask_buf)?;
|
||||
backend.release_buffer(output_buf)?;
|
||||
backend.release_buffer(params_buf)?;
|
||||
backend.release_pipeline(pipeline)?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
// CPU implementations
|
||||
|
||||
fn mean_pool_cpu(
|
||||
&self,
|
||||
token_embeddings: &[f32],
|
||||
attention_mask: &[i64],
|
||||
batch_size: usize,
|
||||
seq_length: usize,
|
||||
hidden_size: usize,
|
||||
) -> Vec<f32> {
|
||||
let mut output = vec![0.0f32; batch_size * hidden_size];
|
||||
|
||||
output
|
||||
.par_chunks_mut(hidden_size)
|
||||
.enumerate()
|
||||
.for_each(|(batch_idx, out_chunk)| {
|
||||
let tokens_base = batch_idx * seq_length * hidden_size;
|
||||
let mask_base = batch_idx * seq_length;
|
||||
|
||||
let mut count = 0.0f32;
|
||||
|
||||
for seq_idx in 0..seq_length {
|
||||
if attention_mask[mask_base + seq_idx] == 1 {
|
||||
let start = tokens_base + seq_idx * hidden_size;
|
||||
for (j, out_val) in out_chunk.iter_mut().enumerate() {
|
||||
*out_val += token_embeddings[start + j];
|
||||
}
|
||||
count += 1.0;
|
||||
}
|
||||
}
|
||||
|
||||
if count > 0.0 {
|
||||
for val in out_chunk.iter_mut() {
|
||||
*val /= count;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
fn cls_pool_cpu(
|
||||
&self,
|
||||
token_embeddings: &[f32],
|
||||
batch_size: usize,
|
||||
hidden_size: usize,
|
||||
) -> Vec<f32> {
|
||||
let seq_length = token_embeddings.len() / (batch_size * hidden_size);
|
||||
let mut output = vec![0.0f32; batch_size * hidden_size];
|
||||
|
||||
for batch_idx in 0..batch_size {
|
||||
let src_start = batch_idx * seq_length * hidden_size;
|
||||
let dst_start = batch_idx * hidden_size;
|
||||
output[dst_start..dst_start + hidden_size]
|
||||
.copy_from_slice(&token_embeddings[src_start..src_start + hidden_size]);
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
fn max_pool_cpu(
|
||||
&self,
|
||||
token_embeddings: &[f32],
|
||||
attention_mask: &[i64],
|
||||
batch_size: usize,
|
||||
seq_length: usize,
|
||||
hidden_size: usize,
|
||||
) -> Vec<f32> {
|
||||
let mut output = vec![f32::NEG_INFINITY; batch_size * hidden_size];
|
||||
|
||||
output
|
||||
.par_chunks_mut(hidden_size)
|
||||
.enumerate()
|
||||
.for_each(|(batch_idx, out_chunk)| {
|
||||
let tokens_base = batch_idx * seq_length * hidden_size;
|
||||
let mask_base = batch_idx * seq_length;
|
||||
|
||||
for seq_idx in 0..seq_length {
|
||||
if attention_mask[mask_base + seq_idx] == 1 {
|
||||
let start = tokens_base + seq_idx * hidden_size;
|
||||
for (j, out_val) in out_chunk.iter_mut().enumerate() {
|
||||
let val = token_embeddings[start + j];
|
||||
if val > *out_val {
|
||||
*out_val = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Replace -inf with 0
|
||||
for val in out_chunk.iter_mut() {
|
||||
if val.is_infinite() {
|
||||
*val = 0.0;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== GPU Similarity ====================
|
||||
|
||||
/// GPU-accelerated similarity computations
|
||||
pub struct GpuSimilarity {
|
||||
use_gpu: bool,
|
||||
min_candidates: usize,
|
||||
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
|
||||
backend: Option<Arc<dyn GpuBackend>>,
|
||||
}
|
||||
|
||||
impl GpuSimilarity {
|
||||
/// Create new GPU similarity calculator
|
||||
pub fn new(backend: &dyn GpuBackend, _shaders: &ShaderRegistry) -> Result<Self> {
|
||||
Ok(Self {
|
||||
use_gpu: backend.is_available() && backend.device_info().supports_compute,
|
||||
min_candidates: 64, // Minimum candidates to use GPU
|
||||
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
|
||||
backend: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Set the backend for GPU operations
|
||||
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
|
||||
pub fn set_backend(&mut self, backend: Arc<dyn GpuBackend>) {
|
||||
self.backend = Some(backend);
|
||||
}
|
||||
|
||||
/// Batch cosine similarity
|
||||
pub fn batch_cosine(&self, query: &[f32], candidates: &[&[f32]]) -> Result<Vec<f32>> {
|
||||
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
|
||||
if self.use_gpu && candidates.len() >= self.min_candidates && self.backend.is_some() {
|
||||
return self.batch_cosine_gpu(query, candidates);
|
||||
}
|
||||
|
||||
Ok(self.batch_cosine_cpu(query, candidates))
|
||||
}
|
||||
|
||||
/// Batch dot product
|
||||
pub fn batch_dot_product(&self, query: &[f32], candidates: &[&[f32]]) -> Result<Vec<f32>> {
|
||||
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
|
||||
if self.use_gpu && candidates.len() >= self.min_candidates && self.backend.is_some() {
|
||||
return self.batch_dot_product_gpu(query, candidates);
|
||||
}
|
||||
|
||||
Ok(self.batch_dot_product_cpu(query, candidates))
|
||||
}
|
||||
|
||||
/// Batch Euclidean distance
|
||||
pub fn batch_euclidean(&self, query: &[f32], candidates: &[&[f32]]) -> Result<Vec<f32>> {
|
||||
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
|
||||
if self.use_gpu && candidates.len() >= self.min_candidates && self.backend.is_some() {
|
||||
return self.batch_euclidean_gpu(query, candidates);
|
||||
}
|
||||
|
||||
Ok(self.batch_euclidean_cpu(query, candidates))
|
||||
}
|
||||
|
||||
/// Find top-k most similar
|
||||
pub fn top_k(&self, query: &[f32], candidates: &[&[f32]], k: usize) -> Result<Vec<(usize, f32)>> {
|
||||
let similarities = self.batch_cosine(query, candidates)?;
|
||||
|
||||
let mut indexed: Vec<(usize, f32)> = similarities.into_iter().enumerate().collect();
|
||||
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
indexed.truncate(k);
|
||||
|
||||
Ok(indexed)
|
||||
}
|
||||
|
||||
// GPU implementations
|
||||
|
||||
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
|
||||
fn batch_cosine_gpu(&self, query: &[f32], candidates: &[&[f32]]) -> Result<Vec<f32>> {
|
||||
let backend = self.backend.as_ref().ok_or_else(|| {
|
||||
EmbeddingError::GpuOperationFailed {
|
||||
operation: "batch_cosine".to_string(),
|
||||
reason: "Backend not initialized".to_string(),
|
||||
}
|
||||
})?;
|
||||
|
||||
let dimension = query.len();
|
||||
let num_candidates = candidates.len();
|
||||
|
||||
// Flatten candidates into contiguous buffer
|
||||
let candidates_flat: Vec<f32> = candidates.iter().flat_map(|c| c.iter().copied()).collect();
|
||||
|
||||
// Create buffers
|
||||
let query_buf = backend.create_buffer((dimension * 4) as u64, BufferUsage::Storage)?;
|
||||
let candidates_buf = backend.create_buffer((candidates_flat.len() * 4) as u64, BufferUsage::Storage)?;
|
||||
let output_buf = backend.create_buffer((num_candidates * 4) as u64, BufferUsage::Storage)?;
|
||||
|
||||
// Create params buffer (dimension, num_candidates)
|
||||
let params: [u32; 2] = [dimension as u32, num_candidates as u32];
|
||||
let params_buf = backend.create_buffer(8, BufferUsage::Uniform)?;
|
||||
backend.write_buffer(¶ms_buf, bytemuck::cast_slice(¶ms))?;
|
||||
|
||||
// Write input data
|
||||
backend.write_buffer(&query_buf, bytemuck::cast_slice(query))?;
|
||||
backend.write_buffer(&candidates_buf, bytemuck::cast_slice(&candidates_flat))?;
|
||||
|
||||
// Create pipeline with batch cosine shader
|
||||
let shader = super::shaders::BATCH_COSINE_SIMILARITY_SHADER;
|
||||
let pipeline = backend.create_pipeline(shader, "batch_cosine_similarity", [256, 1, 1])?;
|
||||
|
||||
// Dispatch with params buffer as 4th binding
|
||||
let workgroups = [num_candidates.div_ceil(256) as u32, 1, 1];
|
||||
backend.dispatch(&pipeline, &[&query_buf, &candidates_buf, &output_buf, ¶ms_buf], workgroups)?;
|
||||
backend.sync()?;
|
||||
|
||||
// Read output
|
||||
let output_bytes = backend.read_buffer(&output_buf, (num_candidates * 4) as u64)?;
|
||||
let output: Vec<f32> = bytemuck::cast_slice(&output_bytes).to_vec();
|
||||
|
||||
// Cleanup
|
||||
backend.release_buffer(query_buf)?;
|
||||
backend.release_buffer(candidates_buf)?;
|
||||
backend.release_buffer(output_buf)?;
|
||||
backend.release_buffer(params_buf)?;
|
||||
backend.release_pipeline(pipeline)?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
|
||||
fn batch_dot_product_gpu(&self, query: &[f32], candidates: &[&[f32]]) -> Result<Vec<f32>> {
|
||||
let backend = self.backend.as_ref().ok_or_else(|| {
|
||||
EmbeddingError::GpuOperationFailed {
|
||||
operation: "batch_dot_product".to_string(),
|
||||
reason: "Backend not initialized".to_string(),
|
||||
}
|
||||
})?;
|
||||
|
||||
let dimension = query.len();
|
||||
let num_candidates = candidates.len();
|
||||
|
||||
// Flatten candidates into contiguous buffer
|
||||
let candidates_flat: Vec<f32> = candidates.iter().flat_map(|c| c.iter().copied()).collect();
|
||||
|
||||
// Create buffers
|
||||
let query_buf = backend.create_buffer((dimension * 4) as u64, BufferUsage::Storage)?;
|
||||
let candidates_buf = backend.create_buffer((candidates_flat.len() * 4) as u64, BufferUsage::Storage)?;
|
||||
let output_buf = backend.create_buffer((num_candidates * 4) as u64, BufferUsage::Storage)?;
|
||||
|
||||
// Create params buffer (dimension, num_candidates)
|
||||
let params: [u32; 2] = [dimension as u32, num_candidates as u32];
|
||||
let params_buf = backend.create_buffer(8, BufferUsage::Uniform)?;
|
||||
backend.write_buffer(¶ms_buf, bytemuck::cast_slice(¶ms))?;
|
||||
|
||||
// Write input data
|
||||
backend.write_buffer(&query_buf, bytemuck::cast_slice(query))?;
|
||||
backend.write_buffer(&candidates_buf, bytemuck::cast_slice(&candidates_flat))?;
|
||||
|
||||
// Create pipeline
|
||||
let shader = super::shaders::DOT_PRODUCT_SHADER;
|
||||
let pipeline = backend.create_pipeline(shader, "dot_product", [256, 1, 1])?;
|
||||
|
||||
// Dispatch with params buffer as 4th binding
|
||||
let workgroups = [num_candidates.div_ceil(256) as u32, 1, 1];
|
||||
backend.dispatch(&pipeline, &[&query_buf, &candidates_buf, &output_buf, ¶ms_buf], workgroups)?;
|
||||
backend.sync()?;
|
||||
|
||||
// Read output
|
||||
let output_bytes = backend.read_buffer(&output_buf, (num_candidates * 4) as u64)?;
|
||||
let output: Vec<f32> = bytemuck::cast_slice(&output_bytes).to_vec();
|
||||
|
||||
// Cleanup
|
||||
backend.release_buffer(query_buf)?;
|
||||
backend.release_buffer(candidates_buf)?;
|
||||
backend.release_buffer(output_buf)?;
|
||||
backend.release_buffer(params_buf)?;
|
||||
backend.release_pipeline(pipeline)?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
|
||||
fn batch_euclidean_gpu(&self, query: &[f32], candidates: &[&[f32]]) -> Result<Vec<f32>> {
|
||||
let backend = self.backend.as_ref().ok_or_else(|| {
|
||||
EmbeddingError::GpuOperationFailed {
|
||||
operation: "batch_euclidean".to_string(),
|
||||
reason: "Backend not initialized".to_string(),
|
||||
}
|
||||
})?;
|
||||
|
||||
let dimension = query.len();
|
||||
let num_candidates = candidates.len();
|
||||
|
||||
// Flatten candidates into contiguous buffer
|
||||
let candidates_flat: Vec<f32> = candidates.iter().flat_map(|c| c.iter().copied()).collect();
|
||||
|
||||
// Create buffers
|
||||
let query_buf = backend.create_buffer((dimension * 4) as u64, BufferUsage::Storage)?;
|
||||
let candidates_buf = backend.create_buffer((candidates_flat.len() * 4) as u64, BufferUsage::Storage)?;
|
||||
let output_buf = backend.create_buffer((num_candidates * 4) as u64, BufferUsage::Storage)?;
|
||||
|
||||
// Create params buffer (dimension, num_candidates)
|
||||
let params: [u32; 2] = [dimension as u32, num_candidates as u32];
|
||||
let params_buf = backend.create_buffer(8, BufferUsage::Uniform)?;
|
||||
backend.write_buffer(¶ms_buf, bytemuck::cast_slice(¶ms))?;
|
||||
|
||||
// Write input data
|
||||
backend.write_buffer(&query_buf, bytemuck::cast_slice(query))?;
|
||||
backend.write_buffer(&candidates_buf, bytemuck::cast_slice(&candidates_flat))?;
|
||||
|
||||
// Create pipeline
|
||||
let shader = super::shaders::EUCLIDEAN_DISTANCE_SHADER;
|
||||
let pipeline = backend.create_pipeline(shader, "euclidean_distance", [256, 1, 1])?;
|
||||
|
||||
// Dispatch with params buffer as 4th binding
|
||||
let workgroups = [num_candidates.div_ceil(256) as u32, 1, 1];
|
||||
backend.dispatch(&pipeline, &[&query_buf, &candidates_buf, &output_buf, ¶ms_buf], workgroups)?;
|
||||
backend.sync()?;
|
||||
|
||||
// Read output
|
||||
let output_bytes = backend.read_buffer(&output_buf, (num_candidates * 4) as u64)?;
|
||||
let output: Vec<f32> = bytemuck::cast_slice(&output_bytes).to_vec();
|
||||
|
||||
// Cleanup
|
||||
backend.release_buffer(query_buf)?;
|
||||
backend.release_buffer(candidates_buf)?;
|
||||
backend.release_buffer(output_buf)?;
|
||||
backend.release_buffer(params_buf)?;
|
||||
backend.release_pipeline(pipeline)?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
// CPU implementations
|
||||
|
||||
fn batch_cosine_cpu(&self, query: &[f32], candidates: &[&[f32]]) -> Vec<f32> {
|
||||
candidates
|
||||
.par_iter()
|
||||
.map(|c| cosine_similarity_cpu(query, c))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn batch_dot_product_cpu(&self, query: &[f32], candidates: &[&[f32]]) -> Vec<f32> {
|
||||
candidates
|
||||
.par_iter()
|
||||
.map(|c| dot_product_cpu(query, c))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn batch_euclidean_cpu(&self, query: &[f32], candidates: &[&[f32]]) -> Vec<f32> {
|
||||
candidates
|
||||
.par_iter()
|
||||
.map(|c| euclidean_distance_cpu(query, c))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== GPU Vector Operations ====================
|
||||
|
||||
/// GPU-accelerated vector operations
|
||||
pub struct GpuVectorOps {
|
||||
use_gpu: bool,
|
||||
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
|
||||
backend: Option<Arc<dyn GpuBackend>>,
|
||||
}
|
||||
|
||||
impl GpuVectorOps {
|
||||
/// Create new GPU vector operations
|
||||
pub fn new(backend: &dyn GpuBackend, _shaders: &ShaderRegistry) -> Result<Self> {
|
||||
Ok(Self {
|
||||
use_gpu: backend.is_available() && backend.device_info().supports_compute,
|
||||
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
|
||||
backend: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Set the backend for GPU operations
|
||||
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
|
||||
pub fn set_backend(&mut self, backend: Arc<dyn GpuBackend>) {
|
||||
self.backend = Some(backend);
|
||||
}
|
||||
|
||||
/// L2 normalize batch of vectors
|
||||
pub fn normalize_batch(&self, vectors: &mut [f32], dimension: usize) -> Result<()> {
|
||||
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
|
||||
if self.use_gpu && vectors.len() >= dimension * 64 && self.backend.is_some() {
|
||||
return self.normalize_batch_gpu(vectors, dimension);
|
||||
}
|
||||
|
||||
self.normalize_batch_cpu(vectors, dimension);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Matrix-vector multiplication
|
||||
pub fn matmul(&self, matrix: &[f32], vector: &[f32], rows: usize, cols: usize) -> Result<Vec<f32>> {
|
||||
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
|
||||
if self.use_gpu && rows >= 64 && self.backend.is_some() {
|
||||
return self.matmul_gpu(matrix, vector, rows, cols);
|
||||
}
|
||||
|
||||
Ok(self.matmul_cpu(matrix, vector, rows, cols))
|
||||
}
|
||||
|
||||
/// Batch vector addition
|
||||
pub fn batch_add(&self, a: &[f32], b: &[f32]) -> Result<Vec<f32>> {
|
||||
if a.len() != b.len() {
|
||||
return Err(EmbeddingError::dimension_mismatch(a.len(), b.len()));
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
|
||||
if self.use_gpu && a.len() >= 1024 && self.backend.is_some() {
|
||||
return self.batch_add_gpu(a, b);
|
||||
}
|
||||
|
||||
Ok(a.par_iter().zip(b.par_iter()).map(|(x, y)| x + y).collect())
|
||||
}
|
||||
|
||||
/// Batch vector scaling
|
||||
pub fn batch_scale(&self, vectors: &mut [f32], scale: f32) -> Result<()> {
|
||||
vectors.par_iter_mut().for_each(|v| *v *= scale);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// GPU implementations
|
||||
|
||||
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
|
||||
fn normalize_batch_gpu(&self, vectors: &mut [f32], dimension: usize) -> Result<()> {
|
||||
let backend = self.backend.as_ref().ok_or_else(|| {
|
||||
EmbeddingError::GpuOperationFailed {
|
||||
operation: "normalize_batch".to_string(),
|
||||
reason: "Backend not initialized".to_string(),
|
||||
}
|
||||
})?;
|
||||
|
||||
let num_vectors = vectors.len() / dimension;
|
||||
|
||||
// Create buffers (input, dummy, output, params)
|
||||
let input_buf = backend.create_buffer((vectors.len() * 4) as u64, BufferUsage::Storage)?;
|
||||
let dummy_buf = backend.create_buffer(4, BufferUsage::Storage)?;
|
||||
let output_buf = backend.create_buffer((vectors.len() * 4) as u64, BufferUsage::Storage)?;
|
||||
|
||||
// Create params buffer (dimension, num_vectors)
|
||||
let params: [u32; 2] = [dimension as u32, num_vectors as u32];
|
||||
let params_buf = backend.create_buffer(8, BufferUsage::Uniform)?;
|
||||
backend.write_buffer(¶ms_buf, bytemuck::cast_slice(¶ms))?;
|
||||
|
||||
// Write input data
|
||||
backend.write_buffer(&input_buf, bytemuck::cast_slice(vectors))?;
|
||||
|
||||
// Create pipeline
|
||||
let shader = super::shaders::L2_NORMALIZE_SHADER;
|
||||
let pipeline = backend.create_pipeline(shader, "l2_normalize", [256, 1, 1])?;
|
||||
|
||||
// Dispatch with 4 bindings
|
||||
let workgroups = [num_vectors.div_ceil(256) as u32, 1, 1];
|
||||
backend.dispatch(&pipeline, &[&input_buf, &dummy_buf, &output_buf, ¶ms_buf], workgroups)?;
|
||||
backend.sync()?;
|
||||
|
||||
// Read output
|
||||
let output_bytes = backend.read_buffer(&output_buf, (vectors.len() * 4) as u64)?;
|
||||
let output: &[f32] = bytemuck::cast_slice(&output_bytes);
|
||||
vectors.copy_from_slice(output);
|
||||
|
||||
// Cleanup
|
||||
backend.release_buffer(input_buf)?;
|
||||
backend.release_buffer(dummy_buf)?;
|
||||
backend.release_buffer(output_buf)?;
|
||||
backend.release_buffer(params_buf)?;
|
||||
backend.release_pipeline(pipeline)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
|
||||
fn matmul_gpu(&self, matrix: &[f32], vector: &[f32], rows: usize, cols: usize) -> Result<Vec<f32>> {
|
||||
let backend = self.backend.as_ref().ok_or_else(|| {
|
||||
EmbeddingError::GpuOperationFailed {
|
||||
operation: "matmul".to_string(),
|
||||
reason: "Backend not initialized".to_string(),
|
||||
}
|
||||
})?;
|
||||
|
||||
// Create buffers
|
||||
let mat_buf = backend.create_buffer((matrix.len() * 4) as u64, BufferUsage::Storage)?;
|
||||
let vec_buf = backend.create_buffer((vector.len() * 4) as u64, BufferUsage::Storage)?;
|
||||
let output_buf = backend.create_buffer((rows * 4) as u64, BufferUsage::Storage)?;
|
||||
|
||||
// Create params buffer (rows, cols)
|
||||
let params: [u32; 2] = [rows as u32, cols as u32];
|
||||
let params_buf = backend.create_buffer(8, BufferUsage::Uniform)?;
|
||||
backend.write_buffer(¶ms_buf, bytemuck::cast_slice(¶ms))?;
|
||||
|
||||
// Write input data
|
||||
backend.write_buffer(&mat_buf, bytemuck::cast_slice(matrix))?;
|
||||
backend.write_buffer(&vec_buf, bytemuck::cast_slice(vector))?;
|
||||
|
||||
// Create pipeline
|
||||
let shader = super::shaders::MATMUL_SHADER;
|
||||
let pipeline = backend.create_pipeline(shader, "matmul", [16, 16, 1])?;
|
||||
|
||||
// Dispatch with params buffer as 4th binding
|
||||
let workgroups = [rows.div_ceil(16) as u32, 1, 1];
|
||||
backend.dispatch(&pipeline, &[&mat_buf, &vec_buf, &output_buf, ¶ms_buf], workgroups)?;
|
||||
backend.sync()?;
|
||||
|
||||
// Read output
|
||||
let output_bytes = backend.read_buffer(&output_buf, (rows * 4) as u64)?;
|
||||
let output: Vec<f32> = bytemuck::cast_slice(&output_bytes).to_vec();
|
||||
|
||||
// Cleanup
|
||||
backend.release_buffer(mat_buf)?;
|
||||
backend.release_buffer(vec_buf)?;
|
||||
backend.release_buffer(output_buf)?;
|
||||
backend.release_buffer(params_buf)?;
|
||||
backend.release_pipeline(pipeline)?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
|
||||
fn batch_add_gpu(&self, a: &[f32], b: &[f32]) -> Result<Vec<f32>> {
|
||||
let backend = self.backend.as_ref().ok_or_else(|| {
|
||||
EmbeddingError::GpuOperationFailed {
|
||||
operation: "batch_add".to_string(),
|
||||
reason: "Backend not initialized".to_string(),
|
||||
}
|
||||
})?;
|
||||
|
||||
// Create buffers
|
||||
let buf_a = backend.create_buffer((a.len() * 4) as u64, BufferUsage::Storage)?;
|
||||
let buf_b = backend.create_buffer((b.len() * 4) as u64, BufferUsage::Storage)?;
|
||||
let output_buf = backend.create_buffer((a.len() * 4) as u64, BufferUsage::Storage)?;
|
||||
|
||||
// Create params buffer (length)
|
||||
let params: [u32; 1] = [a.len() as u32];
|
||||
let params_buf = backend.create_buffer(4, BufferUsage::Uniform)?;
|
||||
backend.write_buffer(¶ms_buf, bytemuck::cast_slice(¶ms))?;
|
||||
|
||||
// Write input data
|
||||
backend.write_buffer(&buf_a, bytemuck::cast_slice(a))?;
|
||||
backend.write_buffer(&buf_b, bytemuck::cast_slice(b))?;
|
||||
|
||||
// Create pipeline
|
||||
let shader = super::shaders::VECTOR_ADD_SHADER;
|
||||
let pipeline = backend.create_pipeline(shader, "vector_add", [256, 1, 1])?;
|
||||
|
||||
// Dispatch with params buffer as 4th binding
|
||||
let workgroups = [a.len().div_ceil(256) as u32, 1, 1];
|
||||
backend.dispatch(&pipeline, &[&buf_a, &buf_b, &output_buf, ¶ms_buf], workgroups)?;
|
||||
backend.sync()?;
|
||||
|
||||
// Read output
|
||||
let output_bytes = backend.read_buffer(&output_buf, (a.len() * 4) as u64)?;
|
||||
let output: Vec<f32> = bytemuck::cast_slice(&output_bytes).to_vec();
|
||||
|
||||
// Cleanup
|
||||
backend.release_buffer(buf_a)?;
|
||||
backend.release_buffer(buf_b)?;
|
||||
backend.release_buffer(output_buf)?;
|
||||
backend.release_buffer(params_buf)?;
|
||||
backend.release_pipeline(pipeline)?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
// CPU implementations
|
||||
|
||||
fn normalize_batch_cpu(&self, vectors: &mut [f32], dimension: usize) {
|
||||
vectors
|
||||
.par_chunks_mut(dimension)
|
||||
.for_each(|chunk| {
|
||||
let norm: f32 = chunk.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm > 1e-12 {
|
||||
for val in chunk.iter_mut() {
|
||||
*val /= norm;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn matmul_cpu(&self, matrix: &[f32], vector: &[f32], rows: usize, cols: usize) -> Vec<f32> {
|
||||
let mut result = vec![0.0f32; rows];
|
||||
|
||||
result
|
||||
.par_iter_mut()
|
||||
.enumerate()
|
||||
.for_each(|(row, out)| {
|
||||
let row_start = row * cols;
|
||||
*out = matrix[row_start..row_start + cols]
|
||||
.iter()
|
||||
.zip(vector.iter())
|
||||
.map(|(m, v)| m * v)
|
||||
.sum();
|
||||
});
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== Standalone Functions ====================
|
||||
|
||||
/// Batch cosine similarity (GPU-accelerated if available)
|
||||
pub fn batch_cosine_similarity_gpu(query: &[f32], candidates: &[&[f32]]) -> Vec<f32> {
|
||||
candidates
|
||||
.par_iter()
|
||||
.map(|c| cosine_similarity_cpu(query, c))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Batch dot product (GPU-accelerated if available)
|
||||
pub fn batch_dot_product_gpu(query: &[f32], candidates: &[&[f32]]) -> Vec<f32> {
|
||||
candidates
|
||||
.par_iter()
|
||||
.map(|c| dot_product_cpu(query, c))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Batch Euclidean distance (GPU-accelerated if available)
|
||||
pub fn batch_euclidean_gpu(query: &[f32], candidates: &[&[f32]]) -> Vec<f32> {
|
||||
candidates
|
||||
.par_iter()
|
||||
.map(|c| euclidean_distance_cpu(query, c))
|
||||
.collect()
|
||||
}
|
||||
|
||||
// ==================== CPU Helper Functions ====================
|
||||
|
||||
#[inline]
|
||||
fn cosine_similarity_cpu(a: &[f32], b: &[f32]) -> f32 {
|
||||
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
|
||||
if norm_a > 1e-12 && norm_b > 1e-12 {
|
||||
dot / (norm_a * norm_b)
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn dot_product_cpu(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn euclidean_distance_cpu(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter()
|
||||
.zip(b.iter())
|
||||
.map(|(x, y)| (x - y).powi(2))
|
||||
.sum::<f32>()
|
||||
.sqrt()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_cosine_similarity() {
|
||||
let a = vec![1.0, 0.0, 0.0];
|
||||
let b = vec![1.0, 0.0, 0.0];
|
||||
let c = vec![0.0, 1.0, 0.0];
|
||||
|
||||
assert!((cosine_similarity_cpu(&a, &b) - 1.0).abs() < 1e-6);
|
||||
assert!(cosine_similarity_cpu(&a, &c).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dot_product() {
|
||||
let a = vec![1.0, 2.0, 3.0];
|
||||
let b = vec![4.0, 5.0, 6.0];
|
||||
|
||||
assert!((dot_product_cpu(&a, &b) - 32.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_euclidean_distance() {
|
||||
let a = vec![0.0, 0.0, 0.0];
|
||||
let b = vec![3.0, 4.0, 0.0];
|
||||
|
||||
assert!((euclidean_distance_cpu(&a, &b) - 5.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_cosine() {
|
||||
let query = vec![1.0, 0.0, 0.0];
|
||||
let candidates: Vec<&[f32]> = vec![
|
||||
&[1.0, 0.0, 0.0][..],
|
||||
&[0.0, 1.0, 0.0][..],
|
||||
&[0.707, 0.707, 0.0][..],
|
||||
];
|
||||
|
||||
let results = batch_cosine_similarity_gpu(&query, &candidates);
|
||||
|
||||
assert_eq!(results.len(), 3);
|
||||
assert!((results[0] - 1.0).abs() < 1e-6);
|
||||
assert!(results[1].abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mean_pool_cpu() {
|
||||
let pooler = GpuPooler {
|
||||
use_gpu: false,
|
||||
#[cfg(feature = "gpu")]
|
||||
backend: None,
|
||||
};
|
||||
|
||||
// batch=2, seq=2, hidden=3
|
||||
let tokens = vec![
|
||||
1.0, 2.0, 3.0, // batch 0, seq 0
|
||||
4.0, 5.0, 6.0, // batch 0, seq 1
|
||||
7.0, 8.0, 9.0, // batch 1, seq 0
|
||||
10.0, 11.0, 12.0, // batch 1, seq 1
|
||||
];
|
||||
let mask = vec![1i64, 1, 1, 1];
|
||||
|
||||
let result = pooler.mean_pool_cpu(&tokens, &mask, 2, 2, 3);
|
||||
|
||||
assert_eq!(result.len(), 6);
|
||||
// Batch 0: mean of [1,2,3] and [4,5,6] = [2.5, 3.5, 4.5]
|
||||
assert!((result[0] - 2.5).abs() < 1e-6);
|
||||
assert!((result[1] - 3.5).abs() < 1e-6);
|
||||
assert!((result[2] - 4.5).abs() < 1e-6);
|
||||
}
|
||||
}
|
||||
613
vendor/ruvector/examples/onnx-embeddings/src/gpu/shaders.rs
vendored
Normal file
613
vendor/ruvector/examples/onnx-embeddings/src/gpu/shaders.rs
vendored
Normal file
@@ -0,0 +1,613 @@
|
||||
//! GPU Compute Shaders for RuVector Operations
|
||||
//!
|
||||
//! WGSL (WebGPU Shading Language) implementations for:
|
||||
//! - Pooling operations
|
||||
//! - Similarity computations
|
||||
//! - Vector normalization
|
||||
//! - Matrix operations
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Shader registry for managing compute shaders
|
||||
#[derive(Debug)]
|
||||
pub struct ShaderRegistry {
|
||||
shaders: HashMap<String, ShaderModule>,
|
||||
}
|
||||
|
||||
/// Shader module information
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ShaderModule {
|
||||
/// Shader name
|
||||
pub name: String,
|
||||
/// WGSL source code
|
||||
pub source: String,
|
||||
/// Entry point function
|
||||
pub entry_point: String,
|
||||
/// Default workgroup size
|
||||
pub workgroup_size: [u32; 3],
|
||||
}
|
||||
|
||||
impl ShaderRegistry {
|
||||
/// Create new registry with built-in shaders
|
||||
pub fn new() -> Self {
|
||||
let mut shaders = HashMap::new();
|
||||
|
||||
// Register all built-in shaders
|
||||
for shader in Self::builtin_shaders() {
|
||||
shaders.insert(shader.name.clone(), shader);
|
||||
}
|
||||
|
||||
Self { shaders }
|
||||
}
|
||||
|
||||
/// Get shader by name
|
||||
pub fn get(&self, name: &str) -> Option<&ShaderModule> {
|
||||
self.shaders.get(name)
|
||||
}
|
||||
|
||||
/// Register custom shader
|
||||
pub fn register(&mut self, shader: ShaderModule) {
|
||||
self.shaders.insert(shader.name.clone(), shader);
|
||||
}
|
||||
|
||||
/// List all available shaders
|
||||
pub fn list(&self) -> Vec<&str> {
|
||||
self.shaders.keys().map(|s| s.as_str()).collect()
|
||||
}
|
||||
|
||||
/// Get built-in shader definitions
|
||||
fn builtin_shaders() -> Vec<ShaderModule> {
|
||||
vec![
|
||||
// Cosine Similarity
|
||||
ShaderModule {
|
||||
name: "cosine_similarity".to_string(),
|
||||
source: SHADER_COSINE_SIMILARITY.to_string(),
|
||||
entry_point: "cosine_similarity".to_string(),
|
||||
workgroup_size: [256, 1, 1],
|
||||
},
|
||||
// Batch Cosine Similarity
|
||||
ShaderModule {
|
||||
name: "batch_cosine_similarity".to_string(),
|
||||
source: SHADER_BATCH_COSINE_SIMILARITY.to_string(),
|
||||
entry_point: "batch_cosine_similarity".to_string(),
|
||||
workgroup_size: [256, 1, 1],
|
||||
},
|
||||
// Dot Product
|
||||
ShaderModule {
|
||||
name: "dot_product".to_string(),
|
||||
source: SHADER_DOT_PRODUCT.to_string(),
|
||||
entry_point: "dot_product".to_string(),
|
||||
workgroup_size: [256, 1, 1],
|
||||
},
|
||||
// Euclidean Distance
|
||||
ShaderModule {
|
||||
name: "euclidean_distance".to_string(),
|
||||
source: SHADER_EUCLIDEAN_DISTANCE.to_string(),
|
||||
entry_point: "euclidean_distance".to_string(),
|
||||
workgroup_size: [256, 1, 1],
|
||||
},
|
||||
// L2 Normalize
|
||||
ShaderModule {
|
||||
name: "l2_normalize".to_string(),
|
||||
source: SHADER_L2_NORMALIZE.to_string(),
|
||||
entry_point: "l2_normalize".to_string(),
|
||||
workgroup_size: [256, 1, 1],
|
||||
},
|
||||
// Mean Pooling
|
||||
ShaderModule {
|
||||
name: "mean_pool".to_string(),
|
||||
source: SHADER_MEAN_POOL.to_string(),
|
||||
entry_point: "mean_pool".to_string(),
|
||||
workgroup_size: [64, 1, 1],
|
||||
},
|
||||
// Max Pooling
|
||||
ShaderModule {
|
||||
name: "max_pool".to_string(),
|
||||
source: SHADER_MAX_POOL.to_string(),
|
||||
entry_point: "max_pool".to_string(),
|
||||
workgroup_size: [64, 1, 1],
|
||||
},
|
||||
// CLS Pooling
|
||||
ShaderModule {
|
||||
name: "cls_pool".to_string(),
|
||||
source: SHADER_CLS_POOL.to_string(),
|
||||
entry_point: "cls_pool".to_string(),
|
||||
workgroup_size: [64, 1, 1],
|
||||
},
|
||||
// Matrix-Vector Multiplication
|
||||
ShaderModule {
|
||||
name: "matmul".to_string(),
|
||||
source: SHADER_MATMUL.to_string(),
|
||||
entry_point: "matmul".to_string(),
|
||||
workgroup_size: [16, 16, 1],
|
||||
},
|
||||
// Vector Addition
|
||||
ShaderModule {
|
||||
name: "vector_add".to_string(),
|
||||
source: SHADER_VECTOR_ADD.to_string(),
|
||||
entry_point: "vector_add".to_string(),
|
||||
workgroup_size: [256, 1, 1],
|
||||
},
|
||||
// Vector Scale
|
||||
ShaderModule {
|
||||
name: "vector_scale".to_string(),
|
||||
source: SHADER_VECTOR_SCALE.to_string(),
|
||||
entry_point: "vector_scale".to_string(),
|
||||
workgroup_size: [256, 1, 1],
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ShaderRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== Shader Source Code ====================
|
||||
|
||||
// Public aliases for operations.rs
|
||||
pub const MEAN_POOL_SHADER: &str = SHADER_MEAN_POOL;
|
||||
pub const MAX_POOL_SHADER: &str = SHADER_MAX_POOL;
|
||||
pub const BATCH_COSINE_SIMILARITY_SHADER: &str = SHADER_BATCH_COSINE_SIMILARITY;
|
||||
pub const DOT_PRODUCT_SHADER: &str = SHADER_DOT_PRODUCT;
|
||||
pub const EUCLIDEAN_DISTANCE_SHADER: &str = SHADER_EUCLIDEAN_DISTANCE;
|
||||
pub const L2_NORMALIZE_SHADER: &str = SHADER_L2_NORMALIZE;
|
||||
pub const MATMUL_SHADER: &str = SHADER_MATMUL;
|
||||
pub const VECTOR_ADD_SHADER: &str = SHADER_VECTOR_ADD;
|
||||
|
||||
/// Cosine similarity between two vectors
|
||||
pub const SHADER_COSINE_SIMILARITY: &str = r#"
|
||||
struct Params {
|
||||
dimension: u32,
|
||||
count: u32,
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> query: array<f32>;
|
||||
@group(0) @binding(1) var<storage, read> candidate: array<f32>;
|
||||
@group(0) @binding(2) var<storage, read_write> result: array<f32>;
|
||||
@group(0) @binding(3) var<uniform> params: Params;
|
||||
|
||||
var<workgroup> shared_dot: array<f32, 256>;
|
||||
var<workgroup> shared_norm_a: array<f32, 256>;
|
||||
var<workgroup> shared_norm_b: array<f32, 256>;
|
||||
|
||||
@compute @workgroup_size(256)
|
||||
fn cosine_similarity(@builtin(global_invocation_id) gid: vec3<u32>,
|
||||
@builtin(local_invocation_id) lid: vec3<u32>) {
|
||||
let idx = gid.x;
|
||||
let local_idx = lid.x;
|
||||
|
||||
var dot: f32 = 0.0;
|
||||
var norm_a: f32 = 0.0;
|
||||
var norm_b: f32 = 0.0;
|
||||
|
||||
// Compute partial sums
|
||||
var i = local_idx;
|
||||
while (i < params.dimension) {
|
||||
let a = query[i];
|
||||
let b = candidate[i];
|
||||
dot += a * b;
|
||||
norm_a += a * a;
|
||||
norm_b += b * b;
|
||||
i += 256u;
|
||||
}
|
||||
|
||||
// Store in shared memory
|
||||
shared_dot[local_idx] = dot;
|
||||
shared_norm_a[local_idx] = norm_a;
|
||||
shared_norm_b[local_idx] = norm_b;
|
||||
workgroupBarrier();
|
||||
|
||||
// Reduction
|
||||
for (var stride = 128u; stride > 0u; stride >>= 1u) {
|
||||
if (local_idx < stride) {
|
||||
shared_dot[local_idx] += shared_dot[local_idx + stride];
|
||||
shared_norm_a[local_idx] += shared_norm_a[local_idx + stride];
|
||||
shared_norm_b[local_idx] += shared_norm_b[local_idx + stride];
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
// Write result
|
||||
if (local_idx == 0u) {
|
||||
let norm_product = sqrt(shared_norm_a[0] * shared_norm_b[0]);
|
||||
if (norm_product > 1e-12) {
|
||||
result[0] = shared_dot[0] / norm_product;
|
||||
} else {
|
||||
result[0] = 0.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
"#;
|
||||
|
||||
/// Batch cosine similarity - one query vs many candidates
|
||||
pub const SHADER_BATCH_COSINE_SIMILARITY: &str = r#"
|
||||
struct Params {
|
||||
dimension: u32,
|
||||
num_candidates: u32,
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> query: array<f32>;
|
||||
@group(0) @binding(1) var<storage, read> candidates: array<f32>;
|
||||
@group(0) @binding(2) var<storage, read_write> results: array<f32>;
|
||||
@group(0) @binding(3) var<uniform> params: Params;
|
||||
|
||||
@compute @workgroup_size(256)
|
||||
fn batch_cosine_similarity(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
let candidate_idx = gid.x;
|
||||
|
||||
if (candidate_idx >= params.num_candidates) {
|
||||
return;
|
||||
}
|
||||
|
||||
let base = candidate_idx * params.dimension;
|
||||
|
||||
var dot: f32 = 0.0;
|
||||
var norm_a: f32 = 0.0;
|
||||
var norm_b: f32 = 0.0;
|
||||
|
||||
for (var i = 0u; i < params.dimension; i++) {
|
||||
let a = query[i];
|
||||
let b = candidates[base + i];
|
||||
dot += a * b;
|
||||
norm_a += a * a;
|
||||
norm_b += b * b;
|
||||
}
|
||||
|
||||
let norm_product = sqrt(norm_a * norm_b);
|
||||
if (norm_product > 1e-12) {
|
||||
results[candidate_idx] = dot / norm_product;
|
||||
} else {
|
||||
results[candidate_idx] = 0.0;
|
||||
}
|
||||
}
|
||||
"#;
|
||||
|
||||
/// Dot product computation
|
||||
pub const SHADER_DOT_PRODUCT: &str = r#"
|
||||
struct Params {
|
||||
dimension: u32,
|
||||
num_candidates: u32,
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> query: array<f32>;
|
||||
@group(0) @binding(1) var<storage, read> candidates: array<f32>;
|
||||
@group(0) @binding(2) var<storage, read_write> results: array<f32>;
|
||||
@group(0) @binding(3) var<uniform> params: Params;
|
||||
|
||||
@compute @workgroup_size(256)
|
||||
fn dot_product(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
let candidate_idx = gid.x;
|
||||
|
||||
if (candidate_idx >= params.num_candidates) {
|
||||
return;
|
||||
}
|
||||
|
||||
let base = candidate_idx * params.dimension;
|
||||
|
||||
var dot: f32 = 0.0;
|
||||
for (var i = 0u; i < params.dimension; i++) {
|
||||
dot += query[i] * candidates[base + i];
|
||||
}
|
||||
|
||||
results[candidate_idx] = dot;
|
||||
}
|
||||
"#;
|
||||
|
||||
/// Euclidean distance computation
|
||||
pub const SHADER_EUCLIDEAN_DISTANCE: &str = r#"
|
||||
struct Params {
|
||||
dimension: u32,
|
||||
num_candidates: u32,
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> query: array<f32>;
|
||||
@group(0) @binding(1) var<storage, read> candidates: array<f32>;
|
||||
@group(0) @binding(2) var<storage, read_write> results: array<f32>;
|
||||
@group(0) @binding(3) var<uniform> params: Params;
|
||||
|
||||
@compute @workgroup_size(256)
|
||||
fn euclidean_distance(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
let candidate_idx = gid.x;
|
||||
|
||||
if (candidate_idx >= params.num_candidates) {
|
||||
return;
|
||||
}
|
||||
|
||||
let base = candidate_idx * params.dimension;
|
||||
|
||||
var sum_sq: f32 = 0.0;
|
||||
for (var i = 0u; i < params.dimension; i++) {
|
||||
let diff = query[i] - candidates[base + i];
|
||||
sum_sq += diff * diff;
|
||||
}
|
||||
|
||||
results[candidate_idx] = sqrt(sum_sq);
|
||||
}
|
||||
"#;
|
||||
|
||||
/// L2 normalization
|
||||
pub const SHADER_L2_NORMALIZE: &str = r#"
|
||||
struct Params {
|
||||
dimension: u32,
|
||||
num_vectors: u32,
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> input_vectors: array<f32>;
|
||||
@group(0) @binding(1) var<storage, read> _dummy: array<f32>;
|
||||
@group(0) @binding(2) var<storage, read_write> output_vectors: array<f32>;
|
||||
@group(0) @binding(3) var<uniform> params: Params;
|
||||
|
||||
@compute @workgroup_size(256)
|
||||
fn l2_normalize(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
let vec_idx = gid.x;
|
||||
|
||||
if (vec_idx >= params.num_vectors) {
|
||||
return;
|
||||
}
|
||||
|
||||
let base = vec_idx * params.dimension;
|
||||
|
||||
// Compute norm
|
||||
var norm_sq: f32 = 0.0;
|
||||
for (var i = 0u; i < params.dimension; i++) {
|
||||
let val = input_vectors[base + i];
|
||||
norm_sq += val * val;
|
||||
}
|
||||
|
||||
let norm = sqrt(norm_sq);
|
||||
|
||||
// Normalize and write to output
|
||||
if (norm > 1e-12) {
|
||||
for (var i = 0u; i < params.dimension; i++) {
|
||||
output_vectors[base + i] = input_vectors[base + i] / norm;
|
||||
}
|
||||
} else {
|
||||
for (var i = 0u; i < params.dimension; i++) {
|
||||
output_vectors[base + i] = input_vectors[base + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
"#;
|
||||
|
||||
/// Mean pooling over sequence
|
||||
pub const SHADER_MEAN_POOL: &str = r#"
|
||||
struct Params {
|
||||
batch_size: u32,
|
||||
seq_length: u32,
|
||||
hidden_size: u32,
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> tokens: array<f32>;
|
||||
@group(0) @binding(1) var<storage, read> attention_mask: array<i32>;
|
||||
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
|
||||
@group(0) @binding(3) var<uniform> params: Params;
|
||||
|
||||
@compute @workgroup_size(64)
|
||||
fn mean_pool(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
let batch_idx = gid.x / params.hidden_size;
|
||||
let hidden_idx = gid.x % params.hidden_size;
|
||||
|
||||
if (batch_idx >= params.batch_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
let tokens_base = batch_idx * params.seq_length * params.hidden_size;
|
||||
let mask_base = batch_idx * params.seq_length;
|
||||
|
||||
var sum: f32 = 0.0;
|
||||
var count: f32 = 0.0;
|
||||
|
||||
for (var i = 0u; i < params.seq_length; i++) {
|
||||
if (attention_mask[mask_base + i] == 1) {
|
||||
sum += tokens[tokens_base + i * params.hidden_size + hidden_idx];
|
||||
count += 1.0;
|
||||
}
|
||||
}
|
||||
|
||||
let out_idx = batch_idx * params.hidden_size + hidden_idx;
|
||||
if (count > 0.0) {
|
||||
output[out_idx] = sum / count;
|
||||
} else {
|
||||
output[out_idx] = 0.0;
|
||||
}
|
||||
}
|
||||
"#;
|
||||
|
||||
/// Max pooling over sequence
|
||||
pub const SHADER_MAX_POOL: &str = r#"
|
||||
struct Params {
|
||||
batch_size: u32,
|
||||
seq_length: u32,
|
||||
hidden_size: u32,
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> tokens: array<f32>;
|
||||
@group(0) @binding(1) var<storage, read> attention_mask: array<i32>;
|
||||
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
|
||||
@group(0) @binding(3) var<uniform> params: Params;
|
||||
|
||||
@compute @workgroup_size(64)
|
||||
fn max_pool(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
let batch_idx = gid.x / params.hidden_size;
|
||||
let hidden_idx = gid.x % params.hidden_size;
|
||||
|
||||
if (batch_idx >= params.batch_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
let tokens_base = batch_idx * params.seq_length * params.hidden_size;
|
||||
let mask_base = batch_idx * params.seq_length;
|
||||
|
||||
var max_val: f32 = -3.402823e+38; // -FLT_MAX
|
||||
var found: bool = false;
|
||||
|
||||
for (var i = 0u; i < params.seq_length; i++) {
|
||||
if (attention_mask[mask_base + i] == 1) {
|
||||
let val = tokens[tokens_base + i * params.hidden_size + hidden_idx];
|
||||
if (!found || val > max_val) {
|
||||
max_val = val;
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let out_idx = batch_idx * params.hidden_size + hidden_idx;
|
||||
output[out_idx] = select(0.0, max_val, found);
|
||||
}
|
||||
"#;
|
||||
|
||||
/// CLS token pooling (first token)
|
||||
pub const SHADER_CLS_POOL: &str = r#"
|
||||
struct Params {
|
||||
batch_size: u32,
|
||||
seq_length: u32,
|
||||
hidden_size: u32,
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> tokens: array<f32>;
|
||||
@group(0) @binding(1) var<storage, read> _dummy: array<f32>;
|
||||
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
|
||||
@group(0) @binding(3) var<uniform> params: Params;
|
||||
|
||||
@compute @workgroup_size(64)
|
||||
fn cls_pool(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
let batch_idx = gid.x / params.hidden_size;
|
||||
let hidden_idx = gid.x % params.hidden_size;
|
||||
|
||||
if (batch_idx >= params.batch_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
// CLS is first token
|
||||
let tokens_base = batch_idx * params.seq_length * params.hidden_size;
|
||||
let out_idx = batch_idx * params.hidden_size + hidden_idx;
|
||||
|
||||
output[out_idx] = tokens[tokens_base + hidden_idx];
|
||||
}
|
||||
"#;
|
||||
|
||||
/// Matrix-vector multiplication
|
||||
pub const SHADER_MATMUL: &str = r#"
|
||||
struct Params {
|
||||
rows: u32,
|
||||
cols: u32,
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> matrix: array<f32>;
|
||||
@group(0) @binding(1) var<storage, read> vector: array<f32>;
|
||||
@group(0) @binding(2) var<storage, read_write> result: array<f32>;
|
||||
@group(0) @binding(3) var<uniform> params: Params;
|
||||
|
||||
@compute @workgroup_size(16, 16)
|
||||
fn matmul(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
let row = gid.x;
|
||||
|
||||
if (row >= params.rows) {
|
||||
return;
|
||||
}
|
||||
|
||||
var sum: f32 = 0.0;
|
||||
for (var col = 0u; col < params.cols; col++) {
|
||||
sum += matrix[row * params.cols + col] * vector[col];
|
||||
}
|
||||
|
||||
result[row] = sum;
|
||||
}
|
||||
"#;
|
||||
|
||||
/// Vector addition
|
||||
pub const SHADER_VECTOR_ADD: &str = r#"
|
||||
struct Params {
|
||||
length: u32,
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> a: array<f32>;
|
||||
@group(0) @binding(1) var<storage, read> b: array<f32>;
|
||||
@group(0) @binding(2) var<storage, read_write> result: array<f32>;
|
||||
@group(0) @binding(3) var<uniform> params: Params;
|
||||
|
||||
@compute @workgroup_size(256)
|
||||
fn vector_add(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
let idx = gid.x;
|
||||
|
||||
if (idx >= params.length) {
|
||||
return;
|
||||
}
|
||||
|
||||
result[idx] = a[idx] + b[idx];
|
||||
}
|
||||
"#;
|
||||
|
||||
/// Vector scaling
|
||||
pub const SHADER_VECTOR_SCALE: &str = r#"
|
||||
struct Params {
|
||||
length: u32,
|
||||
scale: f32,
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> input_vector: array<f32>;
|
||||
@group(0) @binding(1) var<storage, read> _dummy: array<f32>;
|
||||
@group(0) @binding(2) var<storage, read_write> output_vector: array<f32>;
|
||||
@group(0) @binding(3) var<uniform> params: Params;
|
||||
|
||||
@compute @workgroup_size(256)
|
||||
fn vector_scale(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
let idx = gid.x;
|
||||
|
||||
if (idx >= params.length) {
|
||||
return;
|
||||
}
|
||||
|
||||
output_vector[idx] = input_vector[idx] * params.scale;
|
||||
}
|
||||
"#;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_shader_registry() {
|
||||
let registry = ShaderRegistry::new();
|
||||
|
||||
// Check all built-in shaders are registered
|
||||
assert!(registry.get("cosine_similarity").is_some());
|
||||
assert!(registry.get("batch_cosine_similarity").is_some());
|
||||
assert!(registry.get("dot_product").is_some());
|
||||
assert!(registry.get("euclidean_distance").is_some());
|
||||
assert!(registry.get("l2_normalize").is_some());
|
||||
assert!(registry.get("mean_pool").is_some());
|
||||
assert!(registry.get("max_pool").is_some());
|
||||
assert!(registry.get("cls_pool").is_some());
|
||||
assert!(registry.get("matmul").is_some());
|
||||
assert!(registry.get("vector_add").is_some());
|
||||
assert!(registry.get("vector_scale").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shader_content() {
|
||||
let registry = ShaderRegistry::new();
|
||||
|
||||
let cosine = registry.get("cosine_similarity").unwrap();
|
||||
assert!(cosine.source.contains("@compute"));
|
||||
assert!(cosine.source.contains("workgroup_size"));
|
||||
assert_eq!(cosine.entry_point, "cosine_similarity");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_custom_shader() {
|
||||
let mut registry = ShaderRegistry::new();
|
||||
|
||||
registry.register(ShaderModule {
|
||||
name: "custom_op".to_string(),
|
||||
source: "// custom shader".to_string(),
|
||||
entry_point: "custom".to_string(),
|
||||
workgroup_size: [128, 1, 1],
|
||||
});
|
||||
|
||||
assert!(registry.get("custom_op").is_some());
|
||||
}
|
||||
}
|
||||
424
vendor/ruvector/examples/onnx-embeddings/src/gpu/tests.rs
vendored
Normal file
424
vendor/ruvector/examples/onnx-embeddings/src/gpu/tests.rs
vendored
Normal file
@@ -0,0 +1,424 @@
|
||||
//! GPU Module Tests
|
||||
//!
|
||||
//! Comprehensive tests for GPU acceleration functionality.
|
||||
|
||||
use super::*;
|
||||
use super::config::{GpuConfig, GpuMode, PowerPreference, GpuMemoryStats};
|
||||
use super::backend::CpuBackend;
|
||||
use super::shaders::ShaderModule;
|
||||
|
||||
// ==================== Configuration Tests ====================
|
||||
|
||||
#[test]
|
||||
fn test_gpu_config_default() {
|
||||
let config = GpuConfig::default();
|
||||
|
||||
assert_eq!(config.mode, GpuMode::Auto);
|
||||
assert_eq!(config.power_preference, PowerPreference::HighPerformance);
|
||||
assert_eq!(config.workgroup_size, 256);
|
||||
assert!(config.fallback_to_cpu);
|
||||
assert!(config.cache_shaders);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gpu_config_builder() {
|
||||
let config = GpuConfig::auto()
|
||||
.with_mode(GpuMode::WebGpu)
|
||||
.with_power_preference(PowerPreference::LowPower)
|
||||
.with_workgroup_size(512)
|
||||
.with_min_batch_size(32)
|
||||
.with_min_dimension(256)
|
||||
.with_profiling(true);
|
||||
|
||||
assert_eq!(config.mode, GpuMode::WebGpu);
|
||||
assert_eq!(config.power_preference, PowerPreference::LowPower);
|
||||
assert_eq!(config.workgroup_size, 512);
|
||||
assert_eq!(config.min_batch_size, 32);
|
||||
assert_eq!(config.min_dimension, 256);
|
||||
assert!(config.enable_profiling);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_should_use_gpu() {
|
||||
let config = GpuConfig::default()
|
||||
.with_min_batch_size(16)
|
||||
.with_min_dimension(128);
|
||||
|
||||
// Below minimum batch size
|
||||
assert!(!config.should_use_gpu(8, 384));
|
||||
|
||||
// Below minimum dimension
|
||||
assert!(!config.should_use_gpu(32, 64));
|
||||
|
||||
// Both conditions met
|
||||
assert!(config.should_use_gpu(32, 384));
|
||||
|
||||
// CPU only mode
|
||||
let cpu_config = GpuConfig::cpu_only();
|
||||
assert!(!cpu_config.should_use_gpu(1000, 1000));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_preset_configs() {
|
||||
let high_perf = GpuConfig::high_performance();
|
||||
assert_eq!(high_perf.workgroup_size, 512);
|
||||
assert_eq!(high_perf.min_batch_size, 8);
|
||||
|
||||
let low_power = GpuConfig::low_power();
|
||||
assert_eq!(low_power.power_preference, PowerPreference::LowPower);
|
||||
assert_eq!(low_power.workgroup_size, 128);
|
||||
|
||||
let cpu_only = GpuConfig::cpu_only();
|
||||
assert_eq!(cpu_only.mode, GpuMode::CpuOnly);
|
||||
}
|
||||
|
||||
// ==================== Shader Tests ====================
|
||||
|
||||
#[test]
|
||||
fn test_shader_registry_initialization() {
|
||||
let registry = ShaderRegistry::new();
|
||||
|
||||
let expected_shaders = vec![
|
||||
"cosine_similarity",
|
||||
"batch_cosine_similarity",
|
||||
"dot_product",
|
||||
"euclidean_distance",
|
||||
"l2_normalize",
|
||||
"mean_pool",
|
||||
"max_pool",
|
||||
"cls_pool",
|
||||
"matmul",
|
||||
"vector_add",
|
||||
"vector_scale",
|
||||
];
|
||||
|
||||
for name in expected_shaders {
|
||||
assert!(registry.get(name).is_some(), "Missing shader: {}", name);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shader_module_content() {
|
||||
let registry = ShaderRegistry::new();
|
||||
|
||||
// Check cosine similarity shader
|
||||
let cosine = registry.get("cosine_similarity").unwrap();
|
||||
assert!(cosine.source.contains("@compute"));
|
||||
assert!(cosine.source.contains("workgroup_size"));
|
||||
assert!(cosine.source.contains("cosine_similarity"));
|
||||
assert_eq!(cosine.entry_point, "cosine_similarity");
|
||||
assert_eq!(cosine.workgroup_size, [256, 1, 1]);
|
||||
|
||||
// Check mean pool shader
|
||||
let mean_pool = registry.get("mean_pool").unwrap();
|
||||
assert!(mean_pool.source.contains("attention_mask"));
|
||||
assert!(mean_pool.source.contains("hidden_size"));
|
||||
assert_eq!(mean_pool.entry_point, "mean_pool");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_custom_shader_registration() {
|
||||
let mut registry = ShaderRegistry::new();
|
||||
|
||||
let custom = ShaderModule {
|
||||
name: "custom_kernel".to_string(),
|
||||
source: "@compute @workgroup_size(64) fn custom() {}".to_string(),
|
||||
entry_point: "custom".to_string(),
|
||||
workgroup_size: [64, 1, 1],
|
||||
};
|
||||
|
||||
registry.register(custom);
|
||||
|
||||
assert!(registry.get("custom_kernel").is_some());
|
||||
let retrieved = registry.get("custom_kernel").unwrap();
|
||||
assert_eq!(retrieved.entry_point, "custom");
|
||||
}
|
||||
|
||||
// ==================== Batch Operations Tests ====================
|
||||
|
||||
#[test]
|
||||
fn test_batch_cosine_similarity() {
|
||||
let query = vec![1.0, 0.0, 0.0];
|
||||
let candidates: Vec<&[f32]> = vec![
|
||||
&[1.0, 0.0, 0.0][..], // similarity = 1.0
|
||||
&[0.0, 1.0, 0.0][..], // similarity = 0.0
|
||||
&[-1.0, 0.0, 0.0][..], // similarity = -1.0
|
||||
];
|
||||
|
||||
let results = batch_cosine_similarity_gpu(&query, &candidates);
|
||||
|
||||
assert_eq!(results.len(), 3);
|
||||
assert!((results[0] - 1.0).abs() < 1e-6);
|
||||
assert!(results[1].abs() < 1e-6);
|
||||
assert!((results[2] - (-1.0)).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_dot_product() {
|
||||
let query = vec![1.0, 1.0, 1.0];
|
||||
let candidates: Vec<&[f32]> = vec![
|
||||
&[1.0, 1.0, 1.0][..], // dot = 3.0
|
||||
&[2.0, 2.0, 2.0][..], // dot = 6.0
|
||||
&[0.0, 0.0, 0.0][..], // dot = 0.0
|
||||
];
|
||||
|
||||
let results = batch_dot_product_gpu(&query, &candidates);
|
||||
|
||||
assert_eq!(results.len(), 3);
|
||||
assert!((results[0] - 3.0).abs() < 1e-6);
|
||||
assert!((results[1] - 6.0).abs() < 1e-6);
|
||||
assert!(results[2].abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_euclidean() {
|
||||
let query = vec![0.0, 0.0, 0.0];
|
||||
let candidates: Vec<&[f32]> = vec![
|
||||
&[3.0, 4.0, 0.0][..], // dist = 5.0
|
||||
&[1.0, 0.0, 0.0][..], // dist = 1.0
|
||||
&[0.0, 0.0, 0.0][..], // dist = 0.0
|
||||
];
|
||||
|
||||
let results = batch_euclidean_gpu(&query, &candidates);
|
||||
|
||||
assert_eq!(results.len(), 3);
|
||||
assert!((results[0] - 5.0).abs() < 1e-6);
|
||||
assert!((results[1] - 1.0).abs() < 1e-6);
|
||||
assert!(results[2].abs() < 1e-6);
|
||||
}
|
||||
|
||||
// ==================== Pooling Tests (using public API) ====================
|
||||
|
||||
#[test]
|
||||
fn test_mean_pool_via_api() {
|
||||
let backend = CpuBackend;
|
||||
let shaders = ShaderRegistry::new();
|
||||
let pooler = GpuPooler::new(&backend, &shaders).unwrap();
|
||||
|
||||
// batch=2, seq=2, hidden=3
|
||||
let tokens = vec![
|
||||
1.0, 2.0, 3.0, // batch 0, seq 0
|
||||
4.0, 5.0, 6.0, // batch 0, seq 1
|
||||
7.0, 8.0, 9.0, // batch 1, seq 0
|
||||
10.0, 11.0, 12.0, // batch 1, seq 1
|
||||
];
|
||||
let mask = vec![1i64, 1, 1, 1];
|
||||
|
||||
let result = pooler.mean_pool(&tokens, &mask, 2, 2, 3).unwrap();
|
||||
|
||||
assert_eq!(result.len(), 6);
|
||||
// Batch 0: mean of [1,2,3] and [4,5,6] = [2.5, 3.5, 4.5]
|
||||
assert!((result[0] - 2.5).abs() < 1e-6);
|
||||
assert!((result[1] - 3.5).abs() < 1e-6);
|
||||
assert!((result[2] - 4.5).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cls_pool_via_api() {
|
||||
let backend = CpuBackend;
|
||||
let shaders = ShaderRegistry::new();
|
||||
let pooler = GpuPooler::new(&backend, &shaders).unwrap();
|
||||
|
||||
// batch=2, seq=3, hidden=4
|
||||
let tokens = vec![
|
||||
// Batch 0
|
||||
1.0, 2.0, 3.0, 4.0, // CLS token
|
||||
5.0, 6.0, 7.0, 8.0,
|
||||
9.0, 10.0, 11.0, 12.0,
|
||||
// Batch 1
|
||||
10.0, 20.0, 30.0, 40.0, // CLS token
|
||||
50.0, 60.0, 70.0, 80.0,
|
||||
90.0, 100.0, 110.0, 120.0,
|
||||
];
|
||||
|
||||
let result = pooler.cls_pool(&tokens, 2, 4).unwrap();
|
||||
|
||||
assert_eq!(result.len(), 8);
|
||||
|
||||
// Batch 0: first token
|
||||
assert!((result[0] - 1.0).abs() < 1e-6);
|
||||
assert!((result[1] - 2.0).abs() < 1e-6);
|
||||
assert!((result[2] - 3.0).abs() < 1e-6);
|
||||
assert!((result[3] - 4.0).abs() < 1e-6);
|
||||
|
||||
// Batch 1: first token
|
||||
assert!((result[4] - 10.0).abs() < 1e-6);
|
||||
assert!((result[5] - 20.0).abs() < 1e-6);
|
||||
assert!((result[6] - 30.0).abs() < 1e-6);
|
||||
assert!((result[7] - 40.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool_via_api() {
|
||||
let backend = CpuBackend;
|
||||
let shaders = ShaderRegistry::new();
|
||||
let pooler = GpuPooler::new(&backend, &shaders).unwrap();
|
||||
|
||||
// batch=1, seq=3, hidden=4
|
||||
let tokens = vec![
|
||||
1.0, 10.0, 3.0, 4.0, // seq 0
|
||||
5.0, 2.0, 7.0, 8.0, // seq 1
|
||||
9.0, 6.0, 11.0, 0.0, // seq 2
|
||||
];
|
||||
|
||||
let mask = vec![1i64, 1, 1];
|
||||
|
||||
let result = pooler.max_pool(&tokens, &mask, 1, 3, 4).unwrap();
|
||||
|
||||
assert_eq!(result.len(), 4);
|
||||
|
||||
// Max across all sequences for each dimension
|
||||
assert!((result[0] - 9.0).abs() < 1e-6); // max(1, 5, 9)
|
||||
assert!((result[1] - 10.0).abs() < 1e-6); // max(10, 2, 6)
|
||||
assert!((result[2] - 11.0).abs() < 1e-6); // max(3, 7, 11)
|
||||
assert!((result[3] - 8.0).abs() < 1e-6); // max(4, 8, 0)
|
||||
}
|
||||
|
||||
// ==================== Vector Operations Tests ====================
|
||||
|
||||
#[test]
|
||||
fn test_normalize_batch() {
|
||||
let backend = CpuBackend;
|
||||
let shaders = ShaderRegistry::new();
|
||||
let ops = GpuVectorOps::new(&backend, &shaders).unwrap();
|
||||
|
||||
let mut vectors = vec![
|
||||
3.0, 4.0, 0.0, // norm = 5, normalized = [0.6, 0.8, 0]
|
||||
0.0, 0.0, 5.0, // norm = 5, normalized = [0, 0, 1]
|
||||
];
|
||||
|
||||
ops.normalize_batch(&mut vectors, 3).unwrap();
|
||||
|
||||
// Check first vector
|
||||
assert!((vectors[0] - 0.6).abs() < 1e-6);
|
||||
assert!((vectors[1] - 0.8).abs() < 1e-6);
|
||||
assert!(vectors[2].abs() < 1e-6);
|
||||
|
||||
// Check second vector
|
||||
assert!(vectors[3].abs() < 1e-6);
|
||||
assert!(vectors[4].abs() < 1e-6);
|
||||
assert!((vectors[5] - 1.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matmul() {
|
||||
let backend = CpuBackend;
|
||||
let shaders = ShaderRegistry::new();
|
||||
let ops = GpuVectorOps::new(&backend, &shaders).unwrap();
|
||||
|
||||
// 2x3 matrix
|
||||
let matrix = vec![
|
||||
1.0, 2.0, 3.0,
|
||||
4.0, 5.0, 6.0,
|
||||
];
|
||||
|
||||
// 3x1 vector
|
||||
let vector = vec![1.0, 1.0, 1.0];
|
||||
|
||||
let result = ops.matmul(&matrix, &vector, 2, 3).unwrap();
|
||||
|
||||
assert_eq!(result.len(), 2);
|
||||
assert!((result[0] - 6.0).abs() < 1e-6); // 1+2+3
|
||||
assert!((result[1] - 15.0).abs() < 1e-6); // 4+5+6
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_add() {
|
||||
let backend = CpuBackend;
|
||||
let shaders = ShaderRegistry::new();
|
||||
let ops = GpuVectorOps::new(&backend, &shaders).unwrap();
|
||||
|
||||
let a = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let b = vec![5.0, 6.0, 7.0, 8.0];
|
||||
|
||||
let result = ops.batch_add(&a, &b).unwrap();
|
||||
|
||||
assert_eq!(result, vec![6.0, 8.0, 10.0, 12.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_scale() {
|
||||
let backend = CpuBackend;
|
||||
let shaders = ShaderRegistry::new();
|
||||
let ops = GpuVectorOps::new(&backend, &shaders).unwrap();
|
||||
|
||||
let mut vectors = vec![1.0, 2.0, 3.0, 4.0];
|
||||
|
||||
ops.batch_scale(&mut vectors, 2.0).unwrap();
|
||||
|
||||
assert_eq!(vectors, vec![2.0, 4.0, 6.0, 8.0]);
|
||||
}
|
||||
|
||||
// ==================== Integration Tests ====================
|
||||
|
||||
#[test]
|
||||
fn test_gpu_similarity_with_backend() {
|
||||
let backend = CpuBackend;
|
||||
let shaders = ShaderRegistry::new();
|
||||
let similarity = GpuSimilarity::new(&backend, &shaders).unwrap();
|
||||
|
||||
let query = vec![1.0, 0.0, 0.0];
|
||||
let candidates: Vec<&[f32]> = vec![
|
||||
&[1.0, 0.0, 0.0][..],
|
||||
&[0.0, 1.0, 0.0][..],
|
||||
];
|
||||
|
||||
let results = similarity.batch_cosine(&query, &candidates).unwrap();
|
||||
|
||||
assert_eq!(results.len(), 2);
|
||||
assert!((results[0] - 1.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_top_k_similar() {
|
||||
let backend = CpuBackend;
|
||||
let shaders = ShaderRegistry::new();
|
||||
let similarity = GpuSimilarity::new(&backend, &shaders).unwrap();
|
||||
|
||||
let query = vec![1.0, 0.0, 0.0];
|
||||
let candidates: Vec<&[f32]> = vec![
|
||||
&[0.0, 1.0, 0.0][..], // sim = 0
|
||||
&[1.0, 0.0, 0.0][..], // sim = 1 (best)
|
||||
&[0.5, 0.5, 0.0][..], // sim ≈ 0.707
|
||||
&[-1.0, 0.0, 0.0][..], // sim = -1 (worst)
|
||||
];
|
||||
|
||||
let top2 = similarity.top_k(&query, &candidates, 2).unwrap();
|
||||
|
||||
assert_eq!(top2.len(), 2);
|
||||
assert_eq!(top2[0].0, 1); // Index of [1,0,0]
|
||||
assert_eq!(top2[1].0, 2); // Index of [0.5,0.5,0]
|
||||
}
|
||||
|
||||
// ==================== Memory Stats Tests ====================
|
||||
|
||||
#[test]
|
||||
fn test_memory_stats() {
|
||||
let stats = GpuMemoryStats {
|
||||
total: 1024 * 1024 * 1024, // 1GB
|
||||
used: 512 * 1024 * 1024, // 512MB
|
||||
free: 512 * 1024 * 1024,
|
||||
peak: 768 * 1024 * 1024,
|
||||
};
|
||||
|
||||
assert!((stats.usage_percent() - 50.0).abs() < 0.1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_memory_stats() {
|
||||
let stats = GpuMemoryStats::default();
|
||||
assert_eq!(stats.usage_percent(), 0.0);
|
||||
}
|
||||
|
||||
// ==================== Backend Tests ====================
|
||||
|
||||
#[test]
|
||||
fn test_cpu_backend_info() {
|
||||
let backend = CpuBackend;
|
||||
|
||||
assert!(backend.is_available());
|
||||
|
||||
let info = backend.device_info();
|
||||
assert_eq!(info.backend, "CPU");
|
||||
assert!(!info.supports_compute);
|
||||
}
|
||||
187
vendor/ruvector/examples/onnx-embeddings/src/lib.rs
vendored
Normal file
187
vendor/ruvector/examples/onnx-embeddings/src/lib.rs
vendored
Normal file
@@ -0,0 +1,187 @@
|
||||
//! # RuVector ONNX Embeddings
|
||||
//!
|
||||
//! A reimagined embedding pipeline for RuVector using ONNX Runtime in pure Rust.
|
||||
//!
|
||||
//! This crate provides:
|
||||
//! - Native ONNX model inference for embedding generation
|
||||
//! - HuggingFace tokenizer integration
|
||||
//! - Batch processing with SIMD optimization
|
||||
//! - Direct RuVector vector database integration
|
||||
//! - Model management and caching
|
||||
//!
|
||||
//! ## Architecture
|
||||
//!
|
||||
//! ```text
|
||||
//! ┌─────────────────────────────────────────────────────────────────┐
|
||||
//! │ RuVector ONNX Embeddings │
|
||||
//! ├─────────────────────────────────────────────────────────────────┤
|
||||
//! │ │
|
||||
//! │ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
|
||||
//! │ │ Text Input │ -> │ Tokenizer │ -> │ Token IDs │ │
|
||||
//! │ └──────────────┘ └──────────────┘ └──────────────┘ │
|
||||
//! │ │ │
|
||||
//! │ v │
|
||||
//! │ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
|
||||
//! │ │ Embeddings │ <- │ ONNX Runtime │ <- │ Input Tensor │ │
|
||||
//! │ └──────────────┘ └──────────────┘ └──────────────┘ │
|
||||
//! │ │ │
|
||||
//! │ v │
|
||||
//! │ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
|
||||
//! │ │ Normalize │ -> │ Mean Pooling │ -> │ RuVector DB │ │
|
||||
//! │ └──────────────┘ └──────────────┘ └──────────────┘ │
|
||||
//! │ │
|
||||
//! └─────────────────────────────────────────────────────────────────┘
|
||||
//! ```
|
||||
//!
|
||||
//! ## Example
|
||||
//!
|
||||
//! ```rust,ignore
|
||||
//! use ruvector_onnx_embeddings::{Embedder, EmbedderConfig, ModelSource};
|
||||
//!
|
||||
//! #[tokio::main]
|
||||
//! async fn main() -> anyhow::Result<()> {
|
||||
//! // Create embedder with default model (all-MiniLM-L6-v2)
|
||||
//! let embedder = Embedder::new(EmbedderConfig::default()).await?;
|
||||
//!
|
||||
//! // Generate embeddings
|
||||
//! let texts = vec!["Hello, world!", "Rust is awesome!"];
|
||||
//! let embeddings = embedder.embed(&texts)?;
|
||||
//!
|
||||
//! // Use with RuVector
|
||||
//! let db = embedder.create_ruvector_index("my_index")?;
|
||||
//! db.insert_with_embeddings(&texts, &embeddings)?;
|
||||
//!
|
||||
//! Ok(())
|
||||
//! }
|
||||
//! ```
|
||||
|
||||
pub mod config;
|
||||
pub mod embedder;
|
||||
pub mod error;
|
||||
pub mod model;
|
||||
pub mod pooling;
|
||||
pub mod ruvector_integration;
|
||||
pub mod tokenizer;
|
||||
|
||||
/// GPU acceleration module (optional, requires `gpu` feature)
|
||||
#[cfg(feature = "gpu")]
|
||||
pub mod gpu;
|
||||
|
||||
/// GPU module stub for when feature is disabled
|
||||
#[cfg(not(feature = "gpu"))]
|
||||
pub mod gpu {
|
||||
//! GPU acceleration is not available without the `gpu` feature.
|
||||
//!
|
||||
//! Enable with: `cargo build --features gpu`
|
||||
|
||||
/// Placeholder for GpuConfig when GPU feature is disabled
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct GpuConfig;
|
||||
|
||||
impl GpuConfig {
|
||||
/// Create default config (no-op without GPU feature)
|
||||
pub fn auto() -> Self { Self }
|
||||
/// CPU-only config
|
||||
pub fn cpu_only() -> Self { Self }
|
||||
}
|
||||
|
||||
/// Check if GPU is available (always false without feature)
|
||||
pub async fn is_gpu_available() -> bool { false }
|
||||
}
|
||||
|
||||
// Re-exports
|
||||
pub use config::{EmbedderConfig, ModelSource, PoolingStrategy};
|
||||
pub use embedder::{Embedder, EmbedderBuilder, EmbeddingOutput};
|
||||
pub use error::{EmbeddingError, Result};
|
||||
pub use model::{OnnxModel, ModelInfo};
|
||||
pub use pooling::Pooler;
|
||||
pub use ruvector_integration::{
|
||||
Distance, IndexConfig, RagPipeline, RuVectorBuilder, RuVectorEmbeddings, SearchResult, VectorId,
|
||||
};
|
||||
pub use tokenizer::Tokenizer;
|
||||
|
||||
// GPU exports (conditional)
|
||||
#[cfg(feature = "gpu")]
|
||||
pub use gpu::{
|
||||
GpuAccelerator, GpuConfig, GpuMode, GpuInfo, GpuBackend,
|
||||
HybridAccelerator, is_gpu_available,
|
||||
};
|
||||
|
||||
/// Prelude module for convenient imports
|
||||
pub mod prelude {
|
||||
pub use crate::{
|
||||
Distance, Embedder, EmbedderBuilder, EmbedderConfig, EmbeddingError,
|
||||
IndexConfig, ModelSource, PoolingStrategy, RagPipeline, Result,
|
||||
RuVectorBuilder, RuVectorEmbeddings, SearchResult, VectorId,
|
||||
};
|
||||
}
|
||||
|
||||
/// Supported embedding models with pre-configured settings
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
|
||||
pub enum PretrainedModel {
|
||||
/// all-MiniLM-L6-v2: 384 dimensions, fast inference
|
||||
#[default]
|
||||
AllMiniLmL6V2,
|
||||
/// all-MiniLM-L12-v2: 384 dimensions, better quality
|
||||
AllMiniLmL12V2,
|
||||
/// all-mpnet-base-v2: 768 dimensions, high quality
|
||||
AllMpnetBaseV2,
|
||||
/// multi-qa-MiniLM-L6: 384 dimensions, optimized for QA
|
||||
MultiQaMiniLmL6,
|
||||
/// paraphrase-MiniLM-L6-v2: 384 dimensions, paraphrase detection
|
||||
ParaphraseMiniLmL6V2,
|
||||
/// BGE-small-en-v1.5: 384 dimensions, BAAI General Embeddings
|
||||
BgeSmallEnV15,
|
||||
/// E5-small-v2: 384 dimensions, Microsoft E5 model
|
||||
E5SmallV2,
|
||||
/// GTE-small: 384 dimensions, Alibaba GTE model
|
||||
GteSmall,
|
||||
}
|
||||
|
||||
impl PretrainedModel {
|
||||
/// Get the HuggingFace model ID
|
||||
pub fn model_id(&self) -> &'static str {
|
||||
match self {
|
||||
Self::AllMiniLmL6V2 => "sentence-transformers/all-MiniLM-L6-v2",
|
||||
Self::AllMiniLmL12V2 => "sentence-transformers/all-MiniLM-L12-v2",
|
||||
Self::AllMpnetBaseV2 => "sentence-transformers/all-mpnet-base-v2",
|
||||
Self::MultiQaMiniLmL6 => "sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
|
||||
Self::ParaphraseMiniLmL6V2 => "sentence-transformers/paraphrase-MiniLM-L6-v2",
|
||||
Self::BgeSmallEnV15 => "BAAI/bge-small-en-v1.5",
|
||||
Self::E5SmallV2 => "intfloat/e5-small-v2",
|
||||
Self::GteSmall => "thenlper/gte-small",
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the embedding dimension
|
||||
pub fn dimension(&self) -> usize {
|
||||
match self {
|
||||
Self::AllMiniLmL6V2
|
||||
| Self::AllMiniLmL12V2
|
||||
| Self::MultiQaMiniLmL6
|
||||
| Self::ParaphraseMiniLmL6V2
|
||||
| Self::BgeSmallEnV15
|
||||
| Self::E5SmallV2
|
||||
| Self::GteSmall => 384,
|
||||
Self::AllMpnetBaseV2 => 768,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get recommended max sequence length
|
||||
pub fn max_seq_length(&self) -> usize {
|
||||
match self {
|
||||
Self::AllMiniLmL6V2
|
||||
| Self::AllMiniLmL12V2
|
||||
| Self::MultiQaMiniLmL6
|
||||
| Self::ParaphraseMiniLmL6V2 => 256,
|
||||
Self::AllMpnetBaseV2 => 384,
|
||||
Self::BgeSmallEnV15 | Self::E5SmallV2 | Self::GteSmall => 512,
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether the model requires normalized outputs
|
||||
pub fn normalize_output(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
265
vendor/ruvector/examples/onnx-embeddings/src/main.rs
vendored
Normal file
265
vendor/ruvector/examples/onnx-embeddings/src/main.rs
vendored
Normal file
@@ -0,0 +1,265 @@
|
||||
//! RuVector ONNX Embeddings - Example Usage
|
||||
//!
|
||||
//! This example demonstrates how to use ONNX-based embedding generation
|
||||
//! with RuVector for semantic search and RAG pipelines.
|
||||
|
||||
use anyhow::Result;
|
||||
use ruvector_onnx_embeddings::{
|
||||
prelude::*, EmbedderBuilder, PretrainedModel, PoolingStrategy,
|
||||
RuVectorBuilder, RagPipeline, Distance,
|
||||
};
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
// Initialize logging
|
||||
tracing_subscriber::registry()
|
||||
.with(
|
||||
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| "info".into()),
|
||||
)
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.init();
|
||||
|
||||
println!("╔═══════════════════════════════════════════════════════════════╗");
|
||||
println!("║ RuVector ONNX Embeddings - Reimagined for Rust ║");
|
||||
println!("╚═══════════════════════════════════════════════════════════════╝");
|
||||
println!();
|
||||
|
||||
// Run examples
|
||||
basic_embedding_example().await?;
|
||||
batch_embedding_example().await?;
|
||||
semantic_search_example().await?;
|
||||
rag_pipeline_example().await?;
|
||||
clustering_example().await?;
|
||||
|
||||
println!("\n✅ All examples completed successfully!");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Basic embedding generation
|
||||
async fn basic_embedding_example() -> Result<()> {
|
||||
println!("\n━━━ Example 1: Basic Embedding Generation ━━━");
|
||||
|
||||
// Create embedder with default model (all-MiniLM-L6-v2)
|
||||
let mut embedder = Embedder::default_model().await?;
|
||||
|
||||
println!("Model: {}", embedder.model_info().name);
|
||||
println!("Dimension: {}", embedder.dimension());
|
||||
|
||||
// Embed a single sentence
|
||||
let text = "The quick brown fox jumps over the lazy dog.";
|
||||
let embedding = embedder.embed_one(text)?;
|
||||
|
||||
println!("Input: \"{}\"", text);
|
||||
println!("Embedding shape: [{}]", embedding.len());
|
||||
println!(
|
||||
"First 5 values: [{:.4}, {:.4}, {:.4}, {:.4}, {:.4}]",
|
||||
embedding[0], embedding[1], embedding[2], embedding[3], embedding[4]
|
||||
);
|
||||
|
||||
// Compute similarity between two sentences
|
||||
let text1 = "I love programming in Rust.";
|
||||
let text2 = "Rust is my favorite programming language.";
|
||||
let text3 = "The weather is nice today.";
|
||||
|
||||
let sim_related = embedder.similarity(text1, text2)?;
|
||||
let sim_unrelated = embedder.similarity(text1, text3)?;
|
||||
|
||||
println!("\nSimilarity comparisons:");
|
||||
println!(" \"{}\"\n vs\n \"{}\"", text1, text2);
|
||||
println!(" Similarity: {:.4}", sim_related);
|
||||
println!();
|
||||
println!(" \"{}\"\n vs\n \"{}\"", text1, text3);
|
||||
println!(" Similarity: {:.4}", sim_unrelated);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Batch embedding with parallel processing
|
||||
async fn batch_embedding_example() -> Result<()> {
|
||||
println!("\n━━━ Example 2: Batch Embedding ━━━");
|
||||
|
||||
// Create embedder with custom configuration
|
||||
let mut embedder = EmbedderBuilder::new()
|
||||
.pretrained(PretrainedModel::AllMiniLmL6V2)
|
||||
.pooling(PoolingStrategy::Mean)
|
||||
.normalize(true)
|
||||
.batch_size(64)
|
||||
.build()
|
||||
.await?;
|
||||
|
||||
let texts = vec![
|
||||
"Artificial intelligence is transforming technology.",
|
||||
"Machine learning models learn from data.",
|
||||
"Deep learning uses neural networks.",
|
||||
"Natural language processing understands text.",
|
||||
"Computer vision analyzes images.",
|
||||
"Reinforcement learning optimizes decisions.",
|
||||
"Vector databases enable semantic search.",
|
||||
"Embeddings capture semantic meaning.",
|
||||
];
|
||||
|
||||
println!("Embedding {} texts...", texts.len());
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let output = embedder.embed(&texts)?;
|
||||
let elapsed = start.elapsed();
|
||||
|
||||
println!("Completed in {:?}", elapsed);
|
||||
println!("Total embeddings: {}", output.len());
|
||||
println!("Embedding dimension: {}", output.dimension);
|
||||
|
||||
// Show token counts
|
||||
println!("\nToken counts per text:");
|
||||
for (i, (text, tokens)) in texts.iter().zip(output.token_counts.iter()).enumerate() {
|
||||
println!(" [{}] {} tokens: \"{}...\"", i, tokens, &text[..40.min(text.len())]);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Semantic search with RuVector
|
||||
async fn semantic_search_example() -> Result<()> {
|
||||
println!("\n━━━ Example 3: Semantic Search with RuVector ━━━");
|
||||
|
||||
// Create embedder
|
||||
let embedder = Embedder::default_model().await?;
|
||||
|
||||
// Create RuVector index
|
||||
let index = RuVectorBuilder::new("semantic_search")
|
||||
.embedder(embedder)
|
||||
.distance(Distance::Cosine)
|
||||
.max_elements(10_000)
|
||||
.build()?;
|
||||
|
||||
// Knowledge base about programming languages
|
||||
let documents = vec![
|
||||
"Rust is a systems programming language focused on safety and performance.",
|
||||
"Python is widely used for machine learning and data science applications.",
|
||||
"JavaScript is the language of the web, running in browsers everywhere.",
|
||||
"Go is designed for building scalable and efficient server applications.",
|
||||
"TypeScript adds static typing to JavaScript for better developer experience.",
|
||||
"C++ provides low-level control and high performance for system software.",
|
||||
"Java is a mature, object-oriented language popular in enterprise software.",
|
||||
"Swift is Apple's modern language for iOS and macOS development.",
|
||||
"Kotlin is a concise language that runs on the JVM, popular for Android.",
|
||||
"Haskell is a purely functional programming language with strong typing.",
|
||||
];
|
||||
|
||||
println!("Indexing {} documents...", documents.len());
|
||||
index.insert_batch(&documents)?;
|
||||
|
||||
println!("Index size: {} vectors", index.len());
|
||||
|
||||
// Perform searches
|
||||
let queries = vec![
|
||||
"What language is best for web development?",
|
||||
"I want to build a high-performance system application",
|
||||
"Which language should I learn for machine learning?",
|
||||
"I need a language for mobile app development",
|
||||
];
|
||||
|
||||
for query in queries {
|
||||
println!("\n🔍 Query: \"{}\"", query);
|
||||
let results = index.search(query, 3)?;
|
||||
|
||||
for (i, result) in results.iter().enumerate() {
|
||||
println!(
|
||||
" {}. (score: {:.4}) {}",
|
||||
i + 1,
|
||||
result.score,
|
||||
result.text
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// RAG (Retrieval-Augmented Generation) pipeline
|
||||
async fn rag_pipeline_example() -> Result<()> {
|
||||
println!("\n━━━ Example 4: RAG Pipeline ━━━");
|
||||
|
||||
let embedder = Embedder::default_model().await?;
|
||||
|
||||
let index = RuVectorEmbeddings::new_default("rag_index", embedder)?;
|
||||
let rag = RagPipeline::new(index, 3);
|
||||
|
||||
// Add knowledge base
|
||||
let knowledge = vec![
|
||||
"RuVector is a distributed vector database that learns and adapts.",
|
||||
"RuVector uses HNSW indexing for fast approximate nearest neighbor search.",
|
||||
"The embedding dimension in RuVector is configurable based on your model.",
|
||||
"RuVector supports multiple distance metrics: Cosine, Euclidean, and Dot Product.",
|
||||
"Graph Neural Networks in RuVector improve search quality over time.",
|
||||
"RuVector integrates with ONNX models for native embedding generation.",
|
||||
"The NAPI-RS bindings allow using RuVector from Node.js applications.",
|
||||
"RuVector supports WebAssembly for running in web browsers.",
|
||||
"Raft consensus enables distributed deployment of RuVector clusters.",
|
||||
"Quantization in RuVector provides 2-32x memory compression.",
|
||||
];
|
||||
|
||||
println!("Loading {} documents into RAG pipeline...", knowledge.len());
|
||||
rag.add_documents(&knowledge)?;
|
||||
|
||||
// Generate context for questions
|
||||
let questions = vec![
|
||||
"How does RuVector achieve fast search?",
|
||||
"Can I use RuVector in a web browser?",
|
||||
"What compression options does RuVector have?",
|
||||
];
|
||||
|
||||
for question in questions {
|
||||
println!("\n❓ Question: {}", question);
|
||||
let context = rag.format_context(question)?;
|
||||
println!("Generated Context:\n{}", context);
|
||||
println!("{}", "─".repeat(60));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Text clustering example
|
||||
async fn clustering_example() -> Result<()> {
|
||||
println!("\n━━━ Example 5: Text Clustering ━━━");
|
||||
|
||||
let mut embedder = Embedder::default_model().await?;
|
||||
|
||||
// Texts from different categories
|
||||
let texts = vec![
|
||||
// Technology
|
||||
"Artificial intelligence is revolutionizing industries.",
|
||||
"Machine learning algorithms process large datasets.",
|
||||
"Neural networks mimic the human brain.",
|
||||
// Sports
|
||||
"Football is the most popular sport worldwide.",
|
||||
"Basketball requires speed and agility.",
|
||||
"Tennis is played on different court surfaces.",
|
||||
// Food
|
||||
"Italian pasta comes in many shapes and sizes.",
|
||||
"Sushi is a traditional Japanese dish.",
|
||||
"French cuisine is known for its elegance.",
|
||||
];
|
||||
|
||||
println!("Clustering {} texts into 3 categories...", texts.len());
|
||||
|
||||
let clusters = embedder.cluster(&texts, 3)?;
|
||||
|
||||
// Group texts by cluster
|
||||
let mut groups: std::collections::HashMap<usize, Vec<&str>> = std::collections::HashMap::new();
|
||||
for (i, &cluster) in clusters.iter().enumerate() {
|
||||
groups.entry(cluster).or_default().push(texts[i]);
|
||||
}
|
||||
|
||||
println!("\nCluster assignments:");
|
||||
for (cluster_id, members) in groups.iter() {
|
||||
println!("\n📁 Cluster {}:", cluster_id);
|
||||
for text in members {
|
||||
println!(" • {}", text);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
470
vendor/ruvector/examples/onnx-embeddings/src/model.rs
vendored
Normal file
470
vendor/ruvector/examples/onnx-embeddings/src/model.rs
vendored
Normal file
@@ -0,0 +1,470 @@
|
||||
//! ONNX model loading and management
|
||||
|
||||
use crate::config::{EmbedderConfig, ExecutionProvider, ModelSource};
|
||||
use crate::{EmbeddingError, PretrainedModel, Result};
|
||||
use indicatif::{ProgressBar, ProgressStyle};
|
||||
use ort::session::{builder::GraphOptimizationLevel, Session};
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::fs;
|
||||
use std::io::Write;
|
||||
use std::path::Path;
|
||||
use tracing::{debug, info, instrument, warn};
|
||||
|
||||
/// Information about a loaded model
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelInfo {
|
||||
/// Model name or identifier
|
||||
pub name: String,
|
||||
/// Embedding dimension
|
||||
pub dimension: usize,
|
||||
/// Maximum sequence length
|
||||
pub max_seq_length: usize,
|
||||
/// Model file size in bytes
|
||||
pub file_size: u64,
|
||||
/// Model input names
|
||||
pub input_names: Vec<String>,
|
||||
/// Model output names
|
||||
pub output_names: Vec<String>,
|
||||
}
|
||||
|
||||
/// ONNX model wrapper with inference capabilities
|
||||
pub struct OnnxModel {
|
||||
session: Session,
|
||||
info: ModelInfo,
|
||||
}
|
||||
|
||||
impl OnnxModel {
|
||||
/// Load model from configuration
|
||||
#[instrument(skip_all)]
|
||||
pub async fn from_config(config: &EmbedderConfig) -> Result<Self> {
|
||||
match &config.model_source {
|
||||
ModelSource::Local {
|
||||
model_path,
|
||||
tokenizer_path: _,
|
||||
} => Self::from_file(model_path, config).await,
|
||||
|
||||
ModelSource::Pretrained(model) => Self::from_pretrained(*model, config).await,
|
||||
|
||||
ModelSource::HuggingFace { model_id, revision } => {
|
||||
Self::from_huggingface(model_id, revision.as_deref(), config).await
|
||||
}
|
||||
|
||||
ModelSource::Url {
|
||||
model_url,
|
||||
tokenizer_url: _,
|
||||
} => Self::from_url(model_url, config).await,
|
||||
}
|
||||
}
|
||||
|
||||
/// Load model from a local ONNX file
|
||||
#[instrument(skip_all, fields(path = %path.as_ref().display()))]
|
||||
pub async fn from_file(path: impl AsRef<Path>, config: &EmbedderConfig) -> Result<Self> {
|
||||
let path = path.as_ref();
|
||||
info!("Loading ONNX model from file: {}", path.display());
|
||||
|
||||
if !path.exists() {
|
||||
return Err(EmbeddingError::model_not_found(path.display().to_string()));
|
||||
}
|
||||
|
||||
let file_size = fs::metadata(path)?.len();
|
||||
let session = Self::create_session(path, config)?;
|
||||
let info = Self::extract_model_info(&session, path, file_size)?;
|
||||
|
||||
Ok(Self { session, info })
|
||||
}
|
||||
|
||||
/// Load a pretrained model (downloads if not cached)
|
||||
#[instrument(skip_all, fields(model = ?model))]
|
||||
pub async fn from_pretrained(model: PretrainedModel, config: &EmbedderConfig) -> Result<Self> {
|
||||
let model_id = model.model_id();
|
||||
info!("Loading pretrained model: {}", model_id);
|
||||
|
||||
// Check cache first
|
||||
let cache_path = config.cache_dir.join(sanitize_model_id(model_id));
|
||||
let model_path = cache_path.join("model.onnx");
|
||||
|
||||
if model_path.exists() {
|
||||
debug!("Found cached model at {}", model_path.display());
|
||||
return Self::from_file(&model_path, config).await;
|
||||
}
|
||||
|
||||
// Download from HuggingFace
|
||||
Self::from_huggingface(model_id, None, config).await
|
||||
}
|
||||
|
||||
/// Load model from HuggingFace Hub
|
||||
#[instrument(skip_all, fields(model_id = %model_id))]
|
||||
pub async fn from_huggingface(
|
||||
model_id: &str,
|
||||
revision: Option<&str>,
|
||||
config: &EmbedderConfig,
|
||||
) -> Result<Self> {
|
||||
let cache_path = config.cache_dir.join(sanitize_model_id(model_id));
|
||||
fs::create_dir_all(&cache_path)?;
|
||||
|
||||
let model_path = cache_path.join("model.onnx");
|
||||
|
||||
if !model_path.exists() {
|
||||
info!("Downloading model from HuggingFace: {}", model_id);
|
||||
download_from_huggingface(model_id, revision, &cache_path, config.show_progress)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Self::from_file(&model_path, config).await
|
||||
}
|
||||
|
||||
/// Load model from a URL
|
||||
#[instrument(skip_all, fields(url = %url))]
|
||||
pub async fn from_url(url: &str, config: &EmbedderConfig) -> Result<Self> {
|
||||
let hash = hash_url(url);
|
||||
let cache_path = config.cache_dir.join(&hash);
|
||||
fs::create_dir_all(&cache_path)?;
|
||||
|
||||
let model_path = cache_path.join("model.onnx");
|
||||
|
||||
if !model_path.exists() {
|
||||
info!("Downloading model from URL: {}", url);
|
||||
download_file(url, &model_path, config.show_progress).await?;
|
||||
}
|
||||
|
||||
Self::from_file(&model_path, config).await
|
||||
}
|
||||
|
||||
/// Create an ONNX session with the specified configuration
|
||||
fn create_session(path: &Path, config: &EmbedderConfig) -> Result<Session> {
|
||||
let mut builder = Session::builder()?;
|
||||
|
||||
// Set optimization level
|
||||
if config.optimize_graph {
|
||||
builder = builder.with_optimization_level(GraphOptimizationLevel::Level3)?;
|
||||
}
|
||||
|
||||
// Set number of threads
|
||||
builder = builder.with_intra_threads(config.num_threads)?;
|
||||
|
||||
// Configure execution provider
|
||||
match config.execution_provider {
|
||||
ExecutionProvider::Cpu => {
|
||||
// Default CPU provider
|
||||
}
|
||||
#[cfg(feature = "cuda")]
|
||||
ExecutionProvider::Cuda { device_id } => {
|
||||
builder = builder.with_execution_providers([
|
||||
ort::execution_providers::CUDAExecutionProvider::default()
|
||||
.with_device_id(device_id)
|
||||
.build(),
|
||||
])?;
|
||||
}
|
||||
#[cfg(feature = "tensorrt")]
|
||||
ExecutionProvider::TensorRt { device_id } => {
|
||||
builder = builder.with_execution_providers([
|
||||
ort::execution_providers::TensorRTExecutionProvider::default()
|
||||
.with_device_id(device_id)
|
||||
.build(),
|
||||
])?;
|
||||
}
|
||||
#[cfg(feature = "coreml")]
|
||||
ExecutionProvider::CoreMl => {
|
||||
builder = builder.with_execution_providers([
|
||||
ort::execution_providers::CoreMLExecutionProvider::default().build(),
|
||||
])?;
|
||||
}
|
||||
_ => {
|
||||
warn!(
|
||||
"Requested execution provider not available, falling back to CPU"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let session = builder.commit_from_file(path)?;
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
/// Extract model information from the session
|
||||
fn extract_model_info(session: &Session, path: &Path, file_size: u64) -> Result<ModelInfo> {
|
||||
let inputs: Vec<String> = session.inputs.iter().map(|i| i.name.clone()).collect();
|
||||
let outputs: Vec<String> = session.outputs.iter().map(|o| o.name.clone()).collect();
|
||||
|
||||
// Default embedding dimension (will be determined at runtime from actual output)
|
||||
// Most sentence-transformers models output 384 dimensions
|
||||
let dimension = 384;
|
||||
|
||||
let name = path
|
||||
.file_stem()
|
||||
.map(|s| s.to_string_lossy().to_string())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
|
||||
Ok(ModelInfo {
|
||||
name,
|
||||
dimension,
|
||||
max_seq_length: 512,
|
||||
file_size,
|
||||
input_names: inputs,
|
||||
output_names: outputs,
|
||||
})
|
||||
}
|
||||
|
||||
/// Run inference on encoded inputs
|
||||
#[instrument(skip_all, fields(batch_size, seq_length))]
|
||||
pub fn run(
|
||||
&mut self,
|
||||
input_ids: &[i64],
|
||||
attention_mask: &[i64],
|
||||
token_type_ids: &[i64],
|
||||
shape: &[usize],
|
||||
) -> Result<Vec<Vec<f32>>> {
|
||||
use ort::value::Tensor;
|
||||
|
||||
let batch_size = shape[0];
|
||||
let seq_length = shape[1];
|
||||
|
||||
debug!(
|
||||
"Running inference: batch_size={}, seq_length={}",
|
||||
batch_size, seq_length
|
||||
);
|
||||
|
||||
// Create input tensors using ort's Tensor type
|
||||
let input_ids_tensor = Tensor::from_array((
|
||||
vec![batch_size, seq_length],
|
||||
input_ids.to_vec().into_boxed_slice(),
|
||||
))
|
||||
.map_err(|e| EmbeddingError::invalid_model(e.to_string()))?;
|
||||
|
||||
let attention_mask_tensor = Tensor::from_array((
|
||||
vec![batch_size, seq_length],
|
||||
attention_mask.to_vec().into_boxed_slice(),
|
||||
))
|
||||
.map_err(|e| EmbeddingError::invalid_model(e.to_string()))?;
|
||||
|
||||
let token_type_ids_tensor = Tensor::from_array((
|
||||
vec![batch_size, seq_length],
|
||||
token_type_ids.to_vec().into_boxed_slice(),
|
||||
))
|
||||
.map_err(|e| EmbeddingError::invalid_model(e.to_string()))?;
|
||||
|
||||
// Build inputs vector
|
||||
let inputs = vec![
|
||||
("input_ids", input_ids_tensor.into_dyn()),
|
||||
("attention_mask", attention_mask_tensor.into_dyn()),
|
||||
("token_type_ids", token_type_ids_tensor.into_dyn()),
|
||||
];
|
||||
|
||||
// Run inference
|
||||
let outputs = self.session.run(inputs)
|
||||
.map_err(EmbeddingError::OnnxRuntime)?;
|
||||
|
||||
// Extract output tensor
|
||||
// Usually the output is [batch, seq_len, hidden_size] or [batch, hidden_size]
|
||||
let output_names = ["last_hidden_state", "output", "sentence_embedding"];
|
||||
|
||||
// Find the appropriate output by name, or use the first one
|
||||
let output_iter: Vec<_> = outputs.iter().collect();
|
||||
let output = output_iter
|
||||
.iter()
|
||||
.find(|(name, _)| output_names.contains(name))
|
||||
.or_else(|| output_iter.first())
|
||||
.map(|(_, v)| v)
|
||||
.ok_or_else(|| EmbeddingError::invalid_model("No output tensor found"))?;
|
||||
|
||||
// In ort 2.0, try_extract_tensor returns (&Shape, &[f32])
|
||||
let (tensor_shape, tensor_data) = output
|
||||
.try_extract_tensor::<f32>()
|
||||
.map_err(|e| EmbeddingError::invalid_model(e.to_string()))?;
|
||||
|
||||
// Convert Shape to Vec<usize> - Shape yields i64
|
||||
let dims: Vec<usize> = tensor_shape.iter().map(|&d| d as usize).collect();
|
||||
|
||||
// Handle different output shapes
|
||||
let embeddings = if dims.len() == 3 {
|
||||
// [batch, seq_len, hidden] - need pooling
|
||||
let hidden_size = dims[2];
|
||||
(0..batch_size)
|
||||
.map(|i| {
|
||||
let start = i * seq_length * hidden_size;
|
||||
let end = start + seq_length * hidden_size;
|
||||
tensor_data[start..end].to_vec()
|
||||
})
|
||||
.collect()
|
||||
} else if dims.len() == 2 {
|
||||
// [batch, hidden] - already pooled
|
||||
let hidden_size = dims[1];
|
||||
(0..batch_size)
|
||||
.map(|i| {
|
||||
let start = i * hidden_size;
|
||||
let end = start + hidden_size;
|
||||
tensor_data[start..end].to_vec()
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
return Err(EmbeddingError::invalid_model(format!(
|
||||
"Unexpected output shape: {:?}",
|
||||
dims
|
||||
)));
|
||||
};
|
||||
|
||||
Ok(embeddings)
|
||||
}
|
||||
|
||||
/// Get model info
|
||||
pub fn info(&self) -> &ModelInfo {
|
||||
&self.info
|
||||
}
|
||||
|
||||
/// Get embedding dimension
|
||||
pub fn dimension(&self) -> usize {
|
||||
self.info.dimension
|
||||
}
|
||||
}
|
||||
|
||||
/// Download model files from HuggingFace Hub
|
||||
async fn download_from_huggingface(
|
||||
model_id: &str,
|
||||
revision: Option<&str>,
|
||||
cache_path: &Path,
|
||||
show_progress: bool,
|
||||
) -> Result<()> {
|
||||
let revision = revision.unwrap_or("main");
|
||||
let base_url = format!(
|
||||
"https://huggingface.co/{}/resolve/{}",
|
||||
model_id, revision
|
||||
);
|
||||
|
||||
let model_path = cache_path.join("model.onnx");
|
||||
|
||||
// Try to download model.onnx - check multiple locations
|
||||
if !model_path.exists() {
|
||||
// Location 1: Root directory (model.onnx)
|
||||
let root_url = format!("{}/model.onnx", base_url);
|
||||
debug!("Trying to download model from root: {}", root_url);
|
||||
|
||||
let root_result = download_file(&root_url, &model_path, show_progress).await;
|
||||
|
||||
// Location 2: ONNX subfolder (onnx/model.onnx) - common for sentence-transformers
|
||||
if root_result.is_err() && !model_path.exists() {
|
||||
let onnx_url = format!("{}/onnx/model.onnx", base_url);
|
||||
debug!("Root download failed, trying onnx subfolder: {}", onnx_url);
|
||||
|
||||
match download_file(&onnx_url, &model_path, show_progress).await {
|
||||
Ok(_) => debug!("Downloaded model.onnx from onnx/ subfolder"),
|
||||
Err(e) => {
|
||||
// Both locations failed
|
||||
return Err(EmbeddingError::download_failed(format!(
|
||||
"Failed to download model.onnx from {} - tried both root and onnx/ subfolder: {}",
|
||||
model_id, e
|
||||
)));
|
||||
}
|
||||
}
|
||||
} else if let Err(e) = root_result {
|
||||
// Root failed but model exists (shouldn't happen, but handle gracefully)
|
||||
if !model_path.exists() {
|
||||
return Err(e);
|
||||
}
|
||||
} else {
|
||||
debug!("Downloaded model.onnx from root");
|
||||
}
|
||||
}
|
||||
|
||||
// Download auxiliary files (tokenizer.json, config.json) - these are optional
|
||||
let aux_files = ["tokenizer.json", "config.json"];
|
||||
for file in aux_files {
|
||||
let path = cache_path.join(file);
|
||||
if !path.exists() {
|
||||
// Try root first, then onnx subfolder
|
||||
let root_url = format!("{}/{}", base_url, file);
|
||||
match download_file(&root_url, &path, show_progress).await {
|
||||
Ok(_) => debug!("Downloaded {}", file),
|
||||
Err(_) => {
|
||||
// Try onnx subfolder
|
||||
let onnx_url = format!("{}/onnx/{}", base_url, file);
|
||||
match download_file(&onnx_url, &path, show_progress).await {
|
||||
Ok(_) => debug!("Downloaded {} from onnx/ subfolder", file),
|
||||
Err(e) => warn!("Failed to download {} (optional): {}", file, e),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Download a file from URL with optional progress bar
|
||||
async fn download_file(url: &str, path: &Path, show_progress: bool) -> Result<()> {
|
||||
let client = reqwest::Client::new();
|
||||
let response = client.get(url).send().await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(EmbeddingError::download_failed(format!(
|
||||
"HTTP {}: {}",
|
||||
response.status(),
|
||||
url
|
||||
)));
|
||||
}
|
||||
|
||||
let total_size = response.content_length().unwrap_or(0);
|
||||
|
||||
let pb = if show_progress && total_size > 0 {
|
||||
let pb = ProgressBar::new(total_size);
|
||||
pb.set_style(
|
||||
ProgressStyle::default_bar()
|
||||
.template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({eta})")
|
||||
.unwrap()
|
||||
.progress_chars("#>-"),
|
||||
);
|
||||
Some(pb)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut file = fs::File::create(path)?;
|
||||
let mut downloaded = 0u64;
|
||||
|
||||
use futures_util::StreamExt;
|
||||
let mut stream = response.bytes_stream();
|
||||
|
||||
while let Some(chunk) = stream.next().await {
|
||||
let chunk = chunk?;
|
||||
file.write_all(&chunk)?;
|
||||
downloaded += chunk.len() as u64;
|
||||
if let Some(ref pb) = pb {
|
||||
pb.set_position(downloaded);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(pb) = pb {
|
||||
pb.finish_with_message("Downloaded");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Sanitize model ID for use as directory name
|
||||
fn sanitize_model_id(model_id: &str) -> String {
|
||||
model_id.replace(['/', '\\', ':'], "_")
|
||||
}
|
||||
|
||||
/// Create a hash of a URL for caching
|
||||
fn hash_url(url: &str) -> String {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(url.as_bytes());
|
||||
hex::encode(&hasher.finalize()[..8])
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_sanitize_model_id() {
|
||||
assert_eq!(
|
||||
sanitize_model_id("sentence-transformers/all-MiniLM-L6-v2"),
|
||||
"sentence-transformers_all-MiniLM-L6-v2"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hash_url() {
|
||||
let hash = hash_url("https://example.com/model.onnx");
|
||||
assert_eq!(hash.len(), 16); // 8 bytes = 16 hex chars
|
||||
}
|
||||
}
|
||||
397
vendor/ruvector/examples/onnx-embeddings/src/pooling.rs
vendored
Normal file
397
vendor/ruvector/examples/onnx-embeddings/src/pooling.rs
vendored
Normal file
@@ -0,0 +1,397 @@
|
||||
//! Pooling strategies for combining token embeddings into sentence embeddings
|
||||
|
||||
use crate::config::PoolingStrategy;
|
||||
use rayon::prelude::*;
|
||||
use tracing::{debug, instrument};
|
||||
|
||||
/// Pooler for combining token embeddings
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Pooler {
|
||||
strategy: PoolingStrategy,
|
||||
normalize: bool,
|
||||
}
|
||||
|
||||
impl Pooler {
|
||||
/// Create a new pooler with the given strategy
|
||||
pub fn new(strategy: PoolingStrategy, normalize: bool) -> Self {
|
||||
Self { strategy, normalize }
|
||||
}
|
||||
|
||||
/// Pool token embeddings into sentence embeddings
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `token_embeddings` - Token embeddings for each sequence [batch][seq_len * hidden]
|
||||
/// * `attention_mask` - Attention mask for each sequence [batch][seq_len]
|
||||
/// * `seq_length` - Sequence length
|
||||
/// * `hidden_size` - Hidden dimension size
|
||||
#[instrument(skip_all, fields(batch_size = token_embeddings.len(), strategy = ?self.strategy))]
|
||||
pub fn pool(
|
||||
&self,
|
||||
token_embeddings: &[Vec<f32>],
|
||||
attention_mask: &[Vec<i64>],
|
||||
seq_length: usize,
|
||||
hidden_size: usize,
|
||||
) -> Vec<Vec<f32>> {
|
||||
debug!(
|
||||
"Pooling {} sequences with strategy {:?}",
|
||||
token_embeddings.len(),
|
||||
self.strategy
|
||||
);
|
||||
|
||||
let embeddings: Vec<Vec<f32>> = token_embeddings
|
||||
.par_iter()
|
||||
.zip(attention_mask.par_iter())
|
||||
.map(|(tokens, mask)| {
|
||||
self.pool_single(tokens, mask, seq_length, hidden_size)
|
||||
})
|
||||
.collect();
|
||||
|
||||
if self.normalize {
|
||||
embeddings
|
||||
.into_par_iter()
|
||||
.map(|emb| Self::normalize_vector(&emb))
|
||||
.collect()
|
||||
} else {
|
||||
embeddings
|
||||
}
|
||||
}
|
||||
|
||||
/// Pool a single sequence
|
||||
fn pool_single(
|
||||
&self,
|
||||
token_embeddings: &[f32],
|
||||
attention_mask: &[i64],
|
||||
seq_length: usize,
|
||||
hidden_size: usize,
|
||||
) -> Vec<f32> {
|
||||
match self.strategy {
|
||||
PoolingStrategy::Mean => {
|
||||
self.mean_pool(token_embeddings, attention_mask, seq_length, hidden_size)
|
||||
}
|
||||
PoolingStrategy::Cls => {
|
||||
self.cls_pool(token_embeddings, hidden_size)
|
||||
}
|
||||
PoolingStrategy::Max => {
|
||||
self.max_pool(token_embeddings, attention_mask, seq_length, hidden_size)
|
||||
}
|
||||
PoolingStrategy::MeanSqrtLen => {
|
||||
self.mean_sqrt_len_pool(token_embeddings, attention_mask, seq_length, hidden_size)
|
||||
}
|
||||
PoolingStrategy::LastToken => {
|
||||
self.last_token_pool(token_embeddings, attention_mask, seq_length, hidden_size)
|
||||
}
|
||||
PoolingStrategy::WeightedMean => {
|
||||
self.weighted_mean_pool(token_embeddings, attention_mask, seq_length, hidden_size)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Mean pooling over all tokens (weighted by attention mask)
|
||||
fn mean_pool(
|
||||
&self,
|
||||
token_embeddings: &[f32],
|
||||
attention_mask: &[i64],
|
||||
seq_length: usize,
|
||||
hidden_size: usize,
|
||||
) -> Vec<f32> {
|
||||
let mut result = vec![0.0f32; hidden_size];
|
||||
let mut count = 0.0f32;
|
||||
|
||||
for (i, &mask) in attention_mask.iter().enumerate().take(seq_length) {
|
||||
if mask == 1 {
|
||||
let start = i * hidden_size;
|
||||
let end = start + hidden_size;
|
||||
for (j, val) in token_embeddings[start..end].iter().enumerate() {
|
||||
result[j] += val;
|
||||
}
|
||||
count += 1.0;
|
||||
}
|
||||
}
|
||||
|
||||
if count > 0.0 {
|
||||
for val in &mut result {
|
||||
*val /= count;
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// CLS token pooling (first token)
|
||||
fn cls_pool(&self, token_embeddings: &[f32], hidden_size: usize) -> Vec<f32> {
|
||||
token_embeddings[..hidden_size].to_vec()
|
||||
}
|
||||
|
||||
/// Max pooling over all tokens
|
||||
fn max_pool(
|
||||
&self,
|
||||
token_embeddings: &[f32],
|
||||
attention_mask: &[i64],
|
||||
seq_length: usize,
|
||||
hidden_size: usize,
|
||||
) -> Vec<f32> {
|
||||
let mut result = vec![f32::NEG_INFINITY; hidden_size];
|
||||
|
||||
for (i, &mask) in attention_mask.iter().enumerate().take(seq_length) {
|
||||
if mask == 1 {
|
||||
let start = i * hidden_size;
|
||||
let end = start + hidden_size;
|
||||
for (j, val) in token_embeddings[start..end].iter().enumerate() {
|
||||
if *val > result[j] {
|
||||
result[j] = *val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Replace -inf with 0 for empty sequences
|
||||
for val in &mut result {
|
||||
if val.is_infinite() {
|
||||
*val = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Mean pooling with sqrt(length) scaling
|
||||
fn mean_sqrt_len_pool(
|
||||
&self,
|
||||
token_embeddings: &[f32],
|
||||
attention_mask: &[i64],
|
||||
seq_length: usize,
|
||||
hidden_size: usize,
|
||||
) -> Vec<f32> {
|
||||
let mut result = self.mean_pool(token_embeddings, attention_mask, seq_length, hidden_size);
|
||||
let length: f32 = attention_mask.iter().filter(|&&m| m == 1).count() as f32;
|
||||
|
||||
if length > 0.0 {
|
||||
let scale = length.sqrt();
|
||||
for val in &mut result {
|
||||
*val *= scale;
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Last token pooling (for decoder models)
|
||||
fn last_token_pool(
|
||||
&self,
|
||||
token_embeddings: &[f32],
|
||||
attention_mask: &[i64],
|
||||
_seq_length: usize,
|
||||
hidden_size: usize,
|
||||
) -> Vec<f32> {
|
||||
// Find last non-padding token
|
||||
let last_idx = attention_mask
|
||||
.iter()
|
||||
.rposition(|&m| m == 1)
|
||||
.unwrap_or(0);
|
||||
|
||||
let start = last_idx * hidden_size;
|
||||
let end = start + hidden_size;
|
||||
|
||||
if end <= token_embeddings.len() {
|
||||
token_embeddings[start..end].to_vec()
|
||||
} else {
|
||||
self.cls_pool(token_embeddings, hidden_size)
|
||||
}
|
||||
}
|
||||
|
||||
/// Weighted mean pooling based on position
|
||||
fn weighted_mean_pool(
|
||||
&self,
|
||||
token_embeddings: &[f32],
|
||||
attention_mask: &[i64],
|
||||
seq_length: usize,
|
||||
hidden_size: usize,
|
||||
) -> Vec<f32> {
|
||||
let mut result = vec![0.0f32; hidden_size];
|
||||
let mut total_weight = 0.0f32;
|
||||
|
||||
for (i, &mask) in attention_mask.iter().enumerate().take(seq_length) {
|
||||
if mask == 1 {
|
||||
// Weight decreases with position (more weight to early tokens)
|
||||
let weight = 1.0 / (i + 1) as f32;
|
||||
let start = i * hidden_size;
|
||||
let end = start + hidden_size;
|
||||
|
||||
for (j, val) in token_embeddings[start..end].iter().enumerate() {
|
||||
result[j] += val * weight;
|
||||
}
|
||||
total_weight += weight;
|
||||
}
|
||||
}
|
||||
|
||||
if total_weight > 0.0 {
|
||||
for val in &mut result {
|
||||
*val /= total_weight;
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// L2 normalize a vector
|
||||
pub fn normalize_vector(vec: &[f32]) -> Vec<f32> {
|
||||
let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
|
||||
if norm > 1e-12 {
|
||||
vec.iter().map(|x| x / norm).collect()
|
||||
} else {
|
||||
vec.to_vec()
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute cosine similarity between two vectors (SIMD-optimized)
|
||||
#[cfg(feature = "simsimd")]
|
||||
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||
use simsimd::SpatialSimilarity;
|
||||
f32::cosine(a, b).unwrap_or(0.0) as f32
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "simsimd"))]
|
||||
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
|
||||
if norm_a > 1e-12 && norm_b > 1e-12 {
|
||||
dot / (norm_a * norm_b)
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute dot product between two vectors (SIMD-optimized)
|
||||
#[cfg(feature = "simsimd")]
|
||||
pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
|
||||
use simsimd::SpatialSimilarity;
|
||||
f32::dot(a, b).unwrap_or(0.0) as f32
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "simsimd"))]
|
||||
pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
|
||||
}
|
||||
|
||||
/// Compute Euclidean distance between two vectors (SIMD-optimized)
|
||||
#[cfg(feature = "simsimd")]
|
||||
pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
|
||||
use simsimd::SpatialSimilarity;
|
||||
(f32::sqeuclidean(a, b).unwrap_or(0.0) as f32).sqrt()
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "simsimd"))]
|
||||
pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter()
|
||||
.zip(b.iter())
|
||||
.map(|(x, y)| (x - y).powi(2))
|
||||
.sum::<f32>()
|
||||
.sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Pooler {
|
||||
fn default() -> Self {
|
||||
Self::new(PoolingStrategy::Mean, true)
|
||||
}
|
||||
}
|
||||
|
||||
/// Batch distance computation using ndarray
|
||||
pub fn batch_cosine_similarity(
|
||||
query: &[f32],
|
||||
candidates: &[Vec<f32>],
|
||||
) -> Vec<f32> {
|
||||
candidates
|
||||
.par_iter()
|
||||
.map(|c| Pooler::cosine_similarity(query, c))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Find top-k most similar vectors
|
||||
pub fn top_k_similar(
|
||||
query: &[f32],
|
||||
candidates: &[Vec<f32>],
|
||||
k: usize,
|
||||
) -> Vec<(usize, f32)> {
|
||||
let mut scores: Vec<(usize, f32)> = candidates
|
||||
.par_iter()
|
||||
.enumerate()
|
||||
.map(|(i, c)| (i, Pooler::cosine_similarity(query, c)))
|
||||
.collect();
|
||||
|
||||
// Sort by score descending
|
||||
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
|
||||
scores.truncate(k);
|
||||
scores
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_normalize_vector() {
|
||||
let vec = vec![3.0, 4.0];
|
||||
let normalized = Pooler::normalize_vector(&vec);
|
||||
|
||||
let norm: f32 = normalized.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
assert!((norm - 1.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_similarity() {
|
||||
let a = vec![1.0, 0.0, 0.0];
|
||||
let b = vec![1.0, 0.0, 0.0];
|
||||
let c = vec![0.0, 1.0, 0.0];
|
||||
|
||||
assert!((Pooler::cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
|
||||
assert!((Pooler::cosine_similarity(&a, &c)).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mean_pooling() {
|
||||
let pooler = Pooler::new(PoolingStrategy::Mean, false);
|
||||
|
||||
// 2 tokens, 3 dimensions
|
||||
let embeddings = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let mask = vec![1i64, 1];
|
||||
|
||||
let result = pooler.pool_single(&embeddings, &mask, 2, 3);
|
||||
|
||||
assert_eq!(result.len(), 3);
|
||||
assert!((result[0] - 2.5).abs() < 1e-6);
|
||||
assert!((result[1] - 3.5).abs() < 1e-6);
|
||||
assert!((result[2] - 4.5).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cls_pooling() {
|
||||
let pooler = Pooler::new(PoolingStrategy::Cls, false);
|
||||
|
||||
let embeddings = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let mask = vec![1i64, 1];
|
||||
|
||||
let result = pooler.pool_single(&embeddings, &mask, 2, 3);
|
||||
|
||||
assert_eq!(result, vec![1.0, 2.0, 3.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_top_k_similar() {
|
||||
let query = vec![1.0, 0.0, 0.0];
|
||||
let candidates = vec![
|
||||
vec![1.0, 0.0, 0.0],
|
||||
vec![0.0, 1.0, 0.0],
|
||||
vec![0.707, 0.707, 0.0],
|
||||
];
|
||||
|
||||
let results = top_k_similar(&query, &candidates, 2);
|
||||
|
||||
assert_eq!(results.len(), 2);
|
||||
assert_eq!(results[0].0, 0); // Most similar
|
||||
}
|
||||
}
|
||||
565
vendor/ruvector/examples/onnx-embeddings/src/ruvector_integration.rs
vendored
Normal file
565
vendor/ruvector/examples/onnx-embeddings/src/ruvector_integration.rs
vendored
Normal file
@@ -0,0 +1,565 @@
|
||||
//! Standalone vector database integration for ONNX embeddings
|
||||
//!
|
||||
//! This module provides a lightweight vector database built on top of the
|
||||
//! embedding system, demonstrating how to integrate with RuVector or use
|
||||
//! as a standalone semantic search engine.
|
||||
|
||||
use crate::{Embedder, EmbeddingError, Result};
|
||||
use parking_lot::RwLock;
|
||||
use rayon::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug, info, instrument};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Vector ID type (using String for compatibility with RuVector)
|
||||
pub type VectorId = String;
|
||||
|
||||
/// Distance metric for similarity calculation
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||
pub enum Distance {
|
||||
/// Cosine similarity (default, best for normalized embeddings)
|
||||
#[default]
|
||||
Cosine,
|
||||
/// Euclidean (L2) distance
|
||||
Euclidean,
|
||||
/// Dot product
|
||||
DotProduct,
|
||||
}
|
||||
|
||||
/// Search result with text and score
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SearchResult {
|
||||
/// Vector ID
|
||||
pub id: VectorId,
|
||||
/// Original text
|
||||
pub text: String,
|
||||
/// Similarity score (higher is better for cosine, lower for euclidean)
|
||||
pub score: f32,
|
||||
/// Optional metadata
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// Stored vector entry
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct StoredEntry {
|
||||
id: VectorId,
|
||||
text: String,
|
||||
vector: Vec<f32>,
|
||||
metadata: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// Configuration for creating a vector index
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct IndexConfig {
|
||||
/// Distance metric
|
||||
pub distance: Distance,
|
||||
/// Maximum number of elements (for pre-allocation)
|
||||
pub max_elements: usize,
|
||||
/// Number of results to over-fetch for filtering
|
||||
pub ef_search: usize,
|
||||
}
|
||||
|
||||
impl Default for IndexConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
distance: Distance::Cosine,
|
||||
max_elements: 100_000,
|
||||
ef_search: 100,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// RuVector-compatible embeddings index
|
||||
///
|
||||
/// A lightweight in-memory vector database that integrates ONNX embeddings
|
||||
/// with similarity search. Compatible with RuVector's API patterns.
|
||||
pub struct RuVectorEmbeddings {
|
||||
/// The embedder for generating vectors (wrapped in RwLock for mutable access)
|
||||
embedder: Arc<RwLock<Embedder>>,
|
||||
/// Stored vectors and metadata
|
||||
entries: RwLock<Vec<StoredEntry>>,
|
||||
/// Index name
|
||||
name: String,
|
||||
/// Configuration
|
||||
config: IndexConfig,
|
||||
}
|
||||
|
||||
impl RuVectorEmbeddings {
|
||||
/// Create a new RuVector index with the given embedder
|
||||
#[instrument(skip_all)]
|
||||
pub fn new(
|
||||
name: impl Into<String>,
|
||||
embedder: Embedder,
|
||||
config: IndexConfig,
|
||||
) -> Result<Self> {
|
||||
let name = name.into();
|
||||
let dimension = embedder.dimension();
|
||||
|
||||
info!(
|
||||
"Creating RuVector index '{}' with dimension {} and {:?} distance",
|
||||
name, dimension, config.distance
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
embedder: Arc::new(RwLock::new(embedder)),
|
||||
entries: RwLock::new(Vec::with_capacity(config.max_elements.min(10_000))),
|
||||
name,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create with default configuration
|
||||
pub fn new_default(name: impl Into<String>, embedder: Embedder) -> Result<Self> {
|
||||
Self::new(name, embedder, IndexConfig::default())
|
||||
}
|
||||
|
||||
/// Insert a single text with optional metadata
|
||||
#[instrument(skip(self, text, metadata), fields(text_len = text.len()))]
|
||||
pub fn insert(
|
||||
&self,
|
||||
text: &str,
|
||||
metadata: Option<serde_json::Value>,
|
||||
) -> Result<VectorId> {
|
||||
let embedding = self.embedder.write().embed_one(text)?;
|
||||
self.insert_with_embedding(text, embedding, metadata)
|
||||
}
|
||||
|
||||
/// Insert with pre-computed embedding
|
||||
pub fn insert_with_embedding(
|
||||
&self,
|
||||
text: &str,
|
||||
embedding: Vec<f32>,
|
||||
metadata: Option<serde_json::Value>,
|
||||
) -> Result<VectorId> {
|
||||
let id = Uuid::new_v4().to_string();
|
||||
|
||||
let entry = StoredEntry {
|
||||
id: id.clone(),
|
||||
text: text.to_string(),
|
||||
vector: embedding,
|
||||
metadata,
|
||||
};
|
||||
|
||||
self.entries.write().push(entry);
|
||||
|
||||
debug!("Inserted text with ID {}", id);
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
/// Insert multiple texts
|
||||
#[instrument(skip(self, texts), fields(count = texts.len()))]
|
||||
pub fn insert_batch<S: AsRef<str>>(&self, texts: &[S]) -> Result<Vec<VectorId>> {
|
||||
let embeddings = self.embedder.write().embed(texts)?;
|
||||
self.insert_batch_with_embeddings(texts, embeddings.embeddings)
|
||||
}
|
||||
|
||||
/// Insert batch with pre-computed embeddings
|
||||
pub fn insert_batch_with_embeddings<S: AsRef<str>>(
|
||||
&self,
|
||||
texts: &[S],
|
||||
embeddings: Vec<Vec<f32>>,
|
||||
) -> Result<Vec<VectorId>> {
|
||||
if texts.len() != embeddings.len() {
|
||||
return Err(EmbeddingError::dimension_mismatch(
|
||||
texts.len(),
|
||||
embeddings.len(),
|
||||
));
|
||||
}
|
||||
|
||||
let entries: Vec<StoredEntry> = texts
|
||||
.iter()
|
||||
.zip(embeddings)
|
||||
.map(|(text, vector)| StoredEntry {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
text: text.as_ref().to_string(),
|
||||
vector,
|
||||
metadata: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let ids: Vec<VectorId> = entries.iter().map(|e| e.id.clone()).collect();
|
||||
|
||||
self.entries.write().extend(entries);
|
||||
|
||||
info!("Inserted {} vectors", ids.len());
|
||||
Ok(ids)
|
||||
}
|
||||
|
||||
/// Search for similar texts
|
||||
#[instrument(skip(self, query), fields(k))]
|
||||
pub fn search(&self, query: &str, k: usize) -> Result<Vec<SearchResult>> {
|
||||
let query_embedding = self.embedder.write().embed_one(query)?;
|
||||
self.search_with_embedding(&query_embedding, k)
|
||||
}
|
||||
|
||||
/// Search with pre-computed query embedding
|
||||
pub fn search_with_embedding(
|
||||
&self,
|
||||
query_embedding: &[f32],
|
||||
k: usize,
|
||||
) -> Result<Vec<SearchResult>> {
|
||||
let entries = self.entries.read();
|
||||
|
||||
if entries.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
// Calculate similarities in parallel
|
||||
let mut scored: Vec<(usize, f32)> = entries
|
||||
.par_iter()
|
||||
.enumerate()
|
||||
.map(|(i, entry)| {
|
||||
let score = self.compute_similarity(query_embedding, &entry.vector);
|
||||
(i, score)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Sort by score (descending for cosine/dot, ascending for euclidean)
|
||||
match self.config.distance {
|
||||
Distance::Cosine | Distance::DotProduct => {
|
||||
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
}
|
||||
Distance::Euclidean => {
|
||||
scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
}
|
||||
}
|
||||
|
||||
// Take top k
|
||||
let results: Vec<SearchResult> = scored
|
||||
.into_iter()
|
||||
.take(k)
|
||||
.map(|(i, score)| {
|
||||
let entry = &entries[i];
|
||||
SearchResult {
|
||||
id: entry.id.clone(),
|
||||
text: entry.text.clone(),
|
||||
score,
|
||||
metadata: entry.metadata.clone(),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
debug!("Search returned {} results", results.len());
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Compute similarity/distance between two vectors
|
||||
fn compute_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
|
||||
match self.config.distance {
|
||||
Distance::Cosine => Self::cosine_similarity(a, b),
|
||||
Distance::Euclidean => Self::euclidean_distance(a, b),
|
||||
Distance::DotProduct => Self::dot_product(a, b),
|
||||
}
|
||||
}
|
||||
|
||||
/// Cosine similarity between two vectors
|
||||
#[inline]
|
||||
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
|
||||
if norm_a > 1e-10 && norm_b > 1e-10 {
|
||||
dot / (norm_a * norm_b)
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
|
||||
/// Euclidean (L2) distance
|
||||
#[inline]
|
||||
fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter()
|
||||
.zip(b.iter())
|
||||
.map(|(x, y)| (x - y).powi(2))
|
||||
.sum::<f32>()
|
||||
.sqrt()
|
||||
}
|
||||
|
||||
/// Dot product
|
||||
#[inline]
|
||||
fn dot_product(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
|
||||
}
|
||||
|
||||
/// Search with metadata filter
|
||||
#[instrument(skip(self, query, filter), fields(k))]
|
||||
pub fn search_filtered<F>(&self, query: &str, k: usize, filter: F) -> Result<Vec<SearchResult>>
|
||||
where
|
||||
F: Fn(&serde_json::Value) -> bool + Sync,
|
||||
{
|
||||
let query_embedding = self.embedder.write().embed_one(query)?;
|
||||
let entries = self.entries.read();
|
||||
|
||||
if entries.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
// Calculate similarities with filtering
|
||||
let mut scored: Vec<(usize, f32)> = entries
|
||||
.par_iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, entry)| {
|
||||
// Apply filter
|
||||
if let Some(ref meta) = entry.metadata {
|
||||
if !filter(meta) {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
let score = self.compute_similarity(&query_embedding, &entry.vector);
|
||||
Some((i, score))
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Sort
|
||||
match self.config.distance {
|
||||
Distance::Cosine | Distance::DotProduct => {
|
||||
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
}
|
||||
Distance::Euclidean => {
|
||||
scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
}
|
||||
}
|
||||
|
||||
let results: Vec<SearchResult> = scored
|
||||
.into_iter()
|
||||
.take(k)
|
||||
.map(|(i, score)| {
|
||||
let entry = &entries[i];
|
||||
SearchResult {
|
||||
id: entry.id.clone(),
|
||||
text: entry.text.clone(),
|
||||
score,
|
||||
metadata: entry.metadata.clone(),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Get a vector by ID
|
||||
pub fn get(&self, id: &str) -> Option<(String, Vec<f32>)> {
|
||||
let entries = self.entries.read();
|
||||
entries
|
||||
.iter()
|
||||
.find(|e| e.id == id)
|
||||
.map(|e| (e.text.clone(), e.vector.clone()))
|
||||
}
|
||||
|
||||
/// Delete a vector by ID
|
||||
pub fn delete(&self, id: &str) -> bool {
|
||||
let mut entries = self.entries.write();
|
||||
let len_before = entries.len();
|
||||
entries.retain(|e| e.id != id);
|
||||
entries.len() < len_before
|
||||
}
|
||||
|
||||
/// Get the number of vectors in the index
|
||||
pub fn len(&self) -> usize {
|
||||
self.entries.read().len()
|
||||
}
|
||||
|
||||
/// Check if the index is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.entries.read().is_empty()
|
||||
}
|
||||
|
||||
/// Get index name
|
||||
pub fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
/// Get the embedding dimension
|
||||
pub fn dimension(&self) -> usize {
|
||||
self.embedder.read().dimension()
|
||||
}
|
||||
|
||||
/// Get reference to the embedder (wrapped in Arc<RwLock>)
|
||||
pub fn embedder(&self) -> &Arc<RwLock<Embedder>> {
|
||||
&self.embedder
|
||||
}
|
||||
|
||||
/// Clear all vectors
|
||||
pub fn clear(&self) {
|
||||
self.entries.write().clear();
|
||||
}
|
||||
|
||||
/// Export all entries for persistence
|
||||
pub fn export(&self) -> Vec<(VectorId, String, Vec<f32>, Option<serde_json::Value>)> {
|
||||
self.entries
|
||||
.read()
|
||||
.iter()
|
||||
.map(|e| (e.id.clone(), e.text.clone(), e.vector.clone(), e.metadata.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Import entries (for loading from persistence)
|
||||
pub fn import(
|
||||
&self,
|
||||
entries: Vec<(VectorId, String, Vec<f32>, Option<serde_json::Value>)>,
|
||||
) {
|
||||
let stored: Vec<StoredEntry> = entries
|
||||
.into_iter()
|
||||
.map(|(id, text, vector, metadata)| StoredEntry {
|
||||
id,
|
||||
text,
|
||||
vector,
|
||||
metadata,
|
||||
})
|
||||
.collect();
|
||||
|
||||
*self.entries.write() = stored;
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for creating RuVector indexes
|
||||
pub struct RuVectorBuilder {
|
||||
name: String,
|
||||
embedder: Option<Embedder>,
|
||||
config: IndexConfig,
|
||||
}
|
||||
|
||||
impl RuVectorBuilder {
|
||||
/// Create a new builder
|
||||
pub fn new(name: impl Into<String>) -> Self {
|
||||
Self {
|
||||
name: name.into(),
|
||||
embedder: None,
|
||||
config: IndexConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the embedder
|
||||
pub fn embedder(mut self, embedder: Embedder) -> Self {
|
||||
self.embedder = Some(embedder);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set distance metric
|
||||
pub fn distance(mut self, distance: Distance) -> Self {
|
||||
self.config.distance = distance;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set max elements
|
||||
pub fn max_elements(mut self, max: usize) -> Self {
|
||||
self.config.max_elements = max;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set ef_search parameter
|
||||
pub fn ef_search(mut self, ef: usize) -> Self {
|
||||
self.config.ef_search = ef;
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the index
|
||||
pub fn build(self) -> Result<RuVectorEmbeddings> {
|
||||
let embedder = self
|
||||
.embedder
|
||||
.ok_or_else(|| EmbeddingError::invalid_config("Embedder is required"))?;
|
||||
|
||||
RuVectorEmbeddings::new(self.name, embedder, self.config)
|
||||
}
|
||||
}
|
||||
|
||||
/// RAG (Retrieval-Augmented Generation) helper
|
||||
pub struct RagPipeline {
|
||||
index: RuVectorEmbeddings,
|
||||
top_k: usize,
|
||||
}
|
||||
|
||||
impl RagPipeline {
|
||||
/// Create a new RAG pipeline
|
||||
pub fn new(index: RuVectorEmbeddings, top_k: usize) -> Self {
|
||||
Self { index, top_k }
|
||||
}
|
||||
|
||||
/// Retrieve context for a query
|
||||
pub fn retrieve(&self, query: &str) -> Result<Vec<String>> {
|
||||
let results = self.index.search(query, self.top_k)?;
|
||||
Ok(results.into_iter().map(|r| r.text).collect())
|
||||
}
|
||||
|
||||
/// Retrieve with scores
|
||||
pub fn retrieve_with_scores(&self, query: &str) -> Result<Vec<(String, f32)>> {
|
||||
let results = self.index.search(query, self.top_k)?;
|
||||
Ok(results.into_iter().map(|r| (r.text, r.score)).collect())
|
||||
}
|
||||
|
||||
/// Format retrieved context as a prompt
|
||||
pub fn format_context(&self, query: &str) -> Result<String> {
|
||||
let contexts = self.retrieve(query)?;
|
||||
|
||||
let mut prompt = String::from("Context:\n");
|
||||
for (i, ctx) in contexts.iter().enumerate() {
|
||||
prompt.push_str(&format!("[{}] {}\n", i + 1, ctx));
|
||||
}
|
||||
prompt.push_str(&format!("\nQuestion: {}", query));
|
||||
|
||||
Ok(prompt)
|
||||
}
|
||||
|
||||
/// Format context with scores
|
||||
pub fn format_context_with_scores(&self, query: &str) -> Result<String> {
|
||||
let results = self.retrieve_with_scores(query)?;
|
||||
|
||||
let mut prompt = String::from("Context (with relevance scores):\n");
|
||||
for (i, (ctx, score)) in results.iter().enumerate() {
|
||||
prompt.push_str(&format!("[{} - {:.3}] {}\n", i + 1, score, ctx));
|
||||
}
|
||||
prompt.push_str(&format!("\nQuestion: {}", query));
|
||||
|
||||
Ok(prompt)
|
||||
}
|
||||
|
||||
/// Add documents to the index
|
||||
pub fn add_documents<S: AsRef<str>>(&self, documents: &[S]) -> Result<Vec<VectorId>> {
|
||||
self.index.insert_batch(documents)
|
||||
}
|
||||
|
||||
/// Get reference to the underlying index
|
||||
pub fn index(&self) -> &RuVectorEmbeddings {
|
||||
&self.index
|
||||
}
|
||||
|
||||
/// Get mutable reference to set top_k
|
||||
pub fn set_top_k(&mut self, k: usize) {
|
||||
self.top_k = k;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_cosine_similarity() {
|
||||
let a = vec![1.0, 0.0, 0.0];
|
||||
let b = vec![1.0, 0.0, 0.0];
|
||||
let c = vec![0.0, 1.0, 0.0];
|
||||
|
||||
assert!((RuVectorEmbeddings::cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
|
||||
assert!(RuVectorEmbeddings::cosine_similarity(&a, &c).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_euclidean_distance() {
|
||||
let a = vec![0.0, 0.0];
|
||||
let b = vec![3.0, 4.0];
|
||||
|
||||
let dist = RuVectorEmbeddings::euclidean_distance(&a, &b);
|
||||
assert!((dist - 5.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dot_product() {
|
||||
let a = vec![1.0, 2.0, 3.0];
|
||||
let b = vec![4.0, 5.0, 6.0];
|
||||
|
||||
let dot = RuVectorEmbeddings::dot_product(&a, &b);
|
||||
assert!((dot - 32.0).abs() < 1e-6);
|
||||
}
|
||||
}
|
||||
260
vendor/ruvector/examples/onnx-embeddings/src/tokenizer.rs
vendored
Normal file
260
vendor/ruvector/examples/onnx-embeddings/src/tokenizer.rs
vendored
Normal file
@@ -0,0 +1,260 @@
|
||||
//! Text tokenization using HuggingFace tokenizers
|
||||
|
||||
use crate::{EmbeddingError, Result};
|
||||
use std::path::Path;
|
||||
use tokenizers::tokenizer::Tokenizer as HfTokenizer;
|
||||
use tracing::{debug, instrument};
|
||||
|
||||
/// Wrapper around HuggingFace tokenizer with batch processing
|
||||
pub struct Tokenizer {
|
||||
inner: HfTokenizer,
|
||||
max_length: usize,
|
||||
pad_token_id: u32,
|
||||
}
|
||||
|
||||
/// Encoded batch output
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EncodedBatch {
|
||||
/// Token IDs [batch_size, seq_length]
|
||||
pub input_ids: Vec<Vec<i64>>,
|
||||
/// Attention mask [batch_size, seq_length]
|
||||
pub attention_mask: Vec<Vec<i64>>,
|
||||
/// Token type IDs [batch_size, seq_length]
|
||||
pub token_type_ids: Vec<Vec<i64>>,
|
||||
/// Original sequence lengths before padding
|
||||
pub original_lengths: Vec<usize>,
|
||||
}
|
||||
|
||||
impl EncodedBatch {
|
||||
/// Get batch size
|
||||
pub fn batch_size(&self) -> usize {
|
||||
self.input_ids.len()
|
||||
}
|
||||
|
||||
/// Get sequence length (padded)
|
||||
pub fn seq_length(&self) -> usize {
|
||||
self.input_ids.first().map(|v| v.len()).unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Convert to flat arrays for ONNX input
|
||||
pub fn to_onnx_inputs(&self) -> (Vec<i64>, Vec<i64>, Vec<i64>, Vec<usize>) {
|
||||
let batch_size = self.batch_size();
|
||||
let seq_length = self.seq_length();
|
||||
let total_len = batch_size * seq_length;
|
||||
|
||||
let mut flat_input_ids = Vec::with_capacity(total_len);
|
||||
let mut flat_attention_mask = Vec::with_capacity(total_len);
|
||||
let mut flat_token_type_ids = Vec::with_capacity(total_len);
|
||||
|
||||
for i in 0..batch_size {
|
||||
flat_input_ids.extend(&self.input_ids[i]);
|
||||
flat_attention_mask.extend(&self.attention_mask[i]);
|
||||
flat_token_type_ids.extend(&self.token_type_ids[i]);
|
||||
}
|
||||
|
||||
(
|
||||
flat_input_ids,
|
||||
flat_attention_mask,
|
||||
flat_token_type_ids,
|
||||
vec![batch_size, seq_length],
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to find pad token ID from vocabulary
|
||||
fn find_pad_token_id(tokenizer: &HfTokenizer) -> u32 {
|
||||
let vocab = tokenizer.get_vocab(true);
|
||||
vocab
|
||||
.get("[PAD]")
|
||||
.or_else(|| vocab.get("<pad>"))
|
||||
.or_else(|| vocab.get("<|pad|>"))
|
||||
.copied()
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
impl Tokenizer {
|
||||
/// Load tokenizer from a local file
|
||||
#[instrument(skip_all, fields(path = %path.as_ref().display()))]
|
||||
pub fn from_file(path: impl AsRef<Path>, max_length: usize) -> Result<Self> {
|
||||
let path = path.as_ref();
|
||||
debug!("Loading tokenizer from file");
|
||||
|
||||
let inner = HfTokenizer::from_file(path)
|
||||
.map_err(|e| EmbeddingError::tokenizer_not_found(e.to_string()))?;
|
||||
|
||||
let pad_token_id = find_pad_token_id(&inner);
|
||||
|
||||
Ok(Self {
|
||||
inner,
|
||||
max_length,
|
||||
pad_token_id,
|
||||
})
|
||||
}
|
||||
|
||||
/// Load tokenizer from HuggingFace Hub by downloading tokenizer.json
|
||||
#[instrument(skip_all, fields(model_id = %model_id))]
|
||||
pub fn from_pretrained(model_id: &str, max_length: usize) -> Result<Self> {
|
||||
debug!("Loading tokenizer from HuggingFace Hub: {}", model_id);
|
||||
|
||||
// Download tokenizer.json from HuggingFace Hub
|
||||
let url = format!(
|
||||
"https://huggingface.co/{}/resolve/main/tokenizer.json",
|
||||
model_id
|
||||
);
|
||||
|
||||
let response = reqwest::blocking::get(&url)
|
||||
.map_err(|e| EmbeddingError::download_failed(format!("Failed to download tokenizer: {}", e)))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(EmbeddingError::download_failed(format!(
|
||||
"Failed to download tokenizer from {}: HTTP {}",
|
||||
url,
|
||||
response.status()
|
||||
)));
|
||||
}
|
||||
|
||||
let bytes = response.bytes()
|
||||
.map_err(|e| EmbeddingError::download_failed(e.to_string()))?;
|
||||
|
||||
let inner = HfTokenizer::from_bytes(&bytes)
|
||||
.map_err(|e| EmbeddingError::tokenizer_not_found(e.to_string()))?;
|
||||
|
||||
let pad_token_id = find_pad_token_id(&inner);
|
||||
|
||||
Ok(Self {
|
||||
inner,
|
||||
max_length,
|
||||
pad_token_id,
|
||||
})
|
||||
}
|
||||
|
||||
/// Load tokenizer from JSON string
|
||||
pub fn from_json(json: &str, max_length: usize) -> Result<Self> {
|
||||
let inner = HfTokenizer::from_bytes(json.as_bytes())
|
||||
.map_err(|e| EmbeddingError::tokenizer_not_found(e.to_string()))?;
|
||||
|
||||
let pad_token_id = find_pad_token_id(&inner);
|
||||
|
||||
Ok(Self {
|
||||
inner,
|
||||
max_length,
|
||||
pad_token_id,
|
||||
})
|
||||
}
|
||||
|
||||
/// Encode a single text
|
||||
pub fn encode(&self, text: &str) -> Result<EncodedBatch> {
|
||||
self.encode_batch(&[text])
|
||||
}
|
||||
|
||||
/// Encode a batch of texts
|
||||
#[instrument(skip_all, fields(batch_size = texts.len()))]
|
||||
pub fn encode_batch<S: AsRef<str>>(&self, texts: &[S]) -> Result<EncodedBatch> {
|
||||
if texts.is_empty() {
|
||||
return Err(EmbeddingError::EmptyInput);
|
||||
}
|
||||
|
||||
debug!("Encoding batch of {} texts", texts.len());
|
||||
|
||||
// Encode all texts
|
||||
let encodings: Vec<_> = texts
|
||||
.iter()
|
||||
.map(|t| self.inner.encode(t.as_ref(), true))
|
||||
.collect::<std::result::Result<Vec<_>, _>>()
|
||||
.map_err(EmbeddingError::from)?;
|
||||
|
||||
// Find max length in batch (capped at max_length)
|
||||
let max_len = encodings
|
||||
.iter()
|
||||
.map(|e| e.get_ids().len().min(self.max_length))
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
|
||||
// Pad all sequences to the same length
|
||||
let mut input_ids = Vec::with_capacity(texts.len());
|
||||
let mut attention_mask = Vec::with_capacity(texts.len());
|
||||
let mut token_type_ids = Vec::with_capacity(texts.len());
|
||||
let mut original_lengths = Vec::with_capacity(texts.len());
|
||||
|
||||
for encoding in &encodings {
|
||||
let ids = encoding.get_ids();
|
||||
let type_ids = encoding.get_type_ids();
|
||||
let len = ids.len().min(self.max_length);
|
||||
|
||||
original_lengths.push(len);
|
||||
|
||||
// Truncate if necessary and convert to i64
|
||||
let mut ids_vec: Vec<i64> = ids[..len].iter().map(|&x| x as i64).collect();
|
||||
let mut mask_vec: Vec<i64> = vec![1; len];
|
||||
let mut type_vec: Vec<i64> = type_ids[..len].iter().map(|&x| x as i64).collect();
|
||||
|
||||
// Pad to max_len
|
||||
let pad_len = max_len - len;
|
||||
if pad_len > 0 {
|
||||
ids_vec.extend(std::iter::repeat_n(self.pad_token_id as i64, pad_len));
|
||||
mask_vec.extend(std::iter::repeat_n(0i64, pad_len));
|
||||
type_vec.extend(std::iter::repeat_n(0i64, pad_len));
|
||||
}
|
||||
|
||||
input_ids.push(ids_vec);
|
||||
attention_mask.push(mask_vec);
|
||||
token_type_ids.push(type_vec);
|
||||
}
|
||||
|
||||
Ok(EncodedBatch {
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
original_lengths,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get the vocabulary size
|
||||
pub fn vocab_size(&self) -> usize {
|
||||
self.inner.get_vocab_size(true)
|
||||
}
|
||||
|
||||
/// Get the max length
|
||||
pub fn max_length(&self) -> usize {
|
||||
self.max_length
|
||||
}
|
||||
|
||||
/// Set the max length
|
||||
pub fn set_max_length(&mut self, max_length: usize) {
|
||||
self.max_length = max_length;
|
||||
}
|
||||
|
||||
/// Decode token IDs back to text
|
||||
pub fn decode(&self, ids: &[u32], skip_special_tokens: bool) -> Result<String> {
|
||||
self.inner
|
||||
.decode(ids, skip_special_tokens)
|
||||
.map_err(EmbeddingError::from)
|
||||
}
|
||||
|
||||
/// Get the pad token ID
|
||||
pub fn pad_token_id(&self) -> u32 {
|
||||
self.pad_token_id
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_encoded_batch_to_onnx() {
|
||||
let batch = EncodedBatch {
|
||||
input_ids: vec![vec![101, 2054, 2003, 102], vec![101, 2054, 102, 0]],
|
||||
attention_mask: vec![vec![1, 1, 1, 1], vec![1, 1, 1, 0]],
|
||||
token_type_ids: vec![vec![0, 0, 0, 0], vec![0, 0, 0, 0]],
|
||||
original_lengths: vec![4, 3],
|
||||
};
|
||||
|
||||
let (ids, mask, types, shape) = batch.to_onnx_inputs();
|
||||
|
||||
assert_eq!(shape, vec![2, 4]);
|
||||
assert_eq!(ids.len(), 8);
|
||||
assert_eq!(mask.len(), 8);
|
||||
assert_eq!(types.len(), 8);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user