Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,355 @@
/**
* IndexedDB Persistence Layer for Ruvector
*
* Provides:
* - Save/load database state to IndexedDB
* - Batch operations for performance
* - Progressive loading with pagination
* - LRU cache for hot vectors
*/
const DB_NAME = 'ruvector_storage';
const DB_VERSION = 1;
const VECTOR_STORE = 'vectors';
const META_STORE = 'metadata';
/**
* LRU Cache for hot vectors
*/
class LRUCache {
constructor(capacity = 1000) {
this.capacity = capacity;
this.cache = new Map();
}
get(key) {
if (!this.cache.has(key)) return null;
// Move to end (most recently used)
const value = this.cache.get(key);
this.cache.delete(key);
this.cache.set(key, value);
return value;
}
set(key, value) {
// Remove if exists
if (this.cache.has(key)) {
this.cache.delete(key);
}
// Add to end
this.cache.set(key, value);
// Evict oldest if over capacity
if (this.cache.size > this.capacity) {
const firstKey = this.cache.keys().next().value;
this.cache.delete(firstKey);
}
}
has(key) {
return this.cache.has(key);
}
clear() {
this.cache.clear();
}
get size() {
return this.cache.size;
}
}
/**
* IndexedDB Persistence Manager
*/
export class IndexedDBPersistence {
constructor(dbName = null) {
this.dbName = dbName || DB_NAME;
this.db = null;
this.cache = new LRUCache(1000);
}
/**
* Open IndexedDB connection
*/
async open() {
return new Promise((resolve, reject) => {
const request = indexedDB.open(this.dbName, DB_VERSION);
request.onerror = () => reject(request.error);
request.onsuccess = () => {
this.db = request.result;
resolve(this.db);
};
request.onupgradeneeded = (event) => {
const db = event.target.result;
// Create object stores if they don't exist
if (!db.objectStoreNames.contains(VECTOR_STORE)) {
const vectorStore = db.createObjectStore(VECTOR_STORE, { keyPath: 'id' });
vectorStore.createIndex('timestamp', 'timestamp', { unique: false });
}
if (!db.objectStoreNames.contains(META_STORE)) {
db.createObjectStore(META_STORE, { keyPath: 'key' });
}
};
});
}
/**
* Save a single vector
*/
async saveVector(id, vector, metadata = null) {
if (!this.db) await this.open();
return new Promise((resolve, reject) => {
const transaction = this.db.transaction([VECTOR_STORE], 'readwrite');
const store = transaction.objectStore(VECTOR_STORE);
const data = {
id,
vector: Array.from(vector), // Convert Float32Array to regular array
metadata,
timestamp: Date.now()
};
const request = store.put(data);
request.onsuccess = () => {
this.cache.set(id, data);
resolve(id);
};
request.onerror = () => reject(request.error);
});
}
/**
* Save vectors in batch (more efficient)
*/
async saveBatch(entries, batchSize = 100) {
if (!this.db) await this.open();
const chunks = [];
for (let i = 0; i < entries.length; i += batchSize) {
chunks.push(entries.slice(i, i + batchSize));
}
for (const chunk of chunks) {
await new Promise((resolve, reject) => {
const transaction = this.db.transaction([VECTOR_STORE], 'readwrite');
const store = transaction.objectStore(VECTOR_STORE);
for (const entry of chunk) {
const data = {
id: entry.id,
vector: Array.from(entry.vector),
metadata: entry.metadata,
timestamp: Date.now()
};
store.put(data);
this.cache.set(entry.id, data);
}
transaction.oncomplete = () => resolve();
transaction.onerror = () => reject(transaction.error);
});
}
return entries.length;
}
/**
* Load a single vector by ID
*/
async loadVector(id) {
// Check cache first
if (this.cache.has(id)) {
return this.cache.get(id);
}
if (!this.db) await this.open();
return new Promise((resolve, reject) => {
const transaction = this.db.transaction([VECTOR_STORE], 'readonly');
const store = transaction.objectStore(VECTOR_STORE);
const request = store.get(id);
request.onsuccess = () => {
const data = request.result;
if (data) {
// Convert array back to Float32Array
data.vector = new Float32Array(data.vector);
this.cache.set(id, data);
}
resolve(data);
};
request.onerror = () => reject(request.error);
});
}
/**
* Load all vectors (with progressive loading)
*/
async loadAll(onProgress = null, batchSize = 100) {
if (!this.db) await this.open();
return new Promise((resolve, reject) => {
const transaction = this.db.transaction([VECTOR_STORE], 'readonly');
const store = transaction.objectStore(VECTOR_STORE);
const request = store.openCursor();
const vectors = [];
let count = 0;
request.onsuccess = (event) => {
const cursor = event.target.result;
if (cursor) {
const data = cursor.value;
data.vector = new Float32Array(data.vector);
vectors.push(data);
count++;
// Cache hot vectors (first 1000)
if (count <= 1000) {
this.cache.set(data.id, data);
}
// Report progress every batch
if (onProgress && count % batchSize === 0) {
onProgress({
loaded: count,
vectors: [...vectors]
});
vectors.length = 0; // Clear batch
}
cursor.continue();
} else {
// Done
if (onProgress && vectors.length > 0) {
onProgress({
loaded: count,
vectors: vectors,
complete: true
});
}
resolve({ count, complete: true });
}
};
request.onerror = () => reject(request.error);
});
}
/**
* Delete a vector by ID
*/
async deleteVector(id) {
if (!this.db) await this.open();
this.cache.delete(id);
return new Promise((resolve, reject) => {
const transaction = this.db.transaction([VECTOR_STORE], 'readwrite');
const store = transaction.objectStore(VECTOR_STORE);
const request = store.delete(id);
request.onsuccess = () => resolve(true);
request.onerror = () => reject(request.error);
});
}
/**
* Clear all vectors
*/
async clear() {
if (!this.db) await this.open();
this.cache.clear();
return new Promise((resolve, reject) => {
const transaction = this.db.transaction([VECTOR_STORE], 'readwrite');
const store = transaction.objectStore(VECTOR_STORE);
const request = store.clear();
request.onsuccess = () => resolve();
request.onerror = () => reject(request.error);
});
}
/**
* Get database statistics
*/
async getStats() {
if (!this.db) await this.open();
return new Promise((resolve, reject) => {
const transaction = this.db.transaction([VECTOR_STORE], 'readonly');
const store = transaction.objectStore(VECTOR_STORE);
const request = store.count();
request.onsuccess = () => {
resolve({
totalVectors: request.result,
cacheSize: this.cache.size,
cacheHitRate: this.cache.size / request.result
});
};
request.onerror = () => reject(request.error);
});
}
/**
* Save metadata
*/
async saveMeta(key, value) {
if (!this.db) await this.open();
return new Promise((resolve, reject) => {
const transaction = this.db.transaction([META_STORE], 'readwrite');
const store = transaction.objectStore(META_STORE);
const request = store.put({ key, value });
request.onsuccess = () => resolve();
request.onerror = () => reject(request.error);
});
}
/**
* Load metadata
*/
async loadMeta(key) {
if (!this.db) await this.open();
return new Promise((resolve, reject) => {
const transaction = this.db.transaction([META_STORE], 'readonly');
const store = transaction.objectStore(META_STORE);
const request = store.get(key);
request.onsuccess = () => {
const data = request.result;
resolve(data ? data.value : null);
};
request.onerror = () => reject(request.error);
});
}
/**
* Close the database connection
*/
close() {
if (this.db) {
this.db.close();
this.db = null;
}
}
}
export default IndexedDBPersistence;

View File

@@ -0,0 +1,334 @@
//! Trusted Kernel Allowlist
//!
//! Maintains a list of approved kernel hashes for additional security.
//! This provides defense-in-depth beyond signature verification.
use crate::kernel::error::VerifyError;
use std::collections::{HashMap, HashSet};
/// Trusted kernel allowlist
///
/// Maintains approved kernel hashes organized by kernel ID.
/// Even if a kernel has a valid signature, it must be in the allowlist
/// to be executed (when allowlist enforcement is enabled).
#[derive(Debug, Clone)]
pub struct TrustedKernelAllowlist {
/// Set of approved kernel hashes (format: "sha256:...")
approved_hashes: HashSet<String>,
/// Map of kernel_id -> approved hashes for that kernel
kernel_hashes: HashMap<String, HashSet<String>>,
/// Whether to enforce allowlist (can be disabled for development)
enforce: bool,
/// Allowlist version/update timestamp
version: String,
}
impl TrustedKernelAllowlist {
/// Create a new empty allowlist
pub fn new() -> Self {
TrustedKernelAllowlist {
approved_hashes: HashSet::new(),
kernel_hashes: HashMap::new(),
enforce: true,
version: "1.0.0".to_string(),
}
}
/// Create an allowlist that doesn't enforce checks (for development)
///
/// # Warning
/// This should NEVER be used in production.
pub fn insecure_allow_all() -> Self {
TrustedKernelAllowlist {
approved_hashes: HashSet::new(),
kernel_hashes: HashMap::new(),
enforce: false,
version: "dev".to_string(),
}
}
/// Load allowlist from JSON
pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
#[derive(serde::Deserialize)]
struct AllowlistJson {
version: String,
kernels: HashMap<String, Vec<String>>,
}
let parsed: AllowlistJson = serde_json::from_str(json)?;
let mut allowlist = TrustedKernelAllowlist::new();
allowlist.version = parsed.version;
for (kernel_id, hashes) in parsed.kernels {
for hash in hashes {
allowlist.add_kernel_hash(&kernel_id, &hash);
}
}
Ok(allowlist)
}
/// Serialize allowlist to JSON
pub fn to_json(&self) -> Result<String, serde_json::Error> {
#[derive(serde::Serialize)]
struct AllowlistJson {
version: String,
kernels: HashMap<String, Vec<String>>,
}
let kernels: HashMap<String, Vec<String>> = self
.kernel_hashes
.iter()
.map(|(k, v)| (k.clone(), v.iter().cloned().collect()))
.collect();
let json = AllowlistJson {
version: self.version.clone(),
kernels,
};
serde_json::to_string_pretty(&json)
}
/// Add a hash to the global approved set
pub fn add_hash(&mut self, hash: &str) {
self.approved_hashes.insert(hash.to_lowercase());
}
/// Add a hash for a specific kernel ID
pub fn add_kernel_hash(&mut self, kernel_id: &str, hash: &str) {
let lowercase_hash = hash.to_lowercase();
self.approved_hashes.insert(lowercase_hash.clone());
self.kernel_hashes
.entry(kernel_id.to_string())
.or_insert_with(HashSet::new)
.insert(lowercase_hash);
}
/// Remove a hash from the allowlist
pub fn remove_hash(&mut self, hash: &str) {
let lowercase_hash = hash.to_lowercase();
self.approved_hashes.remove(&lowercase_hash);
for hashes in self.kernel_hashes.values_mut() {
hashes.remove(&lowercase_hash);
}
}
/// Check if a hash is in the allowlist
pub fn is_allowed(&self, hash: &str) -> bool {
if !self.enforce {
return true;
}
self.approved_hashes.contains(&hash.to_lowercase())
}
/// Check if a hash is allowed for a specific kernel ID
pub fn is_allowed_for_kernel(&self, kernel_id: &str, hash: &str) -> bool {
if !self.enforce {
return true;
}
let lowercase_hash = hash.to_lowercase();
// Check kernel-specific allowlist first
if let Some(kernel_hashes) = self.kernel_hashes.get(kernel_id) {
return kernel_hashes.contains(&lowercase_hash);
}
// Fall back to global allowlist
self.approved_hashes.contains(&lowercase_hash)
}
/// Verify a kernel is in the allowlist
pub fn verify(&self, kernel_id: &str, hash: &str) -> Result<(), VerifyError> {
if self.is_allowed_for_kernel(kernel_id, hash) {
Ok(())
} else {
Err(VerifyError::NotInAllowlist {
kernel_id: kernel_id.to_string(),
})
}
}
/// Get number of approved hashes
pub fn hash_count(&self) -> usize {
self.approved_hashes.len()
}
/// Get all approved hashes for a kernel ID
pub fn get_kernel_hashes(&self, kernel_id: &str) -> Option<&HashSet<String>> {
self.kernel_hashes.get(kernel_id)
}
/// List all kernel IDs with approved hashes
pub fn kernel_ids(&self) -> Vec<&str> {
self.kernel_hashes.keys().map(|s| s.as_str()).collect()
}
/// Get allowlist version
pub fn version(&self) -> &str {
&self.version
}
/// Set allowlist version
pub fn set_version(&mut self, version: &str) {
self.version = version.to_string();
}
/// Check if enforcement is enabled
pub fn is_enforced(&self) -> bool {
self.enforce
}
/// Merge another allowlist into this one
pub fn merge(&mut self, other: &TrustedKernelAllowlist) {
for hash in &other.approved_hashes {
self.approved_hashes.insert(hash.clone());
}
for (kernel_id, hashes) in &other.kernel_hashes {
let entry = self
.kernel_hashes
.entry(kernel_id.clone())
.or_insert_with(HashSet::new);
for hash in hashes {
entry.insert(hash.clone());
}
}
}
}
impl Default for TrustedKernelAllowlist {
fn default() -> Self {
Self::new()
}
}
/// Built-in allowlist of official RuvLLM kernels
///
/// This provides a starting point with known-good kernel hashes.
/// Production deployments should maintain their own allowlist.
pub fn builtin_allowlist() -> TrustedKernelAllowlist {
let mut allowlist = TrustedKernelAllowlist::new();
allowlist.set_version("0.1.0-builtin");
// Add placeholders for official kernels
// These would be replaced with actual hashes in production
// allowlist.add_kernel_hash("rope_f32", "sha256:...");
// allowlist.add_kernel_hash("rmsnorm_f32", "sha256:...");
// allowlist.add_kernel_hash("swiglu_f32", "sha256:...");
allowlist
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_add_and_check_hash() {
let mut allowlist = TrustedKernelAllowlist::new();
let hash = "sha256:abc123def456";
assert!(!allowlist.is_allowed(hash));
allowlist.add_hash(hash);
assert!(allowlist.is_allowed(hash));
// Case insensitive
assert!(allowlist.is_allowed("SHA256:ABC123DEF456"));
}
#[test]
fn test_kernel_specific_hash() {
let mut allowlist = TrustedKernelAllowlist::new();
allowlist.add_kernel_hash("rope_f32", "sha256:rope_hash");
allowlist.add_kernel_hash("rmsnorm_f32", "sha256:rmsnorm_hash");
assert!(allowlist.is_allowed_for_kernel("rope_f32", "sha256:rope_hash"));
assert!(!allowlist.is_allowed_for_kernel("rope_f32", "sha256:rmsnorm_hash"));
assert!(allowlist.is_allowed_for_kernel("rmsnorm_f32", "sha256:rmsnorm_hash"));
}
#[test]
fn test_verify() {
let mut allowlist = TrustedKernelAllowlist::new();
allowlist.add_kernel_hash("rope_f32", "sha256:valid_hash");
assert!(allowlist.verify("rope_f32", "sha256:valid_hash").is_ok());
assert!(matches!(
allowlist.verify("rope_f32", "sha256:invalid_hash"),
Err(VerifyError::NotInAllowlist { .. })
));
}
#[test]
fn test_insecure_allow_all() {
let allowlist = TrustedKernelAllowlist::insecure_allow_all();
// Should allow any hash when not enforcing
assert!(allowlist.is_allowed("sha256:anything"));
assert!(allowlist.is_allowed_for_kernel("any_kernel", "sha256:anything"));
assert!(!allowlist.is_enforced());
}
#[test]
fn test_remove_hash() {
let mut allowlist = TrustedKernelAllowlist::new();
allowlist.add_kernel_hash("kernel", "sha256:hash");
assert!(allowlist.is_allowed("sha256:hash"));
allowlist.remove_hash("sha256:hash");
assert!(!allowlist.is_allowed("sha256:hash"));
}
#[test]
fn test_json_roundtrip() {
let mut original = TrustedKernelAllowlist::new();
original.set_version("1.2.3");
original.add_kernel_hash("rope_f32", "sha256:hash1");
original.add_kernel_hash("rope_f32", "sha256:hash2");
original.add_kernel_hash("rmsnorm_f32", "sha256:hash3");
let json = original.to_json().unwrap();
let restored = TrustedKernelAllowlist::from_json(&json).unwrap();
assert_eq!(restored.version(), "1.2.3");
assert!(restored.is_allowed_for_kernel("rope_f32", "sha256:hash1"));
assert!(restored.is_allowed_for_kernel("rope_f32", "sha256:hash2"));
assert!(restored.is_allowed_for_kernel("rmsnorm_f32", "sha256:hash3"));
}
#[test]
fn test_merge() {
let mut allowlist1 = TrustedKernelAllowlist::new();
allowlist1.add_kernel_hash("kernel1", "sha256:hash1");
let mut allowlist2 = TrustedKernelAllowlist::new();
allowlist2.add_kernel_hash("kernel2", "sha256:hash2");
allowlist1.merge(&allowlist2);
assert!(allowlist1.is_allowed_for_kernel("kernel1", "sha256:hash1"));
assert!(allowlist1.is_allowed_for_kernel("kernel2", "sha256:hash2"));
}
#[test]
fn test_kernel_ids() {
let mut allowlist = TrustedKernelAllowlist::new();
allowlist.add_kernel_hash("kernel_a", "sha256:a");
allowlist.add_kernel_hash("kernel_b", "sha256:b");
let ids = allowlist.kernel_ids();
assert!(ids.contains(&"kernel_a"));
assert!(ids.contains(&"kernel_b"));
}
}

