Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,213 @@
//! Main WASM embedder implementation
use crate::error::{Result, WasmEmbeddingError};
use crate::model::TractModel;
use crate::pooling::{cosine_similarity, normalize_l2, PoolingStrategy};
use crate::tokenizer::WasmTokenizer;
use serde::{Deserialize, Serialize};
use wasm_bindgen::prelude::*;
/// Configuration for the WASM embedder
#[wasm_bindgen]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WasmEmbedderConfig {
/// Maximum sequence length
#[wasm_bindgen(skip)]
pub max_length: usize,
/// Pooling strategy
#[wasm_bindgen(skip)]
pub pooling: PoolingStrategy,
/// Whether to L2 normalize embeddings
#[wasm_bindgen(skip)]
pub normalize: bool,
}
#[wasm_bindgen]
impl WasmEmbedderConfig {
/// Create a new configuration
#[wasm_bindgen(constructor)]
pub fn new() -> Self {
Self::default()
}
/// Set maximum sequence length
#[wasm_bindgen(js_name = setMaxLength)]
pub fn set_max_length(mut self, max_length: usize) -> Self {
self.max_length = max_length;
self
}
/// Set whether to normalize embeddings
#[wasm_bindgen(js_name = setNormalize)]
pub fn set_normalize(mut self, normalize: bool) -> Self {
self.normalize = normalize;
self
}
/// Set pooling strategy (0=Mean, 1=Cls, 2=Max, 3=MeanSqrtLen, 4=LastToken)
#[wasm_bindgen(js_name = setPooling)]
pub fn set_pooling(mut self, pooling: u8) -> Self {
self.pooling = match pooling {
0 => PoolingStrategy::Mean,
1 => PoolingStrategy::Cls,
2 => PoolingStrategy::Max,
3 => PoolingStrategy::MeanSqrtLen,
4 => PoolingStrategy::LastToken,
_ => PoolingStrategy::Mean,
};
self
}
}
impl Default for WasmEmbedderConfig {
fn default() -> Self {
Self {
max_length: 256,
pooling: PoolingStrategy::Mean,
normalize: true,
}
}
}
/// WASM-compatible embedder using Tract for inference
#[wasm_bindgen]
pub struct WasmEmbedder {
model: TractModel,
tokenizer: WasmTokenizer,
config: WasmEmbedderConfig,
hidden_size: usize,
}
#[wasm_bindgen]
impl WasmEmbedder {
/// Create a new embedder from model and tokenizer bytes
///
/// # Arguments
/// * `model_bytes` - ONNX model file bytes
/// * `tokenizer_json` - Tokenizer JSON configuration
#[wasm_bindgen(constructor)]
pub fn new(model_bytes: &[u8], tokenizer_json: &str) -> std::result::Result<WasmEmbedder, JsValue> {
Self::with_config(model_bytes, tokenizer_json, WasmEmbedderConfig::default())
}
/// Create embedder with custom configuration
#[wasm_bindgen(js_name = withConfig)]
pub fn with_config(
model_bytes: &[u8],
tokenizer_json: &str,
config: WasmEmbedderConfig,
) -> std::result::Result<WasmEmbedder, JsValue> {
let model = TractModel::from_bytes(model_bytes, config.max_length)
.map_err(|e| JsValue::from_str(&e.to_string()))?;
let tokenizer = WasmTokenizer::from_json(tokenizer_json, config.max_length)
.map_err(|e| JsValue::from_str(&e.to_string()))?;
let hidden_size = model.hidden_size();
Ok(Self {
model,
tokenizer,
config,
hidden_size,
})
}
/// Generate embedding for a single text
#[wasm_bindgen(js_name = embedOne)]
pub fn embed_one(&mut self, text: &str) -> std::result::Result<Vec<f32>, JsValue> {
self.embed_one_internal(text)
.map_err(|e| JsValue::from_str(&e.to_string()))
}
/// Generate embeddings for multiple texts
#[wasm_bindgen(js_name = embedBatch)]
pub fn embed_batch(&mut self, texts: Vec<String>) -> std::result::Result<Vec<f32>, JsValue> {
let refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
self.embed_batch_internal(&refs)
.map_err(|e| JsValue::from_str(&e.to_string()))
}
/// Compute similarity between two texts
#[wasm_bindgen]
pub fn similarity(&mut self, text1: &str, text2: &str) -> std::result::Result<f32, JsValue> {
let emb1 = self.embed_one_internal(text1)
.map_err(|e| JsValue::from_str(&e.to_string()))?;
let emb2 = self.embed_one_internal(text2)
.map_err(|e| JsValue::from_str(&e.to_string()))?;
Ok(cosine_similarity(&emb1, &emb2))
}
/// Get the embedding dimension
#[wasm_bindgen]
pub fn dimension(&self) -> usize {
self.hidden_size
}
/// Get maximum sequence length
#[wasm_bindgen(js_name = maxLength)]
pub fn max_length(&self) -> usize {
self.config.max_length
}
}
// Internal implementation
impl WasmEmbedder {
fn embed_one_internal(&mut self, text: &str) -> Result<Vec<f32>> {
// Tokenize
let encoded = self.tokenizer.encode(text)?;
let attention_mask = encoded.attention_mask.clone();
// Run inference
let raw_output = self.model.run(&encoded)?;
// Determine hidden size from output
let seq_len = self.config.max_length;
if raw_output.len() >= seq_len {
let detected_hidden = raw_output.len() / seq_len;
if detected_hidden != self.hidden_size && detected_hidden > 0 {
self.hidden_size = detected_hidden;
self.model.set_hidden_size(detected_hidden);
}
}
// Apply pooling
let mut embedding = self.config.pooling.apply(
&raw_output,
&attention_mask,
self.hidden_size,
);
// Normalize if configured
if self.config.normalize {
normalize_l2(&mut embedding);
}
Ok(embedding)
}
fn embed_batch_internal(&mut self, texts: &[&str]) -> Result<Vec<f32>> {
let mut all_embeddings = Vec::with_capacity(texts.len() * self.hidden_size);
for text in texts {
let embedding = self.embed_one_internal(text)?;
all_embeddings.extend(embedding);
}
Ok(all_embeddings)
}
}
/// Compute cosine similarity between two embedding vectors (JS-friendly)
#[wasm_bindgen(js_name = cosineSimilarity)]
pub fn js_cosine_similarity(a: Vec<f32>, b: Vec<f32>) -> f32 {
cosine_similarity(&a, &b)
}
/// L2 normalize an embedding vector (JS-friendly)
#[wasm_bindgen(js_name = normalizeL2)]
pub fn js_normalize_l2(mut embedding: Vec<f32>) -> Vec<f32> {
normalize_l2(&mut embedding);
embedding
}

