Files
wifi-densepose/vendor/ruvector/crates/ruvllm-wasm/src/bindings.rs

1202 lines
35 KiB
Rust

//! JavaScript/WASM Bindings for RuvLLM
//!
//! This module provides JavaScript-friendly wrappers around the RuvLLM
//! inference runtime. All types are designed to work seamlessly with
//! JavaScript through wasm-bindgen.
//!
//! # Example (JavaScript)
//!
//! ```javascript
//! import init, { RuvLLMWasm, GenerateConfig, KvCacheWasm } from 'ruvllm-wasm';
//!
//! await init();
//!
//! // Create inference engine
//! const llm = new RuvLLMWasm();
//!
//! // Configure generation
//! const config = new GenerateConfig();
//! config.maxTokens = 256;
//! config.temperature = 0.7;
//!
//! // Format a chat conversation
//! const template = ChatTemplateWasm.llama3();
//! const messages = [
//! ChatMessageWasm.system("You are helpful."),
//! ChatMessageWasm.user("Hello!"),
//! ];
//! const prompt = template.format(messages);
//! ```
use crate::utils::log;
use serde::{Deserialize, Serialize};
use std::cell::RefCell;
use std::collections::VecDeque;
use std::sync::atomic::{AtomicUsize, Ordering};
use wasm_bindgen::prelude::*;
// ============================================================================
// Types (re-implemented for WASM self-containment)
// ============================================================================
/// Model size variants
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ModelSize {
Tiny,
Small,
Medium,
Large,
}
impl Default for ModelSize {
fn default() -> Self {
Self::Small
}
}
/// Precision levels for quantization
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Precision {
FP32,
FP16,
Q8,
Q4K,
Q4,
}
impl Default for Precision {
fn default() -> Self {
Self::FP16
}
}
impl Precision {
pub fn bytes_per_element(&self) -> f32 {
match self {
Self::FP32 => 4.0,
Self::FP16 => 2.0,
Self::Q8 => 1.0,
Self::Q4K | Self::Q4 => 0.5,
}
}
}
// ============================================================================
// Configuration Types
// ============================================================================
/// Generation configuration for text generation.
///
/// Controls sampling parameters and output constraints.
/// TypeScript-friendly with getter/setter methods.
#[wasm_bindgen]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GenerateConfig {
/// Maximum tokens to generate
#[wasm_bindgen(skip)]
pub max_tokens: usize,
/// Temperature for sampling (0.0 = deterministic)
#[wasm_bindgen(skip)]
pub temperature: f32,
/// Top-p (nucleus) sampling threshold
#[wasm_bindgen(skip)]
pub top_p: f32,
/// Top-k sampling (0 = disabled)
#[wasm_bindgen(skip)]
pub top_k: usize,
/// Repetition penalty (1.0 = no penalty)
#[wasm_bindgen(skip)]
pub repetition_penalty: f32,
/// Stop sequences (JSON array of strings)
#[wasm_bindgen(skip)]
pub stop_sequences: Vec<String>,
}
#[wasm_bindgen]
impl GenerateConfig {
/// Create a new GenerateConfig with default values.
#[wasm_bindgen(constructor)]
pub fn new() -> GenerateConfig {
GenerateConfig {
max_tokens: 256,
temperature: 0.7,
top_p: 0.9,
top_k: 40,
repetition_penalty: 1.1,
stop_sequences: Vec::new(),
}
}
/// Get maximum tokens.
#[wasm_bindgen(getter, js_name = maxTokens)]
pub fn max_tokens(&self) -> usize {
self.max_tokens
}
/// Set maximum tokens.
#[wasm_bindgen(setter, js_name = maxTokens)]
pub fn set_max_tokens(&mut self, value: usize) {
self.max_tokens = value;
}
/// Get temperature.
#[wasm_bindgen(getter)]
pub fn temperature(&self) -> f32 {
self.temperature
}
/// Set temperature.
#[wasm_bindgen(setter)]
pub fn set_temperature(&mut self, value: f32) {
self.temperature = value;
}
/// Get top-p value.
#[wasm_bindgen(getter, js_name = topP)]
pub fn top_p(&self) -> f32 {
self.top_p
}
/// Set top-p value.
#[wasm_bindgen(setter, js_name = topP)]
pub fn set_top_p(&mut self, value: f32) {
self.top_p = value;
}
/// Get top-k value.
#[wasm_bindgen(getter, js_name = topK)]
pub fn top_k(&self) -> usize {
self.top_k
}
/// Set top-k value.
#[wasm_bindgen(setter, js_name = topK)]
pub fn set_top_k(&mut self, value: usize) {
self.top_k = value;
}
/// Get repetition penalty.
#[wasm_bindgen(getter, js_name = repetitionPenalty)]
pub fn repetition_penalty(&self) -> f32 {
self.repetition_penalty
}
/// Set repetition penalty.
#[wasm_bindgen(setter, js_name = repetitionPenalty)]
pub fn set_repetition_penalty(&mut self, value: f32) {
self.repetition_penalty = value;
}
/// Add a stop sequence.
#[wasm_bindgen(js_name = addStopSequence)]
pub fn add_stop_sequence(&mut self, sequence: &str) {
self.stop_sequences.push(sequence.to_string());
}
/// Clear all stop sequences.
#[wasm_bindgen(js_name = clearStopSequences)]
pub fn clear_stop_sequences(&mut self) {
self.stop_sequences.clear();
}
/// Convert to JSON string.
#[wasm_bindgen(js_name = toJson)]
pub fn to_json(&self) -> Result<String, JsValue> {
serde_json::to_string(self).map_err(|e| JsValue::from_str(&e.to_string()))
}
/// Create from JSON string.
#[wasm_bindgen(js_name = fromJson)]
pub fn from_json(json: &str) -> Result<GenerateConfig, JsValue> {
serde_json::from_str(json).map_err(|e| JsValue::from_str(&e.to_string()))
}
}
impl Default for GenerateConfig {
fn default() -> Self {
Self::new()
}
}
// ============================================================================
// Chat Message Types
// ============================================================================
/// Message role in a conversation
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum Role {
System,
User,
Assistant,
}
impl Role {
pub fn as_str(&self) -> &'static str {
match self {
Role::System => "system",
Role::User => "user",
Role::Assistant => "assistant",
}
}
}
/// Internal chat message
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: Role,
pub content: String,
}
impl ChatMessage {
pub fn system(content: &str) -> Self {
Self {
role: Role::System,
content: content.to_string(),
}
}
pub fn user(content: &str) -> Self {
Self {
role: Role::User,
content: content.to_string(),
}
}
pub fn assistant(content: &str) -> Self {
Self {
role: Role::Assistant,
content: content.to_string(),
}
}
}
/// Chat message for instruction-tuned models.
///
/// Used to construct conversations for chat-based inference.
#[wasm_bindgen]
#[derive(Debug, Clone)]
pub struct ChatMessageWasm {
inner: ChatMessage,
}
#[wasm_bindgen]
impl ChatMessageWasm {
/// Create a system message.
#[wasm_bindgen(js_name = system)]
pub fn system(content: &str) -> ChatMessageWasm {
ChatMessageWasm {
inner: ChatMessage::system(content),
}
}
/// Create a user message.
#[wasm_bindgen(js_name = user)]
pub fn user(content: &str) -> ChatMessageWasm {
ChatMessageWasm {
inner: ChatMessage::user(content),
}
}
/// Create an assistant message.
#[wasm_bindgen(js_name = assistant)]
pub fn assistant(content: &str) -> ChatMessageWasm {
ChatMessageWasm {
inner: ChatMessage::assistant(content),
}
}
/// Get the role as a string.
#[wasm_bindgen(getter)]
pub fn role(&self) -> String {
self.inner.role.as_str().to_string()
}
/// Get the message content.
#[wasm_bindgen(getter)]
pub fn content(&self) -> String {
self.inner.content.clone()
}
}
// ============================================================================
// Chat Templates
// ============================================================================
/// Chat template variants
#[derive(Debug, Clone)]
pub enum ChatTemplate {
Llama3,
Llama2,
Mistral,
Qwen,
ChatML,
Phi,
Gemma,
Custom(String),
}
impl ChatTemplate {
/// Detect template from model ID
pub fn detect_from_model_id(model_id: &str) -> Self {
let model_lower = model_id.to_lowercase();
if model_lower.contains("llama-3") || model_lower.contains("llama3") {
Self::Llama3
} else if model_lower.contains("llama-2") || model_lower.contains("llama2") {
Self::Llama2
} else if model_lower.contains("mistral") || model_lower.contains("mixtral") {
Self::Mistral
} else if model_lower.contains("qwen") {
Self::Qwen
} else if model_lower.contains("phi") {
Self::Phi
} else if model_lower.contains("gemma") {
Self::Gemma
} else {
Self::ChatML
}
}
/// Format messages using this template
pub fn format(&self, messages: &[ChatMessage]) -> String {
match self {
Self::Llama3 => self.format_llama3(messages),
Self::Llama2 => self.format_llama2(messages),
Self::Mistral => self.format_mistral(messages),
Self::Qwen => self.format_qwen(messages),
Self::ChatML => self.format_chatml(messages),
Self::Phi => self.format_phi(messages),
Self::Gemma => self.format_gemma(messages),
Self::Custom(template) => self.format_custom(messages, template),
}
}
fn format_llama3(&self, messages: &[ChatMessage]) -> String {
let mut output = String::from("<|begin_of_text|>");
for msg in messages {
let role = msg.role.as_str();
output.push_str(&format!(
"<|start_header_id|>{}<|end_header_id|>\n\n{}<|eot_id|>",
role, msg.content
));
}
output.push_str("<|start_header_id|>assistant<|end_header_id|>\n\n");
output
}
fn format_llama2(&self, messages: &[ChatMessage]) -> String {
let mut output = String::new();
let mut system_msg = String::new();
for msg in messages {
match msg.role {
Role::System => {
system_msg = msg.content.clone();
}
Role::User => {
if !system_msg.is_empty() {
output.push_str(&format!(
"<s>[INST] <<SYS>>\n{}\n<</SYS>>\n\n{} [/INST]",
system_msg, msg.content
));
system_msg.clear();
} else {
output.push_str(&format!("<s>[INST] {} [/INST]", msg.content));
}
}
Role::Assistant => {
output.push_str(&format!(" {} </s>", msg.content));
}
}
}
output
}
fn format_mistral(&self, messages: &[ChatMessage]) -> String {
let mut output = String::new();
for msg in messages {
match msg.role {
Role::System | Role::User => {
output.push_str(&format!("[INST] {} [/INST]", msg.content));
}
Role::Assistant => {
output.push_str(&format!("{}</s>", msg.content));
}
}
}
output
}
fn format_qwen(&self, messages: &[ChatMessage]) -> String {
self.format_chatml(messages)
}
fn format_chatml(&self, messages: &[ChatMessage]) -> String {
let mut output = String::new();
for msg in messages {
output.push_str(&format!(
"<|im_start|>{}\n{}<|im_end|>\n",
msg.role.as_str(),
msg.content
));
}
output.push_str("<|im_start|>assistant\n");
output
}
fn format_phi(&self, messages: &[ChatMessage]) -> String {
let mut output = String::new();
for msg in messages {
match msg.role {
Role::System => {
output.push_str(&format!("<|system|>\n{}<|end|>\n", msg.content));
}
Role::User => {
output.push_str(&format!("<|user|>\n{}<|end|>\n", msg.content));
}
Role::Assistant => {
output.push_str(&format!("<|assistant|>\n{}<|end|>\n", msg.content));
}
}
}
output.push_str("<|assistant|>\n");
output
}
fn format_gemma(&self, messages: &[ChatMessage]) -> String {
let mut output = String::new();
for msg in messages {
match msg.role {
Role::User => {
output.push_str(&format!(
"<start_of_turn>user\n{}<end_of_turn>\n",
msg.content
));
}
Role::Assistant => {
output.push_str(&format!(
"<start_of_turn>model\n{}<end_of_turn>\n",
msg.content
));
}
Role::System => {
// Gemma doesn't have native system support, prepend to first user
output.push_str(&format!("<start_of_turn>user\n{}\n", msg.content));
}
}
}
output.push_str("<start_of_turn>model\n");
output
}
fn format_custom(&self, _messages: &[ChatMessage], _template: &str) -> String {
// Simplified custom template support
String::new()
}
}
/// Chat template for formatting conversations.
#[wasm_bindgen]
#[derive(Debug, Clone)]
pub struct ChatTemplateWasm {
inner: ChatTemplate,
}
#[wasm_bindgen]
impl ChatTemplateWasm {
/// Create a Llama 3 chat template.
#[wasm_bindgen(js_name = llama3)]
pub fn llama3() -> ChatTemplateWasm {
ChatTemplateWasm {
inner: ChatTemplate::Llama3,
}
}
/// Create a Mistral chat template.
#[wasm_bindgen(js_name = mistral)]
pub fn mistral() -> ChatTemplateWasm {
ChatTemplateWasm {
inner: ChatTemplate::Mistral,
}
}
/// Create a Qwen/ChatML chat template.
#[wasm_bindgen(js_name = chatml)]
pub fn chatml() -> ChatTemplateWasm {
ChatTemplateWasm {
inner: ChatTemplate::ChatML,
}
}
/// Create a Phi chat template.
#[wasm_bindgen(js_name = phi)]
pub fn phi() -> ChatTemplateWasm {
ChatTemplateWasm {
inner: ChatTemplate::Phi,
}
}
/// Create a Gemma chat template.
#[wasm_bindgen(js_name = gemma)]
pub fn gemma() -> ChatTemplateWasm {
ChatTemplateWasm {
inner: ChatTemplate::Gemma,
}
}
/// Create a custom chat template.
#[wasm_bindgen(js_name = custom)]
pub fn custom(template: &str) -> ChatTemplateWasm {
ChatTemplateWasm {
inner: ChatTemplate::Custom(template.to_string()),
}
}
/// Detect template from model ID.
#[wasm_bindgen(js_name = detectFromModelId)]
pub fn detect_from_model_id(model_id: &str) -> ChatTemplateWasm {
ChatTemplateWasm {
inner: ChatTemplate::detect_from_model_id(model_id),
}
}
/// Format messages using this template.
#[wasm_bindgen(js_name = format)]
pub fn format(&self, messages: Vec<ChatMessageWasm>) -> String {
let inner_messages: Vec<ChatMessage> = messages.into_iter().map(|m| m.inner).collect();
self.inner.format(&inner_messages)
}
/// Get the template name.
#[wasm_bindgen(getter)]
pub fn name(&self) -> String {
match &self.inner {
ChatTemplate::Llama3 => "llama3".to_string(),
ChatTemplate::Llama2 => "llama2".to_string(),
ChatTemplate::Mistral => "mistral".to_string(),
ChatTemplate::Qwen => "qwen".to_string(),
ChatTemplate::ChatML => "chatml".to_string(),
ChatTemplate::Phi => "phi".to_string(),
ChatTemplate::Gemma => "gemma".to_string(),
ChatTemplate::Custom(_) => "custom".to_string(),
}
}
}
// ============================================================================
// KV Cache
// ============================================================================
/// KV cache configuration for WASM.
#[wasm_bindgen]
#[derive(Debug, Clone)]
pub struct KvCacheConfigWasm {
tail_length: usize,
max_tokens: usize,
num_kv_heads: usize,
head_dim: usize,
}
#[wasm_bindgen]
impl KvCacheConfigWasm {
/// Create a new KV cache configuration.
#[wasm_bindgen(constructor)]
pub fn new() -> KvCacheConfigWasm {
KvCacheConfigWasm {
tail_length: 256,
max_tokens: 4096,
num_kv_heads: 8,
head_dim: 128,
}
}
/// Get tail length.
#[wasm_bindgen(getter, js_name = tailLength)]
pub fn tail_length(&self) -> usize {
self.tail_length
}
/// Set tail length.
#[wasm_bindgen(setter, js_name = tailLength)]
pub fn set_tail_length(&mut self, value: usize) {
self.tail_length = value;
}
/// Get max tokens.
#[wasm_bindgen(getter, js_name = maxTokens)]
pub fn max_tokens(&self) -> usize {
self.max_tokens
}
/// Set max tokens.
#[wasm_bindgen(setter, js_name = maxTokens)]
pub fn set_max_tokens(&mut self, value: usize) {
self.max_tokens = value;
}
/// Get number of KV heads.
#[wasm_bindgen(getter, js_name = numKvHeads)]
pub fn num_kv_heads(&self) -> usize {
self.num_kv_heads
}
/// Set number of KV heads.
#[wasm_bindgen(setter, js_name = numKvHeads)]
pub fn set_num_kv_heads(&mut self, value: usize) {
self.num_kv_heads = value;
}
/// Get head dimension.
#[wasm_bindgen(getter, js_name = headDim)]
pub fn head_dim(&self) -> usize {
self.head_dim
}
/// Set head dimension.
#[wasm_bindgen(setter, js_name = headDim)]
pub fn set_head_dim(&mut self, value: usize) {
self.head_dim = value;
}
}
impl Default for KvCacheConfigWasm {
fn default() -> Self {
Self::new()
}
}
/// KV cache statistics.
#[wasm_bindgen]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KvCacheStatsWasm {
total_tokens: usize,
tail_tokens: usize,
store_tokens: usize,
tail_bytes: usize,
store_bytes: usize,
compression_ratio: f32,
}
#[wasm_bindgen]
impl KvCacheStatsWasm {
/// Get total tokens.
#[wasm_bindgen(getter, js_name = totalTokens)]
pub fn total_tokens(&self) -> usize {
self.total_tokens
}
/// Get tail tokens.
#[wasm_bindgen(getter, js_name = tailTokens)]
pub fn tail_tokens(&self) -> usize {
self.tail_tokens
}
/// Get store tokens.
#[wasm_bindgen(getter, js_name = storeTokens)]
pub fn store_tokens(&self) -> usize {
self.store_tokens
}
/// Get compression ratio.
#[wasm_bindgen(getter, js_name = compressionRatio)]
pub fn compression_ratio(&self) -> f32 {
self.compression_ratio
}
/// Convert to JSON.
#[wasm_bindgen(js_name = toJson)]
pub fn to_json(&self) -> Result<String, JsValue> {
serde_json::to_string(self).map_err(|e| JsValue::from_str(&e.to_string()))
}
}
/// Two-tier KV cache for WASM.
///
/// Provides memory-efficient caching with a high-precision tail
/// and quantized store for older tokens.
#[wasm_bindgen]
pub struct KvCacheWasm {
// FP16 tail cache (recent tokens)
tail_keys: RefCell<VecDeque<Vec<f32>>>,
tail_values: RefCell<VecDeque<Vec<f32>>>,
// Quantized store (older tokens)
store_keys: RefCell<VecDeque<Vec<u8>>>,
store_values: RefCell<VecDeque<Vec<u8>>>,
// Configuration
config: KvCacheConfigWasm,
}
#[wasm_bindgen]
impl KvCacheWasm {
/// Create a new KV cache with the given configuration.
#[wasm_bindgen(constructor)]
pub fn new(config: &KvCacheConfigWasm) -> KvCacheWasm {
KvCacheWasm {
tail_keys: RefCell::new(VecDeque::new()),
tail_values: RefCell::new(VecDeque::new()),
store_keys: RefCell::new(VecDeque::new()),
store_values: RefCell::new(VecDeque::new()),
config: config.clone(),
}
}
/// Create with default configuration.
#[wasm_bindgen(js_name = withDefaults)]
pub fn with_defaults() -> KvCacheWasm {
KvCacheWasm::new(&KvCacheConfigWasm::default())
}
/// Append KV pairs to the cache.
#[wasm_bindgen]
pub fn append(&self, keys: &[f32], values: &[f32]) -> Result<(), JsValue> {
let stride = self.config.num_kv_heads * self.config.head_dim;
if keys.len() != stride || values.len() != stride {
return Err(JsValue::from_str(&format!(
"Key/value length must be {} (num_kv_heads * head_dim)",
stride
)));
}
let mut tail_keys = self.tail_keys.borrow_mut();
let mut tail_values = self.tail_values.borrow_mut();
// Add to tail
tail_keys.push_back(keys.to_vec());
tail_values.push_back(values.to_vec());
// Migrate from tail to store if needed
while tail_keys.len() > self.config.tail_length {
if let (Some(k), Some(v)) = (tail_keys.pop_front(), tail_values.pop_front()) {
// Simple quantization: convert f32 to u8
let quantized_k: Vec<u8> = k.iter().map(|&x| ((x + 1.0) * 127.5) as u8).collect();
let quantized_v: Vec<u8> = v.iter().map(|&x| ((x + 1.0) * 127.5) as u8).collect();
self.store_keys.borrow_mut().push_back(quantized_k);
self.store_values.borrow_mut().push_back(quantized_v);
}
}
// Evict from store if exceeds max tokens
let total = tail_keys.len() + self.store_keys.borrow().len();
if total > self.config.max_tokens {
let excess = total - self.config.max_tokens;
for _ in 0..excess {
self.store_keys.borrow_mut().pop_front();
self.store_values.borrow_mut().pop_front();
}
}
Ok(())
}
/// Get all cached KV pairs.
#[wasm_bindgen(js_name = getAllKv)]
pub fn get_all_kv(&self) -> Result<JsValue, JsValue> {
let stride = self.config.num_kv_heads * self.config.head_dim;
// Dequantize store
let store_keys = self.store_keys.borrow();
let store_values = self.store_values.borrow();
let tail_keys = self.tail_keys.borrow();
let tail_values = self.tail_values.borrow();
let total_tokens = store_keys.len() + tail_keys.len();
let mut all_keys = Vec::with_capacity(total_tokens * stride);
let mut all_values = Vec::with_capacity(total_tokens * stride);
// Dequantize store
for k in store_keys.iter() {
for &b in k {
all_keys.push((b as f32 / 127.5) - 1.0);
}
}
for v in store_values.iter() {
for &b in v {
all_values.push((b as f32 / 127.5) - 1.0);
}
}
// Add tail (already f32)
for k in tail_keys.iter() {
all_keys.extend(k);
}
for v in tail_values.iter() {
all_values.extend(v);
}
let obj = js_sys::Object::new();
let keys_array = js_sys::Float32Array::from(all_keys.as_slice());
let values_array = js_sys::Float32Array::from(all_values.as_slice());
js_sys::Reflect::set(&obj, &"keys".into(), &keys_array)?;
js_sys::Reflect::set(&obj, &"values".into(), &values_array)?;
Ok(obj.into())
}
/// Get cache statistics.
#[wasm_bindgen]
pub fn stats(&self) -> KvCacheStatsWasm {
let stride = self.config.num_kv_heads * self.config.head_dim;
let tail_tokens = self.tail_keys.borrow().len();
let store_tokens = self.store_keys.borrow().len();
let tail_bytes = tail_tokens * stride * 4; // f32
let store_bytes = store_tokens * stride * 1; // u8
let full_precision_bytes = (tail_tokens + store_tokens) * stride * 4;
let actual_bytes = tail_bytes + store_bytes;
let compression_ratio = if actual_bytes > 0 {
full_precision_bytes as f32 / actual_bytes as f32
} else {
1.0
};
KvCacheStatsWasm {
total_tokens: tail_tokens + store_tokens,
tail_tokens,
store_tokens,
tail_bytes,
store_bytes,
compression_ratio,
}
}
/// Clear the cache.
#[wasm_bindgen]
pub fn clear(&self) {
self.tail_keys.borrow_mut().clear();
self.tail_values.borrow_mut().clear();
self.store_keys.borrow_mut().clear();
self.store_values.borrow_mut().clear();
}
/// Get the total number of cached tokens.
#[wasm_bindgen(getter, js_name = tokenCount)]
pub fn token_count(&self) -> usize {
self.tail_keys.borrow().len() + self.store_keys.borrow().len()
}
}
// ============================================================================
// Memory Arena
// ============================================================================
const DEFAULT_ALIGNMENT: usize = 64;
/// Arena allocator for inference buffers.
///
/// Provides fast bump allocation with O(1) reset for
/// generation-step temporaries.
#[wasm_bindgen]
pub struct InferenceArenaWasm {
data: RefCell<Vec<u8>>,
offset: AtomicUsize,
high_water_mark: AtomicUsize,
allocation_count: AtomicUsize,
}
#[wasm_bindgen]
impl InferenceArenaWasm {
/// Create a new arena with the specified capacity in bytes.
#[wasm_bindgen(constructor)]
pub fn new(capacity: usize) -> InferenceArenaWasm {
let aligned_capacity = (capacity + DEFAULT_ALIGNMENT - 1) & !(DEFAULT_ALIGNMENT - 1);
InferenceArenaWasm {
data: RefCell::new(vec![0u8; aligned_capacity]),
offset: AtomicUsize::new(0),
high_water_mark: AtomicUsize::new(0),
allocation_count: AtomicUsize::new(0),
}
}
/// Create an arena sized for model dimensions.
#[wasm_bindgen(js_name = forModel)]
pub fn for_model(
hidden_dim: usize,
vocab_size: usize,
batch_size: usize,
) -> InferenceArenaWasm {
let activations = hidden_dim * batch_size * 4;
let logits = vocab_size * batch_size * 4;
let scratch = hidden_dim * 4 * 4;
let total = (activations + logits + scratch) * 2;
InferenceArenaWasm::new(total)
}
/// Reset the arena, making all memory available for reuse.
#[wasm_bindgen]
pub fn reset(&self) {
self.offset.store(0, Ordering::Release);
self.allocation_count.store(0, Ordering::Relaxed);
}
/// Get current bytes used.
#[wasm_bindgen(getter)]
pub fn used(&self) -> usize {
self.offset.load(Ordering::Acquire)
}
/// Get total capacity.
#[wasm_bindgen(getter)]
pub fn capacity(&self) -> usize {
self.data.borrow().len()
}
/// Get remaining available bytes.
#[wasm_bindgen(getter)]
pub fn remaining(&self) -> usize {
self.capacity() - self.used()
}
/// Get high water mark (maximum bytes ever used).
#[wasm_bindgen(getter, js_name = highWaterMark)]
pub fn high_water_mark(&self) -> usize {
self.high_water_mark.load(Ordering::Relaxed)
}
/// Get statistics as JSON.
#[wasm_bindgen(js_name = statsJson)]
pub fn stats_json(&self) -> Result<String, JsValue> {
let capacity = self.capacity();
let used = self.used();
let stats = serde_json::json!({
"capacity": capacity,
"used": used,
"remaining": capacity - used,
"high_water_mark": self.high_water_mark(),
"allocation_count": self.allocation_count.load(Ordering::Relaxed),
"utilization": if capacity > 0 { used as f64 / capacity as f64 } else { 0.0 }
});
serde_json::to_string(&stats).map_err(|e| JsValue::from_str(&e.to_string()))
}
}
// ============================================================================
// Buffer Pool
// ============================================================================
/// Buffer pool for efficient memory reuse.
#[wasm_bindgen]
pub struct BufferPoolWasm {
free_lists: RefCell<[Vec<Vec<u8>>; 5]>,
max_per_class: usize,
hits: AtomicUsize,
misses: AtomicUsize,
}
const BUFFER_SIZES: [usize; 5] = [1024, 4096, 16384, 65536, 262144];
#[wasm_bindgen]
impl BufferPoolWasm {
/// Create a new buffer pool with default settings.
#[wasm_bindgen(constructor)]
pub fn new() -> BufferPoolWasm {
BufferPoolWasm::with_capacity(32)
}
/// Create with specified max buffers per size class.
#[wasm_bindgen(js_name = withCapacity)]
pub fn with_capacity(max_buffers_per_class: usize) -> BufferPoolWasm {
BufferPoolWasm {
free_lists: RefCell::new([
Vec::with_capacity(max_buffers_per_class),
Vec::with_capacity(max_buffers_per_class),
Vec::with_capacity(max_buffers_per_class),
Vec::with_capacity(max_buffers_per_class),
Vec::with_capacity(max_buffers_per_class),
]),
max_per_class: max_buffers_per_class,
hits: AtomicUsize::new(0),
misses: AtomicUsize::new(0),
}
}
/// Pre-warm the pool by allocating buffers.
#[wasm_bindgen(js_name = prewarmAll)]
pub fn prewarm_all(&self, count_per_class: usize) {
let mut lists = self.free_lists.borrow_mut();
for (i, size) in BUFFER_SIZES.iter().enumerate() {
for _ in 0..count_per_class.min(self.max_per_class) {
if lists[i].len() < self.max_per_class {
lists[i].push(vec![0u8; *size]);
}
}
}
}
/// Get pool statistics as JSON.
#[wasm_bindgen(js_name = statsJson)]
pub fn stats_json(&self) -> Result<String, JsValue> {
let lists = self.free_lists.borrow();
let free_buffers: Vec<usize> = lists.iter().map(|l| l.len()).collect();
let hits = self.hits.load(Ordering::Relaxed);
let misses = self.misses.load(Ordering::Relaxed);
let total = hits + misses;
let stats = serde_json::json!({
"hits": hits,
"misses": misses,
"allocations": misses,
"returns": hits,
"drops": 0,
"free_buffers": free_buffers,
"hit_rate": if total > 0 { hits as f64 / total as f64 } else { 0.0 }
});
serde_json::to_string(&stats).map_err(|e| JsValue::from_str(&e.to_string()))
}
/// Get the hit rate (0.0 - 1.0).
#[wasm_bindgen(getter, js_name = hitRate)]
pub fn hit_rate(&self) -> f64 {
let hits = self.hits.load(Ordering::Relaxed);
let total = hits + self.misses.load(Ordering::Relaxed);
if total > 0 {
hits as f64 / total as f64
} else {
0.0
}
}
/// Clear all pooled buffers.
#[wasm_bindgen]
pub fn clear(&self) {
let mut lists = self.free_lists.borrow_mut();
for list in lists.iter_mut() {
list.clear();
}
}
}
impl Default for BufferPoolWasm {
fn default() -> Self {
Self::new()
}
}
// ============================================================================
// Main RuvLLM WASM Interface
// ============================================================================
/// Main RuvLLM WASM interface.
///
/// Provides the primary entry point for LLM inference in the browser.
/// Manages KV cache, memory pools, and inference state.
#[wasm_bindgen]
pub struct RuvLLMWasm {
kv_cache: Option<KvCacheWasm>,
buffer_pool: BufferPoolWasm,
initialized: bool,
}
#[wasm_bindgen]
impl RuvLLMWasm {
/// Create a new RuvLLM WASM instance.
#[wasm_bindgen(constructor)]
pub fn new() -> RuvLLMWasm {
crate::utils::set_panic_hook();
RuvLLMWasm {
kv_cache: None,
buffer_pool: BufferPoolWasm::new(),
initialized: false,
}
}
/// Initialize the engine with default configuration.
#[wasm_bindgen]
pub fn initialize(&mut self) -> Result<(), JsValue> {
self.initialize_with_config(&KvCacheConfigWasm::default())
}
/// Initialize with custom KV cache configuration.
#[wasm_bindgen(js_name = initializeWithConfig)]
pub fn initialize_with_config(&mut self, config: &KvCacheConfigWasm) -> Result<(), JsValue> {
log("Initializing RuvLLM WASM...");
self.kv_cache = Some(KvCacheWasm::new(config));
self.buffer_pool.prewarm_all(4);
self.initialized = true;
log("RuvLLM WASM initialized successfully");
Ok(())
}
/// Check if the engine is initialized.
#[wasm_bindgen(getter, js_name = isInitialized)]
pub fn is_initialized(&self) -> bool {
self.initialized
}
/// Get buffer pool statistics.
#[wasm_bindgen(js_name = getPoolStats)]
pub fn get_pool_stats(&self) -> Result<String, JsValue> {
self.buffer_pool.stats_json()
}
/// Clear all caches and reset state.
#[wasm_bindgen]
pub fn reset(&mut self) {
if let Some(cache) = &self.kv_cache {
cache.clear();
}
self.buffer_pool.clear();
log("RuvLLM WASM state reset");
}
/// Get version information.
#[wasm_bindgen(js_name = version)]
pub fn version() -> String {
"2.0.0".to_string()
}
/// Format a chat conversation using a template.
#[wasm_bindgen(js_name = formatChat)]
pub fn format_chat(template: &ChatTemplateWasm, messages: Vec<ChatMessageWasm>) -> String {
let inner_messages: Vec<ChatMessage> = messages.into_iter().map(|m| m.inner).collect();
template.inner.format(&inner_messages)
}
}
impl Default for RuvLLMWasm {
fn default() -> Self {
Self::new()
}
}
// ============================================================================
// Utility Exports
// ============================================================================
/// Get the WASM module version.
#[wasm_bindgen(js_name = getVersion)]
pub fn get_version() -> String {
"2.0.0".to_string()
}
/// Check if the WASM module is ready.
#[wasm_bindgen(js_name = isReady)]
pub fn is_ready() -> bool {
true
}
/// Detect chat template from model ID.
#[wasm_bindgen(js_name = detectChatTemplate)]
pub fn detect_chat_template(model_id: &str) -> ChatTemplateWasm {
ChatTemplateWasm::detect_from_model_id(model_id)
}