Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
213
vendor/ruvector/examples/onnx-embeddings-wasm/src/embedder.rs
vendored
Normal file
213
vendor/ruvector/examples/onnx-embeddings-wasm/src/embedder.rs
vendored
Normal file
@@ -0,0 +1,213 @@
|
||||
//! Main WASM embedder implementation
|
||||
|
||||
use crate::error::{Result, WasmEmbeddingError};
|
||||
use crate::model::TractModel;
|
||||
use crate::pooling::{cosine_similarity, normalize_l2, PoolingStrategy};
|
||||
use crate::tokenizer::WasmTokenizer;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
/// Configuration for the WASM embedder
|
||||
#[wasm_bindgen]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WasmEmbedderConfig {
|
||||
/// Maximum sequence length
|
||||
#[wasm_bindgen(skip)]
|
||||
pub max_length: usize,
|
||||
/// Pooling strategy
|
||||
#[wasm_bindgen(skip)]
|
||||
pub pooling: PoolingStrategy,
|
||||
/// Whether to L2 normalize embeddings
|
||||
#[wasm_bindgen(skip)]
|
||||
pub normalize: bool,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmEmbedderConfig {
|
||||
/// Create a new configuration
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Set maximum sequence length
|
||||
#[wasm_bindgen(js_name = setMaxLength)]
|
||||
pub fn set_max_length(mut self, max_length: usize) -> Self {
|
||||
self.max_length = max_length;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set whether to normalize embeddings
|
||||
#[wasm_bindgen(js_name = setNormalize)]
|
||||
pub fn set_normalize(mut self, normalize: bool) -> Self {
|
||||
self.normalize = normalize;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set pooling strategy (0=Mean, 1=Cls, 2=Max, 3=MeanSqrtLen, 4=LastToken)
|
||||
#[wasm_bindgen(js_name = setPooling)]
|
||||
pub fn set_pooling(mut self, pooling: u8) -> Self {
|
||||
self.pooling = match pooling {
|
||||
0 => PoolingStrategy::Mean,
|
||||
1 => PoolingStrategy::Cls,
|
||||
2 => PoolingStrategy::Max,
|
||||
3 => PoolingStrategy::MeanSqrtLen,
|
||||
4 => PoolingStrategy::LastToken,
|
||||
_ => PoolingStrategy::Mean,
|
||||
};
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for WasmEmbedderConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_length: 256,
|
||||
pooling: PoolingStrategy::Mean,
|
||||
normalize: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// WASM-compatible embedder using Tract for inference
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmEmbedder {
|
||||
model: TractModel,
|
||||
tokenizer: WasmTokenizer,
|
||||
config: WasmEmbedderConfig,
|
||||
hidden_size: usize,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmEmbedder {
|
||||
/// Create a new embedder from model and tokenizer bytes
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `model_bytes` - ONNX model file bytes
|
||||
/// * `tokenizer_json` - Tokenizer JSON configuration
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(model_bytes: &[u8], tokenizer_json: &str) -> std::result::Result<WasmEmbedder, JsValue> {
|
||||
Self::with_config(model_bytes, tokenizer_json, WasmEmbedderConfig::default())
|
||||
}
|
||||
|
||||
/// Create embedder with custom configuration
|
||||
#[wasm_bindgen(js_name = withConfig)]
|
||||
pub fn with_config(
|
||||
model_bytes: &[u8],
|
||||
tokenizer_json: &str,
|
||||
config: WasmEmbedderConfig,
|
||||
) -> std::result::Result<WasmEmbedder, JsValue> {
|
||||
let model = TractModel::from_bytes(model_bytes, config.max_length)
|
||||
.map_err(|e| JsValue::from_str(&e.to_string()))?;
|
||||
|
||||
let tokenizer = WasmTokenizer::from_json(tokenizer_json, config.max_length)
|
||||
.map_err(|e| JsValue::from_str(&e.to_string()))?;
|
||||
|
||||
let hidden_size = model.hidden_size();
|
||||
|
||||
Ok(Self {
|
||||
model,
|
||||
tokenizer,
|
||||
config,
|
||||
hidden_size,
|
||||
})
|
||||
}
|
||||
|
||||
/// Generate embedding for a single text
|
||||
#[wasm_bindgen(js_name = embedOne)]
|
||||
pub fn embed_one(&mut self, text: &str) -> std::result::Result<Vec<f32>, JsValue> {
|
||||
self.embed_one_internal(text)
|
||||
.map_err(|e| JsValue::from_str(&e.to_string()))
|
||||
}
|
||||
|
||||
/// Generate embeddings for multiple texts
|
||||
#[wasm_bindgen(js_name = embedBatch)]
|
||||
pub fn embed_batch(&mut self, texts: Vec<String>) -> std::result::Result<Vec<f32>, JsValue> {
|
||||
let refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
|
||||
self.embed_batch_internal(&refs)
|
||||
.map_err(|e| JsValue::from_str(&e.to_string()))
|
||||
}
|
||||
|
||||
/// Compute similarity between two texts
|
||||
#[wasm_bindgen]
|
||||
pub fn similarity(&mut self, text1: &str, text2: &str) -> std::result::Result<f32, JsValue> {
|
||||
let emb1 = self.embed_one_internal(text1)
|
||||
.map_err(|e| JsValue::from_str(&e.to_string()))?;
|
||||
let emb2 = self.embed_one_internal(text2)
|
||||
.map_err(|e| JsValue::from_str(&e.to_string()))?;
|
||||
|
||||
Ok(cosine_similarity(&emb1, &emb2))
|
||||
}
|
||||
|
||||
/// Get the embedding dimension
|
||||
#[wasm_bindgen]
|
||||
pub fn dimension(&self) -> usize {
|
||||
self.hidden_size
|
||||
}
|
||||
|
||||
/// Get maximum sequence length
|
||||
#[wasm_bindgen(js_name = maxLength)]
|
||||
pub fn max_length(&self) -> usize {
|
||||
self.config.max_length
|
||||
}
|
||||
}
|
||||
|
||||
// Internal implementation
|
||||
impl WasmEmbedder {
|
||||
fn embed_one_internal(&mut self, text: &str) -> Result<Vec<f32>> {
|
||||
// Tokenize
|
||||
let encoded = self.tokenizer.encode(text)?;
|
||||
let attention_mask = encoded.attention_mask.clone();
|
||||
|
||||
// Run inference
|
||||
let raw_output = self.model.run(&encoded)?;
|
||||
|
||||
// Determine hidden size from output
|
||||
let seq_len = self.config.max_length;
|
||||
if raw_output.len() >= seq_len {
|
||||
let detected_hidden = raw_output.len() / seq_len;
|
||||
if detected_hidden != self.hidden_size && detected_hidden > 0 {
|
||||
self.hidden_size = detected_hidden;
|
||||
self.model.set_hidden_size(detected_hidden);
|
||||
}
|
||||
}
|
||||
|
||||
// Apply pooling
|
||||
let mut embedding = self.config.pooling.apply(
|
||||
&raw_output,
|
||||
&attention_mask,
|
||||
self.hidden_size,
|
||||
);
|
||||
|
||||
// Normalize if configured
|
||||
if self.config.normalize {
|
||||
normalize_l2(&mut embedding);
|
||||
}
|
||||
|
||||
Ok(embedding)
|
||||
}
|
||||
|
||||
fn embed_batch_internal(&mut self, texts: &[&str]) -> Result<Vec<f32>> {
|
||||
let mut all_embeddings = Vec::with_capacity(texts.len() * self.hidden_size);
|
||||
|
||||
for text in texts {
|
||||
let embedding = self.embed_one_internal(text)?;
|
||||
all_embeddings.extend(embedding);
|
||||
}
|
||||
|
||||
Ok(all_embeddings)
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute cosine similarity between two embedding vectors (JS-friendly)
|
||||
#[wasm_bindgen(js_name = cosineSimilarity)]
|
||||
pub fn js_cosine_similarity(a: Vec<f32>, b: Vec<f32>) -> f32 {
|
||||
cosine_similarity(&a, &b)
|
||||
}
|
||||
|
||||
/// L2 normalize an embedding vector (JS-friendly)
|
||||
#[wasm_bindgen(js_name = normalizeL2)]
|
||||
pub fn js_normalize_l2(mut embedding: Vec<f32>) -> Vec<f32> {
|
||||
normalize_l2(&mut embedding);
|
||||
embedding
|
||||
}
|
||||
62
vendor/ruvector/examples/onnx-embeddings-wasm/src/error.rs
vendored
Normal file
62
vendor/ruvector/examples/onnx-embeddings-wasm/src/error.rs
vendored
Normal file
@@ -0,0 +1,62 @@
|
||||
//! Error types for WASM embeddings
|
||||
|
||||
use thiserror::Error;
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
/// Result type for WASM embedding operations
|
||||
pub type Result<T> = std::result::Result<T, WasmEmbeddingError>;
|
||||
|
||||
/// Errors that can occur during WASM embedding operations
|
||||
#[derive(Error, Debug)]
|
||||
pub enum WasmEmbeddingError {
|
||||
#[error("Model error: {0}")]
|
||||
Model(String),
|
||||
|
||||
#[error("Tokenizer error: {0}")]
|
||||
Tokenizer(String),
|
||||
|
||||
#[error("Inference error: {0}")]
|
||||
Inference(String),
|
||||
|
||||
#[error("Invalid input: {0}")]
|
||||
InvalidInput(String),
|
||||
|
||||
#[error("Serialization error: {0}")]
|
||||
Serialization(String),
|
||||
}
|
||||
|
||||
impl WasmEmbeddingError {
|
||||
pub fn model(msg: impl Into<String>) -> Self {
|
||||
Self::Model(msg.into())
|
||||
}
|
||||
|
||||
pub fn tokenizer(msg: impl Into<String>) -> Self {
|
||||
Self::Tokenizer(msg.into())
|
||||
}
|
||||
|
||||
pub fn inference(msg: impl Into<String>) -> Self {
|
||||
Self::Inference(msg.into())
|
||||
}
|
||||
|
||||
pub fn invalid_input(msg: impl Into<String>) -> Self {
|
||||
Self::InvalidInput(msg.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<WasmEmbeddingError> for JsValue {
|
||||
fn from(err: WasmEmbeddingError) -> Self {
|
||||
JsValue::from_str(&err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<tract_onnx::prelude::TractError> for WasmEmbeddingError {
|
||||
fn from(err: tract_onnx::prelude::TractError) -> Self {
|
||||
Self::Model(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<serde_json::Error> for WasmEmbeddingError {
|
||||
fn from(err: serde_json::Error) -> Self {
|
||||
Self::Serialization(err.to_string())
|
||||
}
|
||||
}
|
||||
66
vendor/ruvector/examples/onnx-embeddings-wasm/src/lib.rs
vendored
Normal file
66
vendor/ruvector/examples/onnx-embeddings-wasm/src/lib.rs
vendored
Normal file
@@ -0,0 +1,66 @@
|
||||
//! # RuVector ONNX Embeddings - WASM Edition
|
||||
//!
|
||||
//! WASM-compatible embedding generation using Tract for inference.
|
||||
//! Runs in browsers, Cloudflare Workers, Deno, and any WASM runtime.
|
||||
//!
|
||||
//! ## Features
|
||||
//!
|
||||
//! - **Browser Support**: Generate embeddings directly in the browser
|
||||
//! - **Edge Computing**: Deploy to Cloudflare Workers, Vercel Edge, etc.
|
||||
//! - **Portable**: Single WASM binary, no platform-specific dependencies
|
||||
//! - **Same API**: Compatible with the native ruvector-onnx-embeddings crate
|
||||
//!
|
||||
//! ## Usage (JavaScript)
|
||||
//!
|
||||
//! ```javascript
|
||||
//! import init, { WasmEmbedder } from 'ruvector-onnx-embeddings-wasm';
|
||||
//!
|
||||
//! await init();
|
||||
//!
|
||||
//! // Load model from bytes
|
||||
//! const modelBytes = await fetch('/model.onnx').then(r => r.arrayBuffer());
|
||||
//! const tokenizerJson = await fetch('/tokenizer.json').then(r => r.text());
|
||||
//!
|
||||
//! const embedder = new WasmEmbedder(new Uint8Array(modelBytes), tokenizerJson);
|
||||
//!
|
||||
//! // Generate embeddings
|
||||
//! const embedding = embedder.embed_one("Hello, world!");
|
||||
//! console.log("Embedding dimension:", embedding.length);
|
||||
//!
|
||||
//! // Compute similarity
|
||||
//! const similarity = embedder.similarity("I love Rust", "Rust is great");
|
||||
//! console.log("Similarity:", similarity);
|
||||
//! ```
|
||||
|
||||
mod embedder;
|
||||
mod error;
|
||||
mod model;
|
||||
mod pooling;
|
||||
mod tokenizer;
|
||||
|
||||
pub use embedder::{WasmEmbedder, WasmEmbedderConfig};
|
||||
pub use error::WasmEmbeddingError;
|
||||
pub use pooling::PoolingStrategy;
|
||||
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
/// Initialize panic hook for better error messages in WASM
|
||||
#[wasm_bindgen(start)]
|
||||
pub fn init() {
|
||||
#[cfg(feature = "console_error_panic_hook")]
|
||||
console_error_panic_hook::set_once();
|
||||
}
|
||||
|
||||
/// Get the library version
|
||||
#[wasm_bindgen]
|
||||
pub fn version() -> String {
|
||||
env!("CARGO_PKG_VERSION").to_string()
|
||||
}
|
||||
|
||||
/// Check if SIMD is available (for performance info)
|
||||
/// Returns true if compiled with WASM SIMD128 support
|
||||
#[wasm_bindgen]
|
||||
pub fn simd_available() -> bool {
|
||||
// Check if compiled with SIMD128 target feature
|
||||
cfg!(target_feature = "simd128")
|
||||
}
|
||||
116
vendor/ruvector/examples/onnx-embeddings-wasm/src/model.rs
vendored
Normal file
116
vendor/ruvector/examples/onnx-embeddings-wasm/src/model.rs
vendored
Normal file
@@ -0,0 +1,116 @@
|
||||
//! Tract-based ONNX model for WASM inference
|
||||
|
||||
use crate::error::{Result, WasmEmbeddingError};
|
||||
use crate::tokenizer::EncodedInput;
|
||||
use tract_onnx::prelude::*;
|
||||
|
||||
/// Tract ONNX model wrapper for WASM
|
||||
pub struct TractModel {
|
||||
model: SimplePlan<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>,
|
||||
hidden_size: usize,
|
||||
}
|
||||
|
||||
impl TractModel {
|
||||
/// Load model from ONNX bytes
|
||||
pub fn from_bytes(bytes: &[u8], max_seq_length: usize) -> Result<Self> {
|
||||
// Parse ONNX model
|
||||
let model = tract_onnx::onnx()
|
||||
.model_for_read(&mut std::io::Cursor::new(bytes))
|
||||
.map_err(|e| WasmEmbeddingError::model(format!("Failed to parse ONNX: {}", e)))?;
|
||||
|
||||
// Set input shapes for optimization
|
||||
// Standard transformer inputs: [batch, seq_len]
|
||||
let batch = 1usize;
|
||||
let seq_len = max_seq_length;
|
||||
|
||||
let model = model
|
||||
.with_input_fact(
|
||||
0,
|
||||
InferenceFact::dt_shape(i64::datum_type(), tvec![batch, seq_len]),
|
||||
)?
|
||||
.with_input_fact(
|
||||
1,
|
||||
InferenceFact::dt_shape(i64::datum_type(), tvec![batch, seq_len]),
|
||||
)?
|
||||
.with_input_fact(
|
||||
2,
|
||||
InferenceFact::dt_shape(i64::datum_type(), tvec![batch, seq_len]),
|
||||
)?;
|
||||
|
||||
// Optimize the model
|
||||
let model = model
|
||||
.into_optimized()
|
||||
.map_err(|e| WasmEmbeddingError::model(format!("Failed to optimize: {}", e)))?;
|
||||
|
||||
let model = model
|
||||
.into_runnable()
|
||||
.map_err(|e| WasmEmbeddingError::model(format!("Failed to make runnable: {}", e)))?;
|
||||
|
||||
// Default hidden size (will be determined from output)
|
||||
let hidden_size = 384;
|
||||
|
||||
Ok(Self { model, hidden_size })
|
||||
}
|
||||
|
||||
/// Run inference on encoded input
|
||||
pub fn run(&self, input: &EncodedInput) -> Result<Vec<f32>> {
|
||||
let seq_len = input.input_ids.len();
|
||||
|
||||
// Create input tensors
|
||||
let input_ids: Tensor = tract_ndarray::Array2::from_shape_vec(
|
||||
(1, seq_len),
|
||||
input.input_ids.clone(),
|
||||
)
|
||||
.map_err(|e| WasmEmbeddingError::inference(e.to_string()))?
|
||||
.into();
|
||||
|
||||
let attention_mask: Tensor = tract_ndarray::Array2::from_shape_vec(
|
||||
(1, seq_len),
|
||||
input.attention_mask.clone(),
|
||||
)
|
||||
.map_err(|e| WasmEmbeddingError::inference(e.to_string()))?
|
||||
.into();
|
||||
|
||||
let token_type_ids: Tensor = tract_ndarray::Array2::from_shape_vec(
|
||||
(1, seq_len),
|
||||
input.token_type_ids.clone(),
|
||||
)
|
||||
.map_err(|e| WasmEmbeddingError::inference(e.to_string()))?
|
||||
.into();
|
||||
|
||||
// Run inference
|
||||
let inputs = tvec![
|
||||
input_ids.into(),
|
||||
attention_mask.into(),
|
||||
token_type_ids.into()
|
||||
];
|
||||
|
||||
let outputs = self
|
||||
.model
|
||||
.run(inputs)
|
||||
.map_err(|e| WasmEmbeddingError::inference(format!("Inference failed: {}", e)))?;
|
||||
|
||||
// Extract output tensor
|
||||
// Output is typically [batch, seq_len, hidden_size] or [batch, hidden_size]
|
||||
let output = outputs
|
||||
.first()
|
||||
.ok_or_else(|| WasmEmbeddingError::inference("No output tensor"))?;
|
||||
|
||||
let output_array = output
|
||||
.to_array_view::<f32>()
|
||||
.map_err(|e| WasmEmbeddingError::inference(format!("Failed to extract output: {}", e)))?;
|
||||
|
||||
// Flatten and return
|
||||
Ok(output_array.iter().copied().collect())
|
||||
}
|
||||
|
||||
/// Get the hidden size
|
||||
pub fn hidden_size(&self) -> usize {
|
||||
self.hidden_size
|
||||
}
|
||||
|
||||
/// Update hidden size (called after first inference)
|
||||
pub fn set_hidden_size(&mut self, size: usize) {
|
||||
self.hidden_size = size;
|
||||
}
|
||||
}
|
||||
181
vendor/ruvector/examples/onnx-embeddings-wasm/src/pooling.rs
vendored
Normal file
181
vendor/ruvector/examples/onnx-embeddings-wasm/src/pooling.rs
vendored
Normal file
@@ -0,0 +1,181 @@
|
||||
//! Pooling strategies for converting token embeddings to sentence embeddings
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
/// Strategy for pooling token embeddings into a single sentence embedding
|
||||
#[wasm_bindgen]
|
||||
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq)]
|
||||
pub enum PoolingStrategy {
|
||||
/// Average all token embeddings (most common)
|
||||
#[default]
|
||||
Mean,
|
||||
/// Use only the [CLS] token embedding
|
||||
Cls,
|
||||
/// Take the maximum value across all tokens for each dimension
|
||||
Max,
|
||||
/// Mean pooling normalized by sqrt of sequence length
|
||||
MeanSqrtLen,
|
||||
/// Use the last token embedding (for decoder models)
|
||||
LastToken,
|
||||
}
|
||||
|
||||
impl PoolingStrategy {
|
||||
/// Apply pooling to token embeddings
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `embeddings` - Token embeddings [seq_len, hidden_size]
|
||||
/// * `attention_mask` - Attention mask [seq_len]
|
||||
///
|
||||
/// # Returns
|
||||
/// Pooled embedding [hidden_size]
|
||||
pub fn apply(&self, embeddings: &[f32], attention_mask: &[i64], hidden_size: usize) -> Vec<f32> {
|
||||
let seq_len = attention_mask.len();
|
||||
|
||||
if embeddings.is_empty() || hidden_size == 0 {
|
||||
return vec![0.0; hidden_size];
|
||||
}
|
||||
|
||||
match self {
|
||||
PoolingStrategy::Mean => {
|
||||
self.mean_pooling(embeddings, attention_mask, hidden_size, seq_len)
|
||||
}
|
||||
PoolingStrategy::Cls => {
|
||||
// First token (CLS)
|
||||
embeddings[..hidden_size].to_vec()
|
||||
}
|
||||
PoolingStrategy::Max => {
|
||||
self.max_pooling(embeddings, attention_mask, hidden_size, seq_len)
|
||||
}
|
||||
PoolingStrategy::MeanSqrtLen => {
|
||||
let mut pooled = self.mean_pooling(embeddings, attention_mask, hidden_size, seq_len);
|
||||
let valid_tokens: f32 = attention_mask.iter().map(|&m| m as f32).sum();
|
||||
let scale = 1.0 / valid_tokens.sqrt();
|
||||
for v in &mut pooled {
|
||||
*v *= scale;
|
||||
}
|
||||
pooled
|
||||
}
|
||||
PoolingStrategy::LastToken => {
|
||||
// Find last valid token
|
||||
let last_idx = attention_mask
|
||||
.iter()
|
||||
.rposition(|&m| m == 1)
|
||||
.unwrap_or(0);
|
||||
let start = last_idx * hidden_size;
|
||||
embeddings[start..start + hidden_size].to_vec()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn mean_pooling(
|
||||
&self,
|
||||
embeddings: &[f32],
|
||||
attention_mask: &[i64],
|
||||
hidden_size: usize,
|
||||
seq_len: usize,
|
||||
) -> Vec<f32> {
|
||||
let mut pooled = vec![0.0f32; hidden_size];
|
||||
let mut count = 0.0f32;
|
||||
|
||||
for (i, &mask) in attention_mask.iter().enumerate() {
|
||||
if mask == 1 && i < seq_len {
|
||||
let start = i * hidden_size;
|
||||
if start + hidden_size <= embeddings.len() {
|
||||
for (j, v) in pooled.iter_mut().enumerate() {
|
||||
*v += embeddings[start + j];
|
||||
}
|
||||
count += 1.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if count > 0.0 {
|
||||
for v in &mut pooled {
|
||||
*v /= count;
|
||||
}
|
||||
}
|
||||
|
||||
pooled
|
||||
}
|
||||
|
||||
fn max_pooling(
|
||||
&self,
|
||||
embeddings: &[f32],
|
||||
attention_mask: &[i64],
|
||||
hidden_size: usize,
|
||||
seq_len: usize,
|
||||
) -> Vec<f32> {
|
||||
let mut pooled = vec![f32::NEG_INFINITY; hidden_size];
|
||||
|
||||
for (i, &mask) in attention_mask.iter().enumerate() {
|
||||
if mask == 1 && i < seq_len {
|
||||
let start = i * hidden_size;
|
||||
if start + hidden_size <= embeddings.len() {
|
||||
for (j, v) in pooled.iter_mut().enumerate() {
|
||||
*v = v.max(embeddings[start + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Replace -inf with 0 for dimensions with no valid tokens
|
||||
for v in &mut pooled {
|
||||
if v.is_infinite() {
|
||||
*v = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
pooled
|
||||
}
|
||||
}
|
||||
|
||||
/// L2 normalize a vector in place
|
||||
pub fn normalize_l2(embedding: &mut [f32]) {
|
||||
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm > 0.0 {
|
||||
for v in embedding {
|
||||
*v /= norm;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute cosine similarity between two embeddings
|
||||
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||
if a.len() != b.len() || a.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
|
||||
if norm_a > 0.0 && norm_b > 0.0 {
|
||||
dot / (norm_a * norm_b)
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_cosine_similarity() {
|
||||
let a = vec![1.0, 0.0, 0.0];
|
||||
let b = vec![1.0, 0.0, 0.0];
|
||||
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
|
||||
|
||||
let c = vec![0.0, 1.0, 0.0];
|
||||
assert!(cosine_similarity(&a, &c).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize_l2() {
|
||||
let mut v = vec![3.0, 4.0];
|
||||
normalize_l2(&mut v);
|
||||
assert!((v[0] - 0.6).abs() < 1e-6);
|
||||
assert!((v[1] - 0.8).abs() < 1e-6);
|
||||
}
|
||||
}
|
||||
114
vendor/ruvector/examples/onnx-embeddings-wasm/src/tokenizer.rs
vendored
Normal file
114
vendor/ruvector/examples/onnx-embeddings-wasm/src/tokenizer.rs
vendored
Normal file
@@ -0,0 +1,114 @@
|
||||
//! Tokenizer wrapper for WASM embedding generation
|
||||
|
||||
use crate::error::{Result, WasmEmbeddingError};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
/// Tokenizer wrapper that handles text encoding
|
||||
pub struct WasmTokenizer {
|
||||
tokenizer: Tokenizer,
|
||||
max_length: usize,
|
||||
}
|
||||
|
||||
/// Encoded text ready for model inference
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EncodedInput {
|
||||
pub input_ids: Vec<i64>,
|
||||
pub attention_mask: Vec<i64>,
|
||||
pub token_type_ids: Vec<i64>,
|
||||
}
|
||||
|
||||
impl WasmTokenizer {
|
||||
/// Create a new tokenizer from JSON configuration
|
||||
pub fn from_json(json: &str, max_length: usize) -> Result<Self> {
|
||||
let tokenizer = Tokenizer::from_bytes(json.as_bytes())
|
||||
.map_err(|e| WasmEmbeddingError::tokenizer(e.to_string()))?;
|
||||
|
||||
Ok(Self {
|
||||
tokenizer,
|
||||
max_length,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create tokenizer from raw bytes
|
||||
pub fn from_bytes(bytes: &[u8], max_length: usize) -> Result<Self> {
|
||||
let tokenizer = Tokenizer::from_bytes(bytes)
|
||||
.map_err(|e| WasmEmbeddingError::tokenizer(e.to_string()))?;
|
||||
|
||||
Ok(Self {
|
||||
tokenizer,
|
||||
max_length,
|
||||
})
|
||||
}
|
||||
|
||||
/// Encode a single text
|
||||
pub fn encode(&self, text: &str) -> Result<EncodedInput> {
|
||||
let encoding = self
|
||||
.tokenizer
|
||||
.encode(text, true)
|
||||
.map_err(|e| WasmEmbeddingError::tokenizer(e.to_string()))?;
|
||||
|
||||
let mut input_ids: Vec<i64> = encoding.get_ids().iter().map(|&id| id as i64).collect();
|
||||
let mut attention_mask: Vec<i64> =
|
||||
encoding.get_attention_mask().iter().map(|&m| m as i64).collect();
|
||||
let mut token_type_ids: Vec<i64> =
|
||||
encoding.get_type_ids().iter().map(|&t| t as i64).collect();
|
||||
|
||||
// Truncate if necessary
|
||||
if input_ids.len() > self.max_length {
|
||||
input_ids.truncate(self.max_length);
|
||||
attention_mask.truncate(self.max_length);
|
||||
token_type_ids.truncate(self.max_length);
|
||||
}
|
||||
|
||||
// Pad if necessary
|
||||
while input_ids.len() < self.max_length {
|
||||
input_ids.push(0);
|
||||
attention_mask.push(0);
|
||||
token_type_ids.push(0);
|
||||
}
|
||||
|
||||
Ok(EncodedInput {
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
})
|
||||
}
|
||||
|
||||
/// Encode multiple texts with padding to the same length
|
||||
pub fn encode_batch(&self, texts: &[&str]) -> Result<Vec<EncodedInput>> {
|
||||
texts.iter().map(|text| self.encode(text)).collect()
|
||||
}
|
||||
|
||||
/// Get the maximum sequence length
|
||||
pub fn max_length(&self) -> usize {
|
||||
self.max_length
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// Basic tokenizer JSON for testing
|
||||
const TEST_TOKENIZER: &str = r#"{
|
||||
"version": "1.0",
|
||||
"truncation": null,
|
||||
"padding": null,
|
||||
"added_tokens": [],
|
||||
"normalizer": null,
|
||||
"pre_tokenizer": {"type": "Whitespace"},
|
||||
"post_processor": null,
|
||||
"decoder": null,
|
||||
"model": {
|
||||
"type": "WordLevel",
|
||||
"vocab": {"[PAD]": 0, "[UNK]": 1, "hello": 2, "world": 3},
|
||||
"unk_token": "[UNK]"
|
||||
}
|
||||
}"#;
|
||||
|
||||
#[test]
|
||||
fn test_tokenizer_creation() {
|
||||
let tokenizer = WasmTokenizer::from_json(TEST_TOKENIZER, 128);
|
||||
assert!(tokenizer.is_ok());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user