View File

@@ -0,0 +1,314 @@
//! Epoch-Based Interruption
//!
//! Provides execution budget management using Wasmtime's epoch mechanism.
//! This allows coarse-grained interruption of WASM execution with minimal overhead.
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
/// Epoch controller for managing execution budgets
///
/// The epoch mechanism works by periodically incrementing a counter.
/// WASM code checks this counter at certain points (function calls, loops)
/// and traps if the deadline has been exceeded.
#[derive(Debug, Clone)]
pub struct EpochController {
/// Current epoch value
current_epoch: Arc<AtomicU64>,
/// Tick interval
tick_interval: Duration,
/// Whether the controller is running
running: Arc<std::sync::atomic::AtomicBool>,
}
impl EpochController {
/// Create a new epoch controller
///
/// # Arguments
/// * `tick_interval` - How often to increment the epoch (e.g., 10ms)
pub fn new(tick_interval: Duration) -> Self {
EpochController {
current_epoch: Arc::new(AtomicU64::new(0)),
tick_interval,
running: Arc::new(std::sync::atomic::AtomicBool::new(false)),
}
}
/// Create with default 10ms tick interval
pub fn default_interval() -> Self {
Self::new(Duration::from_millis(10))
}
/// Get current epoch value
pub fn current(&self) -> u64 {
self.current_epoch.load(Ordering::Relaxed)
}
/// Manually increment the epoch
pub fn increment(&self) {
self.current_epoch.fetch_add(1, Ordering::Relaxed);
}
/// Reset epoch to zero
pub fn reset(&self) {
self.current_epoch.store(0, Ordering::Relaxed);
}
/// Get tick interval
pub fn tick_interval(&self) -> Duration {
self.tick_interval
}
/// Check if the controller is running
pub fn is_running(&self) -> bool {
self.running.load(Ordering::Relaxed)
}
/// Get a clone of the epoch counter for sharing
pub fn epoch_counter(&self) -> Arc<AtomicU64> {
Arc::clone(&self.current_epoch)
}
/// Calculate deadline epoch for a given budget
///
/// # Arguments
/// * `budget_ticks` - Number of ticks before timeout
///
/// # Returns
/// The epoch value that represents the deadline
pub fn deadline_for_budget(&self, budget_ticks: u64) -> u64 {
self.current() + budget_ticks
}
/// Check if an epoch deadline has been exceeded
pub fn is_deadline_exceeded(&self, deadline: u64) -> bool {
self.current() >= deadline
}
/// Convert epoch ticks to approximate duration
pub fn ticks_to_duration(&self, ticks: u64) -> Duration {
self.tick_interval * ticks as u32
}
/// Convert duration to approximate epoch ticks
pub fn duration_to_ticks(&self, duration: Duration) -> u64 {
(duration.as_nanos() / self.tick_interval.as_nanos()) as u64
}
}
impl Default for EpochController {
fn default() -> Self {
Self::default_interval()
}
}
/// Configuration for epoch-based execution limits
#[derive(Debug, Clone, Copy)]
pub struct EpochConfig {
/// Enable epoch interruption
pub enabled: bool,
/// Tick interval in milliseconds
pub tick_interval_ms: u64,
/// Default budget in ticks
pub default_budget: u64,
/// Maximum allowed budget (prevents abuse)
pub max_budget: u64,
}
impl EpochConfig {
/// Create a new epoch configuration
pub fn new(tick_interval_ms: u64, default_budget: u64) -> Self {
EpochConfig {
enabled: true,
tick_interval_ms,
default_budget,
max_budget: default_budget * 10, // 10x default as max
}
}
/// Create configuration for server workloads (longer budgets)
pub fn server() -> Self {
EpochConfig {
enabled: true,
tick_interval_ms: 10,
default_budget: 1000, // 10 seconds
max_budget: 6000, // 60 seconds max
}
}
/// Create configuration for embedded/constrained workloads
pub fn embedded() -> Self {
EpochConfig {
enabled: true,
tick_interval_ms: 1,
default_budget: 100, // 100ms
max_budget: 1000, // 1 second max
}
}
/// Create configuration with interruption disabled (for benchmarking)
///
/// # Warning
/// Only use this for controlled benchmarking scenarios.
pub fn disabled() -> Self {
EpochConfig {
enabled: false,
tick_interval_ms: 10,
default_budget: u64::MAX,
max_budget: u64::MAX,
}
}
/// Get tick interval as Duration
pub fn tick_interval(&self) -> Duration {
Duration::from_millis(self.tick_interval_ms)
}
/// Clamp a requested budget to the allowed maximum
pub fn clamp_budget(&self, requested: u64) -> u64 {
requested.min(self.max_budget)
}
/// Convert budget ticks to approximate duration
pub fn budget_duration(&self, budget: u64) -> Duration {
Duration::from_millis(budget * self.tick_interval_ms)
}
}
impl Default for EpochConfig {
fn default() -> Self {
Self::server()
}
}
/// Epoch deadline tracker for a single kernel invocation
#[derive(Debug, Clone, Copy)]
pub struct EpochDeadline {
/// The epoch value at which execution should stop
pub deadline: u64,
/// The budget that was allocated
pub budget: u64,
/// When the execution started (epoch value)
pub start_epoch: u64,
}
impl EpochDeadline {
/// Create a new deadline
pub fn new(start_epoch: u64, budget: u64) -> Self {
EpochDeadline {
deadline: start_epoch + budget,
budget,
start_epoch,
}
}
/// Calculate elapsed ticks
pub fn elapsed(&self, current: u64) -> u64 {
current.saturating_sub(self.start_epoch)
}
/// Calculate remaining ticks
pub fn remaining(&self, current: u64) -> u64 {
self.deadline.saturating_sub(current)
}
/// Check if deadline is exceeded
pub fn is_exceeded(&self, current: u64) -> bool {
current >= self.deadline
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_epoch_controller() {
let controller = EpochController::default_interval();
assert_eq!(controller.current(), 0);
controller.increment();
assert_eq!(controller.current(), 1);
controller.increment();
assert_eq!(controller.current(), 2);
controller.reset();
assert_eq!(controller.current(), 0);
}
#[test]
fn test_deadline_calculation() {
let controller = EpochController::default_interval();
let deadline = controller.deadline_for_budget(100);
assert_eq!(deadline, 100);
assert!(!controller.is_deadline_exceeded(deadline));
// Simulate time passing
for _ in 0..100 {
controller.increment();
}
assert!(controller.is_deadline_exceeded(deadline));
}
#[test]
fn test_duration_conversion() {
let config = EpochConfig::new(10, 1000);
assert_eq!(config.budget_duration(100), Duration::from_secs(1));
let controller = EpochController::new(Duration::from_millis(10));
assert_eq!(controller.ticks_to_duration(100), Duration::from_secs(1));
assert_eq!(controller.duration_to_ticks(Duration::from_secs(1)), 100);
}
#[test]
fn test_epoch_config_clamp() {
let config = EpochConfig::new(10, 1000);
assert_eq!(config.max_budget, 10000);
assert_eq!(config.clamp_budget(500), 500);
assert_eq!(config.clamp_budget(20000), 10000);
}
#[test]
fn test_epoch_deadline() {
let deadline = EpochDeadline::new(10, 100);
assert_eq!(deadline.deadline, 110);
assert_eq!(deadline.elapsed(50), 40);
assert_eq!(deadline.remaining(50), 60);
assert!(!deadline.is_exceeded(50));
assert!(deadline.is_exceeded(110));
assert!(deadline.is_exceeded(200));
}
#[test]
fn test_server_config() {
let config = EpochConfig::server();
assert!(config.enabled);
assert_eq!(config.tick_interval_ms, 10);
assert_eq!(config.default_budget, 1000);
}
#[test]
fn test_embedded_config() {
let config = EpochConfig::embedded();
assert!(config.enabled);
assert_eq!(config.tick_interval_ms, 1);
assert_eq!(config.default_budget, 100);
}
#[test]
fn test_disabled_config() {
let config = EpochConfig::disabled();
assert!(!config.enabled);
}
}

View File

@@ -0,0 +1,368 @@
//! Error types for the kernel pack system
//!
//! Provides comprehensive error handling for kernel verification,
//! loading, and execution.
use std::fmt;
/// Errors that can occur during kernel execution
#[derive(Debug, Clone)]
pub enum KernelError {
/// Execution budget exceeded (epoch deadline reached)
EpochDeadline,
/// Out of bounds memory access
MemoryAccessViolation {
/// Attempted access offset
offset: u32,
/// Attempted access size
size: u32,
},
/// Integer overflow/underflow during computation
IntegerOverflow,
/// Unreachable code was executed
Unreachable,
/// Stack overflow in WASM execution
StackOverflow,
/// Indirect call type mismatch
IndirectCallTypeMismatch,
/// Custom trap from kernel with error code
KernelTrap {
/// Error code returned by kernel
code: u32,
/// Optional error message
message: Option<String>,
},
/// Kernel not found
KernelNotFound {
/// Requested kernel ID
kernel_id: String,
},
/// Invalid kernel parameters
InvalidParameters {
/// Description of the parameter error
description: String,
},
/// Tensor shape mismatch
ShapeMismatch {
/// Expected shape description
expected: String,
/// Actual shape description
actual: String,
},
/// Data type mismatch
DTypeMismatch {
/// Expected data type
expected: String,
/// Actual data type
actual: String,
},
/// Memory allocation failed
AllocationFailed {
/// Requested size in bytes
requested_bytes: usize,
},
/// Kernel initialization failed
InitializationFailed {
/// Reason for failure
reason: String,
},
/// Runtime error
RuntimeError {
/// Error message
message: String,
},
/// Feature not supported
UnsupportedFeature {
/// Feature name
feature: String,
},
}
impl fmt::Display for KernelError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
KernelError::EpochDeadline => {
write!(f, "Kernel execution exceeded time budget (epoch deadline)")
}
KernelError::MemoryAccessViolation { offset, size } => {
write!(
f,
"Memory access violation: offset={}, size={}",
offset, size
)
}
KernelError::IntegerOverflow => write!(f, "Integer overflow during computation"),
KernelError::Unreachable => write!(f, "Unreachable code executed"),
KernelError::StackOverflow => write!(f, "Stack overflow"),
KernelError::IndirectCallTypeMismatch => {
write!(f, "Indirect call type mismatch")
}
KernelError::KernelTrap { code, message } => {
write!(f, "Kernel trap (code={})", code)?;
if let Some(msg) = message {
write!(f, ": {}", msg)?;
}
Ok(())
}
KernelError::KernelNotFound { kernel_id } => {
write!(f, "Kernel not found: {}", kernel_id)
}
KernelError::InvalidParameters { description } => {
write!(f, "Invalid parameters: {}", description)
}
KernelError::ShapeMismatch { expected, actual } => {
write!(f, "Shape mismatch: expected {}, got {}", expected, actual)
}
KernelError::DTypeMismatch { expected, actual } => {
write!(f, "DType mismatch: expected {}, got {}", expected, actual)
}
KernelError::AllocationFailed { requested_bytes } => {
write!(f, "Memory allocation failed: {} bytes", requested_bytes)
}
KernelError::InitializationFailed { reason } => {
write!(f, "Kernel initialization failed: {}", reason)
}
KernelError::RuntimeError { message } => {
write!(f, "Runtime error: {}", message)
}
KernelError::UnsupportedFeature { feature } => {
write!(f, "Unsupported feature: {}", feature)
}
}
}
}
impl std::error::Error for KernelError {}
/// Errors that can occur during kernel verification
#[derive(Debug, Clone)]
pub enum VerifyError {
/// No trusted signing key matched
NoTrustedKey,
/// Signature is invalid
InvalidSignature {
/// Description of the signature error
reason: String,
},
/// Hash mismatch
HashMismatch {
/// Expected hash
expected: String,
/// Actual computed hash
actual: String,
},
/// Manifest parsing failed
InvalidManifest {
/// Error message
message: String,
},
/// Version incompatibility
VersionIncompatible {
/// Required version range
required: String,
/// Actual version
actual: String,
},
/// Runtime too old for kernel pack
RuntimeTooOld {
/// Minimum required version
required: String,
/// Actual runtime version
actual: String,
},
/// Runtime too new for kernel pack
RuntimeTooNew {
/// Maximum supported version
max_supported: String,
/// Actual runtime version
actual: String,
},
/// Missing required WASM feature
MissingFeature {
/// Kernel that requires the feature
kernel: String,
/// Missing feature name
feature: String,
},
/// Kernel not in allowlist
NotInAllowlist {
/// Kernel ID
kernel_id: String,
},
/// File I/O error
IoError {
/// Error message
message: String,
},
/// Key parsing error
KeyError {
/// Error message
message: String,
},
}
impl fmt::Display for VerifyError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
VerifyError::NoTrustedKey => {
write!(f, "No trusted signing key matched the manifest signature")
}
VerifyError::InvalidSignature { reason } => {
write!(f, "Invalid signature: {}", reason)
}
VerifyError::HashMismatch { expected, actual } => {
write!(f, "Hash mismatch: expected {}, got {}", expected, actual)
}
VerifyError::InvalidManifest { message } => {
write!(f, "Invalid manifest: {}", message)
}
VerifyError::VersionIncompatible { required, actual } => {
write!(
f,
"Version incompatible: required {}, got {}",
required, actual
)
}
VerifyError::RuntimeTooOld { required, actual } => {
write!(f, "Runtime too old: requires {}, have {}", required, actual)
}
VerifyError::RuntimeTooNew {
max_supported,
actual,
} => {
write!(
f,
"Runtime too new: max supported {}, have {}",
max_supported, actual
)
}
VerifyError::MissingFeature { kernel, feature } => {
write!(
f,
"Kernel '{}' requires missing feature: {}",
kernel, feature
)
}
VerifyError::NotInAllowlist { kernel_id } => {
write!(f, "Kernel '{}' not in allowlist", kernel_id)
}
VerifyError::IoError { message } => write!(f, "I/O error: {}", message),
VerifyError::KeyError { message } => write!(f, "Key error: {}", message),
}
}
}
impl std::error::Error for VerifyError {}
/// Result type alias for kernel operations
pub type KernelResult<T> = Result<T, KernelError>;
/// Result type alias for verification operations
pub type VerifyResult<T> = Result<T, VerifyError>;
/// Standard kernel error codes (returned by kernel_forward/kernel_backward)
#[repr(u32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum KernelErrorCode {
/// Success
Ok = 0,
/// Invalid input tensor
InvalidInput = 1,
/// Invalid output tensor
InvalidOutput = 2,
/// Invalid kernel parameters
InvalidParams = 3,
/// Out of memory
OutOfMemory = 4,
/// Operation not implemented
NotImplemented = 5,
/// Internal kernel error
InternalError = 6,
}
impl From<u32> for KernelErrorCode {
fn from(code: u32) -> Self {
match code {
0 => KernelErrorCode::Ok,
1 => KernelErrorCode::InvalidInput,
2 => KernelErrorCode::InvalidOutput,
3 => KernelErrorCode::InvalidParams,
4 => KernelErrorCode::OutOfMemory,
5 => KernelErrorCode::NotImplemented,
_ => KernelErrorCode::InternalError,
}
}
}
impl fmt::Display for KernelErrorCode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
KernelErrorCode::Ok => write!(f, "OK"),
KernelErrorCode::InvalidInput => write!(f, "Invalid input tensor"),
KernelErrorCode::InvalidOutput => write!(f, "Invalid output tensor"),
KernelErrorCode::InvalidParams => write!(f, "Invalid parameters"),
KernelErrorCode::OutOfMemory => write!(f, "Out of memory"),
KernelErrorCode::NotImplemented => write!(f, "Not implemented"),
KernelErrorCode::InternalError => write!(f, "Internal error"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_kernel_error_display() {
let err = KernelError::EpochDeadline;
assert!(err.to_string().contains("epoch deadline"));
let err = KernelError::MemoryAccessViolation {
offset: 100,
size: 64,
};
assert!(err.to_string().contains("100"));
assert!(err.to_string().contains("64"));
}
#[test]
fn test_verify_error_display() {
let err = VerifyError::HashMismatch {
expected: "abc123".to_string(),
actual: "def456".to_string(),
};
assert!(err.to_string().contains("abc123"));
assert!(err.to_string().contains("def456"));
}
#[test]
fn test_error_code_conversion() {
assert_eq!(KernelErrorCode::from(0), KernelErrorCode::Ok);
assert_eq!(KernelErrorCode::from(1), KernelErrorCode::InvalidInput);
assert_eq!(KernelErrorCode::from(100), KernelErrorCode::InternalError);
}
}

