Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
1983
vendor/ruvector/examples/onnx-embeddings-wasm/Cargo.lock
generated
vendored
Normal file
1983
vendor/ruvector/examples/onnx-embeddings-wasm/Cargo.lock
generated
vendored
Normal file
File diff suppressed because it is too large
Load Diff
61
vendor/ruvector/examples/onnx-embeddings-wasm/Cargo.toml
vendored
Normal file
61
vendor/ruvector/examples/onnx-embeddings-wasm/Cargo.toml
vendored
Normal file
@@ -0,0 +1,61 @@
|
||||
[package]
|
||||
name = "ruvector-onnx-embeddings-wasm"
|
||||
version = "0.1.2"
|
||||
edition = "2021"
|
||||
authors = ["RuVector Team"]
|
||||
description = "WASM embedding generation with SIMD - runs in browsers, Cloudflare Workers, Deno, and edge runtimes"
|
||||
license = "MIT"
|
||||
repository = "https://github.com/ruvnet/ruvector"
|
||||
keywords = ["onnx", "embeddings", "wasm", "webassembly", "ml"]
|
||||
categories = ["wasm", "science", "algorithms"]
|
||||
|
||||
# Standalone package
|
||||
[workspace]
|
||||
|
||||
[lib]
|
||||
crate-type = ["cdylib", "rlib"]
|
||||
|
||||
[dependencies]
|
||||
# Tract - ONNX inference that compiles to WASM
|
||||
tract-onnx = "0.21"
|
||||
tract-core = "0.21"
|
||||
|
||||
# Tokenization - HuggingFace tokenizers (WASM compatible)
|
||||
tokenizers = { version = "0.20", default-features = false, features = ["unstable_wasm"] }
|
||||
|
||||
# WASM bindings
|
||||
wasm-bindgen = "0.2"
|
||||
wasm-bindgen-futures = "0.4"
|
||||
js-sys = "0.3"
|
||||
web-sys = { version = "0.3", features = ["console"] }
|
||||
|
||||
# Serialization
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
serde-wasm-bindgen = "0.6"
|
||||
|
||||
# Error handling
|
||||
thiserror = "2.0"
|
||||
anyhow = "1.0"
|
||||
|
||||
# Async (WASM compatible)
|
||||
futures = "0.3"
|
||||
|
||||
# Console logging for WASM
|
||||
console_error_panic_hook = { version = "0.1", optional = true }
|
||||
|
||||
# Getrandom for WASM
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
|
||||
[dev-dependencies]
|
||||
wasm-bindgen-test = "0.3"
|
||||
|
||||
[features]
|
||||
default = ["console_error_panic_hook"]
|
||||
|
||||
[profile.release]
|
||||
opt-level = "s"
|
||||
lto = true
|
||||
|
||||
[package.metadata.wasm-pack.profile.release]
|
||||
wasm-opt = ["-Os", "--enable-mutable-globals"]
|
||||
435
vendor/ruvector/examples/onnx-embeddings-wasm/README.md
vendored
Normal file
435
vendor/ruvector/examples/onnx-embeddings-wasm/README.md
vendored
Normal file
@@ -0,0 +1,435 @@
|
||||
# RuVector ONNX Embeddings WASM
|
||||
|
||||
[](https://www.npmjs.com/package/ruvector-onnx-embeddings-wasm)
|
||||
[](https://crates.io/crates/ruvector-onnx-embeddings-wasm)
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
[](https://webassembly.org/)
|
||||
[](https://webassembly.org/roadmap/)
|
||||
|
||||
> **Portable embedding generation with SIMD acceleration and parallel workers**
|
||||
|
||||
Generate text embeddings directly in browsers, Cloudflare Workers, Deno, Node.js, and any WASM runtime. Built with [Tract](https://github.com/sonos/tract) for pure Rust ONNX inference.
|
||||
|
||||
## Features
|
||||
|
||||
| Feature | Description |
|
||||
|---------|-------------|
|
||||
| 🌐 **Browser Support** | Generate embeddings client-side, no server needed |
|
||||
| ⚡ **SIMD Acceleration** | WASM SIMD128 for vectorized operations |
|
||||
| 🚀 **Parallel Workers** | Multi-threaded batch processing (3.8x speedup) |
|
||||
| 🏢 **Edge Computing** | Deploy to Cloudflare Workers, Vercel Edge, Deno Deploy |
|
||||
| 📦 **Zero Dependencies** | Single WASM binary, no native modules |
|
||||
| 🤗 **HuggingFace Models** | Pre-configured URLs for popular models |
|
||||
| 🔄 **Auto Caching** | Browser Cache API for instant reloads |
|
||||
| 🎯 **Same API** | Compatible with native `ruvector-onnx-embeddings` |
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
npm install ruvector-onnx-embeddings-wasm
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Node.js (Sequential)
|
||||
|
||||
```javascript
|
||||
import { createEmbedder, similarity, embed } from 'ruvector-onnx-embeddings-wasm/loader';
|
||||
|
||||
// One-liner similarity
|
||||
const score = await similarity("I love dogs", "I adore puppies");
|
||||
console.log(score); // ~0.85
|
||||
|
||||
// One-liner embedding
|
||||
const embedding = await embed("Hello world");
|
||||
console.log(embedding.length); // 384
|
||||
|
||||
// Full control
|
||||
const embedder = await createEmbedder('bge-small-en-v1.5');
|
||||
const emb1 = embedder.embedOne("First text");
|
||||
const emb2 = embedder.embedOne("Second text");
|
||||
```
|
||||
|
||||
### Node.js (Parallel - 3.8x faster)
|
||||
|
||||
```javascript
|
||||
import { ParallelEmbedder } from 'ruvector-onnx-embeddings-wasm/parallel';
|
||||
|
||||
// Initialize with worker threads
|
||||
const embedder = new ParallelEmbedder({ numWorkers: 4 });
|
||||
await embedder.init('all-MiniLM-L6-v2');
|
||||
|
||||
// Batch embed with parallel processing
|
||||
const texts = [
|
||||
"Machine learning is transforming technology",
|
||||
"Deep learning uses neural networks",
|
||||
"Natural language processing understands text",
|
||||
"Computer vision analyzes images"
|
||||
];
|
||||
const embeddings = await embedder.embedBatch(texts);
|
||||
|
||||
// Compute similarity
|
||||
const sim = await embedder.similarity("I love Rust", "Rust is great");
|
||||
console.log(sim); // ~0.85
|
||||
|
||||
// Cleanup
|
||||
await embedder.shutdown();
|
||||
```
|
||||
|
||||
### Browser (ES Modules)
|
||||
|
||||
```html
|
||||
<script type="module">
|
||||
import init, { WasmEmbedder } from 'https://unpkg.com/ruvector-onnx-embeddings-wasm/ruvector_onnx_embeddings_wasm.js';
|
||||
import { createEmbedder } from 'https://unpkg.com/ruvector-onnx-embeddings-wasm/loader.js';
|
||||
|
||||
// Initialize WASM
|
||||
await init();
|
||||
|
||||
// Create embedder (downloads model automatically)
|
||||
const embedder = await createEmbedder('all-MiniLM-L6-v2');
|
||||
|
||||
// Generate embeddings
|
||||
const embedding = embedder.embedOne("Hello, world!");
|
||||
console.log("Dimension:", embedding.length); // 384
|
||||
|
||||
// Compute similarity
|
||||
const sim = embedder.similarity("I love Rust", "Rust is great");
|
||||
console.log("Similarity:", sim.toFixed(4)); // ~0.85
|
||||
</script>
|
||||
```
|
||||
|
||||
### Cloudflare Workers
|
||||
|
||||
```javascript
|
||||
import { WasmEmbedder, WasmEmbedderConfig } from 'ruvector-onnx-embeddings-wasm';
|
||||
|
||||
export default {
|
||||
async fetch(request, env) {
|
||||
// Load model from R2 or KV
|
||||
const modelBytes = await env.MODELS.get('model.onnx', 'arrayBuffer');
|
||||
const tokenizerJson = await env.MODELS.get('tokenizer.json', 'text');
|
||||
|
||||
const embedder = new WasmEmbedder(
|
||||
new Uint8Array(modelBytes),
|
||||
tokenizerJson
|
||||
);
|
||||
|
||||
const { text } = await request.json();
|
||||
const embedding = embedder.embedOne(text);
|
||||
|
||||
return Response.json({
|
||||
embedding: Array.from(embedding),
|
||||
dimension: embedding.length
|
||||
});
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
## Available Models
|
||||
|
||||
| Model | Dimension | Size | Speed | Quality | Best For |
|
||||
|-------|-----------|------|-------|---------|----------|
|
||||
| **all-MiniLM-L6-v2** ⭐ | 384 | 23MB | ⚡⚡⚡ | ⭐⭐⭐ | Default, fast |
|
||||
| **all-MiniLM-L12-v2** | 384 | 33MB | ⚡⚡ | ⭐⭐⭐⭐ | Better quality |
|
||||
| **bge-small-en-v1.5** | 384 | 33MB | ⚡⚡⚡ | ⭐⭐⭐⭐ | State-of-the-art |
|
||||
| **bge-base-en-v1.5** | 768 | 110MB | ⚡ | ⭐⭐⭐⭐⭐ | Best quality |
|
||||
| **e5-small-v2** | 384 | 33MB | ⚡⚡⚡ | ⭐⭐⭐⭐ | Search/retrieval |
|
||||
| **gte-small** | 384 | 33MB | ⚡⚡⚡ | ⭐⭐⭐⭐ | Multilingual |
|
||||
|
||||
## Performance
|
||||
|
||||
### Sequential vs Parallel (Node.js)
|
||||
|
||||
| Batch Size | Sequential | Parallel (4 workers) | Speedup |
|
||||
|------------|------------|----------------------|---------|
|
||||
| 4 texts | 1,573ms | 410ms | **3.83x** |
|
||||
| 8 texts | 3,105ms | 861ms | **3.61x** |
|
||||
| 12 texts | 4,667ms | 1,235ms | **3.78x** |
|
||||
|
||||
*Tested on 16-core machine with all-MiniLM-L6-v2*
|
||||
|
||||
### Environment Benchmarks
|
||||
|
||||
| Environment | Mode | Throughput | Latency |
|
||||
|-------------|------|------------|---------|
|
||||
| Node.js 20 | Sequential | ~2.5 texts/sec | ~390ms |
|
||||
| Node.js 20 | Parallel (4w) | ~9.7 texts/sec | ~103ms |
|
||||
| Chrome (M1 Mac) | Sequential | ~50 texts/sec | ~20ms |
|
||||
| Firefox (M1 Mac) | Sequential | ~45 texts/sec | ~22ms |
|
||||
| Cloudflare Workers | Sequential | ~30 texts/sec | ~33ms |
|
||||
| Deno | Sequential | ~75 texts/sec | ~13ms |
|
||||
|
||||
*Browser benchmarks with smaller inputs; Node.js with full model warmup*
|
||||
|
||||
### SIMD Support
|
||||
|
||||
WASM SIMD128 is enabled by default and provides:
|
||||
- Smaller binary size (180KB reduction)
|
||||
- Vectorized tensor operations
|
||||
- Supported in Chrome 91+, Firefox 89+, Safari 16.4+, Node.js 16+
|
||||
|
||||
```javascript
|
||||
import { simd_available } from 'ruvector-onnx-embeddings-wasm';
|
||||
console.log('SIMD enabled:', simd_available()); // true
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### ModelLoader
|
||||
|
||||
```javascript
|
||||
import { ModelLoader, MODELS, DEFAULT_MODEL } from 'ruvector-onnx-embeddings-wasm/loader';
|
||||
|
||||
// List available models
|
||||
console.log(ModelLoader.listModels());
|
||||
|
||||
// Load with progress
|
||||
const loader = new ModelLoader({
|
||||
cache: true,
|
||||
onProgress: ({ loaded, total, percent }) => console.log(`${percent}%`)
|
||||
});
|
||||
|
||||
const { modelBytes, tokenizerJson, config } = await loader.loadModel('all-MiniLM-L6-v2');
|
||||
```
|
||||
|
||||
### WasmEmbedder
|
||||
|
||||
```typescript
|
||||
class WasmEmbedder {
|
||||
constructor(modelBytes: Uint8Array, tokenizerJson: string);
|
||||
|
||||
static withConfig(
|
||||
modelBytes: Uint8Array,
|
||||
tokenizerJson: string,
|
||||
config: WasmEmbedderConfig
|
||||
): WasmEmbedder;
|
||||
|
||||
embedOne(text: string): Float32Array;
|
||||
embedBatch(texts: string[]): Float32Array;
|
||||
similarity(text1: string, text2: string): number;
|
||||
|
||||
dimension(): number;
|
||||
maxLength(): number;
|
||||
}
|
||||
```
|
||||
|
||||
### WasmEmbedderConfig
|
||||
|
||||
```typescript
|
||||
class WasmEmbedderConfig {
|
||||
constructor();
|
||||
setMaxLength(length: number): WasmEmbedderConfig;
|
||||
setNormalize(normalize: boolean): WasmEmbedderConfig;
|
||||
setPooling(strategy: number): WasmEmbedderConfig;
|
||||
// 0=Mean, 1=Cls, 2=Max, 3=MeanSqrtLen, 4=LastToken
|
||||
}
|
||||
```
|
||||
|
||||
### ParallelEmbedder (Node.js only)
|
||||
|
||||
```typescript
|
||||
class ParallelEmbedder {
|
||||
constructor(options?: { numWorkers?: number });
|
||||
|
||||
init(modelName?: string): Promise<void>;
|
||||
embedOne(text: string): Promise<Float32Array>;
|
||||
embedBatch(texts: string[]): Promise<number[][]>;
|
||||
similarity(text1: string, text2: string): Promise<number>;
|
||||
shutdown(): Promise<void>;
|
||||
}
|
||||
```
|
||||
|
||||
### Utility Functions
|
||||
|
||||
```typescript
|
||||
function cosineSimilarity(a: Float32Array, b: Float32Array): number;
|
||||
function normalizeL2(embedding: Float32Array): Float32Array;
|
||||
function version(): string;
|
||||
function simd_available(): boolean;
|
||||
```
|
||||
|
||||
### Convenience Functions
|
||||
|
||||
```typescript
|
||||
// One-liner embedding
|
||||
async function embed(text: string | string[], modelName?: string): Promise<Float32Array>;
|
||||
|
||||
// One-liner similarity
|
||||
async function similarity(text1: string, text2: string, modelName?: string): Promise<number>;
|
||||
|
||||
// Create configured embedder
|
||||
async function createEmbedder(modelName?: string): Promise<WasmEmbedder>;
|
||||
```
|
||||
|
||||
## Pooling Strategies
|
||||
|
||||
| Value | Strategy | Description |
|
||||
|-------|----------|-------------|
|
||||
| 0 | **Mean** | Average all tokens (default, recommended) |
|
||||
| 1 | **Cls** | Use [CLS] token only (BERT-style) |
|
||||
| 2 | **Max** | Max pooling across tokens |
|
||||
| 3 | **MeanSqrtLen** | Mean normalized by sqrt(length) |
|
||||
| 4 | **LastToken** | Last token (decoder models) |
|
||||
|
||||
## Comparison: Native vs WASM
|
||||
|
||||
| Aspect | Native (`ort`) | WASM (`tract`) |
|
||||
|--------|----------------|----------------|
|
||||
| Speed | ⚡⚡⚡ Native | ⚡⚡ ~2-3x slower |
|
||||
| Browser | ❌ | ✅ |
|
||||
| Edge Workers | ❌ | ✅ |
|
||||
| Parallel | Multi-process | Worker threads |
|
||||
| GPU | CUDA, TensorRT | ❌ |
|
||||
| Bundle Size | ~50MB | ~7.4MB |
|
||||
| SIMD | AVX2/AVX-512 | SIMD128 |
|
||||
| Portability | Platform-specific | Universal |
|
||||
|
||||
**Use native** for: servers, high throughput, GPU acceleration
|
||||
**Use WASM** for: browsers, edge, portability, simpler deployment
|
||||
|
||||
## Building from Source
|
||||
|
||||
```bash
|
||||
# Install wasm-pack
|
||||
cargo install wasm-pack
|
||||
|
||||
# Build for Node.js with SIMD
|
||||
RUSTFLAGS='-C target-feature=+simd128' wasm-pack build --target nodejs --release
|
||||
|
||||
# Build for web with SIMD
|
||||
RUSTFLAGS='-C target-feature=+simd128' wasm-pack build --target web --release
|
||||
|
||||
# Build for bundlers (webpack, vite) with SIMD
|
||||
RUSTFLAGS='-C target-feature=+simd128' wasm-pack build --target bundler --release
|
||||
|
||||
# Build without SIMD (for older browsers)
|
||||
wasm-pack build --target web --release
|
||||
```
|
||||
|
||||
## Use Cases
|
||||
|
||||
### Semantic Search
|
||||
|
||||
```javascript
|
||||
import { createEmbedder, cosineSimilarity } from 'ruvector-onnx-embeddings-wasm/loader';
|
||||
|
||||
const embedder = await createEmbedder();
|
||||
|
||||
// Index documents
|
||||
const docs = ["Rust is fast", "Python is easy", "JavaScript runs everywhere"];
|
||||
const embeddings = docs.map(d => embedder.embedOne(d));
|
||||
|
||||
// Search
|
||||
const query = embedder.embedOne("Which language is performant?");
|
||||
const scores = embeddings.map((e, i) => ({
|
||||
doc: docs[i],
|
||||
score: cosineSimilarity(query, e)
|
||||
}));
|
||||
scores.sort((a, b) => b.score - a.score);
|
||||
console.log(scores[0]); // { doc: "Rust is fast", score: 0.82 }
|
||||
```
|
||||
|
||||
### Batch Processing with Parallel Workers
|
||||
|
||||
```javascript
|
||||
import { ParallelEmbedder } from 'ruvector-onnx-embeddings-wasm/parallel';
|
||||
|
||||
const embedder = new ParallelEmbedder({ numWorkers: 4 });
|
||||
await embedder.init();
|
||||
|
||||
// Process large datasets efficiently
|
||||
const documents = loadDocuments(); // Array of 1000+ texts
|
||||
const batchSize = 100;
|
||||
|
||||
for (let i = 0; i < documents.length; i += batchSize) {
|
||||
const batch = documents.slice(i, i + batchSize);
|
||||
const embeddings = await embedder.embedBatch(batch);
|
||||
await saveEmbeddings(embeddings);
|
||||
}
|
||||
|
||||
await embedder.shutdown();
|
||||
```
|
||||
|
||||
### RAG (Retrieval-Augmented Generation)
|
||||
|
||||
```javascript
|
||||
// Build knowledge base
|
||||
const knowledge = [
|
||||
"RuVector is a vector database",
|
||||
"Embeddings capture semantic meaning",
|
||||
// ... more docs
|
||||
];
|
||||
const knowledgeEmbeddings = knowledge.map(k => embedder.embedOne(k));
|
||||
|
||||
// Retrieve relevant context for LLM
|
||||
function getContext(query, topK = 3) {
|
||||
const queryEmb = embedder.embedOne(query);
|
||||
const scores = knowledgeEmbeddings.map((e, i) => ({
|
||||
text: knowledge[i],
|
||||
score: cosineSimilarity(queryEmb, e)
|
||||
}));
|
||||
return scores.sort((a, b) => b.score - a.score).slice(0, topK);
|
||||
}
|
||||
```
|
||||
|
||||
### Text Clustering
|
||||
|
||||
```javascript
|
||||
const texts = [
|
||||
"Machine learning is amazing",
|
||||
"Deep learning uses neural networks",
|
||||
"I love pizza",
|
||||
"Italian food is delicious"
|
||||
];
|
||||
|
||||
const embeddings = texts.map(t => embedder.embedOne(t));
|
||||
// Use k-means or hierarchical clustering on embeddings
|
||||
```
|
||||
|
||||
## Browser Compatibility
|
||||
|
||||
| Browser | SIMD | Status |
|
||||
|---------|------|--------|
|
||||
| Chrome 91+ | ✅ | Full support |
|
||||
| Firefox 89+ | ✅ | Full support |
|
||||
| Safari 16.4+ | ✅ | Full support |
|
||||
| Edge 91+ | ✅ | Full support |
|
||||
| Node.js 16+ | ✅ | Full support |
|
||||
| Deno | ✅ | Full support |
|
||||
| Cloudflare Workers | ✅ | Full support |
|
||||
|
||||
## Related Packages
|
||||
|
||||
| Package | Runtime | Use Case |
|
||||
|---------|---------|----------|
|
||||
| [ruvector-onnx-embeddings](https://crates.io/crates/ruvector-onnx-embeddings) | Native | High-performance servers |
|
||||
| **ruvector-onnx-embeddings-wasm** | WASM | Browsers, edge, portable |
|
||||
|
||||
## Changelog
|
||||
|
||||
### v0.1.2
|
||||
- Added `ParallelEmbedder` for multi-threaded batch processing (3.8x speedup)
|
||||
- Worker threads support for Node.js environments
|
||||
|
||||
### v0.1.1
|
||||
- Enabled WASM SIMD128 for vectorized operations
|
||||
- Added `simd_available()` function
|
||||
- Reduced binary size by 180KB
|
||||
|
||||
### v0.1.0
|
||||
- Initial release
|
||||
- HuggingFace model loader with caching
|
||||
- Browser and Node.js support
|
||||
- 6 pre-configured models
|
||||
|
||||
## License
|
||||
|
||||
MIT License - See [LICENSE](../../LICENSE) for details.
|
||||
|
||||
---
|
||||
|
||||
<p align="center">
|
||||
<b>Part of the RuVector ecosystem</b><br>
|
||||
High-performance vector operations in Rust
|
||||
</p>
|
||||
351
vendor/ruvector/examples/onnx-embeddings-wasm/loader.js
vendored
Normal file
351
vendor/ruvector/examples/onnx-embeddings-wasm/loader.js
vendored
Normal file
@@ -0,0 +1,351 @@
|
||||
/**
|
||||
* Model Loader for RuVector ONNX Embeddings WASM
|
||||
*
|
||||
* Provides easy loading of pre-trained models from HuggingFace Hub
|
||||
*/
|
||||
|
||||
/**
|
||||
* Pre-configured models with their HuggingFace URLs
|
||||
*/
|
||||
export const MODELS = {
|
||||
// Sentence Transformers - Small & Fast
|
||||
'all-MiniLM-L6-v2': {
|
||||
name: 'all-MiniLM-L6-v2',
|
||||
dimension: 384,
|
||||
maxLength: 256,
|
||||
size: '23MB',
|
||||
description: 'Fast, general-purpose embeddings',
|
||||
model: 'https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/model.onnx',
|
||||
tokenizer: 'https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/tokenizer.json',
|
||||
},
|
||||
'all-MiniLM-L12-v2': {
|
||||
name: 'all-MiniLM-L12-v2',
|
||||
dimension: 384,
|
||||
maxLength: 256,
|
||||
size: '33MB',
|
||||
description: 'Better quality, balanced speed',
|
||||
model: 'https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/onnx/model.onnx',
|
||||
tokenizer: 'https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/tokenizer.json',
|
||||
},
|
||||
|
||||
// BGE Models - State of the art
|
||||
'bge-small-en-v1.5': {
|
||||
name: 'bge-small-en-v1.5',
|
||||
dimension: 384,
|
||||
maxLength: 512,
|
||||
size: '33MB',
|
||||
description: 'State-of-the-art small model',
|
||||
model: 'https://huggingface.co/BAAI/bge-small-en-v1.5/resolve/main/onnx/model.onnx',
|
||||
tokenizer: 'https://huggingface.co/BAAI/bge-small-en-v1.5/resolve/main/tokenizer.json',
|
||||
},
|
||||
'bge-base-en-v1.5': {
|
||||
name: 'bge-base-en-v1.5',
|
||||
dimension: 768,
|
||||
maxLength: 512,
|
||||
size: '110MB',
|
||||
description: 'Best overall quality',
|
||||
model: 'https://huggingface.co/BAAI/bge-base-en-v1.5/resolve/main/onnx/model.onnx',
|
||||
tokenizer: 'https://huggingface.co/BAAI/bge-base-en-v1.5/resolve/main/tokenizer.json',
|
||||
},
|
||||
|
||||
// E5 Models - Microsoft
|
||||
'e5-small-v2': {
|
||||
name: 'e5-small-v2',
|
||||
dimension: 384,
|
||||
maxLength: 512,
|
||||
size: '33MB',
|
||||
description: 'Excellent for search & retrieval',
|
||||
model: 'https://huggingface.co/intfloat/e5-small-v2/resolve/main/onnx/model.onnx',
|
||||
tokenizer: 'https://huggingface.co/intfloat/e5-small-v2/resolve/main/tokenizer.json',
|
||||
},
|
||||
|
||||
// GTE Models - Alibaba
|
||||
'gte-small': {
|
||||
name: 'gte-small',
|
||||
dimension: 384,
|
||||
maxLength: 512,
|
||||
size: '33MB',
|
||||
description: 'Good multilingual support',
|
||||
model: 'https://huggingface.co/thenlper/gte-small/resolve/main/onnx/model.onnx',
|
||||
tokenizer: 'https://huggingface.co/thenlper/gte-small/resolve/main/tokenizer.json',
|
||||
},
|
||||
};
|
||||
|
||||
/**
|
||||
* Default model for quick start
|
||||
*/
|
||||
export const DEFAULT_MODEL = 'all-MiniLM-L6-v2';
|
||||
|
||||
/**
|
||||
* Model loader with caching support
|
||||
*/
|
||||
export class ModelLoader {
|
||||
constructor(options = {}) {
|
||||
this.cache = options.cache ?? true;
|
||||
this.cacheStorage = options.cacheStorage ?? 'ruvector-models';
|
||||
this.onProgress = options.onProgress ?? null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load a pre-configured model by name
|
||||
* @param {string} modelName - Model name from MODELS
|
||||
* @returns {Promise<{modelBytes: Uint8Array, tokenizerJson: string, config: object}>}
|
||||
*/
|
||||
async loadModel(modelName = DEFAULT_MODEL) {
|
||||
const modelConfig = MODELS[modelName];
|
||||
if (!modelConfig) {
|
||||
throw new Error(`Unknown model: ${modelName}. Available: ${Object.keys(MODELS).join(', ')}`);
|
||||
}
|
||||
|
||||
console.log(`Loading model: ${modelConfig.name} (${modelConfig.size})`);
|
||||
|
||||
const [modelBytes, tokenizerJson] = await Promise.all([
|
||||
this.fetchWithCache(modelConfig.model, `${modelName}-model.onnx`, 'arraybuffer'),
|
||||
this.fetchWithCache(modelConfig.tokenizer, `${modelName}-tokenizer.json`, 'text'),
|
||||
]);
|
||||
|
||||
return {
|
||||
modelBytes: new Uint8Array(modelBytes),
|
||||
tokenizerJson,
|
||||
config: modelConfig,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Load model from custom URLs
|
||||
* @param {string} modelUrl - URL to ONNX model
|
||||
* @param {string} tokenizerUrl - URL to tokenizer.json
|
||||
* @returns {Promise<{modelBytes: Uint8Array, tokenizerJson: string}>}
|
||||
*/
|
||||
async loadFromUrls(modelUrl, tokenizerUrl) {
|
||||
const [modelBytes, tokenizerJson] = await Promise.all([
|
||||
this.fetchWithCache(modelUrl, null, 'arraybuffer'),
|
||||
this.fetchWithCache(tokenizerUrl, null, 'text'),
|
||||
]);
|
||||
|
||||
return {
|
||||
modelBytes: new Uint8Array(modelBytes),
|
||||
tokenizerJson,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Load model from local files (Node.js)
|
||||
* @param {string} modelPath - Path to ONNX model
|
||||
* @param {string} tokenizerPath - Path to tokenizer.json
|
||||
* @returns {Promise<{modelBytes: Uint8Array, tokenizerJson: string}>}
|
||||
*/
|
||||
async loadFromFiles(modelPath, tokenizerPath) {
|
||||
// Node.js environment
|
||||
if (typeof process !== 'undefined' && process.versions?.node) {
|
||||
const fs = await import('fs/promises');
|
||||
const [modelBytes, tokenizerJson] = await Promise.all([
|
||||
fs.readFile(modelPath),
|
||||
fs.readFile(tokenizerPath, 'utf8'),
|
||||
]);
|
||||
return {
|
||||
modelBytes: new Uint8Array(modelBytes),
|
||||
tokenizerJson,
|
||||
};
|
||||
}
|
||||
throw new Error('loadFromFiles is only available in Node.js');
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetch with optional caching (uses Cache API in browsers)
|
||||
*/
|
||||
async fetchWithCache(url, cacheKey, responseType) {
|
||||
// Try cache first (browser only)
|
||||
if (this.cache && typeof caches !== 'undefined' && cacheKey) {
|
||||
try {
|
||||
const cache = await caches.open(this.cacheStorage);
|
||||
const cached = await cache.match(cacheKey);
|
||||
if (cached) {
|
||||
console.log(` Cache hit: ${cacheKey}`);
|
||||
return responseType === 'arraybuffer'
|
||||
? await cached.arrayBuffer()
|
||||
: await cached.text();
|
||||
}
|
||||
} catch (e) {
|
||||
// Cache API not available, continue with fetch
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch from network
|
||||
console.log(` Downloading: ${url}`);
|
||||
const response = await this.fetchWithProgress(url);
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to fetch ${url}: ${response.status} ${response.statusText}`);
|
||||
}
|
||||
|
||||
// Clone for caching
|
||||
const responseClone = response.clone();
|
||||
|
||||
// Cache the response (browser only)
|
||||
if (this.cache && typeof caches !== 'undefined' && cacheKey) {
|
||||
try {
|
||||
const cache = await caches.open(this.cacheStorage);
|
||||
await cache.put(cacheKey, responseClone);
|
||||
} catch (e) {
|
||||
// Cache write failed, continue
|
||||
}
|
||||
}
|
||||
|
||||
return responseType === 'arraybuffer'
|
||||
? await response.arrayBuffer()
|
||||
: await response.text();
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetch with progress reporting
|
||||
*/
|
||||
async fetchWithProgress(url) {
|
||||
const response = await fetch(url);
|
||||
|
||||
if (!this.onProgress || !response.body) {
|
||||
return response;
|
||||
}
|
||||
|
||||
const contentLength = response.headers.get('content-length');
|
||||
if (!contentLength) {
|
||||
return response;
|
||||
}
|
||||
|
||||
const total = parseInt(contentLength, 10);
|
||||
let loaded = 0;
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const chunks = [];
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
chunks.push(value);
|
||||
loaded += value.length;
|
||||
|
||||
this.onProgress({
|
||||
loaded,
|
||||
total,
|
||||
percent: Math.round((loaded / total) * 100),
|
||||
});
|
||||
}
|
||||
|
||||
const body = new Uint8Array(loaded);
|
||||
let position = 0;
|
||||
for (const chunk of chunks) {
|
||||
body.set(chunk, position);
|
||||
position += chunk.length;
|
||||
}
|
||||
|
||||
return new Response(body, {
|
||||
headers: response.headers,
|
||||
status: response.status,
|
||||
statusText: response.statusText,
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear cached models
|
||||
*/
|
||||
async clearCache() {
|
||||
if (typeof caches !== 'undefined') {
|
||||
await caches.delete(this.cacheStorage);
|
||||
console.log('Model cache cleared');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* List available models
|
||||
*/
|
||||
static listModels() {
|
||||
return Object.entries(MODELS).map(([key, config]) => ({
|
||||
id: key,
|
||||
...config,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Quick helper to create an embedder with a pre-configured model
|
||||
*
|
||||
* @example
|
||||
* ```javascript
|
||||
* import { createEmbedder } from './loader.js';
|
||||
*
|
||||
* const embedder = await createEmbedder('all-MiniLM-L6-v2');
|
||||
* const embedding = embedder.embedOne("Hello world");
|
||||
* ```
|
||||
*/
|
||||
export async function createEmbedder(modelName = DEFAULT_MODEL, wasmModule = null) {
|
||||
// Import WASM module if not provided
|
||||
if (!wasmModule) {
|
||||
wasmModule = await import('./pkg/ruvector_onnx_embeddings_wasm.js');
|
||||
// Only call default() for web/bundler targets (not nodejs)
|
||||
if (typeof wasmModule.default === 'function') {
|
||||
await wasmModule.default();
|
||||
}
|
||||
}
|
||||
|
||||
const loader = new ModelLoader();
|
||||
const { modelBytes, tokenizerJson, config } = await loader.loadModel(modelName);
|
||||
|
||||
const embedderConfig = new wasmModule.WasmEmbedderConfig()
|
||||
.setMaxLength(config.maxLength)
|
||||
.setNormalize(true)
|
||||
.setPooling(0); // Mean pooling
|
||||
|
||||
const embedder = wasmModule.WasmEmbedder.withConfig(
|
||||
modelBytes,
|
||||
tokenizerJson,
|
||||
embedderConfig
|
||||
);
|
||||
|
||||
return embedder;
|
||||
}
|
||||
|
||||
/**
|
||||
* Quick helper for one-off embedding (loads model, embeds, returns)
|
||||
*
|
||||
* @example
|
||||
* ```javascript
|
||||
* import { embed } from './loader.js';
|
||||
*
|
||||
* const embedding = await embed("Hello world");
|
||||
* const embeddings = await embed(["Hello", "World"]);
|
||||
* ```
|
||||
*/
|
||||
export async function embed(text, modelName = DEFAULT_MODEL) {
|
||||
const embedder = await createEmbedder(modelName);
|
||||
|
||||
if (Array.isArray(text)) {
|
||||
return embedder.embedBatch(text);
|
||||
}
|
||||
return embedder.embedOne(text);
|
||||
}
|
||||
|
||||
/**
|
||||
* Quick helper for similarity comparison
|
||||
*
|
||||
* @example
|
||||
* ```javascript
|
||||
* import { similarity } from './loader.js';
|
||||
*
|
||||
* const score = await similarity("I love dogs", "I adore puppies");
|
||||
* console.log(score); // ~0.85
|
||||
* ```
|
||||
*/
|
||||
export async function similarity(text1, text2, modelName = DEFAULT_MODEL) {
|
||||
const embedder = await createEmbedder(modelName);
|
||||
return embedder.similarity(text1, text2);
|
||||
}
|
||||
|
||||
export default {
|
||||
MODELS,
|
||||
DEFAULT_MODEL,
|
||||
ModelLoader,
|
||||
createEmbedder,
|
||||
embed,
|
||||
similarity,
|
||||
};
|
||||
168
vendor/ruvector/examples/onnx-embeddings-wasm/parallel-embedder.mjs
vendored
Normal file
168
vendor/ruvector/examples/onnx-embeddings-wasm/parallel-embedder.mjs
vendored
Normal file
@@ -0,0 +1,168 @@
|
||||
/**
|
||||
* Parallel ONNX Embedder using Worker Threads
|
||||
*
|
||||
* Distributes embedding work across multiple CPU cores for true parallelism.
|
||||
*/
|
||||
import { Worker } from 'worker_threads';
|
||||
import { cpus } from 'os';
|
||||
import { fileURLToPath } from 'url';
|
||||
import { dirname, join } from 'path';
|
||||
import { ModelLoader, DEFAULT_MODEL } from './loader.js';
|
||||
|
||||
const __dirname = dirname(fileURLToPath(import.meta.url));
|
||||
|
||||
export class ParallelEmbedder {
|
||||
constructor(options = {}) {
|
||||
this.numWorkers = options.numWorkers || Math.max(1, cpus().length - 1);
|
||||
this.workers = [];
|
||||
this.readyWorkers = [];
|
||||
this.pendingRequests = new Map();
|
||||
this.requestId = 0;
|
||||
this.initialized = false;
|
||||
this.modelData = null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize the parallel embedder with a model
|
||||
*/
|
||||
async init(modelName = DEFAULT_MODEL) {
|
||||
if (this.initialized) return;
|
||||
|
||||
console.log(`🚀 Initializing ${this.numWorkers} worker threads...`);
|
||||
|
||||
// Load model once in main thread
|
||||
const loader = new ModelLoader({ cache: false });
|
||||
const { modelBytes, tokenizerJson, config } = await loader.loadModel(modelName);
|
||||
|
||||
// Store as transferable data
|
||||
this.modelData = {
|
||||
modelBytes: Array.from(modelBytes), // Convert to regular array for transfer
|
||||
tokenizerJson,
|
||||
config
|
||||
};
|
||||
this.dimension = config.dimension;
|
||||
|
||||
// Spawn workers
|
||||
const workerPromises = [];
|
||||
for (let i = 0; i < this.numWorkers; i++) {
|
||||
workerPromises.push(this._spawnWorker(i));
|
||||
}
|
||||
await Promise.all(workerPromises);
|
||||
|
||||
this.initialized = true;
|
||||
console.log(`✅ ${this.numWorkers} workers ready`);
|
||||
}
|
||||
|
||||
async _spawnWorker(index) {
|
||||
return new Promise((resolve, reject) => {
|
||||
const worker = new Worker(join(__dirname, 'parallel-worker.mjs'), {
|
||||
workerData: this.modelData
|
||||
});
|
||||
|
||||
worker.on('message', (msg) => {
|
||||
if (msg.type === 'ready') {
|
||||
this.readyWorkers.push(worker);
|
||||
resolve();
|
||||
} else if (msg.type === 'result') {
|
||||
const { id, embeddings } = msg;
|
||||
const pending = this.pendingRequests.get(id);
|
||||
if (pending) {
|
||||
pending.resolve(embeddings);
|
||||
this.pendingRequests.delete(id);
|
||||
this.readyWorkers.push(worker);
|
||||
this._processQueue();
|
||||
}
|
||||
} else if (msg.type === 'error') {
|
||||
const { id, error } = msg;
|
||||
const pending = this.pendingRequests.get(id);
|
||||
if (pending) {
|
||||
pending.reject(new Error(error));
|
||||
this.pendingRequests.delete(id);
|
||||
this.readyWorkers.push(worker);
|
||||
this._processQueue();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
worker.on('error', reject);
|
||||
this.workers.push(worker);
|
||||
});
|
||||
}
|
||||
|
||||
_processQueue() {
|
||||
// Process any queued requests when workers become available
|
||||
}
|
||||
|
||||
/**
|
||||
* Embed texts in parallel across worker threads
|
||||
*/
|
||||
async embedBatch(texts) {
|
||||
if (!this.initialized) {
|
||||
throw new Error('ParallelEmbedder not initialized. Call init() first.');
|
||||
}
|
||||
|
||||
// Split texts into chunks for each worker
|
||||
const chunkSize = Math.ceil(texts.length / this.numWorkers);
|
||||
const chunks = [];
|
||||
for (let i = 0; i < texts.length; i += chunkSize) {
|
||||
chunks.push(texts.slice(i, i + chunkSize));
|
||||
}
|
||||
|
||||
// Send to workers in parallel
|
||||
const promises = chunks.map((chunk, i) => {
|
||||
return new Promise((resolve, reject) => {
|
||||
const id = this.requestId++;
|
||||
const worker = this.readyWorkers.shift() || this.workers[i % this.workers.length];
|
||||
|
||||
this.pendingRequests.set(id, { resolve, reject });
|
||||
worker.postMessage({ type: 'embed', id, texts: chunk });
|
||||
});
|
||||
});
|
||||
|
||||
// Wait for all results
|
||||
const results = await Promise.all(promises);
|
||||
|
||||
// Flatten results
|
||||
return results.flat();
|
||||
}
|
||||
|
||||
/**
|
||||
* Embed a single text (uses one worker)
|
||||
*/
|
||||
async embedOne(text) {
|
||||
const results = await this.embedBatch([text]);
|
||||
return new Float32Array(results[0]);
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute similarity between two texts
|
||||
*/
|
||||
async similarity(text1, text2) {
|
||||
const [emb1, emb2] = await this.embedBatch([text1, text2]);
|
||||
return this._cosineSimilarity(emb1, emb2);
|
||||
}
|
||||
|
||||
_cosineSimilarity(a, b) {
|
||||
let dot = 0, normA = 0, normB = 0;
|
||||
for (let i = 0; i < a.length; i++) {
|
||||
dot += a[i] * b[i];
|
||||
normA += a[i] * a[i];
|
||||
normB += b[i] * b[i];
|
||||
}
|
||||
return dot / (Math.sqrt(normA) * Math.sqrt(normB));
|
||||
}
|
||||
|
||||
/**
|
||||
* Shutdown all workers
|
||||
*/
|
||||
async shutdown() {
|
||||
for (const worker of this.workers) {
|
||||
worker.postMessage({ type: 'shutdown' });
|
||||
}
|
||||
this.workers = [];
|
||||
this.readyWorkers = [];
|
||||
this.initialized = false;
|
||||
}
|
||||
}
|
||||
|
||||
export default ParallelEmbedder;
|
||||
37
vendor/ruvector/examples/onnx-embeddings-wasm/parallel-worker.mjs
vendored
Normal file
37
vendor/ruvector/examples/onnx-embeddings-wasm/parallel-worker.mjs
vendored
Normal file
@@ -0,0 +1,37 @@
|
||||
/**
|
||||
* Worker thread for parallel ONNX embedding generation
|
||||
*/
|
||||
import { parentPort, workerData } from 'worker_threads';
|
||||
import { WasmEmbedder, WasmEmbedderConfig } from './pkg/ruvector_onnx_embeddings_wasm.js';
|
||||
|
||||
// Initialize embedder with model data passed from main thread
|
||||
const { modelBytes, tokenizerJson, config } = workerData;
|
||||
|
||||
const embedderConfig = new WasmEmbedderConfig()
|
||||
.setMaxLength(config.maxLength)
|
||||
.setNormalize(true)
|
||||
.setPooling(0);
|
||||
|
||||
const embedder = WasmEmbedder.withConfig(
|
||||
new Uint8Array(modelBytes),
|
||||
tokenizerJson,
|
||||
embedderConfig
|
||||
);
|
||||
|
||||
// Listen for texts to embed
|
||||
parentPort.on('message', (message) => {
|
||||
if (message.type === 'embed') {
|
||||
const { id, texts } = message;
|
||||
try {
|
||||
const embeddings = texts.map(text => Array.from(embedder.embedOne(text)));
|
||||
parentPort.postMessage({ type: 'result', id, embeddings });
|
||||
} catch (error) {
|
||||
parentPort.postMessage({ type: 'error', id, error: error.message });
|
||||
}
|
||||
} else if (message.type === 'shutdown') {
|
||||
process.exit(0);
|
||||
}
|
||||
});
|
||||
|
||||
// Signal ready
|
||||
parentPort.postMessage({ type: 'ready' });
|
||||
213
vendor/ruvector/examples/onnx-embeddings-wasm/src/embedder.rs
vendored
Normal file
213
vendor/ruvector/examples/onnx-embeddings-wasm/src/embedder.rs
vendored
Normal file
@@ -0,0 +1,213 @@
|
||||
//! Main WASM embedder implementation
|
||||
|
||||
use crate::error::{Result, WasmEmbeddingError};
|
||||
use crate::model::TractModel;
|
||||
use crate::pooling::{cosine_similarity, normalize_l2, PoolingStrategy};
|
||||
use crate::tokenizer::WasmTokenizer;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
/// Configuration for the WASM embedder
|
||||
#[wasm_bindgen]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WasmEmbedderConfig {
|
||||
/// Maximum sequence length
|
||||
#[wasm_bindgen(skip)]
|
||||
pub max_length: usize,
|
||||
/// Pooling strategy
|
||||
#[wasm_bindgen(skip)]
|
||||
pub pooling: PoolingStrategy,
|
||||
/// Whether to L2 normalize embeddings
|
||||
#[wasm_bindgen(skip)]
|
||||
pub normalize: bool,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmEmbedderConfig {
|
||||
/// Create a new configuration
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Set maximum sequence length
|
||||
#[wasm_bindgen(js_name = setMaxLength)]
|
||||
pub fn set_max_length(mut self, max_length: usize) -> Self {
|
||||
self.max_length = max_length;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set whether to normalize embeddings
|
||||
#[wasm_bindgen(js_name = setNormalize)]
|
||||
pub fn set_normalize(mut self, normalize: bool) -> Self {
|
||||
self.normalize = normalize;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set pooling strategy (0=Mean, 1=Cls, 2=Max, 3=MeanSqrtLen, 4=LastToken)
|
||||
#[wasm_bindgen(js_name = setPooling)]
|
||||
pub fn set_pooling(mut self, pooling: u8) -> Self {
|
||||
self.pooling = match pooling {
|
||||
0 => PoolingStrategy::Mean,
|
||||
1 => PoolingStrategy::Cls,
|
||||
2 => PoolingStrategy::Max,
|
||||
3 => PoolingStrategy::MeanSqrtLen,
|
||||
4 => PoolingStrategy::LastToken,
|
||||
_ => PoolingStrategy::Mean,
|
||||
};
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for WasmEmbedderConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_length: 256,
|
||||
pooling: PoolingStrategy::Mean,
|
||||
normalize: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// WASM-compatible embedder using Tract for inference
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmEmbedder {
|
||||
model: TractModel,
|
||||
tokenizer: WasmTokenizer,
|
||||
config: WasmEmbedderConfig,
|
||||
hidden_size: usize,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmEmbedder {
|
||||
/// Create a new embedder from model and tokenizer bytes
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `model_bytes` - ONNX model file bytes
|
||||
/// * `tokenizer_json` - Tokenizer JSON configuration
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(model_bytes: &[u8], tokenizer_json: &str) -> std::result::Result<WasmEmbedder, JsValue> {
|
||||
Self::with_config(model_bytes, tokenizer_json, WasmEmbedderConfig::default())
|
||||
}
|
||||
|
||||
/// Create embedder with custom configuration
|
||||
#[wasm_bindgen(js_name = withConfig)]
|
||||
pub fn with_config(
|
||||
model_bytes: &[u8],
|
||||
tokenizer_json: &str,
|
||||
config: WasmEmbedderConfig,
|
||||
) -> std::result::Result<WasmEmbedder, JsValue> {
|
||||
let model = TractModel::from_bytes(model_bytes, config.max_length)
|
||||
.map_err(|e| JsValue::from_str(&e.to_string()))?;
|
||||
|
||||
let tokenizer = WasmTokenizer::from_json(tokenizer_json, config.max_length)
|
||||
.map_err(|e| JsValue::from_str(&e.to_string()))?;
|
||||
|
||||
let hidden_size = model.hidden_size();
|
||||
|
||||
Ok(Self {
|
||||
model,
|
||||
tokenizer,
|
||||
config,
|
||||
hidden_size,
|
||||
})
|
||||
}
|
||||
|
||||
/// Generate embedding for a single text
|
||||
#[wasm_bindgen(js_name = embedOne)]
|
||||
pub fn embed_one(&mut self, text: &str) -> std::result::Result<Vec<f32>, JsValue> {
|
||||
self.embed_one_internal(text)
|
||||
.map_err(|e| JsValue::from_str(&e.to_string()))
|
||||
}
|
||||
|
||||
/// Generate embeddings for multiple texts
|
||||
#[wasm_bindgen(js_name = embedBatch)]
|
||||
pub fn embed_batch(&mut self, texts: Vec<String>) -> std::result::Result<Vec<f32>, JsValue> {
|
||||
let refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
|
||||
self.embed_batch_internal(&refs)
|
||||
.map_err(|e| JsValue::from_str(&e.to_string()))
|
||||
}
|
||||
|
||||
/// Compute similarity between two texts
|
||||
#[wasm_bindgen]
|
||||
pub fn similarity(&mut self, text1: &str, text2: &str) -> std::result::Result<f32, JsValue> {
|
||||
let emb1 = self.embed_one_internal(text1)
|
||||
.map_err(|e| JsValue::from_str(&e.to_string()))?;
|
||||
let emb2 = self.embed_one_internal(text2)
|
||||
.map_err(|e| JsValue::from_str(&e.to_string()))?;
|
||||
|
||||
Ok(cosine_similarity(&emb1, &emb2))
|
||||
}
|
||||
|
||||
/// Get the embedding dimension
|
||||
#[wasm_bindgen]
|
||||
pub fn dimension(&self) -> usize {
|
||||
self.hidden_size
|
||||
}
|
||||
|
||||
/// Get maximum sequence length
|
||||
#[wasm_bindgen(js_name = maxLength)]
|
||||
pub fn max_length(&self) -> usize {
|
||||
self.config.max_length
|
||||
}
|
||||
}
|
||||
|
||||
// Internal implementation
|
||||
impl WasmEmbedder {
|
||||
fn embed_one_internal(&mut self, text: &str) -> Result<Vec<f32>> {
|
||||
// Tokenize
|
||||
let encoded = self.tokenizer.encode(text)?;
|
||||
let attention_mask = encoded.attention_mask.clone();
|
||||
|
||||
// Run inference
|
||||
let raw_output = self.model.run(&encoded)?;
|
||||
|
||||
// Determine hidden size from output
|
||||
let seq_len = self.config.max_length;
|
||||
if raw_output.len() >= seq_len {
|
||||
let detected_hidden = raw_output.len() / seq_len;
|
||||
if detected_hidden != self.hidden_size && detected_hidden > 0 {
|
||||
self.hidden_size = detected_hidden;
|
||||
self.model.set_hidden_size(detected_hidden);
|
||||
}
|
||||
}
|
||||
|
||||
// Apply pooling
|
||||
let mut embedding = self.config.pooling.apply(
|
||||
&raw_output,
|
||||
&attention_mask,
|
||||
self.hidden_size,
|
||||
);
|
||||
|
||||
// Normalize if configured
|
||||
if self.config.normalize {
|
||||
normalize_l2(&mut embedding);
|
||||
}
|
||||
|
||||
Ok(embedding)
|
||||
}
|
||||
|
||||
fn embed_batch_internal(&mut self, texts: &[&str]) -> Result<Vec<f32>> {
|
||||
let mut all_embeddings = Vec::with_capacity(texts.len() * self.hidden_size);
|
||||
|
||||
for text in texts {
|
||||
let embedding = self.embed_one_internal(text)?;
|
||||
all_embeddings.extend(embedding);
|
||||
}
|
||||
|
||||
Ok(all_embeddings)
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute cosine similarity between two embedding vectors (JS-friendly)
|
||||
#[wasm_bindgen(js_name = cosineSimilarity)]
|
||||
pub fn js_cosine_similarity(a: Vec<f32>, b: Vec<f32>) -> f32 {
|
||||
cosine_similarity(&a, &b)
|
||||
}
|
||||
|
||||
/// L2 normalize an embedding vector (JS-friendly)
|
||||
#[wasm_bindgen(js_name = normalizeL2)]
|
||||
pub fn js_normalize_l2(mut embedding: Vec<f32>) -> Vec<f32> {
|
||||
normalize_l2(&mut embedding);
|
||||
embedding
|
||||
}
|
||||
62
vendor/ruvector/examples/onnx-embeddings-wasm/src/error.rs
vendored
Normal file
62
vendor/ruvector/examples/onnx-embeddings-wasm/src/error.rs
vendored
Normal file
@@ -0,0 +1,62 @@
|
||||
//! Error types for WASM embeddings
|
||||
|
||||
use thiserror::Error;
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
/// Result type for WASM embedding operations
|
||||
pub type Result<T> = std::result::Result<T, WasmEmbeddingError>;
|
||||
|
||||
/// Errors that can occur during WASM embedding operations
|
||||
#[derive(Error, Debug)]
|
||||
pub enum WasmEmbeddingError {
|
||||
#[error("Model error: {0}")]
|
||||
Model(String),
|
||||
|
||||
#[error("Tokenizer error: {0}")]
|
||||
Tokenizer(String),
|
||||
|
||||
#[error("Inference error: {0}")]
|
||||
Inference(String),
|
||||
|
||||
#[error("Invalid input: {0}")]
|
||||
InvalidInput(String),
|
||||
|
||||
#[error("Serialization error: {0}")]
|
||||
Serialization(String),
|
||||
}
|
||||
|
||||
impl WasmEmbeddingError {
|
||||
pub fn model(msg: impl Into<String>) -> Self {
|
||||
Self::Model(msg.into())
|
||||
}
|
||||
|
||||
pub fn tokenizer(msg: impl Into<String>) -> Self {
|
||||
Self::Tokenizer(msg.into())
|
||||
}
|
||||
|
||||
pub fn inference(msg: impl Into<String>) -> Self {
|
||||
Self::Inference(msg.into())
|
||||
}
|
||||
|
||||
pub fn invalid_input(msg: impl Into<String>) -> Self {
|
||||
Self::InvalidInput(msg.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<WasmEmbeddingError> for JsValue {
|
||||
fn from(err: WasmEmbeddingError) -> Self {
|
||||
JsValue::from_str(&err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<tract_onnx::prelude::TractError> for WasmEmbeddingError {
|
||||
fn from(err: tract_onnx::prelude::TractError) -> Self {
|
||||
Self::Model(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<serde_json::Error> for WasmEmbeddingError {
|
||||
fn from(err: serde_json::Error) -> Self {
|
||||
Self::Serialization(err.to_string())
|
||||
}
|
||||
}
|
||||
66
vendor/ruvector/examples/onnx-embeddings-wasm/src/lib.rs
vendored
Normal file
66
vendor/ruvector/examples/onnx-embeddings-wasm/src/lib.rs
vendored
Normal file
@@ -0,0 +1,66 @@
|
||||
//! # RuVector ONNX Embeddings - WASM Edition
|
||||
//!
|
||||
//! WASM-compatible embedding generation using Tract for inference.
|
||||
//! Runs in browsers, Cloudflare Workers, Deno, and any WASM runtime.
|
||||
//!
|
||||
//! ## Features
|
||||
//!
|
||||
//! - **Browser Support**: Generate embeddings directly in the browser
|
||||
//! - **Edge Computing**: Deploy to Cloudflare Workers, Vercel Edge, etc.
|
||||
//! - **Portable**: Single WASM binary, no platform-specific dependencies
|
||||
//! - **Same API**: Compatible with the native ruvector-onnx-embeddings crate
|
||||
//!
|
||||
//! ## Usage (JavaScript)
|
||||
//!
|
||||
//! ```javascript
|
||||
//! import init, { WasmEmbedder } from 'ruvector-onnx-embeddings-wasm';
|
||||
//!
|
||||
//! await init();
|
||||
//!
|
||||
//! // Load model from bytes
|
||||
//! const modelBytes = await fetch('/model.onnx').then(r => r.arrayBuffer());
|
||||
//! const tokenizerJson = await fetch('/tokenizer.json').then(r => r.text());
|
||||
//!
|
||||
//! const embedder = new WasmEmbedder(new Uint8Array(modelBytes), tokenizerJson);
|
||||
//!
|
||||
//! // Generate embeddings
|
||||
//! const embedding = embedder.embed_one("Hello, world!");
|
||||
//! console.log("Embedding dimension:", embedding.length);
|
||||
//!
|
||||
//! // Compute similarity
|
||||
//! const similarity = embedder.similarity("I love Rust", "Rust is great");
|
||||
//! console.log("Similarity:", similarity);
|
||||
//! ```
|
||||
|
||||
mod embedder;
|
||||
mod error;
|
||||
mod model;
|
||||
mod pooling;
|
||||
mod tokenizer;
|
||||
|
||||
pub use embedder::{WasmEmbedder, WasmEmbedderConfig};
|
||||
pub use error::WasmEmbeddingError;
|
||||
pub use pooling::PoolingStrategy;
|
||||
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
/// Initialize panic hook for better error messages in WASM
|
||||
#[wasm_bindgen(start)]
|
||||
pub fn init() {
|
||||
#[cfg(feature = "console_error_panic_hook")]
|
||||
console_error_panic_hook::set_once();
|
||||
}
|
||||
|
||||
/// Get the library version
|
||||
#[wasm_bindgen]
|
||||
pub fn version() -> String {
|
||||
env!("CARGO_PKG_VERSION").to_string()
|
||||
}
|
||||
|
||||
/// Check if SIMD is available (for performance info)
|
||||
/// Returns true if compiled with WASM SIMD128 support
|
||||
#[wasm_bindgen]
|
||||
pub fn simd_available() -> bool {
|
||||
// Check if compiled with SIMD128 target feature
|
||||
cfg!(target_feature = "simd128")
|
||||
}
|
||||
116
vendor/ruvector/examples/onnx-embeddings-wasm/src/model.rs
vendored
Normal file
116
vendor/ruvector/examples/onnx-embeddings-wasm/src/model.rs
vendored
Normal file
@@ -0,0 +1,116 @@
|
||||
//! Tract-based ONNX model for WASM inference
|
||||
|
||||
use crate::error::{Result, WasmEmbeddingError};
|
||||
use crate::tokenizer::EncodedInput;
|
||||
use tract_onnx::prelude::*;
|
||||
|
||||
/// Tract ONNX model wrapper for WASM
|
||||
pub struct TractModel {
|
||||
model: SimplePlan<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>,
|
||||
hidden_size: usize,
|
||||
}
|
||||
|
||||
impl TractModel {
|
||||
/// Load model from ONNX bytes
|
||||
pub fn from_bytes(bytes: &[u8], max_seq_length: usize) -> Result<Self> {
|
||||
// Parse ONNX model
|
||||
let model = tract_onnx::onnx()
|
||||
.model_for_read(&mut std::io::Cursor::new(bytes))
|
||||
.map_err(|e| WasmEmbeddingError::model(format!("Failed to parse ONNX: {}", e)))?;
|
||||
|
||||
// Set input shapes for optimization
|
||||
// Standard transformer inputs: [batch, seq_len]
|
||||
let batch = 1usize;
|
||||
let seq_len = max_seq_length;
|
||||
|
||||
let model = model
|
||||
.with_input_fact(
|
||||
0,
|
||||
InferenceFact::dt_shape(i64::datum_type(), tvec![batch, seq_len]),
|
||||
)?
|
||||
.with_input_fact(
|
||||
1,
|
||||
InferenceFact::dt_shape(i64::datum_type(), tvec![batch, seq_len]),
|
||||
)?
|
||||
.with_input_fact(
|
||||
2,
|
||||
InferenceFact::dt_shape(i64::datum_type(), tvec![batch, seq_len]),
|
||||
)?;
|
||||
|
||||
// Optimize the model
|
||||
let model = model
|
||||
.into_optimized()
|
||||
.map_err(|e| WasmEmbeddingError::model(format!("Failed to optimize: {}", e)))?;
|
||||
|
||||
let model = model
|
||||
.into_runnable()
|
||||
.map_err(|e| WasmEmbeddingError::model(format!("Failed to make runnable: {}", e)))?;
|
||||
|
||||
// Default hidden size (will be determined from output)
|
||||
let hidden_size = 384;
|
||||
|
||||
Ok(Self { model, hidden_size })
|
||||
}
|
||||
|
||||
/// Run inference on encoded input
|
||||
pub fn run(&self, input: &EncodedInput) -> Result<Vec<f32>> {
|
||||
let seq_len = input.input_ids.len();
|
||||
|
||||
// Create input tensors
|
||||
let input_ids: Tensor = tract_ndarray::Array2::from_shape_vec(
|
||||
(1, seq_len),
|
||||
input.input_ids.clone(),
|
||||
)
|
||||
.map_err(|e| WasmEmbeddingError::inference(e.to_string()))?
|
||||
.into();
|
||||
|
||||
let attention_mask: Tensor = tract_ndarray::Array2::from_shape_vec(
|
||||
(1, seq_len),
|
||||
input.attention_mask.clone(),
|
||||
)
|
||||
.map_err(|e| WasmEmbeddingError::inference(e.to_string()))?
|
||||
.into();
|
||||
|
||||
let token_type_ids: Tensor = tract_ndarray::Array2::from_shape_vec(
|
||||
(1, seq_len),
|
||||
input.token_type_ids.clone(),
|
||||
)
|
||||
.map_err(|e| WasmEmbeddingError::inference(e.to_string()))?
|
||||
.into();
|
||||
|
||||
// Run inference
|
||||
let inputs = tvec![
|
||||
input_ids.into(),
|
||||
attention_mask.into(),
|
||||
token_type_ids.into()
|
||||
];
|
||||
|
||||
let outputs = self
|
||||
.model
|
||||
.run(inputs)
|
||||
.map_err(|e| WasmEmbeddingError::inference(format!("Inference failed: {}", e)))?;
|
||||
|
||||
// Extract output tensor
|
||||
// Output is typically [batch, seq_len, hidden_size] or [batch, hidden_size]
|
||||
let output = outputs
|
||||
.first()
|
||||
.ok_or_else(|| WasmEmbeddingError::inference("No output tensor"))?;
|
||||
|
||||
let output_array = output
|
||||
.to_array_view::<f32>()
|
||||
.map_err(|e| WasmEmbeddingError::inference(format!("Failed to extract output: {}", e)))?;
|
||||
|
||||
// Flatten and return
|
||||
Ok(output_array.iter().copied().collect())
|
||||
}
|
||||
|
||||
/// Get the hidden size
|
||||
pub fn hidden_size(&self) -> usize {
|
||||
self.hidden_size
|
||||
}
|
||||
|
||||
/// Update hidden size (called after first inference)
|
||||
pub fn set_hidden_size(&mut self, size: usize) {
|
||||
self.hidden_size = size;
|
||||
}
|
||||
}
|
||||
181
vendor/ruvector/examples/onnx-embeddings-wasm/src/pooling.rs
vendored
Normal file
181
vendor/ruvector/examples/onnx-embeddings-wasm/src/pooling.rs
vendored
Normal file
@@ -0,0 +1,181 @@
|
||||
//! Pooling strategies for converting token embeddings to sentence embeddings
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
/// Strategy for pooling token embeddings into a single sentence embedding
|
||||
#[wasm_bindgen]
|
||||
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq)]
|
||||
pub enum PoolingStrategy {
|
||||
/// Average all token embeddings (most common)
|
||||
#[default]
|
||||
Mean,
|
||||
/// Use only the [CLS] token embedding
|
||||
Cls,
|
||||
/// Take the maximum value across all tokens for each dimension
|
||||
Max,
|
||||
/// Mean pooling normalized by sqrt of sequence length
|
||||
MeanSqrtLen,
|
||||
/// Use the last token embedding (for decoder models)
|
||||
LastToken,
|
||||
}
|
||||
|
||||
impl PoolingStrategy {
|
||||
/// Apply pooling to token embeddings
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `embeddings` - Token embeddings [seq_len, hidden_size]
|
||||
/// * `attention_mask` - Attention mask [seq_len]
|
||||
///
|
||||
/// # Returns
|
||||
/// Pooled embedding [hidden_size]
|
||||
pub fn apply(&self, embeddings: &[f32], attention_mask: &[i64], hidden_size: usize) -> Vec<f32> {
|
||||
let seq_len = attention_mask.len();
|
||||
|
||||
if embeddings.is_empty() || hidden_size == 0 {
|
||||
return vec![0.0; hidden_size];
|
||||
}
|
||||
|
||||
match self {
|
||||
PoolingStrategy::Mean => {
|
||||
self.mean_pooling(embeddings, attention_mask, hidden_size, seq_len)
|
||||
}
|
||||
PoolingStrategy::Cls => {
|
||||
// First token (CLS)
|
||||
embeddings[..hidden_size].to_vec()
|
||||
}
|
||||
PoolingStrategy::Max => {
|
||||
self.max_pooling(embeddings, attention_mask, hidden_size, seq_len)
|
||||
}
|
||||
PoolingStrategy::MeanSqrtLen => {
|
||||
let mut pooled = self.mean_pooling(embeddings, attention_mask, hidden_size, seq_len);
|
||||
let valid_tokens: f32 = attention_mask.iter().map(|&m| m as f32).sum();
|
||||
let scale = 1.0 / valid_tokens.sqrt();
|
||||
for v in &mut pooled {
|
||||
*v *= scale;
|
||||
}
|
||||
pooled
|
||||
}
|
||||
PoolingStrategy::LastToken => {
|
||||
// Find last valid token
|
||||
let last_idx = attention_mask
|
||||
.iter()
|
||||
.rposition(|&m| m == 1)
|
||||
.unwrap_or(0);
|
||||
let start = last_idx * hidden_size;
|
||||
embeddings[start..start + hidden_size].to_vec()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn mean_pooling(
|
||||
&self,
|
||||
embeddings: &[f32],
|
||||
attention_mask: &[i64],
|
||||
hidden_size: usize,
|
||||
seq_len: usize,
|
||||
) -> Vec<f32> {
|
||||
let mut pooled = vec![0.0f32; hidden_size];
|
||||
let mut count = 0.0f32;
|
||||
|
||||
for (i, &mask) in attention_mask.iter().enumerate() {
|
||||
if mask == 1 && i < seq_len {
|
||||
let start = i * hidden_size;
|
||||
if start + hidden_size <= embeddings.len() {
|
||||
for (j, v) in pooled.iter_mut().enumerate() {
|
||||
*v += embeddings[start + j];
|
||||
}
|
||||
count += 1.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if count > 0.0 {
|
||||
for v in &mut pooled {
|
||||
*v /= count;
|
||||
}
|
||||
}
|
||||
|
||||
pooled
|
||||
}
|
||||
|
||||
fn max_pooling(
|
||||
&self,
|
||||
embeddings: &[f32],
|
||||
attention_mask: &[i64],
|
||||
hidden_size: usize,
|
||||
seq_len: usize,
|
||||
) -> Vec<f32> {
|
||||
let mut pooled = vec![f32::NEG_INFINITY; hidden_size];
|
||||
|
||||
for (i, &mask) in attention_mask.iter().enumerate() {
|
||||
if mask == 1 && i < seq_len {
|
||||
let start = i * hidden_size;
|
||||
if start + hidden_size <= embeddings.len() {
|
||||
for (j, v) in pooled.iter_mut().enumerate() {
|
||||
*v = v.max(embeddings[start + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Replace -inf with 0 for dimensions with no valid tokens
|
||||
for v in &mut pooled {
|
||||
if v.is_infinite() {
|
||||
*v = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
pooled
|
||||
}
|
||||
}
|
||||
|
||||
/// L2 normalize a vector in place
|
||||
pub fn normalize_l2(embedding: &mut [f32]) {
|
||||
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm > 0.0 {
|
||||
for v in embedding {
|
||||
*v /= norm;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute cosine similarity between two embeddings
|
||||
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||
if a.len() != b.len() || a.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
|
||||
if norm_a > 0.0 && norm_b > 0.0 {
|
||||
dot / (norm_a * norm_b)
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_cosine_similarity() {
|
||||
let a = vec![1.0, 0.0, 0.0];
|
||||
let b = vec![1.0, 0.0, 0.0];
|
||||
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
|
||||
|
||||
let c = vec![0.0, 1.0, 0.0];
|
||||
assert!(cosine_similarity(&a, &c).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize_l2() {
|
||||
let mut v = vec![3.0, 4.0];
|
||||
normalize_l2(&mut v);
|
||||
assert!((v[0] - 0.6).abs() < 1e-6);
|
||||
assert!((v[1] - 0.8).abs() < 1e-6);
|
||||
}
|
||||
}
|
||||
114
vendor/ruvector/examples/onnx-embeddings-wasm/src/tokenizer.rs
vendored
Normal file
114
vendor/ruvector/examples/onnx-embeddings-wasm/src/tokenizer.rs
vendored
Normal file
@@ -0,0 +1,114 @@
|
||||
//! Tokenizer wrapper for WASM embedding generation
|
||||
|
||||
use crate::error::{Result, WasmEmbeddingError};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
/// Tokenizer wrapper that handles text encoding
|
||||
pub struct WasmTokenizer {
|
||||
tokenizer: Tokenizer,
|
||||
max_length: usize,
|
||||
}
|
||||
|
||||
/// Encoded text ready for model inference
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EncodedInput {
|
||||
pub input_ids: Vec<i64>,
|
||||
pub attention_mask: Vec<i64>,
|
||||
pub token_type_ids: Vec<i64>,
|
||||
}
|
||||
|
||||
impl WasmTokenizer {
|
||||
/// Create a new tokenizer from JSON configuration
|
||||
pub fn from_json(json: &str, max_length: usize) -> Result<Self> {
|
||||
let tokenizer = Tokenizer::from_bytes(json.as_bytes())
|
||||
.map_err(|e| WasmEmbeddingError::tokenizer(e.to_string()))?;
|
||||
|
||||
Ok(Self {
|
||||
tokenizer,
|
||||
max_length,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create tokenizer from raw bytes
|
||||
pub fn from_bytes(bytes: &[u8], max_length: usize) -> Result<Self> {
|
||||
let tokenizer = Tokenizer::from_bytes(bytes)
|
||||
.map_err(|e| WasmEmbeddingError::tokenizer(e.to_string()))?;
|
||||
|
||||
Ok(Self {
|
||||
tokenizer,
|
||||
max_length,
|
||||
})
|
||||
}
|
||||
|
||||
/// Encode a single text
|
||||
pub fn encode(&self, text: &str) -> Result<EncodedInput> {
|
||||
let encoding = self
|
||||
.tokenizer
|
||||
.encode(text, true)
|
||||
.map_err(|e| WasmEmbeddingError::tokenizer(e.to_string()))?;
|
||||
|
||||
let mut input_ids: Vec<i64> = encoding.get_ids().iter().map(|&id| id as i64).collect();
|
||||
let mut attention_mask: Vec<i64> =
|
||||
encoding.get_attention_mask().iter().map(|&m| m as i64).collect();
|
||||
let mut token_type_ids: Vec<i64> =
|
||||
encoding.get_type_ids().iter().map(|&t| t as i64).collect();
|
||||
|
||||
// Truncate if necessary
|
||||
if input_ids.len() > self.max_length {
|
||||
input_ids.truncate(self.max_length);
|
||||
attention_mask.truncate(self.max_length);
|
||||
token_type_ids.truncate(self.max_length);
|
||||
}
|
||||
|
||||
// Pad if necessary
|
||||
while input_ids.len() < self.max_length {
|
||||
input_ids.push(0);
|
||||
attention_mask.push(0);
|
||||
token_type_ids.push(0);
|
||||
}
|
||||
|
||||
Ok(EncodedInput {
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
})
|
||||
}
|
||||
|
||||
/// Encode multiple texts with padding to the same length
|
||||
pub fn encode_batch(&self, texts: &[&str]) -> Result<Vec<EncodedInput>> {
|
||||
texts.iter().map(|text| self.encode(text)).collect()
|
||||
}
|
||||
|
||||
/// Get the maximum sequence length
|
||||
pub fn max_length(&self) -> usize {
|
||||
self.max_length
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// Basic tokenizer JSON for testing
|
||||
const TEST_TOKENIZER: &str = r#"{
|
||||
"version": "1.0",
|
||||
"truncation": null,
|
||||
"padding": null,
|
||||
"added_tokens": [],
|
||||
"normalizer": null,
|
||||
"pre_tokenizer": {"type": "Whitespace"},
|
||||
"post_processor": null,
|
||||
"decoder": null,
|
||||
"model": {
|
||||
"type": "WordLevel",
|
||||
"vocab": {"[PAD]": 0, "[UNK]": 1, "hello": 2, "world": 3},
|
||||
"unk_token": "[UNK]"
|
||||
}
|
||||
}"#;
|
||||
|
||||
#[test]
|
||||
fn test_tokenizer_creation() {
|
||||
let tokenizer = WasmTokenizer::from_json(TEST_TOKENIZER, 128);
|
||||
assert!(tokenizer.is_ok());
|
||||
}
|
||||
}
|
||||
142
vendor/ruvector/examples/onnx-embeddings-wasm/test-full.mjs
vendored
Normal file
142
vendor/ruvector/examples/onnx-embeddings-wasm/test-full.mjs
vendored
Normal file
@@ -0,0 +1,142 @@
|
||||
#!/usr/bin/env node
|
||||
/**
|
||||
* Full end-to-end test with model download
|
||||
*
|
||||
* Downloads all-MiniLM-L6-v2 and runs embedding tests
|
||||
*/
|
||||
|
||||
import { ModelLoader, MODELS, DEFAULT_MODEL } from './loader.js';
|
||||
import {
|
||||
WasmEmbedder,
|
||||
WasmEmbedderConfig,
|
||||
cosineSimilarity,
|
||||
} from './pkg/ruvector_onnx_embeddings_wasm.js';
|
||||
|
||||
console.log('🧪 RuVector ONNX Embeddings WASM - Full E2E Test\n');
|
||||
console.log('='.repeat(60));
|
||||
|
||||
// List available models
|
||||
console.log('\n📦 Available Models:');
|
||||
ModelLoader.listModels().forEach(m => {
|
||||
const isDefault = m.id === DEFAULT_MODEL ? ' ⭐ DEFAULT' : '';
|
||||
console.log(` • ${m.id} (${m.dimension}d, ${m.size})${isDefault}`);
|
||||
console.log(` ${m.description}`);
|
||||
});
|
||||
|
||||
console.log('\n' + '='.repeat(60));
|
||||
console.log(`\n🔄 Loading model: ${DEFAULT_MODEL}...\n`);
|
||||
|
||||
// Load model with progress
|
||||
const loader = new ModelLoader({
|
||||
cache: false, // Disable cache for testing
|
||||
onProgress: ({ loaded, total, percent }) => {
|
||||
process.stdout.write(`\r Progress: ${percent}% (${(loaded/1024/1024).toFixed(1)}MB / ${(total/1024/1024).toFixed(1)}MB)`);
|
||||
}
|
||||
});
|
||||
|
||||
try {
|
||||
const { modelBytes, tokenizerJson, config } = await loader.loadModel(DEFAULT_MODEL);
|
||||
console.log('\n');
|
||||
console.log(` ✅ Model loaded: ${config.name}`);
|
||||
console.log(` ✅ Model size: ${(modelBytes.length / 1024 / 1024).toFixed(2)} MB`);
|
||||
console.log(` ✅ Tokenizer size: ${(tokenizerJson.length / 1024).toFixed(2)} KB`);
|
||||
|
||||
// Create embedder
|
||||
console.log('\n🔧 Creating embedder...');
|
||||
const embedderConfig = new WasmEmbedderConfig()
|
||||
.setMaxLength(config.maxLength)
|
||||
.setNormalize(true)
|
||||
.setPooling(0);
|
||||
|
||||
const embedder = WasmEmbedder.withConfig(modelBytes, tokenizerJson, embedderConfig);
|
||||
console.log(` ✅ Embedder created`);
|
||||
console.log(` ✅ Dimension: ${embedder.dimension()}`);
|
||||
console.log(` ✅ Max length: ${embedder.maxLength()}`);
|
||||
|
||||
// Test 1: Single embedding
|
||||
console.log('\n' + '='.repeat(60));
|
||||
console.log('\n📝 Test 1: Single Embedding');
|
||||
const text1 = "The quick brown fox jumps over the lazy dog.";
|
||||
console.log(` Input: "${text1}"`);
|
||||
|
||||
const start1 = performance.now();
|
||||
const embedding1 = embedder.embedOne(text1);
|
||||
const time1 = performance.now() - start1;
|
||||
|
||||
console.log(` ✅ Output dimension: ${embedding1.length}`);
|
||||
console.log(` ✅ First 5 values: [${Array.from(embedding1.slice(0, 5)).map(v => v.toFixed(4)).join(', ')}]`);
|
||||
console.log(` ✅ Time: ${time1.toFixed(2)}ms`);
|
||||
|
||||
// Test 2: Semantic similarity
|
||||
console.log('\n' + '='.repeat(60));
|
||||
console.log('\n📝 Test 2: Semantic Similarity');
|
||||
|
||||
const pairs = [
|
||||
["I love programming in Rust", "Rust is my favorite programming language"],
|
||||
["The weather is nice today", "It's sunny outside"],
|
||||
["I love programming in Rust", "The weather is nice today"],
|
||||
["Machine learning is fascinating", "AI and deep learning are interesting"],
|
||||
];
|
||||
|
||||
for (const [a, b] of pairs) {
|
||||
const start = performance.now();
|
||||
const sim = embedder.similarity(a, b);
|
||||
const time = performance.now() - start;
|
||||
|
||||
const label = sim > 0.5 ? '🟢 Similar' : '🔴 Different';
|
||||
console.log(`\n "${a.substring(0, 30)}..."`);
|
||||
console.log(` "${b.substring(0, 30)}..."`);
|
||||
console.log(` ${label}: ${sim.toFixed(4)} (${time.toFixed(1)}ms)`);
|
||||
}
|
||||
|
||||
// Test 3: Batch embedding
|
||||
console.log('\n' + '='.repeat(60));
|
||||
console.log('\n📝 Test 3: Batch Embedding');
|
||||
|
||||
const texts = [
|
||||
"Artificial intelligence is transforming technology.",
|
||||
"Machine learning models learn from data.",
|
||||
"Deep learning uses neural networks.",
|
||||
"Vector databases enable semantic search.",
|
||||
];
|
||||
|
||||
console.log(` Embedding ${texts.length} texts...`);
|
||||
const start3 = performance.now();
|
||||
const batchEmbeddings = embedder.embedBatch(texts);
|
||||
const time3 = performance.now() - start3;
|
||||
|
||||
const embeddingDim = embedder.dimension();
|
||||
const numEmbeddings = batchEmbeddings.length / embeddingDim;
|
||||
|
||||
console.log(` ✅ Total values: ${batchEmbeddings.length}`);
|
||||
console.log(` ✅ Embeddings: ${numEmbeddings} x ${embeddingDim}d`);
|
||||
console.log(` ✅ Time: ${time3.toFixed(2)}ms (${(time3/texts.length).toFixed(2)}ms per text)`);
|
||||
|
||||
// Compute pairwise similarities
|
||||
console.log('\n Pairwise similarities:');
|
||||
for (let i = 0; i < numEmbeddings; i++) {
|
||||
for (let j = i + 1; j < numEmbeddings; j++) {
|
||||
const emb_i = batchEmbeddings.slice(i * embeddingDim, (i + 1) * embeddingDim);
|
||||
const emb_j = batchEmbeddings.slice(j * embeddingDim, (j + 1) * embeddingDim);
|
||||
const sim = cosineSimilarity(emb_i, emb_j);
|
||||
console.log(` [${i}] vs [${j}]: ${sim.toFixed(4)}`);
|
||||
}
|
||||
}
|
||||
|
||||
// Summary
|
||||
console.log('\n' + '='.repeat(60));
|
||||
console.log('\n✅ All tests passed!');
|
||||
console.log('='.repeat(60));
|
||||
|
||||
console.log('\n📊 Performance Summary:');
|
||||
console.log(` • Model: ${config.name}`);
|
||||
console.log(` • Dimension: ${embeddingDim}`);
|
||||
console.log(` • Single embed: ~${time1.toFixed(0)}ms`);
|
||||
console.log(` • Batch (4 texts): ~${time3.toFixed(0)}ms`);
|
||||
console.log(` • Throughput: ~${(1000 / (time3/texts.length)).toFixed(0)} texts/sec`);
|
||||
|
||||
} catch (error) {
|
||||
console.error('\n❌ Error:', error.message);
|
||||
console.error(error.stack);
|
||||
process.exit(1);
|
||||
}
|
||||
121
vendor/ruvector/examples/onnx-embeddings-wasm/test-parallel.mjs
vendored
Normal file
121
vendor/ruvector/examples/onnx-embeddings-wasm/test-parallel.mjs
vendored
Normal file
@@ -0,0 +1,121 @@
|
||||
#!/usr/bin/env node
|
||||
/**
|
||||
* Benchmark: Sequential vs Parallel ONNX Embeddings
|
||||
*/
|
||||
import { cpus } from 'os';
|
||||
import { ParallelEmbedder } from './parallel-embedder.mjs';
|
||||
import { createEmbedder } from './loader.js';
|
||||
|
||||
console.log('🧪 Parallel vs Sequential ONNX Embeddings Benchmark\n');
|
||||
console.log(`CPU Cores: ${cpus().length}`);
|
||||
console.log('='.repeat(60));
|
||||
|
||||
// Test data - various batch sizes
|
||||
const testTexts = [
|
||||
"Machine learning is transforming technology",
|
||||
"Deep learning uses neural networks",
|
||||
"Natural language processing understands text",
|
||||
"Computer vision analyzes images",
|
||||
"Reinforcement learning learns from rewards",
|
||||
"Generative AI creates new content",
|
||||
"Vector databases enable semantic search",
|
||||
"Embeddings capture semantic meaning",
|
||||
"Transformers revolutionized NLP",
|
||||
"BERT is a popular language model",
|
||||
"GPT generates human-like text",
|
||||
"RAG combines retrieval and generation",
|
||||
];
|
||||
|
||||
async function benchmarkSequential(embedder, texts, iterations = 3) {
|
||||
const times = [];
|
||||
for (let i = 0; i < iterations; i++) {
|
||||
const start = performance.now();
|
||||
for (const text of texts) {
|
||||
embedder.embedOne(text);
|
||||
}
|
||||
times.push(performance.now() - start);
|
||||
}
|
||||
return times.reduce((a, b) => a + b) / times.length;
|
||||
}
|
||||
|
||||
async function benchmarkParallel(embedder, texts, iterations = 3) {
|
||||
const times = [];
|
||||
for (let i = 0; i < iterations; i++) {
|
||||
const start = performance.now();
|
||||
await embedder.embedBatch(texts);
|
||||
times.push(performance.now() - start);
|
||||
}
|
||||
return times.reduce((a, b) => a + b) / times.length;
|
||||
}
|
||||
|
||||
async function main() {
|
||||
try {
|
||||
// Initialize sequential embedder
|
||||
console.log('\n📦 Loading model for sequential test...');
|
||||
const seqEmbedder = await createEmbedder();
|
||||
console.log('✅ Sequential embedder ready\n');
|
||||
|
||||
// Warm up
|
||||
seqEmbedder.embedOne("warmup");
|
||||
|
||||
// Initialize parallel embedder
|
||||
console.log('📦 Initializing parallel embedder...');
|
||||
const parEmbedder = new ParallelEmbedder({ numWorkers: Math.min(4, cpus().length) });
|
||||
await parEmbedder.init();
|
||||
|
||||
// Benchmark different batch sizes
|
||||
for (const batchSize of [4, 8, 12]) {
|
||||
const texts = testTexts.slice(0, batchSize);
|
||||
|
||||
console.log(`\n${'='.repeat(60)}`);
|
||||
console.log(`📊 Batch Size: ${batchSize} texts`);
|
||||
console.log('='.repeat(60));
|
||||
|
||||
// Sequential benchmark
|
||||
console.log('\n⏱️ Sequential (single-threaded)...');
|
||||
const seqTime = await benchmarkSequential(seqEmbedder, texts);
|
||||
console.log(` Time: ${seqTime.toFixed(1)}ms`);
|
||||
console.log(` Per text: ${(seqTime / batchSize).toFixed(1)}ms`);
|
||||
|
||||
// Parallel benchmark
|
||||
console.log('\n⏱️ Parallel (worker threads)...');
|
||||
const parTime = await benchmarkParallel(parEmbedder, texts);
|
||||
console.log(` Time: ${parTime.toFixed(1)}ms`);
|
||||
console.log(` Per text: ${(parTime / batchSize).toFixed(1)}ms`);
|
||||
|
||||
// Speedup
|
||||
const speedup = seqTime / parTime;
|
||||
const icon = speedup > 1.2 ? '🚀' : speedup > 1 ? '✅' : '⚠️';
|
||||
console.log(`\n${icon} Speedup: ${speedup.toFixed(2)}x`);
|
||||
}
|
||||
|
||||
// Verify correctness
|
||||
console.log(`\n${'='.repeat(60)}`);
|
||||
console.log('🔍 Verifying correctness...');
|
||||
console.log('='.repeat(60));
|
||||
|
||||
const testText = "Vector databases are awesome";
|
||||
const seqEmb = seqEmbedder.embedOne(testText);
|
||||
const parEmb = await parEmbedder.embedOne(testText);
|
||||
|
||||
// Compare embeddings
|
||||
let diff = 0;
|
||||
for (let i = 0; i < seqEmb.length; i++) {
|
||||
diff += Math.abs(seqEmb[i] - parEmb[i]);
|
||||
}
|
||||
const avgDiff = diff / seqEmb.length;
|
||||
console.log(`\nEmbedding difference: ${avgDiff.toExponential(4)}`);
|
||||
console.log(avgDiff < 1e-6 ? '✅ Embeddings match!' : '⚠️ Embeddings differ');
|
||||
|
||||
// Cleanup
|
||||
await parEmbedder.shutdown();
|
||||
console.log('\n✅ Benchmark complete!');
|
||||
|
||||
} catch (error) {
|
||||
console.error('❌ Error:', error.message);
|
||||
console.error(error.stack);
|
||||
process.exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
main();
|
||||
72
vendor/ruvector/examples/onnx-embeddings-wasm/test.mjs
vendored
Normal file
72
vendor/ruvector/examples/onnx-embeddings-wasm/test.mjs
vendored
Normal file
@@ -0,0 +1,72 @@
|
||||
#!/usr/bin/env node
|
||||
/**
|
||||
* Test script to validate WASM embeddings package works
|
||||
*/
|
||||
|
||||
import {
|
||||
version,
|
||||
simd_available,
|
||||
cosineSimilarity,
|
||||
normalizeL2,
|
||||
WasmEmbedderConfig
|
||||
} from './pkg/ruvector_onnx_embeddings_wasm.js';
|
||||
|
||||
console.log('🧪 RuVector ONNX Embeddings WASM - Validation Test\n');
|
||||
|
||||
// Test 1: Version check
|
||||
console.log('Test 1: Version check');
|
||||
const ver = version();
|
||||
console.log(` ✅ Version: ${ver}`);
|
||||
|
||||
// Test 2: SIMD availability
|
||||
console.log('\nTest 2: SIMD availability');
|
||||
const simd = simd_available();
|
||||
console.log(` ✅ SIMD available: ${simd}`);
|
||||
|
||||
// Test 3: Cosine similarity utility
|
||||
console.log('\nTest 3: Cosine similarity utility');
|
||||
const a = new Float32Array([1.0, 0.0, 0.0]);
|
||||
const b = new Float32Array([1.0, 0.0, 0.0]);
|
||||
const c = new Float32Array([0.0, 1.0, 0.0]);
|
||||
|
||||
const simSame = cosineSimilarity(a, b);
|
||||
const simDiff = cosineSimilarity(a, c);
|
||||
|
||||
console.log(` ✅ Same vectors similarity: ${simSame.toFixed(4)} (expected: 1.0)`);
|
||||
console.log(` ✅ Orthogonal vectors similarity: ${simDiff.toFixed(4)} (expected: 0.0)`);
|
||||
|
||||
if (Math.abs(simSame - 1.0) > 0.0001) {
|
||||
throw new Error('Cosine similarity failed for same vectors');
|
||||
}
|
||||
if (Math.abs(simDiff) > 0.0001) {
|
||||
throw new Error('Cosine similarity failed for orthogonal vectors');
|
||||
}
|
||||
|
||||
// Test 4: L2 normalization utility
|
||||
console.log('\nTest 4: L2 normalization utility');
|
||||
const unnormalized = new Float32Array([3.0, 4.0]);
|
||||
const normalized = normalizeL2(unnormalized);
|
||||
|
||||
const norm = Math.sqrt(normalized[0]**2 + normalized[1]**2);
|
||||
console.log(` ✅ Normalized vector: [${normalized[0].toFixed(4)}, ${normalized[1].toFixed(4)}]`);
|
||||
console.log(` ✅ L2 norm: ${norm.toFixed(4)} (expected: 1.0)`);
|
||||
|
||||
if (Math.abs(norm - 1.0) > 0.0001) {
|
||||
throw new Error('L2 normalization failed');
|
||||
}
|
||||
|
||||
// Test 5: Config creation
|
||||
console.log('\nTest 5: WasmEmbedderConfig creation');
|
||||
const config = new WasmEmbedderConfig()
|
||||
.setMaxLength(256)
|
||||
.setNormalize(true)
|
||||
.setPooling(0);
|
||||
console.log(' ✅ Config created and chained successfully');
|
||||
|
||||
// Note: config is consumed by chaining, no need to free
|
||||
|
||||
console.log('\n' + '='.repeat(50));
|
||||
console.log('✅ All utility tests passed!');
|
||||
console.log('='.repeat(50));
|
||||
console.log('\n📝 Note: Full embedder test requires ONNX model + tokenizer files.');
|
||||
console.log(' The core WASM bindings are working correctly.\n');
|
||||
Reference in New Issue
Block a user