Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

1983
vendor/ruvector/examples/onnx-embeddings-wasm/Cargo.lock generated vendored Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,61 @@
[package]
name = "ruvector-onnx-embeddings-wasm"
version = "0.1.2"
edition = "2021"
authors = ["RuVector Team"]
description = "WASM embedding generation with SIMD - runs in browsers, Cloudflare Workers, Deno, and edge runtimes"
license = "MIT"
repository = "https://github.com/ruvnet/ruvector"
keywords = ["onnx", "embeddings", "wasm", "webassembly", "ml"]
categories = ["wasm", "science", "algorithms"]
# Standalone package
[workspace]
[lib]
crate-type = ["cdylib", "rlib"]
[dependencies]
# Tract - ONNX inference that compiles to WASM
tract-onnx = "0.21"
tract-core = "0.21"
# Tokenization - HuggingFace tokenizers (WASM compatible)
tokenizers = { version = "0.20", default-features = false, features = ["unstable_wasm"] }
# WASM bindings
wasm-bindgen = "0.2"
wasm-bindgen-futures = "0.4"
js-sys = "0.3"
web-sys = { version = "0.3", features = ["console"] }
# Serialization
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
serde-wasm-bindgen = "0.6"
# Error handling
thiserror = "2.0"
anyhow = "1.0"
# Async (WASM compatible)
futures = "0.3"
# Console logging for WASM
console_error_panic_hook = { version = "0.1", optional = true }
# Getrandom for WASM
getrandom = { version = "0.2", features = ["js"] }
[dev-dependencies]
wasm-bindgen-test = "0.3"
[features]
default = ["console_error_panic_hook"]
[profile.release]
opt-level = "s"
lto = true
[package.metadata.wasm-pack.profile.release]
wasm-opt = ["-Os", "--enable-mutable-globals"]

View File