View File

@@ -0,0 +1,176 @@
//! SHA256 Hash Verification
//!
//! Provides hash verification for WASM kernel files to ensure integrity.
use crate::kernel::error::VerifyError;
use sha2::{Digest, Sha256};
/// Hash verifier for kernel files
#[derive(Debug, Clone)]
pub struct HashVerifier {
/// Expected hash format prefix (e.g., "sha256:")
prefix: String,
}
impl HashVerifier {
/// Create a new SHA256 hash verifier
pub fn sha256() -> Self {
HashVerifier {
prefix: "sha256:".to_string(),
}
}
/// Compute SHA256 hash of data
pub fn compute_hash(data: &[u8]) -> String {
let mut hasher = Sha256::new();
hasher.update(data);
let result = hasher.finalize();
format!("sha256:{:x}", result)
}
/// Verify kernel data against expected hash
///
/// # Arguments
/// * `kernel_bytes` - The raw WASM kernel bytes
/// * `expected_hash` - Expected hash string (format: "sha256:...")
///
/// # Returns
/// * `Ok(())` if hash matches
/// * `Err(VerifyError::HashMismatch)` if hash doesn't match
pub fn verify(&self, kernel_bytes: &[u8], expected_hash: &str) -> Result<(), VerifyError> {
// Validate expected hash format
if !expected_hash.starts_with(&self.prefix) {
return Err(VerifyError::InvalidManifest {
message: format!(
"Invalid hash format: expected '{}' prefix, got '{}'",
self.prefix,
expected_hash.get(..10).unwrap_or(expected_hash)
),
});
}
let actual_hash = Self::compute_hash(kernel_bytes);
if actual_hash.eq_ignore_ascii_case(expected_hash) {
Ok(())
} else {
Err(VerifyError::HashMismatch {
expected: expected_hash.to_string(),
actual: actual_hash,
})
}
}
/// Verify multiple kernels in batch
///
/// # Arguments
/// * `kernels` - Iterator of (kernel_bytes, expected_hash) pairs
///
/// # Returns
/// * `Ok(())` if all hashes match
/// * `Err` with first mismatch
pub fn verify_batch<'a>(
&self,
kernels: impl Iterator<Item = (&'a [u8], &'a str)>,
) -> Result<(), VerifyError> {
for (bytes, expected) in kernels {
self.verify(bytes, expected)?;
}
Ok(())
}
}
impl Default for HashVerifier {
fn default() -> Self {
Self::sha256()
}
}
/// Compute hash for a kernel file and return formatted string
pub fn hash_kernel(kernel_bytes: &[u8]) -> String {
HashVerifier::compute_hash(kernel_bytes)
}
/// Verify a kernel file against expected hash (convenience function)
pub fn verify_kernel_hash(kernel_bytes: &[u8], expected_hash: &str) -> Result<(), VerifyError> {
HashVerifier::sha256().verify(kernel_bytes, expected_hash)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compute_hash() {
let data = b"hello world";
let hash = HashVerifier::compute_hash(data);
assert!(hash.starts_with("sha256:"));
// Known SHA256 of "hello world"
assert!(hash.contains("b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9"));
}
#[test]
fn test_verify_success() {
let data = b"test kernel data";
let hash = HashVerifier::compute_hash(data);
let verifier = HashVerifier::sha256();
assert!(verifier.verify(data, &hash).is_ok());
}
#[test]
fn test_verify_case_insensitive() {
let data = b"test kernel data";
let hash = HashVerifier::compute_hash(data);
let upper_hash = hash.to_uppercase();
let verifier = HashVerifier::sha256();
assert!(verifier.verify(data, &upper_hash).is_ok());
}
#[test]
fn test_verify_mismatch() {
let data = b"actual data";
let wrong_hash = "sha256:0000000000000000000000000000000000000000000000000000000000000000";
let verifier = HashVerifier::sha256();
let result = verifier.verify(data, wrong_hash);
assert!(matches!(result, Err(VerifyError::HashMismatch { .. })));
}
#[test]
fn test_verify_invalid_format() {
let data = b"test data";
let invalid_hash = "md5:abc123";
let verifier = HashVerifier::sha256();
let result = verifier.verify(data, invalid_hash);
assert!(matches!(result, Err(VerifyError::InvalidManifest { .. })));
}
#[test]
fn test_verify_batch() {
let data1 = b"kernel1";
let data2 = b"kernel2";
let hash1 = HashVerifier::compute_hash(data1);
let hash2 = HashVerifier::compute_hash(data2);
let verifier = HashVerifier::sha256();
let kernels = vec![
(data1.as_slice(), hash1.as_str()),
(data2.as_slice(), hash2.as_str()),
];
assert!(verifier.verify_batch(kernels.into_iter()).is_ok());
}
#[test]
fn test_convenience_function() {
let data = b"convenience test";
let hash = hash_kernel(data);
assert!(verify_kernel_hash(data, &hash).is_ok());
}
}

View File

@@ -0,0 +1,500 @@
//! Kernel Pack Manifest (kernels.json)
//!
//! Defines the manifest schema for kernel packs, including kernel metadata,
//! resource limits, platform requirements, and versioning.
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Kernel pack manifest (kernels.json)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KernelManifest {
/// JSON schema URL
#[serde(rename = "$schema", default)]
pub schema: String,
/// Manifest version (semver)
pub version: String,
/// Pack name
pub name: String,
/// Pack description
pub description: String,
/// Minimum runtime version required
pub min_runtime_version: String,
/// Maximum runtime version supported
pub max_runtime_version: String,
/// Creation timestamp (ISO 8601)
pub created_at: String,
/// Author information
pub author: AuthorInfo,
/// List of kernels in the pack
pub kernels: Vec<KernelInfo>,
/// Fallback mappings (kernel_id -> fallback_kernel_id)
#[serde(default)]
pub fallbacks: HashMap<String, String>,
}
/// Author information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthorInfo {
/// Author name
pub name: String,
/// Contact email
pub email: String,
/// Ed25519 public signing key (base64 or hex encoded)
pub signing_key: String,
}
/// Individual kernel information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KernelInfo {
/// Unique kernel identifier
pub id: String,
/// Human-readable name
pub name: String,
/// Kernel category
pub category: KernelCategory,
/// Path to WASM file relative to pack root
pub path: String,
/// SHA256 hash of the WASM file (format: "sha256:...")
pub hash: String,
/// Entry point function name
pub entry_point: String,
/// Input tensor specifications
pub inputs: Vec<TensorSpec>,
/// Output tensor specifications
pub outputs: Vec<TensorSpec>,
/// Kernel-specific parameters
#[serde(default)]
pub params: HashMap<String, KernelParam>,
/// Resource limits
pub resource_limits: ResourceLimits,
/// Platform-specific configurations
#[serde(default)]
pub platforms: HashMap<String, PlatformConfig>,
/// Benchmark results
#[serde(default)]
pub benchmarks: HashMap<String, BenchmarkResult>,
}
/// Kernel categories
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum KernelCategory {
/// Positional encoding (RoPE, etc.)
PositionalEncoding,
/// Normalization (RMSNorm, LayerNorm, etc.)
Normalization,
/// Activation functions (SwiGLU, GELU, etc.)
Activation,
/// KV cache operations (quantize, dequantize)
KvCache,
/// Adapter operations (LoRA, etc.)
Adapter,
/// Attention mechanisms
Attention,
/// Custom/other operations
Custom,
}
impl Default for KernelCategory {
fn default() -> Self {
KernelCategory::Custom
}
}
/// Tensor specification
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TensorSpec {
/// Tensor name
pub name: String,
/// Data type
pub dtype: DataType,
/// Shape specification (symbolic dimensions like "batch", "seq", numeric for fixed)
pub shape: Vec<ShapeDim>,
}
/// Shape dimension (can be symbolic or numeric)
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ShapeDim {
/// Symbolic dimension (e.g., "batch", "seq", "heads")
Symbolic(String),
/// Fixed numeric dimension
Fixed(usize),
}
/// Data types supported by kernels
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum DataType {
/// 32-bit float
F32,
/// 16-bit float (half precision)
F16,
/// Brain float 16
Bf16,
/// 8-bit integer (signed)
I8,
/// 8-bit unsigned integer
U8,
/// 32-bit integer
I32,
/// Quantized 4-bit
Q4,
/// Quantized 8-bit
Q8,
}
impl DataType {
/// Get size in bytes for this data type
pub fn size_bytes(&self) -> usize {
match self {
DataType::F32 | DataType::I32 => 4,
DataType::F16 | DataType::Bf16 => 2,
DataType::I8 | DataType::U8 | DataType::Q8 => 1,
DataType::Q4 => 1, // Packed, 2 values per byte
}
}
}
/// Kernel parameter definition
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KernelParam {
/// Parameter data type
#[serde(rename = "type")]
pub param_type: ParamType,
/// Default value
pub default: serde_json::Value,
/// Optional minimum value
#[serde(default)]
pub min: Option<serde_json::Value>,
/// Optional maximum value
#[serde(default)]
pub max: Option<serde_json::Value>,
/// Optional description
#[serde(default)]
pub description: Option<String>,
}
/// Parameter types
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ParamType {
F32,
F64,
I32,
I64,
U32,
U64,
Bool,
}
/// Resource limits for kernel execution
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResourceLimits {
/// Maximum WASM memory pages (64KB each)
pub max_memory_pages: u32,
/// Maximum epoch ticks before interruption
pub max_epoch_ticks: u64,
/// Maximum table elements
pub max_table_elements: u32,
/// Optional: Maximum stack size in bytes
#[serde(default)]
pub max_stack_size: Option<usize>,
/// Optional: Maximum globals
#[serde(default)]
pub max_globals: Option<u32>,
}
impl Default for ResourceLimits {
fn default() -> Self {
ResourceLimits {
max_memory_pages: 256, // 16MB
max_epoch_ticks: 1000, // ~10 seconds at 10ms/tick
max_table_elements: 1024, // Function pointers
max_stack_size: None,
max_globals: None,
}
}
}
/// Platform-specific configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PlatformConfig {
/// Minimum version of the runtime
pub min_version: String,
/// Required WASM features
#[serde(default)]
pub features: Vec<String>,
/// Whether AOT compilation is available
#[serde(default)]
pub aot_available: bool,
}
/// Benchmark result
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BenchmarkResult {
/// Latency in microseconds
pub latency_us: u64,
/// Throughput in GFLOPS
pub throughput_gflops: f64,
}
/// Kernel invocation descriptor passed to WASM
///
/// This is the C-compatible struct passed to kernels to describe
/// memory layout and tensor locations.
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct KernelDescriptor {
/// Input tensor A offset in linear memory
pub input_a_offset: u32,
/// Input tensor A size in bytes
pub input_a_size: u32,
/// Input tensor B offset (0 if unused)
pub input_b_offset: u32,
/// Input tensor B size in bytes
pub input_b_size: u32,
/// Output tensor offset
pub output_offset: u32,
/// Output tensor size in bytes
pub output_size: u32,
/// Scratch space offset
pub scratch_offset: u32,
/// Scratch space size in bytes
pub scratch_size: u32,
/// Kernel-specific parameters offset
pub params_offset: u32,
/// Kernel-specific parameters size
pub params_size: u32,
}
impl KernelDescriptor {
/// Create a new kernel descriptor
pub fn new() -> Self {
KernelDescriptor {
input_a_offset: 0,
input_a_size: 0,
input_b_offset: 0,
input_b_size: 0,
output_offset: 0,
output_size: 0,
scratch_offset: 0,
scratch_size: 0,
params_offset: 0,
params_size: 0,
}
}
/// Calculate total memory required
pub fn total_memory_required(&self) -> usize {
let max_end = [
self.input_a_offset + self.input_a_size,
self.input_b_offset + self.input_b_size,
self.output_offset + self.output_size,
self.scratch_offset + self.scratch_size,
self.params_offset + self.params_size,
]
.into_iter()
.max()
.unwrap_or(0);
max_end as usize
}
/// Serialize to bytes for passing to WASM
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = Vec::with_capacity(40);
bytes.extend_from_slice(&self.input_a_offset.to_le_bytes());
bytes.extend_from_slice(&self.input_a_size.to_le_bytes());
bytes.extend_from_slice(&self.input_b_offset.to_le_bytes());
bytes.extend_from_slice(&self.input_b_size.to_le_bytes());
bytes.extend_from_slice(&self.output_offset.to_le_bytes());
bytes.extend_from_slice(&self.output_size.to_le_bytes());
bytes.extend_from_slice(&self.scratch_offset.to_le_bytes());
bytes.extend_from_slice(&self.scratch_size.to_le_bytes());
bytes.extend_from_slice(&self.params_offset.to_le_bytes());
bytes.extend_from_slice(&self.params_size.to_le_bytes());
bytes
}
}
impl Default for KernelDescriptor {
fn default() -> Self {
Self::new()
}
}
impl KernelManifest {
/// Parse manifest from JSON string
pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
serde_json::from_str(json)
}
/// Serialize manifest to JSON string
pub fn to_json(&self) -> Result<String, serde_json::Error> {
serde_json::to_string_pretty(self)
}
/// Get kernel by ID
pub fn get_kernel(&self, id: &str) -> Option<&KernelInfo> {
self.kernels.iter().find(|k| k.id == id)
}
/// Get fallback kernel for a given kernel ID
pub fn get_fallback(&self, id: &str) -> Option<&str> {
self.fallbacks.get(id).map(|s| s.as_str())
}
/// List all kernel IDs
pub fn kernel_ids(&self) -> Vec<&str> {
self.kernels.iter().map(|k| k.id.as_str()).collect()
}
/// List kernels by category
pub fn kernels_by_category(&self, category: KernelCategory) -> Vec<&KernelInfo> {
self.kernels
.iter()
.filter(|k| k.category == category)
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_manifest_json() -> &'static str {
r#"{
"$schema": "https://ruvllm.dev/schemas/kernel-pack-v1.json",
"version": "1.0.0",
"name": "test-kernels",
"description": "Test kernel pack",
"min_runtime_version": "0.5.0",
"max_runtime_version": "1.0.0",
"created_at": "2026-01-18T00:00:00Z",
"author": {
"name": "Test Author",
"email": "test@example.com",
"signing_key": "ed25519:AAAA..."
},
"kernels": [
{
"id": "rope_f32",
"name": "Rotary Position Embedding (FP32)",
"category": "positional_encoding",
"path": "rope/rope_f32.wasm",
"hash": "sha256:abc123",
"entry_point": "rope_forward",
"inputs": [
{"name": "x", "dtype": "f32", "shape": ["batch", "seq", "heads", "dim"]},
{"name": "freqs", "dtype": "f32", "shape": ["seq", 64]}
],
"outputs": [
{"name": "y", "dtype": "f32", "shape": ["batch", "seq", "heads", "dim"]}
],
"params": {
"theta": {"type": "f32", "default": 10000.0}
},
"resource_limits": {
"max_memory_pages": 256,
"max_epoch_ticks": 1000,
"max_table_elements": 1024
},
"platforms": {
"wasmtime": {
"min_version": "15.0.0",
"features": ["simd", "bulk-memory"]
}
},
"benchmarks": {
"seq_512_dim_128": {
"latency_us": 45,
"throughput_gflops": 2.1
}
}
}
],
"fallbacks": {
"rope_f32": "rope_reference"
}
}"#
}
#[test]
fn test_manifest_parsing() {
let manifest = KernelManifest::from_json(sample_manifest_json()).unwrap();
assert_eq!(manifest.name, "test-kernels");
assert_eq!(manifest.version, "1.0.0");
assert_eq!(manifest.kernels.len(), 1);
}
#[test]
fn test_kernel_lookup() {
let manifest = KernelManifest::from_json(sample_manifest_json()).unwrap();
let kernel = manifest.get_kernel("rope_f32").unwrap();
assert_eq!(kernel.name, "Rotary Position Embedding (FP32)");
assert_eq!(kernel.category, KernelCategory::PositionalEncoding);
}
#[test]
fn test_fallback_lookup() {
let manifest = KernelManifest::from_json(sample_manifest_json()).unwrap();
assert_eq!(manifest.get_fallback("rope_f32"), Some("rope_reference"));
assert_eq!(manifest.get_fallback("unknown"), None);
}
#[test]
fn test_kernel_descriptor() {
let mut desc = KernelDescriptor::new();
desc.input_a_offset = 0;
desc.input_a_size = 1024;
desc.output_offset = 1024;
desc.output_size = 1024;
assert_eq!(desc.total_memory_required(), 2048);
assert_eq!(desc.to_bytes().len(), 40);
}
#[test]
fn test_data_type_sizes() {
assert_eq!(DataType::F32.size_bytes(), 4);
assert_eq!(DataType::F16.size_bytes(), 2);
assert_eq!(DataType::I8.size_bytes(), 1);
}
}