View File

@@ -0,0 +1,62 @@
//! Error types for WASM embeddings
use thiserror::Error;
use wasm_bindgen::prelude::*;
/// Result type for WASM embedding operations
pub type Result<T> = std::result::Result<T, WasmEmbeddingError>;
/// Errors that can occur during WASM embedding operations
#[derive(Error, Debug)]
pub enum WasmEmbeddingError {
#[error("Model error: {0}")]
Model(String),
#[error("Tokenizer error: {0}")]
Tokenizer(String),
#[error("Inference error: {0}")]
Inference(String),
#[error("Invalid input: {0}")]
InvalidInput(String),
#[error("Serialization error: {0}")]
Serialization(String),
}
impl WasmEmbeddingError {
pub fn model(msg: impl Into<String>) -> Self {
Self::Model(msg.into())
}
pub fn tokenizer(msg: impl Into<String>) -> Self {
Self::Tokenizer(msg.into())
}
pub fn inference(msg: impl Into<String>) -> Self {
Self::Inference(msg.into())
}
pub fn invalid_input(msg: impl Into<String>) -> Self {
Self::InvalidInput(msg.into())
}
}
impl From<WasmEmbeddingError> for JsValue {
fn from(err: WasmEmbeddingError) -> Self {
JsValue::from_str(&err.to_string())
}
}
impl From<tract_onnx::prelude::TractError> for WasmEmbeddingError {
fn from(err: tract_onnx::prelude::TractError) -> Self {
Self::Model(err.to_string())
}
}
impl From<serde_json::Error> for WasmEmbeddingError {
fn from(err: serde_json::Error) -> Self {
Self::Serialization(err.to_string())
}
}

View File

