Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,252 @@
//! Configuration for the ONNX embedder
use crate::PretrainedModel;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
/// Source of the ONNX model
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ModelSource {
/// Load from HuggingFace Hub (downloads if not cached)
HuggingFace {
model_id: String,
revision: Option<String>,
},
/// Load from a local ONNX file
Local {
model_path: PathBuf,
tokenizer_path: PathBuf,
},
/// Use a pre-configured model
Pretrained(PretrainedModel),
/// Custom URL for model download
Url {
model_url: String,
tokenizer_url: String,
},
}
impl Default for ModelSource {
fn default() -> Self {
Self::Pretrained(PretrainedModel::default())
}
}
impl From<PretrainedModel> for ModelSource {
fn from(model: PretrainedModel) -> Self {
Self::Pretrained(model)
}
}
/// Pooling strategy for combining token embeddings
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum PoolingStrategy {
/// Mean pooling over all tokens (most common)
#[default]
Mean,
/// Use [CLS] token embedding
Cls,
/// Max pooling over all tokens
Max,
/// Mean pooling with sqrt(length) scaling
MeanSqrtLen,
/// Last token pooling (for decoder models)
LastToken,
/// Weighted mean based on attention mask
WeightedMean,
}
/// Execution provider for ONNX Runtime
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)]
pub enum ExecutionProvider {
/// CPU inference (default, always available)
#[default]
Cpu,
/// CUDA GPU acceleration
Cuda { device_id: i32 },
/// TensorRT optimization
TensorRt { device_id: i32 },
/// CoreML on macOS
CoreMl,
/// DirectML on Windows
DirectMl,
/// ROCm for AMD GPUs
Rocm { device_id: i32 },
}
/// Configuration for the embedder
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbedderConfig {
/// Model source
pub model_source: ModelSource,
/// Pooling strategy
pub pooling: PoolingStrategy,
/// Whether to normalize embeddings to unit length
pub normalize: bool,
/// Maximum sequence length (truncation)
pub max_length: usize,
/// Batch size for inference
pub batch_size: usize,
/// Number of threads for CPU inference
pub num_threads: usize,
/// Execution provider
pub execution_provider: ExecutionProvider,
/// Cache directory for downloaded models
pub cache_dir: PathBuf,
/// Whether to show progress during downloads
pub show_progress: bool,
/// Use fp16 inference if available
pub use_fp16: bool,
/// Enable graph optimization
pub optimize_graph: bool,
}
impl Default for EmbedderConfig {
fn default() -> Self {
Self {
model_source: ModelSource::default(),
pooling: PoolingStrategy::default(),
normalize: true,
max_length: 256,
batch_size: 32,
num_threads: num_cpus::get(),
execution_provider: ExecutionProvider::default(),
cache_dir: default_cache_dir(),
show_progress: true,
use_fp16: false,
optimize_graph: true,
}
}
}
impl EmbedderConfig {
/// Create a new config builder
pub fn builder() -> EmbedderConfigBuilder {
EmbedderConfigBuilder::default()
}
/// Create config for a pretrained model
pub fn pretrained(model: PretrainedModel) -> Self {
Self {
model_source: ModelSource::Pretrained(model),
max_length: model.max_seq_length(),
normalize: model.normalize_output(),
..Default::default()
}
}
/// Create config for a local model
pub fn local(model_path: impl Into<PathBuf>, tokenizer_path: impl Into<PathBuf>) -> Self {
Self {
model_source: ModelSource::Local {
model_path: model_path.into(),
tokenizer_path: tokenizer_path.into(),
},
..Default::default()
}
}
/// Create config for a HuggingFace model
pub fn huggingface(model_id: impl Into<String>) -> Self {
Self {
model_source: ModelSource::HuggingFace {
model_id: model_id.into(),
revision: None,
},
..Default::default()
}
}
}
/// Builder for EmbedderConfig
#[derive(Debug, Default)]
pub struct EmbedderConfigBuilder {
config: EmbedderConfig,
}
impl EmbedderConfigBuilder {
pub fn model_source(mut self, source: ModelSource) -> Self {
self.config.model_source = source;
self
}
pub fn pretrained(mut self, model: PretrainedModel) -> Self {
self.config.model_source = ModelSource::Pretrained(model);
self.config.max_length = model.max_seq_length();
self.config.normalize = model.normalize_output();
self
}
pub fn pooling(mut self, strategy: PoolingStrategy) -> Self {
self.config.pooling = strategy;
self
}
pub fn normalize(mut self, normalize: bool) -> Self {
self.config.normalize = normalize;
self
}
pub fn max_length(mut self, length: usize) -> Self {
self.config.max_length = length;
self
}
pub fn batch_size(mut self, size: usize) -> Self {
self.config.batch_size = size;
self
}
pub fn num_threads(mut self, threads: usize) -> Self {
self.config.num_threads = threads;
self
}
pub fn execution_provider(mut self, provider: ExecutionProvider) -> Self {
self.config.execution_provider = provider;
self
}
pub fn cache_dir(mut self, dir: impl Into<PathBuf>) -> Self {
self.config.cache_dir = dir.into();
self
}
pub fn show_progress(mut self, show: bool) -> Self {
self.config.show_progress = show;
self
}
pub fn use_fp16(mut self, use_fp16: bool) -> Self {
self.config.use_fp16 = use_fp16;
self
}
pub fn optimize_graph(mut self, optimize: bool) -> Self {
self.config.optimize_graph = optimize;
self
}
pub fn build(self) -> EmbedderConfig {
self.config
}
}
fn default_cache_dir() -> PathBuf {
dirs::cache_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join("ruvector")
.join("onnx-models")
}
fn num_cpus_get() -> usize {
std::thread::available_parallelism()
.map(|p| p.get())
.unwrap_or(4)
}
mod num_cpus {
pub fn get() -> usize {
super::num_cpus_get()
}
}

View File

@@ -0,0 +1,471 @@
//! Main embedder implementation combining model, tokenizer, and pooling
use crate::config::{EmbedderConfig, ModelSource, PoolingStrategy};
use crate::model::OnnxModel;
use crate::pooling::Pooler;
use crate::tokenizer::Tokenizer;
use crate::{EmbeddingError, PretrainedModel, Result};
use std::path::Path;
use tracing::{debug, info, instrument};
#[cfg(feature = "gpu")]
use crate::gpu::{GpuAccelerator, GpuConfig};
/// High-level embedder combining tokenizer, model, and pooling
pub struct Embedder {
/// ONNX model for inference
model: OnnxModel,
/// Tokenizer for text processing
tokenizer: Tokenizer,
/// Pooler for combining token embeddings
pooler: Pooler,
/// Configuration
config: EmbedderConfig,
/// Optional GPU accelerator for similarity operations
#[cfg(feature = "gpu")]
gpu: Option<GpuAccelerator>,
}
/// Embedding output with metadata
#[derive(Debug, Clone)]
pub struct EmbeddingOutput {
/// The embedding vectors
pub embeddings: Vec<Vec<f32>>,
/// Original input texts
pub texts: Vec<String>,
/// Number of tokens per input
pub token_counts: Vec<usize>,
/// Embedding dimension
pub dimension: usize,
}
impl EmbeddingOutput {
/// Get the number of embeddings
pub fn len(&self) -> usize {
self.embeddings.len()
}
/// Check if empty
pub fn is_empty(&self) -> bool {
self.embeddings.is_empty()
}
/// Get a single embedding by index
pub fn get(&self, index: usize) -> Option<&Vec<f32>> {
self.embeddings.get(index)
}
/// Iterate over embeddings
pub fn iter(&self) -> impl Iterator<Item = &Vec<f32>> {
self.embeddings.iter()
}
/// Convert to owned vectors
pub fn into_vecs(self) -> Vec<Vec<f32>> {
self.embeddings
}
}
impl Embedder {
/// Create a new embedder from configuration
#[instrument(skip_all)]
pub async fn new(config: EmbedderConfig) -> Result<Self> {
info!("Initializing embedder");
// Load model
let model = OnnxModel::from_config(&config).await?;
// Load tokenizer based on model source
let tokenizer = match &config.model_source {
ModelSource::Local {
tokenizer_path, ..
} => Tokenizer::from_file(tokenizer_path, config.max_length)?,
ModelSource::Pretrained(pretrained) => {
Tokenizer::from_pretrained(pretrained.model_id(), config.max_length)?
}
ModelSource::HuggingFace { model_id, .. } => {
Tokenizer::from_pretrained(model_id, config.max_length)?
}
ModelSource::Url { tokenizer_url, .. } => {
// Download tokenizer
let cache_path = config.cache_dir.join("tokenizer.json");
if !cache_path.exists() {
download_tokenizer(tokenizer_url, &cache_path).await?;
}
Tokenizer::from_file(&cache_path, config.max_length)?
}
};
let pooler = Pooler::new(config.pooling, config.normalize);
// Initialize GPU accelerator if available
#[cfg(feature = "gpu")]
let gpu = {
match GpuAccelerator::new(GpuConfig::auto()).await {
Ok(accel) => {
info!("GPU accelerator initialized: {}", accel.device_info().name);
Some(accel)
}
Err(e) => {
debug!("GPU not available, using CPU: {}", e);
None
}
}
};
Ok(Self {
model,
tokenizer,
pooler,
config,
#[cfg(feature = "gpu")]
gpu,
})
}
/// Create embedder with default model (all-MiniLM-L6-v2)
pub async fn default_model() -> Result<Self> {
Self::new(EmbedderConfig::default()).await
}
/// Create embedder for a specific pretrained model
pub async fn pretrained(model: PretrainedModel) -> Result<Self> {
Self::new(EmbedderConfig::pretrained(model)).await
}
/// Embed a single text
#[instrument(skip(self, text), fields(text_len = text.len()))]
pub fn embed_one(&mut self, text: &str) -> Result<Vec<f32>> {
let output = self.embed(&[text])?;
output
.embeddings
.into_iter()
.next()
.ok_or(EmbeddingError::EmptyInput)
}
/// Embed multiple texts
#[instrument(skip(self, texts), fields(batch_size = texts.len()))]
pub fn embed<S: AsRef<str>>(&mut self, texts: &[S]) -> Result<EmbeddingOutput> {
if texts.is_empty() {
return Err(EmbeddingError::EmptyInput);
}
let texts_owned: Vec<String> = texts.iter().map(|t| t.as_ref().to_string()).collect();
// Process in batches
let batch_size = self.config.batch_size;
let mut all_embeddings = Vec::with_capacity(texts.len());
let mut all_token_counts = Vec::with_capacity(texts.len());
for chunk in texts.chunks(batch_size) {
let (embeddings, token_counts) = self.embed_batch(chunk)?;
all_embeddings.extend(embeddings);
all_token_counts.extend(token_counts);
}
Ok(EmbeddingOutput {
embeddings: all_embeddings,
texts: texts_owned,
token_counts: all_token_counts,
dimension: self.model.dimension(),
})
}
/// Embed a batch of texts (internal)
fn embed_batch<S: AsRef<str>>(&mut self, texts: &[S]) -> Result<(Vec<Vec<f32>>, Vec<usize>)> {
debug!("Embedding batch of {} texts", texts.len());
// Tokenize
let encoded = self.tokenizer.encode_batch(texts)?;
let (input_ids, attention_mask, token_type_ids, shape) = encoded.to_onnx_inputs();
// Run model
let token_embeddings = self.model.run(
&input_ids,
&attention_mask,
&token_type_ids,
&shape,
)?;
let seq_length = shape[1];
let hidden_size = self.model.dimension();
// Pool embeddings
let attention_masks: Vec<Vec<i64>> = encoded.attention_mask;
let embeddings = self.pooler.pool(
&token_embeddings,
&attention_masks,
seq_length,
hidden_size,
);
let token_counts = encoded.original_lengths;
Ok((embeddings, token_counts))
}
/// Embed texts (sequential processing)
/// Note: For parallel processing, consider using tokio::spawn with multiple Embedder instances
#[instrument(skip(self, texts), fields(total_texts = texts.len()))]
pub fn embed_parallel<S: AsRef<str> + Sync>(&mut self, texts: &[S]) -> Result<EmbeddingOutput> {
// Use sequential processing since ONNX session requires mutable access
self.embed(texts)
}
/// Process texts one at a time (use embed for batch processing)
pub fn embed_each<S: AsRef<str>>(&mut self, texts: &[S]) -> Vec<Result<Vec<f32>>> {
texts.iter().map(|text| self.embed_one(text.as_ref())).collect()
}
/// Get the embedding dimension
pub fn dimension(&self) -> usize {
self.model.dimension()
}
/// Get model info
pub fn model_info(&self) -> &crate::model::ModelInfo {
self.model.info()
}
/// Get the pooling strategy
pub fn pooling_strategy(&self) -> PoolingStrategy {
self.config.pooling
}
/// Get max sequence length
pub fn max_length(&self) -> usize {
self.config.max_length
}
/// Compute similarity between two texts
pub fn similarity(&mut self, text1: &str, text2: &str) -> Result<f32> {
let emb1 = self.embed_one(text1)?;
let emb2 = self.embed_one(text2)?;
Ok(Pooler::cosine_similarity(&emb1, &emb2))
}
/// Find most similar texts from a corpus
/// Uses GPU acceleration when available and corpus is large enough
#[instrument(skip(self, query, corpus), fields(corpus_size = corpus.len()))]
pub fn most_similar<S: AsRef<str>>(
&mut self,
query: &str,
corpus: &[S],
top_k: usize,
) -> Result<Vec<(usize, f32, String)>> {
let query_emb = self.embed_one(query)?;
let corpus_embs = self.embed(corpus)?;
// Try GPU-accelerated similarity if available
#[cfg(feature = "gpu")]
if let Some(ref gpu) = self.gpu {
if corpus.len() >= 64 {
let candidates: Vec<&[f32]> = corpus_embs.embeddings.iter().map(|v| v.as_slice()).collect();
if let Ok(results) = gpu.top_k_similar(&query_emb, &candidates, top_k) {
return Ok(results
.into_iter()
.map(|(idx, score)| (idx, score, corpus[idx].as_ref().to_string()))
.collect());
}
}
}
// CPU fallback
let mut similarities: Vec<(usize, f32, String)> = corpus_embs
.embeddings
.iter()
.enumerate()
.map(|(i, emb)| {
let sim = Pooler::cosine_similarity(&query_emb, emb);
(i, sim, corpus[i].as_ref().to_string())
})
.collect();
similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
similarities.truncate(top_k);
Ok(similarities)
}
/// Check if GPU acceleration is available
pub fn has_gpu(&self) -> bool {
#[cfg(feature = "gpu")]
{
self.gpu.is_some()
}
#[cfg(not(feature = "gpu"))]
{
false
}
}
/// Get GPU device info if available
#[cfg(feature = "gpu")]
pub fn gpu_info(&self) -> Option<crate::gpu::GpuInfo> {
self.gpu.as_ref().map(|g| g.device_info())
}
/// Cluster texts by similarity (simple k-means-like approach)
#[instrument(skip(self, texts), fields(n_texts = texts.len(), n_clusters))]
pub fn cluster<S: AsRef<str>>(
&mut self,
texts: &[S],
n_clusters: usize,
) -> Result<Vec<usize>> {
let embeddings = self.embed(texts)?;
let dim = self.dimension();
// Initialize centroids with first k embeddings
let mut centroids: Vec<Vec<f32>> = embeddings
.embeddings
.iter()
.take(n_clusters)
.cloned()
.collect();
let mut assignments = vec![0usize; texts.len()];
let max_iterations = 100;
for _ in 0..max_iterations {
let old_assignments = assignments.clone();
// Assign to nearest centroid
for (i, emb) in embeddings.embeddings.iter().enumerate() {
let mut min_dist = f32::MAX;
let mut min_idx = 0;
for (j, centroid) in centroids.iter().enumerate() {
let dist = Pooler::euclidean_distance(emb, centroid);
if dist < min_dist {
min_dist = dist;
min_idx = j;
}
}
assignments[i] = min_idx;
}
// Check convergence
if assignments == old_assignments {
break;
}
// Update centroids
for (j, centroid) in centroids.iter_mut().enumerate() {
let cluster_points: Vec<&Vec<f32>> = embeddings
.embeddings
.iter()
.zip(assignments.iter())
.filter(|(_, &a)| a == j)
.map(|(e, _)| e)
.collect();
if !cluster_points.is_empty() {
*centroid = vec![0.0; dim];
for point in &cluster_points {
for (k, &val) in point.iter().enumerate() {
centroid[k] += val;
}
}
let count = cluster_points.len() as f32;
for val in centroid.iter_mut() {
*val /= count;
}
}
}
}
Ok(assignments)
}
}
/// Download tokenizer from URL
async fn download_tokenizer(url: &str, path: &Path) -> Result<()> {
use std::io::Write;
let response = reqwest::get(url).await?;
if !response.status().is_success() {
return Err(EmbeddingError::download_failed(format!(
"Failed to download tokenizer: HTTP {}",
response.status()
)));
}
let bytes = response.bytes().await?;
let mut file = std::fs::File::create(path)?;
file.write_all(&bytes)?;
Ok(())
}
/// Builder for creating embedders with custom configurations
pub struct EmbedderBuilder {
config: EmbedderConfig,
}
impl EmbedderBuilder {
/// Start building an embedder
pub fn new() -> Self {
Self {
config: EmbedderConfig::default(),
}
}
/// Use a pretrained model
pub fn pretrained(mut self, model: PretrainedModel) -> Self {
self.config = EmbedderConfig::pretrained(model);
self
}
/// Set pooling strategy
pub fn pooling(mut self, strategy: PoolingStrategy) -> Self {
self.config.pooling = strategy;
self
}
/// Set normalization
pub fn normalize(mut self, normalize: bool) -> Self {
self.config.normalize = normalize;
self
}
/// Set batch size
pub fn batch_size(mut self, size: usize) -> Self {
self.config.batch_size = size;
self
}
/// Set max sequence length
pub fn max_length(mut self, length: usize) -> Self {
self.config.max_length = length;
self
}
/// Build the embedder
pub async fn build(self) -> Result<Embedder> {
Embedder::new(self.config).await
}
}
impl Default for EmbedderBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = EmbedderConfig::default();
assert_eq!(config.pooling, PoolingStrategy::Mean);
assert!(config.normalize);
}
}

View File