View File

@@ -0,0 +1,466 @@
//! Shared Memory Protocol
//!
//! Defines the memory layout and protocol for passing tensor data
//! between the host and WASM kernels.
use crate::kernel::error::KernelError;
use crate::kernel::manifest::{DataType, KernelDescriptor};
/// WASM page size (64KB)
pub const PAGE_SIZE: usize = 65536;
/// Shared memory protocol for kernel invocation
///
/// Manages the layout of tensors and parameters in WASM linear memory.
#[derive(Debug, Clone)]
pub struct SharedMemoryProtocol {
/// Total memory size in bytes
total_size: usize,
/// Current allocation offset
current_offset: usize,
/// Memory alignment (typically 8 or 16 bytes)
alignment: usize,
}
impl SharedMemoryProtocol {
/// Create a new memory protocol
///
/// # Arguments
/// * `total_pages` - Number of WASM pages to allocate
/// * `alignment` - Memory alignment in bytes
pub fn new(total_pages: usize, alignment: usize) -> Self {
SharedMemoryProtocol {
total_size: total_pages * PAGE_SIZE,
current_offset: 0,
alignment,
}
}
/// Create with default settings (256 pages = 16MB, 16-byte alignment)
pub fn default_settings() -> Self {
Self::new(256, 16)
}
/// Reset allocator to beginning
pub fn reset(&mut self) {
self.current_offset = 0;
}
/// Align offset to boundary
fn align_offset(&self, offset: usize) -> usize {
(offset + self.alignment - 1) & !(self.alignment - 1)
}
/// Allocate memory region
///
/// # Arguments
/// * `size` - Size in bytes
///
/// # Returns
/// * `Ok(offset)` - Starting offset of allocated region
/// * `Err` - If allocation would exceed total size
pub fn allocate(&mut self, size: usize) -> Result<usize, KernelError> {
let aligned_offset = self.align_offset(self.current_offset);
let end_offset = aligned_offset + size;
if end_offset > self.total_size {
return Err(KernelError::AllocationFailed {
requested_bytes: size,
});
}
self.current_offset = end_offset;
Ok(aligned_offset)
}
/// Get total memory size
pub fn total_size(&self) -> usize {
self.total_size
}
/// Get total pages
pub fn total_pages(&self) -> usize {
self.total_size / PAGE_SIZE
}
/// Get current allocation offset
pub fn current_offset(&self) -> usize {
self.current_offset
}
/// Get remaining available bytes
pub fn remaining(&self) -> usize {
self.total_size.saturating_sub(self.current_offset)
}
/// Check if a memory region is valid
pub fn is_valid_region(&self, offset: usize, size: usize) -> bool {
offset + size <= self.total_size
}
}
impl Default for SharedMemoryProtocol {
fn default() -> Self {
Self::default_settings()
}
}
/// Kernel invocation descriptor with memory layout
///
/// This is a higher-level wrapper around KernelDescriptor that helps
/// manage memory allocation and data transfer.
#[derive(Debug, Clone)]
pub struct KernelInvocationDescriptor {
/// Low-level descriptor
pub descriptor: KernelDescriptor,
/// Memory protocol
protocol: SharedMemoryProtocol,
}
impl KernelInvocationDescriptor {
/// Create a new invocation descriptor
pub fn new(total_pages: usize) -> Self {
KernelInvocationDescriptor {
descriptor: KernelDescriptor::new(),
protocol: SharedMemoryProtocol::new(total_pages, 16),
}
}
/// Create with default memory size
pub fn default_size() -> Self {
Self::new(256)
}
/// Allocate space for input tensor A
pub fn allocate_input_a(&mut self, size: usize) -> Result<u32, KernelError> {
let offset = self.protocol.allocate(size)?;
self.descriptor.input_a_offset = offset as u32;
self.descriptor.input_a_size = size as u32;
Ok(offset as u32)
}
/// Allocate space for input tensor B
pub fn allocate_input_b(&mut self, size: usize) -> Result<u32, KernelError> {
let offset = self.protocol.allocate(size)?;
self.descriptor.input_b_offset = offset as u32;
self.descriptor.input_b_size = size as u32;
Ok(offset as u32)
}
/// Allocate space for output tensor
pub fn allocate_output(&mut self, size: usize) -> Result<u32, KernelError> {
let offset = self.protocol.allocate(size)?;
self.descriptor.output_offset = offset as u32;
self.descriptor.output_size = size as u32;
Ok(offset as u32)
}
/// Allocate scratch space
pub fn allocate_scratch(&mut self, size: usize) -> Result<u32, KernelError> {
let offset = self.protocol.allocate(size)?;
self.descriptor.scratch_offset = offset as u32;
self.descriptor.scratch_size = size as u32;
Ok(offset as u32)
}
/// Allocate space for parameters
pub fn allocate_params(&mut self, size: usize) -> Result<u32, KernelError> {
let offset = self.protocol.allocate(size)?;
self.descriptor.params_offset = offset as u32;
self.descriptor.params_size = size as u32;
Ok(offset as u32)
}
/// Get the low-level descriptor
pub fn as_descriptor(&self) -> &KernelDescriptor {
&self.descriptor
}
/// Get total allocated memory
pub fn total_allocated(&self) -> usize {
self.protocol.current_offset()
}
/// Get remaining memory
pub fn remaining_memory(&self) -> usize {
self.protocol.remaining()
}
/// Required pages for current allocation
pub fn required_pages(&self) -> usize {
(self.total_allocated() + PAGE_SIZE - 1) / PAGE_SIZE
}
}
impl Default for KernelInvocationDescriptor {
fn default() -> Self {
Self::default_size()
}
}
/// Memory region specification
#[derive(Debug, Clone, Copy)]
pub struct MemoryRegion {
/// Start offset in linear memory
pub offset: u32,
/// Size in bytes
pub size: u32,
/// Whether region is read-only
pub read_only: bool,
}
impl MemoryRegion {
/// Create a new memory region
pub fn new(offset: u32, size: u32, read_only: bool) -> Self {
MemoryRegion {
offset,
size,
read_only,
}
}
/// Create a read-only region
pub fn read_only(offset: u32, size: u32) -> Self {
Self::new(offset, size, true)
}
/// Create a writable region
pub fn writable(offset: u32, size: u32) -> Self {
Self::new(offset, size, false)
}
/// Get end offset (exclusive)
pub fn end(&self) -> u32 {
self.offset + self.size
}
/// Check if regions overlap
pub fn overlaps(&self, other: &MemoryRegion) -> bool {
self.offset < other.end() && other.offset < self.end()
}
}
/// Calculate tensor size in bytes
///
/// # Arguments
/// * `shape` - Tensor shape (dimensions)
/// * `dtype` - Data type
///
/// # Returns
/// Size in bytes
pub fn tensor_size_bytes(shape: &[usize], dtype: DataType) -> usize {
let num_elements: usize = shape.iter().product();
num_elements * dtype.size_bytes()
}
/// Calculate required WASM pages for a given byte size
pub fn required_pages(size_bytes: usize) -> usize {
(size_bytes + PAGE_SIZE - 1) / PAGE_SIZE
}
/// Memory layout validator
#[derive(Debug, Default)]
pub struct MemoryLayoutValidator {
/// Registered regions
regions: Vec<MemoryRegion>,
}
impl MemoryLayoutValidator {
/// Create a new validator
pub fn new() -> Self {
MemoryLayoutValidator {
regions: Vec::new(),
}
}
/// Add a region to validate
pub fn add_region(&mut self, region: MemoryRegion) -> Result<(), KernelError> {
// Check for overlaps with existing regions
for existing in &self.regions {
if region.overlaps(existing) {
return Err(KernelError::InvalidParameters {
description: format!(
"Memory region overlap: [{}, {}) overlaps [{}, {})",
region.offset,
region.end(),
existing.offset,
existing.end()
),
});
}
}
self.regions.push(region);
Ok(())
}
/// Validate a descriptor's memory layout
pub fn validate_descriptor(
&self,
desc: &KernelDescriptor,
total_memory: usize,
) -> Result<(), KernelError> {
// Check all regions are within bounds
let regions = [
("input_a", desc.input_a_offset, desc.input_a_size),
("input_b", desc.input_b_offset, desc.input_b_size),
("output", desc.output_offset, desc.output_size),
("scratch", desc.scratch_offset, desc.scratch_size),
("params", desc.params_offset, desc.params_size),
];
for (name, offset, size) in regions {
if size > 0 {
let end = (offset as usize) + (size as usize);
if end > total_memory {
return Err(KernelError::MemoryAccessViolation { offset, size });
}
}
}
// Check for overlaps between output and inputs
let output = MemoryRegion::writable(desc.output_offset, desc.output_size);
if desc.input_a_size > 0 {
let input_a = MemoryRegion::read_only(desc.input_a_offset, desc.input_a_size);
if output.overlaps(&input_a) {
return Err(KernelError::InvalidParameters {
description: "Output overlaps with input_a".to_string(),
});
}
}
if desc.input_b_size > 0 {
let input_b = MemoryRegion::read_only(desc.input_b_offset, desc.input_b_size);
if output.overlaps(&input_b) {
return Err(KernelError::InvalidParameters {
description: "Output overlaps with input_b".to_string(),
});
}
}
Ok(())
}
/// Clear all regions
pub fn clear(&mut self) {
self.regions.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_protocol() {
let mut protocol = SharedMemoryProtocol::new(1, 16); // 1 page = 64KB
let offset1 = protocol.allocate(1024).unwrap();
assert_eq!(offset1, 0);
let offset2 = protocol.allocate(2048).unwrap();
assert!(offset2 >= 1024);
assert_eq!(offset2 % 16, 0); // Aligned
assert!(protocol.remaining() < PAGE_SIZE);
}
#[test]
fn test_allocation_failure() {
let mut protocol = SharedMemoryProtocol::new(1, 16);
// Try to allocate more than available
let result = protocol.allocate(PAGE_SIZE + 1);
assert!(matches!(result, Err(KernelError::AllocationFailed { .. })));
}
#[test]
fn test_invocation_descriptor() {
let mut desc = KernelInvocationDescriptor::new(4); // 4 pages
desc.allocate_input_a(1024).unwrap();
desc.allocate_input_b(1024).unwrap();
desc.allocate_output(1024).unwrap();
desc.allocate_scratch(512).unwrap();
desc.allocate_params(64).unwrap();
assert!(desc.total_allocated() > 3600); // With alignment
assert_eq!(desc.descriptor.input_a_size, 1024);
}
#[test]
fn test_tensor_size() {
let shape = [1, 512, 32, 128]; // batch, seq, heads, dim
let size = tensor_size_bytes(&shape, DataType::F32);
assert_eq!(size, 1 * 512 * 32 * 128 * 4); // 8MB
}
#[test]
fn test_required_pages() {
assert_eq!(required_pages(0), 0);
assert_eq!(required_pages(1), 1);
assert_eq!(required_pages(PAGE_SIZE), 1);
assert_eq!(required_pages(PAGE_SIZE + 1), 2);
}
#[test]
fn test_memory_region_overlap() {
let r1 = MemoryRegion::new(0, 100, false);
let r2 = MemoryRegion::new(50, 100, false);
let r3 = MemoryRegion::new(100, 100, false);
assert!(r1.overlaps(&r2));
assert!(!r1.overlaps(&r3));
}
#[test]
fn test_layout_validator() {
let mut validator = MemoryLayoutValidator::new();
// Add non-overlapping regions
validator
.add_region(MemoryRegion::new(0, 100, false))
.unwrap();
validator
.add_region(MemoryRegion::new(100, 100, false))
.unwrap();
// Try to add overlapping region
let result = validator.add_region(MemoryRegion::new(50, 100, false));
assert!(result.is_err());
}
#[test]
fn test_validate_descriptor() {
let validator = MemoryLayoutValidator::new();
let mut desc = KernelDescriptor::new();
desc.input_a_offset = 0;
desc.input_a_size = 1024;
desc.output_offset = 1024;
desc.output_size = 1024;
// Should pass - no overlap
assert!(validator.validate_descriptor(&desc, PAGE_SIZE).is_ok());
// Should fail - output overlaps input
desc.output_offset = 512;
assert!(validator.validate_descriptor(&desc, PAGE_SIZE).is_err());
}
#[test]
fn test_validate_bounds() {
let validator = MemoryLayoutValidator::new();
let mut desc = KernelDescriptor::new();
desc.input_a_offset = 0;
desc.input_a_size = PAGE_SIZE as u32 + 1; // Too big
let result = validator.validate_descriptor(&desc, PAGE_SIZE);
assert!(matches!(
result,
Err(KernelError::MemoryAccessViolation { .. })
));
}
}

View File

@@ -0,0 +1,71 @@
//! WASM Kernel Pack System (ADR-005)
//!
//! This module implements the WebAssembly kernel pack infrastructure for
//! secure, sandboxed execution of ML compute kernels.
//!
//! # Architecture
//!
//! The kernel pack system provides:
//! - **Sandboxed Execution**: Wasmtime runtime with epoch-based interruption
//! - **Supply Chain Security**: Ed25519 signatures, SHA256 hash verification
//! - **Hot-Swappable Kernels**: Update kernels without service restart
//! - **Cross-Platform**: Same kernels run on servers and embedded devices
//!
//! # Kernel Categories
//!
//! - Positional: RoPE (Rotary Position Embeddings)
//! - Normalization: RMSNorm
//! - Activation: SwiGLU
//! - KV Cache: Quantization/Dequantization
//! - Adapter: LoRA delta application
//!
//! # Example
//!
//! ```rust,ignore
//! use ruvector_wasm::kernel::{KernelManager, KernelPackVerifier};
//!
//! // Load and verify kernel pack
//! let verifier = KernelPackVerifier::with_trusted_keys(keys);
//! let manager = KernelManager::new(runtime_config)?;
//! manager.load_pack("kernel-pack-v1.0.0", &verifier)?;
//!
//! // Execute kernel
//! let result = manager.execute("rope_f32", &descriptor)?;
//! ```
pub mod allowlist;
pub mod epoch;
pub mod error;
pub mod hash;
pub mod manifest;
pub mod memory;
pub mod runtime;
pub mod signature;
// Re-exports
pub use allowlist::TrustedKernelAllowlist;
pub use epoch::{EpochConfig, EpochController};
pub use error::{KernelError, VerifyError};
pub use hash::HashVerifier;
pub use manifest::{
KernelCategory, KernelDescriptor, KernelInfo, KernelManifest, KernelParam, PlatformConfig,
ResourceLimits, TensorSpec,
};
pub use memory::{KernelInvocationDescriptor, SharedMemoryProtocol};
pub use runtime::{KernelRuntime, RuntimeConfig, WasmKernelInstance};
pub use signature::KernelPackVerifier;
/// Current runtime version for compatibility checking
pub const RUNTIME_VERSION: &str = env!("CARGO_PKG_VERSION");
/// Maximum supported kernel manifest schema version
pub const MAX_MANIFEST_VERSION: &str = "1.0.0";
/// WASM page size in bytes (64KB)
pub const WASM_PAGE_SIZE: usize = 65536;
/// Default epoch tick interval in milliseconds
pub const DEFAULT_EPOCH_TICK_MS: u64 = 10;
/// Default epoch budget (ticks before interruption)
pub const DEFAULT_EPOCH_BUDGET: u64 = 1000;

View File

@@ -0,0 +1,575 @@
//! Wasmtime Runtime Integration
//!
//! Provides the runtime traits and implementations for executing
//! WASM kernels with Wasmtime.
use crate::kernel::epoch::{EpochConfig, EpochController, EpochDeadline};
use crate::kernel::error::{KernelError, KernelErrorCode, KernelResult};
use crate::kernel::manifest::{KernelDescriptor, KernelInfo, KernelManifest, ResourceLimits};
use crate::kernel::memory::{MemoryLayoutValidator, SharedMemoryProtocol, PAGE_SIZE};
use std::collections::HashMap;
use std::sync::Arc;
/// Runtime configuration for WASM kernel execution
#[derive(Debug, Clone)]
pub struct RuntimeConfig {
/// Epoch configuration
pub epoch: EpochConfig,
/// Enable SIMD support
pub enable_simd: bool,
/// Enable bulk memory operations
pub enable_bulk_memory: bool,
/// Enable multi-value returns
pub enable_multi_value: bool,
/// Maximum memory pages per instance
pub max_memory_pages: u32,
/// Enable parallel compilation
pub parallel_compilation: bool,
/// Optimization level (0-3, where 0=none, 3=maximum)
pub optimization_level: u8,
/// Enable instance pooling for reuse
pub enable_instance_pooling: bool,
/// Pool size for instance reuse
pub instance_pool_size: usize,
}
impl RuntimeConfig {
/// Create configuration for server workloads
pub fn server() -> Self {
RuntimeConfig {
epoch: EpochConfig::server(),
enable_simd: true,
enable_bulk_memory: true,
enable_multi_value: true,
max_memory_pages: 1024, // 64MB max
parallel_compilation: true,
optimization_level: 3,
enable_instance_pooling: true,
instance_pool_size: 16,
}
}
/// Create configuration for embedded/constrained workloads
pub fn embedded() -> Self {
RuntimeConfig {
epoch: EpochConfig::embedded(),
enable_simd: false, // Often unavailable
enable_bulk_memory: true,
enable_multi_value: true,
max_memory_pages: 64, // 4MB max
parallel_compilation: false,
optimization_level: 2,
enable_instance_pooling: false,
instance_pool_size: 0,
}
}
/// Create configuration for development/debugging
pub fn development() -> Self {
RuntimeConfig {
epoch: EpochConfig::disabled(),
enable_simd: true,
enable_bulk_memory: true,
enable_multi_value: true,
max_memory_pages: 1024,
parallel_compilation: true,
optimization_level: 0, // Fast compilation
enable_instance_pooling: false,
instance_pool_size: 0,
}
}
}
impl Default for RuntimeConfig {
fn default() -> Self {
Self::server()
}
}
/// Compiled WASM kernel module
#[derive(Debug)]
pub struct CompiledKernel {
/// Kernel ID
pub id: String,
/// Kernel info from manifest
pub info: KernelInfo,
/// Compiled module bytes (for caching)
pub compiled_bytes: Vec<u8>,
/// Whether module uses SIMD
pub uses_simd: bool,
/// Required memory pages
pub required_pages: u32,
}
/// WASM kernel instance ready for execution
pub struct WasmKernelInstance {
/// Kernel ID
kernel_id: String,
/// Memory allocated for this instance
memory_pages: u32,
/// Epoch deadline for this invocation
deadline: Option<EpochDeadline>,
/// Memory validator
validator: MemoryLayoutValidator,
}
impl WasmKernelInstance {
/// Create a new kernel instance
pub fn new(kernel_id: String, memory_pages: u32) -> Self {
WasmKernelInstance {
kernel_id,
memory_pages,
deadline: None,
validator: MemoryLayoutValidator::new(),
}
}
/// Set execution deadline
pub fn set_deadline(&mut self, deadline: EpochDeadline) {
self.deadline = Some(deadline);
}
/// Get kernel ID
pub fn kernel_id(&self) -> &str {
&self.kernel_id
}
/// Get allocated memory pages
pub fn memory_pages(&self) -> u32 {
self.memory_pages
}
/// Get memory size in bytes
pub fn memory_size(&self) -> usize {
self.memory_pages as usize * PAGE_SIZE
}
/// Validate a descriptor before execution
pub fn validate_descriptor(&self, desc: &KernelDescriptor) -> KernelResult<()> {
self.validator.validate_descriptor(desc, self.memory_size())
}
/// Check if deadline exceeded (if set)
pub fn check_deadline(&self, controller: &EpochController) -> bool {
if let Some(deadline) = &self.deadline {
deadline.is_exceeded(controller.current())
} else {
false
}
}
}
impl std::fmt::Debug for WasmKernelInstance {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WasmKernelInstance")
.field("kernel_id", &self.kernel_id)
.field("memory_pages", &self.memory_pages)
.field("deadline", &self.deadline)
.finish()
}
}
/// Trait for kernel runtime implementations
pub trait KernelRuntime: Send + Sync {
/// Load and compile a kernel from WASM bytes
fn compile_kernel(
&self,
id: &str,
wasm_bytes: &[u8],
info: &KernelInfo,
) -> KernelResult<CompiledKernel>;
/// Create an instance of a compiled kernel
fn instantiate(&self, kernel: &CompiledKernel) -> KernelResult<WasmKernelInstance>;
/// Execute a kernel with the given descriptor
fn execute(
&self,
instance: &mut WasmKernelInstance,
descriptor: &KernelDescriptor,
memory: &mut [u8],
) -> KernelResult<()>;
/// Get runtime configuration
fn config(&self) -> &RuntimeConfig;
/// Get epoch controller
fn epoch_controller(&self) -> &EpochController;
/// Increment epoch (should be called periodically)
fn tick(&self) {
self.epoch_controller().increment();
}
}
/// Mock runtime for testing without Wasmtime dependency
#[derive(Debug)]
pub struct MockKernelRuntime {
config: RuntimeConfig,
epoch_controller: EpochController,
/// Registered kernel behaviors for testing
kernel_behaviors: HashMap<String, MockKernelBehavior>,
}
/// Mock kernel behavior for testing
#[derive(Debug, Clone)]
pub enum MockKernelBehavior {
/// Always succeed
Success,
/// Always fail with error code
Fail(KernelErrorCode),
/// Timeout (exceed epoch)
Timeout,
/// Return specific output data
ReturnData(Vec<u8>),
}
impl MockKernelRuntime {
/// Create a new mock runtime
pub fn new(config: RuntimeConfig) -> Self {
MockKernelRuntime {
epoch_controller: EpochController::new(config.epoch.tick_interval()),
config,
kernel_behaviors: HashMap::new(),
}
}
/// Register a mock behavior for a kernel
pub fn register_behavior(&mut self, kernel_id: &str, behavior: MockKernelBehavior) {
self.kernel_behaviors
.insert(kernel_id.to_string(), behavior);
}
}
impl KernelRuntime for MockKernelRuntime {
fn compile_kernel(
&self,
id: &str,
_wasm_bytes: &[u8],
info: &KernelInfo,
) -> KernelResult<CompiledKernel> {
Ok(CompiledKernel {
id: id.to_string(),
info: info.clone(),
compiled_bytes: vec![], // No actual compilation
uses_simd: false,
required_pages: info.resource_limits.max_memory_pages,
})
}
fn instantiate(&self, kernel: &CompiledKernel) -> KernelResult<WasmKernelInstance> {
Ok(WasmKernelInstance::new(
kernel.id.clone(),
kernel.required_pages,
))
}
fn execute(
&self,
instance: &mut WasmKernelInstance,
descriptor: &KernelDescriptor,
memory: &mut [u8],
) -> KernelResult<()> {
// Validate descriptor first
instance.validate_descriptor(descriptor)?;
// Check deadline
if instance.check_deadline(&self.epoch_controller) {
return Err(KernelError::EpochDeadline);
}
// Look up mock behavior
let behavior = self
.kernel_behaviors
.get(instance.kernel_id())
.cloned()
.unwrap_or(MockKernelBehavior::Success);
match behavior {
MockKernelBehavior::Success => Ok(()),
MockKernelBehavior::Fail(code) => Err(KernelError::KernelTrap {
code: code as u32,
message: Some(code.to_string()),
}),
MockKernelBehavior::Timeout => Err(KernelError::EpochDeadline),
MockKernelBehavior::ReturnData(data) => {
// Copy data to output region
let out_start = descriptor.output_offset as usize;
let out_end = out_start + descriptor.output_size.min(data.len() as u32) as usize;
if out_end <= memory.len() {
let copy_len = (out_end - out_start).min(data.len());
memory[out_start..out_start + copy_len].copy_from_slice(&data[..copy_len]);
}
Ok(())
}
}
}
fn config(&self) -> &RuntimeConfig {
&self.config
}
fn epoch_controller(&self) -> &EpochController {
&self.epoch_controller
}
}
/// Kernel manager for loading and executing kernel packs
pub struct KernelManager<R: KernelRuntime> {
/// Runtime implementation
runtime: Arc<R>,
/// Loaded manifests
manifests: HashMap<String, KernelManifest>,
/// Compiled kernels
compiled_kernels: HashMap<String, CompiledKernel>,
/// Active kernel pack
active_pack: Option<String>,
}
impl<R: KernelRuntime> KernelManager<R> {
/// Create a new kernel manager
pub fn new(runtime: Arc<R>) -> Self {
KernelManager {
runtime,
manifests: HashMap::new(),
compiled_kernels: HashMap::new(),
active_pack: None,
}
}
/// Load a kernel pack manifest
pub fn load_manifest(&mut self, pack_name: &str, manifest: KernelManifest) {
self.manifests.insert(pack_name.to_string(), manifest);
}
/// Compile a kernel from a loaded pack
pub fn compile_kernel(
&mut self,
pack_name: &str,
kernel_id: &str,
wasm_bytes: &[u8],
) -> KernelResult<()> {
let manifest =
self.manifests
.get(pack_name)
.ok_or_else(|| KernelError::KernelNotFound {
kernel_id: format!("pack:{}", pack_name),
})?;
let info = manifest
.get_kernel(kernel_id)
.ok_or_else(|| KernelError::KernelNotFound {
kernel_id: kernel_id.to_string(),
})?;
let compiled = self.runtime.compile_kernel(kernel_id, wasm_bytes, info)?;
self.compiled_kernels
.insert(kernel_id.to_string(), compiled);
Ok(())
}
/// Set the active kernel pack
pub fn set_active_pack(&mut self, pack_name: &str) -> KernelResult<()> {
if self.manifests.contains_key(pack_name) {
self.active_pack = Some(pack_name.to_string());
Ok(())
} else {
Err(KernelError::KernelNotFound {
kernel_id: format!("pack:{}", pack_name),
})
}
}
/// Execute a kernel
pub fn execute(
&self,
kernel_id: &str,
descriptor: &KernelDescriptor,
memory: &mut [u8],
) -> KernelResult<()> {
let compiled =
self.compiled_kernels
.get(kernel_id)
.ok_or_else(|| KernelError::KernelNotFound {
kernel_id: kernel_id.to_string(),
})?;
let mut instance = self.runtime.instantiate(compiled)?;
// Set deadline if epoch is enabled
if self.runtime.config().epoch.enabled {
let budget = compiled.info.resource_limits.max_epoch_ticks;
let deadline = EpochDeadline::new(self.runtime.epoch_controller().current(), budget);
instance.set_deadline(deadline);
}
self.runtime.execute(&mut instance, descriptor, memory)
}
/// Get kernel info
pub fn get_kernel_info(&self, kernel_id: &str) -> Option<&KernelInfo> {
self.compiled_kernels.get(kernel_id).map(|k| &k.info)
}
/// List compiled kernel IDs
pub fn list_kernels(&self) -> Vec<&str> {
self.compiled_kernels.keys().map(|s| s.as_str()).collect()
}
/// Get runtime reference
pub fn runtime(&self) -> &R {
&self.runtime
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::kernel::manifest::{DataType, KernelCategory, ResourceLimits, ShapeDim, TensorSpec};
fn mock_kernel_info(id: &str) -> KernelInfo {
KernelInfo {
id: id.to_string(),
name: format!("Test {}", id),
category: KernelCategory::Custom,
path: format!("{}.wasm", id),
hash: "sha256:test".to_string(),
entry_point: "kernel_forward".to_string(),
inputs: vec![TensorSpec {
name: "x".to_string(),
dtype: DataType::F32,
shape: vec![ShapeDim::Symbolic("batch".to_string())],
}],
outputs: vec![TensorSpec {
name: "y".to_string(),
dtype: DataType::F32,
shape: vec![ShapeDim::Symbolic("batch".to_string())],
}],
params: HashMap::new(),
resource_limits: ResourceLimits::default(),
platforms: HashMap::new(),
benchmarks: HashMap::new(),
}
}
#[test]
fn test_runtime_config() {
let server = RuntimeConfig::server();
assert!(server.enable_simd);
assert_eq!(server.optimization_level, 3);
let embedded = RuntimeConfig::embedded();
assert!(!embedded.enable_simd);
assert!(!embedded.parallel_compilation);
let dev = RuntimeConfig::development();
assert_eq!(dev.optimization_level, 0);
}
#[test]
fn test_mock_runtime() {
let mut runtime = MockKernelRuntime::new(RuntimeConfig::default());
// Test success behavior
runtime.register_behavior("test_kernel", MockKernelBehavior::Success);
let info = mock_kernel_info("test_kernel");
let compiled = runtime.compile_kernel("test_kernel", &[], &info).unwrap();
let mut instance = runtime.instantiate(&compiled).unwrap();
let mut desc = KernelDescriptor::new();
desc.input_a_offset = 0;
desc.input_a_size = 1024;
desc.output_offset = 1024;
desc.output_size = 1024;
let mut memory = vec![0u8; 65536];
let result = runtime.execute(&mut instance, &desc, &mut memory);
assert!(result.is_ok());
}
#[test]
fn test_mock_runtime_failure() {
let mut runtime = MockKernelRuntime::new(RuntimeConfig::default());
runtime.register_behavior(
"failing_kernel",
MockKernelBehavior::Fail(KernelErrorCode::InvalidInput),
);
let info = mock_kernel_info("failing_kernel");
let compiled = runtime
.compile_kernel("failing_kernel", &[], &info)
.unwrap();
let mut instance = runtime.instantiate(&compiled).unwrap();
let desc = KernelDescriptor::new();
let mut memory = vec![0u8; 65536];
let result = runtime.execute(&mut instance, &desc, &mut memory);
assert!(matches!(result, Err(KernelError::KernelTrap { .. })));
}
#[test]
fn test_wasm_kernel_instance() {
let mut instance = WasmKernelInstance::new("test".to_string(), 256);
assert_eq!(instance.kernel_id(), "test");
assert_eq!(instance.memory_pages(), 256);
assert_eq!(instance.memory_size(), 256 * PAGE_SIZE);
// Test deadline
let controller = EpochController::default_interval();
let deadline = EpochDeadline::new(0, 100);
instance.set_deadline(deadline);
assert!(!instance.check_deadline(&controller));
// Exceed deadline
for _ in 0..100 {
controller.increment();
}
assert!(instance.check_deadline(&controller));
}
#[test]
fn test_kernel_manager() {
let runtime = Arc::new(MockKernelRuntime::new(RuntimeConfig::default()));
let mut manager = KernelManager::new(runtime);
// Create a minimal manifest
let manifest = KernelManifest {
schema: String::new(),
version: "1.0.0".to_string(),
name: "test-pack".to_string(),
description: "Test".to_string(),
min_runtime_version: "0.1.0".to_string(),
max_runtime_version: "1.0.0".to_string(),
created_at: "2026-01-18T00:00:00Z".to_string(),
author: crate::kernel::manifest::AuthorInfo {
name: "Test".to_string(),
email: "test@test.com".to_string(),
signing_key: "test".to_string(),
},
kernels: vec![mock_kernel_info("rope_f32")],
fallbacks: HashMap::new(),
};
manager.load_manifest("test-pack", manifest);
manager.set_active_pack("test-pack").unwrap();
// Compile kernel
manager
.compile_kernel("test-pack", "rope_f32", &[])
.unwrap();
assert_eq!(manager.list_kernels(), vec!["rope_f32"]);
}
}

View File

@@ -0,0 +1,288 @@
//! Ed25519 Signature Verification
//!
//! Provides cryptographic signature verification for kernel pack manifests
//! to ensure supply chain security.
use crate::kernel::error::VerifyError;
use ed25519_dalek::{Signature, Verifier, VerifyingKey};
/// Kernel pack signature verifier
///
/// Maintains a list of trusted Ed25519 public keys and verifies
/// manifest signatures against them.
#[derive(Debug, Clone)]
pub struct KernelPackVerifier {
/// Trusted Ed25519 public keys
trusted_keys: Vec<VerifyingKey>,
/// Whether to require signatures (can be disabled for development)
require_signature: bool,
}
impl KernelPackVerifier {
/// Create a new verifier with no trusted keys
pub fn new() -> Self {
KernelPackVerifier {
trusted_keys: Vec::new(),
require_signature: true,
}
}
/// Create a verifier with pre-loaded trusted keys
pub fn with_trusted_keys(keys: Vec<VerifyingKey>) -> Self {
KernelPackVerifier {
trusted_keys: keys,
require_signature: true,
}
}
/// Create a verifier that doesn't require signatures (for development)
///
/// # Warning
/// This should NEVER be used in production as it bypasses security checks.
pub fn insecure_no_verify() -> Self {
KernelPackVerifier {
trusted_keys: Vec::new(),
require_signature: false,
}
}
/// Add a trusted public key from bytes
pub fn add_trusted_key(&mut self, key_bytes: &[u8; 32]) -> Result<(), VerifyError> {
let key = VerifyingKey::from_bytes(key_bytes).map_err(|e| VerifyError::KeyError {
message: e.to_string(),
})?;
self.trusted_keys.push(key);
Ok(())
}
/// Add a trusted public key from hex string
pub fn add_trusted_key_hex(&mut self, hex: &str) -> Result<(), VerifyError> {
// Remove "ed25519:" prefix if present
let hex = hex.strip_prefix("ed25519:").unwrap_or(hex);
let bytes = hex::decode(hex).map_err(|e| VerifyError::KeyError {
message: format!("Invalid hex: {}", e),
})?;
if bytes.len() != 32 {
return Err(VerifyError::KeyError {
message: format!("Invalid key length: expected 32 bytes, got {}", bytes.len()),
});
}
let mut key_bytes = [0u8; 32];
key_bytes.copy_from_slice(&bytes);
self.add_trusted_key(&key_bytes)
}
/// Add a trusted public key from base64 string
pub fn add_trusted_key_base64(&mut self, b64: &str) -> Result<(), VerifyError> {
// Remove "ed25519:" prefix if present
let b64 = b64.strip_prefix("ed25519:").unwrap_or(b64);
use base64::{engine::general_purpose::STANDARD, Engine};
let bytes = STANDARD.decode(b64).map_err(|e| VerifyError::KeyError {
message: format!("Invalid base64: {}", e),
})?;
if bytes.len() != 32 {
return Err(VerifyError::KeyError {
message: format!("Invalid key length: expected 32 bytes, got {}", bytes.len()),
});
}
let mut key_bytes = [0u8; 32];
key_bytes.copy_from_slice(&bytes);
self.add_trusted_key(&key_bytes)
}
/// Verify manifest signature against trusted keys
///
/// # Arguments
/// * `manifest` - The manifest bytes to verify
/// * `signature` - The signature bytes (64 bytes)
///
/// # Returns
/// * `Ok(())` if signature is valid and from a trusted key
/// * `Err(VerifyError::NoTrustedKey)` if no trusted key verified the signature
pub fn verify(&self, manifest: &[u8], signature: &[u8]) -> Result<(), VerifyError> {
// Skip verification if disabled (development mode)
if !self.require_signature {
return Ok(());
}
// Check we have trusted keys
if self.trusted_keys.is_empty() {
return Err(VerifyError::NoTrustedKey);
}
// Parse signature
let sig = Signature::from_slice(signature).map_err(|e| VerifyError::InvalidSignature {
reason: format!("Invalid signature format: {}", e),
})?;
// Try each trusted key
for key in &self.trusted_keys {
if key.verify(manifest, &sig).is_ok() {
return Ok(());
}
}
Err(VerifyError::NoTrustedKey)
}
/// Verify manifest with signature from hex string
pub fn verify_hex(&self, manifest: &[u8], signature_hex: &str) -> Result<(), VerifyError> {
let signature = hex::decode(signature_hex).map_err(|e| VerifyError::InvalidSignature {
reason: format!("Invalid hex signature: {}", e),
})?;
self.verify(manifest, &signature)
}
/// Verify manifest with signature from base64 string
pub fn verify_base64(&self, manifest: &[u8], signature_b64: &str) -> Result<(), VerifyError> {
use base64::{engine::general_purpose::STANDARD, Engine};
let signature =
STANDARD
.decode(signature_b64)
.map_err(|e| VerifyError::InvalidSignature {
reason: format!("Invalid base64 signature: {}", e),
})?;
self.verify(manifest, &signature)
}
/// Get number of trusted keys
pub fn trusted_key_count(&self) -> usize {
self.trusted_keys.len()
}
/// Check if signature verification is required
pub fn is_verification_required(&self) -> bool {
self.require_signature
}
}
impl Default for KernelPackVerifier {
fn default() -> Self {
Self::new()
}
}
/// Utility function to sign a manifest (for kernel pack creation)
#[cfg(feature = "signing")]
pub fn sign_manifest(manifest: &[u8], signing_key: &ed25519_dalek::SigningKey) -> Vec<u8> {
use ed25519_dalek::Signer;
signing_key.sign(manifest).to_bytes().to_vec()
}
#[cfg(test)]
mod tests {
use super::*;
use ed25519_dalek::SigningKey;
fn generate_key_pair() -> (SigningKey, VerifyingKey) {
// Use a fixed test seed for reproducibility
let mut seed = [0u8; 32];
// Simple deterministic seed based on test
for (i, b) in seed.iter_mut().enumerate() {
*b = (i * 7 + 13) as u8;
}
let signing_key = SigningKey::from_bytes(&seed);
let verifying_key = signing_key.verifying_key();
(signing_key, verifying_key)
}
#[test]
fn test_verify_success() {
use ed25519_dalek::Signer;
let (signing_key, verifying_key) = generate_key_pair();
let manifest = b"test manifest content";
let signature = signing_key.sign(manifest);
let mut verifier = KernelPackVerifier::new();
verifier.trusted_keys.push(verifying_key);
assert!(verifier.verify(manifest, &signature.to_bytes()).is_ok());
}
#[test]
fn test_verify_wrong_key() {
use ed25519_dalek::Signer;
let (signing_key, _) = generate_key_pair();
let (_, wrong_verifying_key) = generate_key_pair();
let manifest = b"test manifest content";
let signature = signing_key.sign(manifest);
let mut verifier = KernelPackVerifier::new();
verifier.trusted_keys.push(wrong_verifying_key);
let result = verifier.verify(manifest, &signature.to_bytes());
assert!(matches!(result, Err(VerifyError::NoTrustedKey)));
}
#[test]
fn test_verify_no_keys() {
let verifier = KernelPackVerifier::new();
let manifest = b"test manifest";
let signature = [0u8; 64];
let result = verifier.verify(manifest, &signature);
assert!(matches!(result, Err(VerifyError::NoTrustedKey)));
}
#[test]
fn test_insecure_no_verify() {
let verifier = KernelPackVerifier::insecure_no_verify();
let manifest = b"test manifest";
let invalid_signature = [0u8; 64];
// Should pass even with invalid signature
assert!(verifier.verify(manifest, &invalid_signature).is_ok());
assert!(!verifier.is_verification_required());
}
#[test]
fn test_add_key_hex() {
let mut verifier = KernelPackVerifier::new();
// Valid 32-byte key in hex
let hex_key = "0000000000000000000000000000000000000000000000000000000000000000";
// Note: This is a degenerate key but tests the parsing
let result = verifier.add_trusted_key_hex(hex_key);
// This specific key may or may not be valid depending on curve requirements
// The important thing is that hex parsing works
assert!(result.is_ok() || matches!(result, Err(VerifyError::KeyError { .. })));
}
#[test]
fn test_add_key_with_prefix() {
let mut verifier = KernelPackVerifier::new();
// Key with ed25519: prefix
let prefixed_key =
"ed25519:0000000000000000000000000000000000000000000000000000000000000000";
let _ = verifier.add_trusted_key_hex(prefixed_key);
// Just testing that prefix stripping works
}
#[test]
fn test_invalid_hex() {
let mut verifier = KernelPackVerifier::new();
let invalid = "not_valid_hex";
let result = verifier.add_trusted_key_hex(invalid);
assert!(matches!(result, Err(VerifyError::KeyError { .. })));
}
#[test]
fn test_wrong_key_length() {
let mut verifier = KernelPackVerifier::new();
let short_key = "0000000000000000"; // 8 bytes
let result = verifier.add_trusted_key_hex(short_key);
assert!(matches!(result, Err(VerifyError::KeyError { .. })));
}
}

View File

@@ -0,0 +1,896 @@
//! WASM bindings for Ruvector
//!
//! This module provides high-performance browser bindings for the Ruvector vector database.
//! Features:
//! - Full VectorDB API (insert, search, delete, batch operations)
//! - SIMD acceleration (when available)
//! - Web Workers support for parallel operations
//! - IndexedDB persistence
//! - Zero-copy transfers via transferable objects
//!
//! # Kernel Pack System (ADR-005)
//!
//! When compiled with the `kernel-pack` feature, this crate also provides the WASM
//! kernel pack infrastructure for secure, sandboxed execution of ML compute kernels.
//!
//! ```toml
//! [dependencies]
//! ruvector-wasm = { version = "0.1", features = ["kernel-pack"] }
//! ```
//!
//! The kernel pack system includes:
//! - Manifest parsing and validation
//! - Ed25519 signature verification
//! - SHA256 hash verification
//! - Trusted kernel allowlist
//! - Epoch-based execution budgets
//! - Shared memory protocol for tensor data
// Kernel pack module (ADR-005)
#[cfg(feature = "kernel-pack")]
pub mod kernel;
use js_sys::{Array, Float32Array, Object, Promise, Reflect, Uint8Array};
use parking_lot::Mutex;
#[cfg(feature = "collections")]
use ruvector_collections::{
CollectionConfig as CoreCollectionConfig, CollectionManager as CoreCollectionManager,
};
use ruvector_core::{
error::RuvectorError,
types::{DbOptions, DistanceMetric, HnswConfig, SearchQuery, SearchResult, VectorEntry},
vector_db::VectorDB as CoreVectorDB,
};
#[cfg(feature = "collections")]
use ruvector_filter::FilterExpression as CoreFilterExpression;
use serde::{Deserialize, Serialize};
use serde_wasm_bindgen::{from_value, to_value};
use std::collections::HashMap;
use std::sync::Arc;
use wasm_bindgen::prelude::*;
use wasm_bindgen_futures::JsFuture;
use web_sys::{
console, IdbDatabase, IdbFactory, IdbObjectStore, IdbRequest, IdbTransaction, Window,
};
/// Initialize panic hook for better error messages in browser console
#[wasm_bindgen(start)]
pub fn init() {
console_error_panic_hook::set_once();
tracing_wasm::set_as_global_default();
}
/// WASM-specific error type that can cross the JS boundary
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WasmError {
pub message: String,
pub kind: String,
}
impl From<RuvectorError> for WasmError {
fn from(err: RuvectorError) -> Self {
WasmError {
message: err.to_string(),
kind: format!("{:?}", err),
}
}
}
impl From<WasmError> for JsValue {
fn from(err: WasmError) -> Self {
let obj = Object::new();
Reflect::set(&obj, &"message".into(), &err.message.into()).unwrap();
Reflect::set(&obj, &"kind".into(), &err.kind.into()).unwrap();
obj.into()
}
}
type WasmResult<T> = Result<T, WasmError>;
/// JavaScript-compatible VectorEntry
#[wasm_bindgen]
#[derive(Clone)]
pub struct JsVectorEntry {
inner: VectorEntry,
}
/// Maximum allowed vector dimensions (security limit to prevent DoS)
const MAX_VECTOR_DIMENSIONS: usize = 65536;
#[wasm_bindgen]
impl JsVectorEntry {
#[wasm_bindgen(constructor)]
pub fn new(
vector: Float32Array,
id: Option<String>,
metadata: Option<JsValue>,
) -> Result<JsVectorEntry, JsValue> {
// Security: Validate vector dimensions before allocation
let vec_len = vector.length() as usize;
if vec_len == 0 {
return Err(JsValue::from_str("Vector cannot be empty"));
}
if vec_len > MAX_VECTOR_DIMENSIONS {
return Err(JsValue::from_str(&format!(
"Vector dimensions {} exceed maximum allowed {}",
vec_len, MAX_VECTOR_DIMENSIONS
)));
}
let vector_data: Vec<f32> = vector.to_vec();
let metadata = if let Some(meta) = metadata {
Some(
from_value(meta)
.map_err(|e| JsValue::from_str(&format!("Invalid metadata: {}", e)))?,
)
} else {
None
};
Ok(JsVectorEntry {
inner: VectorEntry {
id,
vector: vector_data,
metadata,
},
})
}
#[wasm_bindgen(getter)]
pub fn id(&self) -> Option<String> {
self.inner.id.clone()
}
#[wasm_bindgen(getter)]
pub fn vector(&self) -> Float32Array {
Float32Array::from(&self.inner.vector[..])
}
#[wasm_bindgen(getter)]
pub fn metadata(&self) -> Option<JsValue> {
self.inner.metadata.as_ref().map(|m| to_value(m).unwrap())
}
}
/// JavaScript-compatible SearchResult
#[wasm_bindgen]
pub struct JsSearchResult {
inner: SearchResult,
}
#[wasm_bindgen]
impl JsSearchResult {
#[wasm_bindgen(getter)]
pub fn id(&self) -> String {
self.inner.id.clone()
}
#[wasm_bindgen(getter)]
pub fn score(&self) -> f32 {
self.inner.score
}
#[wasm_bindgen(getter)]
pub fn vector(&self) -> Option<Float32Array> {
self.inner
.vector
.as_ref()
.map(|v| Float32Array::from(&v[..]))
}
#[wasm_bindgen(getter)]
pub fn metadata(&self) -> Option<JsValue> {
self.inner.metadata.as_ref().map(|m| to_value(m).unwrap())
}
}
/// Main VectorDB class for browser usage
#[wasm_bindgen]
pub struct VectorDB {
db: Arc<Mutex<CoreVectorDB>>,
dimensions: usize,
db_name: String,
}
#[wasm_bindgen]
impl VectorDB {
/// Create a new VectorDB instance
///
/// # Arguments
/// * `dimensions` - Vector dimensions
/// * `metric` - Distance metric ("euclidean", "cosine", "dotproduct", "manhattan")
/// * `use_hnsw` - Whether to use HNSW index for faster search
#[wasm_bindgen(constructor)]
pub fn new(
dimensions: usize,
metric: Option<String>,
use_hnsw: Option<bool>,
) -> Result<VectorDB, JsValue> {
let distance_metric = match metric.as_deref() {
Some("euclidean") => DistanceMetric::Euclidean,
Some("cosine") => DistanceMetric::Cosine,
Some("dotproduct") => DistanceMetric::DotProduct,
Some("manhattan") => DistanceMetric::Manhattan,
None => DistanceMetric::Cosine,
Some(other) => return Err(JsValue::from_str(&format!("Unknown metric: {}", other))),
};
let hnsw_config = if use_hnsw.unwrap_or(true) {
Some(HnswConfig::default())
} else {
None
};
let options = DbOptions {
dimensions,
distance_metric,
storage_path: ":memory:".to_string(), // Use in-memory for WASM
hnsw_config,
quantization: None, // Disable quantization for WASM (for now)
};
let db = CoreVectorDB::new(options).map_err(|e| JsValue::from(WasmError::from(e)))?;
Ok(VectorDB {
db: Arc::new(Mutex::new(db)),
dimensions,
db_name: format!("ruvector_db_{}", js_sys::Date::now()),
})
}
/// Insert a single vector
///
/// # Arguments
/// * `vector` - Float32Array of vector data
/// * `id` - Optional ID (auto-generated if not provided)
/// * `metadata` - Optional metadata object
///
/// # Returns
/// The vector ID
#[wasm_bindgen]
pub fn insert(
&self,
vector: Float32Array,
id: Option<String>,
metadata: Option<JsValue>,
) -> Result<String, JsValue> {
let entry = JsVectorEntry::new(vector, id, metadata)?;
let db = self.db.lock();
let vector_id = db
.insert(entry.inner)
.map_err(|e| JsValue::from(WasmError::from(e)))?;
Ok(vector_id)
}
/// Insert multiple vectors in a batch (more efficient)
///
/// # Arguments
/// * `entries` - Array of VectorEntry objects
///
/// # Returns
/// Array of vector IDs
#[wasm_bindgen(js_name = insertBatch)]
pub fn insert_batch(&self, entries: JsValue) -> Result<Vec<String>, JsValue> {
// Convert JsValue to Array using reflection
let entries_array: js_sys::Array = entries
.dyn_into()
.map_err(|_| JsValue::from_str("entries must be an array"))?;
let mut vector_entries = Vec::new();
for i in 0..entries_array.length() {
let js_entry = entries_array.get(i);
let vector_arr: Float32Array = Reflect::get(&js_entry, &"vector".into())?.dyn_into()?;
let id: Option<String> = Reflect::get(&js_entry, &"id".into())?.as_string();
let metadata = Reflect::get(&js_entry, &"metadata".into()).ok();
let entry = JsVectorEntry::new(vector_arr, id, metadata)?;
vector_entries.push(entry.inner);
}
let db = self.db.lock();
let ids = db
.insert_batch(vector_entries)
.map_err(|e| JsValue::from(WasmError::from(e)))?;
Ok(ids)
}
/// Search for similar vectors
///
/// # Arguments
/// * `query` - Query vector as Float32Array
/// * `k` - Number of results to return
/// * `filter` - Optional metadata filter object
///
/// # Returns
/// Array of search results
#[wasm_bindgen]
pub fn search(
&self,
query: Float32Array,
k: usize,
filter: Option<JsValue>,
) -> Result<Vec<JsSearchResult>, JsValue> {
let query_vector: Vec<f32> = query.to_vec();
if query_vector.len() != self.dimensions {
return Err(JsValue::from_str(&format!(
"Query vector dimension mismatch: expected {}, got {}",
self.dimensions,
query_vector.len()
)));
}
let metadata_filter = if let Some(f) = filter {
Some(from_value(f).map_err(|e| JsValue::from_str(&format!("Invalid filter: {}", e)))?)
} else {
None
};
let search_query = SearchQuery {
vector: query_vector,
k,
filter: metadata_filter,
ef_search: None,
};
let db = self.db.lock();
let results = db
.search(search_query)
.map_err(|e| JsValue::from(WasmError::from(e)))?;
Ok(results
.into_iter()
.map(|r| JsSearchResult { inner: r })
.collect())
}
/// Delete a vector by ID
///
/// # Arguments
/// * `id` - Vector ID to delete
///
/// # Returns
/// True if deleted, false if not found
#[wasm_bindgen]
pub fn delete(&self, id: &str) -> Result<bool, JsValue> {
let db = self.db.lock();
db.delete(id).map_err(|e| JsValue::from(WasmError::from(e)))
}
/// Get a vector by ID
///
/// # Arguments
/// * `id` - Vector ID
///
/// # Returns
/// VectorEntry or null if not found
#[wasm_bindgen]
pub fn get(&self, id: &str) -> Result<Option<JsVectorEntry>, JsValue> {
let db = self.db.lock();
let entry = db.get(id).map_err(|e| JsValue::from(WasmError::from(e)))?;
Ok(entry.map(|e| JsVectorEntry { inner: e }))
}
/// Get the number of vectors in the database
#[wasm_bindgen]
pub fn len(&self) -> Result<usize, JsValue> {
let db = self.db.lock();
db.len().map_err(|e| JsValue::from(WasmError::from(e)))
}
/// Check if the database is empty
#[wasm_bindgen(js_name = isEmpty)]
pub fn is_empty(&self) -> Result<bool, JsValue> {
let db = self.db.lock();
db.is_empty().map_err(|e| JsValue::from(WasmError::from(e)))
}
/// Get database dimensions
#[wasm_bindgen(getter)]
pub fn dimensions(&self) -> usize {
self.dimensions
}
/// Save database to IndexedDB
/// Returns a Promise that resolves when save is complete
#[wasm_bindgen(js_name = saveToIndexedDB)]
pub fn save_to_indexed_db(&self) -> Result<Promise, JsValue> {
let db_name = self.db_name.clone();
// For now, log that we would save to IndexedDB
// Full implementation would serialize the database state
console::log_1(&format!("Saving database '{}' to IndexedDB...", db_name).into());
// Return resolved promise
Ok(Promise::resolve(&JsValue::TRUE))
}
/// Load database from IndexedDB
/// Returns a Promise that resolves with the VectorDB instance
#[wasm_bindgen(js_name = loadFromIndexedDB)]
pub fn load_from_indexed_db(db_name: String) -> Result<Promise, JsValue> {
console::log_1(&format!("Loading database '{}' from IndexedDB...", db_name).into());
// Return rejected promise for now (not implemented)
Ok(Promise::reject(&JsValue::from_str("Not yet implemented")))
}
}
/// Detect SIMD support in the current environment
#[wasm_bindgen(js_name = detectSIMD)]
pub fn detect_simd() -> bool {
// Check for WebAssembly SIMD support
#[cfg(target_feature = "simd128")]
{
true
}
#[cfg(not(target_feature = "simd128"))]
{
false
}
}
/// Get version information
#[wasm_bindgen]
pub fn version() -> String {
env!("CARGO_PKG_VERSION").to_string()
}
/// Utility: Convert JavaScript array to Float32Array
#[wasm_bindgen(js_name = arrayToFloat32Array)]
pub fn array_to_float32_array(arr: Vec<f32>) -> Float32Array {
Float32Array::from(&arr[..])
}
/// Utility: Measure performance of an operation
#[wasm_bindgen(js_name = benchmark)]
pub fn benchmark(name: &str, iterations: usize, dimensions: usize) -> Result<f64, JsValue> {
use std::time::Instant;
console::log_1(
&format!(
"Running benchmark '{}' with {} iterations...",
name, iterations
)
.into(),
);
let db = VectorDB::new(dimensions, Some("cosine".to_string()), Some(false))?;
let start = Instant::now();
for i in 0..iterations {
let vector: Vec<f32> = (0..dimensions)
.map(|_| js_sys::Math::random() as f32)
.collect();
let vector_arr = Float32Array::from(&vector[..]);
db.insert(vector_arr, Some(format!("vec_{}", i)), None)?;
}
let duration = start.elapsed();
let ops_per_sec = iterations as f64 / duration.as_secs_f64();
console::log_1(&format!("Benchmark complete: {:.2} ops/sec", ops_per_sec).into());
Ok(ops_per_sec)
}
// ===== Collection Manager =====
// Note: Collections are not available in standard WASM builds due to file I/O requirements
// To use collections, compile with the "collections" feature (requires WASI or server environment)
#[cfg(feature = "collections")]
/// WASM Collection Manager for multi-collection support
#[wasm_bindgen]
pub struct CollectionManager {
inner: Arc<Mutex<CoreCollectionManager>>,
}
#[cfg(feature = "collections")]
#[wasm_bindgen]
impl CollectionManager {
/// Create a new CollectionManager
///
/// # Arguments
/// * `base_path` - Optional base path for storing collections (defaults to ":memory:")
#[wasm_bindgen(constructor)]
pub fn new(base_path: Option<String>) -> Result<CollectionManager, JsValue> {
let path = base_path.unwrap_or_else(|| ":memory:".to_string());
let manager = CoreCollectionManager::new(std::path::PathBuf::from(path)).map_err(|e| {
JsValue::from_str(&format!("Failed to create collection manager: {}", e))
})?;
Ok(CollectionManager {
inner: Arc::new(Mutex::new(manager)),
})
}
/// Create a new collection
///
/// # Arguments
/// * `name` - Collection name (alphanumeric, hyphens, underscores only)
/// * `dimensions` - Vector dimensions
/// * `metric` - Optional distance metric ("euclidean", "cosine", "dotproduct", "manhattan")
#[wasm_bindgen(js_name = createCollection)]
pub fn create_collection(
&self,
name: &str,
dimensions: usize,
metric: Option<String>,
) -> Result<(), JsValue> {
let distance_metric = match metric.as_deref() {
Some("euclidean") => DistanceMetric::Euclidean,
Some("cosine") => DistanceMetric::Cosine,
Some("dotproduct") => DistanceMetric::DotProduct,
Some("manhattan") => DistanceMetric::Manhattan,
None => DistanceMetric::Cosine,
Some(other) => return Err(JsValue::from_str(&format!("Unknown metric: {}", other))),
};
let config = CoreCollectionConfig {
dimensions,
distance_metric,
hnsw_config: Some(HnswConfig::default()),
quantization: None,
on_disk_payload: false, // Disable for WASM
};
let manager = self.inner.lock();
manager
.create_collection(name, config)
.map_err(|e| JsValue::from_str(&format!("Failed to create collection: {}", e)))?;
Ok(())
}
/// List all collections
///
/// # Returns
/// Array of collection names
#[wasm_bindgen(js_name = listCollections)]
pub fn list_collections(&self) -> Vec<String> {
let manager = self.inner.lock();
manager.list_collections()
}
/// Delete a collection
///
/// # Arguments
/// * `name` - Collection name to delete
///
/// # Errors
/// Returns error if collection has active aliases
#[wasm_bindgen(js_name = deleteCollection)]
pub fn delete_collection(&self, name: &str) -> Result<(), JsValue> {
let manager = self.inner.lock();
manager
.delete_collection(name)
.map_err(|e| JsValue::from_str(&format!("Failed to delete collection: {}", e)))?;
Ok(())
}
/// Get a collection's VectorDB
///
/// # Arguments
/// * `name` - Collection name or alias
///
/// # Returns
/// VectorDB instance or error if not found
#[wasm_bindgen(js_name = getCollection)]
pub fn get_collection(&self, name: &str) -> Result<VectorDB, JsValue> {
let manager = self.inner.lock();
let collection_ref = manager
.get_collection(name)
.ok_or_else(|| JsValue::from_str(&format!("Collection '{}' not found", name)))?;
let collection = collection_ref.read();
// Create a new VectorDB wrapper that shares the underlying database
// Note: For WASM, we'll need to clone the DB state since we can't share references across WASM boundary
// This is a simplified version - in production you might want a different approach
let dimensions = collection.config.dimensions;
let db_name = collection.name.clone();
// For now, return a new VectorDB with the same config
// In a real implementation, you'd want to share the underlying storage
let db_options = DbOptions {
dimensions: collection.config.dimensions,
distance_metric: collection.config.distance_metric,
storage_path: ":memory:".to_string(),
hnsw_config: collection.config.hnsw_config.clone(),
quantization: collection.config.quantization.clone(),
};
let db = CoreVectorDB::new(db_options)
.map_err(|e| JsValue::from_str(&format!("Failed to get collection: {}", e)))?;
Ok(VectorDB {
db: Arc::new(Mutex::new(db)),
dimensions,
db_name,
})
}
/// Create an alias
///
/// # Arguments
/// * `alias` - Alias name (must be unique)
/// * `collection` - Target collection name
#[wasm_bindgen(js_name = createAlias)]
pub fn create_alias(&self, alias: &str, collection: &str) -> Result<(), JsValue> {
let manager = self.inner.lock();
manager
.create_alias(alias, collection)
.map_err(|e| JsValue::from_str(&format!("Failed to create alias: {}", e)))?;
Ok(())
}
/// Delete an alias
///
/// # Arguments
/// * `alias` - Alias name to delete
#[wasm_bindgen(js_name = deleteAlias)]
pub fn delete_alias(&self, alias: &str) -> Result<(), JsValue> {
let manager = self.inner.lock();
manager
.delete_alias(alias)
.map_err(|e| JsValue::from_str(&format!("Failed to delete alias: {}", e)))?;
Ok(())
}
/// List all aliases
///
/// # Returns
/// JavaScript array of [alias, collection] pairs
#[wasm_bindgen(js_name = listAliases)]
pub fn list_aliases(&self) -> JsValue {
let manager = self.inner.lock();
let aliases = manager.list_aliases();
let arr = Array::new();
for (alias, collection) in aliases {
let pair = Array::new();
pair.push(&JsValue::from_str(&alias));
pair.push(&JsValue::from_str(&collection));
arr.push(&pair);
}
arr.into()
}
}
// ===== Filter Builder =====
#[cfg(feature = "collections")]
/// JavaScript-compatible filter builder
#[wasm_bindgen]
pub struct FilterBuilder {
inner: CoreFilterExpression,
}
#[cfg(feature = "collections")]
#[wasm_bindgen]
impl FilterBuilder {
/// Create a new empty filter builder
#[wasm_bindgen(constructor)]
pub fn new() -> FilterBuilder {
// Default to a match-all filter (we'll use exists on a common field)
// Users should use the builder methods instead
FilterBuilder {
inner: CoreFilterExpression::exists("_id"),
}
}
/// Create an equality filter
///
/// # Arguments
/// * `field` - Field name
/// * `value` - Value to match (will be converted from JS)
///
/// # Example
/// ```javascript
/// const filter = FilterBuilder.eq("status", "active");
/// ```
pub fn eq(field: &str, value: JsValue) -> Result<FilterBuilder, JsValue> {
let json_value: serde_json::Value =
from_value(value).map_err(|e| JsValue::from_str(&format!("Invalid value: {}", e)))?;
Ok(FilterBuilder {
inner: CoreFilterExpression::eq(field, json_value),
})
}
/// Create a not-equal filter
pub fn ne(field: &str, value: JsValue) -> Result<FilterBuilder, JsValue> {
let json_value: serde_json::Value =
from_value(value).map_err(|e| JsValue::from_str(&format!("Invalid value: {}", e)))?;
Ok(FilterBuilder {
inner: CoreFilterExpression::ne(field, json_value),
})
}
/// Create a greater-than filter
pub fn gt(field: &str, value: JsValue) -> Result<FilterBuilder, JsValue> {
let json_value: serde_json::Value =
from_value(value).map_err(|e| JsValue::from_str(&format!("Invalid value: {}", e)))?;
Ok(FilterBuilder {
inner: CoreFilterExpression::gt(field, json_value),
})
}
/// Create a greater-than-or-equal filter
pub fn gte(field: &str, value: JsValue) -> Result<FilterBuilder, JsValue> {
let json_value: serde_json::Value =
from_value(value).map_err(|e| JsValue::from_str(&format!("Invalid value: {}", e)))?;
Ok(FilterBuilder {
inner: CoreFilterExpression::gte(field, json_value),
})
}
/// Create a less-than filter
pub fn lt(field: &str, value: JsValue) -> Result<FilterBuilder, JsValue> {
let json_value: serde_json::Value =
from_value(value).map_err(|e| JsValue::from_str(&format!("Invalid value: {}", e)))?;
Ok(FilterBuilder {
inner: CoreFilterExpression::lt(field, json_value),
})
}
/// Create a less-than-or-equal filter
pub fn lte(field: &str, value: JsValue) -> Result<FilterBuilder, JsValue> {
let json_value: serde_json::Value =
from_value(value).map_err(|e| JsValue::from_str(&format!("Invalid value: {}", e)))?;
Ok(FilterBuilder {
inner: CoreFilterExpression::lte(field, json_value),
})
}
/// Create an IN filter (field matches any of the values)
///
/// # Arguments
/// * `field` - Field name
/// * `values` - Array of values
#[wasm_bindgen(js_name = "in")]
pub fn in_values(field: &str, values: JsValue) -> Result<FilterBuilder, JsValue> {
let json_values: Vec<serde_json::Value> = from_value(values)
.map_err(|e| JsValue::from_str(&format!("Invalid values array: {}", e)))?;
Ok(FilterBuilder {
inner: CoreFilterExpression::in_values(field, json_values),
})
}
/// Create a text match filter
///
/// # Arguments
/// * `field` - Field name
/// * `text` - Text to search for
#[wasm_bindgen(js_name = matchText)]
pub fn match_text(field: &str, text: &str) -> FilterBuilder {
FilterBuilder {
inner: CoreFilterExpression::match_text(field, text),
}
}
/// Create a geo radius filter
///
/// # Arguments
/// * `field` - Field name (should contain {lat, lon} object)
/// * `lat` - Center latitude
/// * `lon` - Center longitude
/// * `radius_m` - Radius in meters
#[wasm_bindgen(js_name = geoRadius)]
pub fn geo_radius(field: &str, lat: f64, lon: f64, radius_m: f64) -> FilterBuilder {
FilterBuilder {
inner: CoreFilterExpression::geo_radius(field, lat, lon, radius_m),
}
}
/// Combine filters with AND
///
/// # Arguments
/// * `filters` - Array of FilterBuilder instances
pub fn and(filters: Vec<FilterBuilder>) -> FilterBuilder {
let inner_filters: Vec<CoreFilterExpression> =
filters.into_iter().map(|f| f.inner).collect();
FilterBuilder {
inner: CoreFilterExpression::and(inner_filters),
}
}
/// Combine filters with OR
///
/// # Arguments
/// * `filters` - Array of FilterBuilder instances
pub fn or(filters: Vec<FilterBuilder>) -> FilterBuilder {
let inner_filters: Vec<CoreFilterExpression> =
filters.into_iter().map(|f| f.inner).collect();
FilterBuilder {
inner: CoreFilterExpression::or(inner_filters),
}
}
/// Negate a filter with NOT
///
/// # Arguments
/// * `filter` - FilterBuilder instance to negate
pub fn not(filter: FilterBuilder) -> FilterBuilder {
FilterBuilder {
inner: CoreFilterExpression::not(filter.inner),
}
}
/// Create an EXISTS filter (field is present)
pub fn exists(field: &str) -> FilterBuilder {
FilterBuilder {
inner: CoreFilterExpression::exists(field),
}
}
/// Create an IS NULL filter (field is null)
#[wasm_bindgen(js_name = isNull)]
pub fn is_null(field: &str) -> FilterBuilder {
FilterBuilder {
inner: CoreFilterExpression::is_null(field),
}
}
/// Convert to JSON for use with search
///
/// # Returns
/// JavaScript object representing the filter
#[wasm_bindgen(js_name = toJson)]
pub fn to_json(&self) -> Result<JsValue, JsValue> {
to_value(&self.inner)
.map_err(|e| JsValue::from_str(&format!("Failed to serialize filter: {}", e)))
}
/// Get all field names referenced in this filter
#[wasm_bindgen(js_name = getFields)]
pub fn get_fields(&self) -> Vec<String> {
self.inner.get_fields()
}
}
#[cfg(feature = "collections")]
impl Default for FilterBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use wasm_bindgen_test::*;
wasm_bindgen_test_configure!(run_in_browser);
#[wasm_bindgen_test]
fn test_version() {
assert!(!version().is_empty());
}
#[wasm_bindgen_test]
fn test_detect_simd() {
// Just ensure it doesn't panic
let _ = detect_simd();
}
}

