Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
This commit is contained in:
282
crates/ruvector-fpga-transformer/src/ffi/wasm_bindgen.rs
Normal file
282
crates/ruvector-fpga-transformer/src/ffi/wasm_bindgen.rs
Normal file
@@ -0,0 +1,282 @@
|
||||
//! WASM bindings via wasm-bindgen
|
||||
//!
|
||||
//! Provides the same API shape for browser and Node.js environments.
|
||||
|
||||
#![cfg(feature = "wasm")]
|
||||
|
||||
use js_sys::{Array, Int16Array, Object, Reflect, Uint16Array, Uint8Array};
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
use crate::artifact::{unpack_artifact, ModelArtifact};
|
||||
use crate::backend::native_sim::{NativeSimBackend, NativeSimConfig};
|
||||
use crate::backend::TransformerBackend;
|
||||
use crate::gating::DefaultCoherenceGate;
|
||||
use crate::types::{ComputeClass, FixedShape, GateHint, InferenceRequest, ModelId};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// WASM Engine for transformer inference
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmEngine {
|
||||
backend: NativeSimBackend,
|
||||
loaded_models: Vec<ModelId>,
|
||||
last_witness: Option<crate::types::WitnessLog>,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmEngine {
|
||||
/// Create a new WASM engine
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new() -> Self {
|
||||
// Use permissive config for WASM
|
||||
let config = NativeSimConfig {
|
||||
max_models: 4,
|
||||
trace: false,
|
||||
lut_softmax: true,
|
||||
max_layers: 0,
|
||||
};
|
||||
|
||||
let gate = Arc::new(DefaultCoherenceGate::new());
|
||||
let backend = NativeSimBackend::with_config(gate, config);
|
||||
|
||||
Self {
|
||||
backend,
|
||||
loaded_models: Vec::new(),
|
||||
last_witness: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Load a model artifact from bytes
|
||||
///
|
||||
/// Returns the model ID as a Uint8Array on success
|
||||
#[wasm_bindgen(js_name = loadArtifact)]
|
||||
pub fn load_artifact(&mut self, artifact_bytes: &[u8]) -> Result<Uint8Array, JsValue> {
|
||||
let artifact = unpack_artifact(artifact_bytes)
|
||||
.map_err(|e| JsValue::from_str(&format!("Failed to unpack artifact: {}", e)))?;
|
||||
|
||||
let model_id = self
|
||||
.backend
|
||||
.load(&artifact)
|
||||
.map_err(|e| JsValue::from_str(&format!("Failed to load model: {}", e)))?;
|
||||
|
||||
self.loaded_models.push(model_id);
|
||||
|
||||
// Return model ID as Uint8Array
|
||||
let id_array = Uint8Array::new_with_length(32);
|
||||
id_array.copy_from(model_id.as_bytes());
|
||||
Ok(id_array)
|
||||
}
|
||||
|
||||
/// Run inference
|
||||
///
|
||||
/// Returns an object with logits, topk, and witness
|
||||
#[wasm_bindgen]
|
||||
pub fn infer(
|
||||
&mut self,
|
||||
model_id: &[u8],
|
||||
tokens: &[u16],
|
||||
mask: &[u8],
|
||||
coherence_score_q: i16,
|
||||
boundary_crossed: bool,
|
||||
max_compute_class: u8,
|
||||
) -> Result<JsValue, JsValue> {
|
||||
// Parse model ID
|
||||
if model_id.len() != 32 {
|
||||
return Err(JsValue::from_str("Model ID must be 32 bytes"));
|
||||
}
|
||||
let mut id_bytes = [0u8; 32];
|
||||
id_bytes.copy_from_slice(model_id);
|
||||
let model = ModelId::new(id_bytes);
|
||||
|
||||
// Get shape from loaded model
|
||||
// For WASM, we use micro shape by default
|
||||
let shape = FixedShape::micro();
|
||||
|
||||
// Validate input length
|
||||
if tokens.len() != shape.seq_len as usize {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Token length mismatch: expected {}, got {}",
|
||||
shape.seq_len,
|
||||
tokens.len()
|
||||
)));
|
||||
}
|
||||
|
||||
// Build gate hint
|
||||
let compute_class =
|
||||
ComputeClass::from_u8(max_compute_class).unwrap_or(ComputeClass::Deliberative);
|
||||
let gate_hint = GateHint::new(coherence_score_q, boundary_crossed, compute_class);
|
||||
|
||||
// Create request
|
||||
let req = InferenceRequest::new(model, shape, tokens, mask, gate_hint);
|
||||
|
||||
// Run inference
|
||||
let result = self
|
||||
.backend
|
||||
.infer(req)
|
||||
.map_err(|e| JsValue::from_str(&format!("Inference failed: {}", e)))?;
|
||||
|
||||
// Store witness
|
||||
self.last_witness = Some(result.witness.clone());
|
||||
|
||||
// Build result object
|
||||
let obj = Object::new();
|
||||
|
||||
// Add logits
|
||||
let logits = Int16Array::new_with_length(result.logits_q.len() as u32);
|
||||
logits.copy_from(&result.logits_q);
|
||||
Reflect::set(&obj, &"logits".into(), &logits)?;
|
||||
|
||||
// Add top-K if available
|
||||
if let Some(topk) = &result.topk {
|
||||
let topk_array = Array::new();
|
||||
for (token, logit) in topk {
|
||||
let pair = Array::new();
|
||||
pair.push(&JsValue::from(*token));
|
||||
pair.push(&JsValue::from(*logit));
|
||||
topk_array.push(&pair);
|
||||
}
|
||||
Reflect::set(&obj, &"topk".into(), &topk_array)?;
|
||||
}
|
||||
|
||||
// Add witness info
|
||||
let witness = Object::new();
|
||||
Reflect::set(
|
||||
&witness,
|
||||
&"backend".into(),
|
||||
&format!("{:?}", result.witness.backend).into(),
|
||||
)?;
|
||||
Reflect::set(
|
||||
&witness,
|
||||
&"cycles".into(),
|
||||
&JsValue::from(result.witness.cycles),
|
||||
)?;
|
||||
Reflect::set(
|
||||
&witness,
|
||||
&"latency_ns".into(),
|
||||
&JsValue::from(result.witness.latency_ns),
|
||||
)?;
|
||||
Reflect::set(
|
||||
&witness,
|
||||
&"gate_decision".into(),
|
||||
&format!("{:?}", result.witness.gate_decision).into(),
|
||||
)?;
|
||||
Reflect::set(&obj, &"witness".into(), &witness)?;
|
||||
|
||||
Ok(obj.into())
|
||||
}
|
||||
|
||||
/// Get the last witness log as JSON
|
||||
#[wasm_bindgen(js_name = getWitness)]
|
||||
pub fn get_witness(&self) -> Result<JsValue, JsValue> {
|
||||
match &self.last_witness {
|
||||
Some(witness) => {
|
||||
let json = serde_json::to_string(witness)
|
||||
.map_err(|e| JsValue::from_str(&format!("Serialization failed: {}", e)))?;
|
||||
Ok(JsValue::from_str(&json))
|
||||
}
|
||||
None => Ok(JsValue::NULL),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get list of loaded model IDs
|
||||
#[wasm_bindgen(js_name = getLoadedModels)]
|
||||
pub fn get_loaded_models(&self) -> Array {
|
||||
let arr = Array::new();
|
||||
for id in &self.loaded_models {
|
||||
let id_array = Uint8Array::new_with_length(32);
|
||||
id_array.copy_from(id.as_bytes());
|
||||
arr.push(&id_array);
|
||||
}
|
||||
arr
|
||||
}
|
||||
|
||||
/// Unload a model
|
||||
#[wasm_bindgen]
|
||||
pub fn unload(&mut self, model_id: &[u8]) -> Result<(), JsValue> {
|
||||
if model_id.len() != 32 {
|
||||
return Err(JsValue::from_str("Model ID must be 32 bytes"));
|
||||
}
|
||||
let mut id_bytes = [0u8; 32];
|
||||
id_bytes.copy_from_slice(model_id);
|
||||
let model = ModelId::new(id_bytes);
|
||||
|
||||
self.backend
|
||||
.unload(model)
|
||||
.map_err(|e| JsValue::from_str(&format!("Unload failed: {}", e)))?;
|
||||
|
||||
self.loaded_models.retain(|id| *id != model);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get backend statistics
|
||||
#[wasm_bindgen(js_name = getStats)]
|
||||
pub fn get_stats(&self) -> Result<JsValue, JsValue> {
|
||||
let stats = self.backend.stats();
|
||||
let obj = Object::new();
|
||||
|
||||
Reflect::set(
|
||||
&obj,
|
||||
&"models_loaded".into(),
|
||||
&JsValue::from(stats.models_loaded as u32),
|
||||
)?;
|
||||
Reflect::set(
|
||||
&obj,
|
||||
&"total_inferences".into(),
|
||||
&JsValue::from(stats.total_inferences as f64),
|
||||
)?;
|
||||
Reflect::set(
|
||||
&obj,
|
||||
&"avg_latency_ns".into(),
|
||||
&JsValue::from(stats.avg_latency_ns as f64),
|
||||
)?;
|
||||
Reflect::set(
|
||||
&obj,
|
||||
&"early_exits".into(),
|
||||
&JsValue::from(stats.early_exits as f64),
|
||||
)?;
|
||||
Reflect::set(
|
||||
&obj,
|
||||
&"skipped".into(),
|
||||
&JsValue::from(stats.skipped as f64),
|
||||
)?;
|
||||
|
||||
Ok(obj.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for WasmEngine {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Utility function to create a micro shape configuration
|
||||
#[wasm_bindgen(js_name = microShape)]
|
||||
pub fn micro_shape() -> Result<JsValue, JsValue> {
|
||||
let shape = FixedShape::micro();
|
||||
let obj = Object::new();
|
||||
|
||||
Reflect::set(&obj, &"seq_len".into(), &JsValue::from(shape.seq_len))?;
|
||||
Reflect::set(&obj, &"d_model".into(), &JsValue::from(shape.d_model))?;
|
||||
Reflect::set(&obj, &"heads".into(), &JsValue::from(shape.heads))?;
|
||||
Reflect::set(&obj, &"d_head".into(), &JsValue::from(shape.d_head))?;
|
||||
Reflect::set(&obj, &"vocab".into(), &JsValue::from(shape.vocab))?;
|
||||
|
||||
Ok(obj.into())
|
||||
}
|
||||
|
||||
/// Utility function to validate an artifact without loading
|
||||
#[wasm_bindgen(js_name = validateArtifact)]
|
||||
pub fn validate_artifact(artifact_bytes: &[u8]) -> Result<JsValue, JsValue> {
|
||||
let artifact = unpack_artifact(artifact_bytes)
|
||||
.map_err(|e| JsValue::from_str(&format!("Invalid artifact: {}", e)))?;
|
||||
|
||||
artifact
|
||||
.validate()
|
||||
.map_err(|e| JsValue::from_str(&format!("Validation failed: {}", e)))?;
|
||||
|
||||
let obj = Object::new();
|
||||
Reflect::set(&obj, &"name".into(), &artifact.manifest.name.into())?;
|
||||
Reflect::set(&obj, &"valid".into(), &JsValue::TRUE)?;
|
||||
|
||||
Ok(obj.into())
|
||||
}
|
||||
Reference in New Issue
Block a user