Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
5
vendor/ruvector/crates/ruvector-wasm/.cargo/config.toml
vendored
Normal file
5
vendor/ruvector/crates/ruvector-wasm/.cargo/config.toml
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
[build]
|
||||
target = "wasm32-unknown-unknown"
|
||||
|
||||
[target.wasm32-unknown-unknown]
|
||||
rustflags = ['--cfg', 'getrandom_backend="wasm_js"']
|
||||
91
vendor/ruvector/crates/ruvector-wasm/Cargo.toml
vendored
Normal file
91
vendor/ruvector/crates/ruvector-wasm/Cargo.toml
vendored
Normal file
@@ -0,0 +1,91 @@
|
||||
[package]
|
||||
name = "ruvector-wasm"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
rust-version.workspace = true
|
||||
license.workspace = true
|
||||
authors.workspace = true
|
||||
repository.workspace = true
|
||||
readme = "README.md"
|
||||
description = "WASM bindings for Ruvector including kernel pack system (ADR-005)"
|
||||
|
||||
[lib]
|
||||
crate-type = ["cdylib", "rlib"]
|
||||
|
||||
[dependencies]
|
||||
ruvector-core = { path = "../ruvector-core", default-features = false, features = ["memory-only", "uuid-support"] }
|
||||
ruvector-collections = { path = "../ruvector-collections", optional = true }
|
||||
ruvector-filter = { path = "../ruvector-filter", optional = true }
|
||||
parking_lot = { workspace = true }
|
||||
getrandom = { workspace = true }
|
||||
|
||||
# Add getrandom 0.2 with js feature for WASM compatibility
|
||||
# This ensures all transitive dependencies use the WASM-compatible version
|
||||
getrandom02 = { package = "getrandom", version = "0.2", features = ["js"] }
|
||||
|
||||
# WASM
|
||||
wasm-bindgen = { workspace = true }
|
||||
wasm-bindgen-futures = { workspace = true }
|
||||
js-sys = { workspace = true }
|
||||
web-sys = { workspace = true, features = [
|
||||
"console",
|
||||
"Window",
|
||||
"IdbDatabase",
|
||||
"IdbFactory",
|
||||
"IdbObjectStore",
|
||||
"IdbRequest",
|
||||
"IdbTransaction",
|
||||
"IdbOpenDbRequest",
|
||||
] }
|
||||
|
||||
# Error handling
|
||||
thiserror = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
|
||||
# Serialization
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
serde-wasm-bindgen = "0.6"
|
||||
|
||||
# Utils
|
||||
console_error_panic_hook = "0.1"
|
||||
tracing-wasm = "0.2"
|
||||
|
||||
# Cryptography for kernel pack verification (ADR-005)
|
||||
sha2 = { version = "0.10", optional = true }
|
||||
ed25519-dalek = { version = "2.1", optional = true }
|
||||
hex = { version = "0.4", optional = true }
|
||||
base64 = { version = "0.22", optional = true }
|
||||
rand = { workspace = true, optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
wasm-bindgen-test = "0.3"
|
||||
rand = { workspace = true }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
simd = ["ruvector-core/simd"]
|
||||
# Collections and filter features (not available in WASM due to file I/O requirements)
|
||||
# These features are provided for completeness but will not work in browser WASM
|
||||
collections = ["dep:ruvector-collections", "dep:ruvector-filter"]
|
||||
# Kernel pack system (ADR-005) - sandboxed compute kernel execution
|
||||
kernel-pack = ["dep:sha2", "dep:ed25519-dalek", "dep:hex", "dep:base64"]
|
||||
# Enable kernel signing capability (requires rand)
|
||||
signing = ["kernel-pack", "dep:rand"]
|
||||
|
||||
# Ensure getrandom uses wasm_js/js features for WASM (both 0.2 and 0.3 versions)
|
||||
[target.'cfg(target_arch = "wasm32")'.dependencies]
|
||||
# getrandom 0.3.x uses wasm_js feature
|
||||
getrandom = { workspace = true, features = ["wasm_js"] }
|
||||
|
||||
[profile.release]
|
||||
opt-level = "z"
|
||||
lto = true
|
||||
codegen-units = 1
|
||||
panic = "abort"
|
||||
|
||||
[profile.release.package."*"]
|
||||
opt-level = "z"
|
||||
|
||||
[package.metadata.wasm-pack.profile.release]
|
||||
wasm-opt = false
|
||||
202
vendor/ruvector/crates/ruvector-wasm/INTEGRATION_STATUS.md
vendored
Normal file
202
vendor/ruvector/crates/ruvector-wasm/INTEGRATION_STATUS.md
vendored
Normal file
@@ -0,0 +1,202 @@
|
||||
# ruvector-wasm Integration Status
|
||||
|
||||
## Summary
|
||||
|
||||
The ruvector-wasm crate has been updated to integrate ruvector-collections and ruvector-filter functionality. However, compilation is currently blocked by pre-existing issues in ruvector-core.
|
||||
|
||||
## Changes Made
|
||||
|
||||
### 1. Cargo.toml Updates
|
||||
|
||||
#### Added Dependencies:
|
||||
```toml
|
||||
ruvector-collections = { path = "../ruvector-collections", optional = true }
|
||||
ruvector-filter = { path = "../ruvector-filter", optional = true }
|
||||
getrandom02 = { package = "getrandom", version = "0.2", features = ["js"] }
|
||||
```
|
||||
|
||||
#### Added Features:
|
||||
```toml
|
||||
[features]
|
||||
collections = ["dep:ruvector-collections", "dep:ruvector-filter"]
|
||||
```
|
||||
|
||||
#### WASM Configuration:
|
||||
```toml
|
||||
[target.'cfg(target_arch = "wasm32")'.dependencies]
|
||||
getrandom = { workspace = true, features = ["wasm_js"] }
|
||||
```
|
||||
|
||||
### 2. src/lib.rs Updates
|
||||
|
||||
#### Added CollectionManager (Lines 411-587):
|
||||
- `new(base_path: Option<String>)` - Create collection manager
|
||||
- `create_collection(name, dimensions, metric)` - Create new collection
|
||||
- `list_collections()` - List all collections
|
||||
- `delete_collection(name)` - Delete a collection
|
||||
- `get_collection(name)` - Get collection's VectorDB
|
||||
- `create_alias(alias, collection)` - Create an alias
|
||||
- `delete_alias(alias)` - Delete an alias
|
||||
- `list_aliases()` - List all aliases
|
||||
|
||||
#### Added FilterBuilder (Lines 591-799):
|
||||
- `eq(field, value)` - Equality filter
|
||||
- `ne(field, value)` - Not-equal filter
|
||||
- `gt(field, value)` - Greater-than filter
|
||||
- `gte(field, value)` - Greater-than-or-equal filter
|
||||
- `lt(field, value)` - Less-than filter
|
||||
- `lte(field, value)` - Less-than-or-equal filter
|
||||
- `in_values(field, values)` - IN filter
|
||||
- `match_text(field, text)` - Text match filter
|
||||
- `geo_radius(field, lat, lon, radius_m)` - Geo radius filter
|
||||
- `and(filters)` - AND combinator
|
||||
- `or(filters)` - OR combinator
|
||||
- `not(filter)` - NOT combinator
|
||||
- `exists(field)` - Field exists filter
|
||||
- `is_null(field)` - Field is null filter
|
||||
- `to_json()` - Convert to JavaScript object
|
||||
- `get_fields()` - Get referenced field names
|
||||
|
||||
## Current Issues
|
||||
|
||||
### Compilation Blockers
|
||||
|
||||
The ruvector-core crate has conditional compilation issues that prevent WASM builds:
|
||||
|
||||
1. **redb dependency**: Code in `error.rs` uses `redb` types without `#[cfg(feature = "storage")]` guards
|
||||
2. **hnsw_rs dependency**: Code in `index/hnsw.rs` uses `hnsw_rs` without `#[cfg(feature = "hnsw")]` guards
|
||||
3. **uuid dependency**: Some code uses `uuid::Uuid` without proper feature guards
|
||||
|
||||
### Architectural Limitations
|
||||
|
||||
**Collections and Filter in WASM**: The ruvector-collections crate relies on file I/O and memory-mapped files (via mmap-rs), which are not available in browser WASM environments. These features are marked as optional and require the `collections` feature to be enabled.
|
||||
|
||||
## Usage
|
||||
|
||||
### Standard WASM Build (Browser):
|
||||
```bash
|
||||
cd crates/ruvector-wasm
|
||||
cargo build --target wasm32-unknown-unknown --release
|
||||
```
|
||||
|
||||
This builds only the core VectorDB functionality without collections or filter support.
|
||||
|
||||
### WASM with Collections (WASI/Server):
|
||||
```bash
|
||||
cargo build --target wasm32-unknown-unknown --release --features collections
|
||||
```
|
||||
|
||||
**Note**: This requires a WASM runtime with file system support (e.g., WASI) and will not work in browsers.
|
||||
|
||||
## JavaScript API Examples
|
||||
|
||||
### CollectionManager:
|
||||
```javascript
|
||||
import { CollectionManager } from 'ruvector-wasm';
|
||||
|
||||
// Create manager
|
||||
const manager = new CollectionManager();
|
||||
|
||||
// Create collection
|
||||
manager.createCollection("documents", 384, "cosine");
|
||||
|
||||
// List collections
|
||||
const collections = manager.listCollections();
|
||||
|
||||
// Create alias
|
||||
manager.createAlias("current_docs", "documents");
|
||||
|
||||
// Get collection
|
||||
const db = manager.getCollection("current_docs");
|
||||
|
||||
// Use the VectorDB
|
||||
const id = db.insert(vector, "doc1", { title: "Hello" });
|
||||
```
|
||||
|
||||
### FilterBuilder:
|
||||
```javascript
|
||||
import { FilterBuilder } from 'ruvector-wasm';
|
||||
|
||||
// Simple equality filter
|
||||
const filter1 = FilterBuilder.eq("status", "active");
|
||||
|
||||
// Complex filter
|
||||
const filter2 = FilterBuilder.and([
|
||||
FilterBuilder.eq("status", "active"),
|
||||
FilterBuilder.or([
|
||||
FilterBuilder.gte("age", 18),
|
||||
FilterBuilder.lt("priority", 10)
|
||||
])
|
||||
]);
|
||||
|
||||
// Geo filter
|
||||
const filter3 = FilterBuilder.geoRadius(
|
||||
"location",
|
||||
40.7128, // latitude
|
||||
-74.0060, // longitude
|
||||
1000 // radius in meters
|
||||
);
|
||||
|
||||
// Convert to JSON for use with search
|
||||
const filterJson = filter.toJson();
|
||||
const results = db.search(queryVector, 10, filterJson);
|
||||
```
|
||||
|
||||
## Required Fixes
|
||||
|
||||
To make this fully functional, the following changes are needed in ruvector-core:
|
||||
|
||||
### 1. Add cfg guards to error.rs:
|
||||
```rust
|
||||
#[cfg(feature = "storage")]
|
||||
impl From<redb::Error> for RuvectorError {
|
||||
// ...
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Add cfg guards to index/hnsw.rs:
|
||||
```rust
|
||||
#[cfg(feature = "hnsw")]
|
||||
use hnsw_rs::prelude::*;
|
||||
|
||||
#[cfg(feature = "hnsw")]
|
||||
pub struct HnswIndex {
|
||||
// ...
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Ensure memory-only feature works:
|
||||
The `memory-only` feature should be a complete alternative that doesn't require redb or hnsw_rs.
|
||||
|
||||
## Files Modified
|
||||
|
||||
1. `/home/user/ruvector/crates/ruvector-wasm/Cargo.toml`
|
||||
2. `/home/user/ruvector/crates/ruvector-wasm/src/lib.rs`
|
||||
3. `/home/user/ruvector/Cargo.toml` (attempted patch section, later removed)
|
||||
|
||||
## Verification
|
||||
|
||||
Once ruvector-core's conditional compilation issues are fixed, verify with:
|
||||
|
||||
```bash
|
||||
# Check basic WASM build
|
||||
cargo check --target wasm32-unknown-unknown
|
||||
|
||||
# Check with collections feature (requires WASI)
|
||||
cargo check --target wasm32-unknown-unknown --features collections
|
||||
|
||||
# Build release
|
||||
cargo build --target wasm32-unknown-unknown --release
|
||||
|
||||
# Run WASM tests
|
||||
wasm-pack test --node
|
||||
```
|
||||
|
||||
## Next Steps
|
||||
|
||||
1. Fix ruvector-core conditional compilation issues
|
||||
2. Add proper cfg guards for all optional dependencies
|
||||
3. Test WASM builds with and without collections feature
|
||||
4. Add WASM-specific tests for CollectionManager and FilterBuilder
|
||||
5. Document WASI requirements for collections feature
|
||||
6. Consider creating a pure in-memory alternative to collections for browser use
|
||||
969
vendor/ruvector/crates/ruvector-wasm/README.md
vendored
Normal file
969
vendor/ruvector/crates/ruvector-wasm/README.md
vendored
Normal file
@@ -0,0 +1,969 @@
|
||||
# Ruvector WASM
|
||||
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
[](https://www.npmjs.com/package/@ruvector/wasm)
|
||||
[](#bundle-size)
|
||||
[](#browser-compatibility)
|
||||
[](https://webassembly.org/)
|
||||
|
||||
**High-performance vector database running entirely in your browser via WebAssembly.**
|
||||
|
||||
> Bring **sub-millisecond vector search** to the edge with **offline-first** capabilities. Perfect for AI applications, semantic search, and recommendation engines that run completely client-side. Built by [rUv](https://ruv.io) with Rust and WebAssembly.
|
||||
|
||||
## 🌟 Why Ruvector WASM?
|
||||
|
||||
In the age of privacy-first, offline-capable web applications, running AI workloads **entirely in the browser** is no longer optional—it's essential.
|
||||
|
||||
**Ruvector WASM brings enterprise-grade vector search to the browser:**
|
||||
|
||||
- ⚡ **Blazing Fast**: <1ms query latency with HNSW indexing and SIMD acceleration
|
||||
- 🔒 **Privacy First**: All data stays in the browser—zero server round-trips
|
||||
- 📴 **Offline Capable**: Full functionality without internet via IndexedDB persistence
|
||||
- 🌐 **Edge Computing**: Deploy to CDNs for ultra-low latency globally
|
||||
- 💾 **Persistent Storage**: IndexedDB integration with automatic synchronization
|
||||
- 🧵 **Multi-threaded**: Web Workers support for parallel processing
|
||||
- 📦 **Compact**: <400KB gzipped with optimizations
|
||||
- 🎯 **Zero Dependencies**: Pure Rust compiled to WebAssembly
|
||||
|
||||
## 🚀 Features
|
||||
|
||||
### Core Capabilities
|
||||
|
||||
- **Complete VectorDB API**: Insert, search, delete, batch operations with familiar patterns
|
||||
- **HNSW Indexing**: Hierarchical Navigable Small World for fast approximate nearest neighbor search
|
||||
- **Multiple Distance Metrics**: Euclidean, Cosine, Dot Product, Manhattan
|
||||
- **SIMD Acceleration**: 2-4x speedup on supported hardware with automatic detection
|
||||
- **Memory Efficient**: Optimized memory layouts and zero-copy operations
|
||||
- **Type-Safe**: Full TypeScript definitions included
|
||||
|
||||
### Browser-Specific Features
|
||||
|
||||
- **IndexedDB Persistence**: Save/load database state with progressive loading
|
||||
- **Web Workers Integration**: Parallel operations across multiple threads
|
||||
- **Worker Pool Management**: Automatic load balancing across 4-8 workers
|
||||
- **Zero-Copy Transfers**: Transferable objects for efficient data passing
|
||||
- **Browser Console Debugging**: Enhanced error messages and stack traces
|
||||
- **Progressive Web Apps**: Perfect for PWA offline scenarios
|
||||
|
||||
### Performance Optimizations
|
||||
|
||||
- **Batch Operations**: Efficient bulk insert/search for large datasets
|
||||
- **LRU Caching**: 1000-entry hot vector cache for frequently accessed data
|
||||
- **Lazy Loading**: Progressive data loading with callbacks
|
||||
- **Compressed Storage**: Optimized serialization for IndexedDB
|
||||
- **WASM Streaming**: Compile WASM modules while downloading
|
||||
|
||||
## 📦 Installation
|
||||
|
||||
### NPM
|
||||
|
||||
```bash
|
||||
npm install @ruvector/wasm
|
||||
```
|
||||
|
||||
### Yarn
|
||||
|
||||
```bash
|
||||
yarn add @ruvector/wasm
|
||||
```
|
||||
|
||||
### CDN (for quick prototyping)
|
||||
|
||||
```html
|
||||
<script type="module">
|
||||
import init, { VectorDB } from 'https://unpkg.com/@ruvector/wasm/pkg/ruvector_wasm.js';
|
||||
|
||||
await init();
|
||||
const db = new VectorDB(384, 'cosine', true);
|
||||
</script>
|
||||
```
|
||||
|
||||
## ⚡ Quick Start
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```javascript
|
||||
import init, { VectorDB } from '@ruvector/wasm';
|
||||
|
||||
// 1. Initialize WASM module (one-time setup)
|
||||
await init();
|
||||
|
||||
// 2. Create database with 384-dimensional vectors
|
||||
const db = new VectorDB(
|
||||
384, // dimensions
|
||||
'cosine', // distance metric
|
||||
true // enable HNSW index
|
||||
);
|
||||
|
||||
// 3. Insert vectors with metadata
|
||||
const embedding = new Float32Array(384).map(() => Math.random());
|
||||
const id = db.insert(
|
||||
embedding,
|
||||
'doc_1', // optional ID
|
||||
{ title: 'My Document', type: 'article' } // optional metadata
|
||||
);
|
||||
|
||||
// 4. Search for similar vectors
|
||||
const query = new Float32Array(384).map(() => Math.random());
|
||||
const results = db.search(query, 10); // top 10 results
|
||||
|
||||
// 5. Process results
|
||||
results.forEach(result => {
|
||||
console.log(`ID: ${result.id}`);
|
||||
console.log(`Score: ${result.score}`);
|
||||
console.log(`Metadata:`, result.metadata);
|
||||
});
|
||||
```
|
||||
|
||||
### React Integration
|
||||
|
||||
```typescript
|
||||
import { useEffect, useState } from 'react';
|
||||
import init, { VectorDB } from '@ruvector/wasm';
|
||||
|
||||
function SemanticSearch() {
|
||||
const [db, setDb] = useState<VectorDB | null>(null);
|
||||
const [results, setResults] = useState([]);
|
||||
const [loading, setLoading] = useState(true);
|
||||
|
||||
useEffect(() => {
|
||||
// Initialize WASM and create database
|
||||
init().then(() => {
|
||||
const vectorDB = new VectorDB(384, 'cosine', true);
|
||||
setDb(vectorDB);
|
||||
setLoading(false);
|
||||
});
|
||||
}, []);
|
||||
|
||||
const handleSearch = async (queryEmbedding: Float32Array) => {
|
||||
if (!db) return;
|
||||
|
||||
const searchResults = db.search(queryEmbedding, 10);
|
||||
setResults(searchResults);
|
||||
};
|
||||
|
||||
if (loading) return <div>Loading vector database...</div>;
|
||||
|
||||
return (
|
||||
<div>
|
||||
<h1>Semantic Search</h1>
|
||||
{/* Your search UI */}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
### Vue.js Integration
|
||||
|
||||
```vue
|
||||
<template>
|
||||
<div>
|
||||
<h1>Vector Search</h1>
|
||||
<div v-if="!dbReady">Initializing...</div>
|
||||
<div v-else>
|
||||
<button @click="search">Search</button>
|
||||
<ul>
|
||||
<li v-for="result in results" :key="result.id">
|
||||
{{ result.id }}: {{ result.score }}
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup>
|
||||
import { ref, onMounted } from 'vue';
|
||||
import init, { VectorDB } from '@ruvector/wasm';
|
||||
|
||||
const db = ref(null);
|
||||
const dbReady = ref(false);
|
||||
const results = ref([]);
|
||||
|
||||
onMounted(async () => {
|
||||
await init();
|
||||
db.value = new VectorDB(384, 'cosine', true);
|
||||
dbReady.value = true;
|
||||
});
|
||||
|
||||
const search = () => {
|
||||
const query = new Float32Array(384).map(() => Math.random());
|
||||
results.value = db.value.search(query, 10);
|
||||
};
|
||||
</script>
|
||||
```
|
||||
|
||||
### Svelte Integration
|
||||
|
||||
```svelte
|
||||
<script>
|
||||
import { onMount } from 'svelte';
|
||||
import init, { VectorDB } from '@ruvector/wasm';
|
||||
|
||||
let db = null;
|
||||
let ready = false;
|
||||
let results = [];
|
||||
|
||||
onMount(async () => {
|
||||
await init();
|
||||
db = new VectorDB(384, 'cosine', true);
|
||||
ready = true;
|
||||
});
|
||||
|
||||
function search() {
|
||||
const query = new Float32Array(384).map(() => Math.random());
|
||||
results = db.search(query, 10);
|
||||
}
|
||||
</script>
|
||||
|
||||
{#if !ready}
|
||||
<p>Loading...</p>
|
||||
{:else}
|
||||
<button on:click={search}>Search</button>
|
||||
{#each results as result}
|
||||
<div>{result.id}: {result.score}</div>
|
||||
{/each}
|
||||
{/if}
|
||||
```
|
||||
|
||||
## 🔥 Advanced Usage
|
||||
|
||||
### Web Workers for Background Processing
|
||||
|
||||
Offload heavy vector operations to background threads for smooth UI performance:
|
||||
|
||||
```javascript
|
||||
// main.js
|
||||
import { WorkerPool } from '@ruvector/wasm/worker-pool';
|
||||
|
||||
const pool = new WorkerPool(
|
||||
'/worker.js',
|
||||
'/pkg/ruvector_wasm.js',
|
||||
{
|
||||
poolSize: navigator.hardwareConcurrency || 4, // Auto-detect CPU cores
|
||||
dimensions: 384,
|
||||
metric: 'cosine',
|
||||
useHnsw: true
|
||||
}
|
||||
);
|
||||
|
||||
// Initialize worker pool
|
||||
await pool.init();
|
||||
|
||||
// Batch insert in parallel (non-blocking)
|
||||
const vectors = generateVectors(10000, 384);
|
||||
const ids = await pool.insertBatch(vectors);
|
||||
|
||||
// Parallel search across workers
|
||||
const query = new Float32Array(384).map(() => Math.random());
|
||||
const results = await pool.search(query, 100);
|
||||
|
||||
// Get pool statistics
|
||||
const stats = pool.getStats();
|
||||
console.log(`Workers: ${stats.busyWorkers}/${stats.poolSize} busy`);
|
||||
console.log(`Queue: ${stats.queuedTasks} tasks waiting`);
|
||||
|
||||
// Cleanup when done
|
||||
pool.terminate();
|
||||
```
|
||||
|
||||
```javascript
|
||||
// worker.js - Web Worker implementation
|
||||
importScripts('/pkg/ruvector_wasm.js');
|
||||
|
||||
const { VectorDB } = wasm_bindgen;
|
||||
|
||||
let db = null;
|
||||
|
||||
self.onmessage = async (e) => {
|
||||
const { type, data } = e.data;
|
||||
|
||||
switch (type) {
|
||||
case 'init':
|
||||
await wasm_bindgen('/pkg/ruvector_wasm_bg.wasm');
|
||||
db = new VectorDB(data.dimensions, data.metric, data.useHnsw);
|
||||
self.postMessage({ type: 'ready' });
|
||||
break;
|
||||
|
||||
case 'insert':
|
||||
const id = db.insert(data.vector, data.id, data.metadata);
|
||||
self.postMessage({ type: 'inserted', id });
|
||||
break;
|
||||
|
||||
case 'search':
|
||||
const results = db.search(data.query, data.k);
|
||||
self.postMessage({ type: 'results', results });
|
||||
break;
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
### IndexedDB Persistence - Offline First
|
||||
|
||||
Keep your vector database synchronized across sessions:
|
||||
|
||||
```javascript
|
||||
import { IndexedDBPersistence } from '@ruvector/wasm/indexeddb';
|
||||
import init, { VectorDB } from '@ruvector/wasm';
|
||||
|
||||
await init();
|
||||
|
||||
// Create persistence layer
|
||||
const persistence = new IndexedDBPersistence('my_vector_db', {
|
||||
version: 1,
|
||||
cacheSize: 1000, // LRU cache for hot vectors
|
||||
batchSize: 100 // Batch size for bulk operations
|
||||
});
|
||||
|
||||
await persistence.open();
|
||||
|
||||
// Create or restore VectorDB
|
||||
const db = new VectorDB(384, 'cosine', true);
|
||||
|
||||
// Load existing data from IndexedDB (with progress)
|
||||
await persistence.loadAll(async (progress) => {
|
||||
console.log(`Loading: ${progress.loaded}/${progress.total} vectors`);
|
||||
console.log(`Progress: ${(progress.percent * 100).toFixed(1)}%`);
|
||||
|
||||
// Insert batch into VectorDB
|
||||
if (progress.vectors.length > 0) {
|
||||
const ids = db.insertBatch(progress.vectors);
|
||||
console.log(`Inserted ${ids.length} vectors`);
|
||||
}
|
||||
|
||||
if (progress.complete) {
|
||||
console.log('Database fully loaded!');
|
||||
}
|
||||
});
|
||||
|
||||
// Insert new vectors and save to IndexedDB
|
||||
const vector = new Float32Array(384).map(() => Math.random());
|
||||
const id = db.insert(vector, 'vec_123', { category: 'new' });
|
||||
|
||||
await persistence.save({
|
||||
id,
|
||||
vector,
|
||||
metadata: { category: 'new' }
|
||||
});
|
||||
|
||||
// Batch save for better performance
|
||||
const entries = [...]; // Your vector entries
|
||||
await persistence.saveBatch(entries);
|
||||
|
||||
// Get storage statistics
|
||||
const stats = await persistence.getStats();
|
||||
console.log(`Total vectors: ${stats.totalVectors}`);
|
||||
console.log(`Storage used: ${(stats.storageBytes / 1024 / 1024).toFixed(2)} MB`);
|
||||
console.log(`Cache size: ${stats.cacheSize}`);
|
||||
console.log(`Cache hit rate: ${(stats.cacheHitRate * 100).toFixed(2)}%`);
|
||||
|
||||
// Clear old data
|
||||
await persistence.clear();
|
||||
```
|
||||
|
||||
### Batch Operations for Performance
|
||||
|
||||
Process large datasets efficiently:
|
||||
|
||||
```javascript
|
||||
import init, { VectorDB } from '@ruvector/wasm';
|
||||
|
||||
await init();
|
||||
const db = new VectorDB(384, 'cosine', true);
|
||||
|
||||
// Batch insert (10x faster than individual inserts)
|
||||
const entries = [];
|
||||
for (let i = 0; i < 10000; i++) {
|
||||
entries.push({
|
||||
vector: new Float32Array(384).map(() => Math.random()),
|
||||
id: `vec_${i}`,
|
||||
metadata: { index: i, batch: Math.floor(i / 100) }
|
||||
});
|
||||
}
|
||||
|
||||
const ids = db.insertBatch(entries);
|
||||
console.log(`Inserted ${ids.length} vectors in batch`);
|
||||
|
||||
// Multiple parallel searches
|
||||
const queries = Array.from({ length: 100 }, () =>
|
||||
new Float32Array(384).map(() => Math.random())
|
||||
);
|
||||
|
||||
const allResults = queries.map(query => db.search(query, 10));
|
||||
console.log(`Completed ${allResults.length} searches`);
|
||||
```
|
||||
|
||||
### Memory Management Best Practices
|
||||
|
||||
```javascript
|
||||
import init, { VectorDB } from '@ruvector/wasm';
|
||||
|
||||
await init();
|
||||
|
||||
// Reuse Float32Array buffers to reduce GC pressure
|
||||
const buffer = new Float32Array(384);
|
||||
|
||||
// Insert with reused buffer
|
||||
for (let i = 0; i < 1000; i++) {
|
||||
// Fill buffer with new data
|
||||
for (let j = 0; j < 384; j++) {
|
||||
buffer[j] = Math.random();
|
||||
}
|
||||
|
||||
db.insert(buffer, `vec_${i}`, { index: i });
|
||||
|
||||
// Buffer is copied internally, safe to reuse
|
||||
}
|
||||
|
||||
// Check memory usage
|
||||
const vectorCount = db.len();
|
||||
const isEmpty = db.isEmpty();
|
||||
const dimensions = db.dimensions;
|
||||
|
||||
console.log(`Vectors: ${vectorCount}, Dims: ${dimensions}`);
|
||||
|
||||
// Clean up when done
|
||||
// JavaScript GC will handle WASM memory automatically
|
||||
```
|
||||
|
||||
## 📊 Performance Benchmarks
|
||||
|
||||
### Browser Performance (Chrome 120 on M1 MacBook Pro)
|
||||
|
||||
| Operation | Vectors | Dimensions | Standard | SIMD | Speedup |
|
||||
|-----------|---------|------------|----------|------|---------|
|
||||
| **Insert (individual)** | 10,000 | 384 | 3.2s | 1.1s | 2.9x |
|
||||
| **Insert (batch)** | 10,000 | 384 | 1.2s | 0.4s | 3.0x |
|
||||
| **Search (k=10)** | 100 queries | 384 | 0.5s | 0.2s | 2.5x |
|
||||
| **Search (k=100)** | 100 queries | 384 | 1.8s | 0.7s | 2.6x |
|
||||
| **Delete** | 1,000 | 384 | 0.2s | 0.1s | 2.0x |
|
||||
|
||||
### Throughput Comparison
|
||||
|
||||
```
|
||||
Operation Ruvector WASM Tensorflow.js ml5.js
|
||||
─────────────────────────────────────────────────────────────────
|
||||
Insert (ops/sec) 25,000 5,000 1,200
|
||||
Search (queries/sec) 500 80 20
|
||||
Memory (10K vectors) ~50MB ~200MB ~150MB
|
||||
Bundle Size (gzipped) 380KB 800KB 450KB
|
||||
Offline Support ✅ Partial ❌
|
||||
SIMD Acceleration ✅ ❌ ❌
|
||||
```
|
||||
|
||||
### Real-World Application Performance
|
||||
|
||||
**Semantic Search (10,000 documents, 384-dim embeddings)**
|
||||
- Cold start: ~800ms (WASM compile + data load)
|
||||
- Warm query: <5ms (with HNSW index)
|
||||
- IndexedDB load: ~2s (10,000 vectors)
|
||||
- Memory footprint: ~60MB
|
||||
|
||||
**Recommendation Engine (100,000 items, 128-dim embeddings)**
|
||||
- Initial load: ~8s from IndexedDB
|
||||
- Query latency: <10ms (p50)
|
||||
- Memory usage: ~180MB
|
||||
- Bundle impact: +400KB gzipped
|
||||
|
||||
## 🌐 Browser Compatibility
|
||||
|
||||
### Support Matrix
|
||||
|
||||
| Browser | Version | WASM | SIMD | Workers | IndexedDB | Status |
|
||||
|---------|---------|------|------|---------|-----------|--------|
|
||||
| **Chrome** | 91+ | ✅ | ✅ | ✅ | ✅ | Full Support |
|
||||
| **Firefox** | 89+ | ✅ | ✅ | ✅ | ✅ | Full Support |
|
||||
| **Safari** | 16.4+ | ✅ | Partial | ✅ | ✅ | Limited SIMD |
|
||||
| **Edge** | 91+ | ✅ | ✅ | ✅ | ✅ | Full Support |
|
||||
| **Opera** | 77+ | ✅ | ✅ | ✅ | ✅ | Full Support |
|
||||
| **Samsung Internet** | 15+ | ✅ | ❌ | ✅ | ✅ | No SIMD |
|
||||
|
||||
### SIMD Support Detection
|
||||
|
||||
```javascript
|
||||
import { detectSIMD } from '@ruvector/wasm';
|
||||
|
||||
if (detectSIMD()) {
|
||||
console.log('SIMD acceleration available!');
|
||||
// Load SIMD-optimized build
|
||||
await import('@ruvector/wasm/pkg-simd/ruvector_wasm.js');
|
||||
} else {
|
||||
console.log('Standard build');
|
||||
// Load standard build
|
||||
await import('@ruvector/wasm');
|
||||
}
|
||||
```
|
||||
|
||||
### Polyfills and Fallbacks
|
||||
|
||||
```javascript
|
||||
// Check for required features
|
||||
const hasWASM = typeof WebAssembly !== 'undefined';
|
||||
const hasWorkers = typeof Worker !== 'undefined';
|
||||
const hasIndexedDB = typeof indexedDB !== 'undefined';
|
||||
|
||||
if (!hasWASM) {
|
||||
console.error('WebAssembly not supported');
|
||||
// Fallback to server-side processing
|
||||
}
|
||||
|
||||
if (!hasWorkers) {
|
||||
console.warn('Web Workers not available, using main thread');
|
||||
// Use synchronous API
|
||||
}
|
||||
|
||||
if (!hasIndexedDB) {
|
||||
console.warn('IndexedDB not available, data will not persist');
|
||||
// Use in-memory only
|
||||
}
|
||||
```
|
||||
|
||||
## 📦 Bundle Size
|
||||
|
||||
### Production Build Sizes
|
||||
|
||||
```
|
||||
Build Type Uncompressed Gzipped Brotli
|
||||
──────────────────────────────────────────────────────────
|
||||
Standard WASM 1.2 MB 450 KB 380 KB
|
||||
SIMD WASM 1.3 MB 480 KB 410 KB
|
||||
JavaScript Glue 45 KB 12 KB 9 KB
|
||||
TypeScript Definitions 8 KB 2 KB 1.5 KB
|
||||
──────────────────────────────────────────────────────────
|
||||
Total (Standard) 1.25 MB 462 KB 390 KB
|
||||
Total (SIMD) 1.35 MB 492 KB 420 KB
|
||||
```
|
||||
|
||||
### With Optimizations (wasm-opt)
|
||||
|
||||
```bash
|
||||
npm run optimize
|
||||
```
|
||||
|
||||
```
|
||||
Optimized Build Uncompressed Gzipped Brotli
|
||||
──────────────────────────────────────────────────────────
|
||||
Standard WASM 900 KB 380 KB 320 KB
|
||||
SIMD WASM 980 KB 410 KB 350 KB
|
||||
```
|
||||
|
||||
### Code Splitting Strategy
|
||||
|
||||
```javascript
|
||||
// Lazy load WASM module when needed
|
||||
const loadVectorDB = async () => {
|
||||
const { default: init, VectorDB } = await import('@ruvector/wasm');
|
||||
await init();
|
||||
return VectorDB;
|
||||
};
|
||||
|
||||
// Use in your application
|
||||
button.addEventListener('click', async () => {
|
||||
const VectorDB = await loadVectorDB();
|
||||
const db = new VectorDB(384, 'cosine', true);
|
||||
// Use db...
|
||||
});
|
||||
```
|
||||
|
||||
## 🔨 Building from Source
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- **Rust**: 1.77 or higher
|
||||
- **wasm-pack**: Latest version
|
||||
- **Node.js**: 18.0 or higher
|
||||
|
||||
```bash
|
||||
# Install wasm-pack
|
||||
curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
|
||||
|
||||
# Or via npm
|
||||
npm install -g wasm-pack
|
||||
```
|
||||
|
||||
### Build Commands
|
||||
|
||||
```bash
|
||||
# Clone repository
|
||||
git clone https://github.com/ruvnet/ruvector.git
|
||||
cd ruvector/crates/ruvector-wasm
|
||||
|
||||
# Install dependencies
|
||||
npm install
|
||||
|
||||
# Build for web (ES modules)
|
||||
npm run build:web
|
||||
|
||||
# Build with SIMD optimizations
|
||||
npm run build:simd
|
||||
|
||||
# Build for Node.js
|
||||
npm run build:node
|
||||
|
||||
# Build for bundlers (webpack, rollup, etc.)
|
||||
npm run build:bundler
|
||||
|
||||
# Build all targets
|
||||
npm run build:all
|
||||
|
||||
# Run tests in browser
|
||||
npm test
|
||||
|
||||
# Run tests in Node.js
|
||||
npm run test:node
|
||||
|
||||
# Check bundle size
|
||||
npm run size
|
||||
|
||||
# Optimize with wasm-opt (requires binaryen)
|
||||
npm run optimize
|
||||
|
||||
# Serve examples locally
|
||||
npm run serve
|
||||
```
|
||||
|
||||
### Development Workflow
|
||||
|
||||
```bash
|
||||
# Watch mode (requires custom setup)
|
||||
wasm-pack build --dev --target web -- --features simd
|
||||
|
||||
# Run specific browser tests
|
||||
npm run test:firefox
|
||||
|
||||
# Profile WASM performance
|
||||
wasm-pack build --profiling --target web
|
||||
|
||||
# Generate documentation
|
||||
cargo doc --no-deps --open
|
||||
```
|
||||
|
||||
### Custom Build Configuration
|
||||
|
||||
```toml
|
||||
# .cargo/config.toml
|
||||
[target.wasm32-unknown-unknown]
|
||||
rustflags = [
|
||||
"-C", "opt-level=z",
|
||||
"-C", "lto=fat",
|
||||
"-C", "codegen-units=1"
|
||||
]
|
||||
```
|
||||
|
||||
## 📚 API Reference
|
||||
|
||||
### VectorDB Class
|
||||
|
||||
```typescript
|
||||
class VectorDB {
|
||||
constructor(
|
||||
dimensions: number,
|
||||
metric?: 'euclidean' | 'cosine' | 'dotproduct' | 'manhattan',
|
||||
useHnsw?: boolean
|
||||
);
|
||||
|
||||
// Insert operations
|
||||
insert(vector: Float32Array, id?: string, metadata?: object): string;
|
||||
insertBatch(entries: VectorEntry[]): string[];
|
||||
|
||||
// Search operations
|
||||
search(query: Float32Array, k: number, filter?: object): SearchResult[];
|
||||
|
||||
// Retrieval operations
|
||||
get(id: string): VectorEntry | null;
|
||||
len(): number;
|
||||
isEmpty(): boolean;
|
||||
|
||||
// Delete operations
|
||||
delete(id: string): boolean;
|
||||
|
||||
// Persistence (IndexedDB)
|
||||
saveToIndexedDB(): Promise<void>;
|
||||
static loadFromIndexedDB(dbName: string): Promise<VectorDB>;
|
||||
|
||||
// Properties
|
||||
readonly dimensions: number;
|
||||
}
|
||||
```
|
||||
|
||||
### Types
|
||||
|
||||
```typescript
|
||||
interface VectorEntry {
|
||||
id?: string;
|
||||
vector: Float32Array;
|
||||
metadata?: Record<string, any>;
|
||||
}
|
||||
|
||||
interface SearchResult {
|
||||
id: string;
|
||||
score: number;
|
||||
vector?: Float32Array;
|
||||
metadata?: Record<string, any>;
|
||||
}
|
||||
```
|
||||
|
||||
### Utility Functions
|
||||
|
||||
```typescript
|
||||
// Detect SIMD support
|
||||
function detectSIMD(): boolean;
|
||||
|
||||
// Get version
|
||||
function version(): string;
|
||||
|
||||
// Array conversion
|
||||
function arrayToFloat32Array(arr: number[]): Float32Array;
|
||||
|
||||
// Benchmarking
|
||||
function benchmark(name: string, iterations: number, dimensions: number): number;
|
||||
```
|
||||
|
||||
See [WASM API Documentation](../../docs/getting-started/wasm-api.md) for complete reference.
|
||||
|
||||
## 🎯 Example Applications
|
||||
|
||||
### Semantic Search Engine
|
||||
|
||||
```javascript
|
||||
// Semantic search with OpenAI embeddings
|
||||
import init, { VectorDB } from '@ruvector/wasm';
|
||||
import { Configuration, OpenAIApi } from 'openai';
|
||||
|
||||
await init();
|
||||
|
||||
const openai = new OpenAIApi(new Configuration({
|
||||
apiKey: process.env.OPENAI_API_KEY
|
||||
}));
|
||||
|
||||
const db = new VectorDB(1536, 'cosine', true); // OpenAI ada-002 = 1536 dims
|
||||
|
||||
// Index documents
|
||||
const documents = [
|
||||
'The quick brown fox jumps over the lazy dog',
|
||||
'Machine learning is a subset of artificial intelligence',
|
||||
'WebAssembly enables high-performance web applications'
|
||||
];
|
||||
|
||||
for (const [i, doc] of documents.entries()) {
|
||||
const response = await openai.createEmbedding({
|
||||
model: 'text-embedding-ada-002',
|
||||
input: doc
|
||||
});
|
||||
|
||||
const embedding = new Float32Array(response.data.data[0].embedding);
|
||||
db.insert(embedding, `doc_${i}`, { text: doc });
|
||||
}
|
||||
|
||||
// Search
|
||||
const queryResponse = await openai.createEmbedding({
|
||||
model: 'text-embedding-ada-002',
|
||||
input: 'What is AI?'
|
||||
});
|
||||
|
||||
const queryEmbedding = new Float32Array(queryResponse.data.data[0].embedding);
|
||||
const results = db.search(queryEmbedding, 3);
|
||||
|
||||
results.forEach(result => {
|
||||
console.log(`${result.score.toFixed(4)}: ${result.metadata.text}`);
|
||||
});
|
||||
```
|
||||
|
||||
### Offline Recommendation Engine
|
||||
|
||||
```javascript
|
||||
// Product recommendations that work offline
|
||||
import init, { VectorDB } from '@ruvector/wasm';
|
||||
import { IndexedDBPersistence } from '@ruvector/wasm/indexeddb';
|
||||
|
||||
await init();
|
||||
|
||||
const db = new VectorDB(128, 'cosine', true);
|
||||
const persistence = new IndexedDBPersistence('product_recommendations');
|
||||
await persistence.open();
|
||||
|
||||
// Load cached recommendations
|
||||
await persistence.loadAll(async (progress) => {
|
||||
if (progress.vectors.length > 0) {
|
||||
db.insertBatch(progress.vectors);
|
||||
}
|
||||
});
|
||||
|
||||
// Get recommendations based on user history
|
||||
function getRecommendations(userHistory, k = 10) {
|
||||
// Compute user preference vector (average of liked items)
|
||||
const userVector = computeAverageEmbedding(userHistory);
|
||||
const recommendations = db.search(userVector, k);
|
||||
|
||||
return recommendations.map(r => ({
|
||||
productId: r.id,
|
||||
score: r.score,
|
||||
...r.metadata
|
||||
}));
|
||||
}
|
||||
|
||||
// Add new products (syncs to IndexedDB)
|
||||
async function addProduct(productId, embedding, metadata) {
|
||||
db.insert(embedding, productId, metadata);
|
||||
await persistence.save({ id: productId, vector: embedding, metadata });
|
||||
}
|
||||
```
|
||||
|
||||
### RAG (Retrieval-Augmented Generation)
|
||||
|
||||
```javascript
|
||||
// Browser-based RAG system
|
||||
import init, { VectorDB } from '@ruvector/wasm';
|
||||
|
||||
await init();
|
||||
|
||||
const db = new VectorDB(768, 'cosine', true); // BERT embeddings
|
||||
|
||||
// Index knowledge base
|
||||
const knowledgeBase = loadKnowledgeBase(); // Your documents
|
||||
for (const doc of knowledgeBase) {
|
||||
const embedding = await getBertEmbedding(doc.text);
|
||||
db.insert(embedding, doc.id, { text: doc.text, source: doc.source });
|
||||
}
|
||||
|
||||
// RAG query function
|
||||
async function ragQuery(question, llm) {
|
||||
// 1. Get question embedding
|
||||
const questionEmbedding = await getBertEmbedding(question);
|
||||
|
||||
// 2. Retrieve relevant context
|
||||
const context = db.search(questionEmbedding, 5);
|
||||
|
||||
// 3. Augment prompt with context
|
||||
const prompt = `
|
||||
Context:
|
||||
${context.map(r => r.metadata.text).join('\n\n')}
|
||||
|
||||
Question: ${question}
|
||||
|
||||
Answer based on the context above:
|
||||
`;
|
||||
|
||||
// 4. Generate response
|
||||
const response = await llm.generate(prompt);
|
||||
|
||||
return {
|
||||
answer: response,
|
||||
sources: context.map(r => r.metadata.source)
|
||||
};
|
||||
}
|
||||
```
|
||||
|
||||
## 🐛 Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
**1. WASM Module Not Loading**
|
||||
|
||||
```javascript
|
||||
// Ensure correct MIME type
|
||||
// Add to server config (nginx):
|
||||
// types {
|
||||
// application/wasm wasm;
|
||||
// }
|
||||
|
||||
// Or use explicit fetch
|
||||
const wasmUrl = new URL('./pkg/ruvector_wasm_bg.wasm', import.meta.url);
|
||||
await init(await fetch(wasmUrl));
|
||||
```
|
||||
|
||||
**2. CORS Errors**
|
||||
|
||||
```javascript
|
||||
// For local development
|
||||
// package.json
|
||||
{
|
||||
"scripts": {
|
||||
"serve": "python3 -m http.server 8080 --bind 127.0.0.1"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**3. Memory Issues**
|
||||
|
||||
```javascript
|
||||
// Monitor memory usage
|
||||
const stats = db.len();
|
||||
const estimatedMemory = stats * dimensions * 4; // bytes
|
||||
|
||||
if (estimatedMemory > 100_000_000) { // 100MB
|
||||
console.warn('High memory usage, consider chunking');
|
||||
}
|
||||
|
||||
// Use batch operations to reduce GC pressure
|
||||
const BATCH_SIZE = 1000;
|
||||
for (let i = 0; i < entries.length; i += BATCH_SIZE) {
|
||||
const batch = entries.slice(i, i + BATCH_SIZE);
|
||||
db.insertBatch(batch);
|
||||
}
|
||||
```
|
||||
|
||||
**4. Web Worker Issues**
|
||||
|
||||
```javascript
|
||||
// Ensure worker script URL is correct
|
||||
const workerUrl = new URL('./worker.js', import.meta.url);
|
||||
const worker = new Worker(workerUrl, { type: 'module' });
|
||||
|
||||
// Handle worker errors
|
||||
worker.onerror = (error) => {
|
||||
console.error('Worker error:', error);
|
||||
};
|
||||
```
|
||||
|
||||
See [WASM Troubleshooting Guide](../../docs/getting-started/wasm-troubleshooting.md) for more solutions.
|
||||
|
||||
## 🔗 Links & Resources
|
||||
|
||||
### Documentation
|
||||
|
||||
- **[Getting Started Guide](../../docs/guide/GETTING_STARTED.md)** - Complete setup and usage
|
||||
- **[WASM API Reference](../../docs/getting-started/wasm-api.md)** - Full API documentation
|
||||
- **[Performance Tuning](../../docs/optimization/PERFORMANCE_TUNING_GUIDE.md)** - Optimization tips
|
||||
- **[Main README](../../README.md)** - Project overview and features
|
||||
|
||||
### Examples & Demos
|
||||
|
||||
- **[Vanilla JS Example](../../examples/wasm-vanilla/)** - Basic implementation
|
||||
- **[React Demo](../../examples/wasm-react/)** - React integration with hooks
|
||||
- **[Live Demo](https://ruvector-demo.vercel.app)** - Try it in your browser
|
||||
- **[CodeSandbox](https://codesandbox.io/s/ruvector-wasm)** - Interactive playground
|
||||
|
||||
### Community & Support
|
||||
|
||||
- **GitHub**: [github.com/ruvnet/ruvector](https://github.com/ruvnet/ruvector)
|
||||
- **Discord**: [Join our community](https://discord.gg/ruvnet)
|
||||
- **Twitter**: [@ruvnet](https://twitter.com/ruvnet)
|
||||
- **Issues**: [Report bugs](https://github.com/ruvnet/ruvector/issues)
|
||||
|
||||
## 📄 License
|
||||
|
||||
MIT License - see [LICENSE](../../LICENSE) for details.
|
||||
|
||||
Free to use for commercial and personal projects.
|
||||
|
||||
## 🙏 Acknowledgments
|
||||
|
||||
- Built with [wasm-pack](https://github.com/rustwasm/wasm-pack) and [wasm-bindgen](https://github.com/rustwasm/wasm-bindgen)
|
||||
- HNSW algorithm implementation from [hnsw_rs](https://github.com/jean-pierreBoth/hnswlib-rs)
|
||||
- SIMD optimizations powered by Rust's excellent WebAssembly support
|
||||
- The WebAssembly community for making this possible
|
||||
|
||||
---
|
||||
|
||||
<div align="center">
|
||||
|
||||
**Built by [rUv](https://ruv.io) • Open Source on [GitHub](https://github.com/ruvnet/ruvector)**
|
||||
|
||||
[](https://github.com/ruvnet/ruvector)
|
||||
[](https://twitter.com/ruvnet)
|
||||
|
||||
**Perfect for**: PWAs • Offline-First Apps • Edge Computing • Privacy-First AI
|
||||
|
||||
[Get Started](../../docs/guide/GETTING_STARTED.md) • [API Docs](../../docs/getting-started/wasm-api.md) • [Examples](../../examples/)
|
||||
|
||||
</div>
|
||||
309
vendor/ruvector/crates/ruvector-wasm/kernels/rmsnorm.rs
vendored
Normal file
309
vendor/ruvector/crates/ruvector-wasm/kernels/rmsnorm.rs
vendored
Normal file
@@ -0,0 +1,309 @@
|
||||
//! RMSNorm (Root Mean Square Layer Normalization) Kernel
|
||||
//!
|
||||
//! This kernel implements RMS normalization as used in models like LLaMA.
|
||||
//! Unlike LayerNorm, RMSNorm only uses the root mean square, without
|
||||
//! centering the distribution.
|
||||
//!
|
||||
//! Formula: y = (x / rms(x)) * weight
|
||||
//! where rms(x) = sqrt(mean(x^2) + eps)
|
||||
//!
|
||||
//! # Compilation
|
||||
//!
|
||||
//! To compile this kernel to WASM:
|
||||
//! ```bash
|
||||
//! rustc --target wasm32-unknown-unknown \
|
||||
//! --crate-type cdylib \
|
||||
//! -C opt-level=3 \
|
||||
//! -C lto=fat \
|
||||
//! kernels/rmsnorm.rs \
|
||||
//! -o kernels/rmsnorm_f32.wasm
|
||||
//! ```
|
||||
|
||||
#![no_std]
|
||||
#![no_main]
|
||||
|
||||
// Panic handler for no_std
|
||||
#[panic_handler]
|
||||
fn panic(_info: &core::panic::PanicInfo) -> ! {
|
||||
loop {}
|
||||
}
|
||||
|
||||
/// Kernel descriptor structure
|
||||
#[repr(C)]
|
||||
pub struct KernelDescriptor {
|
||||
pub input_a_offset: u32, // x tensor
|
||||
pub input_a_size: u32,
|
||||
pub input_b_offset: u32, // weight tensor (gamma)
|
||||
pub input_b_size: u32,
|
||||
pub output_offset: u32,
|
||||
pub output_size: u32,
|
||||
pub scratch_offset: u32, // For storing intermediate RMS values
|
||||
pub scratch_size: u32,
|
||||
pub params_offset: u32,
|
||||
pub params_size: u32,
|
||||
}
|
||||
|
||||
/// RMSNorm parameters
|
||||
#[repr(C)]
|
||||
pub struct RmsNormParams {
|
||||
/// Epsilon for numerical stability (typically 1e-5 or 1e-6)
|
||||
pub eps: f32,
|
||||
/// Hidden dimension (normalizing dimension)
|
||||
pub hidden_dim: u32,
|
||||
/// Number of elements to normalize (batch * seq)
|
||||
pub num_elements: u32,
|
||||
}
|
||||
|
||||
/// Error codes
|
||||
const OK: i32 = 0;
|
||||
const INVALID_INPUT: i32 = 1;
|
||||
const INVALID_OUTPUT: i32 = 2;
|
||||
const INVALID_PARAMS: i32 = 3;
|
||||
|
||||
/// Initialize kernel
|
||||
#[no_mangle]
|
||||
pub extern "C" fn kernel_init(_params_ptr: *const u8, _params_len: u32) -> i32 {
|
||||
OK
|
||||
}
|
||||
|
||||
/// Execute RMSNorm forward pass
|
||||
///
|
||||
/// # Memory Layout
|
||||
///
|
||||
/// Input A (x): [num_elements, hidden_dim] as f32
|
||||
/// Input B (weight): [hidden_dim] as f32 (gamma scaling factors)
|
||||
/// Output (y): [num_elements, hidden_dim] as f32
|
||||
/// Scratch: [num_elements] as f32 (RMS values for backward pass)
|
||||
///
|
||||
/// For each row i:
|
||||
/// rms[i] = sqrt(mean(x[i]^2) + eps)
|
||||
/// y[i] = (x[i] / rms[i]) * weight
|
||||
#[no_mangle]
|
||||
pub extern "C" fn kernel_forward(desc_ptr: *const KernelDescriptor) -> i32 {
|
||||
let desc = unsafe { &*desc_ptr };
|
||||
|
||||
// Validate inputs
|
||||
if desc.input_a_size == 0 {
|
||||
return INVALID_INPUT;
|
||||
}
|
||||
if desc.output_size == 0 {
|
||||
return INVALID_OUTPUT;
|
||||
}
|
||||
if desc.params_size < core::mem::size_of::<RmsNormParams>() as u32 {
|
||||
return INVALID_PARAMS;
|
||||
}
|
||||
|
||||
let memory_base = 0usize as *mut u8;
|
||||
|
||||
let params = unsafe {
|
||||
&*(memory_base.add(desc.params_offset as usize) as *const RmsNormParams)
|
||||
};
|
||||
|
||||
let hidden_dim = params.hidden_dim as usize;
|
||||
let num_elements = params.num_elements as usize;
|
||||
let eps = params.eps;
|
||||
|
||||
// Get tensor pointers
|
||||
let x_ptr = unsafe { memory_base.add(desc.input_a_offset as usize) as *const f32 };
|
||||
let weight_ptr = unsafe { memory_base.add(desc.input_b_offset as usize) as *const f32 };
|
||||
let y_ptr = unsafe { memory_base.add(desc.output_offset as usize) as *mut f32 };
|
||||
|
||||
// Optional: Store RMS values in scratch for backward pass
|
||||
let rms_ptr = if desc.scratch_size >= (num_elements * 4) as u32 {
|
||||
Some(unsafe { memory_base.add(desc.scratch_offset as usize) as *mut f32 })
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Process each element (row)
|
||||
for i in 0..num_elements {
|
||||
let row_offset = i * hidden_dim;
|
||||
|
||||
// Compute sum of squares
|
||||
let mut sum_sq: f32 = 0.0;
|
||||
for j in 0..hidden_dim {
|
||||
unsafe {
|
||||
let val = *x_ptr.add(row_offset + j);
|
||||
sum_sq += val * val;
|
||||
}
|
||||
}
|
||||
|
||||
// Compute RMS
|
||||
let mean_sq = sum_sq / (hidden_dim as f32);
|
||||
let rms = sqrtf(mean_sq + eps);
|
||||
let inv_rms = 1.0 / rms;
|
||||
|
||||
// Store RMS for backward pass if scratch is available
|
||||
if let Some(rms_store) = rms_ptr {
|
||||
unsafe {
|
||||
*rms_store.add(i) = rms;
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize and scale
|
||||
for j in 0..hidden_dim {
|
||||
unsafe {
|
||||
let x_val = *x_ptr.add(row_offset + j);
|
||||
let w_val = *weight_ptr.add(j);
|
||||
*y_ptr.add(row_offset + j) = (x_val * inv_rms) * w_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
OK
|
||||
}
|
||||
|
||||
/// Execute RMSNorm backward pass
|
||||
///
|
||||
/// Computes gradients for x and weight given gradient of output.
|
||||
///
|
||||
/// # Memory Layout (for backward)
|
||||
///
|
||||
/// Input A (grad_y): [num_elements, hidden_dim] as f32
|
||||
/// Input B (x): Original input (needed for gradient)
|
||||
/// Output (grad_x): [num_elements, hidden_dim] as f32
|
||||
/// Scratch: [hidden_dim] as f32 (for grad_weight accumulation)
|
||||
/// Params: Contains weight pointer separately
|
||||
#[no_mangle]
|
||||
pub extern "C" fn kernel_backward(desc_ptr: *const KernelDescriptor) -> i32 {
|
||||
let desc = unsafe { &*desc_ptr };
|
||||
|
||||
if desc.input_a_size == 0 {
|
||||
return INVALID_INPUT;
|
||||
}
|
||||
if desc.output_size == 0 {
|
||||
return INVALID_OUTPUT;
|
||||
}
|
||||
if desc.params_size < core::mem::size_of::<RmsNormParams>() as u32 {
|
||||
return INVALID_PARAMS;
|
||||
}
|
||||
|
||||
let memory_base = 0usize as *mut u8;
|
||||
|
||||
let params = unsafe {
|
||||
&*(memory_base.add(desc.params_offset as usize) as *const RmsNormParams)
|
||||
};
|
||||
|
||||
let hidden_dim = params.hidden_dim as usize;
|
||||
let num_elements = params.num_elements as usize;
|
||||
let eps = params.eps;
|
||||
|
||||
// Note: For a complete backward pass, we would need:
|
||||
// - grad_y: gradient from upstream
|
||||
// - x: original input
|
||||
// - weight: scale parameters
|
||||
// - Output: grad_x
|
||||
// - Accumulate: grad_weight
|
||||
|
||||
// This is a simplified implementation showing the structure
|
||||
let grad_y_ptr = unsafe { memory_base.add(desc.input_a_offset as usize) as *const f32 };
|
||||
let x_ptr = unsafe { memory_base.add(desc.input_b_offset as usize) as *const f32 };
|
||||
let grad_x_ptr = unsafe { memory_base.add(desc.output_offset as usize) as *mut f32 };
|
||||
|
||||
// For each element
|
||||
for i in 0..num_elements {
|
||||
let row_offset = i * hidden_dim;
|
||||
|
||||
// Recompute RMS (or load from scratch if saved during forward)
|
||||
let mut sum_sq: f32 = 0.0;
|
||||
for j in 0..hidden_dim {
|
||||
unsafe {
|
||||
let val = *x_ptr.add(row_offset + j);
|
||||
sum_sq += val * val;
|
||||
}
|
||||
}
|
||||
let mean_sq = sum_sq / (hidden_dim as f32);
|
||||
let rms = sqrtf(mean_sq + eps);
|
||||
let inv_rms = 1.0 / rms;
|
||||
let inv_rms_cubed = inv_rms * inv_rms * inv_rms;
|
||||
|
||||
// Compute grad_norm_x = grad_y * weight
|
||||
// Then grad_x = inv_rms * grad_norm_x - inv_rms^3 * x * mean(x * grad_norm_x)
|
||||
// This is the chain rule applied to RMSNorm
|
||||
|
||||
// First pass: compute sum(x * grad_y) for this row
|
||||
let mut sum_x_grad: f32 = 0.0;
|
||||
for j in 0..hidden_dim {
|
||||
unsafe {
|
||||
let x_val = *x_ptr.add(row_offset + j);
|
||||
let gy_val = *grad_y_ptr.add(row_offset + j);
|
||||
sum_x_grad += x_val * gy_val;
|
||||
}
|
||||
}
|
||||
let mean_x_grad = sum_x_grad / (hidden_dim as f32);
|
||||
|
||||
// Second pass: compute grad_x
|
||||
for j in 0..hidden_dim {
|
||||
unsafe {
|
||||
let x_val = *x_ptr.add(row_offset + j);
|
||||
let gy_val = *grad_y_ptr.add(row_offset + j);
|
||||
|
||||
// Simplified gradient (without weight consideration for this demo)
|
||||
let grad = inv_rms * gy_val - inv_rms_cubed * x_val * mean_x_grad;
|
||||
*grad_x_ptr.add(row_offset + j) = grad;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
OK
|
||||
}
|
||||
|
||||
/// Kernel info structure
|
||||
#[repr(C)]
|
||||
pub struct KernelInfo {
|
||||
pub name_ptr: *const u8,
|
||||
pub name_len: u32,
|
||||
pub version_major: u16,
|
||||
pub version_minor: u16,
|
||||
pub version_patch: u16,
|
||||
pub supports_backward: bool,
|
||||
}
|
||||
|
||||
static KERNEL_NAME: &[u8] = b"rmsnorm_f32\0";
|
||||
|
||||
/// Get kernel metadata
|
||||
#[no_mangle]
|
||||
pub extern "C" fn kernel_info(info_ptr: *mut KernelInfo) -> i32 {
|
||||
if info_ptr.is_null() {
|
||||
return INVALID_PARAMS;
|
||||
}
|
||||
|
||||
unsafe {
|
||||
(*info_ptr).name_ptr = KERNEL_NAME.as_ptr();
|
||||
(*info_ptr).name_len = KERNEL_NAME.len() as u32 - 1;
|
||||
(*info_ptr).version_major = 1;
|
||||
(*info_ptr).version_minor = 0;
|
||||
(*info_ptr).version_patch = 0;
|
||||
(*info_ptr).supports_backward = true;
|
||||
}
|
||||
|
||||
OK
|
||||
}
|
||||
|
||||
/// Cleanup kernel resources
|
||||
#[no_mangle]
|
||||
pub extern "C" fn kernel_cleanup() -> i32 {
|
||||
OK
|
||||
}
|
||||
|
||||
// Minimal sqrt implementation for no_std
|
||||
fn sqrtf(x: f32) -> f32 {
|
||||
if x <= 0.0 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Newton-Raphson method
|
||||
let mut guess = x;
|
||||
|
||||
// Initial guess using bit manipulation
|
||||
let i = x.to_bits();
|
||||
let i = 0x1fbd1df5 + (i >> 1);
|
||||
guess = f32::from_bits(i);
|
||||
|
||||
// Newton-Raphson iterations
|
||||
for _ in 0..3 {
|
||||
guess = 0.5 * (guess + x / guess);
|
||||
}
|
||||
|
||||
guess
|
||||
}
|
||||
304
vendor/ruvector/crates/ruvector-wasm/kernels/rope.rs
vendored
Normal file
304
vendor/ruvector/crates/ruvector-wasm/kernels/rope.rs
vendored
Normal file
@@ -0,0 +1,304 @@
|
||||
//! RoPE (Rotary Position Embedding) Kernel
|
||||
//!
|
||||
//! This kernel implements rotary position embeddings as described in the
|
||||
//! RoFormer paper (https://arxiv.org/abs/2104.09864).
|
||||
//!
|
||||
//! RoPE applies rotation to the query and key vectors in attention,
|
||||
//! encoding relative positional information.
|
||||
//!
|
||||
//! # Compilation
|
||||
//!
|
||||
//! To compile this kernel to WASM:
|
||||
//! ```bash
|
||||
//! rustc --target wasm32-unknown-unknown \
|
||||
//! --crate-type cdylib \
|
||||
//! -C opt-level=3 \
|
||||
//! -C lto=fat \
|
||||
//! kernels/rope.rs \
|
||||
//! -o kernels/rope_f32.wasm
|
||||
//! ```
|
||||
//!
|
||||
//! Or use the provided build script in the kernels directory.
|
||||
|
||||
#![no_std]
|
||||
#![no_main]
|
||||
|
||||
// Panic handler for no_std
|
||||
#[panic_handler]
|
||||
fn panic(_info: &core::panic::PanicInfo) -> ! {
|
||||
loop {}
|
||||
}
|
||||
|
||||
/// Kernel descriptor structure (must match host definition)
|
||||
#[repr(C)]
|
||||
pub struct KernelDescriptor {
|
||||
pub input_a_offset: u32, // x tensor
|
||||
pub input_a_size: u32,
|
||||
pub input_b_offset: u32, // freqs tensor
|
||||
pub input_b_size: u32,
|
||||
pub output_offset: u32,
|
||||
pub output_size: u32,
|
||||
pub scratch_offset: u32,
|
||||
pub scratch_size: u32,
|
||||
pub params_offset: u32,
|
||||
pub params_size: u32,
|
||||
}
|
||||
|
||||
/// RoPE parameters
|
||||
#[repr(C)]
|
||||
pub struct RopeParams {
|
||||
/// Base frequency (typically 10000.0)
|
||||
pub theta: f32,
|
||||
/// Sequence length
|
||||
pub seq_len: u32,
|
||||
/// Head dimension (must be even)
|
||||
pub head_dim: u32,
|
||||
/// Number of heads
|
||||
pub num_heads: u32,
|
||||
/// Batch size
|
||||
pub batch_size: u32,
|
||||
}
|
||||
|
||||
/// Error codes
|
||||
const OK: i32 = 0;
|
||||
const INVALID_INPUT: i32 = 1;
|
||||
const INVALID_OUTPUT: i32 = 2;
|
||||
const INVALID_PARAMS: i32 = 3;
|
||||
|
||||
/// Initialize kernel (optional, for stateful kernels)
|
||||
#[no_mangle]
|
||||
pub extern "C" fn kernel_init(_params_ptr: *const u8, _params_len: u32) -> i32 {
|
||||
OK
|
||||
}
|
||||
|
||||
/// Execute RoPE forward pass
|
||||
///
|
||||
/// # Memory Layout
|
||||
///
|
||||
/// Input A (x): [batch, seq, heads, dim] as f32
|
||||
/// Input B (freqs): [seq, dim/2] as f32 (precomputed frequencies)
|
||||
/// Output (y): [batch, seq, heads, dim] as f32
|
||||
///
|
||||
/// The kernel applies rotation to pairs of elements:
|
||||
/// y[..., 2i] = x[..., 2i] * cos(freq) - x[..., 2i+1] * sin(freq)
|
||||
/// y[..., 2i+1] = x[..., 2i] * sin(freq) + x[..., 2i+1] * cos(freq)
|
||||
#[no_mangle]
|
||||
pub extern "C" fn kernel_forward(desc_ptr: *const KernelDescriptor) -> i32 {
|
||||
// Safety: We trust the host to provide valid pointers
|
||||
let desc = unsafe { &*desc_ptr };
|
||||
|
||||
// Validate inputs
|
||||
if desc.input_a_size == 0 {
|
||||
return INVALID_INPUT;
|
||||
}
|
||||
if desc.output_size == 0 || desc.output_size != desc.input_a_size {
|
||||
return INVALID_OUTPUT;
|
||||
}
|
||||
if desc.params_size < core::mem::size_of::<RopeParams>() as u32 {
|
||||
return INVALID_PARAMS;
|
||||
}
|
||||
|
||||
// Get memory base pointer (WASM linear memory starts at 0)
|
||||
let memory_base = 0usize as *mut u8;
|
||||
|
||||
// Get params
|
||||
let params = unsafe {
|
||||
&*(memory_base.add(desc.params_offset as usize) as *const RopeParams)
|
||||
};
|
||||
|
||||
// Validate head_dim is even
|
||||
if params.head_dim % 2 != 0 {
|
||||
return INVALID_PARAMS;
|
||||
}
|
||||
|
||||
let half_dim = params.head_dim / 2;
|
||||
|
||||
// Get tensor pointers
|
||||
let x_ptr = unsafe { memory_base.add(desc.input_a_offset as usize) as *const f32 };
|
||||
let freqs_ptr = unsafe { memory_base.add(desc.input_b_offset as usize) as *const f32 };
|
||||
let y_ptr = unsafe { memory_base.add(desc.output_offset as usize) as *mut f32 };
|
||||
|
||||
// Apply RoPE
|
||||
// Loop order: batch -> seq -> head -> dim_pair
|
||||
for b in 0..params.batch_size {
|
||||
for s in 0..params.seq_len {
|
||||
for h in 0..params.num_heads {
|
||||
for d in 0..half_dim {
|
||||
// Calculate indices
|
||||
let idx = ((b * params.seq_len + s) * params.num_heads + h) * params.head_dim + d * 2;
|
||||
let freq_idx = s * half_dim + d;
|
||||
|
||||
unsafe {
|
||||
// Get input values
|
||||
let x0 = *x_ptr.add(idx as usize);
|
||||
let x1 = *x_ptr.add(idx as usize + 1);
|
||||
|
||||
// Get frequency (precomputed cos and sin are interleaved)
|
||||
let freq = *freqs_ptr.add(freq_idx as usize);
|
||||
let cos_f = libm::cosf(freq);
|
||||
let sin_f = libm::sinf(freq);
|
||||
|
||||
// Apply rotation
|
||||
let y0 = x0 * cos_f - x1 * sin_f;
|
||||
let y1 = x0 * sin_f + x1 * cos_f;
|
||||
|
||||
// Write output
|
||||
*y_ptr.add(idx as usize) = y0;
|
||||
*y_ptr.add(idx as usize + 1) = y1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
OK
|
||||
}
|
||||
|
||||
/// Execute RoPE backward pass (gradient computation)
|
||||
///
|
||||
/// The backward pass is the same rotation with negated sin,
|
||||
/// since the Jacobian of rotation is another rotation.
|
||||
#[no_mangle]
|
||||
pub extern "C" fn kernel_backward(desc_ptr: *const KernelDescriptor) -> i32 {
|
||||
// For RoPE, backward is essentially the same operation with transposed rotation
|
||||
// (negated sin terms), but the structure is identical
|
||||
let desc = unsafe { &*desc_ptr };
|
||||
|
||||
if desc.input_a_size == 0 {
|
||||
return INVALID_INPUT;
|
||||
}
|
||||
if desc.output_size == 0 || desc.output_size != desc.input_a_size {
|
||||
return INVALID_OUTPUT;
|
||||
}
|
||||
if desc.params_size < core::mem::size_of::<RopeParams>() as u32 {
|
||||
return INVALID_PARAMS;
|
||||
}
|
||||
|
||||
let memory_base = 0usize as *mut u8;
|
||||
|
||||
let params = unsafe {
|
||||
&*(memory_base.add(desc.params_offset as usize) as *const RopeParams)
|
||||
};
|
||||
|
||||
if params.head_dim % 2 != 0 {
|
||||
return INVALID_PARAMS;
|
||||
}
|
||||
|
||||
let half_dim = params.head_dim / 2;
|
||||
|
||||
let grad_y_ptr = unsafe { memory_base.add(desc.input_a_offset as usize) as *const f32 };
|
||||
let freqs_ptr = unsafe { memory_base.add(desc.input_b_offset as usize) as *const f32 };
|
||||
let grad_x_ptr = unsafe { memory_base.add(desc.output_offset as usize) as *mut f32 };
|
||||
|
||||
// Backward RoPE: apply inverse rotation (transpose = negate sin)
|
||||
for b in 0..params.batch_size {
|
||||
for s in 0..params.seq_len {
|
||||
for h in 0..params.num_heads {
|
||||
for d in 0..half_dim {
|
||||
let idx = ((b * params.seq_len + s) * params.num_heads + h) * params.head_dim + d * 2;
|
||||
let freq_idx = s * half_dim + d;
|
||||
|
||||
unsafe {
|
||||
let gy0 = *grad_y_ptr.add(idx as usize);
|
||||
let gy1 = *grad_y_ptr.add(idx as usize + 1);
|
||||
|
||||
let freq = *freqs_ptr.add(freq_idx as usize);
|
||||
let cos_f = libm::cosf(freq);
|
||||
let sin_f = libm::sinf(freq);
|
||||
|
||||
// Inverse rotation (transpose)
|
||||
let gx0 = gy0 * cos_f + gy1 * sin_f;
|
||||
let gx1 = -gy0 * sin_f + gy1 * cos_f;
|
||||
|
||||
*grad_x_ptr.add(idx as usize) = gx0;
|
||||
*grad_x_ptr.add(idx as usize + 1) = gx1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
OK
|
||||
}
|
||||
|
||||
/// Kernel info structure
|
||||
#[repr(C)]
|
||||
pub struct KernelInfo {
|
||||
pub name_ptr: *const u8,
|
||||
pub name_len: u32,
|
||||
pub version_major: u16,
|
||||
pub version_minor: u16,
|
||||
pub version_patch: u16,
|
||||
pub supports_backward: bool,
|
||||
}
|
||||
|
||||
static KERNEL_NAME: &[u8] = b"rope_f32\0";
|
||||
|
||||
/// Get kernel metadata
|
||||
#[no_mangle]
|
||||
pub extern "C" fn kernel_info(info_ptr: *mut KernelInfo) -> i32 {
|
||||
if info_ptr.is_null() {
|
||||
return INVALID_PARAMS;
|
||||
}
|
||||
|
||||
unsafe {
|
||||
(*info_ptr).name_ptr = KERNEL_NAME.as_ptr();
|
||||
(*info_ptr).name_len = KERNEL_NAME.len() as u32 - 1; // Exclude null terminator
|
||||
(*info_ptr).version_major = 1;
|
||||
(*info_ptr).version_minor = 0;
|
||||
(*info_ptr).version_patch = 0;
|
||||
(*info_ptr).supports_backward = true;
|
||||
}
|
||||
|
||||
OK
|
||||
}
|
||||
|
||||
/// Cleanup kernel resources
|
||||
#[no_mangle]
|
||||
pub extern "C" fn kernel_cleanup() -> i32 {
|
||||
// No resources to cleanup for this stateless kernel
|
||||
OK
|
||||
}
|
||||
|
||||
// Minimal libm implementations for no_std
|
||||
mod libm {
|
||||
// Simple Taylor series approximations for sin and cos
|
||||
// In production, use more accurate implementations or link to libm
|
||||
|
||||
const PI: f32 = 3.14159265358979323846;
|
||||
const TWO_PI: f32 = 2.0 * PI;
|
||||
|
||||
fn normalize_angle(mut x: f32) -> f32 {
|
||||
// Reduce to [-PI, PI]
|
||||
while x > PI {
|
||||
x -= TWO_PI;
|
||||
}
|
||||
while x < -PI {
|
||||
x += TWO_PI;
|
||||
}
|
||||
x
|
||||
}
|
||||
|
||||
pub fn sinf(x: f32) -> f32 {
|
||||
let x = normalize_angle(x);
|
||||
// Taylor series: sin(x) = x - x^3/3! + x^5/5! - x^7/7! + ...
|
||||
let x2 = x * x;
|
||||
let x3 = x2 * x;
|
||||
let x5 = x3 * x2;
|
||||
let x7 = x5 * x2;
|
||||
let x9 = x7 * x2;
|
||||
|
||||
x - x3 / 6.0 + x5 / 120.0 - x7 / 5040.0 + x9 / 362880.0
|
||||
}
|
||||
|
||||
pub fn cosf(x: f32) -> f32 {
|
||||
let x = normalize_angle(x);
|
||||
// Taylor series: cos(x) = 1 - x^2/2! + x^4/4! - x^6/6! + ...
|
||||
let x2 = x * x;
|
||||
let x4 = x2 * x2;
|
||||
let x6 = x4 * x2;
|
||||
let x8 = x6 * x2;
|
||||
|
||||
1.0 - x2 / 2.0 + x4 / 24.0 - x6 / 720.0 + x8 / 40320.0
|
||||
}
|
||||
}
|
||||
299
vendor/ruvector/crates/ruvector-wasm/kernels/swiglu.rs
vendored
Normal file
299
vendor/ruvector/crates/ruvector-wasm/kernels/swiglu.rs
vendored
Normal file
@@ -0,0 +1,299 @@
|
||||
//! SwiGLU (Swish-Gated Linear Unit) Activation Kernel
|
||||
//!
|
||||
//! This kernel implements the SwiGLU activation function used in models
|
||||
//! like LLaMA and PaLM. It combines the Swish activation with a gating
|
||||
//! mechanism.
|
||||
//!
|
||||
//! Formula: SwiGLU(x, gate) = swish(gate) * x
|
||||
//! where swish(x) = x * sigmoid(x)
|
||||
//!
|
||||
//! In practice, this is often used in the FFN:
|
||||
//! FFN(x) = (swish(x * W_gate) * (x * W_up)) * W_down
|
||||
//!
|
||||
//! This kernel computes: swish(gate) * x
|
||||
//!
|
||||
//! # Compilation
|
||||
//!
|
||||
//! To compile this kernel to WASM:
|
||||
//! ```bash
|
||||
//! rustc --target wasm32-unknown-unknown \
|
||||
//! --crate-type cdylib \
|
||||
//! -C opt-level=3 \
|
||||
//! -C lto=fat \
|
||||
//! kernels/swiglu.rs \
|
||||
//! -o kernels/swiglu_f32.wasm
|
||||
//! ```
|
||||
|
||||
#![no_std]
|
||||
#![no_main]
|
||||
|
||||
// Panic handler for no_std
|
||||
#[panic_handler]
|
||||
fn panic(_info: &core::panic::PanicInfo) -> ! {
|
||||
loop {}
|
||||
}
|
||||
|
||||
/// Kernel descriptor structure
|
||||
#[repr(C)]
|
||||
pub struct KernelDescriptor {
|
||||
pub input_a_offset: u32, // x tensor (to be gated)
|
||||
pub input_a_size: u32,
|
||||
pub input_b_offset: u32, // gate tensor
|
||||
pub input_b_size: u32,
|
||||
pub output_offset: u32,
|
||||
pub output_size: u32,
|
||||
pub scratch_offset: u32,
|
||||
pub scratch_size: u32,
|
||||
pub params_offset: u32,
|
||||
pub params_size: u32,
|
||||
}
|
||||
|
||||
/// SwiGLU parameters
|
||||
#[repr(C)]
|
||||
pub struct SwiGluParams {
|
||||
/// Number of elements (total size = num_elements * hidden_dim)
|
||||
pub num_elements: u32,
|
||||
/// Hidden dimension
|
||||
pub hidden_dim: u32,
|
||||
/// Beta parameter for SiLU/Swish (typically 1.0)
|
||||
pub beta: f32,
|
||||
}
|
||||
|
||||
/// Error codes
|
||||
const OK: i32 = 0;
|
||||
const INVALID_INPUT: i32 = 1;
|
||||
const INVALID_OUTPUT: i32 = 2;
|
||||
const INVALID_PARAMS: i32 = 3;
|
||||
|
||||
/// Initialize kernel
|
||||
#[no_mangle]
|
||||
pub extern "C" fn kernel_init(_params_ptr: *const u8, _params_len: u32) -> i32 {
|
||||
OK
|
||||
}
|
||||
|
||||
/// Compute swish activation: x * sigmoid(beta * x)
|
||||
#[inline]
|
||||
fn swish(x: f32, beta: f32) -> f32 {
|
||||
x * sigmoid(beta * x)
|
||||
}
|
||||
|
||||
/// Sigmoid function: 1 / (1 + exp(-x))
|
||||
#[inline]
|
||||
fn sigmoid(x: f32) -> f32 {
|
||||
1.0 / (1.0 + expf(-x))
|
||||
}
|
||||
|
||||
/// Execute SwiGLU forward pass
|
||||
///
|
||||
/// # Memory Layout
|
||||
///
|
||||
/// Input A (x): [num_elements, hidden_dim] as f32 (value to gate)
|
||||
/// Input B (gate): [num_elements, hidden_dim] as f32 (gate values)
|
||||
/// Output (y): [num_elements, hidden_dim] as f32
|
||||
///
|
||||
/// y = swish(gate) * x
|
||||
#[no_mangle]
|
||||
pub extern "C" fn kernel_forward(desc_ptr: *const KernelDescriptor) -> i32 {
|
||||
let desc = unsafe { &*desc_ptr };
|
||||
|
||||
// Validate inputs
|
||||
if desc.input_a_size == 0 || desc.input_b_size == 0 {
|
||||
return INVALID_INPUT;
|
||||
}
|
||||
if desc.input_a_size != desc.input_b_size {
|
||||
return INVALID_INPUT; // x and gate must have same size
|
||||
}
|
||||
if desc.output_size == 0 || desc.output_size != desc.input_a_size {
|
||||
return INVALID_OUTPUT;
|
||||
}
|
||||
if desc.params_size < core::mem::size_of::<SwiGluParams>() as u32 {
|
||||
return INVALID_PARAMS;
|
||||
}
|
||||
|
||||
let memory_base = 0usize as *mut u8;
|
||||
|
||||
let params = unsafe {
|
||||
&*(memory_base.add(desc.params_offset as usize) as *const SwiGluParams)
|
||||
};
|
||||
|
||||
let total_elements = (params.num_elements * params.hidden_dim) as usize;
|
||||
let beta = params.beta;
|
||||
|
||||
// Get tensor pointers
|
||||
let x_ptr = unsafe { memory_base.add(desc.input_a_offset as usize) as *const f32 };
|
||||
let gate_ptr = unsafe { memory_base.add(desc.input_b_offset as usize) as *const f32 };
|
||||
let y_ptr = unsafe { memory_base.add(desc.output_offset as usize) as *mut f32 };
|
||||
|
||||
// Apply SwiGLU: y = swish(gate) * x
|
||||
for i in 0..total_elements {
|
||||
unsafe {
|
||||
let x_val = *x_ptr.add(i);
|
||||
let gate_val = *gate_ptr.add(i);
|
||||
let swish_gate = swish(gate_val, beta);
|
||||
*y_ptr.add(i) = swish_gate * x_val;
|
||||
}
|
||||
}
|
||||
|
||||
OK
|
||||
}
|
||||
|
||||
/// Execute SwiGLU backward pass
|
||||
///
|
||||
/// Given grad_y, compute grad_x and grad_gate.
|
||||
///
|
||||
/// grad_x = swish(gate) * grad_y
|
||||
/// grad_gate = x * grad_y * (sigmoid(gate) + gate * sigmoid(gate) * (1 - sigmoid(gate)))
|
||||
/// = x * grad_y * sigmoid(gate) * (1 + gate * (1 - sigmoid(gate)))
|
||||
///
|
||||
/// For this simplified kernel:
|
||||
/// Input A (grad_y): gradient from upstream
|
||||
/// Input B contains both (x, gate) - simplified layout
|
||||
/// Output (grad_x): gradient w.r.t. x
|
||||
/// Scratch: gradient w.r.t. gate (if space available)
|
||||
#[no_mangle]
|
||||
pub extern "C" fn kernel_backward(desc_ptr: *const KernelDescriptor) -> i32 {
|
||||
let desc = unsafe { &*desc_ptr };
|
||||
|
||||
if desc.input_a_size == 0 {
|
||||
return INVALID_INPUT;
|
||||
}
|
||||
if desc.output_size == 0 {
|
||||
return INVALID_OUTPUT;
|
||||
}
|
||||
if desc.params_size < core::mem::size_of::<SwiGluParams>() as u32 {
|
||||
return INVALID_PARAMS;
|
||||
}
|
||||
|
||||
let memory_base = 0usize as *mut u8;
|
||||
|
||||
let params = unsafe {
|
||||
&*(memory_base.add(desc.params_offset as usize) as *const SwiGluParams)
|
||||
};
|
||||
|
||||
let total_elements = (params.num_elements * params.hidden_dim) as usize;
|
||||
let beta = params.beta;
|
||||
|
||||
// For backward, input_b should contain original gate values
|
||||
// This is a simplified layout - real implementation would use separate descriptors
|
||||
let grad_y_ptr = unsafe { memory_base.add(desc.input_a_offset as usize) as *const f32 };
|
||||
let gate_ptr = unsafe { memory_base.add(desc.input_b_offset as usize) as *const f32 };
|
||||
let grad_x_ptr = unsafe { memory_base.add(desc.output_offset as usize) as *mut f32 };
|
||||
|
||||
// Compute grad_x = swish(gate) * grad_y
|
||||
// (simplified: we would also need original x to compute grad_gate)
|
||||
for i in 0..total_elements {
|
||||
unsafe {
|
||||
let grad_y_val = *grad_y_ptr.add(i);
|
||||
let gate_val = *gate_ptr.add(i);
|
||||
let swish_gate = swish(gate_val, beta);
|
||||
*grad_x_ptr.add(i) = swish_gate * grad_y_val;
|
||||
}
|
||||
}
|
||||
|
||||
OK
|
||||
}
|
||||
|
||||
/// Kernel info structure
|
||||
#[repr(C)]
|
||||
pub struct KernelInfo {
|
||||
pub name_ptr: *const u8,
|
||||
pub name_len: u32,
|
||||
pub version_major: u16,
|
||||
pub version_minor: u16,
|
||||
pub version_patch: u16,
|
||||
pub supports_backward: bool,
|
||||
}
|
||||
|
||||
static KERNEL_NAME: &[u8] = b"swiglu_f32\0";
|
||||
|
||||
/// Get kernel metadata
|
||||
#[no_mangle]
|
||||
pub extern "C" fn kernel_info(info_ptr: *mut KernelInfo) -> i32 {
|
||||
if info_ptr.is_null() {
|
||||
return INVALID_PARAMS;
|
||||
}
|
||||
|
||||
unsafe {
|
||||
(*info_ptr).name_ptr = KERNEL_NAME.as_ptr();
|
||||
(*info_ptr).name_len = KERNEL_NAME.len() as u32 - 1;
|
||||
(*info_ptr).version_major = 1;
|
||||
(*info_ptr).version_minor = 0;
|
||||
(*info_ptr).version_patch = 0;
|
||||
(*info_ptr).supports_backward = true;
|
||||
}
|
||||
|
||||
OK
|
||||
}
|
||||
|
||||
/// Cleanup kernel resources
|
||||
#[no_mangle]
|
||||
pub extern "C" fn kernel_cleanup() -> i32 {
|
||||
OK
|
||||
}
|
||||
|
||||
// Minimal exp implementation for no_std
|
||||
fn expf(x: f32) -> f32 {
|
||||
// Handle edge cases
|
||||
if x > 88.0 {
|
||||
return f32::INFINITY;
|
||||
}
|
||||
if x < -88.0 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Use range reduction: exp(x) = 2^k * exp(r)
|
||||
// where k = round(x / ln(2)) and r = x - k * ln(2)
|
||||
const LN2: f32 = 0.693147180559945;
|
||||
const LN2_INV: f32 = 1.442695040888963;
|
||||
|
||||
let k = (x * LN2_INV + 0.5).floor();
|
||||
let r = x - k * LN2;
|
||||
|
||||
// Taylor series for exp(r) where |r| <= ln(2)/2
|
||||
// exp(r) ≈ 1 + r + r^2/2! + r^3/3! + r^4/4! + r^5/5! + r^6/6!
|
||||
let r2 = r * r;
|
||||
let r3 = r2 * r;
|
||||
let r4 = r2 * r2;
|
||||
let r5 = r4 * r;
|
||||
let r6 = r3 * r3;
|
||||
|
||||
let exp_r = 1.0 + r + r2 * 0.5 + r3 * 0.166666667 + r4 * 0.041666667 + r5 * 0.008333333 + r6 * 0.001388889;
|
||||
|
||||
// Combine: exp(x) = 2^k * exp(r)
|
||||
// 2^k can be computed via bit manipulation
|
||||
let k_int = k as i32;
|
||||
let scale_bits = ((127 + k_int) as u32) << 23;
|
||||
let scale = f32::from_bits(scale_bits);
|
||||
|
||||
exp_r * scale
|
||||
}
|
||||
|
||||
/// Compute GeGLU variant (alternative activation)
|
||||
/// GeGLU(x, gate) = gelu(gate) * x
|
||||
/// This is provided as an alternative, not used in default forward
|
||||
#[allow(dead_code)]
|
||||
fn gelu(x: f32) -> f32 {
|
||||
// Approximate GELU: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
||||
const SQRT_2_OVER_PI: f32 = 0.7978845608028654;
|
||||
const COEFF: f32 = 0.044715;
|
||||
|
||||
let x3 = x * x * x;
|
||||
let inner = SQRT_2_OVER_PI * (x + COEFF * x3);
|
||||
0.5 * x * (1.0 + tanhf(inner))
|
||||
}
|
||||
|
||||
/// Minimal tanh implementation
|
||||
#[allow(dead_code)]
|
||||
fn tanhf(x: f32) -> f32 {
|
||||
// tanh(x) = (exp(2x) - 1) / (exp(2x) + 1)
|
||||
// For numerical stability with large |x|
|
||||
if x > 10.0 {
|
||||
return 1.0;
|
||||
}
|
||||
if x < -10.0 {
|
||||
return -1.0;
|
||||
}
|
||||
|
||||
let exp_2x = expf(2.0 * x);
|
||||
(exp_2x - 1.0) / (exp_2x + 1.0)
|
||||
}
|
||||
46
vendor/ruvector/crates/ruvector-wasm/package.json
vendored
Normal file
46
vendor/ruvector/crates/ruvector-wasm/package.json
vendored
Normal file
@@ -0,0 +1,46 @@
|
||||
{
|
||||
"name": "@ruvector/wasm",
|
||||
"version": "0.1.16",
|
||||
"description": "High-performance Rust vector database for browsers via WASM",
|
||||
"main": "pkg/ruvector_wasm.js",
|
||||
"types": "pkg/ruvector_wasm.d.ts",
|
||||
"files": [
|
||||
"pkg",
|
||||
"src/worker.js",
|
||||
"src/worker-pool.js",
|
||||
"src/indexeddb.js"
|
||||
],
|
||||
"scripts": {
|
||||
"build": "npm run build:web && npm run build:simd && npm run build:bundler",
|
||||
"build:web": "wasm-pack build --target web --out-dir pkg --release",
|
||||
"build:simd": "wasm-pack build --target web --out-dir pkg-simd --release -- --features simd",
|
||||
"build:node": "wasm-pack build --target nodejs --out-dir pkg-node --release",
|
||||
"build:bundler": "wasm-pack build --target bundler --out-dir pkg-bundler --release",
|
||||
"build:all": "npm run build && npm run build:node && npm run build:bundler",
|
||||
"test": "wasm-pack test --headless --chrome",
|
||||
"test:firefox": "wasm-pack test --headless --firefox",
|
||||
"test:node": "wasm-pack test --node",
|
||||
"size": "npm run build && gzip -c pkg/ruvector_wasm_bg.wasm | wc -c && echo 'bytes (gzipped)'",
|
||||
"optimize": "npm run build && wasm-opt -Oz pkg/ruvector_wasm_bg.wasm -o pkg/ruvector_wasm_bg.wasm",
|
||||
"serve": "python3 -m http.server 8080"
|
||||
},
|
||||
"keywords": [
|
||||
"vector",
|
||||
"database",
|
||||
"embeddings",
|
||||
"wasm",
|
||||
"browser",
|
||||
"rust",
|
||||
"simd",
|
||||
"web-workers",
|
||||
"indexeddb"
|
||||
],
|
||||
"license": "MIT",
|
||||
"repository": {
|
||||
"type": "git",
|
||||
"url": "https://github.com/ruvnet/ruvector.git"
|
||||
},
|
||||
"devDependencies": {
|
||||
"wasm-pack": "^0.12.1"
|
||||
}
|
||||
}
|
||||
355
vendor/ruvector/crates/ruvector-wasm/src/indexeddb.js
vendored
Normal file
355
vendor/ruvector/crates/ruvector-wasm/src/indexeddb.js
vendored
Normal file
@@ -0,0 +1,355 @@
|
||||
/**
|
||||
* IndexedDB Persistence Layer for Ruvector
|
||||
*
|
||||
* Provides:
|
||||
* - Save/load database state to IndexedDB
|
||||
* - Batch operations for performance
|
||||
* - Progressive loading with pagination
|
||||
* - LRU cache for hot vectors
|
||||
*/
|
||||
|
||||
const DB_NAME = 'ruvector_storage';
|
||||
const DB_VERSION = 1;
|
||||
const VECTOR_STORE = 'vectors';
|
||||
const META_STORE = 'metadata';
|
||||
|
||||
/**
|
||||
* LRU Cache for hot vectors
|
||||
*/
|
||||
class LRUCache {
|
||||
constructor(capacity = 1000) {
|
||||
this.capacity = capacity;
|
||||
this.cache = new Map();
|
||||
}
|
||||
|
||||
get(key) {
|
||||
if (!this.cache.has(key)) return null;
|
||||
|
||||
// Move to end (most recently used)
|
||||
const value = this.cache.get(key);
|
||||
this.cache.delete(key);
|
||||
this.cache.set(key, value);
|
||||
|
||||
return value;
|
||||
}
|
||||
|
||||
set(key, value) {
|
||||
// Remove if exists
|
||||
if (this.cache.has(key)) {
|
||||
this.cache.delete(key);
|
||||
}
|
||||
|
||||
// Add to end
|
||||
this.cache.set(key, value);
|
||||
|
||||
// Evict oldest if over capacity
|
||||
if (this.cache.size > this.capacity) {
|
||||
const firstKey = this.cache.keys().next().value;
|
||||
this.cache.delete(firstKey);
|
||||
}
|
||||
}
|
||||
|
||||
has(key) {
|
||||
return this.cache.has(key);
|
||||
}
|
||||
|
||||
clear() {
|
||||
this.cache.clear();
|
||||
}
|
||||
|
||||
get size() {
|
||||
return this.cache.size;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* IndexedDB Persistence Manager
|
||||
*/
|
||||
export class IndexedDBPersistence {
|
||||
constructor(dbName = null) {
|
||||
this.dbName = dbName || DB_NAME;
|
||||
this.db = null;
|
||||
this.cache = new LRUCache(1000);
|
||||
}
|
||||
|
||||
/**
|
||||
* Open IndexedDB connection
|
||||
*/
|
||||
async open() {
|
||||
return new Promise((resolve, reject) => {
|
||||
const request = indexedDB.open(this.dbName, DB_VERSION);
|
||||
|
||||
request.onerror = () => reject(request.error);
|
||||
request.onsuccess = () => {
|
||||
this.db = request.result;
|
||||
resolve(this.db);
|
||||
};
|
||||
|
||||
request.onupgradeneeded = (event) => {
|
||||
const db = event.target.result;
|
||||
|
||||
// Create object stores if they don't exist
|
||||
if (!db.objectStoreNames.contains(VECTOR_STORE)) {
|
||||
const vectorStore = db.createObjectStore(VECTOR_STORE, { keyPath: 'id' });
|
||||
vectorStore.createIndex('timestamp', 'timestamp', { unique: false });
|
||||
}
|
||||
|
||||
if (!db.objectStoreNames.contains(META_STORE)) {
|
||||
db.createObjectStore(META_STORE, { keyPath: 'key' });
|
||||
}
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Save a single vector
|
||||
*/
|
||||
async saveVector(id, vector, metadata = null) {
|
||||
if (!this.db) await this.open();
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
const transaction = this.db.transaction([VECTOR_STORE], 'readwrite');
|
||||
const store = transaction.objectStore(VECTOR_STORE);
|
||||
|
||||
const data = {
|
||||
id,
|
||||
vector: Array.from(vector), // Convert Float32Array to regular array
|
||||
metadata,
|
||||
timestamp: Date.now()
|
||||
};
|
||||
|
||||
const request = store.put(data);
|
||||
|
||||
request.onsuccess = () => {
|
||||
this.cache.set(id, data);
|
||||
resolve(id);
|
||||
};
|
||||
request.onerror = () => reject(request.error);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Save vectors in batch (more efficient)
|
||||
*/
|
||||
async saveBatch(entries, batchSize = 100) {
|
||||
if (!this.db) await this.open();
|
||||
|
||||
const chunks = [];
|
||||
for (let i = 0; i < entries.length; i += batchSize) {
|
||||
chunks.push(entries.slice(i, i + batchSize));
|
||||
}
|
||||
|
||||
for (const chunk of chunks) {
|
||||
await new Promise((resolve, reject) => {
|
||||
const transaction = this.db.transaction([VECTOR_STORE], 'readwrite');
|
||||
const store = transaction.objectStore(VECTOR_STORE);
|
||||
|
||||
for (const entry of chunk) {
|
||||
const data = {
|
||||
id: entry.id,
|
||||
vector: Array.from(entry.vector),
|
||||
metadata: entry.metadata,
|
||||
timestamp: Date.now()
|
||||
};
|
||||
|
||||
store.put(data);
|
||||
this.cache.set(entry.id, data);
|
||||
}
|
||||
|
||||
transaction.oncomplete = () => resolve();
|
||||
transaction.onerror = () => reject(transaction.error);
|
||||
});
|
||||
}
|
||||
|
||||
return entries.length;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load a single vector by ID
|
||||
*/
|
||||
async loadVector(id) {
|
||||
// Check cache first
|
||||
if (this.cache.has(id)) {
|
||||
return this.cache.get(id);
|
||||
}
|
||||
|
||||
if (!this.db) await this.open();
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
const transaction = this.db.transaction([VECTOR_STORE], 'readonly');
|
||||
const store = transaction.objectStore(VECTOR_STORE);
|
||||
const request = store.get(id);
|
||||
|
||||
request.onsuccess = () => {
|
||||
const data = request.result;
|
||||
if (data) {
|
||||
// Convert array back to Float32Array
|
||||
data.vector = new Float32Array(data.vector);
|
||||
this.cache.set(id, data);
|
||||
}
|
||||
resolve(data);
|
||||
};
|
||||
request.onerror = () => reject(request.error);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Load all vectors (with progressive loading)
|
||||
*/
|
||||
async loadAll(onProgress = null, batchSize = 100) {
|
||||
if (!this.db) await this.open();
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
const transaction = this.db.transaction([VECTOR_STORE], 'readonly');
|
||||
const store = transaction.objectStore(VECTOR_STORE);
|
||||
const request = store.openCursor();
|
||||
|
||||
const vectors = [];
|
||||
let count = 0;
|
||||
|
||||
request.onsuccess = (event) => {
|
||||
const cursor = event.target.result;
|
||||
|
||||
if (cursor) {
|
||||
const data = cursor.value;
|
||||
data.vector = new Float32Array(data.vector);
|
||||
vectors.push(data);
|
||||
count++;
|
||||
|
||||
// Cache hot vectors (first 1000)
|
||||
if (count <= 1000) {
|
||||
this.cache.set(data.id, data);
|
||||
}
|
||||
|
||||
// Report progress every batch
|
||||
if (onProgress && count % batchSize === 0) {
|
||||
onProgress({
|
||||
loaded: count,
|
||||
vectors: [...vectors]
|
||||
});
|
||||
vectors.length = 0; // Clear batch
|
||||
}
|
||||
|
||||
cursor.continue();
|
||||
} else {
|
||||
// Done
|
||||
if (onProgress && vectors.length > 0) {
|
||||
onProgress({
|
||||
loaded: count,
|
||||
vectors: vectors,
|
||||
complete: true
|
||||
});
|
||||
}
|
||||
resolve({ count, complete: true });
|
||||
}
|
||||
};
|
||||
|
||||
request.onerror = () => reject(request.error);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete a vector by ID
|
||||
*/
|
||||
async deleteVector(id) {
|
||||
if (!this.db) await this.open();
|
||||
|
||||
this.cache.delete(id);
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
const transaction = this.db.transaction([VECTOR_STORE], 'readwrite');
|
||||
const store = transaction.objectStore(VECTOR_STORE);
|
||||
const request = store.delete(id);
|
||||
|
||||
request.onsuccess = () => resolve(true);
|
||||
request.onerror = () => reject(request.error);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear all vectors
|
||||
*/
|
||||
async clear() {
|
||||
if (!this.db) await this.open();
|
||||
|
||||
this.cache.clear();
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
const transaction = this.db.transaction([VECTOR_STORE], 'readwrite');
|
||||
const store = transaction.objectStore(VECTOR_STORE);
|
||||
const request = store.clear();
|
||||
|
||||
request.onsuccess = () => resolve();
|
||||
request.onerror = () => reject(request.error);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Get database statistics
|
||||
*/
|
||||
async getStats() {
|
||||
if (!this.db) await this.open();
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
const transaction = this.db.transaction([VECTOR_STORE], 'readonly');
|
||||
const store = transaction.objectStore(VECTOR_STORE);
|
||||
const request = store.count();
|
||||
|
||||
request.onsuccess = () => {
|
||||
resolve({
|
||||
totalVectors: request.result,
|
||||
cacheSize: this.cache.size,
|
||||
cacheHitRate: this.cache.size / request.result
|
||||
});
|
||||
};
|
||||
request.onerror = () => reject(request.error);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Save metadata
|
||||
*/
|
||||
async saveMeta(key, value) {
|
||||
if (!this.db) await this.open();
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
const transaction = this.db.transaction([META_STORE], 'readwrite');
|
||||
const store = transaction.objectStore(META_STORE);
|
||||
const request = store.put({ key, value });
|
||||
|
||||
request.onsuccess = () => resolve();
|
||||
request.onerror = () => reject(request.error);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Load metadata
|
||||
*/
|
||||
async loadMeta(key) {
|
||||
if (!this.db) await this.open();
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
const transaction = this.db.transaction([META_STORE], 'readonly');
|
||||
const store = transaction.objectStore(META_STORE);
|
||||
const request = store.get(key);
|
||||
|
||||
request.onsuccess = () => {
|
||||
const data = request.result;
|
||||
resolve(data ? data.value : null);
|
||||
};
|
||||
request.onerror = () => reject(request.error);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Close the database connection
|
||||
*/
|
||||
close() {
|
||||
if (this.db) {
|
||||
this.db.close();
|
||||
this.db = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export default IndexedDBPersistence;
|
||||
334
vendor/ruvector/crates/ruvector-wasm/src/kernel/allowlist.rs
vendored
Normal file
334
vendor/ruvector/crates/ruvector-wasm/src/kernel/allowlist.rs
vendored
Normal file
@@ -0,0 +1,334 @@
|
||||
//! Trusted Kernel Allowlist
|
||||
//!
|
||||
//! Maintains a list of approved kernel hashes for additional security.
|
||||
//! This provides defense-in-depth beyond signature verification.
|
||||
|
||||
use crate::kernel::error::VerifyError;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
/// Trusted kernel allowlist
|
||||
///
|
||||
/// Maintains approved kernel hashes organized by kernel ID.
|
||||
/// Even if a kernel has a valid signature, it must be in the allowlist
|
||||
/// to be executed (when allowlist enforcement is enabled).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TrustedKernelAllowlist {
|
||||
/// Set of approved kernel hashes (format: "sha256:...")
|
||||
approved_hashes: HashSet<String>,
|
||||
|
||||
/// Map of kernel_id -> approved hashes for that kernel
|
||||
kernel_hashes: HashMap<String, HashSet<String>>,
|
||||
|
||||
/// Whether to enforce allowlist (can be disabled for development)
|
||||
enforce: bool,
|
||||
|
||||
/// Allowlist version/update timestamp
|
||||
version: String,
|
||||
}
|
||||
|
||||
impl TrustedKernelAllowlist {
|
||||
/// Create a new empty allowlist
|
||||
pub fn new() -> Self {
|
||||
TrustedKernelAllowlist {
|
||||
approved_hashes: HashSet::new(),
|
||||
kernel_hashes: HashMap::new(),
|
||||
enforce: true,
|
||||
version: "1.0.0".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an allowlist that doesn't enforce checks (for development)
|
||||
///
|
||||
/// # Warning
|
||||
/// This should NEVER be used in production.
|
||||
pub fn insecure_allow_all() -> Self {
|
||||
TrustedKernelAllowlist {
|
||||
approved_hashes: HashSet::new(),
|
||||
kernel_hashes: HashMap::new(),
|
||||
enforce: false,
|
||||
version: "dev".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Load allowlist from JSON
|
||||
pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
|
||||
#[derive(serde::Deserialize)]
|
||||
struct AllowlistJson {
|
||||
version: String,
|
||||
kernels: HashMap<String, Vec<String>>,
|
||||
}
|
||||
|
||||
let parsed: AllowlistJson = serde_json::from_str(json)?;
|
||||
|
||||
let mut allowlist = TrustedKernelAllowlist::new();
|
||||
allowlist.version = parsed.version;
|
||||
|
||||
for (kernel_id, hashes) in parsed.kernels {
|
||||
for hash in hashes {
|
||||
allowlist.add_kernel_hash(&kernel_id, &hash);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(allowlist)
|
||||
}
|
||||
|
||||
/// Serialize allowlist to JSON
|
||||
pub fn to_json(&self) -> Result<String, serde_json::Error> {
|
||||
#[derive(serde::Serialize)]
|
||||
struct AllowlistJson {
|
||||
version: String,
|
||||
kernels: HashMap<String, Vec<String>>,
|
||||
}
|
||||
|
||||
let kernels: HashMap<String, Vec<String>> = self
|
||||
.kernel_hashes
|
||||
.iter()
|
||||
.map(|(k, v)| (k.clone(), v.iter().cloned().collect()))
|
||||
.collect();
|
||||
|
||||
let json = AllowlistJson {
|
||||
version: self.version.clone(),
|
||||
kernels,
|
||||
};
|
||||
|
||||
serde_json::to_string_pretty(&json)
|
||||
}
|
||||
|
||||
/// Add a hash to the global approved set
|
||||
pub fn add_hash(&mut self, hash: &str) {
|
||||
self.approved_hashes.insert(hash.to_lowercase());
|
||||
}
|
||||
|
||||
/// Add a hash for a specific kernel ID
|
||||
pub fn add_kernel_hash(&mut self, kernel_id: &str, hash: &str) {
|
||||
let lowercase_hash = hash.to_lowercase();
|
||||
self.approved_hashes.insert(lowercase_hash.clone());
|
||||
|
||||
self.kernel_hashes
|
||||
.entry(kernel_id.to_string())
|
||||
.or_insert_with(HashSet::new)
|
||||
.insert(lowercase_hash);
|
||||
}
|
||||
|
||||
/// Remove a hash from the allowlist
|
||||
pub fn remove_hash(&mut self, hash: &str) {
|
||||
let lowercase_hash = hash.to_lowercase();
|
||||
self.approved_hashes.remove(&lowercase_hash);
|
||||
|
||||
for hashes in self.kernel_hashes.values_mut() {
|
||||
hashes.remove(&lowercase_hash);
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a hash is in the allowlist
|
||||
pub fn is_allowed(&self, hash: &str) -> bool {
|
||||
if !self.enforce {
|
||||
return true;
|
||||
}
|
||||
self.approved_hashes.contains(&hash.to_lowercase())
|
||||
}
|
||||
|
||||
/// Check if a hash is allowed for a specific kernel ID
|
||||
pub fn is_allowed_for_kernel(&self, kernel_id: &str, hash: &str) -> bool {
|
||||
if !self.enforce {
|
||||
return true;
|
||||
}
|
||||
|
||||
let lowercase_hash = hash.to_lowercase();
|
||||
|
||||
// Check kernel-specific allowlist first
|
||||
if let Some(kernel_hashes) = self.kernel_hashes.get(kernel_id) {
|
||||
return kernel_hashes.contains(&lowercase_hash);
|
||||
}
|
||||
|
||||
// Fall back to global allowlist
|
||||
self.approved_hashes.contains(&lowercase_hash)
|
||||
}
|
||||
|
||||
/// Verify a kernel is in the allowlist
|
||||
pub fn verify(&self, kernel_id: &str, hash: &str) -> Result<(), VerifyError> {
|
||||
if self.is_allowed_for_kernel(kernel_id, hash) {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(VerifyError::NotInAllowlist {
|
||||
kernel_id: kernel_id.to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Get number of approved hashes
|
||||
pub fn hash_count(&self) -> usize {
|
||||
self.approved_hashes.len()
|
||||
}
|
||||
|
||||
/// Get all approved hashes for a kernel ID
|
||||
pub fn get_kernel_hashes(&self, kernel_id: &str) -> Option<&HashSet<String>> {
|
||||
self.kernel_hashes.get(kernel_id)
|
||||
}
|
||||
|
||||
/// List all kernel IDs with approved hashes
|
||||
pub fn kernel_ids(&self) -> Vec<&str> {
|
||||
self.kernel_hashes.keys().map(|s| s.as_str()).collect()
|
||||
}
|
||||
|
||||
/// Get allowlist version
|
||||
pub fn version(&self) -> &str {
|
||||
&self.version
|
||||
}
|
||||
|
||||
/// Set allowlist version
|
||||
pub fn set_version(&mut self, version: &str) {
|
||||
self.version = version.to_string();
|
||||
}
|
||||
|
||||
/// Check if enforcement is enabled
|
||||
pub fn is_enforced(&self) -> bool {
|
||||
self.enforce
|
||||
}
|
||||
|
||||
/// Merge another allowlist into this one
|
||||
pub fn merge(&mut self, other: &TrustedKernelAllowlist) {
|
||||
for hash in &other.approved_hashes {
|
||||
self.approved_hashes.insert(hash.clone());
|
||||
}
|
||||
|
||||
for (kernel_id, hashes) in &other.kernel_hashes {
|
||||
let entry = self
|
||||
.kernel_hashes
|
||||
.entry(kernel_id.clone())
|
||||
.or_insert_with(HashSet::new);
|
||||
for hash in hashes {
|
||||
entry.insert(hash.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for TrustedKernelAllowlist {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Built-in allowlist of official RuvLLM kernels
|
||||
///
|
||||
/// This provides a starting point with known-good kernel hashes.
|
||||
/// Production deployments should maintain their own allowlist.
|
||||
pub fn builtin_allowlist() -> TrustedKernelAllowlist {
|
||||
let mut allowlist = TrustedKernelAllowlist::new();
|
||||
allowlist.set_version("0.1.0-builtin");
|
||||
|
||||
// Add placeholders for official kernels
|
||||
// These would be replaced with actual hashes in production
|
||||
// allowlist.add_kernel_hash("rope_f32", "sha256:...");
|
||||
// allowlist.add_kernel_hash("rmsnorm_f32", "sha256:...");
|
||||
// allowlist.add_kernel_hash("swiglu_f32", "sha256:...");
|
||||
|
||||
allowlist
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_add_and_check_hash() {
|
||||
let mut allowlist = TrustedKernelAllowlist::new();
|
||||
let hash = "sha256:abc123def456";
|
||||
|
||||
assert!(!allowlist.is_allowed(hash));
|
||||
|
||||
allowlist.add_hash(hash);
|
||||
assert!(allowlist.is_allowed(hash));
|
||||
|
||||
// Case insensitive
|
||||
assert!(allowlist.is_allowed("SHA256:ABC123DEF456"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_kernel_specific_hash() {
|
||||
let mut allowlist = TrustedKernelAllowlist::new();
|
||||
|
||||
allowlist.add_kernel_hash("rope_f32", "sha256:rope_hash");
|
||||
allowlist.add_kernel_hash("rmsnorm_f32", "sha256:rmsnorm_hash");
|
||||
|
||||
assert!(allowlist.is_allowed_for_kernel("rope_f32", "sha256:rope_hash"));
|
||||
assert!(!allowlist.is_allowed_for_kernel("rope_f32", "sha256:rmsnorm_hash"));
|
||||
assert!(allowlist.is_allowed_for_kernel("rmsnorm_f32", "sha256:rmsnorm_hash"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verify() {
|
||||
let mut allowlist = TrustedKernelAllowlist::new();
|
||||
allowlist.add_kernel_hash("rope_f32", "sha256:valid_hash");
|
||||
|
||||
assert!(allowlist.verify("rope_f32", "sha256:valid_hash").is_ok());
|
||||
assert!(matches!(
|
||||
allowlist.verify("rope_f32", "sha256:invalid_hash"),
|
||||
Err(VerifyError::NotInAllowlist { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_insecure_allow_all() {
|
||||
let allowlist = TrustedKernelAllowlist::insecure_allow_all();
|
||||
|
||||
// Should allow any hash when not enforcing
|
||||
assert!(allowlist.is_allowed("sha256:anything"));
|
||||
assert!(allowlist.is_allowed_for_kernel("any_kernel", "sha256:anything"));
|
||||
assert!(!allowlist.is_enforced());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_remove_hash() {
|
||||
let mut allowlist = TrustedKernelAllowlist::new();
|
||||
allowlist.add_kernel_hash("kernel", "sha256:hash");
|
||||
|
||||
assert!(allowlist.is_allowed("sha256:hash"));
|
||||
|
||||
allowlist.remove_hash("sha256:hash");
|
||||
assert!(!allowlist.is_allowed("sha256:hash"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_json_roundtrip() {
|
||||
let mut original = TrustedKernelAllowlist::new();
|
||||
original.set_version("1.2.3");
|
||||
original.add_kernel_hash("rope_f32", "sha256:hash1");
|
||||
original.add_kernel_hash("rope_f32", "sha256:hash2");
|
||||
original.add_kernel_hash("rmsnorm_f32", "sha256:hash3");
|
||||
|
||||
let json = original.to_json().unwrap();
|
||||
let restored = TrustedKernelAllowlist::from_json(&json).unwrap();
|
||||
|
||||
assert_eq!(restored.version(), "1.2.3");
|
||||
assert!(restored.is_allowed_for_kernel("rope_f32", "sha256:hash1"));
|
||||
assert!(restored.is_allowed_for_kernel("rope_f32", "sha256:hash2"));
|
||||
assert!(restored.is_allowed_for_kernel("rmsnorm_f32", "sha256:hash3"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge() {
|
||||
let mut allowlist1 = TrustedKernelAllowlist::new();
|
||||
allowlist1.add_kernel_hash("kernel1", "sha256:hash1");
|
||||
|
||||
let mut allowlist2 = TrustedKernelAllowlist::new();
|
||||
allowlist2.add_kernel_hash("kernel2", "sha256:hash2");
|
||||
|
||||
allowlist1.merge(&allowlist2);
|
||||
|
||||
assert!(allowlist1.is_allowed_for_kernel("kernel1", "sha256:hash1"));
|
||||
assert!(allowlist1.is_allowed_for_kernel("kernel2", "sha256:hash2"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_kernel_ids() {
|
||||
let mut allowlist = TrustedKernelAllowlist::new();
|
||||
allowlist.add_kernel_hash("kernel_a", "sha256:a");
|
||||
allowlist.add_kernel_hash("kernel_b", "sha256:b");
|
||||
|
||||
let ids = allowlist.kernel_ids();
|
||||
assert!(ids.contains(&"kernel_a"));
|
||||
assert!(ids.contains(&"kernel_b"));
|
||||
}
|
||||
}
|
||||
314
vendor/ruvector/crates/ruvector-wasm/src/kernel/epoch.rs
vendored
Normal file
314
vendor/ruvector/crates/ruvector-wasm/src/kernel/epoch.rs
vendored
Normal file
@@ -0,0 +1,314 @@
|
||||
//! Epoch-Based Interruption
|
||||
//!
|
||||
//! Provides execution budget management using Wasmtime's epoch mechanism.
|
||||
//! This allows coarse-grained interruption of WASM execution with minimal overhead.
|
||||
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Epoch controller for managing execution budgets
|
||||
///
|
||||
/// The epoch mechanism works by periodically incrementing a counter.
|
||||
/// WASM code checks this counter at certain points (function calls, loops)
|
||||
/// and traps if the deadline has been exceeded.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EpochController {
|
||||
/// Current epoch value
|
||||
current_epoch: Arc<AtomicU64>,
|
||||
/// Tick interval
|
||||
tick_interval: Duration,
|
||||
/// Whether the controller is running
|
||||
running: Arc<std::sync::atomic::AtomicBool>,
|
||||
}
|
||||
|
||||
impl EpochController {
|
||||
/// Create a new epoch controller
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `tick_interval` - How often to increment the epoch (e.g., 10ms)
|
||||
pub fn new(tick_interval: Duration) -> Self {
|
||||
EpochController {
|
||||
current_epoch: Arc::new(AtomicU64::new(0)),
|
||||
tick_interval,
|
||||
running: Arc::new(std::sync::atomic::AtomicBool::new(false)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with default 10ms tick interval
|
||||
pub fn default_interval() -> Self {
|
||||
Self::new(Duration::from_millis(10))
|
||||
}
|
||||
|
||||
/// Get current epoch value
|
||||
pub fn current(&self) -> u64 {
|
||||
self.current_epoch.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Manually increment the epoch
|
||||
pub fn increment(&self) {
|
||||
self.current_epoch.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Reset epoch to zero
|
||||
pub fn reset(&self) {
|
||||
self.current_epoch.store(0, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Get tick interval
|
||||
pub fn tick_interval(&self) -> Duration {
|
||||
self.tick_interval
|
||||
}
|
||||
|
||||
/// Check if the controller is running
|
||||
pub fn is_running(&self) -> bool {
|
||||
self.running.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Get a clone of the epoch counter for sharing
|
||||
pub fn epoch_counter(&self) -> Arc<AtomicU64> {
|
||||
Arc::clone(&self.current_epoch)
|
||||
}
|
||||
|
||||
/// Calculate deadline epoch for a given budget
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `budget_ticks` - Number of ticks before timeout
|
||||
///
|
||||
/// # Returns
|
||||
/// The epoch value that represents the deadline
|
||||
pub fn deadline_for_budget(&self, budget_ticks: u64) -> u64 {
|
||||
self.current() + budget_ticks
|
||||
}
|
||||
|
||||
/// Check if an epoch deadline has been exceeded
|
||||
pub fn is_deadline_exceeded(&self, deadline: u64) -> bool {
|
||||
self.current() >= deadline
|
||||
}
|
||||
|
||||
/// Convert epoch ticks to approximate duration
|
||||
pub fn ticks_to_duration(&self, ticks: u64) -> Duration {
|
||||
self.tick_interval * ticks as u32
|
||||
}
|
||||
|
||||
/// Convert duration to approximate epoch ticks
|
||||
pub fn duration_to_ticks(&self, duration: Duration) -> u64 {
|
||||
(duration.as_nanos() / self.tick_interval.as_nanos()) as u64
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for EpochController {
|
||||
fn default() -> Self {
|
||||
Self::default_interval()
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for epoch-based execution limits
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct EpochConfig {
|
||||
/// Enable epoch interruption
|
||||
pub enabled: bool,
|
||||
|
||||
/// Tick interval in milliseconds
|
||||
pub tick_interval_ms: u64,
|
||||
|
||||
/// Default budget in ticks
|
||||
pub default_budget: u64,
|
||||
|
||||
/// Maximum allowed budget (prevents abuse)
|
||||
pub max_budget: u64,
|
||||
}
|
||||
|
||||
impl EpochConfig {
|
||||
/// Create a new epoch configuration
|
||||
pub fn new(tick_interval_ms: u64, default_budget: u64) -> Self {
|
||||
EpochConfig {
|
||||
enabled: true,
|
||||
tick_interval_ms,
|
||||
default_budget,
|
||||
max_budget: default_budget * 10, // 10x default as max
|
||||
}
|
||||
}
|
||||
|
||||
/// Create configuration for server workloads (longer budgets)
|
||||
pub fn server() -> Self {
|
||||
EpochConfig {
|
||||
enabled: true,
|
||||
tick_interval_ms: 10,
|
||||
default_budget: 1000, // 10 seconds
|
||||
max_budget: 6000, // 60 seconds max
|
||||
}
|
||||
}
|
||||
|
||||
/// Create configuration for embedded/constrained workloads
|
||||
pub fn embedded() -> Self {
|
||||
EpochConfig {
|
||||
enabled: true,
|
||||
tick_interval_ms: 1,
|
||||
default_budget: 100, // 100ms
|
||||
max_budget: 1000, // 1 second max
|
||||
}
|
||||
}
|
||||
|
||||
/// Create configuration with interruption disabled (for benchmarking)
|
||||
///
|
||||
/// # Warning
|
||||
/// Only use this for controlled benchmarking scenarios.
|
||||
pub fn disabled() -> Self {
|
||||
EpochConfig {
|
||||
enabled: false,
|
||||
tick_interval_ms: 10,
|
||||
default_budget: u64::MAX,
|
||||
max_budget: u64::MAX,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get tick interval as Duration
|
||||
pub fn tick_interval(&self) -> Duration {
|
||||
Duration::from_millis(self.tick_interval_ms)
|
||||
}
|
||||
|
||||
/// Clamp a requested budget to the allowed maximum
|
||||
pub fn clamp_budget(&self, requested: u64) -> u64 {
|
||||
requested.min(self.max_budget)
|
||||
}
|
||||
|
||||
/// Convert budget ticks to approximate duration
|
||||
pub fn budget_duration(&self, budget: u64) -> Duration {
|
||||
Duration::from_millis(budget * self.tick_interval_ms)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for EpochConfig {
|
||||
fn default() -> Self {
|
||||
Self::server()
|
||||
}
|
||||
}
|
||||
|
||||
/// Epoch deadline tracker for a single kernel invocation
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct EpochDeadline {
|
||||
/// The epoch value at which execution should stop
|
||||
pub deadline: u64,
|
||||
/// The budget that was allocated
|
||||
pub budget: u64,
|
||||
/// When the execution started (epoch value)
|
||||
pub start_epoch: u64,
|
||||
}
|
||||
|
||||
impl EpochDeadline {
|
||||
/// Create a new deadline
|
||||
pub fn new(start_epoch: u64, budget: u64) -> Self {
|
||||
EpochDeadline {
|
||||
deadline: start_epoch + budget,
|
||||
budget,
|
||||
start_epoch,
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate elapsed ticks
|
||||
pub fn elapsed(&self, current: u64) -> u64 {
|
||||
current.saturating_sub(self.start_epoch)
|
||||
}
|
||||
|
||||
/// Calculate remaining ticks
|
||||
pub fn remaining(&self, current: u64) -> u64 {
|
||||
self.deadline.saturating_sub(current)
|
||||
}
|
||||
|
||||
/// Check if deadline is exceeded
|
||||
pub fn is_exceeded(&self, current: u64) -> bool {
|
||||
current >= self.deadline
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_epoch_controller() {
|
||||
let controller = EpochController::default_interval();
|
||||
assert_eq!(controller.current(), 0);
|
||||
|
||||
controller.increment();
|
||||
assert_eq!(controller.current(), 1);
|
||||
|
||||
controller.increment();
|
||||
assert_eq!(controller.current(), 2);
|
||||
|
||||
controller.reset();
|
||||
assert_eq!(controller.current(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deadline_calculation() {
|
||||
let controller = EpochController::default_interval();
|
||||
|
||||
let deadline = controller.deadline_for_budget(100);
|
||||
assert_eq!(deadline, 100);
|
||||
|
||||
assert!(!controller.is_deadline_exceeded(deadline));
|
||||
|
||||
// Simulate time passing
|
||||
for _ in 0..100 {
|
||||
controller.increment();
|
||||
}
|
||||
|
||||
assert!(controller.is_deadline_exceeded(deadline));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_duration_conversion() {
|
||||
let config = EpochConfig::new(10, 1000);
|
||||
|
||||
assert_eq!(config.budget_duration(100), Duration::from_secs(1));
|
||||
|
||||
let controller = EpochController::new(Duration::from_millis(10));
|
||||
assert_eq!(controller.ticks_to_duration(100), Duration::from_secs(1));
|
||||
assert_eq!(controller.duration_to_ticks(Duration::from_secs(1)), 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_epoch_config_clamp() {
|
||||
let config = EpochConfig::new(10, 1000);
|
||||
assert_eq!(config.max_budget, 10000);
|
||||
|
||||
assert_eq!(config.clamp_budget(500), 500);
|
||||
assert_eq!(config.clamp_budget(20000), 10000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_epoch_deadline() {
|
||||
let deadline = EpochDeadline::new(10, 100);
|
||||
|
||||
assert_eq!(deadline.deadline, 110);
|
||||
assert_eq!(deadline.elapsed(50), 40);
|
||||
assert_eq!(deadline.remaining(50), 60);
|
||||
assert!(!deadline.is_exceeded(50));
|
||||
assert!(deadline.is_exceeded(110));
|
||||
assert!(deadline.is_exceeded(200));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_server_config() {
|
||||
let config = EpochConfig::server();
|
||||
assert!(config.enabled);
|
||||
assert_eq!(config.tick_interval_ms, 10);
|
||||
assert_eq!(config.default_budget, 1000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedded_config() {
|
||||
let config = EpochConfig::embedded();
|
||||
assert!(config.enabled);
|
||||
assert_eq!(config.tick_interval_ms, 1);
|
||||
assert_eq!(config.default_budget, 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_disabled_config() {
|
||||
let config = EpochConfig::disabled();
|
||||
assert!(!config.enabled);
|
||||
}
|
||||
}
|
||||
368
vendor/ruvector/crates/ruvector-wasm/src/kernel/error.rs
vendored
Normal file
368
vendor/ruvector/crates/ruvector-wasm/src/kernel/error.rs
vendored
Normal file
@@ -0,0 +1,368 @@
|
||||
//! Error types for the kernel pack system
|
||||
//!
|
||||
//! Provides comprehensive error handling for kernel verification,
|
||||
//! loading, and execution.
|
||||
|
||||
use std::fmt;
|
||||
|
||||
/// Errors that can occur during kernel execution
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum KernelError {
|
||||
/// Execution budget exceeded (epoch deadline reached)
|
||||
EpochDeadline,
|
||||
|
||||
/// Out of bounds memory access
|
||||
MemoryAccessViolation {
|
||||
/// Attempted access offset
|
||||
offset: u32,
|
||||
/// Attempted access size
|
||||
size: u32,
|
||||
},
|
||||
|
||||
/// Integer overflow/underflow during computation
|
||||
IntegerOverflow,
|
||||
|
||||
/// Unreachable code was executed
|
||||
Unreachable,
|
||||
|
||||
/// Stack overflow in WASM execution
|
||||
StackOverflow,
|
||||
|
||||
/// Indirect call type mismatch
|
||||
IndirectCallTypeMismatch,
|
||||
|
||||
/// Custom trap from kernel with error code
|
||||
KernelTrap {
|
||||
/// Error code returned by kernel
|
||||
code: u32,
|
||||
/// Optional error message
|
||||
message: Option<String>,
|
||||
},
|
||||
|
||||
/// Kernel not found
|
||||
KernelNotFound {
|
||||
/// Requested kernel ID
|
||||
kernel_id: String,
|
||||
},
|
||||
|
||||
/// Invalid kernel parameters
|
||||
InvalidParameters {
|
||||
/// Description of the parameter error
|
||||
description: String,
|
||||
},
|
||||
|
||||
/// Tensor shape mismatch
|
||||
ShapeMismatch {
|
||||
/// Expected shape description
|
||||
expected: String,
|
||||
/// Actual shape description
|
||||
actual: String,
|
||||
},
|
||||
|
||||
/// Data type mismatch
|
||||
DTypeMismatch {
|
||||
/// Expected data type
|
||||
expected: String,
|
||||
/// Actual data type
|
||||
actual: String,
|
||||
},
|
||||
|
||||
/// Memory allocation failed
|
||||
AllocationFailed {
|
||||
/// Requested size in bytes
|
||||
requested_bytes: usize,
|
||||
},
|
||||
|
||||
/// Kernel initialization failed
|
||||
InitializationFailed {
|
||||
/// Reason for failure
|
||||
reason: String,
|
||||
},
|
||||
|
||||
/// Runtime error
|
||||
RuntimeError {
|
||||
/// Error message
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// Feature not supported
|
||||
UnsupportedFeature {
|
||||
/// Feature name
|
||||
feature: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl fmt::Display for KernelError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
KernelError::EpochDeadline => {
|
||||
write!(f, "Kernel execution exceeded time budget (epoch deadline)")
|
||||
}
|
||||
KernelError::MemoryAccessViolation { offset, size } => {
|
||||
write!(
|
||||
f,
|
||||
"Memory access violation: offset={}, size={}",
|
||||
offset, size
|
||||
)
|
||||
}
|
||||
KernelError::IntegerOverflow => write!(f, "Integer overflow during computation"),
|
||||
KernelError::Unreachable => write!(f, "Unreachable code executed"),
|
||||
KernelError::StackOverflow => write!(f, "Stack overflow"),
|
||||
KernelError::IndirectCallTypeMismatch => {
|
||||
write!(f, "Indirect call type mismatch")
|
||||
}
|
||||
KernelError::KernelTrap { code, message } => {
|
||||
write!(f, "Kernel trap (code={})", code)?;
|
||||
if let Some(msg) = message {
|
||||
write!(f, ": {}", msg)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
KernelError::KernelNotFound { kernel_id } => {
|
||||
write!(f, "Kernel not found: {}", kernel_id)
|
||||
}
|
||||
KernelError::InvalidParameters { description } => {
|
||||
write!(f, "Invalid parameters: {}", description)
|
||||
}
|
||||
KernelError::ShapeMismatch { expected, actual } => {
|
||||
write!(f, "Shape mismatch: expected {}, got {}", expected, actual)
|
||||
}
|
||||
KernelError::DTypeMismatch { expected, actual } => {
|
||||
write!(f, "DType mismatch: expected {}, got {}", expected, actual)
|
||||
}
|
||||
KernelError::AllocationFailed { requested_bytes } => {
|
||||
write!(f, "Memory allocation failed: {} bytes", requested_bytes)
|
||||
}
|
||||
KernelError::InitializationFailed { reason } => {
|
||||
write!(f, "Kernel initialization failed: {}", reason)
|
||||
}
|
||||
KernelError::RuntimeError { message } => {
|
||||
write!(f, "Runtime error: {}", message)
|
||||
}
|
||||
KernelError::UnsupportedFeature { feature } => {
|
||||
write!(f, "Unsupported feature: {}", feature)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for KernelError {}
|
||||
|
||||
/// Errors that can occur during kernel verification
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum VerifyError {
|
||||
/// No trusted signing key matched
|
||||
NoTrustedKey,
|
||||
|
||||
/// Signature is invalid
|
||||
InvalidSignature {
|
||||
/// Description of the signature error
|
||||
reason: String,
|
||||
},
|
||||
|
||||
/// Hash mismatch
|
||||
HashMismatch {
|
||||
/// Expected hash
|
||||
expected: String,
|
||||
/// Actual computed hash
|
||||
actual: String,
|
||||
},
|
||||
|
||||
/// Manifest parsing failed
|
||||
InvalidManifest {
|
||||
/// Error message
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// Version incompatibility
|
||||
VersionIncompatible {
|
||||
/// Required version range
|
||||
required: String,
|
||||
/// Actual version
|
||||
actual: String,
|
||||
},
|
||||
|
||||
/// Runtime too old for kernel pack
|
||||
RuntimeTooOld {
|
||||
/// Minimum required version
|
||||
required: String,
|
||||
/// Actual runtime version
|
||||
actual: String,
|
||||
},
|
||||
|
||||
/// Runtime too new for kernel pack
|
||||
RuntimeTooNew {
|
||||
/// Maximum supported version
|
||||
max_supported: String,
|
||||
/// Actual runtime version
|
||||
actual: String,
|
||||
},
|
||||
|
||||
/// Missing required WASM feature
|
||||
MissingFeature {
|
||||
/// Kernel that requires the feature
|
||||
kernel: String,
|
||||
/// Missing feature name
|
||||
feature: String,
|
||||
},
|
||||
|
||||
/// Kernel not in allowlist
|
||||
NotInAllowlist {
|
||||
/// Kernel ID
|
||||
kernel_id: String,
|
||||
},
|
||||
|
||||
/// File I/O error
|
||||
IoError {
|
||||
/// Error message
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// Key parsing error
|
||||
KeyError {
|
||||
/// Error message
|
||||
message: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl fmt::Display for VerifyError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
VerifyError::NoTrustedKey => {
|
||||
write!(f, "No trusted signing key matched the manifest signature")
|
||||
}
|
||||
VerifyError::InvalidSignature { reason } => {
|
||||
write!(f, "Invalid signature: {}", reason)
|
||||
}
|
||||
VerifyError::HashMismatch { expected, actual } => {
|
||||
write!(f, "Hash mismatch: expected {}, got {}", expected, actual)
|
||||
}
|
||||
VerifyError::InvalidManifest { message } => {
|
||||
write!(f, "Invalid manifest: {}", message)
|
||||
}
|
||||
VerifyError::VersionIncompatible { required, actual } => {
|
||||
write!(
|
||||
f,
|
||||
"Version incompatible: required {}, got {}",
|
||||
required, actual
|
||||
)
|
||||
}
|
||||
VerifyError::RuntimeTooOld { required, actual } => {
|
||||
write!(f, "Runtime too old: requires {}, have {}", required, actual)
|
||||
}
|
||||
VerifyError::RuntimeTooNew {
|
||||
max_supported,
|
||||
actual,
|
||||
} => {
|
||||
write!(
|
||||
f,
|
||||
"Runtime too new: max supported {}, have {}",
|
||||
max_supported, actual
|
||||
)
|
||||
}
|
||||
VerifyError::MissingFeature { kernel, feature } => {
|
||||
write!(
|
||||
f,
|
||||
"Kernel '{}' requires missing feature: {}",
|
||||
kernel, feature
|
||||
)
|
||||
}
|
||||
VerifyError::NotInAllowlist { kernel_id } => {
|
||||
write!(f, "Kernel '{}' not in allowlist", kernel_id)
|
||||
}
|
||||
VerifyError::IoError { message } => write!(f, "I/O error: {}", message),
|
||||
VerifyError::KeyError { message } => write!(f, "Key error: {}", message),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for VerifyError {}
|
||||
|
||||
/// Result type alias for kernel operations
|
||||
pub type KernelResult<T> = Result<T, KernelError>;
|
||||
|
||||
/// Result type alias for verification operations
|
||||
pub type VerifyResult<T> = Result<T, VerifyError>;
|
||||
|
||||
/// Standard kernel error codes (returned by kernel_forward/kernel_backward)
|
||||
#[repr(u32)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum KernelErrorCode {
|
||||
/// Success
|
||||
Ok = 0,
|
||||
/// Invalid input tensor
|
||||
InvalidInput = 1,
|
||||
/// Invalid output tensor
|
||||
InvalidOutput = 2,
|
||||
/// Invalid kernel parameters
|
||||
InvalidParams = 3,
|
||||
/// Out of memory
|
||||
OutOfMemory = 4,
|
||||
/// Operation not implemented
|
||||
NotImplemented = 5,
|
||||
/// Internal kernel error
|
||||
InternalError = 6,
|
||||
}
|
||||
|
||||
impl From<u32> for KernelErrorCode {
|
||||
fn from(code: u32) -> Self {
|
||||
match code {
|
||||
0 => KernelErrorCode::Ok,
|
||||
1 => KernelErrorCode::InvalidInput,
|
||||
2 => KernelErrorCode::InvalidOutput,
|
||||
3 => KernelErrorCode::InvalidParams,
|
||||
4 => KernelErrorCode::OutOfMemory,
|
||||
5 => KernelErrorCode::NotImplemented,
|
||||
_ => KernelErrorCode::InternalError,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for KernelErrorCode {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
KernelErrorCode::Ok => write!(f, "OK"),
|
||||
KernelErrorCode::InvalidInput => write!(f, "Invalid input tensor"),
|
||||
KernelErrorCode::InvalidOutput => write!(f, "Invalid output tensor"),
|
||||
KernelErrorCode::InvalidParams => write!(f, "Invalid parameters"),
|
||||
KernelErrorCode::OutOfMemory => write!(f, "Out of memory"),
|
||||
KernelErrorCode::NotImplemented => write!(f, "Not implemented"),
|
||||
KernelErrorCode::InternalError => write!(f, "Internal error"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_kernel_error_display() {
|
||||
let err = KernelError::EpochDeadline;
|
||||
assert!(err.to_string().contains("epoch deadline"));
|
||||
|
||||
let err = KernelError::MemoryAccessViolation {
|
||||
offset: 100,
|
||||
size: 64,
|
||||
};
|
||||
assert!(err.to_string().contains("100"));
|
||||
assert!(err.to_string().contains("64"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verify_error_display() {
|
||||
let err = VerifyError::HashMismatch {
|
||||
expected: "abc123".to_string(),
|
||||
actual: "def456".to_string(),
|
||||
};
|
||||
assert!(err.to_string().contains("abc123"));
|
||||
assert!(err.to_string().contains("def456"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_code_conversion() {
|
||||
assert_eq!(KernelErrorCode::from(0), KernelErrorCode::Ok);
|
||||
assert_eq!(KernelErrorCode::from(1), KernelErrorCode::InvalidInput);
|
||||
assert_eq!(KernelErrorCode::from(100), KernelErrorCode::InternalError);
|
||||
}
|
||||
}
|
||||
176
vendor/ruvector/crates/ruvector-wasm/src/kernel/hash.rs
vendored
Normal file
176
vendor/ruvector/crates/ruvector-wasm/src/kernel/hash.rs
vendored
Normal file
@@ -0,0 +1,176 @@
|
||||
//! SHA256 Hash Verification
|
||||
//!
|
||||
//! Provides hash verification for WASM kernel files to ensure integrity.
|
||||
|
||||
use crate::kernel::error::VerifyError;
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
/// Hash verifier for kernel files
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HashVerifier {
|
||||
/// Expected hash format prefix (e.g., "sha256:")
|
||||
prefix: String,
|
||||
}
|
||||
|
||||
impl HashVerifier {
|
||||
/// Create a new SHA256 hash verifier
|
||||
pub fn sha256() -> Self {
|
||||
HashVerifier {
|
||||
prefix: "sha256:".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute SHA256 hash of data
|
||||
pub fn compute_hash(data: &[u8]) -> String {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(data);
|
||||
let result = hasher.finalize();
|
||||
format!("sha256:{:x}", result)
|
||||
}
|
||||
|
||||
/// Verify kernel data against expected hash
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `kernel_bytes` - The raw WASM kernel bytes
|
||||
/// * `expected_hash` - Expected hash string (format: "sha256:...")
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Ok(())` if hash matches
|
||||
/// * `Err(VerifyError::HashMismatch)` if hash doesn't match
|
||||
pub fn verify(&self, kernel_bytes: &[u8], expected_hash: &str) -> Result<(), VerifyError> {
|
||||
// Validate expected hash format
|
||||
if !expected_hash.starts_with(&self.prefix) {
|
||||
return Err(VerifyError::InvalidManifest {
|
||||
message: format!(
|
||||
"Invalid hash format: expected '{}' prefix, got '{}'",
|
||||
self.prefix,
|
||||
expected_hash.get(..10).unwrap_or(expected_hash)
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
let actual_hash = Self::compute_hash(kernel_bytes);
|
||||
|
||||
if actual_hash.eq_ignore_ascii_case(expected_hash) {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(VerifyError::HashMismatch {
|
||||
expected: expected_hash.to_string(),
|
||||
actual: actual_hash,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Verify multiple kernels in batch
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `kernels` - Iterator of (kernel_bytes, expected_hash) pairs
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Ok(())` if all hashes match
|
||||
/// * `Err` with first mismatch
|
||||
pub fn verify_batch<'a>(
|
||||
&self,
|
||||
kernels: impl Iterator<Item = (&'a [u8], &'a str)>,
|
||||
) -> Result<(), VerifyError> {
|
||||
for (bytes, expected) in kernels {
|
||||
self.verify(bytes, expected)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for HashVerifier {
|
||||
fn default() -> Self {
|
||||
Self::sha256()
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute hash for a kernel file and return formatted string
|
||||
pub fn hash_kernel(kernel_bytes: &[u8]) -> String {
|
||||
HashVerifier::compute_hash(kernel_bytes)
|
||||
}
|
||||
|
||||
/// Verify a kernel file against expected hash (convenience function)
|
||||
pub fn verify_kernel_hash(kernel_bytes: &[u8], expected_hash: &str) -> Result<(), VerifyError> {
|
||||
HashVerifier::sha256().verify(kernel_bytes, expected_hash)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_compute_hash() {
|
||||
let data = b"hello world";
|
||||
let hash = HashVerifier::compute_hash(data);
|
||||
assert!(hash.starts_with("sha256:"));
|
||||
// Known SHA256 of "hello world"
|
||||
assert!(hash.contains("b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verify_success() {
|
||||
let data = b"test kernel data";
|
||||
let hash = HashVerifier::compute_hash(data);
|
||||
|
||||
let verifier = HashVerifier::sha256();
|
||||
assert!(verifier.verify(data, &hash).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verify_case_insensitive() {
|
||||
let data = b"test kernel data";
|
||||
let hash = HashVerifier::compute_hash(data);
|
||||
let upper_hash = hash.to_uppercase();
|
||||
|
||||
let verifier = HashVerifier::sha256();
|
||||
assert!(verifier.verify(data, &upper_hash).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verify_mismatch() {
|
||||
let data = b"actual data";
|
||||
let wrong_hash = "sha256:0000000000000000000000000000000000000000000000000000000000000000";
|
||||
|
||||
let verifier = HashVerifier::sha256();
|
||||
let result = verifier.verify(data, wrong_hash);
|
||||
|
||||
assert!(matches!(result, Err(VerifyError::HashMismatch { .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verify_invalid_format() {
|
||||
let data = b"test data";
|
||||
let invalid_hash = "md5:abc123";
|
||||
|
||||
let verifier = HashVerifier::sha256();
|
||||
let result = verifier.verify(data, invalid_hash);
|
||||
|
||||
assert!(matches!(result, Err(VerifyError::InvalidManifest { .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verify_batch() {
|
||||
let data1 = b"kernel1";
|
||||
let data2 = b"kernel2";
|
||||
let hash1 = HashVerifier::compute_hash(data1);
|
||||
let hash2 = HashVerifier::compute_hash(data2);
|
||||
|
||||
let verifier = HashVerifier::sha256();
|
||||
let kernels = vec![
|
||||
(data1.as_slice(), hash1.as_str()),
|
||||
(data2.as_slice(), hash2.as_str()),
|
||||
];
|
||||
|
||||
assert!(verifier.verify_batch(kernels.into_iter()).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convenience_function() {
|
||||
let data = b"convenience test";
|
||||
let hash = hash_kernel(data);
|
||||
|
||||
assert!(verify_kernel_hash(data, &hash).is_ok());
|
||||
}
|
||||
}
|
||||
500
vendor/ruvector/crates/ruvector-wasm/src/kernel/manifest.rs
vendored
Normal file
500
vendor/ruvector/crates/ruvector-wasm/src/kernel/manifest.rs
vendored
Normal file
@@ -0,0 +1,500 @@
|
||||
//! Kernel Pack Manifest (kernels.json)
|
||||
//!
|
||||
//! Defines the manifest schema for kernel packs, including kernel metadata,
|
||||
//! resource limits, platform requirements, and versioning.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Kernel pack manifest (kernels.json)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct KernelManifest {
|
||||
/// JSON schema URL
|
||||
#[serde(rename = "$schema", default)]
|
||||
pub schema: String,
|
||||
|
||||
/// Manifest version (semver)
|
||||
pub version: String,
|
||||
|
||||
/// Pack name
|
||||
pub name: String,
|
||||
|
||||
/// Pack description
|
||||
pub description: String,
|
||||
|
||||
/// Minimum runtime version required
|
||||
pub min_runtime_version: String,
|
||||
|
||||
/// Maximum runtime version supported
|
||||
pub max_runtime_version: String,
|
||||
|
||||
/// Creation timestamp (ISO 8601)
|
||||
pub created_at: String,
|
||||
|
||||
/// Author information
|
||||
pub author: AuthorInfo,
|
||||
|
||||
/// List of kernels in the pack
|
||||
pub kernels: Vec<KernelInfo>,
|
||||
|
||||
/// Fallback mappings (kernel_id -> fallback_kernel_id)
|
||||
#[serde(default)]
|
||||
pub fallbacks: HashMap<String, String>,
|
||||
}
|
||||
|
||||
/// Author information
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AuthorInfo {
|
||||
/// Author name
|
||||
pub name: String,
|
||||
|
||||
/// Contact email
|
||||
pub email: String,
|
||||
|
||||
/// Ed25519 public signing key (base64 or hex encoded)
|
||||
pub signing_key: String,
|
||||
}
|
||||
|
||||
/// Individual kernel information
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct KernelInfo {
|
||||
/// Unique kernel identifier
|
||||
pub id: String,
|
||||
|
||||
/// Human-readable name
|
||||
pub name: String,
|
||||
|
||||
/// Kernel category
|
||||
pub category: KernelCategory,
|
||||
|
||||
/// Path to WASM file relative to pack root
|
||||
pub path: String,
|
||||
|
||||
/// SHA256 hash of the WASM file (format: "sha256:...")
|
||||
pub hash: String,
|
||||
|
||||
/// Entry point function name
|
||||
pub entry_point: String,
|
||||
|
||||
/// Input tensor specifications
|
||||
pub inputs: Vec<TensorSpec>,
|
||||
|
||||
/// Output tensor specifications
|
||||
pub outputs: Vec<TensorSpec>,
|
||||
|
||||
/// Kernel-specific parameters
|
||||
#[serde(default)]
|
||||
pub params: HashMap<String, KernelParam>,
|
||||
|
||||
/// Resource limits
|
||||
pub resource_limits: ResourceLimits,
|
||||
|
||||
/// Platform-specific configurations
|
||||
#[serde(default)]
|
||||
pub platforms: HashMap<String, PlatformConfig>,
|
||||
|
||||
/// Benchmark results
|
||||
#[serde(default)]
|
||||
pub benchmarks: HashMap<String, BenchmarkResult>,
|
||||
}
|
||||
|
||||
/// Kernel categories
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum KernelCategory {
|
||||
/// Positional encoding (RoPE, etc.)
|
||||
PositionalEncoding,
|
||||
/// Normalization (RMSNorm, LayerNorm, etc.)
|
||||
Normalization,
|
||||
/// Activation functions (SwiGLU, GELU, etc.)
|
||||
Activation,
|
||||
/// KV cache operations (quantize, dequantize)
|
||||
KvCache,
|
||||
/// Adapter operations (LoRA, etc.)
|
||||
Adapter,
|
||||
/// Attention mechanisms
|
||||
Attention,
|
||||
/// Custom/other operations
|
||||
Custom,
|
||||
}
|
||||
|
||||
impl Default for KernelCategory {
|
||||
fn default() -> Self {
|
||||
KernelCategory::Custom
|
||||
}
|
||||
}
|
||||
|
||||
/// Tensor specification
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TensorSpec {
|
||||
/// Tensor name
|
||||
pub name: String,
|
||||
|
||||
/// Data type
|
||||
pub dtype: DataType,
|
||||
|
||||
/// Shape specification (symbolic dimensions like "batch", "seq", numeric for fixed)
|
||||
pub shape: Vec<ShapeDim>,
|
||||
}
|
||||
|
||||
/// Shape dimension (can be symbolic or numeric)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ShapeDim {
|
||||
/// Symbolic dimension (e.g., "batch", "seq", "heads")
|
||||
Symbolic(String),
|
||||
/// Fixed numeric dimension
|
||||
Fixed(usize),
|
||||
}
|
||||
|
||||
/// Data types supported by kernels
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum DataType {
|
||||
/// 32-bit float
|
||||
F32,
|
||||
/// 16-bit float (half precision)
|
||||
F16,
|
||||
/// Brain float 16
|
||||
Bf16,
|
||||
/// 8-bit integer (signed)
|
||||
I8,
|
||||
/// 8-bit unsigned integer
|
||||
U8,
|
||||
/// 32-bit integer
|
||||
I32,
|
||||
/// Quantized 4-bit
|
||||
Q4,
|
||||
/// Quantized 8-bit
|
||||
Q8,
|
||||
}
|
||||
|
||||
impl DataType {
|
||||
/// Get size in bytes for this data type
|
||||
pub fn size_bytes(&self) -> usize {
|
||||
match self {
|
||||
DataType::F32 | DataType::I32 => 4,
|
||||
DataType::F16 | DataType::Bf16 => 2,
|
||||
DataType::I8 | DataType::U8 | DataType::Q8 => 1,
|
||||
DataType::Q4 => 1, // Packed, 2 values per byte
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Kernel parameter definition
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct KernelParam {
|
||||
/// Parameter data type
|
||||
#[serde(rename = "type")]
|
||||
pub param_type: ParamType,
|
||||
|
||||
/// Default value
|
||||
pub default: serde_json::Value,
|
||||
|
||||
/// Optional minimum value
|
||||
#[serde(default)]
|
||||
pub min: Option<serde_json::Value>,
|
||||
|
||||
/// Optional maximum value
|
||||
#[serde(default)]
|
||||
pub max: Option<serde_json::Value>,
|
||||
|
||||
/// Optional description
|
||||
#[serde(default)]
|
||||
pub description: Option<String>,
|
||||
}
|
||||
|
||||
/// Parameter types
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ParamType {
|
||||
F32,
|
||||
F64,
|
||||
I32,
|
||||
I64,
|
||||
U32,
|
||||
U64,
|
||||
Bool,
|
||||
}
|
||||
|
||||
/// Resource limits for kernel execution
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ResourceLimits {
|
||||
/// Maximum WASM memory pages (64KB each)
|
||||
pub max_memory_pages: u32,
|
||||
|
||||
/// Maximum epoch ticks before interruption
|
||||
pub max_epoch_ticks: u64,
|
||||
|
||||
/// Maximum table elements
|
||||
pub max_table_elements: u32,
|
||||
|
||||
/// Optional: Maximum stack size in bytes
|
||||
#[serde(default)]
|
||||
pub max_stack_size: Option<usize>,
|
||||
|
||||
/// Optional: Maximum globals
|
||||
#[serde(default)]
|
||||
pub max_globals: Option<u32>,
|
||||
}
|
||||
|
||||
impl Default for ResourceLimits {
|
||||
fn default() -> Self {
|
||||
ResourceLimits {
|
||||
max_memory_pages: 256, // 16MB
|
||||
max_epoch_ticks: 1000, // ~10 seconds at 10ms/tick
|
||||
max_table_elements: 1024, // Function pointers
|
||||
max_stack_size: None,
|
||||
max_globals: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Platform-specific configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PlatformConfig {
|
||||
/// Minimum version of the runtime
|
||||
pub min_version: String,
|
||||
|
||||
/// Required WASM features
|
||||
#[serde(default)]
|
||||
pub features: Vec<String>,
|
||||
|
||||
/// Whether AOT compilation is available
|
||||
#[serde(default)]
|
||||
pub aot_available: bool,
|
||||
}
|
||||
|
||||
/// Benchmark result
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BenchmarkResult {
|
||||
/// Latency in microseconds
|
||||
pub latency_us: u64,
|
||||
|
||||
/// Throughput in GFLOPS
|
||||
pub throughput_gflops: f64,
|
||||
}
|
||||
|
||||
/// Kernel invocation descriptor passed to WASM
|
||||
///
|
||||
/// This is the C-compatible struct passed to kernels to describe
|
||||
/// memory layout and tensor locations.
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct KernelDescriptor {
|
||||
/// Input tensor A offset in linear memory
|
||||
pub input_a_offset: u32,
|
||||
/// Input tensor A size in bytes
|
||||
pub input_a_size: u32,
|
||||
/// Input tensor B offset (0 if unused)
|
||||
pub input_b_offset: u32,
|
||||
/// Input tensor B size in bytes
|
||||
pub input_b_size: u32,
|
||||
/// Output tensor offset
|
||||
pub output_offset: u32,
|
||||
/// Output tensor size in bytes
|
||||
pub output_size: u32,
|
||||
/// Scratch space offset
|
||||
pub scratch_offset: u32,
|
||||
/// Scratch space size in bytes
|
||||
pub scratch_size: u32,
|
||||
/// Kernel-specific parameters offset
|
||||
pub params_offset: u32,
|
||||
/// Kernel-specific parameters size
|
||||
pub params_size: u32,
|
||||
}
|
||||
|
||||
impl KernelDescriptor {
|
||||
/// Create a new kernel descriptor
|
||||
pub fn new() -> Self {
|
||||
KernelDescriptor {
|
||||
input_a_offset: 0,
|
||||
input_a_size: 0,
|
||||
input_b_offset: 0,
|
||||
input_b_size: 0,
|
||||
output_offset: 0,
|
||||
output_size: 0,
|
||||
scratch_offset: 0,
|
||||
scratch_size: 0,
|
||||
params_offset: 0,
|
||||
params_size: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate total memory required
|
||||
pub fn total_memory_required(&self) -> usize {
|
||||
let max_end = [
|
||||
self.input_a_offset + self.input_a_size,
|
||||
self.input_b_offset + self.input_b_size,
|
||||
self.output_offset + self.output_size,
|
||||
self.scratch_offset + self.scratch_size,
|
||||
self.params_offset + self.params_size,
|
||||
]
|
||||
.into_iter()
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
|
||||
max_end as usize
|
||||
}
|
||||
|
||||
/// Serialize to bytes for passing to WASM
|
||||
pub fn to_bytes(&self) -> Vec<u8> {
|
||||
let mut bytes = Vec::with_capacity(40);
|
||||
bytes.extend_from_slice(&self.input_a_offset.to_le_bytes());
|
||||
bytes.extend_from_slice(&self.input_a_size.to_le_bytes());
|
||||
bytes.extend_from_slice(&self.input_b_offset.to_le_bytes());
|
||||
bytes.extend_from_slice(&self.input_b_size.to_le_bytes());
|
||||
bytes.extend_from_slice(&self.output_offset.to_le_bytes());
|
||||
bytes.extend_from_slice(&self.output_size.to_le_bytes());
|
||||
bytes.extend_from_slice(&self.scratch_offset.to_le_bytes());
|
||||
bytes.extend_from_slice(&self.scratch_size.to_le_bytes());
|
||||
bytes.extend_from_slice(&self.params_offset.to_le_bytes());
|
||||
bytes.extend_from_slice(&self.params_size.to_le_bytes());
|
||||
bytes
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for KernelDescriptor {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelManifest {
|
||||
/// Parse manifest from JSON string
|
||||
pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
|
||||
serde_json::from_str(json)
|
||||
}
|
||||
|
||||
/// Serialize manifest to JSON string
|
||||
pub fn to_json(&self) -> Result<String, serde_json::Error> {
|
||||
serde_json::to_string_pretty(self)
|
||||
}
|
||||
|
||||
/// Get kernel by ID
|
||||
pub fn get_kernel(&self, id: &str) -> Option<&KernelInfo> {
|
||||
self.kernels.iter().find(|k| k.id == id)
|
||||
}
|
||||
|
||||
/// Get fallback kernel for a given kernel ID
|
||||
pub fn get_fallback(&self, id: &str) -> Option<&str> {
|
||||
self.fallbacks.get(id).map(|s| s.as_str())
|
||||
}
|
||||
|
||||
/// List all kernel IDs
|
||||
pub fn kernel_ids(&self) -> Vec<&str> {
|
||||
self.kernels.iter().map(|k| k.id.as_str()).collect()
|
||||
}
|
||||
|
||||
/// List kernels by category
|
||||
pub fn kernels_by_category(&self, category: KernelCategory) -> Vec<&KernelInfo> {
|
||||
self.kernels
|
||||
.iter()
|
||||
.filter(|k| k.category == category)
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn sample_manifest_json() -> &'static str {
|
||||
r#"{
|
||||
"$schema": "https://ruvllm.dev/schemas/kernel-pack-v1.json",
|
||||
"version": "1.0.0",
|
||||
"name": "test-kernels",
|
||||
"description": "Test kernel pack",
|
||||
"min_runtime_version": "0.5.0",
|
||||
"max_runtime_version": "1.0.0",
|
||||
"created_at": "2026-01-18T00:00:00Z",
|
||||
"author": {
|
||||
"name": "Test Author",
|
||||
"email": "test@example.com",
|
||||
"signing_key": "ed25519:AAAA..."
|
||||
},
|
||||
"kernels": [
|
||||
{
|
||||
"id": "rope_f32",
|
||||
"name": "Rotary Position Embedding (FP32)",
|
||||
"category": "positional_encoding",
|
||||
"path": "rope/rope_f32.wasm",
|
||||
"hash": "sha256:abc123",
|
||||
"entry_point": "rope_forward",
|
||||
"inputs": [
|
||||
{"name": "x", "dtype": "f32", "shape": ["batch", "seq", "heads", "dim"]},
|
||||
{"name": "freqs", "dtype": "f32", "shape": ["seq", 64]}
|
||||
],
|
||||
"outputs": [
|
||||
{"name": "y", "dtype": "f32", "shape": ["batch", "seq", "heads", "dim"]}
|
||||
],
|
||||
"params": {
|
||||
"theta": {"type": "f32", "default": 10000.0}
|
||||
},
|
||||
"resource_limits": {
|
||||
"max_memory_pages": 256,
|
||||
"max_epoch_ticks": 1000,
|
||||
"max_table_elements": 1024
|
||||
},
|
||||
"platforms": {
|
||||
"wasmtime": {
|
||||
"min_version": "15.0.0",
|
||||
"features": ["simd", "bulk-memory"]
|
||||
}
|
||||
},
|
||||
"benchmarks": {
|
||||
"seq_512_dim_128": {
|
||||
"latency_us": 45,
|
||||
"throughput_gflops": 2.1
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"fallbacks": {
|
||||
"rope_f32": "rope_reference"
|
||||
}
|
||||
}"#
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_manifest_parsing() {
|
||||
let manifest = KernelManifest::from_json(sample_manifest_json()).unwrap();
|
||||
assert_eq!(manifest.name, "test-kernels");
|
||||
assert_eq!(manifest.version, "1.0.0");
|
||||
assert_eq!(manifest.kernels.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_kernel_lookup() {
|
||||
let manifest = KernelManifest::from_json(sample_manifest_json()).unwrap();
|
||||
let kernel = manifest.get_kernel("rope_f32").unwrap();
|
||||
assert_eq!(kernel.name, "Rotary Position Embedding (FP32)");
|
||||
assert_eq!(kernel.category, KernelCategory::PositionalEncoding);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fallback_lookup() {
|
||||
let manifest = KernelManifest::from_json(sample_manifest_json()).unwrap();
|
||||
assert_eq!(manifest.get_fallback("rope_f32"), Some("rope_reference"));
|
||||
assert_eq!(manifest.get_fallback("unknown"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_kernel_descriptor() {
|
||||
let mut desc = KernelDescriptor::new();
|
||||
desc.input_a_offset = 0;
|
||||
desc.input_a_size = 1024;
|
||||
desc.output_offset = 1024;
|
||||
desc.output_size = 1024;
|
||||
|
||||
assert_eq!(desc.total_memory_required(), 2048);
|
||||
assert_eq!(desc.to_bytes().len(), 40);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_data_type_sizes() {
|
||||
assert_eq!(DataType::F32.size_bytes(), 4);
|
||||
assert_eq!(DataType::F16.size_bytes(), 2);
|
||||
assert_eq!(DataType::I8.size_bytes(), 1);
|
||||
}
|
||||
}
|
||||
466
vendor/ruvector/crates/ruvector-wasm/src/kernel/memory.rs
vendored
Normal file
466
vendor/ruvector/crates/ruvector-wasm/src/kernel/memory.rs
vendored
Normal file
@@ -0,0 +1,466 @@
|
||||
//! Shared Memory Protocol
|
||||
//!
|
||||
//! Defines the memory layout and protocol for passing tensor data
|
||||
//! between the host and WASM kernels.
|
||||
|
||||
use crate::kernel::error::KernelError;
|
||||
use crate::kernel::manifest::{DataType, KernelDescriptor};
|
||||
|
||||
/// WASM page size (64KB)
|
||||
pub const PAGE_SIZE: usize = 65536;
|
||||
|
||||
/// Shared memory protocol for kernel invocation
|
||||
///
|
||||
/// Manages the layout of tensors and parameters in WASM linear memory.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SharedMemoryProtocol {
|
||||
/// Total memory size in bytes
|
||||
total_size: usize,
|
||||
/// Current allocation offset
|
||||
current_offset: usize,
|
||||
/// Memory alignment (typically 8 or 16 bytes)
|
||||
alignment: usize,
|
||||
}
|
||||
|
||||
impl SharedMemoryProtocol {
|
||||
/// Create a new memory protocol
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `total_pages` - Number of WASM pages to allocate
|
||||
/// * `alignment` - Memory alignment in bytes
|
||||
pub fn new(total_pages: usize, alignment: usize) -> Self {
|
||||
SharedMemoryProtocol {
|
||||
total_size: total_pages * PAGE_SIZE,
|
||||
current_offset: 0,
|
||||
alignment,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with default settings (256 pages = 16MB, 16-byte alignment)
|
||||
pub fn default_settings() -> Self {
|
||||
Self::new(256, 16)
|
||||
}
|
||||
|
||||
/// Reset allocator to beginning
|
||||
pub fn reset(&mut self) {
|
||||
self.current_offset = 0;
|
||||
}
|
||||
|
||||
/// Align offset to boundary
|
||||
fn align_offset(&self, offset: usize) -> usize {
|
||||
(offset + self.alignment - 1) & !(self.alignment - 1)
|
||||
}
|
||||
|
||||
/// Allocate memory region
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `size` - Size in bytes
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Ok(offset)` - Starting offset of allocated region
|
||||
/// * `Err` - If allocation would exceed total size
|
||||
pub fn allocate(&mut self, size: usize) -> Result<usize, KernelError> {
|
||||
let aligned_offset = self.align_offset(self.current_offset);
|
||||
let end_offset = aligned_offset + size;
|
||||
|
||||
if end_offset > self.total_size {
|
||||
return Err(KernelError::AllocationFailed {
|
||||
requested_bytes: size,
|
||||
});
|
||||
}
|
||||
|
||||
self.current_offset = end_offset;
|
||||
Ok(aligned_offset)
|
||||
}
|
||||
|
||||
/// Get total memory size
|
||||
pub fn total_size(&self) -> usize {
|
||||
self.total_size
|
||||
}
|
||||
|
||||
/// Get total pages
|
||||
pub fn total_pages(&self) -> usize {
|
||||
self.total_size / PAGE_SIZE
|
||||
}
|
||||
|
||||
/// Get current allocation offset
|
||||
pub fn current_offset(&self) -> usize {
|
||||
self.current_offset
|
||||
}
|
||||
|
||||
/// Get remaining available bytes
|
||||
pub fn remaining(&self) -> usize {
|
||||
self.total_size.saturating_sub(self.current_offset)
|
||||
}
|
||||
|
||||
/// Check if a memory region is valid
|
||||
pub fn is_valid_region(&self, offset: usize, size: usize) -> bool {
|
||||
offset + size <= self.total_size
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SharedMemoryProtocol {
|
||||
fn default() -> Self {
|
||||
Self::default_settings()
|
||||
}
|
||||
}
|
||||
|
||||
/// Kernel invocation descriptor with memory layout
|
||||
///
|
||||
/// This is a higher-level wrapper around KernelDescriptor that helps
|
||||
/// manage memory allocation and data transfer.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KernelInvocationDescriptor {
|
||||
/// Low-level descriptor
|
||||
pub descriptor: KernelDescriptor,
|
||||
/// Memory protocol
|
||||
protocol: SharedMemoryProtocol,
|
||||
}
|
||||
|
||||
impl KernelInvocationDescriptor {
|
||||
/// Create a new invocation descriptor
|
||||
pub fn new(total_pages: usize) -> Self {
|
||||
KernelInvocationDescriptor {
|
||||
descriptor: KernelDescriptor::new(),
|
||||
protocol: SharedMemoryProtocol::new(total_pages, 16),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with default memory size
|
||||
pub fn default_size() -> Self {
|
||||
Self::new(256)
|
||||
}
|
||||
|
||||
/// Allocate space for input tensor A
|
||||
pub fn allocate_input_a(&mut self, size: usize) -> Result<u32, KernelError> {
|
||||
let offset = self.protocol.allocate(size)?;
|
||||
self.descriptor.input_a_offset = offset as u32;
|
||||
self.descriptor.input_a_size = size as u32;
|
||||
Ok(offset as u32)
|
||||
}
|
||||
|
||||
/// Allocate space for input tensor B
|
||||
pub fn allocate_input_b(&mut self, size: usize) -> Result<u32, KernelError> {
|
||||
let offset = self.protocol.allocate(size)?;
|
||||
self.descriptor.input_b_offset = offset as u32;
|
||||
self.descriptor.input_b_size = size as u32;
|
||||
Ok(offset as u32)
|
||||
}
|
||||
|
||||
/// Allocate space for output tensor
|
||||
pub fn allocate_output(&mut self, size: usize) -> Result<u32, KernelError> {
|
||||
let offset = self.protocol.allocate(size)?;
|
||||
self.descriptor.output_offset = offset as u32;
|
||||
self.descriptor.output_size = size as u32;
|
||||
Ok(offset as u32)
|
||||
}
|
||||
|
||||
/// Allocate scratch space
|
||||
pub fn allocate_scratch(&mut self, size: usize) -> Result<u32, KernelError> {
|
||||
let offset = self.protocol.allocate(size)?;
|
||||
self.descriptor.scratch_offset = offset as u32;
|
||||
self.descriptor.scratch_size = size as u32;
|
||||
Ok(offset as u32)
|
||||
}
|
||||
|
||||
/// Allocate space for parameters
|
||||
pub fn allocate_params(&mut self, size: usize) -> Result<u32, KernelError> {
|
||||
let offset = self.protocol.allocate(size)?;
|
||||
self.descriptor.params_offset = offset as u32;
|
||||
self.descriptor.params_size = size as u32;
|
||||
Ok(offset as u32)
|
||||
}
|
||||
|
||||
/// Get the low-level descriptor
|
||||
pub fn as_descriptor(&self) -> &KernelDescriptor {
|
||||
&self.descriptor
|
||||
}
|
||||
|
||||
/// Get total allocated memory
|
||||
pub fn total_allocated(&self) -> usize {
|
||||
self.protocol.current_offset()
|
||||
}
|
||||
|
||||
/// Get remaining memory
|
||||
pub fn remaining_memory(&self) -> usize {
|
||||
self.protocol.remaining()
|
||||
}
|
||||
|
||||
/// Required pages for current allocation
|
||||
pub fn required_pages(&self) -> usize {
|
||||
(self.total_allocated() + PAGE_SIZE - 1) / PAGE_SIZE
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for KernelInvocationDescriptor {
|
||||
fn default() -> Self {
|
||||
Self::default_size()
|
||||
}
|
||||
}
|
||||
|
||||
/// Memory region specification
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct MemoryRegion {
|
||||
/// Start offset in linear memory
|
||||
pub offset: u32,
|
||||
/// Size in bytes
|
||||
pub size: u32,
|
||||
/// Whether region is read-only
|
||||
pub read_only: bool,
|
||||
}
|
||||
|
||||
impl MemoryRegion {
|
||||
/// Create a new memory region
|
||||
pub fn new(offset: u32, size: u32, read_only: bool) -> Self {
|
||||
MemoryRegion {
|
||||
offset,
|
||||
size,
|
||||
read_only,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a read-only region
|
||||
pub fn read_only(offset: u32, size: u32) -> Self {
|
||||
Self::new(offset, size, true)
|
||||
}
|
||||
|
||||
/// Create a writable region
|
||||
pub fn writable(offset: u32, size: u32) -> Self {
|
||||
Self::new(offset, size, false)
|
||||
}
|
||||
|
||||
/// Get end offset (exclusive)
|
||||
pub fn end(&self) -> u32 {
|
||||
self.offset + self.size
|
||||
}
|
||||
|
||||
/// Check if regions overlap
|
||||
pub fn overlaps(&self, other: &MemoryRegion) -> bool {
|
||||
self.offset < other.end() && other.offset < self.end()
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate tensor size in bytes
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `shape` - Tensor shape (dimensions)
|
||||
/// * `dtype` - Data type
|
||||
///
|
||||
/// # Returns
|
||||
/// Size in bytes
|
||||
pub fn tensor_size_bytes(shape: &[usize], dtype: DataType) -> usize {
|
||||
let num_elements: usize = shape.iter().product();
|
||||
num_elements * dtype.size_bytes()
|
||||
}
|
||||
|
||||
/// Calculate required WASM pages for a given byte size
|
||||
pub fn required_pages(size_bytes: usize) -> usize {
|
||||
(size_bytes + PAGE_SIZE - 1) / PAGE_SIZE
|
||||
}
|
||||
|
||||
/// Memory layout validator
|
||||
#[derive(Debug, Default)]
|
||||
pub struct MemoryLayoutValidator {
|
||||
/// Registered regions
|
||||
regions: Vec<MemoryRegion>,
|
||||
}
|
||||
|
||||
impl MemoryLayoutValidator {
|
||||
/// Create a new validator
|
||||
pub fn new() -> Self {
|
||||
MemoryLayoutValidator {
|
||||
regions: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a region to validate
|
||||
pub fn add_region(&mut self, region: MemoryRegion) -> Result<(), KernelError> {
|
||||
// Check for overlaps with existing regions
|
||||
for existing in &self.regions {
|
||||
if region.overlaps(existing) {
|
||||
return Err(KernelError::InvalidParameters {
|
||||
description: format!(
|
||||
"Memory region overlap: [{}, {}) overlaps [{}, {})",
|
||||
region.offset,
|
||||
region.end(),
|
||||
existing.offset,
|
||||
existing.end()
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
self.regions.push(region);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate a descriptor's memory layout
|
||||
pub fn validate_descriptor(
|
||||
&self,
|
||||
desc: &KernelDescriptor,
|
||||
total_memory: usize,
|
||||
) -> Result<(), KernelError> {
|
||||
// Check all regions are within bounds
|
||||
let regions = [
|
||||
("input_a", desc.input_a_offset, desc.input_a_size),
|
||||
("input_b", desc.input_b_offset, desc.input_b_size),
|
||||
("output", desc.output_offset, desc.output_size),
|
||||
("scratch", desc.scratch_offset, desc.scratch_size),
|
||||
("params", desc.params_offset, desc.params_size),
|
||||
];
|
||||
|
||||
for (name, offset, size) in regions {
|
||||
if size > 0 {
|
||||
let end = (offset as usize) + (size as usize);
|
||||
if end > total_memory {
|
||||
return Err(KernelError::MemoryAccessViolation { offset, size });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for overlaps between output and inputs
|
||||
let output = MemoryRegion::writable(desc.output_offset, desc.output_size);
|
||||
|
||||
if desc.input_a_size > 0 {
|
||||
let input_a = MemoryRegion::read_only(desc.input_a_offset, desc.input_a_size);
|
||||
if output.overlaps(&input_a) {
|
||||
return Err(KernelError::InvalidParameters {
|
||||
description: "Output overlaps with input_a".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if desc.input_b_size > 0 {
|
||||
let input_b = MemoryRegion::read_only(desc.input_b_offset, desc.input_b_size);
|
||||
if output.overlaps(&input_b) {
|
||||
return Err(KernelError::InvalidParameters {
|
||||
description: "Output overlaps with input_b".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Clear all regions
|
||||
pub fn clear(&mut self) {
|
||||
self.regions.clear();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_memory_protocol() {
|
||||
let mut protocol = SharedMemoryProtocol::new(1, 16); // 1 page = 64KB
|
||||
|
||||
let offset1 = protocol.allocate(1024).unwrap();
|
||||
assert_eq!(offset1, 0);
|
||||
|
||||
let offset2 = protocol.allocate(2048).unwrap();
|
||||
assert!(offset2 >= 1024);
|
||||
assert_eq!(offset2 % 16, 0); // Aligned
|
||||
|
||||
assert!(protocol.remaining() < PAGE_SIZE);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_allocation_failure() {
|
||||
let mut protocol = SharedMemoryProtocol::new(1, 16);
|
||||
|
||||
// Try to allocate more than available
|
||||
let result = protocol.allocate(PAGE_SIZE + 1);
|
||||
assert!(matches!(result, Err(KernelError::AllocationFailed { .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invocation_descriptor() {
|
||||
let mut desc = KernelInvocationDescriptor::new(4); // 4 pages
|
||||
|
||||
desc.allocate_input_a(1024).unwrap();
|
||||
desc.allocate_input_b(1024).unwrap();
|
||||
desc.allocate_output(1024).unwrap();
|
||||
desc.allocate_scratch(512).unwrap();
|
||||
desc.allocate_params(64).unwrap();
|
||||
|
||||
assert!(desc.total_allocated() > 3600); // With alignment
|
||||
assert_eq!(desc.descriptor.input_a_size, 1024);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tensor_size() {
|
||||
let shape = [1, 512, 32, 128]; // batch, seq, heads, dim
|
||||
let size = tensor_size_bytes(&shape, DataType::F32);
|
||||
assert_eq!(size, 1 * 512 * 32 * 128 * 4); // 8MB
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_required_pages() {
|
||||
assert_eq!(required_pages(0), 0);
|
||||
assert_eq!(required_pages(1), 1);
|
||||
assert_eq!(required_pages(PAGE_SIZE), 1);
|
||||
assert_eq!(required_pages(PAGE_SIZE + 1), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_memory_region_overlap() {
|
||||
let r1 = MemoryRegion::new(0, 100, false);
|
||||
let r2 = MemoryRegion::new(50, 100, false);
|
||||
let r3 = MemoryRegion::new(100, 100, false);
|
||||
|
||||
assert!(r1.overlaps(&r2));
|
||||
assert!(!r1.overlaps(&r3));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_layout_validator() {
|
||||
let mut validator = MemoryLayoutValidator::new();
|
||||
|
||||
// Add non-overlapping regions
|
||||
validator
|
||||
.add_region(MemoryRegion::new(0, 100, false))
|
||||
.unwrap();
|
||||
validator
|
||||
.add_region(MemoryRegion::new(100, 100, false))
|
||||
.unwrap();
|
||||
|
||||
// Try to add overlapping region
|
||||
let result = validator.add_region(MemoryRegion::new(50, 100, false));
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_descriptor() {
|
||||
let validator = MemoryLayoutValidator::new();
|
||||
let mut desc = KernelDescriptor::new();
|
||||
|
||||
desc.input_a_offset = 0;
|
||||
desc.input_a_size = 1024;
|
||||
desc.output_offset = 1024;
|
||||
desc.output_size = 1024;
|
||||
|
||||
// Should pass - no overlap
|
||||
assert!(validator.validate_descriptor(&desc, PAGE_SIZE).is_ok());
|
||||
|
||||
// Should fail - output overlaps input
|
||||
desc.output_offset = 512;
|
||||
assert!(validator.validate_descriptor(&desc, PAGE_SIZE).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_bounds() {
|
||||
let validator = MemoryLayoutValidator::new();
|
||||
let mut desc = KernelDescriptor::new();
|
||||
|
||||
desc.input_a_offset = 0;
|
||||
desc.input_a_size = PAGE_SIZE as u32 + 1; // Too big
|
||||
|
||||
let result = validator.validate_descriptor(&desc, PAGE_SIZE);
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(KernelError::MemoryAccessViolation { .. })
|
||||
));
|
||||
}
|
||||
}
|
||||
71
vendor/ruvector/crates/ruvector-wasm/src/kernel/mod.rs
vendored
Normal file
71
vendor/ruvector/crates/ruvector-wasm/src/kernel/mod.rs
vendored
Normal file
@@ -0,0 +1,71 @@
|
||||
//! WASM Kernel Pack System (ADR-005)
|
||||
//!
|
||||
//! This module implements the WebAssembly kernel pack infrastructure for
|
||||
//! secure, sandboxed execution of ML compute kernels.
|
||||
//!
|
||||
//! # Architecture
|
||||
//!
|
||||
//! The kernel pack system provides:
|
||||
//! - **Sandboxed Execution**: Wasmtime runtime with epoch-based interruption
|
||||
//! - **Supply Chain Security**: Ed25519 signatures, SHA256 hash verification
|
||||
//! - **Hot-Swappable Kernels**: Update kernels without service restart
|
||||
//! - **Cross-Platform**: Same kernels run on servers and embedded devices
|
||||
//!
|
||||
//! # Kernel Categories
|
||||
//!
|
||||
//! - Positional: RoPE (Rotary Position Embeddings)
|
||||
//! - Normalization: RMSNorm
|
||||
//! - Activation: SwiGLU
|
||||
//! - KV Cache: Quantization/Dequantization
|
||||
//! - Adapter: LoRA delta application
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! ```rust,ignore
|
||||
//! use ruvector_wasm::kernel::{KernelManager, KernelPackVerifier};
|
||||
//!
|
||||
//! // Load and verify kernel pack
|
||||
//! let verifier = KernelPackVerifier::with_trusted_keys(keys);
|
||||
//! let manager = KernelManager::new(runtime_config)?;
|
||||
//! manager.load_pack("kernel-pack-v1.0.0", &verifier)?;
|
||||
//!
|
||||
//! // Execute kernel
|
||||
//! let result = manager.execute("rope_f32", &descriptor)?;
|
||||
//! ```
|
||||
|
||||
pub mod allowlist;
|
||||
pub mod epoch;
|
||||
pub mod error;
|
||||
pub mod hash;
|
||||
pub mod manifest;
|
||||
pub mod memory;
|
||||
pub mod runtime;
|
||||
pub mod signature;
|
||||
|
||||
// Re-exports
|
||||
pub use allowlist::TrustedKernelAllowlist;
|
||||
pub use epoch::{EpochConfig, EpochController};
|
||||
pub use error::{KernelError, VerifyError};
|
||||
pub use hash::HashVerifier;
|
||||
pub use manifest::{
|
||||
KernelCategory, KernelDescriptor, KernelInfo, KernelManifest, KernelParam, PlatformConfig,
|
||||
ResourceLimits, TensorSpec,
|
||||
};
|
||||
pub use memory::{KernelInvocationDescriptor, SharedMemoryProtocol};
|
||||
pub use runtime::{KernelRuntime, RuntimeConfig, WasmKernelInstance};
|
||||
pub use signature::KernelPackVerifier;
|
||||
|
||||
/// Current runtime version for compatibility checking
|
||||
pub const RUNTIME_VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
|
||||
/// Maximum supported kernel manifest schema version
|
||||
pub const MAX_MANIFEST_VERSION: &str = "1.0.0";
|
||||
|
||||
/// WASM page size in bytes (64KB)
|
||||
pub const WASM_PAGE_SIZE: usize = 65536;
|
||||
|
||||
/// Default epoch tick interval in milliseconds
|
||||
pub const DEFAULT_EPOCH_TICK_MS: u64 = 10;
|
||||
|
||||
/// Default epoch budget (ticks before interruption)
|
||||
pub const DEFAULT_EPOCH_BUDGET: u64 = 1000;
|
||||
575
vendor/ruvector/crates/ruvector-wasm/src/kernel/runtime.rs
vendored
Normal file
575
vendor/ruvector/crates/ruvector-wasm/src/kernel/runtime.rs
vendored
Normal file
@@ -0,0 +1,575 @@
|
||||
//! Wasmtime Runtime Integration
|
||||
//!
|
||||
//! Provides the runtime traits and implementations for executing
|
||||
//! WASM kernels with Wasmtime.
|
||||
|
||||
use crate::kernel::epoch::{EpochConfig, EpochController, EpochDeadline};
|
||||
use crate::kernel::error::{KernelError, KernelErrorCode, KernelResult};
|
||||
use crate::kernel::manifest::{KernelDescriptor, KernelInfo, KernelManifest, ResourceLimits};
|
||||
use crate::kernel::memory::{MemoryLayoutValidator, SharedMemoryProtocol, PAGE_SIZE};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Runtime configuration for WASM kernel execution
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RuntimeConfig {
|
||||
/// Epoch configuration
|
||||
pub epoch: EpochConfig,
|
||||
|
||||
/// Enable SIMD support
|
||||
pub enable_simd: bool,
|
||||
|
||||
/// Enable bulk memory operations
|
||||
pub enable_bulk_memory: bool,
|
||||
|
||||
/// Enable multi-value returns
|
||||
pub enable_multi_value: bool,
|
||||
|
||||
/// Maximum memory pages per instance
|
||||
pub max_memory_pages: u32,
|
||||
|
||||
/// Enable parallel compilation
|
||||
pub parallel_compilation: bool,
|
||||
|
||||
/// Optimization level (0-3, where 0=none, 3=maximum)
|
||||
pub optimization_level: u8,
|
||||
|
||||
/// Enable instance pooling for reuse
|
||||
pub enable_instance_pooling: bool,
|
||||
|
||||
/// Pool size for instance reuse
|
||||
pub instance_pool_size: usize,
|
||||
}
|
||||
|
||||
impl RuntimeConfig {
|
||||
/// Create configuration for server workloads
|
||||
pub fn server() -> Self {
|
||||
RuntimeConfig {
|
||||
epoch: EpochConfig::server(),
|
||||
enable_simd: true,
|
||||
enable_bulk_memory: true,
|
||||
enable_multi_value: true,
|
||||
max_memory_pages: 1024, // 64MB max
|
||||
parallel_compilation: true,
|
||||
optimization_level: 3,
|
||||
enable_instance_pooling: true,
|
||||
instance_pool_size: 16,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create configuration for embedded/constrained workloads
|
||||
pub fn embedded() -> Self {
|
||||
RuntimeConfig {
|
||||
epoch: EpochConfig::embedded(),
|
||||
enable_simd: false, // Often unavailable
|
||||
enable_bulk_memory: true,
|
||||
enable_multi_value: true,
|
||||
max_memory_pages: 64, // 4MB max
|
||||
parallel_compilation: false,
|
||||
optimization_level: 2,
|
||||
enable_instance_pooling: false,
|
||||
instance_pool_size: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create configuration for development/debugging
|
||||
pub fn development() -> Self {
|
||||
RuntimeConfig {
|
||||
epoch: EpochConfig::disabled(),
|
||||
enable_simd: true,
|
||||
enable_bulk_memory: true,
|
||||
enable_multi_value: true,
|
||||
max_memory_pages: 1024,
|
||||
parallel_compilation: true,
|
||||
optimization_level: 0, // Fast compilation
|
||||
enable_instance_pooling: false,
|
||||
instance_pool_size: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RuntimeConfig {
|
||||
fn default() -> Self {
|
||||
Self::server()
|
||||
}
|
||||
}
|
||||
|
||||
/// Compiled WASM kernel module
|
||||
#[derive(Debug)]
|
||||
pub struct CompiledKernel {
|
||||
/// Kernel ID
|
||||
pub id: String,
|
||||
/// Kernel info from manifest
|
||||
pub info: KernelInfo,
|
||||
/// Compiled module bytes (for caching)
|
||||
pub compiled_bytes: Vec<u8>,
|
||||
/// Whether module uses SIMD
|
||||
pub uses_simd: bool,
|
||||
/// Required memory pages
|
||||
pub required_pages: u32,
|
||||
}
|
||||
|
||||
/// WASM kernel instance ready for execution
|
||||
pub struct WasmKernelInstance {
|
||||
/// Kernel ID
|
||||
kernel_id: String,
|
||||
/// Memory allocated for this instance
|
||||
memory_pages: u32,
|
||||
/// Epoch deadline for this invocation
|
||||
deadline: Option<EpochDeadline>,
|
||||
/// Memory validator
|
||||
validator: MemoryLayoutValidator,
|
||||
}
|
||||
|
||||
impl WasmKernelInstance {
|
||||
/// Create a new kernel instance
|
||||
pub fn new(kernel_id: String, memory_pages: u32) -> Self {
|
||||
WasmKernelInstance {
|
||||
kernel_id,
|
||||
memory_pages,
|
||||
deadline: None,
|
||||
validator: MemoryLayoutValidator::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set execution deadline
|
||||
pub fn set_deadline(&mut self, deadline: EpochDeadline) {
|
||||
self.deadline = Some(deadline);
|
||||
}
|
||||
|
||||
/// Get kernel ID
|
||||
pub fn kernel_id(&self) -> &str {
|
||||
&self.kernel_id
|
||||
}
|
||||
|
||||
/// Get allocated memory pages
|
||||
pub fn memory_pages(&self) -> u32 {
|
||||
self.memory_pages
|
||||
}
|
||||
|
||||
/// Get memory size in bytes
|
||||
pub fn memory_size(&self) -> usize {
|
||||
self.memory_pages as usize * PAGE_SIZE
|
||||
}
|
||||
|
||||
/// Validate a descriptor before execution
|
||||
pub fn validate_descriptor(&self, desc: &KernelDescriptor) -> KernelResult<()> {
|
||||
self.validator.validate_descriptor(desc, self.memory_size())
|
||||
}
|
||||
|
||||
/// Check if deadline exceeded (if set)
|
||||
pub fn check_deadline(&self, controller: &EpochController) -> bool {
|
||||
if let Some(deadline) = &self.deadline {
|
||||
deadline.is_exceeded(controller.current())
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for WasmKernelInstance {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("WasmKernelInstance")
|
||||
.field("kernel_id", &self.kernel_id)
|
||||
.field("memory_pages", &self.memory_pages)
|
||||
.field("deadline", &self.deadline)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait for kernel runtime implementations
|
||||
pub trait KernelRuntime: Send + Sync {
|
||||
/// Load and compile a kernel from WASM bytes
|
||||
fn compile_kernel(
|
||||
&self,
|
||||
id: &str,
|
||||
wasm_bytes: &[u8],
|
||||
info: &KernelInfo,
|
||||
) -> KernelResult<CompiledKernel>;
|
||||
|
||||
/// Create an instance of a compiled kernel
|
||||
fn instantiate(&self, kernel: &CompiledKernel) -> KernelResult<WasmKernelInstance>;
|
||||
|
||||
/// Execute a kernel with the given descriptor
|
||||
fn execute(
|
||||
&self,
|
||||
instance: &mut WasmKernelInstance,
|
||||
descriptor: &KernelDescriptor,
|
||||
memory: &mut [u8],
|
||||
) -> KernelResult<()>;
|
||||
|
||||
/// Get runtime configuration
|
||||
fn config(&self) -> &RuntimeConfig;
|
||||
|
||||
/// Get epoch controller
|
||||
fn epoch_controller(&self) -> &EpochController;
|
||||
|
||||
/// Increment epoch (should be called periodically)
|
||||
fn tick(&self) {
|
||||
self.epoch_controller().increment();
|
||||
}
|
||||
}
|
||||
|
||||
/// Mock runtime for testing without Wasmtime dependency
|
||||
#[derive(Debug)]
|
||||
pub struct MockKernelRuntime {
|
||||
config: RuntimeConfig,
|
||||
epoch_controller: EpochController,
|
||||
/// Registered kernel behaviors for testing
|
||||
kernel_behaviors: HashMap<String, MockKernelBehavior>,
|
||||
}
|
||||
|
||||
/// Mock kernel behavior for testing
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum MockKernelBehavior {
|
||||
/// Always succeed
|
||||
Success,
|
||||
/// Always fail with error code
|
||||
Fail(KernelErrorCode),
|
||||
/// Timeout (exceed epoch)
|
||||
Timeout,
|
||||
/// Return specific output data
|
||||
ReturnData(Vec<u8>),
|
||||
}
|
||||
|
||||
impl MockKernelRuntime {
|
||||
/// Create a new mock runtime
|
||||
pub fn new(config: RuntimeConfig) -> Self {
|
||||
MockKernelRuntime {
|
||||
epoch_controller: EpochController::new(config.epoch.tick_interval()),
|
||||
config,
|
||||
kernel_behaviors: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a mock behavior for a kernel
|
||||
pub fn register_behavior(&mut self, kernel_id: &str, behavior: MockKernelBehavior) {
|
||||
self.kernel_behaviors
|
||||
.insert(kernel_id.to_string(), behavior);
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelRuntime for MockKernelRuntime {
|
||||
fn compile_kernel(
|
||||
&self,
|
||||
id: &str,
|
||||
_wasm_bytes: &[u8],
|
||||
info: &KernelInfo,
|
||||
) -> KernelResult<CompiledKernel> {
|
||||
Ok(CompiledKernel {
|
||||
id: id.to_string(),
|
||||
info: info.clone(),
|
||||
compiled_bytes: vec![], // No actual compilation
|
||||
uses_simd: false,
|
||||
required_pages: info.resource_limits.max_memory_pages,
|
||||
})
|
||||
}
|
||||
|
||||
fn instantiate(&self, kernel: &CompiledKernel) -> KernelResult<WasmKernelInstance> {
|
||||
Ok(WasmKernelInstance::new(
|
||||
kernel.id.clone(),
|
||||
kernel.required_pages,
|
||||
))
|
||||
}
|
||||
|
||||
fn execute(
|
||||
&self,
|
||||
instance: &mut WasmKernelInstance,
|
||||
descriptor: &KernelDescriptor,
|
||||
memory: &mut [u8],
|
||||
) -> KernelResult<()> {
|
||||
// Validate descriptor first
|
||||
instance.validate_descriptor(descriptor)?;
|
||||
|
||||
// Check deadline
|
||||
if instance.check_deadline(&self.epoch_controller) {
|
||||
return Err(KernelError::EpochDeadline);
|
||||
}
|
||||
|
||||
// Look up mock behavior
|
||||
let behavior = self
|
||||
.kernel_behaviors
|
||||
.get(instance.kernel_id())
|
||||
.cloned()
|
||||
.unwrap_or(MockKernelBehavior::Success);
|
||||
|
||||
match behavior {
|
||||
MockKernelBehavior::Success => Ok(()),
|
||||
MockKernelBehavior::Fail(code) => Err(KernelError::KernelTrap {
|
||||
code: code as u32,
|
||||
message: Some(code.to_string()),
|
||||
}),
|
||||
MockKernelBehavior::Timeout => Err(KernelError::EpochDeadline),
|
||||
MockKernelBehavior::ReturnData(data) => {
|
||||
// Copy data to output region
|
||||
let out_start = descriptor.output_offset as usize;
|
||||
let out_end = out_start + descriptor.output_size.min(data.len() as u32) as usize;
|
||||
if out_end <= memory.len() {
|
||||
let copy_len = (out_end - out_start).min(data.len());
|
||||
memory[out_start..out_start + copy_len].copy_from_slice(&data[..copy_len]);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn config(&self) -> &RuntimeConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
fn epoch_controller(&self) -> &EpochController {
|
||||
&self.epoch_controller
|
||||
}
|
||||
}
|
||||
|
||||
/// Kernel manager for loading and executing kernel packs
|
||||
pub struct KernelManager<R: KernelRuntime> {
|
||||
/// Runtime implementation
|
||||
runtime: Arc<R>,
|
||||
/// Loaded manifests
|
||||
manifests: HashMap<String, KernelManifest>,
|
||||
/// Compiled kernels
|
||||
compiled_kernels: HashMap<String, CompiledKernel>,
|
||||
/// Active kernel pack
|
||||
active_pack: Option<String>,
|
||||
}
|
||||
|
||||
impl<R: KernelRuntime> KernelManager<R> {
|
||||
/// Create a new kernel manager
|
||||
pub fn new(runtime: Arc<R>) -> Self {
|
||||
KernelManager {
|
||||
runtime,
|
||||
manifests: HashMap::new(),
|
||||
compiled_kernels: HashMap::new(),
|
||||
active_pack: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Load a kernel pack manifest
|
||||
pub fn load_manifest(&mut self, pack_name: &str, manifest: KernelManifest) {
|
||||
self.manifests.insert(pack_name.to_string(), manifest);
|
||||
}
|
||||
|
||||
/// Compile a kernel from a loaded pack
|
||||
pub fn compile_kernel(
|
||||
&mut self,
|
||||
pack_name: &str,
|
||||
kernel_id: &str,
|
||||
wasm_bytes: &[u8],
|
||||
) -> KernelResult<()> {
|
||||
let manifest =
|
||||
self.manifests
|
||||
.get(pack_name)
|
||||
.ok_or_else(|| KernelError::KernelNotFound {
|
||||
kernel_id: format!("pack:{}", pack_name),
|
||||
})?;
|
||||
|
||||
let info = manifest
|
||||
.get_kernel(kernel_id)
|
||||
.ok_or_else(|| KernelError::KernelNotFound {
|
||||
kernel_id: kernel_id.to_string(),
|
||||
})?;
|
||||
|
||||
let compiled = self.runtime.compile_kernel(kernel_id, wasm_bytes, info)?;
|
||||
self.compiled_kernels
|
||||
.insert(kernel_id.to_string(), compiled);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set the active kernel pack
|
||||
pub fn set_active_pack(&mut self, pack_name: &str) -> KernelResult<()> {
|
||||
if self.manifests.contains_key(pack_name) {
|
||||
self.active_pack = Some(pack_name.to_string());
|
||||
Ok(())
|
||||
} else {
|
||||
Err(KernelError::KernelNotFound {
|
||||
kernel_id: format!("pack:{}", pack_name),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute a kernel
|
||||
pub fn execute(
|
||||
&self,
|
||||
kernel_id: &str,
|
||||
descriptor: &KernelDescriptor,
|
||||
memory: &mut [u8],
|
||||
) -> KernelResult<()> {
|
||||
let compiled =
|
||||
self.compiled_kernels
|
||||
.get(kernel_id)
|
||||
.ok_or_else(|| KernelError::KernelNotFound {
|
||||
kernel_id: kernel_id.to_string(),
|
||||
})?;
|
||||
|
||||
let mut instance = self.runtime.instantiate(compiled)?;
|
||||
|
||||
// Set deadline if epoch is enabled
|
||||
if self.runtime.config().epoch.enabled {
|
||||
let budget = compiled.info.resource_limits.max_epoch_ticks;
|
||||
let deadline = EpochDeadline::new(self.runtime.epoch_controller().current(), budget);
|
||||
instance.set_deadline(deadline);
|
||||
}
|
||||
|
||||
self.runtime.execute(&mut instance, descriptor, memory)
|
||||
}
|
||||
|
||||
/// Get kernel info
|
||||
pub fn get_kernel_info(&self, kernel_id: &str) -> Option<&KernelInfo> {
|
||||
self.compiled_kernels.get(kernel_id).map(|k| &k.info)
|
||||
}
|
||||
|
||||
/// List compiled kernel IDs
|
||||
pub fn list_kernels(&self) -> Vec<&str> {
|
||||
self.compiled_kernels.keys().map(|s| s.as_str()).collect()
|
||||
}
|
||||
|
||||
/// Get runtime reference
|
||||
pub fn runtime(&self) -> &R {
|
||||
&self.runtime
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::kernel::manifest::{DataType, KernelCategory, ResourceLimits, ShapeDim, TensorSpec};
|
||||
|
||||
fn mock_kernel_info(id: &str) -> KernelInfo {
|
||||
KernelInfo {
|
||||
id: id.to_string(),
|
||||
name: format!("Test {}", id),
|
||||
category: KernelCategory::Custom,
|
||||
path: format!("{}.wasm", id),
|
||||
hash: "sha256:test".to_string(),
|
||||
entry_point: "kernel_forward".to_string(),
|
||||
inputs: vec![TensorSpec {
|
||||
name: "x".to_string(),
|
||||
dtype: DataType::F32,
|
||||
shape: vec![ShapeDim::Symbolic("batch".to_string())],
|
||||
}],
|
||||
outputs: vec![TensorSpec {
|
||||
name: "y".to_string(),
|
||||
dtype: DataType::F32,
|
||||
shape: vec![ShapeDim::Symbolic("batch".to_string())],
|
||||
}],
|
||||
params: HashMap::new(),
|
||||
resource_limits: ResourceLimits::default(),
|
||||
platforms: HashMap::new(),
|
||||
benchmarks: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_runtime_config() {
|
||||
let server = RuntimeConfig::server();
|
||||
assert!(server.enable_simd);
|
||||
assert_eq!(server.optimization_level, 3);
|
||||
|
||||
let embedded = RuntimeConfig::embedded();
|
||||
assert!(!embedded.enable_simd);
|
||||
assert!(!embedded.parallel_compilation);
|
||||
|
||||
let dev = RuntimeConfig::development();
|
||||
assert_eq!(dev.optimization_level, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mock_runtime() {
|
||||
let mut runtime = MockKernelRuntime::new(RuntimeConfig::default());
|
||||
|
||||
// Test success behavior
|
||||
runtime.register_behavior("test_kernel", MockKernelBehavior::Success);
|
||||
|
||||
let info = mock_kernel_info("test_kernel");
|
||||
let compiled = runtime.compile_kernel("test_kernel", &[], &info).unwrap();
|
||||
let mut instance = runtime.instantiate(&compiled).unwrap();
|
||||
|
||||
let mut desc = KernelDescriptor::new();
|
||||
desc.input_a_offset = 0;
|
||||
desc.input_a_size = 1024;
|
||||
desc.output_offset = 1024;
|
||||
desc.output_size = 1024;
|
||||
|
||||
let mut memory = vec![0u8; 65536];
|
||||
let result = runtime.execute(&mut instance, &desc, &mut memory);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mock_runtime_failure() {
|
||||
let mut runtime = MockKernelRuntime::new(RuntimeConfig::default());
|
||||
runtime.register_behavior(
|
||||
"failing_kernel",
|
||||
MockKernelBehavior::Fail(KernelErrorCode::InvalidInput),
|
||||
);
|
||||
|
||||
let info = mock_kernel_info("failing_kernel");
|
||||
let compiled = runtime
|
||||
.compile_kernel("failing_kernel", &[], &info)
|
||||
.unwrap();
|
||||
let mut instance = runtime.instantiate(&compiled).unwrap();
|
||||
|
||||
let desc = KernelDescriptor::new();
|
||||
let mut memory = vec![0u8; 65536];
|
||||
let result = runtime.execute(&mut instance, &desc, &mut memory);
|
||||
assert!(matches!(result, Err(KernelError::KernelTrap { .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wasm_kernel_instance() {
|
||||
let mut instance = WasmKernelInstance::new("test".to_string(), 256);
|
||||
|
||||
assert_eq!(instance.kernel_id(), "test");
|
||||
assert_eq!(instance.memory_pages(), 256);
|
||||
assert_eq!(instance.memory_size(), 256 * PAGE_SIZE);
|
||||
|
||||
// Test deadline
|
||||
let controller = EpochController::default_interval();
|
||||
let deadline = EpochDeadline::new(0, 100);
|
||||
instance.set_deadline(deadline);
|
||||
|
||||
assert!(!instance.check_deadline(&controller));
|
||||
|
||||
// Exceed deadline
|
||||
for _ in 0..100 {
|
||||
controller.increment();
|
||||
}
|
||||
assert!(instance.check_deadline(&controller));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_kernel_manager() {
|
||||
let runtime = Arc::new(MockKernelRuntime::new(RuntimeConfig::default()));
|
||||
let mut manager = KernelManager::new(runtime);
|
||||
|
||||
// Create a minimal manifest
|
||||
let manifest = KernelManifest {
|
||||
schema: String::new(),
|
||||
version: "1.0.0".to_string(),
|
||||
name: "test-pack".to_string(),
|
||||
description: "Test".to_string(),
|
||||
min_runtime_version: "0.1.0".to_string(),
|
||||
max_runtime_version: "1.0.0".to_string(),
|
||||
created_at: "2026-01-18T00:00:00Z".to_string(),
|
||||
author: crate::kernel::manifest::AuthorInfo {
|
||||
name: "Test".to_string(),
|
||||
email: "test@test.com".to_string(),
|
||||
signing_key: "test".to_string(),
|
||||
},
|
||||
kernels: vec![mock_kernel_info("rope_f32")],
|
||||
fallbacks: HashMap::new(),
|
||||
};
|
||||
|
||||
manager.load_manifest("test-pack", manifest);
|
||||
manager.set_active_pack("test-pack").unwrap();
|
||||
|
||||
// Compile kernel
|
||||
manager
|
||||
.compile_kernel("test-pack", "rope_f32", &[])
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(manager.list_kernels(), vec!["rope_f32"]);
|
||||
}
|
||||
}
|
||||
288
vendor/ruvector/crates/ruvector-wasm/src/kernel/signature.rs
vendored
Normal file
288
vendor/ruvector/crates/ruvector-wasm/src/kernel/signature.rs
vendored
Normal file
@@ -0,0 +1,288 @@
|
||||
//! Ed25519 Signature Verification
|
||||
//!
|
||||
//! Provides cryptographic signature verification for kernel pack manifests
|
||||
//! to ensure supply chain security.
|
||||
|
||||
use crate::kernel::error::VerifyError;
|
||||
use ed25519_dalek::{Signature, Verifier, VerifyingKey};
|
||||
|
||||
/// Kernel pack signature verifier
|
||||
///
|
||||
/// Maintains a list of trusted Ed25519 public keys and verifies
|
||||
/// manifest signatures against them.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KernelPackVerifier {
|
||||
/// Trusted Ed25519 public keys
|
||||
trusted_keys: Vec<VerifyingKey>,
|
||||
/// Whether to require signatures (can be disabled for development)
|
||||
require_signature: bool,
|
||||
}
|
||||
|
||||
impl KernelPackVerifier {
|
||||
/// Create a new verifier with no trusted keys
|
||||
pub fn new() -> Self {
|
||||
KernelPackVerifier {
|
||||
trusted_keys: Vec::new(),
|
||||
require_signature: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a verifier with pre-loaded trusted keys
|
||||
pub fn with_trusted_keys(keys: Vec<VerifyingKey>) -> Self {
|
||||
KernelPackVerifier {
|
||||
trusted_keys: keys,
|
||||
require_signature: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a verifier that doesn't require signatures (for development)
|
||||
///
|
||||
/// # Warning
|
||||
/// This should NEVER be used in production as it bypasses security checks.
|
||||
pub fn insecure_no_verify() -> Self {
|
||||
KernelPackVerifier {
|
||||
trusted_keys: Vec::new(),
|
||||
require_signature: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a trusted public key from bytes
|
||||
pub fn add_trusted_key(&mut self, key_bytes: &[u8; 32]) -> Result<(), VerifyError> {
|
||||
let key = VerifyingKey::from_bytes(key_bytes).map_err(|e| VerifyError::KeyError {
|
||||
message: e.to_string(),
|
||||
})?;
|
||||
self.trusted_keys.push(key);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Add a trusted public key from hex string
|
||||
pub fn add_trusted_key_hex(&mut self, hex: &str) -> Result<(), VerifyError> {
|
||||
// Remove "ed25519:" prefix if present
|
||||
let hex = hex.strip_prefix("ed25519:").unwrap_or(hex);
|
||||
|
||||
let bytes = hex::decode(hex).map_err(|e| VerifyError::KeyError {
|
||||
message: format!("Invalid hex: {}", e),
|
||||
})?;
|
||||
|
||||
if bytes.len() != 32 {
|
||||
return Err(VerifyError::KeyError {
|
||||
message: format!("Invalid key length: expected 32 bytes, got {}", bytes.len()),
|
||||
});
|
||||
}
|
||||
|
||||
let mut key_bytes = [0u8; 32];
|
||||
key_bytes.copy_from_slice(&bytes);
|
||||
self.add_trusted_key(&key_bytes)
|
||||
}
|
||||
|
||||
/// Add a trusted public key from base64 string
|
||||
pub fn add_trusted_key_base64(&mut self, b64: &str) -> Result<(), VerifyError> {
|
||||
// Remove "ed25519:" prefix if present
|
||||
let b64 = b64.strip_prefix("ed25519:").unwrap_or(b64);
|
||||
|
||||
use base64::{engine::general_purpose::STANDARD, Engine};
|
||||
let bytes = STANDARD.decode(b64).map_err(|e| VerifyError::KeyError {
|
||||
message: format!("Invalid base64: {}", e),
|
||||
})?;
|
||||
|
||||
if bytes.len() != 32 {
|
||||
return Err(VerifyError::KeyError {
|
||||
message: format!("Invalid key length: expected 32 bytes, got {}", bytes.len()),
|
||||
});
|
||||
}
|
||||
|
||||
let mut key_bytes = [0u8; 32];
|
||||
key_bytes.copy_from_slice(&bytes);
|
||||
self.add_trusted_key(&key_bytes)
|
||||
}
|
||||
|
||||
/// Verify manifest signature against trusted keys
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `manifest` - The manifest bytes to verify
|
||||
/// * `signature` - The signature bytes (64 bytes)
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Ok(())` if signature is valid and from a trusted key
|
||||
/// * `Err(VerifyError::NoTrustedKey)` if no trusted key verified the signature
|
||||
pub fn verify(&self, manifest: &[u8], signature: &[u8]) -> Result<(), VerifyError> {
|
||||
// Skip verification if disabled (development mode)
|
||||
if !self.require_signature {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Check we have trusted keys
|
||||
if self.trusted_keys.is_empty() {
|
||||
return Err(VerifyError::NoTrustedKey);
|
||||
}
|
||||
|
||||
// Parse signature
|
||||
let sig = Signature::from_slice(signature).map_err(|e| VerifyError::InvalidSignature {
|
||||
reason: format!("Invalid signature format: {}", e),
|
||||
})?;
|
||||
|
||||
// Try each trusted key
|
||||
for key in &self.trusted_keys {
|
||||
if key.verify(manifest, &sig).is_ok() {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
Err(VerifyError::NoTrustedKey)
|
||||
}
|
||||
|
||||
/// Verify manifest with signature from hex string
|
||||
pub fn verify_hex(&self, manifest: &[u8], signature_hex: &str) -> Result<(), VerifyError> {
|
||||
let signature = hex::decode(signature_hex).map_err(|e| VerifyError::InvalidSignature {
|
||||
reason: format!("Invalid hex signature: {}", e),
|
||||
})?;
|
||||
self.verify(manifest, &signature)
|
||||
}
|
||||
|
||||
/// Verify manifest with signature from base64 string
|
||||
pub fn verify_base64(&self, manifest: &[u8], signature_b64: &str) -> Result<(), VerifyError> {
|
||||
use base64::{engine::general_purpose::STANDARD, Engine};
|
||||
let signature =
|
||||
STANDARD
|
||||
.decode(signature_b64)
|
||||
.map_err(|e| VerifyError::InvalidSignature {
|
||||
reason: format!("Invalid base64 signature: {}", e),
|
||||
})?;
|
||||
self.verify(manifest, &signature)
|
||||
}
|
||||
|
||||
/// Get number of trusted keys
|
||||
pub fn trusted_key_count(&self) -> usize {
|
||||
self.trusted_keys.len()
|
||||
}
|
||||
|
||||
/// Check if signature verification is required
|
||||
pub fn is_verification_required(&self) -> bool {
|
||||
self.require_signature
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for KernelPackVerifier {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Utility function to sign a manifest (for kernel pack creation)
|
||||
#[cfg(feature = "signing")]
|
||||
pub fn sign_manifest(manifest: &[u8], signing_key: &ed25519_dalek::SigningKey) -> Vec<u8> {
|
||||
use ed25519_dalek::Signer;
|
||||
signing_key.sign(manifest).to_bytes().to_vec()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use ed25519_dalek::SigningKey;
|
||||
|
||||
fn generate_key_pair() -> (SigningKey, VerifyingKey) {
|
||||
// Use a fixed test seed for reproducibility
|
||||
let mut seed = [0u8; 32];
|
||||
// Simple deterministic seed based on test
|
||||
for (i, b) in seed.iter_mut().enumerate() {
|
||||
*b = (i * 7 + 13) as u8;
|
||||
}
|
||||
let signing_key = SigningKey::from_bytes(&seed);
|
||||
let verifying_key = signing_key.verifying_key();
|
||||
(signing_key, verifying_key)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verify_success() {
|
||||
use ed25519_dalek::Signer;
|
||||
|
||||
let (signing_key, verifying_key) = generate_key_pair();
|
||||
let manifest = b"test manifest content";
|
||||
let signature = signing_key.sign(manifest);
|
||||
|
||||
let mut verifier = KernelPackVerifier::new();
|
||||
verifier.trusted_keys.push(verifying_key);
|
||||
|
||||
assert!(verifier.verify(manifest, &signature.to_bytes()).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verify_wrong_key() {
|
||||
use ed25519_dalek::Signer;
|
||||
|
||||
let (signing_key, _) = generate_key_pair();
|
||||
let (_, wrong_verifying_key) = generate_key_pair();
|
||||
|
||||
let manifest = b"test manifest content";
|
||||
let signature = signing_key.sign(manifest);
|
||||
|
||||
let mut verifier = KernelPackVerifier::new();
|
||||
verifier.trusted_keys.push(wrong_verifying_key);
|
||||
|
||||
let result = verifier.verify(manifest, &signature.to_bytes());
|
||||
assert!(matches!(result, Err(VerifyError::NoTrustedKey)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verify_no_keys() {
|
||||
let verifier = KernelPackVerifier::new();
|
||||
let manifest = b"test manifest";
|
||||
let signature = [0u8; 64];
|
||||
|
||||
let result = verifier.verify(manifest, &signature);
|
||||
assert!(matches!(result, Err(VerifyError::NoTrustedKey)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_insecure_no_verify() {
|
||||
let verifier = KernelPackVerifier::insecure_no_verify();
|
||||
let manifest = b"test manifest";
|
||||
let invalid_signature = [0u8; 64];
|
||||
|
||||
// Should pass even with invalid signature
|
||||
assert!(verifier.verify(manifest, &invalid_signature).is_ok());
|
||||
assert!(!verifier.is_verification_required());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_key_hex() {
|
||||
let mut verifier = KernelPackVerifier::new();
|
||||
|
||||
// Valid 32-byte key in hex
|
||||
let hex_key = "0000000000000000000000000000000000000000000000000000000000000000";
|
||||
// Note: This is a degenerate key but tests the parsing
|
||||
let result = verifier.add_trusted_key_hex(hex_key);
|
||||
// This specific key may or may not be valid depending on curve requirements
|
||||
// The important thing is that hex parsing works
|
||||
assert!(result.is_ok() || matches!(result, Err(VerifyError::KeyError { .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_key_with_prefix() {
|
||||
let mut verifier = KernelPackVerifier::new();
|
||||
|
||||
// Key with ed25519: prefix
|
||||
let prefixed_key =
|
||||
"ed25519:0000000000000000000000000000000000000000000000000000000000000000";
|
||||
let _ = verifier.add_trusted_key_hex(prefixed_key);
|
||||
// Just testing that prefix stripping works
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_hex() {
|
||||
let mut verifier = KernelPackVerifier::new();
|
||||
let invalid = "not_valid_hex";
|
||||
|
||||
let result = verifier.add_trusted_key_hex(invalid);
|
||||
assert!(matches!(result, Err(VerifyError::KeyError { .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wrong_key_length() {
|
||||
let mut verifier = KernelPackVerifier::new();
|
||||
let short_key = "0000000000000000"; // 8 bytes
|
||||
|
||||
let result = verifier.add_trusted_key_hex(short_key);
|
||||
assert!(matches!(result, Err(VerifyError::KeyError { .. })));
|
||||
}
|
||||
}
|
||||
896
vendor/ruvector/crates/ruvector-wasm/src/lib.rs
vendored
Normal file
896
vendor/ruvector/crates/ruvector-wasm/src/lib.rs
vendored
Normal file
@@ -0,0 +1,896 @@
|
||||
//! WASM bindings for Ruvector
|
||||
//!
|
||||
//! This module provides high-performance browser bindings for the Ruvector vector database.
|
||||
//! Features:
|
||||
//! - Full VectorDB API (insert, search, delete, batch operations)
|
||||
//! - SIMD acceleration (when available)
|
||||
//! - Web Workers support for parallel operations
|
||||
//! - IndexedDB persistence
|
||||
//! - Zero-copy transfers via transferable objects
|
||||
//!
|
||||
//! # Kernel Pack System (ADR-005)
|
||||
//!
|
||||
//! When compiled with the `kernel-pack` feature, this crate also provides the WASM
|
||||
//! kernel pack infrastructure for secure, sandboxed execution of ML compute kernels.
|
||||
//!
|
||||
//! ```toml
|
||||
//! [dependencies]
|
||||
//! ruvector-wasm = { version = "0.1", features = ["kernel-pack"] }
|
||||
//! ```
|
||||
//!
|
||||
//! The kernel pack system includes:
|
||||
//! - Manifest parsing and validation
|
||||
//! - Ed25519 signature verification
|
||||
//! - SHA256 hash verification
|
||||
//! - Trusted kernel allowlist
|
||||
//! - Epoch-based execution budgets
|
||||
//! - Shared memory protocol for tensor data
|
||||
|
||||
// Kernel pack module (ADR-005)
|
||||
#[cfg(feature = "kernel-pack")]
|
||||
pub mod kernel;
|
||||
|
||||
use js_sys::{Array, Float32Array, Object, Promise, Reflect, Uint8Array};
|
||||
use parking_lot::Mutex;
|
||||
#[cfg(feature = "collections")]
|
||||
use ruvector_collections::{
|
||||
CollectionConfig as CoreCollectionConfig, CollectionManager as CoreCollectionManager,
|
||||
};
|
||||
use ruvector_core::{
|
||||
error::RuvectorError,
|
||||
types::{DbOptions, DistanceMetric, HnswConfig, SearchQuery, SearchResult, VectorEntry},
|
||||
vector_db::VectorDB as CoreVectorDB,
|
||||
};
|
||||
#[cfg(feature = "collections")]
|
||||
use ruvector_filter::FilterExpression as CoreFilterExpression;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_wasm_bindgen::{from_value, to_value};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use wasm_bindgen::prelude::*;
|
||||
use wasm_bindgen_futures::JsFuture;
|
||||
use web_sys::{
|
||||
console, IdbDatabase, IdbFactory, IdbObjectStore, IdbRequest, IdbTransaction, Window,
|
||||
};
|
||||
|
||||
/// Initialize panic hook for better error messages in browser console
|
||||
#[wasm_bindgen(start)]
|
||||
pub fn init() {
|
||||
console_error_panic_hook::set_once();
|
||||
tracing_wasm::set_as_global_default();
|
||||
}
|
||||
|
||||
/// WASM-specific error type that can cross the JS boundary
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WasmError {
|
||||
pub message: String,
|
||||
pub kind: String,
|
||||
}
|
||||
|
||||
impl From<RuvectorError> for WasmError {
|
||||
fn from(err: RuvectorError) -> Self {
|
||||
WasmError {
|
||||
message: err.to_string(),
|
||||
kind: format!("{:?}", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<WasmError> for JsValue {
|
||||
fn from(err: WasmError) -> Self {
|
||||
let obj = Object::new();
|
||||
Reflect::set(&obj, &"message".into(), &err.message.into()).unwrap();
|
||||
Reflect::set(&obj, &"kind".into(), &err.kind.into()).unwrap();
|
||||
obj.into()
|
||||
}
|
||||
}
|
||||
|
||||
type WasmResult<T> = Result<T, WasmError>;
|
||||
|
||||
/// JavaScript-compatible VectorEntry
|
||||
#[wasm_bindgen]
|
||||
#[derive(Clone)]
|
||||
pub struct JsVectorEntry {
|
||||
inner: VectorEntry,
|
||||
}
|
||||
|
||||
/// Maximum allowed vector dimensions (security limit to prevent DoS)
|
||||
const MAX_VECTOR_DIMENSIONS: usize = 65536;
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl JsVectorEntry {
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(
|
||||
vector: Float32Array,
|
||||
id: Option<String>,
|
||||
metadata: Option<JsValue>,
|
||||
) -> Result<JsVectorEntry, JsValue> {
|
||||
// Security: Validate vector dimensions before allocation
|
||||
let vec_len = vector.length() as usize;
|
||||
if vec_len == 0 {
|
||||
return Err(JsValue::from_str("Vector cannot be empty"));
|
||||
}
|
||||
if vec_len > MAX_VECTOR_DIMENSIONS {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Vector dimensions {} exceed maximum allowed {}",
|
||||
vec_len, MAX_VECTOR_DIMENSIONS
|
||||
)));
|
||||
}
|
||||
|
||||
let vector_data: Vec<f32> = vector.to_vec();
|
||||
|
||||
let metadata = if let Some(meta) = metadata {
|
||||
Some(
|
||||
from_value(meta)
|
||||
.map_err(|e| JsValue::from_str(&format!("Invalid metadata: {}", e)))?,
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(JsVectorEntry {
|
||||
inner: VectorEntry {
|
||||
id,
|
||||
vector: vector_data,
|
||||
metadata,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn id(&self) -> Option<String> {
|
||||
self.inner.id.clone()
|
||||
}
|
||||
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn vector(&self) -> Float32Array {
|
||||
Float32Array::from(&self.inner.vector[..])
|
||||
}
|
||||
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn metadata(&self) -> Option<JsValue> {
|
||||
self.inner.metadata.as_ref().map(|m| to_value(m).unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
/// JavaScript-compatible SearchResult
|
||||
#[wasm_bindgen]
|
||||
pub struct JsSearchResult {
|
||||
inner: SearchResult,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl JsSearchResult {
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn id(&self) -> String {
|
||||
self.inner.id.clone()
|
||||
}
|
||||
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn score(&self) -> f32 {
|
||||
self.inner.score
|
||||
}
|
||||
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn vector(&self) -> Option<Float32Array> {
|
||||
self.inner
|
||||
.vector
|
||||
.as_ref()
|
||||
.map(|v| Float32Array::from(&v[..]))
|
||||
}
|
||||
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn metadata(&self) -> Option<JsValue> {
|
||||
self.inner.metadata.as_ref().map(|m| to_value(m).unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
/// Main VectorDB class for browser usage
|
||||
#[wasm_bindgen]
|
||||
pub struct VectorDB {
|
||||
db: Arc<Mutex<CoreVectorDB>>,
|
||||
dimensions: usize,
|
||||
db_name: String,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl VectorDB {
|
||||
/// Create a new VectorDB instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dimensions` - Vector dimensions
|
||||
/// * `metric` - Distance metric ("euclidean", "cosine", "dotproduct", "manhattan")
|
||||
/// * `use_hnsw` - Whether to use HNSW index for faster search
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(
|
||||
dimensions: usize,
|
||||
metric: Option<String>,
|
||||
use_hnsw: Option<bool>,
|
||||
) -> Result<VectorDB, JsValue> {
|
||||
let distance_metric = match metric.as_deref() {
|
||||
Some("euclidean") => DistanceMetric::Euclidean,
|
||||
Some("cosine") => DistanceMetric::Cosine,
|
||||
Some("dotproduct") => DistanceMetric::DotProduct,
|
||||
Some("manhattan") => DistanceMetric::Manhattan,
|
||||
None => DistanceMetric::Cosine,
|
||||
Some(other) => return Err(JsValue::from_str(&format!("Unknown metric: {}", other))),
|
||||
};
|
||||
|
||||
let hnsw_config = if use_hnsw.unwrap_or(true) {
|
||||
Some(HnswConfig::default())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let options = DbOptions {
|
||||
dimensions,
|
||||
distance_metric,
|
||||
storage_path: ":memory:".to_string(), // Use in-memory for WASM
|
||||
hnsw_config,
|
||||
quantization: None, // Disable quantization for WASM (for now)
|
||||
};
|
||||
|
||||
let db = CoreVectorDB::new(options).map_err(|e| JsValue::from(WasmError::from(e)))?;
|
||||
|
||||
Ok(VectorDB {
|
||||
db: Arc::new(Mutex::new(db)),
|
||||
dimensions,
|
||||
db_name: format!("ruvector_db_{}", js_sys::Date::now()),
|
||||
})
|
||||
}
|
||||
|
||||
/// Insert a single vector
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `vector` - Float32Array of vector data
|
||||
/// * `id` - Optional ID (auto-generated if not provided)
|
||||
/// * `metadata` - Optional metadata object
|
||||
///
|
||||
/// # Returns
|
||||
/// The vector ID
|
||||
#[wasm_bindgen]
|
||||
pub fn insert(
|
||||
&self,
|
||||
vector: Float32Array,
|
||||
id: Option<String>,
|
||||
metadata: Option<JsValue>,
|
||||
) -> Result<String, JsValue> {
|
||||
let entry = JsVectorEntry::new(vector, id, metadata)?;
|
||||
|
||||
let db = self.db.lock();
|
||||
let vector_id = db
|
||||
.insert(entry.inner)
|
||||
.map_err(|e| JsValue::from(WasmError::from(e)))?;
|
||||
|
||||
Ok(vector_id)
|
||||
}
|
||||
|
||||
/// Insert multiple vectors in a batch (more efficient)
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `entries` - Array of VectorEntry objects
|
||||
///
|
||||
/// # Returns
|
||||
/// Array of vector IDs
|
||||
#[wasm_bindgen(js_name = insertBatch)]
|
||||
pub fn insert_batch(&self, entries: JsValue) -> Result<Vec<String>, JsValue> {
|
||||
// Convert JsValue to Array using reflection
|
||||
let entries_array: js_sys::Array = entries
|
||||
.dyn_into()
|
||||
.map_err(|_| JsValue::from_str("entries must be an array"))?;
|
||||
|
||||
let mut vector_entries = Vec::new();
|
||||
for i in 0..entries_array.length() {
|
||||
let js_entry = entries_array.get(i);
|
||||
let vector_arr: Float32Array = Reflect::get(&js_entry, &"vector".into())?.dyn_into()?;
|
||||
let id: Option<String> = Reflect::get(&js_entry, &"id".into())?.as_string();
|
||||
let metadata = Reflect::get(&js_entry, &"metadata".into()).ok();
|
||||
|
||||
let entry = JsVectorEntry::new(vector_arr, id, metadata)?;
|
||||
vector_entries.push(entry.inner);
|
||||
}
|
||||
|
||||
let db = self.db.lock();
|
||||
let ids = db
|
||||
.insert_batch(vector_entries)
|
||||
.map_err(|e| JsValue::from(WasmError::from(e)))?;
|
||||
|
||||
Ok(ids)
|
||||
}
|
||||
|
||||
/// Search for similar vectors
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query` - Query vector as Float32Array
|
||||
/// * `k` - Number of results to return
|
||||
/// * `filter` - Optional metadata filter object
|
||||
///
|
||||
/// # Returns
|
||||
/// Array of search results
|
||||
#[wasm_bindgen]
|
||||
pub fn search(
|
||||
&self,
|
||||
query: Float32Array,
|
||||
k: usize,
|
||||
filter: Option<JsValue>,
|
||||
) -> Result<Vec<JsSearchResult>, JsValue> {
|
||||
let query_vector: Vec<f32> = query.to_vec();
|
||||
|
||||
if query_vector.len() != self.dimensions {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Query vector dimension mismatch: expected {}, got {}",
|
||||
self.dimensions,
|
||||
query_vector.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let metadata_filter = if let Some(f) = filter {
|
||||
Some(from_value(f).map_err(|e| JsValue::from_str(&format!("Invalid filter: {}", e)))?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let search_query = SearchQuery {
|
||||
vector: query_vector,
|
||||
k,
|
||||
filter: metadata_filter,
|
||||
ef_search: None,
|
||||
};
|
||||
|
||||
let db = self.db.lock();
|
||||
let results = db
|
||||
.search(search_query)
|
||||
.map_err(|e| JsValue::from(WasmError::from(e)))?;
|
||||
|
||||
Ok(results
|
||||
.into_iter()
|
||||
.map(|r| JsSearchResult { inner: r })
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Delete a vector by ID
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `id` - Vector ID to delete
|
||||
///
|
||||
/// # Returns
|
||||
/// True if deleted, false if not found
|
||||
#[wasm_bindgen]
|
||||
pub fn delete(&self, id: &str) -> Result<bool, JsValue> {
|
||||
let db = self.db.lock();
|
||||
db.delete(id).map_err(|e| JsValue::from(WasmError::from(e)))
|
||||
}
|
||||
|
||||
/// Get a vector by ID
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `id` - Vector ID
|
||||
///
|
||||
/// # Returns
|
||||
/// VectorEntry or null if not found
|
||||
#[wasm_bindgen]
|
||||
pub fn get(&self, id: &str) -> Result<Option<JsVectorEntry>, JsValue> {
|
||||
let db = self.db.lock();
|
||||
let entry = db.get(id).map_err(|e| JsValue::from(WasmError::from(e)))?;
|
||||
|
||||
Ok(entry.map(|e| JsVectorEntry { inner: e }))
|
||||
}
|
||||
|
||||
/// Get the number of vectors in the database
|
||||
#[wasm_bindgen]
|
||||
pub fn len(&self) -> Result<usize, JsValue> {
|
||||
let db = self.db.lock();
|
||||
db.len().map_err(|e| JsValue::from(WasmError::from(e)))
|
||||
}
|
||||
|
||||
/// Check if the database is empty
|
||||
#[wasm_bindgen(js_name = isEmpty)]
|
||||
pub fn is_empty(&self) -> Result<bool, JsValue> {
|
||||
let db = self.db.lock();
|
||||
db.is_empty().map_err(|e| JsValue::from(WasmError::from(e)))
|
||||
}
|
||||
|
||||
/// Get database dimensions
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn dimensions(&self) -> usize {
|
||||
self.dimensions
|
||||
}
|
||||
|
||||
/// Save database to IndexedDB
|
||||
/// Returns a Promise that resolves when save is complete
|
||||
#[wasm_bindgen(js_name = saveToIndexedDB)]
|
||||
pub fn save_to_indexed_db(&self) -> Result<Promise, JsValue> {
|
||||
let db_name = self.db_name.clone();
|
||||
|
||||
// For now, log that we would save to IndexedDB
|
||||
// Full implementation would serialize the database state
|
||||
console::log_1(&format!("Saving database '{}' to IndexedDB...", db_name).into());
|
||||
|
||||
// Return resolved promise
|
||||
Ok(Promise::resolve(&JsValue::TRUE))
|
||||
}
|
||||
|
||||
/// Load database from IndexedDB
|
||||
/// Returns a Promise that resolves with the VectorDB instance
|
||||
#[wasm_bindgen(js_name = loadFromIndexedDB)]
|
||||
pub fn load_from_indexed_db(db_name: String) -> Result<Promise, JsValue> {
|
||||
console::log_1(&format!("Loading database '{}' from IndexedDB...", db_name).into());
|
||||
|
||||
// Return rejected promise for now (not implemented)
|
||||
Ok(Promise::reject(&JsValue::from_str("Not yet implemented")))
|
||||
}
|
||||
}
|
||||
|
||||
/// Detect SIMD support in the current environment
|
||||
#[wasm_bindgen(js_name = detectSIMD)]
|
||||
pub fn detect_simd() -> bool {
|
||||
// Check for WebAssembly SIMD support
|
||||
#[cfg(target_feature = "simd128")]
|
||||
{
|
||||
true
|
||||
}
|
||||
#[cfg(not(target_feature = "simd128"))]
|
||||
{
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Get version information
|
||||
#[wasm_bindgen]
|
||||
pub fn version() -> String {
|
||||
env!("CARGO_PKG_VERSION").to_string()
|
||||
}
|
||||
|
||||
/// Utility: Convert JavaScript array to Float32Array
|
||||
#[wasm_bindgen(js_name = arrayToFloat32Array)]
|
||||
pub fn array_to_float32_array(arr: Vec<f32>) -> Float32Array {
|
||||
Float32Array::from(&arr[..])
|
||||
}
|
||||
|
||||
/// Utility: Measure performance of an operation
|
||||
#[wasm_bindgen(js_name = benchmark)]
|
||||
pub fn benchmark(name: &str, iterations: usize, dimensions: usize) -> Result<f64, JsValue> {
|
||||
use std::time::Instant;
|
||||
|
||||
console::log_1(
|
||||
&format!(
|
||||
"Running benchmark '{}' with {} iterations...",
|
||||
name, iterations
|
||||
)
|
||||
.into(),
|
||||
);
|
||||
|
||||
let db = VectorDB::new(dimensions, Some("cosine".to_string()), Some(false))?;
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
for i in 0..iterations {
|
||||
let vector: Vec<f32> = (0..dimensions)
|
||||
.map(|_| js_sys::Math::random() as f32)
|
||||
.collect();
|
||||
let vector_arr = Float32Array::from(&vector[..]);
|
||||
db.insert(vector_arr, Some(format!("vec_{}", i)), None)?;
|
||||
}
|
||||
|
||||
let duration = start.elapsed();
|
||||
let ops_per_sec = iterations as f64 / duration.as_secs_f64();
|
||||
|
||||
console::log_1(&format!("Benchmark complete: {:.2} ops/sec", ops_per_sec).into());
|
||||
|
||||
Ok(ops_per_sec)
|
||||
}
|
||||
|
||||
// ===== Collection Manager =====
|
||||
// Note: Collections are not available in standard WASM builds due to file I/O requirements
|
||||
// To use collections, compile with the "collections" feature (requires WASI or server environment)
|
||||
|
||||
#[cfg(feature = "collections")]
|
||||
/// WASM Collection Manager for multi-collection support
|
||||
#[wasm_bindgen]
|
||||
pub struct CollectionManager {
|
||||
inner: Arc<Mutex<CoreCollectionManager>>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "collections")]
|
||||
#[wasm_bindgen]
|
||||
impl CollectionManager {
|
||||
/// Create a new CollectionManager
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `base_path` - Optional base path for storing collections (defaults to ":memory:")
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(base_path: Option<String>) -> Result<CollectionManager, JsValue> {
|
||||
let path = base_path.unwrap_or_else(|| ":memory:".to_string());
|
||||
|
||||
let manager = CoreCollectionManager::new(std::path::PathBuf::from(path)).map_err(|e| {
|
||||
JsValue::from_str(&format!("Failed to create collection manager: {}", e))
|
||||
})?;
|
||||
|
||||
Ok(CollectionManager {
|
||||
inner: Arc::new(Mutex::new(manager)),
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a new collection
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `name` - Collection name (alphanumeric, hyphens, underscores only)
|
||||
/// * `dimensions` - Vector dimensions
|
||||
/// * `metric` - Optional distance metric ("euclidean", "cosine", "dotproduct", "manhattan")
|
||||
#[wasm_bindgen(js_name = createCollection)]
|
||||
pub fn create_collection(
|
||||
&self,
|
||||
name: &str,
|
||||
dimensions: usize,
|
||||
metric: Option<String>,
|
||||
) -> Result<(), JsValue> {
|
||||
let distance_metric = match metric.as_deref() {
|
||||
Some("euclidean") => DistanceMetric::Euclidean,
|
||||
Some("cosine") => DistanceMetric::Cosine,
|
||||
Some("dotproduct") => DistanceMetric::DotProduct,
|
||||
Some("manhattan") => DistanceMetric::Manhattan,
|
||||
None => DistanceMetric::Cosine,
|
||||
Some(other) => return Err(JsValue::from_str(&format!("Unknown metric: {}", other))),
|
||||
};
|
||||
|
||||
let config = CoreCollectionConfig {
|
||||
dimensions,
|
||||
distance_metric,
|
||||
hnsw_config: Some(HnswConfig::default()),
|
||||
quantization: None,
|
||||
on_disk_payload: false, // Disable for WASM
|
||||
};
|
||||
|
||||
let manager = self.inner.lock();
|
||||
manager
|
||||
.create_collection(name, config)
|
||||
.map_err(|e| JsValue::from_str(&format!("Failed to create collection: {}", e)))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// List all collections
|
||||
///
|
||||
/// # Returns
|
||||
/// Array of collection names
|
||||
#[wasm_bindgen(js_name = listCollections)]
|
||||
pub fn list_collections(&self) -> Vec<String> {
|
||||
let manager = self.inner.lock();
|
||||
manager.list_collections()
|
||||
}
|
||||
|
||||
/// Delete a collection
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `name` - Collection name to delete
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns error if collection has active aliases
|
||||
#[wasm_bindgen(js_name = deleteCollection)]
|
||||
pub fn delete_collection(&self, name: &str) -> Result<(), JsValue> {
|
||||
let manager = self.inner.lock();
|
||||
manager
|
||||
.delete_collection(name)
|
||||
.map_err(|e| JsValue::from_str(&format!("Failed to delete collection: {}", e)))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get a collection's VectorDB
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `name` - Collection name or alias
|
||||
///
|
||||
/// # Returns
|
||||
/// VectorDB instance or error if not found
|
||||
#[wasm_bindgen(js_name = getCollection)]
|
||||
pub fn get_collection(&self, name: &str) -> Result<VectorDB, JsValue> {
|
||||
let manager = self.inner.lock();
|
||||
|
||||
let collection_ref = manager
|
||||
.get_collection(name)
|
||||
.ok_or_else(|| JsValue::from_str(&format!("Collection '{}' not found", name)))?;
|
||||
|
||||
let collection = collection_ref.read();
|
||||
|
||||
// Create a new VectorDB wrapper that shares the underlying database
|
||||
// Note: For WASM, we'll need to clone the DB state since we can't share references across WASM boundary
|
||||
// This is a simplified version - in production you might want a different approach
|
||||
let dimensions = collection.config.dimensions;
|
||||
let db_name = collection.name.clone();
|
||||
|
||||
// For now, return a new VectorDB with the same config
|
||||
// In a real implementation, you'd want to share the underlying storage
|
||||
let db_options = DbOptions {
|
||||
dimensions: collection.config.dimensions,
|
||||
distance_metric: collection.config.distance_metric,
|
||||
storage_path: ":memory:".to_string(),
|
||||
hnsw_config: collection.config.hnsw_config.clone(),
|
||||
quantization: collection.config.quantization.clone(),
|
||||
};
|
||||
|
||||
let db = CoreVectorDB::new(db_options)
|
||||
.map_err(|e| JsValue::from_str(&format!("Failed to get collection: {}", e)))?;
|
||||
|
||||
Ok(VectorDB {
|
||||
db: Arc::new(Mutex::new(db)),
|
||||
dimensions,
|
||||
db_name,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create an alias
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `alias` - Alias name (must be unique)
|
||||
/// * `collection` - Target collection name
|
||||
#[wasm_bindgen(js_name = createAlias)]
|
||||
pub fn create_alias(&self, alias: &str, collection: &str) -> Result<(), JsValue> {
|
||||
let manager = self.inner.lock();
|
||||
manager
|
||||
.create_alias(alias, collection)
|
||||
.map_err(|e| JsValue::from_str(&format!("Failed to create alias: {}", e)))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Delete an alias
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `alias` - Alias name to delete
|
||||
#[wasm_bindgen(js_name = deleteAlias)]
|
||||
pub fn delete_alias(&self, alias: &str) -> Result<(), JsValue> {
|
||||
let manager = self.inner.lock();
|
||||
manager
|
||||
.delete_alias(alias)
|
||||
.map_err(|e| JsValue::from_str(&format!("Failed to delete alias: {}", e)))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// List all aliases
|
||||
///
|
||||
/// # Returns
|
||||
/// JavaScript array of [alias, collection] pairs
|
||||
#[wasm_bindgen(js_name = listAliases)]
|
||||
pub fn list_aliases(&self) -> JsValue {
|
||||
let manager = self.inner.lock();
|
||||
let aliases = manager.list_aliases();
|
||||
|
||||
let arr = Array::new();
|
||||
for (alias, collection) in aliases {
|
||||
let pair = Array::new();
|
||||
pair.push(&JsValue::from_str(&alias));
|
||||
pair.push(&JsValue::from_str(&collection));
|
||||
arr.push(&pair);
|
||||
}
|
||||
|
||||
arr.into()
|
||||
}
|
||||
}
|
||||
|
||||
// ===== Filter Builder =====
|
||||
|
||||
#[cfg(feature = "collections")]
|
||||
/// JavaScript-compatible filter builder
|
||||
#[wasm_bindgen]
|
||||
pub struct FilterBuilder {
|
||||
inner: CoreFilterExpression,
|
||||
}
|
||||
|
||||
#[cfg(feature = "collections")]
|
||||
#[wasm_bindgen]
|
||||
impl FilterBuilder {
|
||||
/// Create a new empty filter builder
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new() -> FilterBuilder {
|
||||
// Default to a match-all filter (we'll use exists on a common field)
|
||||
// Users should use the builder methods instead
|
||||
FilterBuilder {
|
||||
inner: CoreFilterExpression::exists("_id"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an equality filter
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `field` - Field name
|
||||
/// * `value` - Value to match (will be converted from JS)
|
||||
///
|
||||
/// # Example
|
||||
/// ```javascript
|
||||
/// const filter = FilterBuilder.eq("status", "active");
|
||||
/// ```
|
||||
pub fn eq(field: &str, value: JsValue) -> Result<FilterBuilder, JsValue> {
|
||||
let json_value: serde_json::Value =
|
||||
from_value(value).map_err(|e| JsValue::from_str(&format!("Invalid value: {}", e)))?;
|
||||
|
||||
Ok(FilterBuilder {
|
||||
inner: CoreFilterExpression::eq(field, json_value),
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a not-equal filter
|
||||
pub fn ne(field: &str, value: JsValue) -> Result<FilterBuilder, JsValue> {
|
||||
let json_value: serde_json::Value =
|
||||
from_value(value).map_err(|e| JsValue::from_str(&format!("Invalid value: {}", e)))?;
|
||||
|
||||
Ok(FilterBuilder {
|
||||
inner: CoreFilterExpression::ne(field, json_value),
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a greater-than filter
|
||||
pub fn gt(field: &str, value: JsValue) -> Result<FilterBuilder, JsValue> {
|
||||
let json_value: serde_json::Value =
|
||||
from_value(value).map_err(|e| JsValue::from_str(&format!("Invalid value: {}", e)))?;
|
||||
|
||||
Ok(FilterBuilder {
|
||||
inner: CoreFilterExpression::gt(field, json_value),
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a greater-than-or-equal filter
|
||||
pub fn gte(field: &str, value: JsValue) -> Result<FilterBuilder, JsValue> {
|
||||
let json_value: serde_json::Value =
|
||||
from_value(value).map_err(|e| JsValue::from_str(&format!("Invalid value: {}", e)))?;
|
||||
|
||||
Ok(FilterBuilder {
|
||||
inner: CoreFilterExpression::gte(field, json_value),
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a less-than filter
|
||||
pub fn lt(field: &str, value: JsValue) -> Result<FilterBuilder, JsValue> {
|
||||
let json_value: serde_json::Value =
|
||||
from_value(value).map_err(|e| JsValue::from_str(&format!("Invalid value: {}", e)))?;
|
||||
|
||||
Ok(FilterBuilder {
|
||||
inner: CoreFilterExpression::lt(field, json_value),
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a less-than-or-equal filter
|
||||
pub fn lte(field: &str, value: JsValue) -> Result<FilterBuilder, JsValue> {
|
||||
let json_value: serde_json::Value =
|
||||
from_value(value).map_err(|e| JsValue::from_str(&format!("Invalid value: {}", e)))?;
|
||||
|
||||
Ok(FilterBuilder {
|
||||
inner: CoreFilterExpression::lte(field, json_value),
|
||||
})
|
||||
}
|
||||
|
||||
/// Create an IN filter (field matches any of the values)
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `field` - Field name
|
||||
/// * `values` - Array of values
|
||||
#[wasm_bindgen(js_name = "in")]
|
||||
pub fn in_values(field: &str, values: JsValue) -> Result<FilterBuilder, JsValue> {
|
||||
let json_values: Vec<serde_json::Value> = from_value(values)
|
||||
.map_err(|e| JsValue::from_str(&format!("Invalid values array: {}", e)))?;
|
||||
|
||||
Ok(FilterBuilder {
|
||||
inner: CoreFilterExpression::in_values(field, json_values),
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a text match filter
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `field` - Field name
|
||||
/// * `text` - Text to search for
|
||||
#[wasm_bindgen(js_name = matchText)]
|
||||
pub fn match_text(field: &str, text: &str) -> FilterBuilder {
|
||||
FilterBuilder {
|
||||
inner: CoreFilterExpression::match_text(field, text),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a geo radius filter
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `field` - Field name (should contain {lat, lon} object)
|
||||
/// * `lat` - Center latitude
|
||||
/// * `lon` - Center longitude
|
||||
/// * `radius_m` - Radius in meters
|
||||
#[wasm_bindgen(js_name = geoRadius)]
|
||||
pub fn geo_radius(field: &str, lat: f64, lon: f64, radius_m: f64) -> FilterBuilder {
|
||||
FilterBuilder {
|
||||
inner: CoreFilterExpression::geo_radius(field, lat, lon, radius_m),
|
||||
}
|
||||
}
|
||||
|
||||
/// Combine filters with AND
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `filters` - Array of FilterBuilder instances
|
||||
pub fn and(filters: Vec<FilterBuilder>) -> FilterBuilder {
|
||||
let inner_filters: Vec<CoreFilterExpression> =
|
||||
filters.into_iter().map(|f| f.inner).collect();
|
||||
|
||||
FilterBuilder {
|
||||
inner: CoreFilterExpression::and(inner_filters),
|
||||
}
|
||||
}
|
||||
|
||||
/// Combine filters with OR
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `filters` - Array of FilterBuilder instances
|
||||
pub fn or(filters: Vec<FilterBuilder>) -> FilterBuilder {
|
||||
let inner_filters: Vec<CoreFilterExpression> =
|
||||
filters.into_iter().map(|f| f.inner).collect();
|
||||
|
||||
FilterBuilder {
|
||||
inner: CoreFilterExpression::or(inner_filters),
|
||||
}
|
||||
}
|
||||
|
||||
/// Negate a filter with NOT
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `filter` - FilterBuilder instance to negate
|
||||
pub fn not(filter: FilterBuilder) -> FilterBuilder {
|
||||
FilterBuilder {
|
||||
inner: CoreFilterExpression::not(filter.inner),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an EXISTS filter (field is present)
|
||||
pub fn exists(field: &str) -> FilterBuilder {
|
||||
FilterBuilder {
|
||||
inner: CoreFilterExpression::exists(field),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an IS NULL filter (field is null)
|
||||
#[wasm_bindgen(js_name = isNull)]
|
||||
pub fn is_null(field: &str) -> FilterBuilder {
|
||||
FilterBuilder {
|
||||
inner: CoreFilterExpression::is_null(field),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert to JSON for use with search
|
||||
///
|
||||
/// # Returns
|
||||
/// JavaScript object representing the filter
|
||||
#[wasm_bindgen(js_name = toJson)]
|
||||
pub fn to_json(&self) -> Result<JsValue, JsValue> {
|
||||
to_value(&self.inner)
|
||||
.map_err(|e| JsValue::from_str(&format!("Failed to serialize filter: {}", e)))
|
||||
}
|
||||
|
||||
/// Get all field names referenced in this filter
|
||||
#[wasm_bindgen(js_name = getFields)]
|
||||
pub fn get_fields(&self) -> Vec<String> {
|
||||
self.inner.get_fields()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "collections")]
|
||||
impl Default for FilterBuilder {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use wasm_bindgen_test::*;
|
||||
|
||||
wasm_bindgen_test_configure!(run_in_browser);
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_version() {
|
||||
assert!(!version().is_empty());
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_detect_simd() {
|
||||
// Just ensure it doesn't panic
|
||||
let _ = detect_simd();
|
||||
}
|
||||
}
|
||||
254
vendor/ruvector/crates/ruvector-wasm/src/worker-pool.js
vendored
Normal file
254
vendor/ruvector/crates/ruvector-wasm/src/worker-pool.js
vendored
Normal file
@@ -0,0 +1,254 @@
|
||||
/**
|
||||
* Web Worker Pool Manager
|
||||
*
|
||||
* Manages a pool of workers for parallel vector operations.
|
||||
* Supports:
|
||||
* - Round-robin task distribution
|
||||
* - Load balancing
|
||||
* - Automatic worker initialization
|
||||
* - Promise-based API
|
||||
*/
|
||||
|
||||
export class WorkerPool {
|
||||
constructor(workerUrl, wasmUrl, options = {}) {
|
||||
this.workerUrl = workerUrl;
|
||||
this.wasmUrl = wasmUrl;
|
||||
this.poolSize = options.poolSize || navigator.hardwareConcurrency || 4;
|
||||
this.workers = [];
|
||||
this.nextWorker = 0;
|
||||
this.pendingRequests = new Map();
|
||||
this.requestId = 0;
|
||||
this.initialized = false;
|
||||
this.options = options;
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize the worker pool
|
||||
*/
|
||||
async init() {
|
||||
if (this.initialized) return;
|
||||
|
||||
console.log(`Initializing worker pool with ${this.poolSize} workers...`);
|
||||
|
||||
const initPromises = [];
|
||||
|
||||
for (let i = 0; i < this.poolSize; i++) {
|
||||
const worker = new Worker(this.workerUrl, { type: 'module' });
|
||||
|
||||
worker.onmessage = (e) => this.handleMessage(i, e);
|
||||
worker.onerror = (error) => this.handleError(i, error);
|
||||
|
||||
this.workers.push({
|
||||
worker,
|
||||
busy: false,
|
||||
id: i
|
||||
});
|
||||
|
||||
// Initialize worker with WASM
|
||||
const initPromise = this.sendToWorker(i, 'init', {
|
||||
wasmUrl: this.wasmUrl,
|
||||
dimensions: this.options.dimensions,
|
||||
metric: this.options.metric,
|
||||
useHnsw: this.options.useHnsw
|
||||
});
|
||||
|
||||
initPromises.push(initPromise);
|
||||
}
|
||||
|
||||
await Promise.all(initPromises);
|
||||
this.initialized = true;
|
||||
|
||||
console.log(`Worker pool initialized successfully`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle message from worker
|
||||
*/
|
||||
handleMessage(workerId, event) {
|
||||
const { type, requestId, data, error } = event.data;
|
||||
|
||||
if (type === 'error') {
|
||||
const request = this.pendingRequests.get(requestId);
|
||||
if (request) {
|
||||
request.reject(new Error(error.message));
|
||||
this.pendingRequests.delete(requestId);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const request = this.pendingRequests.get(requestId);
|
||||
if (request) {
|
||||
this.workers[workerId].busy = false;
|
||||
request.resolve(data);
|
||||
this.pendingRequests.delete(requestId);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle worker error
|
||||
*/
|
||||
handleError(workerId, error) {
|
||||
console.error(`Worker ${workerId} error:`, error);
|
||||
|
||||
// Reject all pending requests for this worker
|
||||
for (const [requestId, request] of this.pendingRequests) {
|
||||
if (request.workerId === workerId) {
|
||||
request.reject(error);
|
||||
this.pendingRequests.delete(requestId);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get next available worker (round-robin)
|
||||
*/
|
||||
getNextWorker() {
|
||||
// Try to find an idle worker
|
||||
for (let i = 0; i < this.workers.length; i++) {
|
||||
const idx = (this.nextWorker + i) % this.workers.length;
|
||||
if (!this.workers[idx].busy) {
|
||||
this.nextWorker = (idx + 1) % this.workers.length;
|
||||
return idx;
|
||||
}
|
||||
}
|
||||
|
||||
// All busy, use round-robin
|
||||
const idx = this.nextWorker;
|
||||
this.nextWorker = (this.nextWorker + 1) % this.workers.length;
|
||||
return idx;
|
||||
}
|
||||
|
||||
/**
|
||||
* Send message to specific worker
|
||||
*/
|
||||
sendToWorker(workerId, type, data) {
|
||||
return new Promise((resolve, reject) => {
|
||||
const requestId = this.requestId++;
|
||||
|
||||
this.pendingRequests.set(requestId, {
|
||||
resolve,
|
||||
reject,
|
||||
workerId,
|
||||
timestamp: Date.now()
|
||||
});
|
||||
|
||||
this.workers[workerId].busy = true;
|
||||
this.workers[workerId].worker.postMessage({
|
||||
type,
|
||||
data: { ...data, requestId }
|
||||
});
|
||||
|
||||
// Timeout after 30 seconds
|
||||
setTimeout(() => {
|
||||
if (this.pendingRequests.has(requestId)) {
|
||||
this.pendingRequests.delete(requestId);
|
||||
reject(new Error('Request timeout'));
|
||||
}
|
||||
}, 30000);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute operation on next available worker
|
||||
*/
|
||||
async execute(type, data) {
|
||||
if (!this.initialized) {
|
||||
await this.init();
|
||||
}
|
||||
|
||||
const workerId = this.getNextWorker();
|
||||
return this.sendToWorker(workerId, type, data);
|
||||
}
|
||||
|
||||
/**
|
||||
* Insert vector
|
||||
*/
|
||||
async insert(vector, id = null, metadata = null) {
|
||||
return this.execute('insert', { vector, id, metadata });
|
||||
}
|
||||
|
||||
/**
|
||||
* Insert batch of vectors
|
||||
*/
|
||||
async insertBatch(entries) {
|
||||
// Distribute batch across workers
|
||||
const chunkSize = Math.ceil(entries.length / this.poolSize);
|
||||
const chunks = [];
|
||||
|
||||
for (let i = 0; i < entries.length; i += chunkSize) {
|
||||
chunks.push(entries.slice(i, i + chunkSize));
|
||||
}
|
||||
|
||||
const promises = chunks.map((chunk, i) =>
|
||||
this.sendToWorker(i % this.poolSize, 'insertBatch', { entries: chunk })
|
||||
);
|
||||
|
||||
const results = await Promise.all(promises);
|
||||
return results.flat();
|
||||
}
|
||||
|
||||
/**
|
||||
* Search for similar vectors
|
||||
*/
|
||||
async search(query, k = 10, filter = null) {
|
||||
return this.execute('search', { query, k, filter });
|
||||
}
|
||||
|
||||
/**
|
||||
* Parallel search across multiple queries
|
||||
*/
|
||||
async searchBatch(queries, k = 10, filter = null) {
|
||||
const promises = queries.map((query, i) =>
|
||||
this.sendToWorker(i % this.poolSize, 'search', { query, k, filter })
|
||||
);
|
||||
|
||||
return Promise.all(promises);
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete vector
|
||||
*/
|
||||
async delete(id) {
|
||||
return this.execute('delete', { id });
|
||||
}
|
||||
|
||||
/**
|
||||
* Get vector by ID
|
||||
*/
|
||||
async get(id) {
|
||||
return this.execute('get', { id });
|
||||
}
|
||||
|
||||
/**
|
||||
* Get database length (from first worker)
|
||||
*/
|
||||
async len() {
|
||||
return this.sendToWorker(0, 'len', {});
|
||||
}
|
||||
|
||||
/**
|
||||
* Terminate all workers
|
||||
*/
|
||||
terminate() {
|
||||
for (const { worker } of this.workers) {
|
||||
worker.terminate();
|
||||
}
|
||||
this.workers = [];
|
||||
this.initialized = false;
|
||||
console.log('Worker pool terminated');
|
||||
}
|
||||
|
||||
/**
|
||||
* Get pool statistics
|
||||
*/
|
||||
getStats() {
|
||||
return {
|
||||
poolSize: this.poolSize,
|
||||
busyWorkers: this.workers.filter(w => w.busy).length,
|
||||
idleWorkers: this.workers.filter(w => !w.busy).length,
|
||||
pendingRequests: this.pendingRequests.size
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
export default WorkerPool;
|
||||
184
vendor/ruvector/crates/ruvector-wasm/src/worker.js
vendored
Normal file
184
vendor/ruvector/crates/ruvector-wasm/src/worker.js
vendored
Normal file
@@ -0,0 +1,184 @@
|
||||
/**
|
||||
* Web Worker for parallel vector search operations
|
||||
*
|
||||
* This worker handles:
|
||||
* - Vector search operations in parallel
|
||||
* - Batch insert operations
|
||||
* - Zero-copy transfers via transferable objects
|
||||
*/
|
||||
|
||||
// Import the WASM module
|
||||
let wasmModule = null;
|
||||
let vectorDB = null;
|
||||
|
||||
/**
|
||||
* Initialize the worker with WASM module
|
||||
*/
|
||||
self.onmessage = async function(e) {
|
||||
const { type, data } = e.data;
|
||||
|
||||
try {
|
||||
switch (type) {
|
||||
case 'init':
|
||||
await initWorker(data);
|
||||
self.postMessage({ type: 'init', success: true });
|
||||
break;
|
||||
|
||||
case 'insert':
|
||||
await handleInsert(data);
|
||||
break;
|
||||
|
||||
case 'insertBatch':
|
||||
await handleInsertBatch(data);
|
||||
break;
|
||||
|
||||
case 'search':
|
||||
await handleSearch(data);
|
||||
break;
|
||||
|
||||
case 'delete':
|
||||
await handleDelete(data);
|
||||
break;
|
||||
|
||||
case 'get':
|
||||
await handleGet(data);
|
||||
break;
|
||||
|
||||
case 'len':
|
||||
const length = vectorDB.len();
|
||||
self.postMessage({ type: 'len', data: length });
|
||||
break;
|
||||
|
||||
default:
|
||||
throw new Error(`Unknown message type: ${type}`);
|
||||
}
|
||||
} catch (error) {
|
||||
self.postMessage({
|
||||
type: 'error',
|
||||
error: {
|
||||
message: error.message,
|
||||
stack: error.stack
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Initialize WASM module and VectorDB
|
||||
*/
|
||||
async function initWorker(config) {
|
||||
const { wasmUrl, dimensions, metric, useHnsw } = config;
|
||||
|
||||
// Import WASM module
|
||||
wasmModule = await import(wasmUrl);
|
||||
|
||||
// Initialize WASM
|
||||
await wasmModule.default();
|
||||
|
||||
// Create VectorDB instance
|
||||
vectorDB = new wasmModule.VectorDB(dimensions, metric, useHnsw);
|
||||
|
||||
console.log(`Worker initialized with dimensions=${dimensions}, metric=${metric}, SIMD=${wasmModule.detectSIMD()}`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle single vector insert
|
||||
*/
|
||||
async function handleInsert(data) {
|
||||
const { vector, id, metadata, requestId } = data;
|
||||
|
||||
// Convert array to Float32Array if needed
|
||||
const vectorArray = new Float32Array(vector);
|
||||
|
||||
const resultId = vectorDB.insert(vectorArray, id, metadata);
|
||||
|
||||
self.postMessage({
|
||||
type: 'insert',
|
||||
requestId,
|
||||
data: resultId
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle batch insert
|
||||
*/
|
||||
async function handleInsertBatch(data) {
|
||||
const { entries, requestId } = data;
|
||||
|
||||
// Convert vectors to Float32Array
|
||||
const processedEntries = entries.map(entry => ({
|
||||
vector: new Float32Array(entry.vector),
|
||||
id: entry.id,
|
||||
metadata: entry.metadata
|
||||
}));
|
||||
|
||||
const ids = vectorDB.insertBatch(processedEntries);
|
||||
|
||||
self.postMessage({
|
||||
type: 'insertBatch',
|
||||
requestId,
|
||||
data: ids
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle vector search
|
||||
*/
|
||||
async function handleSearch(data) {
|
||||
const { query, k, filter, requestId } = data;
|
||||
|
||||
// Convert query to Float32Array
|
||||
const queryArray = new Float32Array(query);
|
||||
|
||||
const results = vectorDB.search(queryArray, k, filter);
|
||||
|
||||
// Convert results to plain objects
|
||||
const plainResults = results.map(result => ({
|
||||
id: result.id,
|
||||
score: result.score,
|
||||
vector: result.vector ? Array.from(result.vector) : null,
|
||||
metadata: result.metadata
|
||||
}));
|
||||
|
||||
self.postMessage({
|
||||
type: 'search',
|
||||
requestId,
|
||||
data: plainResults
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle delete operation
|
||||
*/
|
||||
async function handleDelete(data) {
|
||||
const { id, requestId } = data;
|
||||
|
||||
const deleted = vectorDB.delete(id);
|
||||
|
||||
self.postMessage({
|
||||
type: 'delete',
|
||||
requestId,
|
||||
data: deleted
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle get operation
|
||||
*/
|
||||
async function handleGet(data) {
|
||||
const { id, requestId } = data;
|
||||
|
||||
const entry = vectorDB.get(id);
|
||||
|
||||
const plainEntry = entry ? {
|
||||
id: entry.id,
|
||||
vector: Array.from(entry.vector),
|
||||
metadata: entry.metadata
|
||||
} : null;
|
||||
|
||||
self.postMessage({
|
||||
type: 'get',
|
||||
requestId,
|
||||
data: plainEntry
|
||||
});
|
||||
}
|
||||
160
vendor/ruvector/crates/ruvector-wasm/tests/wasm.rs
vendored
Normal file
160
vendor/ruvector/crates/ruvector-wasm/tests/wasm.rs
vendored
Normal file
@@ -0,0 +1,160 @@
|
||||
//! WASM-specific tests
|
||||
|
||||
#![cfg(target_arch = "wasm32")]
|
||||
|
||||
use js_sys::Float32Array;
|
||||
use ruvector_wasm::*;
|
||||
use wasm_bindgen_test::*;
|
||||
|
||||
wasm_bindgen_test_configure!(run_in_browser);
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_vector_db_creation() {
|
||||
let db = VectorDB::new(128, Some("cosine".to_string()), Some(false));
|
||||
assert!(db.is_ok());
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_insert_and_search() {
|
||||
let db = VectorDB::new(3, Some("euclidean".to_string()), Some(false)).unwrap();
|
||||
|
||||
// Insert a vector
|
||||
let vector = Float32Array::from(&[1.0, 0.0, 0.0][..]);
|
||||
let id = db.insert(vector, Some("test1".to_string()), None);
|
||||
assert!(id.is_ok());
|
||||
|
||||
// Search
|
||||
let query = Float32Array::from(&[1.0, 0.0, 0.0][..]);
|
||||
let results = db.search(query, 1, None);
|
||||
assert!(results.is_ok());
|
||||
|
||||
let results = results.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_batch_insert() {
|
||||
let db = VectorDB::new(3, Some("cosine".to_string()), Some(false)).unwrap();
|
||||
|
||||
let entries = js_sys::Array::new();
|
||||
|
||||
for i in 0..10 {
|
||||
let entry = js_sys::Object::new();
|
||||
let vector = Float32Array::from(&[i as f32, 0.0, 0.0][..]);
|
||||
js_sys::Reflect::set(&entry, &"vector".into(), &vector).unwrap();
|
||||
js_sys::Reflect::set(&entry, &"id".into(), &format!("vec_{}", i).into()).unwrap();
|
||||
entries.push(&entry);
|
||||
}
|
||||
|
||||
let result = db.insert_batch(entries.into());
|
||||
assert!(result.is_ok());
|
||||
|
||||
let ids = result.unwrap();
|
||||
assert_eq!(ids.len(), 10);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_delete() {
|
||||
let db = VectorDB::new(3, Some("cosine".to_string()), Some(false)).unwrap();
|
||||
|
||||
// Insert
|
||||
let vector = Float32Array::from(&[1.0, 0.0, 0.0][..]);
|
||||
let id = db
|
||||
.insert(vector, Some("test_delete".to_string()), None)
|
||||
.unwrap();
|
||||
|
||||
// Delete
|
||||
let deleted = db.delete(&id);
|
||||
assert!(deleted.is_ok());
|
||||
assert_eq!(deleted.unwrap(), true);
|
||||
|
||||
// Verify deleted
|
||||
let get_result = db.get(&id);
|
||||
assert!(get_result.is_ok());
|
||||
assert!(get_result.unwrap().is_none());
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_get() {
|
||||
let db = VectorDB::new(3, Some("cosine".to_string()), Some(false)).unwrap();
|
||||
|
||||
// Insert
|
||||
let vector = Float32Array::from(&[1.0, 2.0, 3.0][..]);
|
||||
let id = db
|
||||
.insert(vector, Some("test_get".to_string()), None)
|
||||
.unwrap();
|
||||
|
||||
// Get
|
||||
let entry = db.get(&id);
|
||||
assert!(entry.is_ok());
|
||||
|
||||
let entry = entry.unwrap();
|
||||
assert!(entry.is_some());
|
||||
|
||||
let entry = entry.unwrap();
|
||||
assert_eq!(entry.id(), Some("test_get".to_string()));
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_len_and_is_empty() {
|
||||
let db = VectorDB::new(3, Some("cosine".to_string()), Some(false)).unwrap();
|
||||
|
||||
// Initially empty
|
||||
assert!(db.is_empty().unwrap());
|
||||
assert_eq!(db.len().unwrap(), 0);
|
||||
|
||||
// Insert vector
|
||||
let vector = Float32Array::from(&[1.0, 0.0, 0.0][..]);
|
||||
db.insert(vector, Some("test1".to_string()), None).unwrap();
|
||||
|
||||
// Not empty
|
||||
assert!(!db.is_empty().unwrap());
|
||||
assert_eq!(db.len().unwrap(), 1);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_different_metrics() {
|
||||
for metric in &["euclidean", "cosine", "dotproduct", "manhattan"] {
|
||||
let db = VectorDB::new(3, Some(metric.to_string()), Some(false));
|
||||
assert!(db.is_ok(), "Failed to create DB with metric: {}", metric);
|
||||
}
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_dimension_mismatch() {
|
||||
let db = VectorDB::new(3, Some("cosine".to_string()), Some(false)).unwrap();
|
||||
|
||||
// Try to insert vector with wrong dimensions
|
||||
let vector = Float32Array::from(&[1.0, 0.0][..]); // Only 2 dimensions
|
||||
let result = db.insert(vector, Some("test_wrong_dim".to_string()), None);
|
||||
|
||||
// Should fail due to dimension mismatch
|
||||
// Note: This might succeed depending on implementation
|
||||
// The search with wrong dimensions should definitely fail
|
||||
let query = Float32Array::from(&[1.0, 0.0][..]);
|
||||
let search_result = db.search(query, 1, None);
|
||||
assert!(search_result.is_err());
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_version() {
|
||||
let v = version();
|
||||
assert!(!v.is_empty());
|
||||
assert!(v.contains('.'));
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_detect_simd() {
|
||||
// Just ensure it doesn't panic
|
||||
let _ = detect_simd();
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_array_to_float32_array() {
|
||||
let arr = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let float_arr = array_to_float32_array(arr.clone());
|
||||
|
||||
assert_eq!(float_arr.length(), 4);
|
||||
assert_eq!(float_arr.get_index(0), 1.0);
|
||||
assert_eq!(float_arr.get_index(3), 4.0);
|
||||
}
|
||||
Reference in New Issue
Block a user