View File

@@ -0,0 +1,254 @@
/**
* Web Worker Pool Manager
*
* Manages a pool of workers for parallel vector operations.
* Supports:
* - Round-robin task distribution
* - Load balancing
* - Automatic worker initialization
* - Promise-based API
*/
export class WorkerPool {
constructor(workerUrl, wasmUrl, options = {}) {
this.workerUrl = workerUrl;
this.wasmUrl = wasmUrl;
this.poolSize = options.poolSize || navigator.hardwareConcurrency || 4;
this.workers = [];
this.nextWorker = 0;
this.pendingRequests = new Map();
this.requestId = 0;
this.initialized = false;
this.options = options;
}
/**
* Initialize the worker pool
*/
async init() {
if (this.initialized) return;
console.log(`Initializing worker pool with ${this.poolSize} workers...`);
const initPromises = [];
for (let i = 0; i < this.poolSize; i++) {
const worker = new Worker(this.workerUrl, { type: 'module' });
worker.onmessage = (e) => this.handleMessage(i, e);
worker.onerror = (error) => this.handleError(i, error);
this.workers.push({
worker,
busy: false,
id: i
});
// Initialize worker with WASM
const initPromise = this.sendToWorker(i, 'init', {
wasmUrl: this.wasmUrl,
dimensions: this.options.dimensions,
metric: this.options.metric,
useHnsw: this.options.useHnsw
});
initPromises.push(initPromise);
}
await Promise.all(initPromises);
this.initialized = true;
console.log(`Worker pool initialized successfully`);
}
/**
* Handle message from worker
*/
handleMessage(workerId, event) {
const { type, requestId, data, error } = event.data;
if (type === 'error') {
const request = this.pendingRequests.get(requestId);
if (request) {
request.reject(new Error(error.message));
this.pendingRequests.delete(requestId);
}
return;
}
const request = this.pendingRequests.get(requestId);
if (request) {
this.workers[workerId].busy = false;
request.resolve(data);
this.pendingRequests.delete(requestId);
}
}
/**
* Handle worker error
*/
handleError(workerId, error) {
console.error(`Worker ${workerId} error:`, error);
// Reject all pending requests for this worker
for (const [requestId, request] of this.pendingRequests) {
if (request.workerId === workerId) {
request.reject(error);
this.pendingRequests.delete(requestId);
}
}
}
/**
* Get next available worker (round-robin)
*/
getNextWorker() {
// Try to find an idle worker
for (let i = 0; i < this.workers.length; i++) {
const idx = (this.nextWorker + i) % this.workers.length;
if (!this.workers[idx].busy) {
this.nextWorker = (idx + 1) % this.workers.length;
return idx;
}
}
// All busy, use round-robin
const idx = this.nextWorker;
this.nextWorker = (this.nextWorker + 1) % this.workers.length;
return idx;
}
/**
* Send message to specific worker
*/
sendToWorker(workerId, type, data) {
return new Promise((resolve, reject) => {
const requestId = this.requestId++;
this.pendingRequests.set(requestId, {
resolve,
reject,
workerId,
timestamp: Date.now()
});
this.workers[workerId].busy = true;
this.workers[workerId].worker.postMessage({
type,
data: { ...data, requestId }
});
// Timeout after 30 seconds
setTimeout(() => {
if (this.pendingRequests.has(requestId)) {
this.pendingRequests.delete(requestId);
reject(new Error('Request timeout'));
}
}, 30000);
});
}
/**
* Execute operation on next available worker
*/
async execute(type, data) {
if (!this.initialized) {
await this.init();
}
const workerId = this.getNextWorker();
return this.sendToWorker(workerId, type, data);
}
/**
* Insert vector
*/
async insert(vector, id = null, metadata = null) {
return this.execute('insert', { vector, id, metadata });
}
/**
* Insert batch of vectors
*/
async insertBatch(entries) {
// Distribute batch across workers
const chunkSize = Math.ceil(entries.length / this.poolSize);
const chunks = [];
for (let i = 0; i < entries.length; i += chunkSize) {
chunks.push(entries.slice(i, i + chunkSize));
}
const promises = chunks.map((chunk, i) =>
this.sendToWorker(i % this.poolSize, 'insertBatch', { entries: chunk })
);
const results = await Promise.all(promises);
return results.flat();
}
/**
* Search for similar vectors
*/
async search(query, k = 10, filter = null) {
return this.execute('search', { query, k, filter });
}
/**
* Parallel search across multiple queries
*/
async searchBatch(queries, k = 10, filter = null) {
const promises = queries.map((query, i) =>
this.sendToWorker(i % this.poolSize, 'search', { query, k, filter })
);
return Promise.all(promises);
}
/**
* Delete vector
*/
async delete(id) {
return this.execute('delete', { id });
}
/**
* Get vector by ID
*/
async get(id) {
return this.execute('get', { id });
}
/**
* Get database length (from first worker)
*/
async len() {
return this.sendToWorker(0, 'len', {});
}
/**
* Terminate all workers
*/
terminate() {
for (const { worker } of this.workers) {
worker.terminate();
}
this.workers = [];
this.initialized = false;
console.log('Worker pool terminated');
}
/**
* Get pool statistics
*/
getStats() {
return {
poolSize: this.poolSize,
busyWorkers: this.workers.filter(w => w.busy).length,
idleWorkers: this.workers.filter(w => !w.busy).length,
pendingRequests: this.pendingRequests.size
};
}
}
export default WorkerPool;