@@ -0,0 +1,435 @@
# RuVector ONNX Embeddings WASM
[![npm version](https://img.shields.io/npm/v/ruvector-onnx-embeddings-wasm.svg)](https://www.npmjs.com/package/ruvector-onnx-embeddings-wasm)
[![crates.io](https://img.shields.io/crates/v/ruvector-onnx-embeddings-wasm.svg)](https://crates.io/crates/ruvector-onnx-embeddings-wasm)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
[![WebAssembly](https://img.shields.io/badge/WebAssembly-654FF0?logo=webassembly&logoColor=white)](https://webassembly.org/)
[![SIMD](https://img.shields.io/badge/SIMD-128bit-green)](https://webassembly.org/roadmap/)
> **Portable embedding generation with SIMD acceleration and parallel workers**
Generate text embeddings directly in browsers, Cloudflare Workers, Deno, Node.js, and any WASM runtime. Built with [Tract](https://github.com/sonos/tract) for pure Rust ONNX inference.
## Features
| Feature | Description |
|---------|-------------|
| 🌐 **Browser Support** | Generate embeddings client-side, no server needed |
| ⚡ **SIMD Acceleration** | WASM SIMD128 for vectorized operations |
| 🚀 **Parallel Workers** | Multi-threaded batch processing (3.8x speedup) |
| 🏢 **Edge Computing** | Deploy to Cloudflare Workers, Vercel Edge, Deno Deploy |
| 📦 **Zero Dependencies** | Single WASM binary, no native modules |
| 🤗 **HuggingFace Models** | Pre-configured URLs for popular models |
| 🔄 **Auto Caching** | Browser Cache API for instant reloads |
| 🎯 **Same API** | Compatible with native `ruvector-onnx-embeddings` |
## Installation
```bash
npm install ruvector-onnx-embeddings-wasm
```
## Quick Start
### Node.js (Sequential)
```javascript
import { createEmbedder, similarity, embed } from 'ruvector-onnx-embeddings-wasm/loader';
// One-liner similarity
const score = await similarity("I love dogs", "I adore puppies");
console.log(score); // ~0.85
// One-liner embedding
const embedding = await embed("Hello world");
console.log(embedding.length); // 384
// Full control
const embedder = await createEmbedder('bge-small-en-v1.5');
const emb1 = embedder.embedOne("First text");
const emb2 = embedder.embedOne("Second text");
```
### Node.js (Parallel - 3.8x faster)
```javascript
import { ParallelEmbedder } from 'ruvector-onnx-embeddings-wasm/parallel';
// Initialize with worker threads
const embedder = new ParallelEmbedder({ numWorkers: 4 });
await embedder.init('all-MiniLM-L6-v2');
// Batch embed with parallel processing
const texts = [
"Machine learning is transforming technology",
"Deep learning uses neural networks",
"Natural language processing understands text",
"Computer vision analyzes images"
];
const embeddings = await embedder.embedBatch(texts);
// Compute similarity
const sim = await embedder.similarity("I love Rust", "Rust is great");
console.log(sim); // ~0.85
// Cleanup
await embedder.shutdown();
```
### Browser (ES Modules)
```html
<script type="module">
import init, { WasmEmbedder } from 'https://unpkg.com/ruvector-onnx-embeddings-wasm/ruvector_onnx_embeddings_wasm.js';
import { createEmbedder } from 'https://unpkg.com/ruvector-onnx-embeddings-wasm/loader.js';
// Initialize WASM
await init();
// Create embedder (downloads model automatically)
const embedder = await createEmbedder('all-MiniLM-L6-v2');
// Generate embeddings
const embedding = embedder.embedOne("Hello, world!");
console.log("Dimension:", embedding.length); // 384
// Compute similarity
const sim = embedder.similarity("I love Rust", "Rust is great");
console.log("Similarity:", sim.toFixed(4)); // ~0.85
</script>
```
### Cloudflare Workers
```javascript
import { WasmEmbedder, WasmEmbedderConfig } from 'ruvector-onnx-embeddings-wasm';
export default {
async fetch(request, env) {
// Load model from R2 or KV
const modelBytes = await env.MODELS.get('model.onnx', 'arrayBuffer');
const tokenizerJson = await env.MODELS.get('tokenizer.json', 'text');
const embedder = new WasmEmbedder(
new Uint8Array(modelBytes),
tokenizerJson
);
const { text } = await request.json();
const embedding = embedder.embedOne(text);
return Response.json({
embedding: Array.from(embedding),
dimension: embedding.length
});
}
};
```
## Available Models
| Model | Dimension | Size | Speed | Quality | Best For |
|-------|-----------|------|-------|---------|----------|
| **all-MiniLM-L6-v2** ⭐ | 384 | 23MB | ⚡⚡⚡ | ⭐⭐⭐ | Default, fast |
| **all-MiniLM-L12-v2** | 384 | 33MB | ⚡⚡ | ⭐⭐⭐⭐ | Better quality |
| **bge-small-en-v1.5** | 384 | 33MB | ⚡⚡⚡ | ⭐⭐⭐⭐ | State-of-the-art |
| **bge-base-en-v1.5** | 768 | 110MB | ⚡ | ⭐⭐⭐⭐⭐ | Best quality |
| **e5-small-v2** | 384 | 33MB | ⚡⚡⚡ | ⭐⭐⭐⭐ | Search/retrieval |
| **gte-small** | 384 | 33MB | ⚡⚡⚡ | ⭐⭐⭐⭐ | Multilingual |
## Performance
### Sequential vs Parallel (Node.js)
| Batch Size | Sequential | Parallel (4 workers) | Speedup |
|------------|------------|----------------------|---------|
| 4 texts | 1,573ms | 410ms | **3.83x** |
| 8 texts | 3,105ms | 861ms | **3.61x** |
| 12 texts | 4,667ms | 1,235ms | **3.78x** |
*Tested on 16-core machine with all-MiniLM-L6-v2*
### Environment Benchmarks
| Environment | Mode | Throughput | Latency |
|-------------|------|------------|---------|
| Node.js 20 | Sequential | ~2.5 texts/sec | ~390ms |
| Node.js 20 | Parallel (4w) | ~9.7 texts/sec | ~103ms |
| Chrome (M1 Mac) | Sequential | ~50 texts/sec | ~20ms |
| Firefox (M1 Mac) | Sequential | ~45 texts/sec | ~22ms |
| Cloudflare Workers | Sequential | ~30 texts/sec | ~33ms |
| Deno | Sequential | ~75 texts/sec | ~13ms |
*Browser benchmarks with smaller inputs; Node.js with full model warmup*
### SIMD Support
WASM SIMD128 is enabled by default and provides:
- Smaller binary size (180KB reduction)
- Vectorized tensor operations
- Supported in Chrome 91+, Firefox 89+, Safari 16.4+, Node.js 16+
```javascript
import { simd_available } from 'ruvector-onnx-embeddings-wasm';
console.log('SIMD enabled:', simd_available()); // true
```
## API Reference
### ModelLoader
```javascript
import { ModelLoader, MODELS, DEFAULT_MODEL } from 'ruvector-onnx-embeddings-wasm/loader';
// List available models
console.log(ModelLoader.listModels());
// Load with progress
const loader = new ModelLoader({
cache: true,
onProgress: ({ loaded, total, percent }) => console.log(`${percent}%`)
});
const { modelBytes, tokenizerJson, config } = await loader.loadModel('all-MiniLM-L6-v2');
```
### WasmEmbedder
```typescript
class WasmEmbedder {
constructor(modelBytes: Uint8Array, tokenizerJson: string);
static withConfig(
modelBytes: Uint8Array,
tokenizerJson: string,
config: WasmEmbedderConfig
): WasmEmbedder;
embedOne(text: string): Float32Array;
embedBatch(texts: string[]): Float32Array;
similarity(text1: string, text2: string): number;
dimension(): number;
maxLength(): number;
}
```
### WasmEmbedderConfig
```typescript
class WasmEmbedderConfig {
constructor();
setMaxLength(length: number): WasmEmbedderConfig;
setNormalize(normalize: boolean): WasmEmbedderConfig;
setPooling(strategy: number): WasmEmbedderConfig;
// 0=Mean, 1=Cls, 2=Max, 3=MeanSqrtLen, 4=LastToken
}
```
### ParallelEmbedder (Node.js only)
```typescript
class ParallelEmbedder {
constructor(options?: { numWorkers?: number });
init(modelName?: string): Promise<void>;
embedOne(text: string): Promise<Float32Array>;
embedBatch(texts: string[]): Promise<number[][]>;
similarity(text1: string, text2: string): Promise<number>;
shutdown(): Promise<void>;
}
```
### Utility Functions
```typescript
function cosineSimilarity(a: Float32Array, b: Float32Array): number;
function normalizeL2(embedding: Float32Array): Float32Array;
function version(): string;
function simd_available(): boolean;
```
### Convenience Functions
```typescript
// One-liner embedding
async function embed(text: string | string[], modelName?: string): Promise<Float32Array>;
// One-liner similarity
async function similarity(text1: string, text2: string, modelName?: string): Promise<number>;
// Create configured embedder
async function createEmbedder(modelName?: string): Promise<WasmEmbedder>;
```
## Pooling Strategies
| Value | Strategy | Description |
|-------|----------|-------------|
| 0 | **Mean** | Average all tokens (default, recommended) |
| 1 | **Cls** | Use [CLS] token only (BERT-style) |
| 2 | **Max** | Max pooling across tokens |
| 3 | **MeanSqrtLen** | Mean normalized by sqrt(length) |
| 4 | **LastToken** | Last token (decoder models) |
## Comparison: Native vs WASM
| Aspect | Native (`ort`) | WASM (`tract`) |
|--------|----------------|----------------|
| Speed | ⚡⚡⚡ Native | ⚡⚡ ~2-3x slower |
| Browser | ❌ | ✅ |
| Edge Workers | ❌ | ✅ |
| Parallel | Multi-process | Worker threads |
| GPU | CUDA, TensorRT | ❌ |
| Bundle Size | ~50MB | ~7.4MB |
| SIMD | AVX2/AVX-512 | SIMD128 |
| Portability | Platform-specific | Universal |
**Use native** for: servers, high throughput, GPU acceleration
**Use WASM** for: browsers, edge, portability, simpler deployment
## Building from Source
```bash
# Install wasm-pack
cargo install wasm-pack
# Build for Node.js with SIMD
RUSTFLAGS='-C target-feature=+simd128' wasm-pack build --target nodejs --release
# Build for web with SIMD
RUSTFLAGS='-C target-feature=+simd128' wasm-pack build --target web --release
# Build for bundlers (webpack, vite) with SIMD
RUSTFLAGS='-C target-feature=+simd128' wasm-pack build --target bundler --release
# Build without SIMD (for older browsers)
wasm-pack build --target web --release
```
## Use Cases
### Semantic Search
```javascript
import { createEmbedder, cosineSimilarity } from 'ruvector-onnx-embeddings-wasm/loader';
const embedder = await createEmbedder();
// Index documents
const docs = ["Rust is fast", "Python is easy", "JavaScript runs everywhere"];
const embeddings = docs.map(d => embedder.embedOne(d));
// Search
const query = embedder.embedOne("Which language is performant?");
const scores = embeddings.map((e, i) => ({
doc: docs[i],
score: cosineSimilarity(query, e)
}));
scores.sort((a, b) => b.score - a.score);
console.log(scores[0]); // { doc: "Rust is fast", score: 0.82 }
```
### Batch Processing with Parallel Workers
```javascript
import { ParallelEmbedder } from 'ruvector-onnx-embeddings-wasm/parallel';
const embedder = new ParallelEmbedder({ numWorkers: 4 });
await embedder.init();
// Process large datasets efficiently
const documents = loadDocuments(); // Array of 1000+ texts
const batchSize = 100;
for (let i = 0; i < documents.length; i += batchSize) {
const batch = documents.slice(i, i + batchSize);
const embeddings = await embedder.embedBatch(batch);
await saveEmbeddings(embeddings);
}
await embedder.shutdown();
```
### RAG (Retrieval-Augmented Generation)
```javascript
// Build knowledge base
const knowledge = [
"RuVector is a vector database",
"Embeddings capture semantic meaning",
// ... more docs
];
const knowledgeEmbeddings = knowledge.map(k => embedder.embedOne(k));
// Retrieve relevant context for LLM
function getContext(query, topK = 3) {
const queryEmb = embedder.embedOne(query);
const scores = knowledgeEmbeddings.map((e, i) => ({
text: knowledge[i],
score: cosineSimilarity(queryEmb, e)
}));
return scores.sort((a, b) => b.score - a.score).slice(0, topK);
}
```
### Text Clustering
```javascript
const texts = [
"Machine learning is amazing",
"Deep learning uses neural networks",
"I love pizza",
"Italian food is delicious"
];
const embeddings = texts.map(t => embedder.embedOne(t));
// Use k-means or hierarchical clustering on embeddings
```
## Browser Compatibility
| Browser | SIMD | Status |
|---------|------|--------|
| Chrome 91+ | ✅ | Full support |
| Firefox 89+ | ✅ | Full support |
| Safari 16.4+ | ✅ | Full support |
| Edge 91+ | ✅ | Full support |
| Node.js 16+ | ✅ | Full support |
| Deno | ✅ | Full support |
| Cloudflare Workers | ✅ | Full support |
## Related Packages
| Package | Runtime | Use Case |
|---------|---------|----------|
| [ruvector-onnx-embeddings](https://crates.io/crates/ruvector-onnx-embeddings) | Native | High-performance servers |
| **ruvector-onnx-embeddings-wasm** | WASM | Browsers, edge, portable |
## Changelog
### v0.1.2
- Added `ParallelEmbedder` for multi-threaded batch processing (3.8x speedup)
- Worker threads support for Node.js environments
### v0.1.1
- Enabled WASM SIMD128 for vectorized operations
- Added `simd_available()` function
- Reduced binary size by 180KB
### v0.1.0
- Initial release
- HuggingFace model loader with caching
- Browser and Node.js support
- 6 pre-configured models
## License
MIT License - See [LICENSE](../../LICENSE) for details.
---
<p align="center">
<b>Part of the RuVector ecosystem</b><br>
High-performance vector operations in Rust
</p>

View File

@@ -0,0 +1,351 @@
/**
* Model Loader for RuVector ONNX Embeddings WASM
*
* Provides easy loading of pre-trained models from HuggingFace Hub
*/
/**
* Pre-configured models with their HuggingFace URLs
*/
export const MODELS = {
// Sentence Transformers - Small & Fast
'all-MiniLM-L6-v2': {
name: 'all-MiniLM-L6-v2',
dimension: 384,
maxLength: 256,
size: '23MB',
description: 'Fast, general-purpose embeddings',
model: 'https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/model.onnx',
tokenizer: 'https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/tokenizer.json',
},
'all-MiniLM-L12-v2': {
name: 'all-MiniLM-L12-v2',
dimension: 384,
maxLength: 256,
size: '33MB',
description: 'Better quality, balanced speed',
model: 'https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/onnx/model.onnx',
tokenizer: 'https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/tokenizer.json',
},
// BGE Models - State of the art
'bge-small-en-v1.5': {
name: 'bge-small-en-v1.5',
dimension: 384,
maxLength: 512,
size: '33MB',
description: 'State-of-the-art small model',
model: 'https://huggingface.co/BAAI/bge-small-en-v1.5/resolve/main/onnx/model.onnx',
tokenizer: 'https://huggingface.co/BAAI/bge-small-en-v1.5/resolve/main/tokenizer.json',
},
'bge-base-en-v1.5': {
name: 'bge-base-en-v1.5',
dimension: 768,
maxLength: 512,
size: '110MB',
description: 'Best overall quality',
model: 'https://huggingface.co/BAAI/bge-base-en-v1.5/resolve/main/onnx/model.onnx',
tokenizer: 'https://huggingface.co/BAAI/bge-base-en-v1.5/resolve/main/tokenizer.json',
},
// E5 Models - Microsoft
'e5-small-v2': {
name: 'e5-small-v2',
dimension: 384,
maxLength: 512,
size: '33MB',
description: 'Excellent for search & retrieval',
model: 'https://huggingface.co/intfloat/e5-small-v2/resolve/main/onnx/model.onnx',
tokenizer: 'https://huggingface.co/intfloat/e5-small-v2/resolve/main/tokenizer.json',
},
// GTE Models - Alibaba
'gte-small': {
name: 'gte-small',
dimension: 384,
maxLength: 512,
size: '33MB',
description: 'Good multilingual support',
model: 'https://huggingface.co/thenlper/gte-small/resolve/main/onnx/model.onnx',
tokenizer: 'https://huggingface.co/thenlper/gte-small/resolve/main/tokenizer.json',
},
};
/**
* Default model for quick start
*/
export const DEFAULT_MODEL = 'all-MiniLM-L6-v2';
/**
* Model loader with caching support
*/
export class ModelLoader {
constructor(options = {}) {
this.cache = options.cache ?? true;
this.cacheStorage = options.cacheStorage ?? 'ruvector-models';
this.onProgress = options.onProgress ?? null;
}
/**
* Load a pre-configured model by name
* @param {string} modelName - Model name from MODELS
* @returns {Promise<{modelBytes: Uint8Array, tokenizerJson: string, config: object}>}
*/
async loadModel(modelName = DEFAULT_MODEL) {
const modelConfig = MODELS[modelName];
if (!modelConfig) {
throw new Error(`Unknown model: ${modelName}. Available: ${Object.keys(MODELS).join(', ')}`);
}
console.log(`Loading model: ${modelConfig.name} (${modelConfig.size})`);
const [modelBytes, tokenizerJson] = await Promise.all([
this.fetchWithCache(modelConfig.model, `${modelName}-model.onnx`, 'arraybuffer'),
this.fetchWithCache(modelConfig.tokenizer, `${modelName}-tokenizer.json`, 'text'),
]);
return {
modelBytes: new Uint8Array(modelBytes),
tokenizerJson,
config: modelConfig,
};
}
/**
* Load model from custom URLs
* @param {string} modelUrl - URL to ONNX model
* @param {string} tokenizerUrl - URL to tokenizer.json
* @returns {Promise<{modelBytes: Uint8Array, tokenizerJson: string}>}
*/
async loadFromUrls(modelUrl, tokenizerUrl) {
const [modelBytes, tokenizerJson] = await Promise.all([
this.fetchWithCache(modelUrl, null, 'arraybuffer'),
this.fetchWithCache(tokenizerUrl, null, 'text'),
]);
return {
modelBytes: new Uint8Array(modelBytes),
tokenizerJson,
};
}
/**
* Load model from local files (Node.js)
* @param {string} modelPath - Path to ONNX model
* @param {string} tokenizerPath - Path to tokenizer.json
* @returns {Promise<{modelBytes: Uint8Array, tokenizerJson: string}>}
*/
async loadFromFiles(modelPath, tokenizerPath) {
// Node.js environment
if (typeof process !== 'undefined' && process.versions?.node) {
const fs = await import('fs/promises');
const [modelBytes, tokenizerJson] = await Promise.all([
fs.readFile(modelPath),
fs.readFile(tokenizerPath, 'utf8'),
]);
return {
modelBytes: new Uint8Array(modelBytes),
tokenizerJson,
};
}
throw new Error('loadFromFiles is only available in Node.js');
}
/**
* Fetch with optional caching (uses Cache API in browsers)
*/
async fetchWithCache(url, cacheKey, responseType) {
// Try cache first (browser only)
if (this.cache && typeof caches !== 'undefined' && cacheKey) {
try {
const cache = await caches.open(this.cacheStorage);
const cached = await cache.match(cacheKey);
if (cached) {
console.log(` Cache hit: ${cacheKey}`);
return responseType === 'arraybuffer'
? await cached.arrayBuffer()
: await cached.text();
}
} catch (e) {
// Cache API not available, continue with fetch
}
}
// Fetch from network
console.log(` Downloading: ${url}`);
const response = await this.fetchWithProgress(url);
if (!response.ok) {
throw new Error(`Failed to fetch ${url}: ${response.status} ${response.statusText}`);
}
// Clone for caching
const responseClone = response.clone();
// Cache the response (browser only)
if (this.cache && typeof caches !== 'undefined' && cacheKey) {
try {
const cache = await caches.open(this.cacheStorage);
await cache.put(cacheKey, responseClone);
} catch (e) {
// Cache write failed, continue
}
}
return responseType === 'arraybuffer'
? await response.arrayBuffer()
: await response.text();
}
/**
* Fetch with progress reporting
*/
async fetchWithProgress(url) {
const response = await fetch(url);
if (!this.onProgress || !response.body) {
return response;
}
const contentLength = response.headers.get('content-length');
if (!contentLength) {
return response;
}
const total = parseInt(contentLength, 10);
let loaded = 0;
const reader = response.body.getReader();
const chunks = [];
while (true) {
const { done, value } = await reader.read();
if (done) break;
chunks.push(value);
loaded += value.length;
this.onProgress({
loaded,
total,
percent: Math.round((loaded / total) * 100),
});
}
const body = new Uint8Array(loaded);
let position = 0;
for (const chunk of chunks) {
body.set(chunk, position);
position += chunk.length;
}
return new Response(body, {
headers: response.headers,
status: response.status,
statusText: response.statusText,
});
}
/**
* Clear cached models
*/
async clearCache() {
if (typeof caches !== 'undefined') {
await caches.delete(this.cacheStorage);
console.log('Model cache cleared');
}
}
/**
* List available models
*/
static listModels() {
return Object.entries(MODELS).map(([key, config]) => ({
id: key,
...config,
}));
}
}
/**
* Quick helper to create an embedder with a pre-configured model
*
* @example
* ```javascript
* import { createEmbedder } from './loader.js';
*
* const embedder = await createEmbedder('all-MiniLM-L6-v2');
* const embedding = embedder.embedOne("Hello world");
* ```
*/
export async function createEmbedder(modelName = DEFAULT_MODEL, wasmModule = null) {
// Import WASM module if not provided
if (!wasmModule) {
wasmModule = await import('./pkg/ruvector_onnx_embeddings_wasm.js');
// Only call default() for web/bundler targets (not nodejs)
if (typeof wasmModule.default === 'function') {
await wasmModule.default();
}
}
const loader = new ModelLoader();
const { modelBytes, tokenizerJson, config } = await loader.loadModel(modelName);
const embedderConfig = new wasmModule.WasmEmbedderConfig()
.setMaxLength(config.maxLength)
.setNormalize(true)
.setPooling(0); // Mean pooling
const embedder = wasmModule.WasmEmbedder.withConfig(
modelBytes,
tokenizerJson,
embedderConfig
);
return embedder;
}
/**
* Quick helper for one-off embedding (loads model, embeds, returns)
*
* @example
* ```javascript
* import { embed } from './loader.js';
*
* const embedding = await embed("Hello world");
* const embeddings = await embed(["Hello", "World"]);
* ```
*/
export async function embed(text, modelName = DEFAULT_MODEL) {
const embedder = await createEmbedder(modelName);
if (Array.isArray(text)) {
return embedder.embedBatch(text);
}
return embedder.embedOne(text);
}
/**
* Quick helper for similarity comparison
*
* @example
* ```javascript
* import { similarity } from './loader.js';
*
* const score = await similarity("I love dogs", "I adore puppies");
* console.log(score); // ~0.85
* ```
*/
export async function similarity(text1, text2, modelName = DEFAULT_MODEL) {
const embedder = await createEmbedder(modelName);
return embedder.similarity(text1, text2);
}
export default {
MODELS,
DEFAULT_MODEL,
ModelLoader,
createEmbedder,
embed,
similarity,
};

View File

@@ -0,0 +1,168 @@
/**
* Parallel ONNX Embedder using Worker Threads
*
* Distributes embedding work across multiple CPU cores for true parallelism.
*/
import { Worker } from 'worker_threads';
import { cpus } from 'os';
import { fileURLToPath } from 'url';
import { dirname, join } from 'path';
import { ModelLoader, DEFAULT_MODEL } from './loader.js';
const __dirname = dirname(fileURLToPath(import.meta.url));
export class ParallelEmbedder {
constructor(options = {}) {
this.numWorkers = options.numWorkers || Math.max(1, cpus().length - 1);
this.workers = [];
this.readyWorkers = [];
this.pendingRequests = new Map();
this.requestId = 0;
this.initialized = false;
this.modelData = null;
}
/**
* Initialize the parallel embedder with a model
*/
async init(modelName = DEFAULT_MODEL) {
if (this.initialized) return;
console.log(`🚀 Initializing ${this.numWorkers} worker threads...`);
// Load model once in main thread
const loader = new ModelLoader({ cache: false });
const { modelBytes, tokenizerJson, config } = await loader.loadModel(modelName);
// Store as transferable data
this.modelData = {
modelBytes: Array.from(modelBytes), // Convert to regular array for transfer
tokenizerJson,
config
};
this.dimension = config.dimension;
// Spawn workers
const workerPromises = [];
for (let i = 0; i < this.numWorkers; i++) {
workerPromises.push(this._spawnWorker(i));
}
await Promise.all(workerPromises);
this.initialized = true;
console.log(`${this.numWorkers} workers ready`);
}
async _spawnWorker(index) {
return new Promise((resolve, reject) => {
const worker = new Worker(join(__dirname, 'parallel-worker.mjs'), {
workerData: this.modelData
});
worker.on('message', (msg) => {
if (msg.type === 'ready') {
this.readyWorkers.push(worker);
resolve();
} else if (msg.type === 'result') {
const { id, embeddings } = msg;
const pending = this.pendingRequests.get(id);
if (pending) {
pending.resolve(embeddings);
this.pendingRequests.delete(id);
this.readyWorkers.push(worker);
this._processQueue();
}
} else if (msg.type === 'error') {
const { id, error } = msg;
const pending = this.pendingRequests.get(id);
if (pending) {
pending.reject(new Error(error));
this.pendingRequests.delete(id);
this.readyWorkers.push(worker);
this._processQueue();
}
}
});
worker.on('error', reject);
this.workers.push(worker);
});
}
_processQueue() {
// Process any queued requests when workers become available
}
/**
* Embed texts in parallel across worker threads
*/
async embedBatch(texts) {
if (!this.initialized) {
throw new Error('ParallelEmbedder not initialized. Call init() first.');
}
// Split texts into chunks for each worker
const chunkSize = Math.ceil(texts.length / this.numWorkers);
const chunks = [];
for (let i = 0; i < texts.length; i += chunkSize) {
chunks.push(texts.slice(i, i + chunkSize));
}
// Send to workers in parallel
const promises = chunks.map((chunk, i) => {
return new Promise((resolve, reject) => {
const id = this.requestId++;
const worker = this.readyWorkers.shift() || this.workers[i % this.workers.length];
this.pendingRequests.set(id, { resolve, reject });
worker.postMessage({ type: 'embed', id, texts: chunk });
});
});
// Wait for all results
const results = await Promise.all(promises);
// Flatten results
return results.flat();
}
/**
* Embed a single text (uses one worker)
*/
async embedOne(text) {
const results = await this.embedBatch([text]);
return new Float32Array(results[0]);
}
/**
* Compute similarity between two texts
*/
async similarity(text1, text2) {
const [emb1, emb2] = await this.embedBatch([text1, text2]);
return this._cosineSimilarity(emb1, emb2);
}
_cosineSimilarity(a, b) {
let dot = 0, normA = 0, normB = 0;
for (let i = 0; i < a.length; i++) {
dot += a[i] * b[i];
normA += a[i] * a[i];
normB += b[i] * b[i];
}
return dot / (Math.sqrt(normA) * Math.sqrt(normB));
}
/**
* Shutdown all workers
*/
async shutdown() {
for (const worker of this.workers) {
worker.postMessage({ type: 'shutdown' });
}
this.workers = [];
this.readyWorkers = [];
this.initialized = false;
}
}
export default ParallelEmbedder;

View File

@@ -0,0 +1,37 @@
/**
* Worker thread for parallel ONNX embedding generation
*/
import { parentPort, workerData } from 'worker_threads';
import { WasmEmbedder, WasmEmbedderConfig } from './pkg/ruvector_onnx_embeddings_wasm.js';
// Initialize embedder with model data passed from main thread
const { modelBytes, tokenizerJson, config } = workerData;
const embedderConfig = new WasmEmbedderConfig()
.setMaxLength(config.maxLength)
.setNormalize(true)
.setPooling(0);
const embedder = WasmEmbedder.withConfig(
new Uint8Array(modelBytes),
tokenizerJson,
embedderConfig
);
// Listen for texts to embed
parentPort.on('message', (message) => {
if (message.type === 'embed') {
const { id, texts } = message;
try {
const embeddings = texts.map(text => Array.from(embedder.embedOne(text)));
parentPort.postMessage({ type: 'result', id, embeddings });
} catch (error) {
parentPort.postMessage({ type: 'error', id, error: error.message });
}
} else if (message.type === 'shutdown') {
process.exit(0);
}
});
// Signal ready
parentPort.postMessage({ type: 'ready' });

View File

@@ -0,0 +1,213 @@
//! Main WASM embedder implementation
use crate::error::{Result, WasmEmbeddingError};
use crate::model::TractModel;
use crate::pooling::{cosine_similarity, normalize_l2, PoolingStrategy};
use crate::tokenizer::WasmTokenizer;
use serde::{Deserialize, Serialize};
use wasm_bindgen::prelude::*;
/// Configuration for the WASM embedder
#[wasm_bindgen]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WasmEmbedderConfig {
/// Maximum sequence length
#[wasm_bindgen(skip)]
pub max_length: usize,
/// Pooling strategy
#[wasm_bindgen(skip)]
pub pooling: PoolingStrategy,
/// Whether to L2 normalize embeddings
#[wasm_bindgen(skip)]
pub normalize: bool,
}
#[wasm_bindgen]
impl WasmEmbedderConfig {
/// Create a new configuration
#[wasm_bindgen(constructor)]
pub fn new() -> Self {
Self::default()
}
/// Set maximum sequence length
#[wasm_bindgen(js_name = setMaxLength)]
pub fn set_max_length(mut self, max_length: usize) -> Self {
self.max_length = max_length;
self
}
/// Set whether to normalize embeddings
#[wasm_bindgen(js_name = setNormalize)]
pub fn set_normalize(mut self, normalize: bool) -> Self {
self.normalize = normalize;
self
}
/// Set pooling strategy (0=Mean, 1=Cls, 2=Max, 3=MeanSqrtLen, 4=LastToken)
#[wasm_bindgen(js_name = setPooling)]
pub fn set_pooling(mut self, pooling: u8) -> Self {
self.pooling = match pooling {
0 => PoolingStrategy::Mean,
1 => PoolingStrategy::Cls,
2 => PoolingStrategy::Max,
3 => PoolingStrategy::MeanSqrtLen,
4 => PoolingStrategy::LastToken,
_ => PoolingStrategy::Mean,
};
self
}
}
impl Default for WasmEmbedderConfig {
fn default() -> Self {
Self {
max_length: 256,
pooling: PoolingStrategy::Mean,
normalize: true,
}
}
}
/// WASM-compatible embedder using Tract for inference
#[wasm_bindgen]
pub struct WasmEmbedder {
model: TractModel,
tokenizer: WasmTokenizer,
config: WasmEmbedderConfig,
hidden_size: usize,
}
#[wasm_bindgen]
impl WasmEmbedder {
/// Create a new embedder from model and tokenizer bytes
///
/// # Arguments
/// * `model_bytes` - ONNX model file bytes
/// * `tokenizer_json` - Tokenizer JSON configuration
#[wasm_bindgen(constructor)]
pub fn new(model_bytes: &[u8], tokenizer_json: &str) -> std::result::Result<WasmEmbedder, JsValue> {
Self::with_config(model_bytes, tokenizer_json, WasmEmbedderConfig::default())
}
/// Create embedder with custom configuration
#[wasm_bindgen(js_name = withConfig)]
pub fn with_config(
model_bytes: &[u8],
tokenizer_json: &str,
config: WasmEmbedderConfig,
) -> std::result::Result<WasmEmbedder, JsValue> {
let model = TractModel::from_bytes(model_bytes, config.max_length)
.map_err(|e| JsValue::from_str(&e.to_string()))?;
let tokenizer = WasmTokenizer::from_json(tokenizer_json, config.max_length)
.map_err(|e| JsValue::from_str(&e.to_string()))?;
let hidden_size = model.hidden_size();
Ok(Self {
model,
tokenizer,
config,
hidden_size,
})
}
/// Generate embedding for a single text
#[wasm_bindgen(js_name = embedOne)]
pub fn embed_one(&mut self, text: &str) -> std::result::Result<Vec<f32>, JsValue> {
self.embed_one_internal(text)
.map_err(|e| JsValue::from_str(&e.to_string()))
}
/// Generate embeddings for multiple texts
#[wasm_bindgen(js_name = embedBatch)]
pub fn embed_batch(&mut self, texts: Vec<String>) -> std::result::Result<Vec<f32>, JsValue> {
let refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
self.embed_batch_internal(&refs)
.map_err(|e| JsValue::from_str(&e.to_string()))
}
/// Compute similarity between two texts
#[wasm_bindgen]
pub fn similarity(&mut self, text1: &str, text2: &str) -> std::result::Result<f32, JsValue> {
let emb1 = self.embed_one_internal(text1)
.map_err(|e| JsValue::from_str(&e.to_string()))?;
let emb2 = self.embed_one_internal(text2)
.map_err(|e| JsValue::from_str(&e.to_string()))?;
Ok(cosine_similarity(&emb1, &emb2))
}
/// Get the embedding dimension
#[wasm_bindgen]
pub fn dimension(&self) -> usize {
self.hidden_size
}
/// Get maximum sequence length
#[wasm_bindgen(js_name = maxLength)]
pub fn max_length(&self) -> usize {
self.config.max_length
}
}
// Internal implementation
impl WasmEmbedder {
fn embed_one_internal(&mut self, text: &str) -> Result<Vec<f32>> {
// Tokenize
let encoded = self.tokenizer.encode(text)?;
let attention_mask = encoded.attention_mask.clone();
// Run inference
let raw_output = self.model.run(&encoded)?;
// Determine hidden size from output
let seq_len = self.config.max_length;
if raw_output.len() >= seq_len {
let detected_hidden = raw_output.len() / seq_len;
if detected_hidden != self.hidden_size && detected_hidden > 0 {
self.hidden_size = detected_hidden;
self.model.set_hidden_size(detected_hidden);
}
}
// Apply pooling
let mut embedding = self.config.pooling.apply(
&raw_output,
&attention_mask,
self.hidden_size,
);
// Normalize if configured
if self.config.normalize {
normalize_l2(&mut embedding);
}
Ok(embedding)
}
fn embed_batch_internal(&mut self, texts: &[&str]) -> Result<Vec<f32>> {
let mut all_embeddings = Vec::with_capacity(texts.len() * self.hidden_size);
for text in texts {
let embedding = self.embed_one_internal(text)?;
all_embeddings.extend(embedding);
}
Ok(all_embeddings)
}
}
/// Compute cosine similarity between two embedding vectors (JS-friendly)
#[wasm_bindgen(js_name = cosineSimilarity)]
pub fn js_cosine_similarity(a: Vec<f32>, b: Vec<f32>) -> f32 {
cosine_similarity(&a, &b)
}
/// L2 normalize an embedding vector (JS-friendly)
#[wasm_bindgen(js_name = normalizeL2)]
pub fn js_normalize_l2(mut embedding: Vec<f32>) -> Vec<f32> {
normalize_l2(&mut embedding);
embedding
}

View File

@@ -0,0 +1,62 @@
//! Error types for WASM embeddings
use thiserror::Error;
use wasm_bindgen::prelude::*;
/// Result type for WASM embedding operations
pub type Result<T> = std::result::Result<T, WasmEmbeddingError>;
/// Errors that can occur during WASM embedding operations
#[derive(Error, Debug)]
pub enum WasmEmbeddingError {
#[error("Model error: {0}")]
Model(String),
#[error("Tokenizer error: {0}")]
Tokenizer(String),
#[error("Inference error: {0}")]
Inference(String),
#[error("Invalid input: {0}")]
InvalidInput(String),
#[error("Serialization error: {0}")]
Serialization(String),
}
impl WasmEmbeddingError {
pub fn model(msg: impl Into<String>) -> Self {
Self::Model(msg.into())
}
pub fn tokenizer(msg: impl Into<String>) -> Self {
Self::Tokenizer(msg.into())
}
pub fn inference(msg: impl Into<String>) -> Self {
Self::Inference(msg.into())
}
pub fn invalid_input(msg: impl Into<String>) -> Self {
Self::InvalidInput(msg.into())
}
}
impl From<WasmEmbeddingError> for JsValue {
fn from(err: WasmEmbeddingError) -> Self {
JsValue::from_str(&err.to_string())
}
}
impl From<tract_onnx::prelude::TractError> for WasmEmbeddingError {
fn from(err: tract_onnx::prelude::TractError) -> Self {
Self::Model(err.to_string())
}
}
impl From<serde_json::Error> for WasmEmbeddingError {
fn from(err: serde_json::Error) -> Self {
Self::Serialization(err.to_string())
}
}

View File

@@ -0,0 +1,66 @@
//! # RuVector ONNX Embeddings - WASM Edition
//!
//! WASM-compatible embedding generation using Tract for inference.
//! Runs in browsers, Cloudflare Workers, Deno, and any WASM runtime.
//!
//! ## Features
//!
//! - **Browser Support**: Generate embeddings directly in the browser
//! - **Edge Computing**: Deploy to Cloudflare Workers, Vercel Edge, etc.
//! - **Portable**: Single WASM binary, no platform-specific dependencies
//! - **Same API**: Compatible with the native ruvector-onnx-embeddings crate
//!
//! ## Usage (JavaScript)
//!
//! ```javascript
//! import init, { WasmEmbedder } from 'ruvector-onnx-embeddings-wasm';
//!
//! await init();
//!
//! // Load model from bytes
//! const modelBytes = await fetch('/model.onnx').then(r => r.arrayBuffer());
//! const tokenizerJson = await fetch('/tokenizer.json').then(r => r.text());
//!
//! const embedder = new WasmEmbedder(new Uint8Array(modelBytes), tokenizerJson);
//!
//! // Generate embeddings
//! const embedding = embedder.embed_one("Hello, world!");
//! console.log("Embedding dimension:", embedding.length);
//!
//! // Compute similarity
//! const similarity = embedder.similarity("I love Rust", "Rust is great");
//! console.log("Similarity:", similarity);
//! ```
mod embedder;
mod error;
mod model;
mod pooling;
mod tokenizer;
pub use embedder::{WasmEmbedder, WasmEmbedderConfig};
pub use error::WasmEmbeddingError;
pub use pooling::PoolingStrategy;
use wasm_bindgen::prelude::*;
/// Initialize panic hook for better error messages in WASM
#[wasm_bindgen(start)]
pub fn init() {
#[cfg(feature = "console_error_panic_hook")]
console_error_panic_hook::set_once();
}
/// Get the library version
#[wasm_bindgen]
pub fn version() -> String {
env!("CARGO_PKG_VERSION").to_string()
}
/// Check if SIMD is available (for performance info)
/// Returns true if compiled with WASM SIMD128 support
#[wasm_bindgen]
pub fn simd_available() -> bool {
// Check if compiled with SIMD128 target feature
cfg!(target_feature = "simd128")
}

View File

@@ -0,0 +1,116 @@
//! Tract-based ONNX model for WASM inference
use crate::error::{Result, WasmEmbeddingError};
use crate::tokenizer::EncodedInput;
use tract_onnx::prelude::*;
/// Tract ONNX model wrapper for WASM
pub struct TractModel {
model: SimplePlan<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>,
hidden_size: usize,
}
impl TractModel {
/// Load model from ONNX bytes
pub fn from_bytes(bytes: &[u8], max_seq_length: usize) -> Result<Self> {
// Parse ONNX model
let model = tract_onnx::onnx()
.model_for_read(&mut std::io::Cursor::new(bytes))
.map_err(|e| WasmEmbeddingError::model(format!("Failed to parse ONNX: {}", e)))?;
// Set input shapes for optimization
// Standard transformer inputs: [batch, seq_len]
let batch = 1usize;
let seq_len = max_seq_length;
let model = model
.with_input_fact(
0,
InferenceFact::dt_shape(i64::datum_type(), tvec![batch, seq_len]),
)?
.with_input_fact(
1,
InferenceFact::dt_shape(i64::datum_type(), tvec![batch, seq_len]),
)?
.with_input_fact(
2,
InferenceFact::dt_shape(i64::datum_type(), tvec![batch, seq_len]),
)?;
// Optimize the model
let model = model
.into_optimized()
.map_err(|e| WasmEmbeddingError::model(format!("Failed to optimize: {}", e)))?;
let model = model
.into_runnable()
.map_err(|e| WasmEmbeddingError::model(format!("Failed to make runnable: {}", e)))?;
// Default hidden size (will be determined from output)
let hidden_size = 384;
Ok(Self { model, hidden_size })
}
/// Run inference on encoded input
pub fn run(&self, input: &EncodedInput) -> Result<Vec<f32>> {
let seq_len = input.input_ids.len();
// Create input tensors
let input_ids: Tensor = tract_ndarray::Array2::from_shape_vec(
(1, seq_len),
input.input_ids.clone(),
)
.map_err(|e| WasmEmbeddingError::inference(e.to_string()))?
.into();
let attention_mask: Tensor = tract_ndarray::Array2::from_shape_vec(
(1, seq_len),
input.attention_mask.clone(),
)
.map_err(|e| WasmEmbeddingError::inference(e.to_string()))?
.into();
let token_type_ids: Tensor = tract_ndarray::Array2::from_shape_vec(
(1, seq_len),
input.token_type_ids.clone(),
)
.map_err(|e| WasmEmbeddingError::inference(e.to_string()))?
.into();
// Run inference
let inputs = tvec![
input_ids.into(),
attention_mask.into(),
token_type_ids.into()
];
let outputs = self
.model
.run(inputs)
.map_err(|e| WasmEmbeddingError::inference(format!("Inference failed: {}", e)))?;
// Extract output tensor
// Output is typically [batch, seq_len, hidden_size] or [batch, hidden_size]
let output = outputs
.first()
.ok_or_else(|| WasmEmbeddingError::inference("No output tensor"))?;
let output_array = output
.to_array_view::<f32>()
.map_err(|e| WasmEmbeddingError::inference(format!("Failed to extract output: {}", e)))?;
// Flatten and return
Ok(output_array.iter().copied().collect())
}
/// Get the hidden size
pub fn hidden_size(&self) -> usize {
self.hidden_size
}
/// Update hidden size (called after first inference)
pub fn set_hidden_size(&mut self, size: usize) {
self.hidden_size = size;
}
}

View File

@@ -0,0 +1,181 @@
//! Pooling strategies for converting token embeddings to sentence embeddings
use serde::{Deserialize, Serialize};
use wasm_bindgen::prelude::*;
/// Strategy for pooling token embeddings into a single sentence embedding
#[wasm_bindgen]
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq)]
pub enum PoolingStrategy {
/// Average all token embeddings (most common)
#[default]
Mean,
/// Use only the [CLS] token embedding
Cls,
/// Take the maximum value across all tokens for each dimension
Max,
/// Mean pooling normalized by sqrt of sequence length
MeanSqrtLen,
/// Use the last token embedding (for decoder models)
LastToken,
}
impl PoolingStrategy {
/// Apply pooling to token embeddings
///
/// # Arguments
/// * `embeddings` - Token embeddings [seq_len, hidden_size]
/// * `attention_mask` - Attention mask [seq_len]
///
/// # Returns
/// Pooled embedding [hidden_size]
pub fn apply(&self, embeddings: &[f32], attention_mask: &[i64], hidden_size: usize) -> Vec<f32> {
let seq_len = attention_mask.len();
if embeddings.is_empty() || hidden_size == 0 {
return vec![0.0; hidden_size];
}
match self {
PoolingStrategy::Mean => {
self.mean_pooling(embeddings, attention_mask, hidden_size, seq_len)
}
PoolingStrategy::Cls => {
// First token (CLS)
embeddings[..hidden_size].to_vec()
}
PoolingStrategy::Max => {
self.max_pooling(embeddings, attention_mask, hidden_size, seq_len)
}
PoolingStrategy::MeanSqrtLen => {
let mut pooled = self.mean_pooling(embeddings, attention_mask, hidden_size, seq_len);
let valid_tokens: f32 = attention_mask.iter().map(|&m| m as f32).sum();
let scale = 1.0 / valid_tokens.sqrt();
for v in &mut pooled {
*v *= scale;
}
pooled
}
PoolingStrategy::LastToken => {
// Find last valid token
let last_idx = attention_mask
.iter()
.rposition(|&m| m == 1)
.unwrap_or(0);
let start = last_idx * hidden_size;
embeddings[start..start + hidden_size].to_vec()
}
}
}
fn mean_pooling(
&self,
embeddings: &[f32],
attention_mask: &[i64],
hidden_size: usize,
seq_len: usize,
) -> Vec<f32> {
let mut pooled = vec![0.0f32; hidden_size];
let mut count = 0.0f32;
for (i, &mask) in attention_mask.iter().enumerate() {
if mask == 1 && i < seq_len {
let start = i * hidden_size;
if start + hidden_size <= embeddings.len() {
for (j, v) in pooled.iter_mut().enumerate() {
*v += embeddings[start + j];
}
count += 1.0;
}
}
}
if count > 0.0 {
for v in &mut pooled {
*v /= count;
}
}
pooled
}
fn max_pooling(
&self,
embeddings: &[f32],
attention_mask: &[i64],
hidden_size: usize,
seq_len: usize,
) -> Vec<f32> {
let mut pooled = vec![f32::NEG_INFINITY; hidden_size];
for (i, &mask) in attention_mask.iter().enumerate() {
if mask == 1 && i < seq_len {
let start = i * hidden_size;
if start + hidden_size <= embeddings.len() {
for (j, v) in pooled.iter_mut().enumerate() {
*v = v.max(embeddings[start + j]);
}
}
}
}
// Replace -inf with 0 for dimensions with no valid tokens
for v in &mut pooled {
if v.is_infinite() {
*v = 0.0;
}
}
pooled
}
}
/// L2 normalize a vector in place
pub fn normalize_l2(embedding: &mut [f32]) {
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for v in embedding {
*v /= norm;
}
}
}
/// Compute cosine similarity between two embeddings
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a > 0.0 && norm_b > 0.0 {
dot / (norm_a * norm_b)
} else {
0.0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
let c = vec![0.0, 1.0, 0.0];
assert!(cosine_similarity(&a, &c).abs() < 1e-6);
}
#[test]
fn test_normalize_l2() {
let mut v = vec![3.0, 4.0];
normalize_l2(&mut v);
assert!((v[0] - 0.6).abs() < 1e-6);
assert!((v[1] - 0.8).abs() < 1e-6);
}
}

View File

@@ -0,0 +1,114 @@
//! Tokenizer wrapper for WASM embedding generation
use crate::error::{Result, WasmEmbeddingError};
use tokenizers::Tokenizer;
/// Tokenizer wrapper that handles text encoding
pub struct WasmTokenizer {
tokenizer: Tokenizer,
max_length: usize,
}
/// Encoded text ready for model inference
#[derive(Debug, Clone)]
pub struct EncodedInput {
pub input_ids: Vec<i64>,
pub attention_mask: Vec<i64>,
pub token_type_ids: Vec<i64>,
}
impl WasmTokenizer {
/// Create a new tokenizer from JSON configuration
pub fn from_json(json: &str, max_length: usize) -> Result<Self> {
let tokenizer = Tokenizer::from_bytes(json.as_bytes())
.map_err(|e| WasmEmbeddingError::tokenizer(e.to_string()))?;
Ok(Self {
tokenizer,
max_length,
})
}
/// Create tokenizer from raw bytes
pub fn from_bytes(bytes: &[u8], max_length: usize) -> Result<Self> {
let tokenizer = Tokenizer::from_bytes(bytes)
.map_err(|e| WasmEmbeddingError::tokenizer(e.to_string()))?;
Ok(Self {
tokenizer,
max_length,
})
}
/// Encode a single text
pub fn encode(&self, text: &str) -> Result<EncodedInput> {
let encoding = self
.tokenizer
.encode(text, true)
.map_err(|e| WasmEmbeddingError::tokenizer(e.to_string()))?;
let mut input_ids: Vec<i64> = encoding.get_ids().iter().map(|&id| id as i64).collect();
let mut attention_mask: Vec<i64> =
encoding.get_attention_mask().iter().map(|&m| m as i64).collect();
let mut token_type_ids: Vec<i64> =
encoding.get_type_ids().iter().map(|&t| t as i64).collect();
// Truncate if necessary
if input_ids.len() > self.max_length {
input_ids.truncate(self.max_length);
attention_mask.truncate(self.max_length);
token_type_ids.truncate(self.max_length);
}
// Pad if necessary
while input_ids.len() < self.max_length {
input_ids.push(0);
attention_mask.push(0);
token_type_ids.push(0);
}
Ok(EncodedInput {
input_ids,
attention_mask,
token_type_ids,
})
}
/// Encode multiple texts with padding to the same length
pub fn encode_batch(&self, texts: &[&str]) -> Result<Vec<EncodedInput>> {
texts.iter().map(|text| self.encode(text)).collect()
}
/// Get the maximum sequence length
pub fn max_length(&self) -> usize {
self.max_length
}
}
#[cfg(test)]
mod tests {
use super::*;
// Basic tokenizer JSON for testing
const TEST_TOKENIZER: &str = r#"{
"version": "1.0",
"truncation": null,
"padding": null,
"added_tokens": [],
"normalizer": null,
"pre_tokenizer": {"type": "Whitespace"},
"post_processor": null,
"decoder": null,
"model": {
"type": "WordLevel",
"vocab": {"[PAD]": 0, "[UNK]": 1, "hello": 2, "world": 3},
"unk_token": "[UNK]"
}
}"#;
#[test]
fn test_tokenizer_creation() {
let tokenizer = WasmTokenizer::from_json(TEST_TOKENIZER, 128);
assert!(tokenizer.is_ok());
}
}

View File

@@ -0,0 +1,142 @@
#!/usr/bin/env node
/**
* Full end-to-end test with model download
*
* Downloads all-MiniLM-L6-v2 and runs embedding tests
*/
import { ModelLoader, MODELS, DEFAULT_MODEL } from './loader.js';
import {
WasmEmbedder,
WasmEmbedderConfig,
cosineSimilarity,
} from './pkg/ruvector_onnx_embeddings_wasm.js';
console.log('🧪 RuVector ONNX Embeddings WASM - Full E2E Test\n');
console.log('='.repeat(60));
// List available models
console.log('\n📦 Available Models:');
ModelLoader.listModels().forEach(m => {
const isDefault = m.id === DEFAULT_MODEL ? ' ⭐ DEFAULT' : '';
console.log(`${m.id} (${m.dimension}d, ${m.size})${isDefault}`);
console.log(` ${m.description}`);
});
console.log('\n' + '='.repeat(60));
console.log(`\n🔄 Loading model: ${DEFAULT_MODEL}...\n`);
// Load model with progress
const loader = new ModelLoader({
cache: false, // Disable cache for testing
onProgress: ({ loaded, total, percent }) => {
process.stdout.write(`\r Progress: ${percent}% (${(loaded/1024/1024).toFixed(1)}MB / ${(total/1024/1024).toFixed(1)}MB)`);
}
});
try {
const { modelBytes, tokenizerJson, config } = await loader.loadModel(DEFAULT_MODEL);
console.log('\n');
console.log(` ✅ Model loaded: ${config.name}`);
console.log(` ✅ Model size: ${(modelBytes.length / 1024 / 1024).toFixed(2)} MB`);
console.log(` ✅ Tokenizer size: ${(tokenizerJson.length / 1024).toFixed(2)} KB`);
// Create embedder
console.log('\n🔧 Creating embedder...');
const embedderConfig = new WasmEmbedderConfig()
.setMaxLength(config.maxLength)
.setNormalize(true)
.setPooling(0);
const embedder = WasmEmbedder.withConfig(modelBytes, tokenizerJson, embedderConfig);
console.log(` ✅ Embedder created`);
console.log(` ✅ Dimension: ${embedder.dimension()}`);
console.log(` ✅ Max length: ${embedder.maxLength()}`);
// Test 1: Single embedding
console.log('\n' + '='.repeat(60));
console.log('\n📝 Test 1: Single Embedding');
const text1 = "The quick brown fox jumps over the lazy dog.";
console.log(` Input: "${text1}"`);
const start1 = performance.now();
const embedding1 = embedder.embedOne(text1);
const time1 = performance.now() - start1;
console.log(` ✅ Output dimension: ${embedding1.length}`);
console.log(` ✅ First 5 values: [${Array.from(embedding1.slice(0, 5)).map(v => v.toFixed(4)).join(', ')}]`);
console.log(` ✅ Time: ${time1.toFixed(2)}ms`);
// Test 2: Semantic similarity
console.log('\n' + '='.repeat(60));
console.log('\n📝 Test 2: Semantic Similarity');
const pairs = [
["I love programming in Rust", "Rust is my favorite programming language"],
["The weather is nice today", "It's sunny outside"],
["I love programming in Rust", "The weather is nice today"],
["Machine learning is fascinating", "AI and deep learning are interesting"],
];
for (const [a, b] of pairs) {
const start = performance.now();
const sim = embedder.similarity(a, b);
const time = performance.now() - start;
const label = sim > 0.5 ? '🟢 Similar' : '🔴 Different';
console.log(`\n "${a.substring(0, 30)}..."`);
console.log(` "${b.substring(0, 30)}..."`);
console.log(` ${label}: ${sim.toFixed(4)} (${time.toFixed(1)}ms)`);
}
// Test 3: Batch embedding
console.log('\n' + '='.repeat(60));
console.log('\n📝 Test 3: Batch Embedding');
const texts = [
"Artificial intelligence is transforming technology.",
"Machine learning models learn from data.",
"Deep learning uses neural networks.",
"Vector databases enable semantic search.",
];
console.log(` Embedding ${texts.length} texts...`);
const start3 = performance.now();
const batchEmbeddings = embedder.embedBatch(texts);
const time3 = performance.now() - start3;
const embeddingDim = embedder.dimension();
const numEmbeddings = batchEmbeddings.length / embeddingDim;
console.log(` ✅ Total values: ${batchEmbeddings.length}`);
console.log(` ✅ Embeddings: ${numEmbeddings} x ${embeddingDim}d`);
console.log(` ✅ Time: ${time3.toFixed(2)}ms (${(time3/texts.length).toFixed(2)}ms per text)`);
// Compute pairwise similarities
console.log('\n Pairwise similarities:');
for (let i = 0; i < numEmbeddings; i++) {
for (let j = i + 1; j < numEmbeddings; j++) {
const emb_i = batchEmbeddings.slice(i * embeddingDim, (i + 1) * embeddingDim);
const emb_j = batchEmbeddings.slice(j * embeddingDim, (j + 1) * embeddingDim);
const sim = cosineSimilarity(emb_i, emb_j);
console.log(` [${i}] vs [${j}]: ${sim.toFixed(4)}`);
}
}
// Summary
console.log('\n' + '='.repeat(60));
console.log('\n✅ All tests passed!');
console.log('='.repeat(60));
console.log('\n📊 Performance Summary:');
console.log(` • Model: ${config.name}`);
console.log(` • Dimension: ${embeddingDim}`);
console.log(` • Single embed: ~${time1.toFixed(0)}ms`);
console.log(` • Batch (4 texts): ~${time3.toFixed(0)}ms`);
console.log(` • Throughput: ~${(1000 / (time3/texts.length)).toFixed(0)} texts/sec`);
} catch (error) {
console.error('\n❌ Error:', error.message);
console.error(error.stack);
process.exit(1);
}

View File

@@ -0,0 +1,121 @@
#!/usr/bin/env node
/**
* Benchmark: Sequential vs Parallel ONNX Embeddings
*/
import { cpus } from 'os';
import { ParallelEmbedder } from './parallel-embedder.mjs';
import { createEmbedder } from './loader.js';
console.log('🧪 Parallel vs Sequential ONNX Embeddings Benchmark\n');
console.log(`CPU Cores: ${cpus().length}`);
console.log('='.repeat(60));
// Test data - various batch sizes
const testTexts = [
"Machine learning is transforming technology",
"Deep learning uses neural networks",
"Natural language processing understands text",
"Computer vision analyzes images",
"Reinforcement learning learns from rewards",
"Generative AI creates new content",
"Vector databases enable semantic search",
"Embeddings capture semantic meaning",
"Transformers revolutionized NLP",
"BERT is a popular language model",
"GPT generates human-like text",
"RAG combines retrieval and generation",
];
async function benchmarkSequential(embedder, texts, iterations = 3) {
const times = [];
for (let i = 0; i < iterations; i++) {
const start = performance.now();
for (const text of texts) {
embedder.embedOne(text);
}
times.push(performance.now() - start);
}
return times.reduce((a, b) => a + b) / times.length;
}
async function benchmarkParallel(embedder, texts, iterations = 3) {
const times = [];
for (let i = 0; i < iterations; i++) {
const start = performance.now();
await embedder.embedBatch(texts);
times.push(performance.now() - start);
}
return times.reduce((a, b) => a + b) / times.length;
}
async function main() {
try {
// Initialize sequential embedder
console.log('\n📦 Loading model for sequential test...');
const seqEmbedder = await createEmbedder();
console.log('✅ Sequential embedder ready\n');
// Warm up
seqEmbedder.embedOne("warmup");
// Initialize parallel embedder
console.log('📦 Initializing parallel embedder...');
const parEmbedder = new ParallelEmbedder({ numWorkers: Math.min(4, cpus().length) });
await parEmbedder.init();
// Benchmark different batch sizes
for (const batchSize of [4, 8, 12]) {
const texts = testTexts.slice(0, batchSize);
console.log(`\n${'='.repeat(60)}`);
console.log(`📊 Batch Size: ${batchSize} texts`);
console.log('='.repeat(60));
// Sequential benchmark
console.log('\n⏱ Sequential (single-threaded)...');
const seqTime = await benchmarkSequential(seqEmbedder, texts);
console.log(` Time: ${seqTime.toFixed(1)}ms`);
console.log(` Per text: ${(seqTime / batchSize).toFixed(1)}ms`);
// Parallel benchmark
console.log('\n⏱ Parallel (worker threads)...');
const parTime = await benchmarkParallel(parEmbedder, texts);
console.log(` Time: ${parTime.toFixed(1)}ms`);
console.log(` Per text: ${(parTime / batchSize).toFixed(1)}ms`);
// Speedup
const speedup = seqTime / parTime;
const icon = speedup > 1.2 ? '🚀' : speedup > 1 ? '✅' : '⚠️';
console.log(`\n${icon} Speedup: ${speedup.toFixed(2)}x`);
}
// Verify correctness
console.log(`\n${'='.repeat(60)}`);
console.log('🔍 Verifying correctness...');
console.log('='.repeat(60));
const testText = "Vector databases are awesome";
const seqEmb = seqEmbedder.embedOne(testText);
const parEmb = await parEmbedder.embedOne(testText);
// Compare embeddings
let diff = 0;
for (let i = 0; i < seqEmb.length; i++) {
diff += Math.abs(seqEmb[i] - parEmb[i]);
}
const avgDiff = diff / seqEmb.length;
console.log(`\nEmbedding difference: ${avgDiff.toExponential(4)}`);
console.log(avgDiff < 1e-6 ? '✅ Embeddings match!' : '⚠️ Embeddings differ');
// Cleanup
await parEmbedder.shutdown();
console.log('\n✅ Benchmark complete!');
} catch (error) {
console.error('❌ Error:', error.message);
console.error(error.stack);
process.exit(1);
}
}
main();

View File

@@ -0,0 +1,72 @@
#!/usr/bin/env node
/**
* Test script to validate WASM embeddings package works
*/
import {
version,
simd_available,
cosineSimilarity,
normalizeL2,
WasmEmbedderConfig
} from './pkg/ruvector_onnx_embeddings_wasm.js';
console.log('🧪 RuVector ONNX Embeddings WASM - Validation Test\n');
// Test 1: Version check
console.log('Test 1: Version check');
const ver = version();
console.log(` ✅ Version: ${ver}`);
// Test 2: SIMD availability
console.log('\nTest 2: SIMD availability');
const simd = simd_available();
console.log(` ✅ SIMD available: ${simd}`);
// Test 3: Cosine similarity utility
console.log('\nTest 3: Cosine similarity utility');
const a = new Float32Array([1.0, 0.0, 0.0]);
const b = new Float32Array([1.0, 0.0, 0.0]);
const c = new Float32Array([0.0, 1.0, 0.0]);
const simSame = cosineSimilarity(a, b);
const simDiff = cosineSimilarity(a, c);
console.log(` ✅ Same vectors similarity: ${simSame.toFixed(4)} (expected: 1.0)`);
console.log(` ✅ Orthogonal vectors similarity: ${simDiff.toFixed(4)} (expected: 0.0)`);
if (Math.abs(simSame - 1.0) > 0.0001) {
throw new Error('Cosine similarity failed for same vectors');
}
if (Math.abs(simDiff) > 0.0001) {
throw new Error('Cosine similarity failed for orthogonal vectors');
}
// Test 4: L2 normalization utility
console.log('\nTest 4: L2 normalization utility');
const unnormalized = new Float32Array([3.0, 4.0]);
const normalized = normalizeL2(unnormalized);
const norm = Math.sqrt(normalized[0]**2 + normalized[1]**2);
console.log(` ✅ Normalized vector: [${normalized[0].toFixed(4)}, ${normalized[1].toFixed(4)}]`);
console.log(` ✅ L2 norm: ${norm.toFixed(4)} (expected: 1.0)`);
if (Math.abs(norm - 1.0) > 0.0001) {
throw new Error('L2 normalization failed');
}
// Test 5: Config creation
console.log('\nTest 5: WasmEmbedderConfig creation');
const config = new WasmEmbedderConfig()
.setMaxLength(256)
.setNormalize(true)
.setPooling(0);
console.log(' ✅ Config created and chained successfully');
// Note: config is consumed by chaining, no need to free
console.log('\n' + '='.repeat(50));
console.log('✅ All utility tests passed!');
console.log('='.repeat(50));
console.log('\n📝 Note: Full embedder test requires ONNX model + tokenizer files.');
console.log(' The core WASM bindings are working correctly.\n');