@@ -0,0 +1,233 @@
//! Error types for ONNX embeddings
use thiserror::Error;
/// Result type alias for embedding operations
pub type Result<T> = std::result::Result<T, EmbeddingError>;
/// Errors that can occur during embedding operations
#[derive(Error, Debug)]
pub enum EmbeddingError {
/// ONNX Runtime error
#[error("ONNX Runtime error: {0}")]
OnnxRuntime(#[from] ort::Error),
/// Tokenizer error
#[error("Tokenizer error: {0}")]
Tokenizer(#[from] tokenizers::tokenizer::Error),
/// IO error
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
/// HTTP request error
#[error("HTTP error: {0}")]
Http(#[from] reqwest::Error),
/// Model not found
#[error("Model not found: {path}")]
ModelNotFound { path: String },
/// Tokenizer not found
#[error("Tokenizer not found: {path}")]
TokenizerNotFound { path: String },
/// Invalid model format
#[error("Invalid model format: {reason}")]
InvalidModel { reason: String },
/// Dimension mismatch
#[error("Dimension mismatch: expected {expected}, got {actual}")]
DimensionMismatch { expected: usize, actual: usize },
/// Empty input
#[error("Empty input provided")]
EmptyInput,
/// Batch size exceeded
#[error("Batch size {size} exceeds maximum {max}")]
BatchSizeExceeded { size: usize, max: usize },
/// Sequence too long
#[error("Sequence length {length} exceeds maximum {max}")]
SequenceTooLong { length: usize, max: usize },
/// Download failed
#[error("Failed to download model: {reason}")]
DownloadFailed { reason: String },
/// Cache error
#[error("Cache error: {reason}")]
CacheError { reason: String },
/// Checksum mismatch
#[error("Checksum mismatch: expected {expected}, got {actual}")]
ChecksumMismatch { expected: String, actual: String },
/// Invalid configuration
#[error("Invalid configuration: {reason}")]
InvalidConfig { reason: String },
/// Execution provider not available
#[error("Execution provider not available: {provider}")]
ExecutionProviderNotAvailable { provider: String },
/// RuVector integration error
#[error("RuVector error: {0}")]
RuVector(String),
/// Serialization error
#[error("Serialization error: {0}")]
Serialization(#[from] serde_json::Error),
/// Shape error from ndarray
#[error("Shape error: {0}")]
Shape(#[from] ndarray::ShapeError),
/// Generic error
#[error("{0}")]
Other(String),
/// GPU initialization error
#[error("GPU initialization failed: {reason}")]
GpuInitFailed { reason: String },
/// GPU operation error
#[error("GPU operation failed: {operation} - {reason}")]
GpuOperationFailed { operation: String, reason: String },
/// Shader compilation error
#[error("Shader compilation failed: {shader} - {reason}")]
ShaderCompilationFailed { shader: String, reason: String },
/// GPU buffer error
#[error("GPU buffer error: {reason}")]
GpuBufferError { reason: String },
/// GPU not available
#[error("GPU not available: {reason}")]
GpuNotAvailable { reason: String },
}
impl EmbeddingError {
/// Create a model not found error
pub fn model_not_found(path: impl Into<String>) -> Self {
Self::ModelNotFound { path: path.into() }
}
/// Create a tokenizer not found error
pub fn tokenizer_not_found(path: impl Into<String>) -> Self {
Self::TokenizerNotFound { path: path.into() }
}
/// Create an invalid model error
pub fn invalid_model(reason: impl Into<String>) -> Self {
Self::InvalidModel {
reason: reason.into(),
}
}
/// Create a dimension mismatch error
pub fn dimension_mismatch(expected: usize, actual: usize) -> Self {
Self::DimensionMismatch { expected, actual }
}
/// Create a download failed error
pub fn download_failed(reason: impl Into<String>) -> Self {
Self::DownloadFailed {
reason: reason.into(),
}
}
/// Create a cache error
pub fn cache_error(reason: impl Into<String>) -> Self {
Self::CacheError {
reason: reason.into(),
}
}
/// Create an invalid config error
pub fn invalid_config(reason: impl Into<String>) -> Self {
Self::InvalidConfig {
reason: reason.into(),
}
}
/// Create an execution provider error
pub fn execution_provider_not_available(provider: impl Into<String>) -> Self {
Self::ExecutionProviderNotAvailable {
provider: provider.into(),
}
}
/// Create a RuVector error
pub fn ruvector(msg: impl Into<String>) -> Self {
Self::RuVector(msg.into())
}
/// Create a generic error
pub fn other(msg: impl Into<String>) -> Self {
Self::Other(msg.into())
}
/// Create a GPU initialization error
pub fn gpu_init_failed(reason: impl Into<String>) -> Self {
Self::GpuInitFailed { reason: reason.into() }
}
/// Create a GPU operation error
pub fn gpu_operation_failed(operation: impl Into<String>, reason: impl Into<String>) -> Self {
Self::GpuOperationFailed {
operation: operation.into(),
reason: reason.into(),
}
}
/// Create a shader compilation error
pub fn shader_compilation_failed(shader: impl Into<String>, reason: impl Into<String>) -> Self {
Self::ShaderCompilationFailed {
shader: shader.into(),
reason: reason.into(),
}
}
/// Create a GPU buffer error
pub fn gpu_buffer_error(reason: impl Into<String>) -> Self {
Self::GpuBufferError { reason: reason.into() }
}
/// Create a GPU not available error
pub fn gpu_not_available(reason: impl Into<String>) -> Self {
Self::GpuNotAvailable { reason: reason.into() }
}
/// Check if this error is a GPU error
pub fn is_gpu_error(&self) -> bool {
matches!(
self,
Self::GpuInitFailed { .. }
| Self::GpuOperationFailed { .. }
| Self::ShaderCompilationFailed { .. }
| Self::GpuBufferError { .. }
| Self::GpuNotAvailable { .. }
)
}
/// Check if this error is recoverable
pub fn is_recoverable(&self) -> bool {
matches!(
self,
Self::Http(_) | Self::DownloadFailed { .. } | Self::CacheError { .. }
)
}
/// Check if this error is a configuration error
pub fn is_config_error(&self) -> bool {
matches!(
self,
Self::InvalidConfig { .. }
| Self::InvalidModel { .. }
| Self::DimensionMismatch { .. }
)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,293 @@
//! GPU Configuration for RuVector ONNX Embeddings
//!
//! Provides configuration options for GPU acceleration including
//! device selection, memory limits, and performance tuning.
use serde::{Deserialize, Serialize};
/// GPU execution mode
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
pub enum GpuMode {
/// Automatically select best available backend
#[default]
Auto,
/// Force WebGPU backend
WebGpu,
/// Force CUDA-WASM transpiled backend
CudaWasm,
/// CPU-only (disable GPU)
CpuOnly,
}
/// Power preference for GPU device selection
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
pub enum PowerPreference {
/// Prefer low power consumption (integrated GPU)
LowPower,
/// Prefer high performance (discrete GPU)
#[default]
HighPerformance,
/// No preference
None,
}
/// GPU acceleration configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GpuConfig {
/// GPU execution mode
pub mode: GpuMode,
/// Power preference for device selection
pub power_preference: PowerPreference,
/// Maximum GPU memory usage (bytes, 0 = unlimited)
pub max_memory: u64,
/// Workgroup size for compute shaders (0 = auto)
pub workgroup_size: u32,
/// Enable async GPU operations
pub async_compute: bool,
/// Minimum batch size to use GPU (smaller batches use CPU)
pub min_batch_size: usize,
/// Minimum vector dimension to use GPU
pub min_dimension: usize,
/// Enable shader caching
pub cache_shaders: bool,
/// Enable profiling and timing
pub enable_profiling: bool,
/// Fallback to CPU on GPU error
pub fallback_to_cpu: bool,
/// Device index (for multi-GPU systems)
pub device_index: u32,
}
impl Default for GpuConfig {
fn default() -> Self {
Self {
mode: GpuMode::Auto,
power_preference: PowerPreference::HighPerformance,
max_memory: 0, // unlimited
workgroup_size: 256,
async_compute: true,
min_batch_size: 16,
min_dimension: 128,
cache_shaders: true,
enable_profiling: false,
fallback_to_cpu: true,
device_index: 0,
}
}
}
impl GpuConfig {
/// Create configuration with automatic settings
pub fn auto() -> Self {
Self::default()
}
/// Create configuration for high performance
pub fn high_performance() -> Self {
Self {
mode: GpuMode::Auto,
power_preference: PowerPreference::HighPerformance,
workgroup_size: 512,
async_compute: true,
min_batch_size: 8,
min_dimension: 64,
..Default::default()
}
}
/// Create configuration for low power usage
pub fn low_power() -> Self {
Self {
mode: GpuMode::Auto,
power_preference: PowerPreference::LowPower,
workgroup_size: 128,
async_compute: false,
min_batch_size: 32,
min_dimension: 256,
..Default::default()
}
}
/// Create CPU-only configuration
pub fn cpu_only() -> Self {
Self {
mode: GpuMode::CpuOnly,
..Default::default()
}
}
/// Create WebGPU-specific configuration
pub fn webgpu() -> Self {
Self {
mode: GpuMode::WebGpu,
..Default::default()
}
}
/// Create CUDA-WASM specific configuration
#[cfg(feature = "cuda-wasm")]
pub fn cuda_wasm() -> Self {
Self {
mode: GpuMode::CudaWasm,
workgroup_size: 256,
..Default::default()
}
}
// Builder methods
/// Set GPU mode
pub fn with_mode(mut self, mode: GpuMode) -> Self {
self.mode = mode;
self
}
/// Set power preference
pub fn with_power_preference(mut self, pref: PowerPreference) -> Self {
self.power_preference = pref;
self
}
/// Set maximum memory
pub fn with_max_memory(mut self, bytes: u64) -> Self {
self.max_memory = bytes;
self
}
/// Set workgroup size
pub fn with_workgroup_size(mut self, size: u32) -> Self {
self.workgroup_size = size;
self
}
/// Set minimum batch size for GPU usage
pub fn with_min_batch_size(mut self, size: usize) -> Self {
self.min_batch_size = size;
self
}
/// Set minimum dimension for GPU usage
pub fn with_min_dimension(mut self, dim: usize) -> Self {
self.min_dimension = dim;
self
}
/// Enable or disable profiling
pub fn with_profiling(mut self, enable: bool) -> Self {
self.enable_profiling = enable;
self
}
/// Enable or disable CPU fallback
pub fn with_fallback(mut self, enable: bool) -> Self {
self.fallback_to_cpu = enable;
self
}
/// Set device index
pub fn with_device(mut self, index: u32) -> Self {
self.device_index = index;
self
}
/// Check if GPU should be used for given workload
pub fn should_use_gpu(&self, batch_size: usize, dimension: usize) -> bool {
self.mode != GpuMode::CpuOnly
&& batch_size >= self.min_batch_size
&& dimension >= self.min_dimension
}
}
/// GPU memory statistics
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct GpuMemoryStats {
/// Total GPU memory (bytes)
pub total: u64,
/// Used GPU memory (bytes)
pub used: u64,
/// Free GPU memory (bytes)
pub free: u64,
/// Peak usage (bytes)
pub peak: u64,
}
impl GpuMemoryStats {
/// Get usage percentage
pub fn usage_percent(&self) -> f32 {
if self.total > 0 {
(self.used as f32 / self.total as f32) * 100.0
} else {
0.0
}
}
}
/// GPU profiling data
#[allow(dead_code)]
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct GpuProfilingData {
/// Total operations executed
pub operations: u64,
/// Total GPU time (microseconds)
pub gpu_time_us: u64,
/// Total CPU time (microseconds)
pub cpu_time_us: u64,
/// GPU speedup over CPU
pub speedup: f32,
/// Memory transfers (bytes)
pub memory_transferred: u64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = GpuConfig::default();
assert_eq!(config.mode, GpuMode::Auto);
assert_eq!(config.power_preference, PowerPreference::HighPerformance);
assert!(config.fallback_to_cpu);
}
#[test]
fn test_should_use_gpu() {
let config = GpuConfig::default()
.with_min_batch_size(16)
.with_min_dimension(128);
assert!(!config.should_use_gpu(8, 384)); // batch too small
assert!(!config.should_use_gpu(32, 64)); // dimension too small
assert!(config.should_use_gpu(32, 384)); // both ok
}
#[test]
fn test_cpu_only() {
let config = GpuConfig::cpu_only();
assert!(!config.should_use_gpu(1000, 1000));
}
#[test]
fn test_builder() {
let config = GpuConfig::auto()
.with_mode(GpuMode::WebGpu)
.with_max_memory(1024 * 1024 * 1024)
.with_workgroup_size(512)
.with_profiling(true);
assert_eq!(config.mode, GpuMode::WebGpu);
assert_eq!(config.max_memory, 1024 * 1024 * 1024);
assert_eq!(config.workgroup_size, 512);
assert!(config.enable_profiling);
}
}

View File

@@ -0,0 +1,298 @@
//! GPU Acceleration Module for RuVector ONNX Embeddings
//!
//! This module provides optional GPU acceleration using cuda-wasm for:
//! - Pooling operations
//! - Similarity computations
//! - Batch vector operations
//!
//! ## Architecture
//!
//! ```text
//! ┌─────────────────────────────────────────────────────────────────┐
//! │ GPU Acceleration Layer │
//! ├─────────────────────────────────────────────────────────────────┤
//! │ ┌─────────────┐ ┌─────────────┐ ┌─────────────────────┐ │
//! │ │ GpuBackend │ -> │ Shaders │ -> │ WebGPU Runtime │ │
//! │ │ (Trait) │ │ (WGSL) │ │ (wgpu) │ │
//! │ └─────────────┘ └─────────────┘ └─────────────────────┘ │
//! │ │ │ │
//! │ v v │
//! │ ┌─────────────┐ ┌─────────────┐ ┌─────────────────────┐ │
//! │ │ GpuPooler │ │ GpuSimilar │ │ GpuVectorOps │ │
//! │ │ │ │ │ │ │ │
//! │ └─────────────┘ └─────────────┘ └─────────────────────┘ │
//! └─────────────────────────────────────────────────────────────────┘
//! ```
//!
//! ## Feature Flags
//!
//! - `gpu`: Enable GPU acceleration (WebGPU backend)
//! - `cuda-wasm`: Enable CUDA-WASM transpilation support
//!
//! ## Usage
//!
//! ```rust,ignore
//! use ruvector_onnx_embeddings::gpu::{GpuAccelerator, GpuConfig};
//!
//! // Create GPU accelerator with auto-detection
//! let gpu = GpuAccelerator::new(GpuConfig::auto()).await?;
//!
//! // GPU-accelerated similarity search
//! let similarities = gpu.batch_cosine_similarity(&query, &candidates)?;
//!
//! // GPU-accelerated pooling
//! let pooled = gpu.mean_pool(&token_embeddings, &attention_mask)?;
//! ```
mod backend;
mod config;
mod operations;
mod shaders;
#[cfg(test)]
mod tests;
pub use backend::{GpuBackend, GpuDevice, GpuInfo};
pub use config::{GpuConfig, GpuMode, PowerPreference};
pub use operations::{
GpuPooler, GpuSimilarity, GpuVectorOps,
batch_cosine_similarity_gpu, batch_dot_product_gpu, batch_euclidean_gpu,
};
pub use shaders::ShaderRegistry;
use crate::Result;
use std::sync::Arc;
/// GPU Accelerator - Main entry point for GPU operations
///
/// Provides unified access to GPU-accelerated operations with automatic
/// fallback to CPU when GPU is unavailable.
pub struct GpuAccelerator {
backend: Arc<dyn GpuBackend>,
config: GpuConfig,
pooler: GpuPooler,
similarity: GpuSimilarity,
vector_ops: GpuVectorOps,
}
impl GpuAccelerator {
/// Create a new GPU accelerator with the given configuration
pub async fn new(config: GpuConfig) -> Result<Self> {
let backend: Arc<dyn GpuBackend> = Arc::from(backend::create_backend(&config).await?);
let shader_registry = ShaderRegistry::new();
let mut pooler = GpuPooler::new(backend.as_ref(), &shader_registry)?;
let mut similarity = GpuSimilarity::new(backend.as_ref(), &shader_registry)?;
let mut vector_ops = GpuVectorOps::new(backend.as_ref(), &shader_registry)?;
// Wire up the backend to all components for GPU dispatch
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
{
pooler.set_backend(Arc::clone(&backend));
similarity.set_backend(Arc::clone(&backend));
vector_ops.set_backend(Arc::clone(&backend));
}
Ok(Self {
backend,
config,
pooler,
similarity,
vector_ops,
})
}
/// Create with automatic configuration
pub async fn auto() -> Result<Self> {
Self::new(GpuConfig::auto()).await
}
/// Check if GPU acceleration is available
pub fn is_available(&self) -> bool {
self.backend.is_available()
}
/// Get GPU device information
pub fn device_info(&self) -> GpuInfo {
self.backend.device_info()
}
/// Get the current configuration
pub fn config(&self) -> &GpuConfig {
&self.config
}
// ==================== Pooling Operations ====================
/// Mean pooling over token embeddings (GPU-accelerated)
pub fn mean_pool(
&self,
token_embeddings: &[f32],
attention_mask: &[i64],
batch_size: usize,
seq_length: usize,
hidden_size: usize,
) -> Result<Vec<f32>> {
self.pooler.mean_pool(
token_embeddings,
attention_mask,
batch_size,
seq_length,
hidden_size,
)
}
/// CLS token pooling (GPU-accelerated)
pub fn cls_pool(
&self,
token_embeddings: &[f32],
batch_size: usize,
hidden_size: usize,
) -> Result<Vec<f32>> {
self.pooler.cls_pool(token_embeddings, batch_size, hidden_size)
}
/// Max pooling over token embeddings (GPU-accelerated)
pub fn max_pool(
&self,
token_embeddings: &[f32],
attention_mask: &[i64],
batch_size: usize,
seq_length: usize,
hidden_size: usize,
) -> Result<Vec<f32>> {
self.pooler.max_pool(
token_embeddings,
attention_mask,
batch_size,
seq_length,
hidden_size,
)
}
// ==================== Similarity Operations ====================
/// Batch cosine similarity (GPU-accelerated)
pub fn batch_cosine_similarity(
&self,
query: &[f32],
candidates: &[&[f32]],
) -> Result<Vec<f32>> {
self.similarity.batch_cosine(query, candidates)
}
/// Batch dot product (GPU-accelerated)
pub fn batch_dot_product(
&self,
query: &[f32],
candidates: &[&[f32]],
) -> Result<Vec<f32>> {
self.similarity.batch_dot_product(query, candidates)
}
/// Batch Euclidean distance (GPU-accelerated)
pub fn batch_euclidean_distance(
&self,
query: &[f32],
candidates: &[&[f32]],
) -> Result<Vec<f32>> {
self.similarity.batch_euclidean(query, candidates)
}
/// Find top-k most similar vectors (GPU-accelerated)
pub fn top_k_similar(
&self,
query: &[f32],
candidates: &[&[f32]],
k: usize,
) -> Result<Vec<(usize, f32)>> {
self.similarity.top_k(query, candidates, k)
}
// ==================== Vector Operations ====================
/// L2 normalize vectors (GPU-accelerated)
pub fn normalize_batch(&self, vectors: &mut [f32], dimension: usize) -> Result<()> {
self.vector_ops.normalize_batch(vectors, dimension)
}
/// Matrix-vector multiplication (GPU-accelerated)
pub fn matmul(
&self,
matrix: &[f32],
vector: &[f32],
rows: usize,
cols: usize,
) -> Result<Vec<f32>> {
self.vector_ops.matmul(matrix, vector, rows, cols)
}
/// Batch vector addition (GPU-accelerated)
pub fn batch_add(&self, a: &[f32], b: &[f32]) -> Result<Vec<f32>> {
self.vector_ops.batch_add(a, b)
}
/// Batch vector scaling (GPU-accelerated)
pub fn batch_scale(&self, vectors: &mut [f32], scale: f32) -> Result<()> {
self.vector_ops.batch_scale(vectors, scale)
}
}
/// Convenience function to check GPU availability without creating accelerator
pub async fn is_gpu_available() -> bool {
backend::probe_gpu().await
}
/// Get GPU device info without full initialization
pub async fn get_gpu_info() -> Option<GpuInfo> {
backend::get_device_info().await
}
/// Fallback wrapper that tries GPU first, then CPU
pub struct HybridAccelerator {
gpu: Option<GpuAccelerator>,
use_gpu: bool,
}
impl HybridAccelerator {
/// Create hybrid accelerator with GPU if available
pub async fn new() -> Self {
let gpu = GpuAccelerator::auto().await.ok();
let use_gpu = gpu.is_some();
Self { gpu, use_gpu }
}
/// Check if GPU is being used
pub fn using_gpu(&self) -> bool {
self.use_gpu && self.gpu.is_some()
}
/// Disable GPU (use CPU only)
pub fn disable_gpu(&mut self) {
self.use_gpu = false;
}
/// Enable GPU if available
pub fn enable_gpu(&mut self) {
self.use_gpu = self.gpu.is_some();
}
/// Batch cosine similarity with automatic backend selection
pub fn batch_cosine_similarity(
&self,
query: &[f32],
candidates: &[Vec<f32>],
) -> Vec<f32> {
if self.use_gpu {
if let Some(ref gpu) = self.gpu {
let refs: Vec<&[f32]> = candidates.iter().map(|v| v.as_slice()).collect();
if let Ok(result) = gpu.batch_cosine_similarity(query, &refs) {
return result;
}
}
}
// CPU fallback
crate::pooling::batch_cosine_similarity(query, candidates)
}
}

View File

@@ -0,0 +1,934 @@
//! GPU-Accelerated Operations
//!
//! High-level GPU operations for embeddings with automatic fallback to CPU.
use crate::{EmbeddingError, Result};
use super::backend::{GpuBackend, BufferUsage};
use super::shaders::ShaderRegistry;
use rayon::prelude::*;
use std::sync::Arc;
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
use bytemuck;
// ==================== GPU Pooler ====================
/// GPU-accelerated pooling operations
pub struct GpuPooler {
use_gpu: bool,
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
backend: Option<Arc<dyn GpuBackend>>,
}
impl GpuPooler {
/// Create new GPU pooler
pub fn new(backend: &dyn GpuBackend, _shaders: &ShaderRegistry) -> Result<Self> {
let use_gpu = backend.is_available() && backend.device_info().supports_compute;
Ok(Self {
use_gpu,
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
backend: None, // Will be set by GpuAccelerator
})
}
/// Set the backend for GPU operations
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
pub fn set_backend(&mut self, backend: Arc<dyn GpuBackend>) {
self.backend = Some(backend);
}
/// Mean pooling (GPU or CPU fallback)
pub fn mean_pool(
&self,
token_embeddings: &[f32],
attention_mask: &[i64],
batch_size: usize,
seq_length: usize,
hidden_size: usize,
) -> Result<Vec<f32>> {
// GPU implementation requires minimum batch size for efficiency
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
if self.use_gpu && batch_size >= 8 && self.backend.is_some() {
return self.mean_pool_gpu(token_embeddings, attention_mask, batch_size, seq_length, hidden_size);
}
Ok(self.mean_pool_cpu(token_embeddings, attention_mask, batch_size, seq_length, hidden_size))
}
/// CLS pooling (GPU or CPU fallback)
pub fn cls_pool(
&self,
token_embeddings: &[f32],
batch_size: usize,
hidden_size: usize,
) -> Result<Vec<f32>> {
// CLS pooling is simple copy, CPU is often faster
Ok(self.cls_pool_cpu(token_embeddings, batch_size, hidden_size))
}
/// Max pooling (GPU or CPU fallback)
pub fn max_pool(
&self,
token_embeddings: &[f32],
attention_mask: &[i64],
batch_size: usize,
seq_length: usize,
hidden_size: usize,
) -> Result<Vec<f32>> {
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
if self.use_gpu && batch_size >= 8 && self.backend.is_some() {
return self.max_pool_gpu(token_embeddings, attention_mask, batch_size, seq_length, hidden_size);
}
Ok(self.max_pool_cpu(token_embeddings, attention_mask, batch_size, seq_length, hidden_size))
}
// GPU implementations
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
fn mean_pool_gpu(
&self,
token_embeddings: &[f32],
attention_mask: &[i64],
batch_size: usize,
seq_length: usize,
hidden_size: usize,
) -> Result<Vec<f32>> {
let backend = self.backend.as_ref().ok_or_else(|| {
EmbeddingError::GpuOperationFailed {
operation: "mean_pool".to_string(),
reason: "Backend not initialized".to_string(),
}
})?;
// Create buffers
let token_buf = backend.create_buffer(
(token_embeddings.len() * 4) as u64,
BufferUsage::Storage,
)?;
let mask_buf = backend.create_buffer(
(attention_mask.len() * 8) as u64,
BufferUsage::Storage,
)?;
let output_buf = backend.create_buffer(
(batch_size * hidden_size * 4) as u64,
BufferUsage::Storage,
)?;
// Create params buffer (batch_size, seq_length, hidden_size)
let params: [u32; 3] = [batch_size as u32, seq_length as u32, hidden_size as u32];
let params_buf = backend.create_buffer(16, BufferUsage::Uniform)?; // 16 bytes aligned
backend.write_buffer(&params_buf, bytemuck::cast_slice(&params))?;
// Write input data
backend.write_buffer(&token_buf, bytemuck::cast_slice(token_embeddings))?;
backend.write_buffer(&mask_buf, bytemuck::cast_slice(attention_mask))?;
// Create pipeline with mean pool shader
let shader = super::shaders::MEAN_POOL_SHADER;
let pipeline = backend.create_pipeline(shader, "mean_pool", [64, 1, 1])?;
// Dispatch with params buffer as 4th binding
let total_outputs = batch_size * hidden_size;
let workgroups = [total_outputs.div_ceil(64) as u32, 1, 1];
backend.dispatch(&pipeline, &[&token_buf, &mask_buf, &output_buf, &params_buf], workgroups)?;
backend.sync()?;
// Read output
let output_bytes = backend.read_buffer(&output_buf, (batch_size * hidden_size * 4) as u64)?;
let output: Vec<f32> = bytemuck::cast_slice(&output_bytes).to_vec();
// Cleanup
backend.release_buffer(token_buf)?;
backend.release_buffer(mask_buf)?;
backend.release_buffer(output_buf)?;
backend.release_buffer(params_buf)?;
backend.release_pipeline(pipeline)?;
Ok(output)
}
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
fn max_pool_gpu(
&self,
token_embeddings: &[f32],
attention_mask: &[i64],
batch_size: usize,
seq_length: usize,
hidden_size: usize,
) -> Result<Vec<f32>> {
let backend = self.backend.as_ref().ok_or_else(|| {
EmbeddingError::GpuOperationFailed {
operation: "max_pool".to_string(),
reason: "Backend not initialized".to_string(),
}
})?;
// Create buffers
let token_buf = backend.create_buffer(
(token_embeddings.len() * 4) as u64,
BufferUsage::Storage,
)?;
let mask_buf = backend.create_buffer(
(attention_mask.len() * 8) as u64,
BufferUsage::Storage,
)?;
let output_buf = backend.create_buffer(
(batch_size * hidden_size * 4) as u64,
BufferUsage::Storage,
)?;
// Create params buffer (batch_size, seq_length, hidden_size)
let params: [u32; 3] = [batch_size as u32, seq_length as u32, hidden_size as u32];
let params_buf = backend.create_buffer(16, BufferUsage::Uniform)?;
backend.write_buffer(&params_buf, bytemuck::cast_slice(&params))?;
// Write input data
backend.write_buffer(&token_buf, bytemuck::cast_slice(token_embeddings))?;
backend.write_buffer(&mask_buf, bytemuck::cast_slice(attention_mask))?;
// Create pipeline with max pool shader
let shader = super::shaders::MAX_POOL_SHADER;
let pipeline = backend.create_pipeline(shader, "max_pool", [64, 1, 1])?;
// Dispatch with params buffer as 4th binding
let total_outputs = batch_size * hidden_size;
let workgroups = [total_outputs.div_ceil(64) as u32, 1, 1];
backend.dispatch(&pipeline, &[&token_buf, &mask_buf, &output_buf, &params_buf], workgroups)?;
backend.sync()?;
// Read output
let output_bytes = backend.read_buffer(&output_buf, (batch_size * hidden_size * 4) as u64)?;
let output: Vec<f32> = bytemuck::cast_slice(&output_bytes).to_vec();
// Cleanup
backend.release_buffer(token_buf)?;
backend.release_buffer(mask_buf)?;
backend.release_buffer(output_buf)?;
backend.release_buffer(params_buf)?;
backend.release_pipeline(pipeline)?;
Ok(output)
}
// CPU implementations
fn mean_pool_cpu(
&self,
token_embeddings: &[f32],
attention_mask: &[i64],
batch_size: usize,
seq_length: usize,
hidden_size: usize,
) -> Vec<f32> {
let mut output = vec![0.0f32; batch_size * hidden_size];
output
.par_chunks_mut(hidden_size)
.enumerate()
.for_each(|(batch_idx, out_chunk)| {
let tokens_base = batch_idx * seq_length * hidden_size;
let mask_base = batch_idx * seq_length;
let mut count = 0.0f32;
for seq_idx in 0..seq_length {
if attention_mask[mask_base + seq_idx] == 1 {
let start = tokens_base + seq_idx * hidden_size;
for (j, out_val) in out_chunk.iter_mut().enumerate() {
*out_val += token_embeddings[start + j];
}
count += 1.0;
}
}
if count > 0.0 {
for val in out_chunk.iter_mut() {
*val /= count;
}
}
});
output
}
fn cls_pool_cpu(
&self,
token_embeddings: &[f32],
batch_size: usize,
hidden_size: usize,
) -> Vec<f32> {
let seq_length = token_embeddings.len() / (batch_size * hidden_size);
let mut output = vec![0.0f32; batch_size * hidden_size];
for batch_idx in 0..batch_size {
let src_start = batch_idx * seq_length * hidden_size;
let dst_start = batch_idx * hidden_size;
output[dst_start..dst_start + hidden_size]
.copy_from_slice(&token_embeddings[src_start..src_start + hidden_size]);
}
output
}
fn max_pool_cpu(
&self,
token_embeddings: &[f32],
attention_mask: &[i64],
batch_size: usize,
seq_length: usize,
hidden_size: usize,
) -> Vec<f32> {
let mut output = vec![f32::NEG_INFINITY; batch_size * hidden_size];
output
.par_chunks_mut(hidden_size)
.enumerate()
.for_each(|(batch_idx, out_chunk)| {
let tokens_base = batch_idx * seq_length * hidden_size;
let mask_base = batch_idx * seq_length;
for seq_idx in 0..seq_length {
if attention_mask[mask_base + seq_idx] == 1 {
let start = tokens_base + seq_idx * hidden_size;
for (j, out_val) in out_chunk.iter_mut().enumerate() {
let val = token_embeddings[start + j];
if val > *out_val {
*out_val = val;
}
}
}
}
// Replace -inf with 0
for val in out_chunk.iter_mut() {
if val.is_infinite() {
*val = 0.0;
}
}
});
output
}
}
// ==================== GPU Similarity ====================
/// GPU-accelerated similarity computations
pub struct GpuSimilarity {
use_gpu: bool,
min_candidates: usize,
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
backend: Option<Arc<dyn GpuBackend>>,
}
impl GpuSimilarity {
/// Create new GPU similarity calculator
pub fn new(backend: &dyn GpuBackend, _shaders: &ShaderRegistry) -> Result<Self> {
Ok(Self {
use_gpu: backend.is_available() && backend.device_info().supports_compute,
min_candidates: 64, // Minimum candidates to use GPU
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
backend: None,
})
}
/// Set the backend for GPU operations
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
pub fn set_backend(&mut self, backend: Arc<dyn GpuBackend>) {
self.backend = Some(backend);
}
/// Batch cosine similarity
pub fn batch_cosine(&self, query: &[f32], candidates: &[&[f32]]) -> Result<Vec<f32>> {
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
if self.use_gpu && candidates.len() >= self.min_candidates && self.backend.is_some() {
return self.batch_cosine_gpu(query, candidates);
}
Ok(self.batch_cosine_cpu(query, candidates))
}
/// Batch dot product
pub fn batch_dot_product(&self, query: &[f32], candidates: &[&[f32]]) -> Result<Vec<f32>> {
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
if self.use_gpu && candidates.len() >= self.min_candidates && self.backend.is_some() {
return self.batch_dot_product_gpu(query, candidates);
}
Ok(self.batch_dot_product_cpu(query, candidates))
}
/// Batch Euclidean distance
pub fn batch_euclidean(&self, query: &[f32], candidates: &[&[f32]]) -> Result<Vec<f32>> {
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
if self.use_gpu && candidates.len() >= self.min_candidates && self.backend.is_some() {
return self.batch_euclidean_gpu(query, candidates);
}
Ok(self.batch_euclidean_cpu(query, candidates))
}
/// Find top-k most similar
pub fn top_k(&self, query: &[f32], candidates: &[&[f32]], k: usize) -> Result<Vec<(usize, f32)>> {
let similarities = self.batch_cosine(query, candidates)?;
let mut indexed: Vec<(usize, f32)> = similarities.into_iter().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
indexed.truncate(k);
Ok(indexed)
}
// GPU implementations
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
fn batch_cosine_gpu(&self, query: &[f32], candidates: &[&[f32]]) -> Result<Vec<f32>> {
let backend = self.backend.as_ref().ok_or_else(|| {
EmbeddingError::GpuOperationFailed {
operation: "batch_cosine".to_string(),
reason: "Backend not initialized".to_string(),
}
})?;
let dimension = query.len();
let num_candidates = candidates.len();
// Flatten candidates into contiguous buffer
let candidates_flat: Vec<f32> = candidates.iter().flat_map(|c| c.iter().copied()).collect();
// Create buffers
let query_buf = backend.create_buffer((dimension * 4) as u64, BufferUsage::Storage)?;
let candidates_buf = backend.create_buffer((candidates_flat.len() * 4) as u64, BufferUsage::Storage)?;
let output_buf = backend.create_buffer((num_candidates * 4) as u64, BufferUsage::Storage)?;
// Create params buffer (dimension, num_candidates)
let params: [u32; 2] = [dimension as u32, num_candidates as u32];
let params_buf = backend.create_buffer(8, BufferUsage::Uniform)?;
backend.write_buffer(&params_buf, bytemuck::cast_slice(&params))?;
// Write input data
backend.write_buffer(&query_buf, bytemuck::cast_slice(query))?;
backend.write_buffer(&candidates_buf, bytemuck::cast_slice(&candidates_flat))?;
// Create pipeline with batch cosine shader
let shader = super::shaders::BATCH_COSINE_SIMILARITY_SHADER;
let pipeline = backend.create_pipeline(shader, "batch_cosine_similarity", [256, 1, 1])?;
// Dispatch with params buffer as 4th binding
let workgroups = [num_candidates.div_ceil(256) as u32, 1, 1];
backend.dispatch(&pipeline, &[&query_buf, &candidates_buf, &output_buf, &params_buf], workgroups)?;
backend.sync()?;
// Read output
let output_bytes = backend.read_buffer(&output_buf, (num_candidates * 4) as u64)?;
let output: Vec<f32> = bytemuck::cast_slice(&output_bytes).to_vec();
// Cleanup
backend.release_buffer(query_buf)?;
backend.release_buffer(candidates_buf)?;
backend.release_buffer(output_buf)?;
backend.release_buffer(params_buf)?;
backend.release_pipeline(pipeline)?;
Ok(output)
}
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
fn batch_dot_product_gpu(&self, query: &[f32], candidates: &[&[f32]]) -> Result<Vec<f32>> {
let backend = self.backend.as_ref().ok_or_else(|| {
EmbeddingError::GpuOperationFailed {
operation: "batch_dot_product".to_string(),
reason: "Backend not initialized".to_string(),
}
})?;
let dimension = query.len();
let num_candidates = candidates.len();
// Flatten candidates into contiguous buffer
let candidates_flat: Vec<f32> = candidates.iter().flat_map(|c| c.iter().copied()).collect();
// Create buffers
let query_buf = backend.create_buffer((dimension * 4) as u64, BufferUsage::Storage)?;
let candidates_buf = backend.create_buffer((candidates_flat.len() * 4) as u64, BufferUsage::Storage)?;
let output_buf = backend.create_buffer((num_candidates * 4) as u64, BufferUsage::Storage)?;
// Create params buffer (dimension, num_candidates)
let params: [u32; 2] = [dimension as u32, num_candidates as u32];
let params_buf = backend.create_buffer(8, BufferUsage::Uniform)?;
backend.write_buffer(&params_buf, bytemuck::cast_slice(&params))?;
// Write input data
backend.write_buffer(&query_buf, bytemuck::cast_slice(query))?;
backend.write_buffer(&candidates_buf, bytemuck::cast_slice(&candidates_flat))?;
// Create pipeline
let shader = super::shaders::DOT_PRODUCT_SHADER;
let pipeline = backend.create_pipeline(shader, "dot_product", [256, 1, 1])?;
// Dispatch with params buffer as 4th binding
let workgroups = [num_candidates.div_ceil(256) as u32, 1, 1];
backend.dispatch(&pipeline, &[&query_buf, &candidates_buf, &output_buf, &params_buf], workgroups)?;
backend.sync()?;
// Read output
let output_bytes = backend.read_buffer(&output_buf, (num_candidates * 4) as u64)?;
let output: Vec<f32> = bytemuck::cast_slice(&output_bytes).to_vec();
// Cleanup
backend.release_buffer(query_buf)?;
backend.release_buffer(candidates_buf)?;
backend.release_buffer(output_buf)?;
backend.release_buffer(params_buf)?;
backend.release_pipeline(pipeline)?;
Ok(output)
}
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
fn batch_euclidean_gpu(&self, query: &[f32], candidates: &[&[f32]]) -> Result<Vec<f32>> {
let backend = self.backend.as_ref().ok_or_else(|| {
EmbeddingError::GpuOperationFailed {
operation: "batch_euclidean".to_string(),
reason: "Backend not initialized".to_string(),
}
})?;
let dimension = query.len();
let num_candidates = candidates.len();
// Flatten candidates into contiguous buffer
let candidates_flat: Vec<f32> = candidates.iter().flat_map(|c| c.iter().copied()).collect();
// Create buffers
let query_buf = backend.create_buffer((dimension * 4) as u64, BufferUsage::Storage)?;
let candidates_buf = backend.create_buffer((candidates_flat.len() * 4) as u64, BufferUsage::Storage)?;
let output_buf = backend.create_buffer((num_candidates * 4) as u64, BufferUsage::Storage)?;
// Create params buffer (dimension, num_candidates)
let params: [u32; 2] = [dimension as u32, num_candidates as u32];
let params_buf = backend.create_buffer(8, BufferUsage::Uniform)?;
backend.write_buffer(&params_buf, bytemuck::cast_slice(&params))?;
// Write input data
backend.write_buffer(&query_buf, bytemuck::cast_slice(query))?;
backend.write_buffer(&candidates_buf, bytemuck::cast_slice(&candidates_flat))?;
// Create pipeline
let shader = super::shaders::EUCLIDEAN_DISTANCE_SHADER;
let pipeline = backend.create_pipeline(shader, "euclidean_distance", [256, 1, 1])?;
// Dispatch with params buffer as 4th binding
let workgroups = [num_candidates.div_ceil(256) as u32, 1, 1];
backend.dispatch(&pipeline, &[&query_buf, &candidates_buf, &output_buf, &params_buf], workgroups)?;
backend.sync()?;
// Read output
let output_bytes = backend.read_buffer(&output_buf, (num_candidates * 4) as u64)?;
let output: Vec<f32> = bytemuck::cast_slice(&output_bytes).to_vec();
// Cleanup
backend.release_buffer(query_buf)?;
backend.release_buffer(candidates_buf)?;
backend.release_buffer(output_buf)?;
backend.release_buffer(params_buf)?;
backend.release_pipeline(pipeline)?;
Ok(output)
}
// CPU implementations
fn batch_cosine_cpu(&self, query: &[f32], candidates: &[&[f32]]) -> Vec<f32> {
candidates
.par_iter()
.map(|c| cosine_similarity_cpu(query, c))
.collect()
}
fn batch_dot_product_cpu(&self, query: &[f32], candidates: &[&[f32]]) -> Vec<f32> {
candidates
.par_iter()
.map(|c| dot_product_cpu(query, c))
.collect()
}
fn batch_euclidean_cpu(&self, query: &[f32], candidates: &[&[f32]]) -> Vec<f32> {
candidates
.par_iter()
.map(|c| euclidean_distance_cpu(query, c))
.collect()
}
}
// ==================== GPU Vector Operations ====================
/// GPU-accelerated vector operations
pub struct GpuVectorOps {
use_gpu: bool,
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
backend: Option<Arc<dyn GpuBackend>>,
}
impl GpuVectorOps {
/// Create new GPU vector operations
pub fn new(backend: &dyn GpuBackend, _shaders: &ShaderRegistry) -> Result<Self> {
Ok(Self {
use_gpu: backend.is_available() && backend.device_info().supports_compute,
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
backend: None,
})
}
/// Set the backend for GPU operations
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
pub fn set_backend(&mut self, backend: Arc<dyn GpuBackend>) {
self.backend = Some(backend);
}
/// L2 normalize batch of vectors
pub fn normalize_batch(&self, vectors: &mut [f32], dimension: usize) -> Result<()> {
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
if self.use_gpu && vectors.len() >= dimension * 64 && self.backend.is_some() {
return self.normalize_batch_gpu(vectors, dimension);
}
self.normalize_batch_cpu(vectors, dimension);
Ok(())
}
/// Matrix-vector multiplication
pub fn matmul(&self, matrix: &[f32], vector: &[f32], rows: usize, cols: usize) -> Result<Vec<f32>> {
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
if self.use_gpu && rows >= 64 && self.backend.is_some() {
return self.matmul_gpu(matrix, vector, rows, cols);
}
Ok(self.matmul_cpu(matrix, vector, rows, cols))
}
/// Batch vector addition
pub fn batch_add(&self, a: &[f32], b: &[f32]) -> Result<Vec<f32>> {
if a.len() != b.len() {
return Err(EmbeddingError::dimension_mismatch(a.len(), b.len()));
}
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
if self.use_gpu && a.len() >= 1024 && self.backend.is_some() {
return self.batch_add_gpu(a, b);
}
Ok(a.par_iter().zip(b.par_iter()).map(|(x, y)| x + y).collect())
}
/// Batch vector scaling
pub fn batch_scale(&self, vectors: &mut [f32], scale: f32) -> Result<()> {
vectors.par_iter_mut().for_each(|v| *v *= scale);
Ok(())
}
// GPU implementations
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
fn normalize_batch_gpu(&self, vectors: &mut [f32], dimension: usize) -> Result<()> {
let backend = self.backend.as_ref().ok_or_else(|| {
EmbeddingError::GpuOperationFailed {
operation: "normalize_batch".to_string(),
reason: "Backend not initialized".to_string(),
}
})?;
let num_vectors = vectors.len() / dimension;
// Create buffers (input, dummy, output, params)
let input_buf = backend.create_buffer((vectors.len() * 4) as u64, BufferUsage::Storage)?;
let dummy_buf = backend.create_buffer(4, BufferUsage::Storage)?;
let output_buf = backend.create_buffer((vectors.len() * 4) as u64, BufferUsage::Storage)?;
// Create params buffer (dimension, num_vectors)
let params: [u32; 2] = [dimension as u32, num_vectors as u32];
let params_buf = backend.create_buffer(8, BufferUsage::Uniform)?;
backend.write_buffer(&params_buf, bytemuck::cast_slice(&params))?;
// Write input data
backend.write_buffer(&input_buf, bytemuck::cast_slice(vectors))?;
// Create pipeline
let shader = super::shaders::L2_NORMALIZE_SHADER;
let pipeline = backend.create_pipeline(shader, "l2_normalize", [256, 1, 1])?;
// Dispatch with 4 bindings
let workgroups = [num_vectors.div_ceil(256) as u32, 1, 1];
backend.dispatch(&pipeline, &[&input_buf, &dummy_buf, &output_buf, &params_buf], workgroups)?;
backend.sync()?;
// Read output
let output_bytes = backend.read_buffer(&output_buf, (vectors.len() * 4) as u64)?;
let output: &[f32] = bytemuck::cast_slice(&output_bytes);
vectors.copy_from_slice(output);
// Cleanup
backend.release_buffer(input_buf)?;
backend.release_buffer(dummy_buf)?;
backend.release_buffer(output_buf)?;
backend.release_buffer(params_buf)?;
backend.release_pipeline(pipeline)?;
Ok(())
}
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
fn matmul_gpu(&self, matrix: &[f32], vector: &[f32], rows: usize, cols: usize) -> Result<Vec<f32>> {
let backend = self.backend.as_ref().ok_or_else(|| {
EmbeddingError::GpuOperationFailed {
operation: "matmul".to_string(),
reason: "Backend not initialized".to_string(),
}
})?;
// Create buffers
let mat_buf = backend.create_buffer((matrix.len() * 4) as u64, BufferUsage::Storage)?;
let vec_buf = backend.create_buffer((vector.len() * 4) as u64, BufferUsage::Storage)?;
let output_buf = backend.create_buffer((rows * 4) as u64, BufferUsage::Storage)?;
// Create params buffer (rows, cols)
let params: [u32; 2] = [rows as u32, cols as u32];
let params_buf = backend.create_buffer(8, BufferUsage::Uniform)?;
backend.write_buffer(&params_buf, bytemuck::cast_slice(&params))?;
// Write input data
backend.write_buffer(&mat_buf, bytemuck::cast_slice(matrix))?;
backend.write_buffer(&vec_buf, bytemuck::cast_slice(vector))?;
// Create pipeline
let shader = super::shaders::MATMUL_SHADER;
let pipeline = backend.create_pipeline(shader, "matmul", [16, 16, 1])?;
// Dispatch with params buffer as 4th binding
let workgroups = [rows.div_ceil(16) as u32, 1, 1];
backend.dispatch(&pipeline, &[&mat_buf, &vec_buf, &output_buf, &params_buf], workgroups)?;
backend.sync()?;
// Read output
let output_bytes = backend.read_buffer(&output_buf, (rows * 4) as u64)?;
let output: Vec<f32> = bytemuck::cast_slice(&output_bytes).to_vec();
// Cleanup
backend.release_buffer(mat_buf)?;
backend.release_buffer(vec_buf)?;
backend.release_buffer(output_buf)?;
backend.release_buffer(params_buf)?;
backend.release_pipeline(pipeline)?;
Ok(output)
}
#[cfg(any(feature = "gpu", feature = "cuda-wasm"))]
fn batch_add_gpu(&self, a: &[f32], b: &[f32]) -> Result<Vec<f32>> {
let backend = self.backend.as_ref().ok_or_else(|| {
EmbeddingError::GpuOperationFailed {
operation: "batch_add".to_string(),
reason: "Backend not initialized".to_string(),
}
})?;
// Create buffers
let buf_a = backend.create_buffer((a.len() * 4) as u64, BufferUsage::Storage)?;
let buf_b = backend.create_buffer((b.len() * 4) as u64, BufferUsage::Storage)?;
let output_buf = backend.create_buffer((a.len() * 4) as u64, BufferUsage::Storage)?;
// Create params buffer (length)
let params: [u32; 1] = [a.len() as u32];
let params_buf = backend.create_buffer(4, BufferUsage::Uniform)?;
backend.write_buffer(&params_buf, bytemuck::cast_slice(&params))?;
// Write input data
backend.write_buffer(&buf_a, bytemuck::cast_slice(a))?;
backend.write_buffer(&buf_b, bytemuck::cast_slice(b))?;
// Create pipeline
let shader = super::shaders::VECTOR_ADD_SHADER;
let pipeline = backend.create_pipeline(shader, "vector_add", [256, 1, 1])?;
// Dispatch with params buffer as 4th binding
let workgroups = [a.len().div_ceil(256) as u32, 1, 1];
backend.dispatch(&pipeline, &[&buf_a, &buf_b, &output_buf, &params_buf], workgroups)?;
backend.sync()?;
// Read output
let output_bytes = backend.read_buffer(&output_buf, (a.len() * 4) as u64)?;
let output: Vec<f32> = bytemuck::cast_slice(&output_bytes).to_vec();
// Cleanup
backend.release_buffer(buf_a)?;
backend.release_buffer(buf_b)?;
backend.release_buffer(output_buf)?;
backend.release_buffer(params_buf)?;
backend.release_pipeline(pipeline)?;
Ok(output)
}
// CPU implementations
fn normalize_batch_cpu(&self, vectors: &mut [f32], dimension: usize) {
vectors
.par_chunks_mut(dimension)
.for_each(|chunk| {
let norm: f32 = chunk.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-12 {
for val in chunk.iter_mut() {
*val /= norm;
}
}
});
}
fn matmul_cpu(&self, matrix: &[f32], vector: &[f32], rows: usize, cols: usize) -> Vec<f32> {
let mut result = vec![0.0f32; rows];
result
.par_iter_mut()
.enumerate()
.for_each(|(row, out)| {
let row_start = row * cols;
*out = matrix[row_start..row_start + cols]
.iter()
.zip(vector.iter())
.map(|(m, v)| m * v)
.sum();
});
result
}
}
// ==================== Standalone Functions ====================
/// Batch cosine similarity (GPU-accelerated if available)
pub fn batch_cosine_similarity_gpu(query: &[f32], candidates: &[&[f32]]) -> Vec<f32> {
candidates
.par_iter()
.map(|c| cosine_similarity_cpu(query, c))
.collect()
}
/// Batch dot product (GPU-accelerated if available)
pub fn batch_dot_product_gpu(query: &[f32], candidates: &[&[f32]]) -> Vec<f32> {
candidates
.par_iter()
.map(|c| dot_product_cpu(query, c))
.collect()
}
/// Batch Euclidean distance (GPU-accelerated if available)
pub fn batch_euclidean_gpu(query: &[f32], candidates: &[&[f32]]) -> Vec<f32> {
candidates
.par_iter()
.map(|c| euclidean_distance_cpu(query, c))
.collect()
}
// ==================== CPU Helper Functions ====================
#[inline]
fn cosine_similarity_cpu(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a > 1e-12 && norm_b > 1e-12 {
dot / (norm_a * norm_b)
} else {
0.0
}
}
#[inline]
fn dot_product_cpu(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[inline]
fn euclidean_distance_cpu(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
let c = vec![0.0, 1.0, 0.0];
assert!((cosine_similarity_cpu(&a, &b) - 1.0).abs() < 1e-6);
assert!(cosine_similarity_cpu(&a, &c).abs() < 1e-6);
}
#[test]
fn test_dot_product() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
assert!((dot_product_cpu(&a, &b) - 32.0).abs() < 1e-6);
}
#[test]
fn test_euclidean_distance() {
let a = vec![0.0, 0.0, 0.0];
let b = vec![3.0, 4.0, 0.0];
assert!((euclidean_distance_cpu(&a, &b) - 5.0).abs() < 1e-6);
}
#[test]
fn test_batch_cosine() {
let query = vec![1.0, 0.0, 0.0];
let candidates: Vec<&[f32]> = vec![
&[1.0, 0.0, 0.0][..],
&[0.0, 1.0, 0.0][..],
&[0.707, 0.707, 0.0][..],
];
let results = batch_cosine_similarity_gpu(&query, &candidates);
assert_eq!(results.len(), 3);
assert!((results[0] - 1.0).abs() < 1e-6);
assert!(results[1].abs() < 1e-6);
}
#[test]
fn test_mean_pool_cpu() {
let pooler = GpuPooler {
use_gpu: false,
#[cfg(feature = "gpu")]
backend: None,
};
// batch=2, seq=2, hidden=3
let tokens = vec![
1.0, 2.0, 3.0, // batch 0, seq 0
4.0, 5.0, 6.0, // batch 0, seq 1
7.0, 8.0, 9.0, // batch 1, seq 0
10.0, 11.0, 12.0, // batch 1, seq 1
];
let mask = vec![1i64, 1, 1, 1];
let result = pooler.mean_pool_cpu(&tokens, &mask, 2, 2, 3);
assert_eq!(result.len(), 6);
// Batch 0: mean of [1,2,3] and [4,5,6] = [2.5, 3.5, 4.5]
assert!((result[0] - 2.5).abs() < 1e-6);
assert!((result[1] - 3.5).abs() < 1e-6);
assert!((result[2] - 4.5).abs() < 1e-6);
}
}

View File

@@ -0,0 +1,613 @@
//! GPU Compute Shaders for RuVector Operations
//!
//! WGSL (WebGPU Shading Language) implementations for:
//! - Pooling operations
//! - Similarity computations
//! - Vector normalization
//! - Matrix operations
use std::collections::HashMap;
/// Shader registry for managing compute shaders
#[derive(Debug)]
pub struct ShaderRegistry {
shaders: HashMap<String, ShaderModule>,
}
/// Shader module information
#[derive(Debug, Clone)]
pub struct ShaderModule {
/// Shader name
pub name: String,
/// WGSL source code
pub source: String,
/// Entry point function
pub entry_point: String,
/// Default workgroup size
pub workgroup_size: [u32; 3],
}
impl ShaderRegistry {
/// Create new registry with built-in shaders
pub fn new() -> Self {
let mut shaders = HashMap::new();
// Register all built-in shaders
for shader in Self::builtin_shaders() {
shaders.insert(shader.name.clone(), shader);
}
Self { shaders }
}
/// Get shader by name
pub fn get(&self, name: &str) -> Option<&ShaderModule> {
self.shaders.get(name)
}
/// Register custom shader
pub fn register(&mut self, shader: ShaderModule) {
self.shaders.insert(shader.name.clone(), shader);
}
/// List all available shaders
pub fn list(&self) -> Vec<&str> {
self.shaders.keys().map(|s| s.as_str()).collect()
}
/// Get built-in shader definitions
fn builtin_shaders() -> Vec<ShaderModule> {
vec![
// Cosine Similarity
ShaderModule {
name: "cosine_similarity".to_string(),
source: SHADER_COSINE_SIMILARITY.to_string(),
entry_point: "cosine_similarity".to_string(),
workgroup_size: [256, 1, 1],
},
// Batch Cosine Similarity
ShaderModule {
name: "batch_cosine_similarity".to_string(),
source: SHADER_BATCH_COSINE_SIMILARITY.to_string(),
entry_point: "batch_cosine_similarity".to_string(),
workgroup_size: [256, 1, 1],
},
// Dot Product
ShaderModule {
name: "dot_product".to_string(),
source: SHADER_DOT_PRODUCT.to_string(),
entry_point: "dot_product".to_string(),
workgroup_size: [256, 1, 1],
},
// Euclidean Distance
ShaderModule {
name: "euclidean_distance".to_string(),
source: SHADER_EUCLIDEAN_DISTANCE.to_string(),
entry_point: "euclidean_distance".to_string(),
workgroup_size: [256, 1, 1],
},
// L2 Normalize
ShaderModule {
name: "l2_normalize".to_string(),
source: SHADER_L2_NORMALIZE.to_string(),
entry_point: "l2_normalize".to_string(),
workgroup_size: [256, 1, 1],
},
// Mean Pooling
ShaderModule {
name: "mean_pool".to_string(),
source: SHADER_MEAN_POOL.to_string(),
entry_point: "mean_pool".to_string(),
workgroup_size: [64, 1, 1],
},
// Max Pooling
ShaderModule {
name: "max_pool".to_string(),
source: SHADER_MAX_POOL.to_string(),
entry_point: "max_pool".to_string(),
workgroup_size: [64, 1, 1],
},
// CLS Pooling
ShaderModule {
name: "cls_pool".to_string(),
source: SHADER_CLS_POOL.to_string(),
entry_point: "cls_pool".to_string(),
workgroup_size: [64, 1, 1],
},
// Matrix-Vector Multiplication
ShaderModule {
name: "matmul".to_string(),
source: SHADER_MATMUL.to_string(),
entry_point: "matmul".to_string(),
workgroup_size: [16, 16, 1],
},
// Vector Addition
ShaderModule {
name: "vector_add".to_string(),
source: SHADER_VECTOR_ADD.to_string(),
entry_point: "vector_add".to_string(),
workgroup_size: [256, 1, 1],
},
// Vector Scale
ShaderModule {
name: "vector_scale".to_string(),
source: SHADER_VECTOR_SCALE.to_string(),
entry_point: "vector_scale".to_string(),
workgroup_size: [256, 1, 1],
},
]
}
}
impl Default for ShaderRegistry {
fn default() -> Self {
Self::new()
}
}
// ==================== Shader Source Code ====================
// Public aliases for operations.rs
pub const MEAN_POOL_SHADER: &str = SHADER_MEAN_POOL;
pub const MAX_POOL_SHADER: &str = SHADER_MAX_POOL;
pub const BATCH_COSINE_SIMILARITY_SHADER: &str = SHADER_BATCH_COSINE_SIMILARITY;
pub const DOT_PRODUCT_SHADER: &str = SHADER_DOT_PRODUCT;
pub const EUCLIDEAN_DISTANCE_SHADER: &str = SHADER_EUCLIDEAN_DISTANCE;
pub const L2_NORMALIZE_SHADER: &str = SHADER_L2_NORMALIZE;
pub const MATMUL_SHADER: &str = SHADER_MATMUL;
pub const VECTOR_ADD_SHADER: &str = SHADER_VECTOR_ADD;
/// Cosine similarity between two vectors
pub const SHADER_COSINE_SIMILARITY: &str = r#"
struct Params {
dimension: u32,
count: u32,
}
@group(0) @binding(0) var<storage, read> query: array<f32>;
@group(0) @binding(1) var<storage, read> candidate: array<f32>;
@group(0) @binding(2) var<storage, read_write> result: array<f32>;
@group(0) @binding(3) var<uniform> params: Params;
var<workgroup> shared_dot: array<f32, 256>;
var<workgroup> shared_norm_a: array<f32, 256>;
var<workgroup> shared_norm_b: array<f32, 256>;
@compute @workgroup_size(256)
fn cosine_similarity(@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
let idx = gid.x;
let local_idx = lid.x;
var dot: f32 = 0.0;
var norm_a: f32 = 0.0;
var norm_b: f32 = 0.0;
// Compute partial sums
var i = local_idx;
while (i < params.dimension) {
let a = query[i];
let b = candidate[i];
dot += a * b;
norm_a += a * a;
norm_b += b * b;
i += 256u;
}
// Store in shared memory
shared_dot[local_idx] = dot;
shared_norm_a[local_idx] = norm_a;
shared_norm_b[local_idx] = norm_b;
workgroupBarrier();
// Reduction
for (var stride = 128u; stride > 0u; stride >>= 1u) {
if (local_idx < stride) {
shared_dot[local_idx] += shared_dot[local_idx + stride];
shared_norm_a[local_idx] += shared_norm_a[local_idx + stride];
shared_norm_b[local_idx] += shared_norm_b[local_idx + stride];
}
workgroupBarrier();
}
// Write result
if (local_idx == 0u) {
let norm_product = sqrt(shared_norm_a[0] * shared_norm_b[0]);
if (norm_product > 1e-12) {
result[0] = shared_dot[0] / norm_product;
} else {
result[0] = 0.0;
}
}
}
"#;
/// Batch cosine similarity - one query vs many candidates
pub const SHADER_BATCH_COSINE_SIMILARITY: &str = r#"
struct Params {
dimension: u32,
num_candidates: u32,
}
@group(0) @binding(0) var<storage, read> query: array<f32>;
@group(0) @binding(1) var<storage, read> candidates: array<f32>;
@group(0) @binding(2) var<storage, read_write> results: array<f32>;
@group(0) @binding(3) var<uniform> params: Params;
@compute @workgroup_size(256)
fn batch_cosine_similarity(@builtin(global_invocation_id) gid: vec3<u32>) {
let candidate_idx = gid.x;
if (candidate_idx >= params.num_candidates) {
return;
}
let base = candidate_idx * params.dimension;
var dot: f32 = 0.0;
var norm_a: f32 = 0.0;
var norm_b: f32 = 0.0;
for (var i = 0u; i < params.dimension; i++) {
let a = query[i];
let b = candidates[base + i];
dot += a * b;
norm_a += a * a;
norm_b += b * b;
}
let norm_product = sqrt(norm_a * norm_b);
if (norm_product > 1e-12) {
results[candidate_idx] = dot / norm_product;
} else {
results[candidate_idx] = 0.0;
}
}
"#;
/// Dot product computation
pub const SHADER_DOT_PRODUCT: &str = r#"
struct Params {
dimension: u32,
num_candidates: u32,
}
@group(0) @binding(0) var<storage, read> query: array<f32>;
@group(0) @binding(1) var<storage, read> candidates: array<f32>;
@group(0) @binding(2) var<storage, read_write> results: array<f32>;
@group(0) @binding(3) var<uniform> params: Params;
@compute @workgroup_size(256)
fn dot_product(@builtin(global_invocation_id) gid: vec3<u32>) {
let candidate_idx = gid.x;
if (candidate_idx >= params.num_candidates) {
return;
}
let base = candidate_idx * params.dimension;
var dot: f32 = 0.0;
for (var i = 0u; i < params.dimension; i++) {
dot += query[i] * candidates[base + i];
}
results[candidate_idx] = dot;
}
"#;
/// Euclidean distance computation
pub const SHADER_EUCLIDEAN_DISTANCE: &str = r#"
struct Params {
dimension: u32,
num_candidates: u32,
}
@group(0) @binding(0) var<storage, read> query: array<f32>;
@group(0) @binding(1) var<storage, read> candidates: array<f32>;
@group(0) @binding(2) var<storage, read_write> results: array<f32>;
@group(0) @binding(3) var<uniform> params: Params;
@compute @workgroup_size(256)
fn euclidean_distance(@builtin(global_invocation_id) gid: vec3<u32>) {
let candidate_idx = gid.x;
if (candidate_idx >= params.num_candidates) {
return;
}
let base = candidate_idx * params.dimension;
var sum_sq: f32 = 0.0;
for (var i = 0u; i < params.dimension; i++) {
let diff = query[i] - candidates[base + i];
sum_sq += diff * diff;
}
results[candidate_idx] = sqrt(sum_sq);
}
"#;
/// L2 normalization
pub const SHADER_L2_NORMALIZE: &str = r#"
struct Params {
dimension: u32,
num_vectors: u32,
}
@group(0) @binding(0) var<storage, read> input_vectors: array<f32>;
@group(0) @binding(1) var<storage, read> _dummy: array<f32>;
@group(0) @binding(2) var<storage, read_write> output_vectors: array<f32>;
@group(0) @binding(3) var<uniform> params: Params;
@compute @workgroup_size(256)
fn l2_normalize(@builtin(global_invocation_id) gid: vec3<u32>) {
let vec_idx = gid.x;
if (vec_idx >= params.num_vectors) {
return;
}
let base = vec_idx * params.dimension;
// Compute norm
var norm_sq: f32 = 0.0;
for (var i = 0u; i < params.dimension; i++) {
let val = input_vectors[base + i];
norm_sq += val * val;
}
let norm = sqrt(norm_sq);
// Normalize and write to output
if (norm > 1e-12) {
for (var i = 0u; i < params.dimension; i++) {
output_vectors[base + i] = input_vectors[base + i] / norm;
}
} else {
for (var i = 0u; i < params.dimension; i++) {
output_vectors[base + i] = input_vectors[base + i];
}
}
}
"#;
/// Mean pooling over sequence
pub const SHADER_MEAN_POOL: &str = r#"
struct Params {
batch_size: u32,
seq_length: u32,
hidden_size: u32,
}
@group(0) @binding(0) var<storage, read> tokens: array<f32>;
@group(0) @binding(1) var<storage, read> attention_mask: array<i32>;
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
@group(0) @binding(3) var<uniform> params: Params;
@compute @workgroup_size(64)
fn mean_pool(@builtin(global_invocation_id) gid: vec3<u32>) {
let batch_idx = gid.x / params.hidden_size;
let hidden_idx = gid.x % params.hidden_size;
if (batch_idx >= params.batch_size) {
return;
}
let tokens_base = batch_idx * params.seq_length * params.hidden_size;
let mask_base = batch_idx * params.seq_length;
var sum: f32 = 0.0;
var count: f32 = 0.0;
for (var i = 0u; i < params.seq_length; i++) {
if (attention_mask[mask_base + i] == 1) {
sum += tokens[tokens_base + i * params.hidden_size + hidden_idx];
count += 1.0;
}
}
let out_idx = batch_idx * params.hidden_size + hidden_idx;
if (count > 0.0) {
output[out_idx] = sum / count;
} else {
output[out_idx] = 0.0;
}
}
"#;
/// Max pooling over sequence
pub const SHADER_MAX_POOL: &str = r#"
struct Params {
batch_size: u32,
seq_length: u32,
hidden_size: u32,
}
@group(0) @binding(0) var<storage, read> tokens: array<f32>;
@group(0) @binding(1) var<storage, read> attention_mask: array<i32>;
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
@group(0) @binding(3) var<uniform> params: Params;
@compute @workgroup_size(64)
fn max_pool(@builtin(global_invocation_id) gid: vec3<u32>) {
let batch_idx = gid.x / params.hidden_size;
let hidden_idx = gid.x % params.hidden_size;
if (batch_idx >= params.batch_size) {
return;
}
let tokens_base = batch_idx * params.seq_length * params.hidden_size;
let mask_base = batch_idx * params.seq_length;
var max_val: f32 = -3.402823e+38; // -FLT_MAX
var found: bool = false;
for (var i = 0u; i < params.seq_length; i++) {
if (attention_mask[mask_base + i] == 1) {
let val = tokens[tokens_base + i * params.hidden_size + hidden_idx];
if (!found || val > max_val) {
max_val = val;
found = true;
}
}
}
let out_idx = batch_idx * params.hidden_size + hidden_idx;
output[out_idx] = select(0.0, max_val, found);
}
"#;
/// CLS token pooling (first token)
pub const SHADER_CLS_POOL: &str = r#"
struct Params {
batch_size: u32,
seq_length: u32,
hidden_size: u32,
}
@group(0) @binding(0) var<storage, read> tokens: array<f32>;
@group(0) @binding(1) var<storage, read> _dummy: array<f32>;
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
@group(0) @binding(3) var<uniform> params: Params;
@compute @workgroup_size(64)
fn cls_pool(@builtin(global_invocation_id) gid: vec3<u32>) {
let batch_idx = gid.x / params.hidden_size;
let hidden_idx = gid.x % params.hidden_size;
if (batch_idx >= params.batch_size) {
return;
}
// CLS is first token
let tokens_base = batch_idx * params.seq_length * params.hidden_size;
let out_idx = batch_idx * params.hidden_size + hidden_idx;
output[out_idx] = tokens[tokens_base + hidden_idx];
}
"#;
/// Matrix-vector multiplication
pub const SHADER_MATMUL: &str = r#"
struct Params {
rows: u32,
cols: u32,
}
@group(0) @binding(0) var<storage, read> matrix: array<f32>;
@group(0) @binding(1) var<storage, read> vector: array<f32>;
@group(0) @binding(2) var<storage, read_write> result: array<f32>;
@group(0) @binding(3) var<uniform> params: Params;
@compute @workgroup_size(16, 16)
fn matmul(@builtin(global_invocation_id) gid: vec3<u32>) {
let row = gid.x;
if (row >= params.rows) {
return;
}
var sum: f32 = 0.0;
for (var col = 0u; col < params.cols; col++) {
sum += matrix[row * params.cols + col] * vector[col];
}
result[row] = sum;
}
"#;
/// Vector addition
pub const SHADER_VECTOR_ADD: &str = r#"
struct Params {
length: u32,
}
@group(0) @binding(0) var<storage, read> a: array<f32>;
@group(0) @binding(1) var<storage, read> b: array<f32>;
@group(0) @binding(2) var<storage, read_write> result: array<f32>;
@group(0) @binding(3) var<uniform> params: Params;
@compute @workgroup_size(256)
fn vector_add(@builtin(global_invocation_id) gid: vec3<u32>) {
let idx = gid.x;
if (idx >= params.length) {
return;
}
result[idx] = a[idx] + b[idx];
}
"#;
/// Vector scaling
pub const SHADER_VECTOR_SCALE: &str = r#"
struct Params {
length: u32,
scale: f32,
}
@group(0) @binding(0) var<storage, read> input_vector: array<f32>;
@group(0) @binding(1) var<storage, read> _dummy: array<f32>;
@group(0) @binding(2) var<storage, read_write> output_vector: array<f32>;
@group(0) @binding(3) var<uniform> params: Params;
@compute @workgroup_size(256)
fn vector_scale(@builtin(global_invocation_id) gid: vec3<u32>) {
let idx = gid.x;
if (idx >= params.length) {
return;
}
output_vector[idx] = input_vector[idx] * params.scale;
}
"#;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_shader_registry() {
let registry = ShaderRegistry::new();
// Check all built-in shaders are registered
assert!(registry.get("cosine_similarity").is_some());
assert!(registry.get("batch_cosine_similarity").is_some());
assert!(registry.get("dot_product").is_some());
assert!(registry.get("euclidean_distance").is_some());
assert!(registry.get("l2_normalize").is_some());
assert!(registry.get("mean_pool").is_some());
assert!(registry.get("max_pool").is_some());
assert!(registry.get("cls_pool").is_some());
assert!(registry.get("matmul").is_some());
assert!(registry.get("vector_add").is_some());
assert!(registry.get("vector_scale").is_some());
}
#[test]
fn test_shader_content() {
let registry = ShaderRegistry::new();
let cosine = registry.get("cosine_similarity").unwrap();
assert!(cosine.source.contains("@compute"));
assert!(cosine.source.contains("workgroup_size"));
assert_eq!(cosine.entry_point, "cosine_similarity");
}
#[test]
fn test_custom_shader() {
let mut registry = ShaderRegistry::new();
registry.register(ShaderModule {
name: "custom_op".to_string(),
source: "// custom shader".to_string(),
entry_point: "custom".to_string(),
workgroup_size: [128, 1, 1],
});
assert!(registry.get("custom_op").is_some());
}
}

View File

@@ -0,0 +1,424 @@
//! GPU Module Tests
//!
//! Comprehensive tests for GPU acceleration functionality.
use super::*;
use super::config::{GpuConfig, GpuMode, PowerPreference, GpuMemoryStats};
use super::backend::CpuBackend;
use super::shaders::ShaderModule;
// ==================== Configuration Tests ====================
#[test]
fn test_gpu_config_default() {
let config = GpuConfig::default();
assert_eq!(config.mode, GpuMode::Auto);
assert_eq!(config.power_preference, PowerPreference::HighPerformance);
assert_eq!(config.workgroup_size, 256);
assert!(config.fallback_to_cpu);
assert!(config.cache_shaders);
}
#[test]
fn test_gpu_config_builder() {
let config = GpuConfig::auto()
.with_mode(GpuMode::WebGpu)
.with_power_preference(PowerPreference::LowPower)
.with_workgroup_size(512)
.with_min_batch_size(32)
.with_min_dimension(256)
.with_profiling(true);
assert_eq!(config.mode, GpuMode::WebGpu);
assert_eq!(config.power_preference, PowerPreference::LowPower);
assert_eq!(config.workgroup_size, 512);
assert_eq!(config.min_batch_size, 32);
assert_eq!(config.min_dimension, 256);
assert!(config.enable_profiling);
}
#[test]
fn test_should_use_gpu() {
let config = GpuConfig::default()
.with_min_batch_size(16)
.with_min_dimension(128);
// Below minimum batch size
assert!(!config.should_use_gpu(8, 384));
// Below minimum dimension
assert!(!config.should_use_gpu(32, 64));
// Both conditions met
assert!(config.should_use_gpu(32, 384));
// CPU only mode
let cpu_config = GpuConfig::cpu_only();
assert!(!cpu_config.should_use_gpu(1000, 1000));
}
#[test]
fn test_preset_configs() {
let high_perf = GpuConfig::high_performance();
assert_eq!(high_perf.workgroup_size, 512);
assert_eq!(high_perf.min_batch_size, 8);
let low_power = GpuConfig::low_power();
assert_eq!(low_power.power_preference, PowerPreference::LowPower);
assert_eq!(low_power.workgroup_size, 128);
let cpu_only = GpuConfig::cpu_only();
assert_eq!(cpu_only.mode, GpuMode::CpuOnly);
}
// ==================== Shader Tests ====================
#[test]
fn test_shader_registry_initialization() {
let registry = ShaderRegistry::new();
let expected_shaders = vec![
"cosine_similarity",
"batch_cosine_similarity",
"dot_product",
"euclidean_distance",
"l2_normalize",
"mean_pool",
"max_pool",
"cls_pool",
"matmul",
"vector_add",
"vector_scale",
];
for name in expected_shaders {
assert!(registry.get(name).is_some(), "Missing shader: {}", name);
}
}
#[test]
fn test_shader_module_content() {
let registry = ShaderRegistry::new();
// Check cosine similarity shader
let cosine = registry.get("cosine_similarity").unwrap();
assert!(cosine.source.contains("@compute"));
assert!(cosine.source.contains("workgroup_size"));
assert!(cosine.source.contains("cosine_similarity"));
assert_eq!(cosine.entry_point, "cosine_similarity");
assert_eq!(cosine.workgroup_size, [256, 1, 1]);
// Check mean pool shader
let mean_pool = registry.get("mean_pool").unwrap();
assert!(mean_pool.source.contains("attention_mask"));
assert!(mean_pool.source.contains("hidden_size"));
assert_eq!(mean_pool.entry_point, "mean_pool");
}
#[test]
fn test_custom_shader_registration() {
let mut registry = ShaderRegistry::new();
let custom = ShaderModule {
name: "custom_kernel".to_string(),
source: "@compute @workgroup_size(64) fn custom() {}".to_string(),
entry_point: "custom".to_string(),
workgroup_size: [64, 1, 1],
};
registry.register(custom);
assert!(registry.get("custom_kernel").is_some());
let retrieved = registry.get("custom_kernel").unwrap();
assert_eq!(retrieved.entry_point, "custom");
}
// ==================== Batch Operations Tests ====================
#[test]
fn test_batch_cosine_similarity() {
let query = vec![1.0, 0.0, 0.0];
let candidates: Vec<&[f32]> = vec![
&[1.0, 0.0, 0.0][..], // similarity = 1.0
&[0.0, 1.0, 0.0][..], // similarity = 0.0
&[-1.0, 0.0, 0.0][..], // similarity = -1.0
];
let results = batch_cosine_similarity_gpu(&query, &candidates);
assert_eq!(results.len(), 3);
assert!((results[0] - 1.0).abs() < 1e-6);
assert!(results[1].abs() < 1e-6);
assert!((results[2] - (-1.0)).abs() < 1e-6);
}
#[test]
fn test_batch_dot_product() {
let query = vec![1.0, 1.0, 1.0];
let candidates: Vec<&[f32]> = vec![
&[1.0, 1.0, 1.0][..], // dot = 3.0
&[2.0, 2.0, 2.0][..], // dot = 6.0
&[0.0, 0.0, 0.0][..], // dot = 0.0
];
let results = batch_dot_product_gpu(&query, &candidates);
assert_eq!(results.len(), 3);
assert!((results[0] - 3.0).abs() < 1e-6);
assert!((results[1] - 6.0).abs() < 1e-6);
assert!(results[2].abs() < 1e-6);
}
#[test]
fn test_batch_euclidean() {
let query = vec![0.0, 0.0, 0.0];
let candidates: Vec<&[f32]> = vec![
&[3.0, 4.0, 0.0][..], // dist = 5.0
&[1.0, 0.0, 0.0][..], // dist = 1.0
&[0.0, 0.0, 0.0][..], // dist = 0.0
];
let results = batch_euclidean_gpu(&query, &candidates);
assert_eq!(results.len(), 3);
assert!((results[0] - 5.0).abs() < 1e-6);
assert!((results[1] - 1.0).abs() < 1e-6);
assert!(results[2].abs() < 1e-6);
}
// ==================== Pooling Tests (using public API) ====================
#[test]
fn test_mean_pool_via_api() {
let backend = CpuBackend;
let shaders = ShaderRegistry::new();
let pooler = GpuPooler::new(&backend, &shaders).unwrap();
// batch=2, seq=2, hidden=3
let tokens = vec![
1.0, 2.0, 3.0, // batch 0, seq 0
4.0, 5.0, 6.0, // batch 0, seq 1
7.0, 8.0, 9.0, // batch 1, seq 0
10.0, 11.0, 12.0, // batch 1, seq 1
];
let mask = vec![1i64, 1, 1, 1];
let result = pooler.mean_pool(&tokens, &mask, 2, 2, 3).unwrap();
assert_eq!(result.len(), 6);
// Batch 0: mean of [1,2,3] and [4,5,6] = [2.5, 3.5, 4.5]
assert!((result[0] - 2.5).abs() < 1e-6);
assert!((result[1] - 3.5).abs() < 1e-6);
assert!((result[2] - 4.5).abs() < 1e-6);
}
#[test]
fn test_cls_pool_via_api() {
let backend = CpuBackend;
let shaders = ShaderRegistry::new();
let pooler = GpuPooler::new(&backend, &shaders).unwrap();
// batch=2, seq=3, hidden=4
let tokens = vec![
// Batch 0
1.0, 2.0, 3.0, 4.0, // CLS token
5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0,
// Batch 1
10.0, 20.0, 30.0, 40.0, // CLS token
50.0, 60.0, 70.0, 80.0,
90.0, 100.0, 110.0, 120.0,
];
let result = pooler.cls_pool(&tokens, 2, 4).unwrap();
assert_eq!(result.len(), 8);
// Batch 0: first token
assert!((result[0] - 1.0).abs() < 1e-6);
assert!((result[1] - 2.0).abs() < 1e-6);
assert!((result[2] - 3.0).abs() < 1e-6);
assert!((result[3] - 4.0).abs() < 1e-6);
// Batch 1: first token
assert!((result[4] - 10.0).abs() < 1e-6);
assert!((result[5] - 20.0).abs() < 1e-6);
assert!((result[6] - 30.0).abs() < 1e-6);
assert!((result[7] - 40.0).abs() < 1e-6);
}
#[test]
fn test_max_pool_via_api() {
let backend = CpuBackend;
let shaders = ShaderRegistry::new();
let pooler = GpuPooler::new(&backend, &shaders).unwrap();
// batch=1, seq=3, hidden=4
let tokens = vec![
1.0, 10.0, 3.0, 4.0, // seq 0
5.0, 2.0, 7.0, 8.0, // seq 1
9.0, 6.0, 11.0, 0.0, // seq 2
];
let mask = vec![1i64, 1, 1];
let result = pooler.max_pool(&tokens, &mask, 1, 3, 4).unwrap();
assert_eq!(result.len(), 4);
// Max across all sequences for each dimension
assert!((result[0] - 9.0).abs() < 1e-6); // max(1, 5, 9)
assert!((result[1] - 10.0).abs() < 1e-6); // max(10, 2, 6)
assert!((result[2] - 11.0).abs() < 1e-6); // max(3, 7, 11)
assert!((result[3] - 8.0).abs() < 1e-6); // max(4, 8, 0)
}
// ==================== Vector Operations Tests ====================
#[test]
fn test_normalize_batch() {
let backend = CpuBackend;
let shaders = ShaderRegistry::new();
let ops = GpuVectorOps::new(&backend, &shaders).unwrap();
let mut vectors = vec![
3.0, 4.0, 0.0, // norm = 5, normalized = [0.6, 0.8, 0]
0.0, 0.0, 5.0, // norm = 5, normalized = [0, 0, 1]
];
ops.normalize_batch(&mut vectors, 3).unwrap();
// Check first vector
assert!((vectors[0] - 0.6).abs() < 1e-6);
assert!((vectors[1] - 0.8).abs() < 1e-6);
assert!(vectors[2].abs() < 1e-6);
// Check second vector
assert!(vectors[3].abs() < 1e-6);
assert!(vectors[4].abs() < 1e-6);
assert!((vectors[5] - 1.0).abs() < 1e-6);
}
#[test]
fn test_matmul() {
let backend = CpuBackend;
let shaders = ShaderRegistry::new();
let ops = GpuVectorOps::new(&backend, &shaders).unwrap();
// 2x3 matrix
let matrix = vec![
1.0, 2.0, 3.0,
4.0, 5.0, 6.0,
];
// 3x1 vector
let vector = vec![1.0, 1.0, 1.0];
let result = ops.matmul(&matrix, &vector, 2, 3).unwrap();
assert_eq!(result.len(), 2);
assert!((result[0] - 6.0).abs() < 1e-6); // 1+2+3
assert!((result[1] - 15.0).abs() < 1e-6); // 4+5+6
}
#[test]
fn test_batch_add() {
let backend = CpuBackend;
let shaders = ShaderRegistry::new();
let ops = GpuVectorOps::new(&backend, &shaders).unwrap();
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![5.0, 6.0, 7.0, 8.0];
let result = ops.batch_add(&a, &b).unwrap();
assert_eq!(result, vec![6.0, 8.0, 10.0, 12.0]);
}
#[test]
fn test_batch_scale() {
let backend = CpuBackend;
let shaders = ShaderRegistry::new();
let ops = GpuVectorOps::new(&backend, &shaders).unwrap();
let mut vectors = vec![1.0, 2.0, 3.0, 4.0];
ops.batch_scale(&mut vectors, 2.0).unwrap();
assert_eq!(vectors, vec![2.0, 4.0, 6.0, 8.0]);
}
// ==================== Integration Tests ====================
#[test]
fn test_gpu_similarity_with_backend() {
let backend = CpuBackend;
let shaders = ShaderRegistry::new();
let similarity = GpuSimilarity::new(&backend, &shaders).unwrap();
let query = vec![1.0, 0.0, 0.0];
let candidates: Vec<&[f32]> = vec![
&[1.0, 0.0, 0.0][..],
&[0.0, 1.0, 0.0][..],
];
let results = similarity.batch_cosine(&query, &candidates).unwrap();
assert_eq!(results.len(), 2);
assert!((results[0] - 1.0).abs() < 1e-6);
}
#[test]
fn test_top_k_similar() {
let backend = CpuBackend;
let shaders = ShaderRegistry::new();
let similarity = GpuSimilarity::new(&backend, &shaders).unwrap();
let query = vec![1.0, 0.0, 0.0];
let candidates: Vec<&[f32]> = vec![
&[0.0, 1.0, 0.0][..], // sim = 0
&[1.0, 0.0, 0.0][..], // sim = 1 (best)
&[0.5, 0.5, 0.0][..], // sim ≈ 0.707
&[-1.0, 0.0, 0.0][..], // sim = -1 (worst)
];
let top2 = similarity.top_k(&query, &candidates, 2).unwrap();
assert_eq!(top2.len(), 2);
assert_eq!(top2[0].0, 1); // Index of [1,0,0]
assert_eq!(top2[1].0, 2); // Index of [0.5,0.5,0]
}
// ==================== Memory Stats Tests ====================
#[test]
fn test_memory_stats() {
let stats = GpuMemoryStats {
total: 1024 * 1024 * 1024, // 1GB
used: 512 * 1024 * 1024, // 512MB
free: 512 * 1024 * 1024,
peak: 768 * 1024 * 1024,
};
assert!((stats.usage_percent() - 50.0).abs() < 0.1);
}
#[test]
fn test_empty_memory_stats() {
let stats = GpuMemoryStats::default();
assert_eq!(stats.usage_percent(), 0.0);
}
// ==================== Backend Tests ====================
#[test]
fn test_cpu_backend_info() {
let backend = CpuBackend;
assert!(backend.is_available());
let info = backend.device_info();
assert_eq!(info.backend, "CPU");
assert!(!info.supports_compute);
}

View File

@@ -0,0 +1,187 @@
//! # RuVector ONNX Embeddings
//!
//! A reimagined embedding pipeline for RuVector using ONNX Runtime in pure Rust.
//!
//! This crate provides:
//! - Native ONNX model inference for embedding generation
//! - HuggingFace tokenizer integration
//! - Batch processing with SIMD optimization
//! - Direct RuVector vector database integration
//! - Model management and caching
//!
//! ## Architecture
//!
//! ```text
//! ┌─────────────────────────────────────────────────────────────────┐
//! │ RuVector ONNX Embeddings │
//! ├─────────────────────────────────────────────────────────────────┤
//! │ │
//! │ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
//! │ │ Text Input │ -> │ Tokenizer │ -> │ Token IDs │ │
//! │ └──────────────┘ └──────────────┘ └──────────────┘ │
//! │ │ │
//! │ v │
//! │ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
//! │ │ Embeddings │ <- │ ONNX Runtime │ <- │ Input Tensor │ │
//! │ └──────────────┘ └──────────────┘ └──────────────┘ │
//! │ │ │
//! │ v │
//! │ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
//! │ │ Normalize │ -> │ Mean Pooling │ -> │ RuVector DB │ │
//! │ └──────────────┘ └──────────────┘ └──────────────┘ │
//! │ │
//! └─────────────────────────────────────────────────────────────────┘
//! ```
//!
//! ## Example
//!
//! ```rust,ignore
//! use ruvector_onnx_embeddings::{Embedder, EmbedderConfig, ModelSource};
//!
//! #[tokio::main]
//! async fn main() -> anyhow::Result<()> {
//! // Create embedder with default model (all-MiniLM-L6-v2)
//! let embedder = Embedder::new(EmbedderConfig::default()).await?;
//!
//! // Generate embeddings
//! let texts = vec!["Hello, world!", "Rust is awesome!"];
//! let embeddings = embedder.embed(&texts)?;
//!
//! // Use with RuVector
//! let db = embedder.create_ruvector_index("my_index")?;
//! db.insert_with_embeddings(&texts, &embeddings)?;
//!
//! Ok(())
//! }
//! ```
pub mod config;
pub mod embedder;
pub mod error;
pub mod model;
pub mod pooling;
pub mod ruvector_integration;
pub mod tokenizer;
/// GPU acceleration module (optional, requires `gpu` feature)
#[cfg(feature = "gpu")]
pub mod gpu;
/// GPU module stub for when feature is disabled
#[cfg(not(feature = "gpu"))]
pub mod gpu {
//! GPU acceleration is not available without the `gpu` feature.
//!
//! Enable with: `cargo build --features gpu`
/// Placeholder for GpuConfig when GPU feature is disabled
#[derive(Debug, Clone, Default)]
pub struct GpuConfig;
impl GpuConfig {
/// Create default config (no-op without GPU feature)
pub fn auto() -> Self { Self }
/// CPU-only config
pub fn cpu_only() -> Self { Self }
}
/// Check if GPU is available (always false without feature)
pub async fn is_gpu_available() -> bool { false }
}
// Re-exports
pub use config::{EmbedderConfig, ModelSource, PoolingStrategy};
pub use embedder::{Embedder, EmbedderBuilder, EmbeddingOutput};
pub use error::{EmbeddingError, Result};
pub use model::{OnnxModel, ModelInfo};
pub use pooling::Pooler;
pub use ruvector_integration::{
Distance, IndexConfig, RagPipeline, RuVectorBuilder, RuVectorEmbeddings, SearchResult, VectorId,
};
pub use tokenizer::Tokenizer;
// GPU exports (conditional)
#[cfg(feature = "gpu")]
pub use gpu::{
GpuAccelerator, GpuConfig, GpuMode, GpuInfo, GpuBackend,
HybridAccelerator, is_gpu_available,
};
/// Prelude module for convenient imports
pub mod prelude {
pub use crate::{
Distance, Embedder, EmbedderBuilder, EmbedderConfig, EmbeddingError,
IndexConfig, ModelSource, PoolingStrategy, RagPipeline, Result,
RuVectorBuilder, RuVectorEmbeddings, SearchResult, VectorId,
};
}
/// Supported embedding models with pre-configured settings
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum PretrainedModel {
/// all-MiniLM-L6-v2: 384 dimensions, fast inference
#[default]
AllMiniLmL6V2,
/// all-MiniLM-L12-v2: 384 dimensions, better quality
AllMiniLmL12V2,
/// all-mpnet-base-v2: 768 dimensions, high quality
AllMpnetBaseV2,
/// multi-qa-MiniLM-L6: 384 dimensions, optimized for QA
MultiQaMiniLmL6,
/// paraphrase-MiniLM-L6-v2: 384 dimensions, paraphrase detection
ParaphraseMiniLmL6V2,
/// BGE-small-en-v1.5: 384 dimensions, BAAI General Embeddings
BgeSmallEnV15,
/// E5-small-v2: 384 dimensions, Microsoft E5 model
E5SmallV2,
/// GTE-small: 384 dimensions, Alibaba GTE model
GteSmall,
}
impl PretrainedModel {
/// Get the HuggingFace model ID
pub fn model_id(&self) -> &'static str {
match self {
Self::AllMiniLmL6V2 => "sentence-transformers/all-MiniLM-L6-v2",
Self::AllMiniLmL12V2 => "sentence-transformers/all-MiniLM-L12-v2",
Self::AllMpnetBaseV2 => "sentence-transformers/all-mpnet-base-v2",
Self::MultiQaMiniLmL6 => "sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
Self::ParaphraseMiniLmL6V2 => "sentence-transformers/paraphrase-MiniLM-L6-v2",
Self::BgeSmallEnV15 => "BAAI/bge-small-en-v1.5",
Self::E5SmallV2 => "intfloat/e5-small-v2",
Self::GteSmall => "thenlper/gte-small",
}
}
/// Get the embedding dimension
pub fn dimension(&self) -> usize {
match self {
Self::AllMiniLmL6V2
| Self::AllMiniLmL12V2
| Self::MultiQaMiniLmL6
| Self::ParaphraseMiniLmL6V2
| Self::BgeSmallEnV15
| Self::E5SmallV2
| Self::GteSmall => 384,
Self::AllMpnetBaseV2 => 768,
}
}
/// Get recommended max sequence length
pub fn max_seq_length(&self) -> usize {
match self {
Self::AllMiniLmL6V2
| Self::AllMiniLmL12V2
| Self::MultiQaMiniLmL6
| Self::ParaphraseMiniLmL6V2 => 256,
Self::AllMpnetBaseV2 => 384,
Self::BgeSmallEnV15 | Self::E5SmallV2 | Self::GteSmall => 512,
}
}
/// Whether the model requires normalized outputs
pub fn normalize_output(&self) -> bool {
true
}
}

View File

@@ -0,0 +1,265 @@
//! RuVector ONNX Embeddings - Example Usage
//!
//! This example demonstrates how to use ONNX-based embedding generation
//! with RuVector for semantic search and RAG pipelines.
use anyhow::Result;
use ruvector_onnx_embeddings::{
prelude::*, EmbedderBuilder, PretrainedModel, PoolingStrategy,
RuVectorBuilder, RagPipeline, Distance,
};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
#[tokio::main]
async fn main() -> Result<()> {
// Initialize logging
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "info".into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
println!("╔═══════════════════════════════════════════════════════════════╗");
println!("║ RuVector ONNX Embeddings - Reimagined for Rust ║");
println!("╚═══════════════════════════════════════════════════════════════╝");
println!();
// Run examples
basic_embedding_example().await?;
batch_embedding_example().await?;
semantic_search_example().await?;
rag_pipeline_example().await?;
clustering_example().await?;
println!("\n✅ All examples completed successfully!");
Ok(())
}
/// Basic embedding generation
async fn basic_embedding_example() -> Result<()> {
println!("\n━━━ Example 1: Basic Embedding Generation ━━━");
// Create embedder with default model (all-MiniLM-L6-v2)
let mut embedder = Embedder::default_model().await?;
println!("Model: {}", embedder.model_info().name);
println!("Dimension: {}", embedder.dimension());
// Embed a single sentence
let text = "The quick brown fox jumps over the lazy dog.";
let embedding = embedder.embed_one(text)?;
println!("Input: \"{}\"", text);
println!("Embedding shape: [{}]", embedding.len());
println!(
"First 5 values: [{:.4}, {:.4}, {:.4}, {:.4}, {:.4}]",
embedding[0], embedding[1], embedding[2], embedding[3], embedding[4]
);
// Compute similarity between two sentences
let text1 = "I love programming in Rust.";
let text2 = "Rust is my favorite programming language.";
let text3 = "The weather is nice today.";
let sim_related = embedder.similarity(text1, text2)?;
let sim_unrelated = embedder.similarity(text1, text3)?;
println!("\nSimilarity comparisons:");
println!(" \"{}\"\n vs\n \"{}\"", text1, text2);
println!(" Similarity: {:.4}", sim_related);
println!();
println!(" \"{}\"\n vs\n \"{}\"", text1, text3);
println!(" Similarity: {:.4}", sim_unrelated);
Ok(())
}
/// Batch embedding with parallel processing
async fn batch_embedding_example() -> Result<()> {
println!("\n━━━ Example 2: Batch Embedding ━━━");
// Create embedder with custom configuration
let mut embedder = EmbedderBuilder::new()
.pretrained(PretrainedModel::AllMiniLmL6V2)
.pooling(PoolingStrategy::Mean)
.normalize(true)
.batch_size(64)
.build()
.await?;
let texts = vec![
"Artificial intelligence is transforming technology.",
"Machine learning models learn from data.",
"Deep learning uses neural networks.",
"Natural language processing understands text.",
"Computer vision analyzes images.",
"Reinforcement learning optimizes decisions.",
"Vector databases enable semantic search.",
"Embeddings capture semantic meaning.",
];
println!("Embedding {} texts...", texts.len());
let start = std::time::Instant::now();
let output = embedder.embed(&texts)?;
let elapsed = start.elapsed();
println!("Completed in {:?}", elapsed);
println!("Total embeddings: {}", output.len());
println!("Embedding dimension: {}", output.dimension);
// Show token counts
println!("\nToken counts per text:");
for (i, (text, tokens)) in texts.iter().zip(output.token_counts.iter()).enumerate() {
println!(" [{}] {} tokens: \"{}...\"", i, tokens, &text[..40.min(text.len())]);
}
Ok(())
}
/// Semantic search with RuVector
async fn semantic_search_example() -> Result<()> {
println!("\n━━━ Example 3: Semantic Search with RuVector ━━━");
// Create embedder
let embedder = Embedder::default_model().await?;
// Create RuVector index
let index = RuVectorBuilder::new("semantic_search")
.embedder(embedder)
.distance(Distance::Cosine)
.max_elements(10_000)
.build()?;
// Knowledge base about programming languages
let documents = vec![
"Rust is a systems programming language focused on safety and performance.",
"Python is widely used for machine learning and data science applications.",
"JavaScript is the language of the web, running in browsers everywhere.",
"Go is designed for building scalable and efficient server applications.",
"TypeScript adds static typing to JavaScript for better developer experience.",
"C++ provides low-level control and high performance for system software.",
"Java is a mature, object-oriented language popular in enterprise software.",
"Swift is Apple's modern language for iOS and macOS development.",
"Kotlin is a concise language that runs on the JVM, popular for Android.",
"Haskell is a purely functional programming language with strong typing.",
];
println!("Indexing {} documents...", documents.len());
index.insert_batch(&documents)?;
println!("Index size: {} vectors", index.len());
// Perform searches
let queries = vec![
"What language is best for web development?",
"I want to build a high-performance system application",
"Which language should I learn for machine learning?",
"I need a language for mobile app development",
];
for query in queries {
println!("\n🔍 Query: \"{}\"", query);
let results = index.search(query, 3)?;
for (i, result) in results.iter().enumerate() {
println!(
" {}. (score: {:.4}) {}",
i + 1,
result.score,
result.text
);
}
}
Ok(())
}
/// RAG (Retrieval-Augmented Generation) pipeline
async fn rag_pipeline_example() -> Result<()> {
println!("\n━━━ Example 4: RAG Pipeline ━━━");
let embedder = Embedder::default_model().await?;
let index = RuVectorEmbeddings::new_default("rag_index", embedder)?;
let rag = RagPipeline::new(index, 3);
// Add knowledge base
let knowledge = vec![
"RuVector is a distributed vector database that learns and adapts.",
"RuVector uses HNSW indexing for fast approximate nearest neighbor search.",
"The embedding dimension in RuVector is configurable based on your model.",
"RuVector supports multiple distance metrics: Cosine, Euclidean, and Dot Product.",
"Graph Neural Networks in RuVector improve search quality over time.",
"RuVector integrates with ONNX models for native embedding generation.",
"The NAPI-RS bindings allow using RuVector from Node.js applications.",
"RuVector supports WebAssembly for running in web browsers.",
"Raft consensus enables distributed deployment of RuVector clusters.",
"Quantization in RuVector provides 2-32x memory compression.",
];
println!("Loading {} documents into RAG pipeline...", knowledge.len());
rag.add_documents(&knowledge)?;
// Generate context for questions
let questions = vec![
"How does RuVector achieve fast search?",
"Can I use RuVector in a web browser?",
"What compression options does RuVector have?",
];
for question in questions {
println!("\n❓ Question: {}", question);
let context = rag.format_context(question)?;
println!("Generated Context:\n{}", context);
println!("{}", "".repeat(60));
}
Ok(())
}
/// Text clustering example
async fn clustering_example() -> Result<()> {
println!("\n━━━ Example 5: Text Clustering ━━━");
let mut embedder = Embedder::default_model().await?;
// Texts from different categories
let texts = vec![
// Technology
"Artificial intelligence is revolutionizing industries.",
"Machine learning algorithms process large datasets.",
"Neural networks mimic the human brain.",
// Sports
"Football is the most popular sport worldwide.",
"Basketball requires speed and agility.",
"Tennis is played on different court surfaces.",
// Food
"Italian pasta comes in many shapes and sizes.",
"Sushi is a traditional Japanese dish.",
"French cuisine is known for its elegance.",
];
println!("Clustering {} texts into 3 categories...", texts.len());
let clusters = embedder.cluster(&texts, 3)?;
// Group texts by cluster
let mut groups: std::collections::HashMap<usize, Vec<&str>> = std::collections::HashMap::new();
for (i, &cluster) in clusters.iter().enumerate() {
groups.entry(cluster).or_default().push(texts[i]);
}
println!("\nCluster assignments:");
for (cluster_id, members) in groups.iter() {
println!("\n📁 Cluster {}:", cluster_id);
for text in members {
println!("{}", text);
}
}
Ok(())
}

View File

@@ -0,0 +1,470 @@
//! ONNX model loading and management
use crate::config::{EmbedderConfig, ExecutionProvider, ModelSource};
use crate::{EmbeddingError, PretrainedModel, Result};
use indicatif::{ProgressBar, ProgressStyle};
use ort::session::{builder::GraphOptimizationLevel, Session};
use sha2::{Digest, Sha256};
use std::fs;
use std::io::Write;
use std::path::Path;
use tracing::{debug, info, instrument, warn};
/// Information about a loaded model
#[derive(Debug, Clone)]
pub struct ModelInfo {
/// Model name or identifier
pub name: String,
/// Embedding dimension
pub dimension: usize,
/// Maximum sequence length
pub max_seq_length: usize,
/// Model file size in bytes
pub file_size: u64,
/// Model input names
pub input_names: Vec<String>,
/// Model output names
pub output_names: Vec<String>,
}
/// ONNX model wrapper with inference capabilities
pub struct OnnxModel {
session: Session,
info: ModelInfo,
}
impl OnnxModel {
/// Load model from configuration
#[instrument(skip_all)]
pub async fn from_config(config: &EmbedderConfig) -> Result<Self> {
match &config.model_source {
ModelSource::Local {
model_path,
tokenizer_path: _,
} => Self::from_file(model_path, config).await,
ModelSource::Pretrained(model) => Self::from_pretrained(*model, config).await,
ModelSource::HuggingFace { model_id, revision } => {
Self::from_huggingface(model_id, revision.as_deref(), config).await
}
ModelSource::Url {
model_url,
tokenizer_url: _,
} => Self::from_url(model_url, config).await,
}
}
/// Load model from a local ONNX file
#[instrument(skip_all, fields(path = %path.as_ref().display()))]
pub async fn from_file(path: impl AsRef<Path>, config: &EmbedderConfig) -> Result<Self> {
let path = path.as_ref();
info!("Loading ONNX model from file: {}", path.display());
if !path.exists() {
return Err(EmbeddingError::model_not_found(path.display().to_string()));
}
let file_size = fs::metadata(path)?.len();
let session = Self::create_session(path, config)?;
let info = Self::extract_model_info(&session, path, file_size)?;
Ok(Self { session, info })
}
/// Load a pretrained model (downloads if not cached)
#[instrument(skip_all, fields(model = ?model))]
pub async fn from_pretrained(model: PretrainedModel, config: &EmbedderConfig) -> Result<Self> {
let model_id = model.model_id();
info!("Loading pretrained model: {}", model_id);
// Check cache first
let cache_path = config.cache_dir.join(sanitize_model_id(model_id));
let model_path = cache_path.join("model.onnx");
if model_path.exists() {
debug!("Found cached model at {}", model_path.display());
return Self::from_file(&model_path, config).await;
}
// Download from HuggingFace
Self::from_huggingface(model_id, None, config).await
}
/// Load model from HuggingFace Hub
#[instrument(skip_all, fields(model_id = %model_id))]
pub async fn from_huggingface(
model_id: &str,
revision: Option<&str>,
config: &EmbedderConfig,
) -> Result<Self> {
let cache_path = config.cache_dir.join(sanitize_model_id(model_id));
fs::create_dir_all(&cache_path)?;
let model_path = cache_path.join("model.onnx");
if !model_path.exists() {
info!("Downloading model from HuggingFace: {}", model_id);
download_from_huggingface(model_id, revision, &cache_path, config.show_progress)
.await?;
}
Self::from_file(&model_path, config).await
}
/// Load model from a URL
#[instrument(skip_all, fields(url = %url))]
pub async fn from_url(url: &str, config: &EmbedderConfig) -> Result<Self> {
let hash = hash_url(url);
let cache_path = config.cache_dir.join(&hash);
fs::create_dir_all(&cache_path)?;
let model_path = cache_path.join("model.onnx");
if !model_path.exists() {
info!("Downloading model from URL: {}", url);
download_file(url, &model_path, config.show_progress).await?;
}
Self::from_file(&model_path, config).await
}
/// Create an ONNX session with the specified configuration
fn create_session(path: &Path, config: &EmbedderConfig) -> Result<Session> {
let mut builder = Session::builder()?;
// Set optimization level
if config.optimize_graph {
builder = builder.with_optimization_level(GraphOptimizationLevel::Level3)?;
}
// Set number of threads
builder = builder.with_intra_threads(config.num_threads)?;
// Configure execution provider
match config.execution_provider {
ExecutionProvider::Cpu => {
// Default CPU provider
}
#[cfg(feature = "cuda")]
ExecutionProvider::Cuda { device_id } => {
builder = builder.with_execution_providers([
ort::execution_providers::CUDAExecutionProvider::default()
.with_device_id(device_id)
.build(),
])?;
}
#[cfg(feature = "tensorrt")]
ExecutionProvider::TensorRt { device_id } => {
builder = builder.with_execution_providers([
ort::execution_providers::TensorRTExecutionProvider::default()
.with_device_id(device_id)
.build(),
])?;
}
#[cfg(feature = "coreml")]
ExecutionProvider::CoreMl => {
builder = builder.with_execution_providers([
ort::execution_providers::CoreMLExecutionProvider::default().build(),
])?;
}
_ => {
warn!(
"Requested execution provider not available, falling back to CPU"
);
}
}
let session = builder.commit_from_file(path)?;
Ok(session)
}
/// Extract model information from the session
fn extract_model_info(session: &Session, path: &Path, file_size: u64) -> Result<ModelInfo> {
let inputs: Vec<String> = session.inputs.iter().map(|i| i.name.clone()).collect();
let outputs: Vec<String> = session.outputs.iter().map(|o| o.name.clone()).collect();
// Default embedding dimension (will be determined at runtime from actual output)
// Most sentence-transformers models output 384 dimensions
let dimension = 384;
let name = path
.file_stem()
.map(|s| s.to_string_lossy().to_string())
.unwrap_or_else(|| "unknown".to_string());
Ok(ModelInfo {
name,
dimension,
max_seq_length: 512,
file_size,
input_names: inputs,
output_names: outputs,
})
}
/// Run inference on encoded inputs
#[instrument(skip_all, fields(batch_size, seq_length))]
pub fn run(
&mut self,
input_ids: &[i64],
attention_mask: &[i64],
token_type_ids: &[i64],
shape: &[usize],
) -> Result<Vec<Vec<f32>>> {
use ort::value::Tensor;
let batch_size = shape[0];
let seq_length = shape[1];
debug!(
"Running inference: batch_size={}, seq_length={}",
batch_size, seq_length
);
// Create input tensors using ort's Tensor type
let input_ids_tensor = Tensor::from_array((
vec![batch_size, seq_length],
input_ids.to_vec().into_boxed_slice(),
))
.map_err(|e| EmbeddingError::invalid_model(e.to_string()))?;
let attention_mask_tensor = Tensor::from_array((
vec![batch_size, seq_length],
attention_mask.to_vec().into_boxed_slice(),
))
.map_err(|e| EmbeddingError::invalid_model(e.to_string()))?;
let token_type_ids_tensor = Tensor::from_array((
vec![batch_size, seq_length],
token_type_ids.to_vec().into_boxed_slice(),
))
.map_err(|e| EmbeddingError::invalid_model(e.to_string()))?;
// Build inputs vector
let inputs = vec![
("input_ids", input_ids_tensor.into_dyn()),
("attention_mask", attention_mask_tensor.into_dyn()),
("token_type_ids", token_type_ids_tensor.into_dyn()),
];
// Run inference
let outputs = self.session.run(inputs)
.map_err(EmbeddingError::OnnxRuntime)?;
// Extract output tensor
// Usually the output is [batch, seq_len, hidden_size] or [batch, hidden_size]
let output_names = ["last_hidden_state", "output", "sentence_embedding"];
// Find the appropriate output by name, or use the first one
let output_iter: Vec<_> = outputs.iter().collect();
let output = output_iter
.iter()
.find(|(name, _)| output_names.contains(name))
.or_else(|| output_iter.first())
.map(|(_, v)| v)
.ok_or_else(|| EmbeddingError::invalid_model("No output tensor found"))?;
// In ort 2.0, try_extract_tensor returns (&Shape, &[f32])
let (tensor_shape, tensor_data) = output
.try_extract_tensor::<f32>()
.map_err(|e| EmbeddingError::invalid_model(e.to_string()))?;
// Convert Shape to Vec<usize> - Shape yields i64
let dims: Vec<usize> = tensor_shape.iter().map(|&d| d as usize).collect();
// Handle different output shapes
let embeddings = if dims.len() == 3 {
// [batch, seq_len, hidden] - need pooling
let hidden_size = dims[2];
(0..batch_size)
.map(|i| {
let start = i * seq_length * hidden_size;
let end = start + seq_length * hidden_size;
tensor_data[start..end].to_vec()
})
.collect()
} else if dims.len() == 2 {
// [batch, hidden] - already pooled
let hidden_size = dims[1];
(0..batch_size)
.map(|i| {
let start = i * hidden_size;
let end = start + hidden_size;
tensor_data[start..end].to_vec()
})
.collect()
} else {
return Err(EmbeddingError::invalid_model(format!(
"Unexpected output shape: {:?}",
dims
)));
};
Ok(embeddings)
}
/// Get model info
pub fn info(&self) -> &ModelInfo {
&self.info
}
/// Get embedding dimension
pub fn dimension(&self) -> usize {
self.info.dimension
}
}
/// Download model files from HuggingFace Hub
async fn download_from_huggingface(
model_id: &str,
revision: Option<&str>,
cache_path: &Path,
show_progress: bool,
) -> Result<()> {
let revision = revision.unwrap_or("main");
let base_url = format!(
"https://huggingface.co/{}/resolve/{}",
model_id, revision
);
let model_path = cache_path.join("model.onnx");
// Try to download model.onnx - check multiple locations
if !model_path.exists() {
// Location 1: Root directory (model.onnx)
let root_url = format!("{}/model.onnx", base_url);
debug!("Trying to download model from root: {}", root_url);
let root_result = download_file(&root_url, &model_path, show_progress).await;
// Location 2: ONNX subfolder (onnx/model.onnx) - common for sentence-transformers
if root_result.is_err() && !model_path.exists() {
let onnx_url = format!("{}/onnx/model.onnx", base_url);
debug!("Root download failed, trying onnx subfolder: {}", onnx_url);
match download_file(&onnx_url, &model_path, show_progress).await {
Ok(_) => debug!("Downloaded model.onnx from onnx/ subfolder"),
Err(e) => {
// Both locations failed
return Err(EmbeddingError::download_failed(format!(
"Failed to download model.onnx from {} - tried both root and onnx/ subfolder: {}",
model_id, e
)));
}
}
} else if let Err(e) = root_result {
// Root failed but model exists (shouldn't happen, but handle gracefully)
if !model_path.exists() {
return Err(e);
}
} else {
debug!("Downloaded model.onnx from root");
}
}
// Download auxiliary files (tokenizer.json, config.json) - these are optional
let aux_files = ["tokenizer.json", "config.json"];
for file in aux_files {
let path = cache_path.join(file);
if !path.exists() {
// Try root first, then onnx subfolder
let root_url = format!("{}/{}", base_url, file);
match download_file(&root_url, &path, show_progress).await {
Ok(_) => debug!("Downloaded {}", file),
Err(_) => {
// Try onnx subfolder
let onnx_url = format!("{}/onnx/{}", base_url, file);
match download_file(&onnx_url, &path, show_progress).await {
Ok(_) => debug!("Downloaded {} from onnx/ subfolder", file),
Err(e) => warn!("Failed to download {} (optional): {}", file, e),
}
}
}
}
}
Ok(())
}
/// Download a file from URL with optional progress bar
async fn download_file(url: &str, path: &Path, show_progress: bool) -> Result<()> {
let client = reqwest::Client::new();
let response = client.get(url).send().await?;
if !response.status().is_success() {
return Err(EmbeddingError::download_failed(format!(
"HTTP {}: {}",
response.status(),
url
)));
}
let total_size = response.content_length().unwrap_or(0);
let pb = if show_progress && total_size > 0 {
let pb = ProgressBar::new(total_size);
pb.set_style(
ProgressStyle::default_bar()
.template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({eta})")
.unwrap()
.progress_chars("#>-"),
);
Some(pb)
} else {
None
};
let mut file = fs::File::create(path)?;
let mut downloaded = 0u64;
use futures_util::StreamExt;
let mut stream = response.bytes_stream();
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
file.write_all(&chunk)?;
downloaded += chunk.len() as u64;
if let Some(ref pb) = pb {
pb.set_position(downloaded);
}
}
if let Some(pb) = pb {
pb.finish_with_message("Downloaded");
}
Ok(())
}
/// Sanitize model ID for use as directory name
fn sanitize_model_id(model_id: &str) -> String {
model_id.replace(['/', '\\', ':'], "_")
}
/// Create a hash of a URL for caching
fn hash_url(url: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(url.as_bytes());
hex::encode(&hasher.finalize()[..8])
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sanitize_model_id() {
assert_eq!(
sanitize_model_id("sentence-transformers/all-MiniLM-L6-v2"),
"sentence-transformers_all-MiniLM-L6-v2"
);
}
#[test]
fn test_hash_url() {
let hash = hash_url("https://example.com/model.onnx");
assert_eq!(hash.len(), 16); // 8 bytes = 16 hex chars
}
}

