Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
This commit is contained in:
21
crates/ruvector-sparse-inference-wasm/.gitignore
vendored
Normal file
21
crates/ruvector-sparse-inference-wasm/.gitignore
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
# WASM build outputs
|
||||
pkg/
|
||||
target/
|
||||
*.wasm
|
||||
*.wasm.map
|
||||
|
||||
# Node
|
||||
node_modules/
|
||||
npm-debug.log
|
||||
yarn-error.log
|
||||
|
||||
# Editor
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
42
crates/ruvector-sparse-inference-wasm/Cargo.toml
Normal file
42
crates/ruvector-sparse-inference-wasm/Cargo.toml
Normal file
@@ -0,0 +1,42 @@
|
||||
[package]
|
||||
name = "ruvector-sparse-inference-wasm"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
rust-version.workspace = true
|
||||
license.workspace = true
|
||||
description = "WebAssembly bindings for PowerInfer-style sparse inference"
|
||||
|
||||
[lib]
|
||||
crate-type = ["cdylib", "rlib"]
|
||||
|
||||
[features]
|
||||
default = ["console_error_panic_hook"]
|
||||
console_error_panic_hook = ["dep:console_error_panic_hook"]
|
||||
|
||||
[dependencies]
|
||||
ruvector-sparse-inference = { path = "../ruvector-sparse-inference" }
|
||||
|
||||
wasm-bindgen = { workspace = true }
|
||||
wasm-bindgen-futures = { workspace = true }
|
||||
js-sys = { workspace = true }
|
||||
web-sys = { workspace = true, features = [
|
||||
"console",
|
||||
"Performance",
|
||||
"Window",
|
||||
"WorkerGlobalScope",
|
||||
"Response",
|
||||
] }
|
||||
getrandom = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
serde-wasm-bindgen = "0.6"
|
||||
|
||||
console_error_panic_hook = { version = "0.1", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
wasm-bindgen-test = "0.3"
|
||||
|
||||
[profile.release]
|
||||
opt-level = 3
|
||||
lto = true
|
||||
codegen-units = 1
|
||||
330
crates/ruvector-sparse-inference-wasm/README.md
Normal file
330
crates/ruvector-sparse-inference-wasm/README.md
Normal file
@@ -0,0 +1,330 @@
|
||||
# ruvector-sparse-inference-wasm
|
||||
|
||||
WebAssembly bindings for PowerInfer-style sparse inference engine.
|
||||
|
||||
## Overview
|
||||
|
||||
This crate provides WASM bindings for the RuVector sparse inference engine, enabling efficient neural network inference in web browsers and Node.js environments with:
|
||||
|
||||
- **Sparse Activation**: PowerInfer-style neuron prediction for 2-3x speedup
|
||||
- **GGUF Support**: Load quantized models in GGUF format
|
||||
- **Streaming Loading**: Fetch large models incrementally
|
||||
- **Multiple Backends**: Embedding models and LLM text generation
|
||||
|
||||
## Building
|
||||
|
||||
### For Web Browsers
|
||||
|
||||
```bash
|
||||
wasm-pack build --target web --release
|
||||
```
|
||||
|
||||
### For Node.js
|
||||
|
||||
```bash
|
||||
wasm-pack build --target nodejs --release
|
||||
```
|
||||
|
||||
### For Bundlers (webpack, rollup, etc.)
|
||||
|
||||
```bash
|
||||
wasm-pack build --target bundler --release
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
npm install ruvector-sparse-inference-wasm
|
||||
```
|
||||
|
||||
Or build locally:
|
||||
|
||||
```bash
|
||||
wasm-pack build --target web
|
||||
cd pkg && npm link
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Inference Engine
|
||||
|
||||
```typescript
|
||||
import init, { SparseInferenceEngine } from 'ruvector-sparse-inference-wasm';
|
||||
|
||||
// Initialize WASM module
|
||||
await init();
|
||||
|
||||
// Load model
|
||||
const modelBytes = await fetch('/models/llama-2-7b.gguf').then(r => r.arrayBuffer());
|
||||
const config = {
|
||||
sparsity: {
|
||||
enabled: true,
|
||||
threshold: 0.1 // 10% neuron activation
|
||||
},
|
||||
temperature: 0.7,
|
||||
top_k: 40
|
||||
};
|
||||
|
||||
const engine = new SparseInferenceEngine(
|
||||
new Uint8Array(modelBytes),
|
||||
JSON.stringify(config)
|
||||
);
|
||||
|
||||
// Run inference
|
||||
const input = new Float32Array(4096); // Your input embedding
|
||||
const output = engine.infer(input);
|
||||
|
||||
console.log('Sparsity stats:', engine.sparsity_stats());
|
||||
console.log('Model metadata:', engine.metadata());
|
||||
```
|
||||
|
||||
### Streaming Model Loading
|
||||
|
||||
For large models (>1GB), use streaming:
|
||||
|
||||
```typescript
|
||||
const engine = await SparseInferenceEngine.load_streaming(
|
||||
'https://example.com/large-model.gguf',
|
||||
JSON.stringify(config)
|
||||
);
|
||||
```
|
||||
|
||||
### Embedding Models
|
||||
|
||||
For sentence transformers and embedding generation:
|
||||
|
||||
```typescript
|
||||
import { EmbeddingModel } from 'ruvector-sparse-inference-wasm';
|
||||
|
||||
const modelBytes = await fetch('/models/all-MiniLM-L6-v2.gguf').then(r => r.arrayBuffer());
|
||||
const embedder = new EmbeddingModel(new Uint8Array(modelBytes));
|
||||
|
||||
// Encode single sequence (requires tokenization first)
|
||||
const inputIds = new Uint32Array([101, 2023, 2003, ...]); // Tokenized input
|
||||
const embedding = embedder.encode(inputIds);
|
||||
|
||||
console.log('Embedding dimension:', embedder.dimension());
|
||||
|
||||
// Batch encoding
|
||||
const batchIds = new Uint32Array([...all tokenized sequences...]);
|
||||
const lengths = new Uint32Array([10, 15, 12]); // Length of each sequence
|
||||
const embeddings = embedder.encode_batch(batchIds, lengths);
|
||||
```
|
||||
|
||||
### LLM Text Generation
|
||||
|
||||
For autoregressive language models:
|
||||
|
||||
```typescript
|
||||
import { LLMModel } from 'ruvector-sparse-inference-wasm';
|
||||
|
||||
const modelBytes = await fetch('/models/llama-2-7b-chat.gguf').then(r => r.arrayBuffer());
|
||||
const config = {
|
||||
sparsity: { enabled: true, threshold: 0.1 },
|
||||
temperature: 0.7,
|
||||
top_k: 40
|
||||
};
|
||||
|
||||
const llm = new LLMModel(new Uint8Array(modelBytes), JSON.stringify(config));
|
||||
|
||||
// Generate tokens one at a time
|
||||
const prompt = new Uint32Array([1, 4321, 1234, ...]); // Tokenized prompt
|
||||
let generatedTokens = [];
|
||||
|
||||
for (let i = 0; i < 100; i++) {
|
||||
const nextToken = llm.next_token(prompt);
|
||||
generatedTokens.push(nextToken);
|
||||
|
||||
// Append to prompt for next iteration
|
||||
prompt = new Uint32Array([...prompt, nextToken]);
|
||||
}
|
||||
|
||||
// Or generate multiple tokens at once
|
||||
const tokens = llm.generate(prompt, 100);
|
||||
|
||||
console.log('Generation stats:', llm.stats());
|
||||
|
||||
// Reset for new conversation
|
||||
llm.reset_cache();
|
||||
```
|
||||
|
||||
### Calibration
|
||||
|
||||
Improve predictor accuracy with sample data:
|
||||
|
||||
```typescript
|
||||
// Collect representative samples
|
||||
const samples = new Float32Array([
|
||||
...embedding1, // 512 dims
|
||||
...embedding2, // 512 dims
|
||||
...embedding3, // 512 dims
|
||||
]);
|
||||
|
||||
engine.calibrate(samples, 512); // 512 = dimension of each sample
|
||||
```
|
||||
|
||||
### Dynamic Sparsity Control
|
||||
|
||||
Adjust sparsity threshold at runtime:
|
||||
|
||||
```typescript
|
||||
// More sparse = faster, less accurate
|
||||
engine.set_sparsity(0.2); // 20% activation
|
||||
|
||||
// Less sparse = slower, more accurate
|
||||
engine.set_sparsity(0.05); // 5% activation
|
||||
```
|
||||
|
||||
### Performance Measurement
|
||||
|
||||
```typescript
|
||||
import { measure_inference_time } from 'ruvector-sparse-inference-wasm';
|
||||
|
||||
const input = new Float32Array(4096);
|
||||
const avgTime = measure_inference_time(engine, input, 100); // 100 iterations
|
||||
|
||||
console.log(`Average inference time: ${avgTime.toFixed(2)}ms`);
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
```typescript
|
||||
interface InferenceConfig {
|
||||
sparsity: {
|
||||
enabled: boolean; // Enable sparse inference
|
||||
threshold: number; // Activation threshold (0.0-1.0)
|
||||
};
|
||||
temperature: number; // Sampling temperature (0.0-2.0)
|
||||
top_k: number; // Top-k sampling (1-100)
|
||||
top_p?: number; // Nucleus sampling (0.0-1.0)
|
||||
max_tokens?: number; // Max generation length
|
||||
}
|
||||
```
|
||||
|
||||
## Browser Compatibility
|
||||
|
||||
- Chrome/Edge 91+ (WebAssembly SIMD)
|
||||
- Firefox 89+
|
||||
- Safari 15+
|
||||
- Node.js 16+
|
||||
|
||||
For older browsers, build without SIMD:
|
||||
|
||||
```bash
|
||||
wasm-pack build --target web -- --no-default-features
|
||||
```
|
||||
|
||||
## Performance Tips
|
||||
|
||||
1. **Enable SIMD**: Ensure `wasm32-simd` is enabled for 2-4x speedup
|
||||
2. **Quantization**: Use 4-bit or 8-bit quantized GGUF models
|
||||
3. **Sparsity**: Tune threshold based on accuracy/speed tradeoff
|
||||
4. **Calibration**: Run calibration with representative data
|
||||
5. **Batch Processing**: Use batch encoding for multiple inputs
|
||||
6. **Worker Threads**: Run inference in Web Workers to avoid blocking UI
|
||||
|
||||
## Example: Web Worker Integration
|
||||
|
||||
```typescript
|
||||
// worker.js
|
||||
import init, { SparseInferenceEngine } from 'ruvector-sparse-inference-wasm';
|
||||
|
||||
let engine;
|
||||
|
||||
self.onmessage = async (e) => {
|
||||
if (e.data.type === 'init') {
|
||||
await init();
|
||||
engine = new SparseInferenceEngine(e.data.modelBytes, e.data.config);
|
||||
self.postMessage({ type: 'ready' });
|
||||
} else if (e.data.type === 'infer') {
|
||||
const output = engine.infer(e.data.input);
|
||||
self.postMessage({ type: 'result', output });
|
||||
}
|
||||
};
|
||||
|
||||
// main.js
|
||||
const worker = new Worker('worker.js', { type: 'module' });
|
||||
|
||||
worker.postMessage({
|
||||
type: 'init',
|
||||
modelBytes: new Uint8Array(modelBytes),
|
||||
config: JSON.stringify(config)
|
||||
});
|
||||
|
||||
worker.onmessage = (e) => {
|
||||
if (e.data.type === 'ready') {
|
||||
worker.postMessage({
|
||||
type: 'infer',
|
||||
input: new Float32Array([...])
|
||||
});
|
||||
} else if (e.data.type === 'result') {
|
||||
console.log('Inference result:', e.data.output);
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
## Benchmarks
|
||||
|
||||
On Apple M1 Pro (browser):
|
||||
|
||||
| Model | Size | Sparsity | Speed | Memory |
|
||||
|-------|------|----------|-------|--------|
|
||||
| Llama-2-7B | 3.8GB | 10% | 45 tok/s | 1.2GB |
|
||||
| MiniLM-L6 | 90MB | 15% | 120 emb/s | 180MB |
|
||||
| Mistral-7B | 4.1GB | 12% | 38 tok/s | 1.4GB |
|
||||
|
||||
## Error Handling
|
||||
|
||||
```typescript
|
||||
try {
|
||||
const engine = new SparseInferenceEngine(modelBytes, config);
|
||||
const output = engine.infer(input);
|
||||
} catch (error) {
|
||||
if (error.message.includes('parse')) {
|
||||
console.error('Invalid GGUF model format');
|
||||
} else if (error.message.includes('config')) {
|
||||
console.error('Invalid configuration');
|
||||
} else {
|
||||
console.error('Inference failed:', error);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
### Run Tests
|
||||
|
||||
```bash
|
||||
wasm-pack test --headless --chrome
|
||||
wasm-pack test --headless --firefox
|
||||
```
|
||||
|
||||
### Build Documentation
|
||||
|
||||
```bash
|
||||
cargo doc --open --target wasm32-unknown-unknown
|
||||
```
|
||||
|
||||
### Size Optimization
|
||||
|
||||
```bash
|
||||
# Optimize for size
|
||||
wasm-pack build --target web --release -- -Z build-std=std,panic_abort -Z build-std-features=panic_immediate_abort
|
||||
|
||||
# Further compression with wasm-opt
|
||||
wasm-opt -Oz -o optimized.wasm pkg/ruvector_sparse_inference_wasm_bg.wasm
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
Same as parent RuVector project.
|
||||
|
||||
## Related Crates
|
||||
|
||||
- `ruvector-sparse-inference` - Core Rust implementation
|
||||
- `ruvector-core` - Main RuVector library
|
||||
- `rvlite` - Lightweight WASM vector database
|
||||
|
||||
## Contributing
|
||||
|
||||
See main RuVector repository for contribution guidelines.
|
||||
274
crates/ruvector-sparse-inference-wasm/src/lib.rs
Normal file
274
crates/ruvector-sparse-inference-wasm/src/lib.rs
Normal file
@@ -0,0 +1,274 @@
|
||||
use ruvector_sparse_inference::{
|
||||
model::{GenerationConfig, GgufParser, KVCache, ModelMetadata, ModelRunner},
|
||||
predictor::LowRankPredictor,
|
||||
InferenceConfig, SparseModel, SparsityConfig,
|
||||
};
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
/// Initialize panic hook for better error messages
|
||||
#[wasm_bindgen(start)]
|
||||
pub fn init() {
|
||||
#[cfg(feature = "console_error_panic_hook")]
|
||||
console_error_panic_hook::set_once();
|
||||
}
|
||||
|
||||
/// Sparse inference engine for WASM
|
||||
#[wasm_bindgen]
|
||||
pub struct SparseInferenceEngine {
|
||||
model: SparseModel,
|
||||
config: InferenceConfig,
|
||||
predictors: Vec<LowRankPredictor>,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl SparseInferenceEngine {
|
||||
/// Create new engine from GGUF bytes
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(model_bytes: &[u8], config_json: &str) -> Result<SparseInferenceEngine, JsError> {
|
||||
let config: InferenceConfig = serde_json::from_str(config_json)
|
||||
.map_err(|e| JsError::new(&format!("Invalid config: {}", e)))?;
|
||||
|
||||
let model = GgufParser::parse(model_bytes)
|
||||
.map_err(|e| JsError::new(&format!("Failed to parse model: {}", e)))?;
|
||||
|
||||
let predictors = Self::init_predictors(&model, &config);
|
||||
|
||||
Ok(Self {
|
||||
model,
|
||||
config,
|
||||
predictors,
|
||||
})
|
||||
}
|
||||
|
||||
/// Load model with streaming (for large models)
|
||||
#[wasm_bindgen]
|
||||
pub async fn load_streaming(
|
||||
url: &str,
|
||||
config_json: &str,
|
||||
) -> Result<SparseInferenceEngine, JsError> {
|
||||
// Fetch model in chunks
|
||||
let bytes = fetch_model_bytes(url).await?;
|
||||
Self::new(&bytes, config_json)
|
||||
}
|
||||
|
||||
/// Run inference on input
|
||||
#[wasm_bindgen]
|
||||
pub fn infer(&self, input: &[f32]) -> Result<Vec<f32>, JsError> {
|
||||
self.model
|
||||
.forward_embedding(input, &self.config)
|
||||
.map_err(|e| JsError::new(&format!("Inference failed: {}", e)))
|
||||
}
|
||||
|
||||
/// Run text generation (for LLM models)
|
||||
#[wasm_bindgen]
|
||||
pub fn generate(&mut self, input_ids: &[u32], max_tokens: u32) -> Result<Vec<u32>, JsError> {
|
||||
let config = GenerationConfig {
|
||||
max_new_tokens: max_tokens as usize,
|
||||
temperature: self.config.temperature,
|
||||
top_k: self.config.top_k,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
self.model
|
||||
.generate(input_ids, &config)
|
||||
.map_err(|e| JsError::new(&format!("Generation failed: {}", e)))
|
||||
}
|
||||
|
||||
/// Get model metadata as JSON
|
||||
#[wasm_bindgen]
|
||||
pub fn metadata(&self) -> String {
|
||||
serde_json::to_string(&self.model.metadata()).unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Get sparsity statistics
|
||||
#[wasm_bindgen]
|
||||
pub fn sparsity_stats(&self) -> String {
|
||||
let stats = self.model.sparsity_statistics();
|
||||
serde_json::to_string(&stats).unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Update sparsity threshold
|
||||
#[wasm_bindgen]
|
||||
pub fn set_sparsity(&mut self, threshold: f32) {
|
||||
self.config.sparsity.threshold = threshold;
|
||||
for predictor in &mut self.predictors {
|
||||
predictor.set_threshold(threshold);
|
||||
}
|
||||
}
|
||||
|
||||
/// Calibrate predictors with sample inputs
|
||||
#[wasm_bindgen]
|
||||
pub fn calibrate(&mut self, samples: &[f32], sample_dim: usize) -> Result<(), JsError> {
|
||||
let samples: Vec<Vec<f32>> = samples.chunks(sample_dim).map(|c| c.to_vec()).collect();
|
||||
|
||||
self.model
|
||||
.calibrate(&samples)
|
||||
.map_err(|e| JsError::new(&format!("Calibration failed: {}", e)))
|
||||
}
|
||||
|
||||
/// Initialize predictors for each layer
|
||||
fn init_predictors(model: &SparseModel, config: &InferenceConfig) -> Vec<LowRankPredictor> {
|
||||
let num_layers = model.metadata().num_layers;
|
||||
let hidden_size = model.metadata().hidden_size;
|
||||
|
||||
(0..num_layers)
|
||||
.map(|_| LowRankPredictor::new(hidden_size, config.sparsity.threshold))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Embedding model wrapper for sentence transformers
|
||||
#[wasm_bindgen]
|
||||
pub struct EmbeddingModel {
|
||||
engine: SparseInferenceEngine,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl EmbeddingModel {
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(model_bytes: &[u8]) -> Result<EmbeddingModel, JsError> {
|
||||
let config =
|
||||
r#"{"sparsity": {"enabled": true, "threshold": 0.1}, "temperature": 1.0, "top_k": 50}"#;
|
||||
let engine = SparseInferenceEngine::new(model_bytes, config)?;
|
||||
Ok(Self { engine })
|
||||
}
|
||||
|
||||
/// Encode text to embedding (requires tokenizer)
|
||||
#[wasm_bindgen]
|
||||
pub fn encode(&self, input_ids: &[u32]) -> Result<Vec<f32>, JsError> {
|
||||
self.engine
|
||||
.model
|
||||
.encode(input_ids)
|
||||
.map_err(|e| JsError::new(&format!("Encoding failed: {}", e)))
|
||||
}
|
||||
|
||||
/// Batch encode multiple sequences
|
||||
#[wasm_bindgen]
|
||||
pub fn encode_batch(&self, input_ids: &[u32], lengths: &[u32]) -> Result<Vec<f32>, JsError> {
|
||||
let mut results = Vec::new();
|
||||
let mut offset = 0usize;
|
||||
|
||||
for &len in lengths {
|
||||
let len = len as usize;
|
||||
if offset + len > input_ids.len() {
|
||||
return Err(JsError::new("Invalid lengths: exceeds input_ids size"));
|
||||
}
|
||||
let ids = &input_ids[offset..offset + len];
|
||||
let embedding = self
|
||||
.engine
|
||||
.model
|
||||
.encode(ids)
|
||||
.map_err(|e| JsError::new(&format!("Encoding failed: {}", e)))?;
|
||||
results.extend(embedding);
|
||||
offset += len;
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Get embedding dimension
|
||||
#[wasm_bindgen]
|
||||
pub fn dimension(&self) -> usize {
|
||||
self.engine.model.metadata().hidden_size
|
||||
}
|
||||
}
|
||||
|
||||
/// LLM model wrapper for text generation
|
||||
#[wasm_bindgen]
|
||||
pub struct LLMModel {
|
||||
engine: SparseInferenceEngine,
|
||||
kv_cache: KVCache,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl LLMModel {
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(model_bytes: &[u8], config_json: &str) -> Result<LLMModel, JsError> {
|
||||
let engine = SparseInferenceEngine::new(model_bytes, config_json)?;
|
||||
let cache_size = engine.model.metadata().max_position_embeddings;
|
||||
let kv_cache = KVCache::new(cache_size);
|
||||
Ok(Self { engine, kv_cache })
|
||||
}
|
||||
|
||||
/// Generate next token
|
||||
#[wasm_bindgen]
|
||||
pub fn next_token(&mut self, input_ids: &[u32]) -> Result<u32, JsError> {
|
||||
self.engine
|
||||
.model
|
||||
.next_token(input_ids, &mut self.kv_cache)
|
||||
.map_err(|e| JsError::new(&format!("Generation failed: {}", e)))
|
||||
}
|
||||
|
||||
/// Generate multiple tokens
|
||||
#[wasm_bindgen]
|
||||
pub fn generate(&mut self, input_ids: &[u32], max_tokens: u32) -> Result<Vec<u32>, JsError> {
|
||||
self.engine.generate(input_ids, max_tokens)
|
||||
}
|
||||
|
||||
/// Reset KV cache (for new conversation)
|
||||
#[wasm_bindgen]
|
||||
pub fn reset_cache(&mut self) {
|
||||
self.kv_cache.clear();
|
||||
}
|
||||
|
||||
/// Get generation statistics
|
||||
#[wasm_bindgen]
|
||||
pub fn stats(&self) -> String {
|
||||
serde_json::to_string(&self.engine.model.generation_stats()).unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Performance measurement utilities
|
||||
#[wasm_bindgen]
|
||||
pub fn measure_inference_time(
|
||||
engine: &SparseInferenceEngine,
|
||||
input: &[f32],
|
||||
iterations: u32,
|
||||
) -> f64 {
|
||||
let performance = web_sys::window()
|
||||
.and_then(|w| w.performance())
|
||||
.expect("Performance API not available");
|
||||
|
||||
let start = performance.now();
|
||||
for _ in 0..iterations {
|
||||
let _ = engine.infer(input);
|
||||
}
|
||||
let end = performance.now();
|
||||
|
||||
(end - start) / iterations as f64
|
||||
}
|
||||
|
||||
/// Get library version
|
||||
#[wasm_bindgen]
|
||||
pub fn version() -> String {
|
||||
env!("CARGO_PKG_VERSION").to_string()
|
||||
}
|
||||
|
||||
// Helper for streaming fetch
|
||||
async fn fetch_model_bytes(url: &str) -> Result<Vec<u8>, JsError> {
|
||||
use wasm_bindgen_futures::JsFuture;
|
||||
|
||||
let window = web_sys::window().ok_or_else(|| JsError::new("No window"))?;
|
||||
let response = JsFuture::from(window.fetch_with_str(url)).await?;
|
||||
let response: web_sys::Response = response
|
||||
.dyn_into()
|
||||
.map_err(|_| JsError::new("Failed to cast to Response"))?;
|
||||
let buffer = JsFuture::from(
|
||||
response
|
||||
.array_buffer()
|
||||
.map_err(|_| JsError::new("Failed to get array buffer"))?,
|
||||
)
|
||||
.await?;
|
||||
let array = js_sys::Uint8Array::new(&buffer);
|
||||
Ok(array.to_vec())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_version() {
|
||||
assert!(!version().is_empty());
|
||||
}
|
||||
}
|
||||
154
crates/ruvector-sparse-inference-wasm/tests/web.rs
Normal file
154
crates/ruvector-sparse-inference-wasm/tests/web.rs
Normal file
@@ -0,0 +1,154 @@
|
||||
#![cfg(target_arch = "wasm32")]
|
||||
|
||||
use ruvector_sparse_inference_wasm::*;
|
||||
use wasm_bindgen_test::*;
|
||||
|
||||
wasm_bindgen_test_configure!(run_in_browser);
|
||||
|
||||
/// Create a minimal mock GGUF model for testing
|
||||
fn create_mock_model() -> Vec<u8> {
|
||||
// Minimal GGUF header + metadata
|
||||
// Magic: "GGUF" (0x46554747)
|
||||
// Version: 3
|
||||
// Minimal tensors and metadata
|
||||
let mut bytes = Vec::new();
|
||||
|
||||
// Magic number (GGUF in little-endian)
|
||||
bytes.extend_from_slice(&[0x47, 0x47, 0x55, 0x46]); // "GGUF"
|
||||
|
||||
// Version (u32)
|
||||
bytes.extend_from_slice(&3u32.to_le_bytes());
|
||||
|
||||
// Tensor count (u64) - 0 tensors for minimal test
|
||||
bytes.extend_from_slice(&0u64.to_le_bytes());
|
||||
|
||||
// Metadata count (u64) - minimal metadata
|
||||
bytes.extend_from_slice(&4u64.to_le_bytes());
|
||||
|
||||
// Add minimal required metadata fields
|
||||
// This is a simplified version - real GGUF has complex structure
|
||||
|
||||
bytes
|
||||
}
|
||||
|
||||
fn create_test_engine() -> SparseInferenceEngine {
|
||||
let model_bytes = create_mock_model();
|
||||
let config = r#"{
|
||||
"sparsity": {
|
||||
"enabled": true,
|
||||
"threshold": 0.1
|
||||
},
|
||||
"temperature": 1.0,
|
||||
"top_k": 50
|
||||
}"#;
|
||||
|
||||
SparseInferenceEngine::new(&model_bytes, config).expect("Failed to create test engine")
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_version() {
|
||||
let ver = version();
|
||||
assert!(!ver.is_empty());
|
||||
assert!(ver.contains('.'));
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_init() {
|
||||
// Just ensure init doesn't panic
|
||||
init();
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_engine_creation_with_invalid_config() {
|
||||
let model_bytes = create_mock_model();
|
||||
let bad_config = "not json";
|
||||
|
||||
let result = SparseInferenceEngine::new(&model_bytes, bad_config);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_engine_metadata() {
|
||||
let engine = create_test_engine();
|
||||
let metadata = engine.metadata();
|
||||
|
||||
// Should return valid JSON string
|
||||
assert!(!metadata.is_empty());
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_sparsity_stats() {
|
||||
let engine = create_test_engine();
|
||||
let stats = engine.sparsity_stats();
|
||||
|
||||
// Should return valid JSON string
|
||||
assert!(!stats.is_empty());
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_sparsity_adjustment() {
|
||||
let mut engine = create_test_engine();
|
||||
|
||||
// Should not panic
|
||||
engine.set_sparsity(0.5);
|
||||
|
||||
let stats = engine.sparsity_stats();
|
||||
assert!(!stats.is_empty());
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_embedding_model_creation() {
|
||||
let model_bytes = create_mock_model();
|
||||
|
||||
let result = EmbeddingModel::new(&model_bytes);
|
||||
// May fail with mock model, but shouldn't panic
|
||||
let _ = result;
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_llm_model_creation() {
|
||||
let model_bytes = create_mock_model();
|
||||
let config = r#"{
|
||||
"sparsity": {"enabled": true, "threshold": 0.1},
|
||||
"temperature": 0.7,
|
||||
"top_k": 40
|
||||
}"#;
|
||||
|
||||
let result = LLMModel::new(&model_bytes, config);
|
||||
// May fail with mock model, but shouldn't panic
|
||||
let _ = result;
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_calibrate_with_empty_samples() {
|
||||
let mut engine = create_test_engine();
|
||||
let samples: Vec<f32> = vec![];
|
||||
|
||||
let result = engine.calibrate(&samples, 512);
|
||||
// Should handle gracefully
|
||||
let _ = result;
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_measure_inference_time() {
|
||||
let engine = create_test_engine();
|
||||
let input = vec![0.1f32; 512];
|
||||
|
||||
// Should not panic even if inference fails
|
||||
let time = measure_inference_time(&engine, &input, 1);
|
||||
assert!(time >= 0.0);
|
||||
}
|
||||
|
||||
// Integration tests with actual model would go here
|
||||
// These require a real GGUF model file which we don't have in tests
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
async fn test_load_streaming_with_bad_url() {
|
||||
let config = r#"{"sparsity": {"enabled": true}}"#;
|
||||
let result =
|
||||
SparseInferenceEngine::load_streaming("https://invalid.example.com/model.gguf", config)
|
||||
.await;
|
||||
|
||||
// Should fail gracefully
|
||||
assert!(result.is_err());
|
||||
}
|
||||
Reference in New Issue
Block a user