@@ -0,0 +1,66 @@
//! # RuVector ONNX Embeddings - WASM Edition
//!
//! WASM-compatible embedding generation using Tract for inference.
//! Runs in browsers, Cloudflare Workers, Deno, and any WASM runtime.
//!
//! ## Features
//!
//! - **Browser Support**: Generate embeddings directly in the browser
//! - **Edge Computing**: Deploy to Cloudflare Workers, Vercel Edge, etc.
//! - **Portable**: Single WASM binary, no platform-specific dependencies
//! - **Same API**: Compatible with the native ruvector-onnx-embeddings crate
//!
//! ## Usage (JavaScript)
//!
//! ```javascript
//! import init, { WasmEmbedder } from 'ruvector-onnx-embeddings-wasm';
//!
//! await init();
//!
//! // Load model from bytes
//! const modelBytes = await fetch('/model.onnx').then(r => r.arrayBuffer());
//! const tokenizerJson = await fetch('/tokenizer.json').then(r => r.text());
//!
//! const embedder = new WasmEmbedder(new Uint8Array(modelBytes), tokenizerJson);
//!
//! // Generate embeddings
//! const embedding = embedder.embed_one("Hello, world!");
//! console.log("Embedding dimension:", embedding.length);
//!
//! // Compute similarity
//! const similarity = embedder.similarity("I love Rust", "Rust is great");
//! console.log("Similarity:", similarity);
//! ```
mod embedder;
mod error;
mod model;
mod pooling;
mod tokenizer;
pub use embedder::{WasmEmbedder, WasmEmbedderConfig};
pub use error::WasmEmbeddingError;
pub use pooling::PoolingStrategy;
use wasm_bindgen::prelude::*;
/// Initialize panic hook for better error messages in WASM
#[wasm_bindgen(start)]
pub fn init() {
#[cfg(feature = "console_error_panic_hook")]
console_error_panic_hook::set_once();
}
/// Get the library version
#[wasm_bindgen]
pub fn version() -> String {
env!("CARGO_PKG_VERSION").to_string()
}
/// Check if SIMD is available (for performance info)
/// Returns true if compiled with WASM SIMD128 support
#[wasm_bindgen]
pub fn simd_available() -> bool {
// Check if compiled with SIMD128 target feature
cfg!(target_feature = "simd128")
}

View File

@@ -0,0 +1,116 @@
//! Tract-based ONNX model for WASM inference
use crate::error::{Result, WasmEmbeddingError};
use crate::tokenizer::EncodedInput;
use tract_onnx::prelude::*;
/// Tract ONNX model wrapper for WASM
pub struct TractModel {
model: SimplePlan<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>,
hidden_size: usize,
}
impl TractModel {
/// Load model from ONNX bytes
pub fn from_bytes(bytes: &[u8], max_seq_length: usize) -> Result<Self> {
// Parse ONNX model
let model = tract_onnx::onnx()
.model_for_read(&mut std::io::Cursor::new(bytes))
.map_err(|e| WasmEmbeddingError::model(format!("Failed to parse ONNX: {}", e)))?;
// Set input shapes for optimization
// Standard transformer inputs: [batch, seq_len]
let batch = 1usize;
let seq_len = max_seq_length;
let model = model
.with_input_fact(
0,
InferenceFact::dt_shape(i64::datum_type(), tvec![batch, seq_len]),
)?
.with_input_fact(
1,
InferenceFact::dt_shape(i64::datum_type(), tvec![batch, seq_len]),
)?
.with_input_fact(
2,
InferenceFact::dt_shape(i64::datum_type(), tvec![batch, seq_len]),
)?;
// Optimize the model
let model = model
.into_optimized()
.map_err(|e| WasmEmbeddingError::model(format!("Failed to optimize: {}", e)))?;
let model = model
.into_runnable()
.map_err(|e| WasmEmbeddingError::model(format!("Failed to make runnable: {}", e)))?;
// Default hidden size (will be determined from output)
let hidden_size = 384;
Ok(Self { model, hidden_size })
}
/// Run inference on encoded input
pub fn run(&self, input: &EncodedInput) -> Result<Vec<f32>> {
let seq_len = input.input_ids.len();
// Create input tensors
let input_ids: Tensor = tract_ndarray::Array2::from_shape_vec(
(1, seq_len),
input.input_ids.clone(),
)
.map_err(|e| WasmEmbeddingError::inference(e.to_string()))?
.into();
let attention_mask: Tensor = tract_ndarray::Array2::from_shape_vec(
(1, seq_len),
input.attention_mask.clone(),
)
.map_err(|e| WasmEmbeddingError::inference(e.to_string()))?
.into();
let token_type_ids: Tensor = tract_ndarray::Array2::from_shape_vec(
(1, seq_len),
input.token_type_ids.clone(),
)
.map_err(|e| WasmEmbeddingError::inference(e.to_string()))?
.into();
// Run inference
let inputs = tvec![
input_ids.into(),
attention_mask.into(),
token_type_ids.into()
];
let outputs = self
.model
.run(inputs)
.map_err(|e| WasmEmbeddingError::inference(format!("Inference failed: {}", e)))?;
// Extract output tensor
// Output is typically [batch, seq_len, hidden_size] or [batch, hidden_size]
let output = outputs
.first()
.ok_or_else(|| WasmEmbeddingError::inference("No output tensor"))?;
let output_array = output
.to_array_view::<f32>()
.map_err(|e| WasmEmbeddingError::inference(format!("Failed to extract output: {}", e)))?;
// Flatten and return
Ok(output_array.iter().copied().collect())
}
/// Get the hidden size
pub fn hidden_size(&self) -> usize {
self.hidden_size
}
/// Update hidden size (called after first inference)
pub fn set_hidden_size(&mut self, size: usize) {
self.hidden_size = size;
}
}

