//! JavaScript/WASM Bindings for RuvLLM //! //! This module provides JavaScript-friendly wrappers around the RuvLLM //! inference runtime. All types are designed to work seamlessly with //! JavaScript through wasm-bindgen. //! //! # Example (JavaScript) //! //! ```javascript //! import init, { RuvLLMWasm, GenerateConfig, KvCacheWasm } from 'ruvllm-wasm'; //! //! await init(); //! //! // Create inference engine //! const llm = new RuvLLMWasm(); //! //! // Configure generation //! const config = new GenerateConfig(); //! config.maxTokens = 256; //! config.temperature = 0.7; //! //! // Format a chat conversation //! const template = ChatTemplateWasm.llama3(); //! const messages = [ //! ChatMessageWasm.system("You are helpful."), //! ChatMessageWasm.user("Hello!"), //! ]; //! const prompt = template.format(messages); //! ``` use crate::utils::log; use serde::{Deserialize, Serialize}; use std::cell::RefCell; use std::collections::VecDeque; use std::sync::atomic::{AtomicUsize, Ordering}; use wasm_bindgen::prelude::*; // ============================================================================ // Types (re-implemented for WASM self-containment) // ============================================================================ /// Model size variants #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] pub enum ModelSize { Tiny, Small, Medium, Large, } impl Default for ModelSize { fn default() -> Self { Self::Small } } /// Precision levels for quantization #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] pub enum Precision { FP32, FP16, Q8, Q4K, Q4, } impl Default for Precision { fn default() -> Self { Self::FP16 } } impl Precision { pub fn bytes_per_element(&self) -> f32 { match self { Self::FP32 => 4.0, Self::FP16 => 2.0, Self::Q8 => 1.0, Self::Q4K | Self::Q4 => 0.5, } } } // ============================================================================ // Configuration Types // ============================================================================ /// Generation configuration for text generation. /// /// Controls sampling parameters and output constraints. /// TypeScript-friendly with getter/setter methods. #[wasm_bindgen] #[derive(Debug, Clone, Serialize, Deserialize)] pub struct GenerateConfig { /// Maximum tokens to generate #[wasm_bindgen(skip)] pub max_tokens: usize, /// Temperature for sampling (0.0 = deterministic) #[wasm_bindgen(skip)] pub temperature: f32, /// Top-p (nucleus) sampling threshold #[wasm_bindgen(skip)] pub top_p: f32, /// Top-k sampling (0 = disabled) #[wasm_bindgen(skip)] pub top_k: usize, /// Repetition penalty (1.0 = no penalty) #[wasm_bindgen(skip)] pub repetition_penalty: f32, /// Stop sequences (JSON array of strings) #[wasm_bindgen(skip)] pub stop_sequences: Vec, } #[wasm_bindgen] impl GenerateConfig { /// Create a new GenerateConfig with default values. #[wasm_bindgen(constructor)] pub fn new() -> GenerateConfig { GenerateConfig { max_tokens: 256, temperature: 0.7, top_p: 0.9, top_k: 40, repetition_penalty: 1.1, stop_sequences: Vec::new(), } } /// Get maximum tokens. #[wasm_bindgen(getter, js_name = maxTokens)] pub fn max_tokens(&self) -> usize { self.max_tokens } /// Set maximum tokens. #[wasm_bindgen(setter, js_name = maxTokens)] pub fn set_max_tokens(&mut self, value: usize) { self.max_tokens = value; } /// Get temperature. #[wasm_bindgen(getter)] pub fn temperature(&self) -> f32 { self.temperature } /// Set temperature. #[wasm_bindgen(setter)] pub fn set_temperature(&mut self, value: f32) { self.temperature = value; } /// Get top-p value. #[wasm_bindgen(getter, js_name = topP)] pub fn top_p(&self) -> f32 { self.top_p } /// Set top-p value. #[wasm_bindgen(setter, js_name = topP)] pub fn set_top_p(&mut self, value: f32) { self.top_p = value; } /// Get top-k value. #[wasm_bindgen(getter, js_name = topK)] pub fn top_k(&self) -> usize { self.top_k } /// Set top-k value. #[wasm_bindgen(setter, js_name = topK)] pub fn set_top_k(&mut self, value: usize) { self.top_k = value; } /// Get repetition penalty. #[wasm_bindgen(getter, js_name = repetitionPenalty)] pub fn repetition_penalty(&self) -> f32 { self.repetition_penalty } /// Set repetition penalty. #[wasm_bindgen(setter, js_name = repetitionPenalty)] pub fn set_repetition_penalty(&mut self, value: f32) { self.repetition_penalty = value; } /// Add a stop sequence. #[wasm_bindgen(js_name = addStopSequence)] pub fn add_stop_sequence(&mut self, sequence: &str) { self.stop_sequences.push(sequence.to_string()); } /// Clear all stop sequences. #[wasm_bindgen(js_name = clearStopSequences)] pub fn clear_stop_sequences(&mut self) { self.stop_sequences.clear(); } /// Convert to JSON string. #[wasm_bindgen(js_name = toJson)] pub fn to_json(&self) -> Result { serde_json::to_string(self).map_err(|e| JsValue::from_str(&e.to_string())) } /// Create from JSON string. #[wasm_bindgen(js_name = fromJson)] pub fn from_json(json: &str) -> Result { serde_json::from_str(json).map_err(|e| JsValue::from_str(&e.to_string())) } } impl Default for GenerateConfig { fn default() -> Self { Self::new() } } // ============================================================================ // Chat Message Types // ============================================================================ /// Message role in a conversation #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] pub enum Role { System, User, Assistant, } impl Role { pub fn as_str(&self) -> &'static str { match self { Role::System => "system", Role::User => "user", Role::Assistant => "assistant", } } } /// Internal chat message #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatMessage { pub role: Role, pub content: String, } impl ChatMessage { pub fn system(content: &str) -> Self { Self { role: Role::System, content: content.to_string(), } } pub fn user(content: &str) -> Self { Self { role: Role::User, content: content.to_string(), } } pub fn assistant(content: &str) -> Self { Self { role: Role::Assistant, content: content.to_string(), } } } /// Chat message for instruction-tuned models. /// /// Used to construct conversations for chat-based inference. #[wasm_bindgen] #[derive(Debug, Clone)] pub struct ChatMessageWasm { inner: ChatMessage, } #[wasm_bindgen] impl ChatMessageWasm { /// Create a system message. #[wasm_bindgen(js_name = system)] pub fn system(content: &str) -> ChatMessageWasm { ChatMessageWasm { inner: ChatMessage::system(content), } } /// Create a user message. #[wasm_bindgen(js_name = user)] pub fn user(content: &str) -> ChatMessageWasm { ChatMessageWasm { inner: ChatMessage::user(content), } } /// Create an assistant message. #[wasm_bindgen(js_name = assistant)] pub fn assistant(content: &str) -> ChatMessageWasm { ChatMessageWasm { inner: ChatMessage::assistant(content), } } /// Get the role as a string. #[wasm_bindgen(getter)] pub fn role(&self) -> String { self.inner.role.as_str().to_string() } /// Get the message content. #[wasm_bindgen(getter)] pub fn content(&self) -> String { self.inner.content.clone() } } // ============================================================================ // Chat Templates // ============================================================================ /// Chat template variants #[derive(Debug, Clone)] pub enum ChatTemplate { Llama3, Llama2, Mistral, Qwen, ChatML, Phi, Gemma, Custom(String), } impl ChatTemplate { /// Detect template from model ID pub fn detect_from_model_id(model_id: &str) -> Self { let model_lower = model_id.to_lowercase(); if model_lower.contains("llama-3") || model_lower.contains("llama3") { Self::Llama3 } else if model_lower.contains("llama-2") || model_lower.contains("llama2") { Self::Llama2 } else if model_lower.contains("mistral") || model_lower.contains("mixtral") { Self::Mistral } else if model_lower.contains("qwen") { Self::Qwen } else if model_lower.contains("phi") { Self::Phi } else if model_lower.contains("gemma") { Self::Gemma } else { Self::ChatML } } /// Format messages using this template pub fn format(&self, messages: &[ChatMessage]) -> String { match self { Self::Llama3 => self.format_llama3(messages), Self::Llama2 => self.format_llama2(messages), Self::Mistral => self.format_mistral(messages), Self::Qwen => self.format_qwen(messages), Self::ChatML => self.format_chatml(messages), Self::Phi => self.format_phi(messages), Self::Gemma => self.format_gemma(messages), Self::Custom(template) => self.format_custom(messages, template), } } fn format_llama3(&self, messages: &[ChatMessage]) -> String { let mut output = String::from("<|begin_of_text|>"); for msg in messages { let role = msg.role.as_str(); output.push_str(&format!( "<|start_header_id|>{}<|end_header_id|>\n\n{}<|eot_id|>", role, msg.content )); } output.push_str("<|start_header_id|>assistant<|end_header_id|>\n\n"); output } fn format_llama2(&self, messages: &[ChatMessage]) -> String { let mut output = String::new(); let mut system_msg = String::new(); for msg in messages { match msg.role { Role::System => { system_msg = msg.content.clone(); } Role::User => { if !system_msg.is_empty() { output.push_str(&format!( "[INST] <>\n{}\n<>\n\n{} [/INST]", system_msg, msg.content )); system_msg.clear(); } else { output.push_str(&format!("[INST] {} [/INST]", msg.content)); } } Role::Assistant => { output.push_str(&format!(" {} ", msg.content)); } } } output } fn format_mistral(&self, messages: &[ChatMessage]) -> String { let mut output = String::new(); for msg in messages { match msg.role { Role::System | Role::User => { output.push_str(&format!("[INST] {} [/INST]", msg.content)); } Role::Assistant => { output.push_str(&format!("{}", msg.content)); } } } output } fn format_qwen(&self, messages: &[ChatMessage]) -> String { self.format_chatml(messages) } fn format_chatml(&self, messages: &[ChatMessage]) -> String { let mut output = String::new(); for msg in messages { output.push_str(&format!( "<|im_start|>{}\n{}<|im_end|>\n", msg.role.as_str(), msg.content )); } output.push_str("<|im_start|>assistant\n"); output } fn format_phi(&self, messages: &[ChatMessage]) -> String { let mut output = String::new(); for msg in messages { match msg.role { Role::System => { output.push_str(&format!("<|system|>\n{}<|end|>\n", msg.content)); } Role::User => { output.push_str(&format!("<|user|>\n{}<|end|>\n", msg.content)); } Role::Assistant => { output.push_str(&format!("<|assistant|>\n{}<|end|>\n", msg.content)); } } } output.push_str("<|assistant|>\n"); output } fn format_gemma(&self, messages: &[ChatMessage]) -> String { let mut output = String::new(); for msg in messages { match msg.role { Role::User => { output.push_str(&format!( "user\n{}\n", msg.content )); } Role::Assistant => { output.push_str(&format!( "model\n{}\n", msg.content )); } Role::System => { // Gemma doesn't have native system support, prepend to first user output.push_str(&format!("user\n{}\n", msg.content)); } } } output.push_str("model\n"); output } fn format_custom(&self, _messages: &[ChatMessage], _template: &str) -> String { // Simplified custom template support String::new() } } /// Chat template for formatting conversations. #[wasm_bindgen] #[derive(Debug, Clone)] pub struct ChatTemplateWasm { inner: ChatTemplate, } #[wasm_bindgen] impl ChatTemplateWasm { /// Create a Llama 3 chat template. #[wasm_bindgen(js_name = llama3)] pub fn llama3() -> ChatTemplateWasm { ChatTemplateWasm { inner: ChatTemplate::Llama3, } } /// Create a Mistral chat template. #[wasm_bindgen(js_name = mistral)] pub fn mistral() -> ChatTemplateWasm { ChatTemplateWasm { inner: ChatTemplate::Mistral, } } /// Create a Qwen/ChatML chat template. #[wasm_bindgen(js_name = chatml)] pub fn chatml() -> ChatTemplateWasm { ChatTemplateWasm { inner: ChatTemplate::ChatML, } } /// Create a Phi chat template. #[wasm_bindgen(js_name = phi)] pub fn phi() -> ChatTemplateWasm { ChatTemplateWasm { inner: ChatTemplate::Phi, } } /// Create a Gemma chat template. #[wasm_bindgen(js_name = gemma)] pub fn gemma() -> ChatTemplateWasm { ChatTemplateWasm { inner: ChatTemplate::Gemma, } } /// Create a custom chat template. #[wasm_bindgen(js_name = custom)] pub fn custom(template: &str) -> ChatTemplateWasm { ChatTemplateWasm { inner: ChatTemplate::Custom(template.to_string()), } } /// Detect template from model ID. #[wasm_bindgen(js_name = detectFromModelId)] pub fn detect_from_model_id(model_id: &str) -> ChatTemplateWasm { ChatTemplateWasm { inner: ChatTemplate::detect_from_model_id(model_id), } } /// Format messages using this template. #[wasm_bindgen(js_name = format)] pub fn format(&self, messages: Vec) -> String { let inner_messages: Vec = messages.into_iter().map(|m| m.inner).collect(); self.inner.format(&inner_messages) } /// Get the template name. #[wasm_bindgen(getter)] pub fn name(&self) -> String { match &self.inner { ChatTemplate::Llama3 => "llama3".to_string(), ChatTemplate::Llama2 => "llama2".to_string(), ChatTemplate::Mistral => "mistral".to_string(), ChatTemplate::Qwen => "qwen".to_string(), ChatTemplate::ChatML => "chatml".to_string(), ChatTemplate::Phi => "phi".to_string(), ChatTemplate::Gemma => "gemma".to_string(), ChatTemplate::Custom(_) => "custom".to_string(), } } } // ============================================================================ // KV Cache // ============================================================================ /// KV cache configuration for WASM. #[wasm_bindgen] #[derive(Debug, Clone)] pub struct KvCacheConfigWasm { tail_length: usize, max_tokens: usize, num_kv_heads: usize, head_dim: usize, } #[wasm_bindgen] impl KvCacheConfigWasm { /// Create a new KV cache configuration. #[wasm_bindgen(constructor)] pub fn new() -> KvCacheConfigWasm { KvCacheConfigWasm { tail_length: 256, max_tokens: 4096, num_kv_heads: 8, head_dim: 128, } } /// Get tail length. #[wasm_bindgen(getter, js_name = tailLength)] pub fn tail_length(&self) -> usize { self.tail_length } /// Set tail length. #[wasm_bindgen(setter, js_name = tailLength)] pub fn set_tail_length(&mut self, value: usize) { self.tail_length = value; } /// Get max tokens. #[wasm_bindgen(getter, js_name = maxTokens)] pub fn max_tokens(&self) -> usize { self.max_tokens } /// Set max tokens. #[wasm_bindgen(setter, js_name = maxTokens)] pub fn set_max_tokens(&mut self, value: usize) { self.max_tokens = value; } /// Get number of KV heads. #[wasm_bindgen(getter, js_name = numKvHeads)] pub fn num_kv_heads(&self) -> usize { self.num_kv_heads } /// Set number of KV heads. #[wasm_bindgen(setter, js_name = numKvHeads)] pub fn set_num_kv_heads(&mut self, value: usize) { self.num_kv_heads = value; } /// Get head dimension. #[wasm_bindgen(getter, js_name = headDim)] pub fn head_dim(&self) -> usize { self.head_dim } /// Set head dimension. #[wasm_bindgen(setter, js_name = headDim)] pub fn set_head_dim(&mut self, value: usize) { self.head_dim = value; } } impl Default for KvCacheConfigWasm { fn default() -> Self { Self::new() } } /// KV cache statistics. #[wasm_bindgen] #[derive(Debug, Clone, Serialize, Deserialize)] pub struct KvCacheStatsWasm { total_tokens: usize, tail_tokens: usize, store_tokens: usize, tail_bytes: usize, store_bytes: usize, compression_ratio: f32, } #[wasm_bindgen] impl KvCacheStatsWasm { /// Get total tokens. #[wasm_bindgen(getter, js_name = totalTokens)] pub fn total_tokens(&self) -> usize { self.total_tokens } /// Get tail tokens. #[wasm_bindgen(getter, js_name = tailTokens)] pub fn tail_tokens(&self) -> usize { self.tail_tokens } /// Get store tokens. #[wasm_bindgen(getter, js_name = storeTokens)] pub fn store_tokens(&self) -> usize { self.store_tokens } /// Get compression ratio. #[wasm_bindgen(getter, js_name = compressionRatio)] pub fn compression_ratio(&self) -> f32 { self.compression_ratio } /// Convert to JSON. #[wasm_bindgen(js_name = toJson)] pub fn to_json(&self) -> Result { serde_json::to_string(self).map_err(|e| JsValue::from_str(&e.to_string())) } } /// Two-tier KV cache for WASM. /// /// Provides memory-efficient caching with a high-precision tail /// and quantized store for older tokens. #[wasm_bindgen] pub struct KvCacheWasm { // FP16 tail cache (recent tokens) tail_keys: RefCell>>, tail_values: RefCell>>, // Quantized store (older tokens) store_keys: RefCell>>, store_values: RefCell>>, // Configuration config: KvCacheConfigWasm, } #[wasm_bindgen] impl KvCacheWasm { /// Create a new KV cache with the given configuration. #[wasm_bindgen(constructor)] pub fn new(config: &KvCacheConfigWasm) -> KvCacheWasm { KvCacheWasm { tail_keys: RefCell::new(VecDeque::new()), tail_values: RefCell::new(VecDeque::new()), store_keys: RefCell::new(VecDeque::new()), store_values: RefCell::new(VecDeque::new()), config: config.clone(), } } /// Create with default configuration. #[wasm_bindgen(js_name = withDefaults)] pub fn with_defaults() -> KvCacheWasm { KvCacheWasm::new(&KvCacheConfigWasm::default()) } /// Append KV pairs to the cache. #[wasm_bindgen] pub fn append(&self, keys: &[f32], values: &[f32]) -> Result<(), JsValue> { let stride = self.config.num_kv_heads * self.config.head_dim; if keys.len() != stride || values.len() != stride { return Err(JsValue::from_str(&format!( "Key/value length must be {} (num_kv_heads * head_dim)", stride ))); } let mut tail_keys = self.tail_keys.borrow_mut(); let mut tail_values = self.tail_values.borrow_mut(); // Add to tail tail_keys.push_back(keys.to_vec()); tail_values.push_back(values.to_vec()); // Migrate from tail to store if needed while tail_keys.len() > self.config.tail_length { if let (Some(k), Some(v)) = (tail_keys.pop_front(), tail_values.pop_front()) { // Simple quantization: convert f32 to u8 let quantized_k: Vec = k.iter().map(|&x| ((x + 1.0) * 127.5) as u8).collect(); let quantized_v: Vec = v.iter().map(|&x| ((x + 1.0) * 127.5) as u8).collect(); self.store_keys.borrow_mut().push_back(quantized_k); self.store_values.borrow_mut().push_back(quantized_v); } } // Evict from store if exceeds max tokens let total = tail_keys.len() + self.store_keys.borrow().len(); if total > self.config.max_tokens { let excess = total - self.config.max_tokens; for _ in 0..excess { self.store_keys.borrow_mut().pop_front(); self.store_values.borrow_mut().pop_front(); } } Ok(()) } /// Get all cached KV pairs. #[wasm_bindgen(js_name = getAllKv)] pub fn get_all_kv(&self) -> Result { let stride = self.config.num_kv_heads * self.config.head_dim; // Dequantize store let store_keys = self.store_keys.borrow(); let store_values = self.store_values.borrow(); let tail_keys = self.tail_keys.borrow(); let tail_values = self.tail_values.borrow(); let total_tokens = store_keys.len() + tail_keys.len(); let mut all_keys = Vec::with_capacity(total_tokens * stride); let mut all_values = Vec::with_capacity(total_tokens * stride); // Dequantize store for k in store_keys.iter() { for &b in k { all_keys.push((b as f32 / 127.5) - 1.0); } } for v in store_values.iter() { for &b in v { all_values.push((b as f32 / 127.5) - 1.0); } } // Add tail (already f32) for k in tail_keys.iter() { all_keys.extend(k); } for v in tail_values.iter() { all_values.extend(v); } let obj = js_sys::Object::new(); let keys_array = js_sys::Float32Array::from(all_keys.as_slice()); let values_array = js_sys::Float32Array::from(all_values.as_slice()); js_sys::Reflect::set(&obj, &"keys".into(), &keys_array)?; js_sys::Reflect::set(&obj, &"values".into(), &values_array)?; Ok(obj.into()) } /// Get cache statistics. #[wasm_bindgen] pub fn stats(&self) -> KvCacheStatsWasm { let stride = self.config.num_kv_heads * self.config.head_dim; let tail_tokens = self.tail_keys.borrow().len(); let store_tokens = self.store_keys.borrow().len(); let tail_bytes = tail_tokens * stride * 4; // f32 let store_bytes = store_tokens * stride * 1; // u8 let full_precision_bytes = (tail_tokens + store_tokens) * stride * 4; let actual_bytes = tail_bytes + store_bytes; let compression_ratio = if actual_bytes > 0 { full_precision_bytes as f32 / actual_bytes as f32 } else { 1.0 }; KvCacheStatsWasm { total_tokens: tail_tokens + store_tokens, tail_tokens, store_tokens, tail_bytes, store_bytes, compression_ratio, } } /// Clear the cache. #[wasm_bindgen] pub fn clear(&self) { self.tail_keys.borrow_mut().clear(); self.tail_values.borrow_mut().clear(); self.store_keys.borrow_mut().clear(); self.store_values.borrow_mut().clear(); } /// Get the total number of cached tokens. #[wasm_bindgen(getter, js_name = tokenCount)] pub fn token_count(&self) -> usize { self.tail_keys.borrow().len() + self.store_keys.borrow().len() } } // ============================================================================ // Memory Arena // ============================================================================ const DEFAULT_ALIGNMENT: usize = 64; /// Arena allocator for inference buffers. /// /// Provides fast bump allocation with O(1) reset for /// generation-step temporaries. #[wasm_bindgen] pub struct InferenceArenaWasm { data: RefCell>, offset: AtomicUsize, high_water_mark: AtomicUsize, allocation_count: AtomicUsize, } #[wasm_bindgen] impl InferenceArenaWasm { /// Create a new arena with the specified capacity in bytes. #[wasm_bindgen(constructor)] pub fn new(capacity: usize) -> InferenceArenaWasm { let aligned_capacity = (capacity + DEFAULT_ALIGNMENT - 1) & !(DEFAULT_ALIGNMENT - 1); InferenceArenaWasm { data: RefCell::new(vec![0u8; aligned_capacity]), offset: AtomicUsize::new(0), high_water_mark: AtomicUsize::new(0), allocation_count: AtomicUsize::new(0), } } /// Create an arena sized for model dimensions. #[wasm_bindgen(js_name = forModel)] pub fn for_model( hidden_dim: usize, vocab_size: usize, batch_size: usize, ) -> InferenceArenaWasm { let activations = hidden_dim * batch_size * 4; let logits = vocab_size * batch_size * 4; let scratch = hidden_dim * 4 * 4; let total = (activations + logits + scratch) * 2; InferenceArenaWasm::new(total) } /// Reset the arena, making all memory available for reuse. #[wasm_bindgen] pub fn reset(&self) { self.offset.store(0, Ordering::Release); self.allocation_count.store(0, Ordering::Relaxed); } /// Get current bytes used. #[wasm_bindgen(getter)] pub fn used(&self) -> usize { self.offset.load(Ordering::Acquire) } /// Get total capacity. #[wasm_bindgen(getter)] pub fn capacity(&self) -> usize { self.data.borrow().len() } /// Get remaining available bytes. #[wasm_bindgen(getter)] pub fn remaining(&self) -> usize { self.capacity() - self.used() } /// Get high water mark (maximum bytes ever used). #[wasm_bindgen(getter, js_name = highWaterMark)] pub fn high_water_mark(&self) -> usize { self.high_water_mark.load(Ordering::Relaxed) } /// Get statistics as JSON. #[wasm_bindgen(js_name = statsJson)] pub fn stats_json(&self) -> Result { let capacity = self.capacity(); let used = self.used(); let stats = serde_json::json!({ "capacity": capacity, "used": used, "remaining": capacity - used, "high_water_mark": self.high_water_mark(), "allocation_count": self.allocation_count.load(Ordering::Relaxed), "utilization": if capacity > 0 { used as f64 / capacity as f64 } else { 0.0 } }); serde_json::to_string(&stats).map_err(|e| JsValue::from_str(&e.to_string())) } } // ============================================================================ // Buffer Pool // ============================================================================ /// Buffer pool for efficient memory reuse. #[wasm_bindgen] pub struct BufferPoolWasm { free_lists: RefCell<[Vec>; 5]>, max_per_class: usize, hits: AtomicUsize, misses: AtomicUsize, } const BUFFER_SIZES: [usize; 5] = [1024, 4096, 16384, 65536, 262144]; #[wasm_bindgen] impl BufferPoolWasm { /// Create a new buffer pool with default settings. #[wasm_bindgen(constructor)] pub fn new() -> BufferPoolWasm { BufferPoolWasm::with_capacity(32) } /// Create with specified max buffers per size class. #[wasm_bindgen(js_name = withCapacity)] pub fn with_capacity(max_buffers_per_class: usize) -> BufferPoolWasm { BufferPoolWasm { free_lists: RefCell::new([ Vec::with_capacity(max_buffers_per_class), Vec::with_capacity(max_buffers_per_class), Vec::with_capacity(max_buffers_per_class), Vec::with_capacity(max_buffers_per_class), Vec::with_capacity(max_buffers_per_class), ]), max_per_class: max_buffers_per_class, hits: AtomicUsize::new(0), misses: AtomicUsize::new(0), } } /// Pre-warm the pool by allocating buffers. #[wasm_bindgen(js_name = prewarmAll)] pub fn prewarm_all(&self, count_per_class: usize) { let mut lists = self.free_lists.borrow_mut(); for (i, size) in BUFFER_SIZES.iter().enumerate() { for _ in 0..count_per_class.min(self.max_per_class) { if lists[i].len() < self.max_per_class { lists[i].push(vec![0u8; *size]); } } } } /// Get pool statistics as JSON. #[wasm_bindgen(js_name = statsJson)] pub fn stats_json(&self) -> Result { let lists = self.free_lists.borrow(); let free_buffers: Vec = lists.iter().map(|l| l.len()).collect(); let hits = self.hits.load(Ordering::Relaxed); let misses = self.misses.load(Ordering::Relaxed); let total = hits + misses; let stats = serde_json::json!({ "hits": hits, "misses": misses, "allocations": misses, "returns": hits, "drops": 0, "free_buffers": free_buffers, "hit_rate": if total > 0 { hits as f64 / total as f64 } else { 0.0 } }); serde_json::to_string(&stats).map_err(|e| JsValue::from_str(&e.to_string())) } /// Get the hit rate (0.0 - 1.0). #[wasm_bindgen(getter, js_name = hitRate)] pub fn hit_rate(&self) -> f64 { let hits = self.hits.load(Ordering::Relaxed); let total = hits + self.misses.load(Ordering::Relaxed); if total > 0 { hits as f64 / total as f64 } else { 0.0 } } /// Clear all pooled buffers. #[wasm_bindgen] pub fn clear(&self) { let mut lists = self.free_lists.borrow_mut(); for list in lists.iter_mut() { list.clear(); } } } impl Default for BufferPoolWasm { fn default() -> Self { Self::new() } } // ============================================================================ // Main RuvLLM WASM Interface // ============================================================================ /// Main RuvLLM WASM interface. /// /// Provides the primary entry point for LLM inference in the browser. /// Manages KV cache, memory pools, and inference state. #[wasm_bindgen] pub struct RuvLLMWasm { kv_cache: Option, buffer_pool: BufferPoolWasm, initialized: bool, } #[wasm_bindgen] impl RuvLLMWasm { /// Create a new RuvLLM WASM instance. #[wasm_bindgen(constructor)] pub fn new() -> RuvLLMWasm { crate::utils::set_panic_hook(); RuvLLMWasm { kv_cache: None, buffer_pool: BufferPoolWasm::new(), initialized: false, } } /// Initialize the engine with default configuration. #[wasm_bindgen] pub fn initialize(&mut self) -> Result<(), JsValue> { self.initialize_with_config(&KvCacheConfigWasm::default()) } /// Initialize with custom KV cache configuration. #[wasm_bindgen(js_name = initializeWithConfig)] pub fn initialize_with_config(&mut self, config: &KvCacheConfigWasm) -> Result<(), JsValue> { log("Initializing RuvLLM WASM..."); self.kv_cache = Some(KvCacheWasm::new(config)); self.buffer_pool.prewarm_all(4); self.initialized = true; log("RuvLLM WASM initialized successfully"); Ok(()) } /// Check if the engine is initialized. #[wasm_bindgen(getter, js_name = isInitialized)] pub fn is_initialized(&self) -> bool { self.initialized } /// Get buffer pool statistics. #[wasm_bindgen(js_name = getPoolStats)] pub fn get_pool_stats(&self) -> Result { self.buffer_pool.stats_json() } /// Clear all caches and reset state. #[wasm_bindgen] pub fn reset(&mut self) { if let Some(cache) = &self.kv_cache { cache.clear(); } self.buffer_pool.clear(); log("RuvLLM WASM state reset"); } /// Get version information. #[wasm_bindgen(js_name = version)] pub fn version() -> String { "2.0.0".to_string() } /// Format a chat conversation using a template. #[wasm_bindgen(js_name = formatChat)] pub fn format_chat(template: &ChatTemplateWasm, messages: Vec) -> String { let inner_messages: Vec = messages.into_iter().map(|m| m.inner).collect(); template.inner.format(&inner_messages) } } impl Default for RuvLLMWasm { fn default() -> Self { Self::new() } } // ============================================================================ // Utility Exports // ============================================================================ /// Get the WASM module version. #[wasm_bindgen(js_name = getVersion)] pub fn get_version() -> String { "2.0.0".to_string() } /// Check if the WASM module is ready. #[wasm_bindgen(js_name = isReady)] pub fn is_ready() -> bool { true } /// Detect chat template from model ID. #[wasm_bindgen(js_name = detectChatTemplate)] pub fn detect_chat_template(model_id: &str) -> ChatTemplateWasm { ChatTemplateWasm::detect_from_model_id(model_id) }