View File

@@ -0,0 +1,397 @@
//! Pooling strategies for combining token embeddings into sentence embeddings
use crate::config::PoolingStrategy;
use rayon::prelude::*;
use tracing::{debug, instrument};
/// Pooler for combining token embeddings
#[derive(Debug, Clone)]
pub struct Pooler {
strategy: PoolingStrategy,
normalize: bool,
}
impl Pooler {
/// Create a new pooler with the given strategy
pub fn new(strategy: PoolingStrategy, normalize: bool) -> Self {
Self { strategy, normalize }
}
/// Pool token embeddings into sentence embeddings
///
/// # Arguments
/// * `token_embeddings` - Token embeddings for each sequence [batch][seq_len * hidden]
/// * `attention_mask` - Attention mask for each sequence [batch][seq_len]
/// * `seq_length` - Sequence length
/// * `hidden_size` - Hidden dimension size
#[instrument(skip_all, fields(batch_size = token_embeddings.len(), strategy = ?self.strategy))]
pub fn pool(
&self,
token_embeddings: &[Vec<f32>],
attention_mask: &[Vec<i64>],
seq_length: usize,
hidden_size: usize,
) -> Vec<Vec<f32>> {
debug!(
"Pooling {} sequences with strategy {:?}",
token_embeddings.len(),
self.strategy
);
let embeddings: Vec<Vec<f32>> = token_embeddings
.par_iter()
.zip(attention_mask.par_iter())
.map(|(tokens, mask)| {
self.pool_single(tokens, mask, seq_length, hidden_size)
})
.collect();
if self.normalize {
embeddings
.into_par_iter()
.map(|emb| Self::normalize_vector(&emb))
.collect()
} else {
embeddings
}
}
/// Pool a single sequence
fn pool_single(
&self,
token_embeddings: &[f32],
attention_mask: &[i64],
seq_length: usize,
hidden_size: usize,
) -> Vec<f32> {
match self.strategy {
PoolingStrategy::Mean => {
self.mean_pool(token_embeddings, attention_mask, seq_length, hidden_size)
}
PoolingStrategy::Cls => {
self.cls_pool(token_embeddings, hidden_size)
}
PoolingStrategy::Max => {
self.max_pool(token_embeddings, attention_mask, seq_length, hidden_size)
}
PoolingStrategy::MeanSqrtLen => {
self.mean_sqrt_len_pool(token_embeddings, attention_mask, seq_length, hidden_size)
}
PoolingStrategy::LastToken => {
self.last_token_pool(token_embeddings, attention_mask, seq_length, hidden_size)
}
PoolingStrategy::WeightedMean => {
self.weighted_mean_pool(token_embeddings, attention_mask, seq_length, hidden_size)
}
}
}
/// Mean pooling over all tokens (weighted by attention mask)
fn mean_pool(
&self,
token_embeddings: &[f32],
attention_mask: &[i64],
seq_length: usize,
hidden_size: usize,
) -> Vec<f32> {
let mut result = vec![0.0f32; hidden_size];
let mut count = 0.0f32;
for (i, &mask) in attention_mask.iter().enumerate().take(seq_length) {
if mask == 1 {
let start = i * hidden_size;
let end = start + hidden_size;
for (j, val) in token_embeddings[start..end].iter().enumerate() {
result[j] += val;
}
count += 1.0;
}
}
if count > 0.0 {
for val in &mut result {
*val /= count;
}
}
result
}
/// CLS token pooling (first token)
fn cls_pool(&self, token_embeddings: &[f32], hidden_size: usize) -> Vec<f32> {
token_embeddings[..hidden_size].to_vec()
}
/// Max pooling over all tokens
fn max_pool(
&self,
token_embeddings: &[f32],
attention_mask: &[i64],
seq_length: usize,
hidden_size: usize,
) -> Vec<f32> {
let mut result = vec![f32::NEG_INFINITY; hidden_size];
for (i, &mask) in attention_mask.iter().enumerate().take(seq_length) {
if mask == 1 {
let start = i * hidden_size;
let end = start + hidden_size;
for (j, val) in token_embeddings[start..end].iter().enumerate() {
if *val > result[j] {
result[j] = *val;
}
}
}
}
// Replace -inf with 0 for empty sequences
for val in &mut result {
if val.is_infinite() {
*val = 0.0;
}
}
result
}
/// Mean pooling with sqrt(length) scaling
fn mean_sqrt_len_pool(
&self,
token_embeddings: &[f32],
attention_mask: &[i64],
seq_length: usize,
hidden_size: usize,
) -> Vec<f32> {
let mut result = self.mean_pool(token_embeddings, attention_mask, seq_length, hidden_size);
let length: f32 = attention_mask.iter().filter(|&&m| m == 1).count() as f32;
if length > 0.0 {
let scale = length.sqrt();
for val in &mut result {
*val *= scale;
}
}
result
}
/// Last token pooling (for decoder models)
fn last_token_pool(
&self,
token_embeddings: &[f32],
attention_mask: &[i64],
_seq_length: usize,
hidden_size: usize,
) -> Vec<f32> {
// Find last non-padding token
let last_idx = attention_mask
.iter()
.rposition(|&m| m == 1)
.unwrap_or(0);
let start = last_idx * hidden_size;
let end = start + hidden_size;
if end <= token_embeddings.len() {
token_embeddings[start..end].to_vec()
} else {
self.cls_pool(token_embeddings, hidden_size)
}
}
/// Weighted mean pooling based on position
fn weighted_mean_pool(
&self,
token_embeddings: &[f32],
attention_mask: &[i64],
seq_length: usize,
hidden_size: usize,
) -> Vec<f32> {
let mut result = vec![0.0f32; hidden_size];
let mut total_weight = 0.0f32;
for (i, &mask) in attention_mask.iter().enumerate().take(seq_length) {
if mask == 1 {
// Weight decreases with position (more weight to early tokens)
let weight = 1.0 / (i + 1) as f32;
let start = i * hidden_size;
let end = start + hidden_size;
for (j, val) in token_embeddings[start..end].iter().enumerate() {
result[j] += val * weight;
}
total_weight += weight;
}
}
if total_weight > 0.0 {
for val in &mut result {
*val /= total_weight;
}
}
result
}
/// L2 normalize a vector
pub fn normalize_vector(vec: &[f32]) -> Vec<f32> {
let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-12 {
vec.iter().map(|x| x / norm).collect()
} else {
vec.to_vec()
}
}
/// Compute cosine similarity between two vectors (SIMD-optimized)
#[cfg(feature = "simsimd")]
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
use simsimd::SpatialSimilarity;
f32::cosine(a, b).unwrap_or(0.0) as f32
}
#[cfg(not(feature = "simsimd"))]
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a > 1e-12 && norm_b > 1e-12 {
dot / (norm_a * norm_b)
} else {
0.0
}
}
/// Compute dot product between two vectors (SIMD-optimized)
#[cfg(feature = "simsimd")]
pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
use simsimd::SpatialSimilarity;
f32::dot(a, b).unwrap_or(0.0) as f32
}
#[cfg(not(feature = "simsimd"))]
pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
/// Compute Euclidean distance between two vectors (SIMD-optimized)
#[cfg(feature = "simsimd")]
pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
use simsimd::SpatialSimilarity;
(f32::sqeuclidean(a, b).unwrap_or(0.0) as f32).sqrt()
}
#[cfg(not(feature = "simsimd"))]
pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
}
impl Default for Pooler {
fn default() -> Self {
Self::new(PoolingStrategy::Mean, true)
}
}
/// Batch distance computation using ndarray
pub fn batch_cosine_similarity(
query: &[f32],
candidates: &[Vec<f32>],
) -> Vec<f32> {
candidates
.par_iter()
.map(|c| Pooler::cosine_similarity(query, c))
.collect()
}
/// Find top-k most similar vectors
pub fn top_k_similar(
query: &[f32],
candidates: &[Vec<f32>],
k: usize,
) -> Vec<(usize, f32)> {
let mut scores: Vec<(usize, f32)> = candidates
.par_iter()
.enumerate()
.map(|(i, c)| (i, Pooler::cosine_similarity(query, c)))
.collect();
// Sort by score descending
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores.truncate(k);
scores
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_normalize_vector() {
let vec = vec![3.0, 4.0];
let normalized = Pooler::normalize_vector(&vec);
let norm: f32 = normalized.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-6);
}
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
let c = vec![0.0, 1.0, 0.0];
assert!((Pooler::cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
assert!((Pooler::cosine_similarity(&a, &c)).abs() < 1e-6);
}
#[test]
fn test_mean_pooling() {
let pooler = Pooler::new(PoolingStrategy::Mean, false);
// 2 tokens, 3 dimensions
let embeddings = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let mask = vec![1i64, 1];
let result = pooler.pool_single(&embeddings, &mask, 2, 3);
assert_eq!(result.len(), 3);
assert!((result[0] - 2.5).abs() < 1e-6);
assert!((result[1] - 3.5).abs() < 1e-6);
assert!((result[2] - 4.5).abs() < 1e-6);
}
#[test]
fn test_cls_pooling() {
let pooler = Pooler::new(PoolingStrategy::Cls, false);
let embeddings = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let mask = vec![1i64, 1];
let result = pooler.pool_single(&embeddings, &mask, 2, 3);
assert_eq!(result, vec![1.0, 2.0, 3.0]);
}
#[test]
fn test_top_k_similar() {
let query = vec![1.0, 0.0, 0.0];
let candidates = vec![
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.707, 0.707, 0.0],
];
let results = top_k_similar(&query, &candidates, 2);
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, 0); // Most similar
}
}