View File

@@ -0,0 +1,181 @@
//! Pooling strategies for converting token embeddings to sentence embeddings
use serde::{Deserialize, Serialize};
use wasm_bindgen::prelude::*;
/// Strategy for pooling token embeddings into a single sentence embedding
#[wasm_bindgen]
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq)]
pub enum PoolingStrategy {
/// Average all token embeddings (most common)
#[default]
Mean,
/// Use only the [CLS] token embedding
Cls,
/// Take the maximum value across all tokens for each dimension
Max,
/// Mean pooling normalized by sqrt of sequence length
MeanSqrtLen,
/// Use the last token embedding (for decoder models)
LastToken,
}
impl PoolingStrategy {
/// Apply pooling to token embeddings
///
/// # Arguments
/// * `embeddings` - Token embeddings [seq_len, hidden_size]
/// * `attention_mask` - Attention mask [seq_len]
///
/// # Returns
/// Pooled embedding [hidden_size]
pub fn apply(&self, embeddings: &[f32], attention_mask: &[i64], hidden_size: usize) -> Vec<f32> {
let seq_len = attention_mask.len();
if embeddings.is_empty() || hidden_size == 0 {
return vec![0.0; hidden_size];
}
match self {
PoolingStrategy::Mean => {
self.mean_pooling(embeddings, attention_mask, hidden_size, seq_len)
}
PoolingStrategy::Cls => {
// First token (CLS)
embeddings[..hidden_size].to_vec()
}
PoolingStrategy::Max => {
self.max_pooling(embeddings, attention_mask, hidden_size, seq_len)
}
PoolingStrategy::MeanSqrtLen => {
let mut pooled = self.mean_pooling(embeddings, attention_mask, hidden_size, seq_len);
let valid_tokens: f32 = attention_mask.iter().map(|&m| m as f32).sum();
let scale = 1.0 / valid_tokens.sqrt();
for v in &mut pooled {
*v *= scale;
}
pooled
}
PoolingStrategy::LastToken => {
// Find last valid token
let last_idx = attention_mask
.iter()
.rposition(|&m| m == 1)
.unwrap_or(0);
let start = last_idx * hidden_size;
embeddings[start..start + hidden_size].to_vec()
}
}
}
fn mean_pooling(
&self,
embeddings: &[f32],
attention_mask: &[i64],
hidden_size: usize,
seq_len: usize,
) -> Vec<f32> {
let mut pooled = vec![0.0f32; hidden_size];
let mut count = 0.0f32;
for (i, &mask) in attention_mask.iter().enumerate() {
if mask == 1 && i < seq_len {
let start = i * hidden_size;
if start + hidden_size <= embeddings.len() {
for (j, v) in pooled.iter_mut().enumerate() {
*v += embeddings[start + j];
}
count += 1.0;
}
}
}
if count > 0.0 {
for v in &mut pooled {
*v /= count;
}
}
pooled
}
fn max_pooling(
&self,
embeddings: &[f32],
attention_mask: &[i64],
hidden_size: usize,
seq_len: usize,
) -> Vec<f32> {
let mut pooled = vec![f32::NEG_INFINITY; hidden_size];
for (i, &mask) in attention_mask.iter().enumerate() {
if mask == 1 && i < seq_len {
let start = i * hidden_size;
if start + hidden_size <= embeddings.len() {
for (j, v) in pooled.iter_mut().enumerate() {
*v = v.max(embeddings[start + j]);
}
}
}
}
// Replace -inf with 0 for dimensions with no valid tokens
for v in &mut pooled {
if v.is_infinite() {
*v = 0.0;
}
}
pooled
}
}
/// L2 normalize a vector in place
pub fn normalize_l2(embedding: &mut [f32]) {
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for v in embedding {
*v /= norm;
}
}
}
/// Compute cosine similarity between two embeddings
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a > 0.0 && norm_b > 0.0 {
dot / (norm_a * norm_b)
} else {
0.0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
let c = vec![0.0, 1.0, 0.0];
assert!(cosine_similarity(&a, &c).abs() < 1e-6);
}
#[test]
fn test_normalize_l2() {
let mut v = vec![3.0, 4.0];
normalize_l2(&mut v);
assert!((v[0] - 0.6).abs() < 1e-6);
assert!((v[1] - 0.8).abs() < 1e-6);
}
}