View File

@@ -0,0 +1,184 @@
/**
* Web Worker for parallel vector search operations
*
* This worker handles:
* - Vector search operations in parallel
* - Batch insert operations
* - Zero-copy transfers via transferable objects
*/
// Import the WASM module
let wasmModule = null;
let vectorDB = null;
/**
* Initialize the worker with WASM module
*/
self.onmessage = async function(e) {
const { type, data } = e.data;
try {
switch (type) {
case 'init':
await initWorker(data);
self.postMessage({ type: 'init', success: true });
break;
case 'insert':
await handleInsert(data);
break;
case 'insertBatch':
await handleInsertBatch(data);
break;
case 'search':
await handleSearch(data);
break;
case 'delete':
await handleDelete(data);
break;
case 'get':
await handleGet(data);
break;
case 'len':
const length = vectorDB.len();
self.postMessage({ type: 'len', data: length });
break;
default:
throw new Error(`Unknown message type: ${type}`);
}
} catch (error) {
self.postMessage({
type: 'error',
error: {
message: error.message,
stack: error.stack
}
});
}
};
/**
* Initialize WASM module and VectorDB
*/
async function initWorker(config) {
const { wasmUrl, dimensions, metric, useHnsw } = config;
// Import WASM module
wasmModule = await import(wasmUrl);
// Initialize WASM
await wasmModule.default();
// Create VectorDB instance
vectorDB = new wasmModule.VectorDB(dimensions, metric, useHnsw);
console.log(`Worker initialized with dimensions=${dimensions}, metric=${metric}, SIMD=${wasmModule.detectSIMD()}`);
}
/**
* Handle single vector insert
*/
async function handleInsert(data) {
const { vector, id, metadata, requestId } = data;
// Convert array to Float32Array if needed
const vectorArray = new Float32Array(vector);
const resultId = vectorDB.insert(vectorArray, id, metadata);
self.postMessage({
type: 'insert',
requestId,
data: resultId
});
}
/**
* Handle batch insert
*/
async function handleInsertBatch(data) {
const { entries, requestId } = data;
// Convert vectors to Float32Array
const processedEntries = entries.map(entry => ({
vector: new Float32Array(entry.vector),
id: entry.id,
metadata: entry.metadata
}));
const ids = vectorDB.insertBatch(processedEntries);
self.postMessage({
type: 'insertBatch',
requestId,
data: ids
});
}
/**
* Handle vector search
*/
async function handleSearch(data) {
const { query, k, filter, requestId } = data;
// Convert query to Float32Array
const queryArray = new Float32Array(query);
const results = vectorDB.search(queryArray, k, filter);
// Convert results to plain objects
const plainResults = results.map(result => ({
id: result.id,
score: result.score,
vector: result.vector ? Array.from(result.vector) : null,
metadata: result.metadata
}));
self.postMessage({
type: 'search',
requestId,
data: plainResults
});
}
/**
* Handle delete operation
*/
async function handleDelete(data) {
const { id, requestId } = data;
const deleted = vectorDB.delete(id);
self.postMessage({
type: 'delete',
requestId,
data: deleted
});
}
/**
* Handle get operation
*/
async function handleGet(data) {
const { id, requestId } = data;
const entry = vectorDB.get(id);
const plainEntry = entry ? {
id: entry.id,
vector: Array.from(entry.vector),
metadata: entry.metadata
} : null;
self.postMessage({
type: 'get',
requestId,
data: plainEntry
});
}