View File

@@ -0,0 +1,565 @@
//! Standalone vector database integration for ONNX embeddings
//!
//! This module provides a lightweight vector database built on top of the
//! embedding system, demonstrating how to integrate with RuVector or use
//! as a standalone semantic search engine.
use crate::{Embedder, EmbeddingError, Result};
use parking_lot::RwLock;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tracing::{debug, info, instrument};
use uuid::Uuid;
/// Vector ID type (using String for compatibility with RuVector)
pub type VectorId = String;
/// Distance metric for similarity calculation
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum Distance {
/// Cosine similarity (default, best for normalized embeddings)
#[default]
Cosine,
/// Euclidean (L2) distance
Euclidean,
/// Dot product
DotProduct,
}
/// Search result with text and score
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResult {
/// Vector ID
pub id: VectorId,
/// Original text
pub text: String,
/// Similarity score (higher is better for cosine, lower for euclidean)
pub score: f32,
/// Optional metadata
pub metadata: Option<serde_json::Value>,
}
/// Stored vector entry
#[derive(Debug, Clone, Serialize, Deserialize)]
struct StoredEntry {
id: VectorId,
text: String,
vector: Vec<f32>,
metadata: Option<serde_json::Value>,
}
/// Configuration for creating a vector index
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IndexConfig {
/// Distance metric
pub distance: Distance,
/// Maximum number of elements (for pre-allocation)
pub max_elements: usize,
/// Number of results to over-fetch for filtering
pub ef_search: usize,
}
impl Default for IndexConfig {
fn default() -> Self {
Self {
distance: Distance::Cosine,
max_elements: 100_000,
ef_search: 100,
}
}
}
/// RuVector-compatible embeddings index
///
/// A lightweight in-memory vector database that integrates ONNX embeddings
/// with similarity search. Compatible with RuVector's API patterns.
pub struct RuVectorEmbeddings {
/// The embedder for generating vectors (wrapped in RwLock for mutable access)
embedder: Arc<RwLock<Embedder>>,
/// Stored vectors and metadata
entries: RwLock<Vec<StoredEntry>>,
/// Index name
name: String,
/// Configuration
config: IndexConfig,
}
impl RuVectorEmbeddings {
/// Create a new RuVector index with the given embedder
#[instrument(skip_all)]
pub fn new(
name: impl Into<String>,
embedder: Embedder,
config: IndexConfig,
) -> Result<Self> {
let name = name.into();
let dimension = embedder.dimension();
info!(
"Creating RuVector index '{}' with dimension {} and {:?} distance",
name, dimension, config.distance
);
Ok(Self {
embedder: Arc::new(RwLock::new(embedder)),
entries: RwLock::new(Vec::with_capacity(config.max_elements.min(10_000))),
name,
config,
})
}
/// Create with default configuration
pub fn new_default(name: impl Into<String>, embedder: Embedder) -> Result<Self> {
Self::new(name, embedder, IndexConfig::default())
}
/// Insert a single text with optional metadata
#[instrument(skip(self, text, metadata), fields(text_len = text.len()))]
pub fn insert(
&self,
text: &str,
metadata: Option<serde_json::Value>,
) -> Result<VectorId> {
let embedding = self.embedder.write().embed_one(text)?;
self.insert_with_embedding(text, embedding, metadata)
}
/// Insert with pre-computed embedding
pub fn insert_with_embedding(
&self,
text: &str,
embedding: Vec<f32>,
metadata: Option<serde_json::Value>,
) -> Result<VectorId> {
let id = Uuid::new_v4().to_string();
let entry = StoredEntry {
id: id.clone(),
text: text.to_string(),
vector: embedding,
metadata,
};
self.entries.write().push(entry);
debug!("Inserted text with ID {}", id);
Ok(id)
}
/// Insert multiple texts
#[instrument(skip(self, texts), fields(count = texts.len()))]
pub fn insert_batch<S: AsRef<str>>(&self, texts: &[S]) -> Result<Vec<VectorId>> {
let embeddings = self.embedder.write().embed(texts)?;
self.insert_batch_with_embeddings(texts, embeddings.embeddings)
}
/// Insert batch with pre-computed embeddings
pub fn insert_batch_with_embeddings<S: AsRef<str>>(
&self,
texts: &[S],
embeddings: Vec<Vec<f32>>,
) -> Result<Vec<VectorId>> {
if texts.len() != embeddings.len() {
return Err(EmbeddingError::dimension_mismatch(
texts.len(),
embeddings.len(),
));
}
let entries: Vec<StoredEntry> = texts
.iter()
.zip(embeddings)
.map(|(text, vector)| StoredEntry {
id: Uuid::new_v4().to_string(),
text: text.as_ref().to_string(),
vector,
metadata: None,
})
.collect();
let ids: Vec<VectorId> = entries.iter().map(|e| e.id.clone()).collect();
self.entries.write().extend(entries);
info!("Inserted {} vectors", ids.len());
Ok(ids)
}
/// Search for similar texts
#[instrument(skip(self, query), fields(k))]
pub fn search(&self, query: &str, k: usize) -> Result<Vec<SearchResult>> {
let query_embedding = self.embedder.write().embed_one(query)?;
self.search_with_embedding(&query_embedding, k)
}
/// Search with pre-computed query embedding
pub fn search_with_embedding(
&self,
query_embedding: &[f32],
k: usize,
) -> Result<Vec<SearchResult>> {
let entries = self.entries.read();
if entries.is_empty() {
return Ok(Vec::new());
}
// Calculate similarities in parallel
let mut scored: Vec<(usize, f32)> = entries
.par_iter()
.enumerate()
.map(|(i, entry)| {
let score = self.compute_similarity(query_embedding, &entry.vector);
(i, score)
})
.collect();
// Sort by score (descending for cosine/dot, ascending for euclidean)
match self.config.distance {
Distance::Cosine | Distance::DotProduct => {
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
}
Distance::Euclidean => {
scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
}
}
// Take top k
let results: Vec<SearchResult> = scored
.into_iter()
.take(k)
.map(|(i, score)| {
let entry = &entries[i];
SearchResult {
id: entry.id.clone(),
text: entry.text.clone(),
score,
metadata: entry.metadata.clone(),
}
})
.collect();
debug!("Search returned {} results", results.len());
Ok(results)
}
/// Compute similarity/distance between two vectors
fn compute_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
match self.config.distance {
Distance::Cosine => Self::cosine_similarity(a, b),
Distance::Euclidean => Self::euclidean_distance(a, b),
Distance::DotProduct => Self::dot_product(a, b),
}
}
/// Cosine similarity between two vectors
#[inline]
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a > 1e-10 && norm_b > 1e-10 {
dot / (norm_a * norm_b)
} else {
0.0
}
}
/// Euclidean (L2) distance
#[inline]
fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
/// Dot product
#[inline]
fn dot_product(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
/// Search with metadata filter
#[instrument(skip(self, query, filter), fields(k))]
pub fn search_filtered<F>(&self, query: &str, k: usize, filter: F) -> Result<Vec<SearchResult>>
where
F: Fn(&serde_json::Value) -> bool + Sync,
{
let query_embedding = self.embedder.write().embed_one(query)?;
let entries = self.entries.read();
if entries.is_empty() {
return Ok(Vec::new());
}
// Calculate similarities with filtering
let mut scored: Vec<(usize, f32)> = entries
.par_iter()
.enumerate()
.filter_map(|(i, entry)| {
// Apply filter
if let Some(ref meta) = entry.metadata {
if !filter(meta) {
return None;
}
}
let score = self.compute_similarity(&query_embedding, &entry.vector);
Some((i, score))
})
.collect();
// Sort
match self.config.distance {
Distance::Cosine | Distance::DotProduct => {
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
}
Distance::Euclidean => {
scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
}
}
let results: Vec<SearchResult> = scored
.into_iter()
.take(k)
.map(|(i, score)| {
let entry = &entries[i];
SearchResult {
id: entry.id.clone(),
text: entry.text.clone(),
score,
metadata: entry.metadata.clone(),
}
})
.collect();
Ok(results)
}
/// Get a vector by ID
pub fn get(&self, id: &str) -> Option<(String, Vec<f32>)> {
let entries = self.entries.read();
entries
.iter()
.find(|e| e.id == id)
.map(|e| (e.text.clone(), e.vector.clone()))
}
/// Delete a vector by ID
pub fn delete(&self, id: &str) -> bool {
let mut entries = self.entries.write();
let len_before = entries.len();
entries.retain(|e| e.id != id);
entries.len() < len_before
}
/// Get the number of vectors in the index
pub fn len(&self) -> usize {
self.entries.read().len()
}
/// Check if the index is empty
pub fn is_empty(&self) -> bool {
self.entries.read().is_empty()
}
/// Get index name
pub fn name(&self) -> &str {
&self.name
}
/// Get the embedding dimension
pub fn dimension(&self) -> usize {
self.embedder.read().dimension()
}
/// Get reference to the embedder (wrapped in Arc<RwLock>)
pub fn embedder(&self) -> &Arc<RwLock<Embedder>> {
&self.embedder
}
/// Clear all vectors
pub fn clear(&self) {
self.entries.write().clear();
}
/// Export all entries for persistence
pub fn export(&self) -> Vec<(VectorId, String, Vec<f32>, Option<serde_json::Value>)> {
self.entries
.read()
.iter()
.map(|e| (e.id.clone(), e.text.clone(), e.vector.clone(), e.metadata.clone()))
.collect()
}
/// Import entries (for loading from persistence)
pub fn import(
&self,
entries: Vec<(VectorId, String, Vec<f32>, Option<serde_json::Value>)>,
) {
let stored: Vec<StoredEntry> = entries
.into_iter()
.map(|(id, text, vector, metadata)| StoredEntry {
id,
text,
vector,
metadata,
})
.collect();
*self.entries.write() = stored;
}
}
/// Builder for creating RuVector indexes
pub struct RuVectorBuilder {
name: String,
embedder: Option<Embedder>,
config: IndexConfig,
}
impl RuVectorBuilder {
/// Create a new builder
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
embedder: None,
config: IndexConfig::default(),
}
}
/// Set the embedder
pub fn embedder(mut self, embedder: Embedder) -> Self {
self.embedder = Some(embedder);
self
}
/// Set distance metric
pub fn distance(mut self, distance: Distance) -> Self {
self.config.distance = distance;
self
}
/// Set max elements
pub fn max_elements(mut self, max: usize) -> Self {
self.config.max_elements = max;
self
}
/// Set ef_search parameter
pub fn ef_search(mut self, ef: usize) -> Self {
self.config.ef_search = ef;
self
}
/// Build the index
pub fn build(self) -> Result<RuVectorEmbeddings> {
let embedder = self
.embedder
.ok_or_else(|| EmbeddingError::invalid_config("Embedder is required"))?;
RuVectorEmbeddings::new(self.name, embedder, self.config)
}
}
/// RAG (Retrieval-Augmented Generation) helper
pub struct RagPipeline {
index: RuVectorEmbeddings,
top_k: usize,
}
impl RagPipeline {
/// Create a new RAG pipeline
pub fn new(index: RuVectorEmbeddings, top_k: usize) -> Self {
Self { index, top_k }
}
/// Retrieve context for a query
pub fn retrieve(&self, query: &str) -> Result<Vec<String>> {
let results = self.index.search(query, self.top_k)?;
Ok(results.into_iter().map(|r| r.text).collect())
}
/// Retrieve with scores
pub fn retrieve_with_scores(&self, query: &str) -> Result<Vec<(String, f32)>> {
let results = self.index.search(query, self.top_k)?;
Ok(results.into_iter().map(|r| (r.text, r.score)).collect())
}
/// Format retrieved context as a prompt
pub fn format_context(&self, query: &str) -> Result<String> {
let contexts = self.retrieve(query)?;
let mut prompt = String::from("Context:\n");
for (i, ctx) in contexts.iter().enumerate() {
prompt.push_str(&format!("[{}] {}\n", i + 1, ctx));
}
prompt.push_str(&format!("\nQuestion: {}", query));
Ok(prompt)
}
/// Format context with scores
pub fn format_context_with_scores(&self, query: &str) -> Result<String> {
let results = self.retrieve_with_scores(query)?;
let mut prompt = String::from("Context (with relevance scores):\n");
for (i, (ctx, score)) in results.iter().enumerate() {
prompt.push_str(&format!("[{} - {:.3}] {}\n", i + 1, score, ctx));
}
prompt.push_str(&format!("\nQuestion: {}", query));
Ok(prompt)
}
/// Add documents to the index
pub fn add_documents<S: AsRef<str>>(&self, documents: &[S]) -> Result<Vec<VectorId>> {
self.index.insert_batch(documents)
}
/// Get reference to the underlying index
pub fn index(&self) -> &RuVectorEmbeddings {
&self.index
}
/// Get mutable reference to set top_k
pub fn set_top_k(&mut self, k: usize) {
self.top_k = k;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
let c = vec![0.0, 1.0, 0.0];
assert!((RuVectorEmbeddings::cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
assert!(RuVectorEmbeddings::cosine_similarity(&a, &c).abs() < 1e-6);
}
#[test]
fn test_euclidean_distance() {
let a = vec![0.0, 0.0];
let b = vec![3.0, 4.0];
let dist = RuVectorEmbeddings::euclidean_distance(&a, &b);
assert!((dist - 5.0).abs() < 1e-6);
}
#[test]
fn test_dot_product() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
let dot = RuVectorEmbeddings::dot_product(&a, &b);
assert!((dot - 32.0).abs() < 1e-6);
}
}

View File

@@ -0,0 +1,260 @@
//! Text tokenization using HuggingFace tokenizers
use crate::{EmbeddingError, Result};
use std::path::Path;
use tokenizers::tokenizer::Tokenizer as HfTokenizer;
use tracing::{debug, instrument};
/// Wrapper around HuggingFace tokenizer with batch processing
pub struct Tokenizer {
inner: HfTokenizer,
max_length: usize,
pad_token_id: u32,
}
/// Encoded batch output
#[derive(Debug, Clone)]
pub struct EncodedBatch {
/// Token IDs [batch_size, seq_length]
pub input_ids: Vec<Vec<i64>>,
/// Attention mask [batch_size, seq_length]
pub attention_mask: Vec<Vec<i64>>,
/// Token type IDs [batch_size, seq_length]
pub token_type_ids: Vec<Vec<i64>>,
/// Original sequence lengths before padding
pub original_lengths: Vec<usize>,
}
impl EncodedBatch {
/// Get batch size
pub fn batch_size(&self) -> usize {
self.input_ids.len()
}
/// Get sequence length (padded)
pub fn seq_length(&self) -> usize {
self.input_ids.first().map(|v| v.len()).unwrap_or(0)
}
/// Convert to flat arrays for ONNX input
pub fn to_onnx_inputs(&self) -> (Vec<i64>, Vec<i64>, Vec<i64>, Vec<usize>) {
let batch_size = self.batch_size();
let seq_length = self.seq_length();
let total_len = batch_size * seq_length;
let mut flat_input_ids = Vec::with_capacity(total_len);
let mut flat_attention_mask = Vec::with_capacity(total_len);
let mut flat_token_type_ids = Vec::with_capacity(total_len);
for i in 0..batch_size {
flat_input_ids.extend(&self.input_ids[i]);
flat_attention_mask.extend(&self.attention_mask[i]);
flat_token_type_ids.extend(&self.token_type_ids[i]);
}
(
flat_input_ids,
flat_attention_mask,
flat_token_type_ids,
vec![batch_size, seq_length],
)
}
}
/// Helper function to find pad token ID from vocabulary
fn find_pad_token_id(tokenizer: &HfTokenizer) -> u32 {
let vocab = tokenizer.get_vocab(true);
vocab
.get("[PAD]")
.or_else(|| vocab.get("<pad>"))
.or_else(|| vocab.get("<|pad|>"))
.copied()
.unwrap_or(0)
}
impl Tokenizer {
/// Load tokenizer from a local file
#[instrument(skip_all, fields(path = %path.as_ref().display()))]
pub fn from_file(path: impl AsRef<Path>, max_length: usize) -> Result<Self> {
let path = path.as_ref();
debug!("Loading tokenizer from file");
let inner = HfTokenizer::from_file(path)
.map_err(|e| EmbeddingError::tokenizer_not_found(e.to_string()))?;
let pad_token_id = find_pad_token_id(&inner);
Ok(Self {
inner,
max_length,
pad_token_id,
})
}
/// Load tokenizer from HuggingFace Hub by downloading tokenizer.json
#[instrument(skip_all, fields(model_id = %model_id))]
pub fn from_pretrained(model_id: &str, max_length: usize) -> Result<Self> {
debug!("Loading tokenizer from HuggingFace Hub: {}", model_id);
// Download tokenizer.json from HuggingFace Hub
let url = format!(
"https://huggingface.co/{}/resolve/main/tokenizer.json",
model_id
);
let response = reqwest::blocking::get(&url)
.map_err(|e| EmbeddingError::download_failed(format!("Failed to download tokenizer: {}", e)))?;
if !response.status().is_success() {
return Err(EmbeddingError::download_failed(format!(
"Failed to download tokenizer from {}: HTTP {}",
url,
response.status()
)));
}
let bytes = response.bytes()
.map_err(|e| EmbeddingError::download_failed(e.to_string()))?;
let inner = HfTokenizer::from_bytes(&bytes)
.map_err(|e| EmbeddingError::tokenizer_not_found(e.to_string()))?;
let pad_token_id = find_pad_token_id(&inner);
Ok(Self {
inner,
max_length,
pad_token_id,
})
}
/// Load tokenizer from JSON string
pub fn from_json(json: &str, max_length: usize) -> Result<Self> {
let inner = HfTokenizer::from_bytes(json.as_bytes())
.map_err(|e| EmbeddingError::tokenizer_not_found(e.to_string()))?;
let pad_token_id = find_pad_token_id(&inner);
Ok(Self {
inner,
max_length,
pad_token_id,
})
}
/// Encode a single text
pub fn encode(&self, text: &str) -> Result<EncodedBatch> {
self.encode_batch(&[text])
}
/// Encode a batch of texts
#[instrument(skip_all, fields(batch_size = texts.len()))]
pub fn encode_batch<S: AsRef<str>>(&self, texts: &[S]) -> Result<EncodedBatch> {
if texts.is_empty() {
return Err(EmbeddingError::EmptyInput);
}
debug!("Encoding batch of {} texts", texts.len());
// Encode all texts
let encodings: Vec<_> = texts
.iter()
.map(|t| self.inner.encode(t.as_ref(), true))
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(EmbeddingError::from)?;
// Find max length in batch (capped at max_length)
let max_len = encodings
.iter()
.map(|e| e.get_ids().len().min(self.max_length))
.max()
.unwrap_or(0);
// Pad all sequences to the same length
let mut input_ids = Vec::with_capacity(texts.len());
let mut attention_mask = Vec::with_capacity(texts.len());
let mut token_type_ids = Vec::with_capacity(texts.len());
let mut original_lengths = Vec::with_capacity(texts.len());
for encoding in &encodings {
let ids = encoding.get_ids();
let type_ids = encoding.get_type_ids();
let len = ids.len().min(self.max_length);
original_lengths.push(len);
// Truncate if necessary and convert to i64
let mut ids_vec: Vec<i64> = ids[..len].iter().map(|&x| x as i64).collect();
let mut mask_vec: Vec<i64> = vec![1; len];
let mut type_vec: Vec<i64> = type_ids[..len].iter().map(|&x| x as i64).collect();
// Pad to max_len
let pad_len = max_len - len;
if pad_len > 0 {
ids_vec.extend(std::iter::repeat_n(self.pad_token_id as i64, pad_len));
mask_vec.extend(std::iter::repeat_n(0i64, pad_len));
type_vec.extend(std::iter::repeat_n(0i64, pad_len));
}
input_ids.push(ids_vec);
attention_mask.push(mask_vec);
token_type_ids.push(type_vec);
}
Ok(EncodedBatch {
input_ids,
attention_mask,
token_type_ids,
original_lengths,
})
}
/// Get the vocabulary size
pub fn vocab_size(&self) -> usize {
self.inner.get_vocab_size(true)
}
/// Get the max length
pub fn max_length(&self) -> usize {
self.max_length
}
/// Set the max length
pub fn set_max_length(&mut self, max_length: usize) {
self.max_length = max_length;
}
/// Decode token IDs back to text
pub fn decode(&self, ids: &[u32], skip_special_tokens: bool) -> Result<String> {
self.inner
.decode(ids, skip_special_tokens)
.map_err(EmbeddingError::from)
}
/// Get the pad token ID
pub fn pad_token_id(&self) -> u32 {
self.pad_token_id
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encoded_batch_to_onnx() {
let batch = EncodedBatch {
input_ids: vec![vec![101, 2054, 2003, 102], vec![101, 2054, 102, 0]],
attention_mask: vec![vec![1, 1, 1, 1], vec![1, 1, 1, 0]],
token_type_ids: vec![vec![0, 0, 0, 0], vec![0, 0, 0, 0]],
original_lengths: vec![4, 3],
};
let (ids, mask, types, shape) = batch.to_onnx_inputs();
assert_eq!(shape, vec![2, 4]);
assert_eq!(ids.len(), 8);
assert_eq!(mask.len(), 8);
assert_eq!(types.len(), 8);
}
}