View File

@@ -0,0 +1,114 @@
//! Tokenizer wrapper for WASM embedding generation
use crate::error::{Result, WasmEmbeddingError};
use tokenizers::Tokenizer;
/// Tokenizer wrapper that handles text encoding
pub struct WasmTokenizer {
tokenizer: Tokenizer,
max_length: usize,
}
/// Encoded text ready for model inference
#[derive(Debug, Clone)]
pub struct EncodedInput {
pub input_ids: Vec<i64>,
pub attention_mask: Vec<i64>,
pub token_type_ids: Vec<i64>,
}
impl WasmTokenizer {
/// Create a new tokenizer from JSON configuration
pub fn from_json(json: &str, max_length: usize) -> Result<Self> {
let tokenizer = Tokenizer::from_bytes(json.as_bytes())
.map_err(|e| WasmEmbeddingError::tokenizer(e.to_string()))?;
Ok(Self {
tokenizer,
max_length,
})
}
/// Create tokenizer from raw bytes
pub fn from_bytes(bytes: &[u8], max_length: usize) -> Result<Self> {
let tokenizer = Tokenizer::from_bytes(bytes)
.map_err(|e| WasmEmbeddingError::tokenizer(e.to_string()))?;
Ok(Self {
tokenizer,
max_length,
})
}
/// Encode a single text
pub fn encode(&self, text: &str) -> Result<EncodedInput> {
let encoding = self
.tokenizer
.encode(text, true)
.map_err(|e| WasmEmbeddingError::tokenizer(e.to_string()))?;
let mut input_ids: Vec<i64> = encoding.get_ids().iter().map(|&id| id as i64).collect();
let mut attention_mask: Vec<i64> =
encoding.get_attention_mask().iter().map(|&m| m as i64).collect();
let mut token_type_ids: Vec<i64> =
encoding.get_type_ids().iter().map(|&t| t as i64).collect();
// Truncate if necessary
if input_ids.len() > self.max_length {
input_ids.truncate(self.max_length);
attention_mask.truncate(self.max_length);
token_type_ids.truncate(self.max_length);
}
// Pad if necessary
while input_ids.len() < self.max_length {
input_ids.push(0);
attention_mask.push(0);
token_type_ids.push(0);
}
Ok(EncodedInput {
input_ids,
attention_mask,
token_type_ids,
})
}
/// Encode multiple texts with padding to the same length
pub fn encode_batch(&self, texts: &[&str]) -> Result<Vec<EncodedInput>> {
texts.iter().map(|text| self.encode(text)).collect()
}
/// Get the maximum sequence length
pub fn max_length(&self) -> usize {
self.max_length
}
}
#[cfg(test)]
mod tests {
use super::*;
// Basic tokenizer JSON for testing
const TEST_TOKENIZER: &str = r#"{
"version": "1.0",
"truncation": null,
"padding": null,
"added_tokens": [],
"normalizer": null,
"pre_tokenizer": {"type": "Whitespace"},
"post_processor": null,
"decoder": null,
"model": {
"type": "WordLevel",
"vocab": {"[PAD]": 0, "[UNK]": 1, "hello": 2, "world": 3},
"unk_token": "[UNK]"
}
}"#;
#[test]
fn test_tokenizer_creation() {
let tokenizer = WasmTokenizer::from_json(TEST_TOKENIZER, 128);
assert!(tokenizer.is_ok());
}
}