Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
1201
vendor/ruvector/crates/ruvllm-wasm/src/bindings.rs
vendored
Normal file
1201
vendor/ruvector/crates/ruvllm-wasm/src/bindings.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
797
vendor/ruvector/crates/ruvllm-wasm/src/hnsw_router.rs
vendored
Normal file
797
vendor/ruvector/crates/ruvllm-wasm/src/hnsw_router.rs
vendored
Normal file
@@ -0,0 +1,797 @@
|
||||
//! HNSW Semantic Router for Browser-Compatible Pattern Routing
|
||||
//!
|
||||
//! Pure Rust implementation of HNSW (Hierarchical Navigable Small World) graph
|
||||
//! for semantic pattern routing in WASM environments. Uses cosine similarity
|
||||
//! for embedding comparison.
|
||||
//!
|
||||
//! ## Features
|
||||
//!
|
||||
//! - **Browser-Compatible**: Pure Rust with no external WASM-incompatible deps
|
||||
//! - **Pattern Storage**: Store embeddings with metadata for routing decisions
|
||||
//! - **Semantic Search**: Find similar patterns using approximate nearest neighbor search
|
||||
//! - **Memory-Efficient**: Configurable max patterns to limit memory usage
|
||||
//! - **Serializable**: JSON serialization for IndexedDB persistence
|
||||
//!
|
||||
//! ## Example (JavaScript)
|
||||
//!
|
||||
//! ```javascript
|
||||
//! import { HnswRouterWasm, PatternWasm } from 'ruvllm-wasm';
|
||||
//!
|
||||
//! // Create router for 384-dimensional embeddings
|
||||
//! const router = HnswRouterWasm.new(384, 1000);
|
||||
//!
|
||||
//! // Add patterns with embeddings
|
||||
//! const embedding = new Float32Array([0.1, 0.2, ...]); // 384 dims
|
||||
//! router.addPattern(embedding, "rust-expert", JSON.stringify({
|
||||
//! domain: "rust",
|
||||
//! expertise: "high"
|
||||
//! }));
|
||||
//!
|
||||
//! // Route a query
|
||||
//! const queryEmbedding = new Float32Array([0.15, 0.18, ...]);
|
||||
//! const results = router.route(queryEmbedding, 5); // top 5 matches
|
||||
//!
|
||||
//! results.forEach(result => {
|
||||
//! console.log(`Match: ${result.name}, Score: ${result.score}`);
|
||||
//! });
|
||||
//!
|
||||
//! // Serialize to JSON for persistence
|
||||
//! const json = router.toJson();
|
||||
//! localStorage.setItem('router', json);
|
||||
//!
|
||||
//! // Restore from JSON
|
||||
//! const restored = HnswRouterWasm.fromJson(json);
|
||||
//! ```
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
/// Maximum connections per node in the HNSW graph (M parameter)
|
||||
const DEFAULT_M: usize = 16;
|
||||
|
||||
/// Maximum connections in layer 0 (M0 = M * 2)
|
||||
const DEFAULT_M0: usize = 32;
|
||||
|
||||
/// Number of nearest neighbors to explore during construction (efConstruction)
|
||||
const DEFAULT_EF_CONSTRUCTION: usize = 100;
|
||||
|
||||
/// Number of nearest neighbors to explore during search (efSearch)
|
||||
const DEFAULT_EF_SEARCH: usize = 50;
|
||||
|
||||
/// A stored pattern with embedding and metadata
|
||||
///
|
||||
/// Represents a routing pattern that can be matched against queries.
|
||||
/// Each pattern has a name, embedding vector, and optional metadata.
|
||||
#[wasm_bindgen]
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct PatternWasm {
|
||||
#[wasm_bindgen(skip)]
|
||||
pub name: String,
|
||||
#[wasm_bindgen(skip)]
|
||||
pub embedding: Vec<f32>,
|
||||
#[wasm_bindgen(skip)]
|
||||
pub metadata: String,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl PatternWasm {
|
||||
/// Create a new pattern
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// - `embedding`: Float32Array of embedding values
|
||||
/// - `name`: Pattern name/identifier
|
||||
/// - `metadata`: JSON string with additional metadata
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(embedding: &[f32], name: &str, metadata: &str) -> Self {
|
||||
Self {
|
||||
name: name.to_string(),
|
||||
embedding: embedding.to_vec(),
|
||||
metadata: metadata.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get pattern name
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn name(&self) -> String {
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
/// Get pattern embedding as Float32Array
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn embedding(&self) -> Vec<f32> {
|
||||
self.embedding.clone()
|
||||
}
|
||||
|
||||
/// Get pattern metadata JSON string
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn metadata(&self) -> String {
|
||||
self.metadata.clone()
|
||||
}
|
||||
|
||||
/// Set pattern name
|
||||
#[wasm_bindgen(setter)]
|
||||
pub fn set_name(&mut self, name: String) {
|
||||
self.name = name;
|
||||
}
|
||||
|
||||
/// Set pattern metadata
|
||||
#[wasm_bindgen(setter)]
|
||||
pub fn set_metadata(&mut self, metadata: String) {
|
||||
self.metadata = metadata;
|
||||
}
|
||||
}
|
||||
|
||||
/// A routing search result with similarity score
|
||||
///
|
||||
/// Represents a matched pattern from a semantic search query.
|
||||
#[wasm_bindgen]
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct RouteResultWasm {
|
||||
#[wasm_bindgen(skip)]
|
||||
pub name: String,
|
||||
#[wasm_bindgen(skip)]
|
||||
pub score: f32,
|
||||
#[wasm_bindgen(skip)]
|
||||
pub metadata: String,
|
||||
#[wasm_bindgen(skip)]
|
||||
pub embedding: Vec<f32>,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl RouteResultWasm {
|
||||
/// Get result pattern name
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn name(&self) -> String {
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
/// Get similarity score (higher is better, 0.0-1.0 for cosine)
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn score(&self) -> f32 {
|
||||
self.score
|
||||
}
|
||||
|
||||
/// Get result metadata JSON string
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn metadata(&self) -> String {
|
||||
self.metadata.clone()
|
||||
}
|
||||
|
||||
/// Get result embedding as Float32Array
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn embedding(&self) -> Vec<f32> {
|
||||
self.embedding.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// HNSW node representing a pattern in the graph
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
struct HnswNode {
|
||||
/// Node ID (index in patterns vector)
|
||||
id: usize,
|
||||
/// Graph layer (0 = base layer, higher = upper layers)
|
||||
layer: usize,
|
||||
/// Connections to other nodes at this layer
|
||||
neighbors: Vec<usize>,
|
||||
}
|
||||
|
||||
/// Internal HNSW graph state
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
struct HnswGraph {
|
||||
/// All stored patterns
|
||||
patterns: Vec<PatternWasm>,
|
||||
/// HNSW nodes per layer (layer -> node_id -> node)
|
||||
layers: Vec<HashMap<usize, HnswNode>>,
|
||||
/// Entry point node ID
|
||||
entry_point: Option<usize>,
|
||||
/// Maximum layer
|
||||
max_layer: usize,
|
||||
/// Configuration parameters
|
||||
m: usize,
|
||||
m0: usize,
|
||||
ef_construction: usize,
|
||||
ef_search: usize,
|
||||
}
|
||||
|
||||
impl HnswGraph {
|
||||
fn new(m: usize, ef_construction: usize, ef_search: usize) -> Self {
|
||||
Self {
|
||||
patterns: Vec::new(),
|
||||
layers: vec![HashMap::new()],
|
||||
entry_point: None,
|
||||
max_layer: 0,
|
||||
m,
|
||||
m0: m * 2,
|
||||
ef_construction,
|
||||
ef_search,
|
||||
}
|
||||
}
|
||||
|
||||
/// Select layer for new node using exponential decay
|
||||
fn select_layer(&self) -> usize {
|
||||
let ml = 1.0 / (self.m as f64).ln();
|
||||
let level = (-js_sys::Math::random().ln() * ml).floor() as usize;
|
||||
level.min(self.max_layer + 1)
|
||||
}
|
||||
|
||||
/// Calculate cosine similarity between two embeddings
|
||||
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||
if a.len() != b.len() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
|
||||
if norm_a < 1e-8 || norm_b < 1e-8 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
(dot / (norm_a * norm_b)).max(-1.0).min(1.0)
|
||||
}
|
||||
|
||||
/// Add a pattern to the HNSW graph
|
||||
fn add_pattern(&mut self, pattern: PatternWasm) {
|
||||
let node_id = self.patterns.len();
|
||||
let layer = self.select_layer();
|
||||
|
||||
// Ensure we have enough layers
|
||||
while self.layers.len() <= layer {
|
||||
self.layers.push(HashMap::new());
|
||||
}
|
||||
|
||||
// Update max layer and entry point if needed
|
||||
if layer > self.max_layer {
|
||||
self.max_layer = layer;
|
||||
self.entry_point = Some(node_id);
|
||||
}
|
||||
|
||||
// Insert node at all layers from 0 to selected layer
|
||||
for l in 0..=layer {
|
||||
let node = HnswNode {
|
||||
id: node_id,
|
||||
layer: l,
|
||||
neighbors: Vec::new(),
|
||||
};
|
||||
self.layers[l].insert(node_id, node);
|
||||
}
|
||||
|
||||
// Connect the new node to the graph
|
||||
if self.patterns.is_empty() {
|
||||
self.entry_point = Some(node_id);
|
||||
} else {
|
||||
self.connect_node(node_id, &pattern.embedding, layer);
|
||||
}
|
||||
|
||||
self.patterns.push(pattern);
|
||||
}
|
||||
|
||||
/// Connect a new node to existing nodes in the graph
|
||||
fn connect_node(&mut self, node_id: usize, embedding: &[f32], node_layer: usize) {
|
||||
let entry_point = self.entry_point.unwrap();
|
||||
|
||||
// Search for nearest neighbors from top to node layer
|
||||
let mut curr = entry_point;
|
||||
for l in (node_layer + 1..=self.max_layer).rev() {
|
||||
curr = self.search_layer(embedding, curr, 1, l)[0].0;
|
||||
}
|
||||
|
||||
// Insert connections from node_layer down to 0
|
||||
for l in (0..=node_layer).rev() {
|
||||
let m = if l == 0 { self.m0 } else { self.m };
|
||||
let candidates = self.search_layer(embedding, curr, self.ef_construction, l);
|
||||
|
||||
// Select M nearest neighbors
|
||||
let neighbors: Vec<usize> = candidates.iter().take(m).map(|(id, _)| *id).collect();
|
||||
|
||||
// Add bidirectional connections
|
||||
if let Some(node) = self.layers[l].get_mut(&node_id) {
|
||||
node.neighbors = neighbors.clone();
|
||||
}
|
||||
|
||||
// Collect neighbors that need pruning
|
||||
let mut to_prune = Vec::new();
|
||||
|
||||
for &neighbor_id in &neighbors {
|
||||
if let Some(neighbor) = self.layers[l].get_mut(&neighbor_id) {
|
||||
if !neighbor.neighbors.contains(&node_id) {
|
||||
neighbor.neighbors.push(node_id);
|
||||
|
||||
// Check if pruning needed
|
||||
if neighbor.neighbors.len() > m {
|
||||
to_prune.push(neighbor_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Prune connections after iteration
|
||||
for neighbor_id in to_prune {
|
||||
let neighbor_emb = self.patterns[neighbor_id].embedding.clone();
|
||||
self.prune_connections(neighbor_id, &neighbor_emb, m, l);
|
||||
}
|
||||
|
||||
curr = candidates[0].0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Prune connections to maintain M maximum
|
||||
fn prune_connections(&mut self, node_id: usize, embedding: &[f32], m: usize, layer: usize) {
|
||||
if let Some(node) = self.layers[layer].get(&node_id) {
|
||||
let mut scored_neighbors: Vec<(usize, f32)> = node
|
||||
.neighbors
|
||||
.iter()
|
||||
.map(|&id| {
|
||||
let sim = Self::cosine_similarity(embedding, &self.patterns[id].embedding);
|
||||
(id, sim)
|
||||
})
|
||||
.collect();
|
||||
|
||||
scored_neighbors.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
|
||||
let pruned: Vec<usize> = scored_neighbors
|
||||
.into_iter()
|
||||
.take(m)
|
||||
.map(|(id, _)| id)
|
||||
.collect();
|
||||
|
||||
if let Some(node) = self.layers[layer].get_mut(&node_id) {
|
||||
node.neighbors = pruned;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Search a single layer for nearest neighbors
|
||||
fn search_layer(
|
||||
&self,
|
||||
query: &[f32],
|
||||
entry_point: usize,
|
||||
ef: usize,
|
||||
layer: usize,
|
||||
) -> Vec<(usize, f32)> {
|
||||
let mut visited = vec![false; self.patterns.len()];
|
||||
let mut candidates = Vec::new();
|
||||
let mut best = Vec::new();
|
||||
|
||||
let entry_sim = Self::cosine_similarity(query, &self.patterns[entry_point].embedding);
|
||||
candidates.push((entry_point, entry_sim));
|
||||
best.push((entry_point, entry_sim));
|
||||
visited[entry_point] = true;
|
||||
|
||||
while !candidates.is_empty() {
|
||||
// Get candidate with highest similarity
|
||||
candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
|
||||
let (curr_id, curr_sim) = candidates.pop().unwrap();
|
||||
|
||||
// If worse than worst in best set, stop
|
||||
if !best.is_empty() {
|
||||
let worst_best = best
|
||||
.iter()
|
||||
.min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
|
||||
.unwrap();
|
||||
if curr_sim < worst_best.1 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Explore neighbors
|
||||
if let Some(node) = self.layers[layer].get(&curr_id) {
|
||||
for &neighbor_id in &node.neighbors {
|
||||
if !visited[neighbor_id] {
|
||||
visited[neighbor_id] = true;
|
||||
let sim =
|
||||
Self::cosine_similarity(query, &self.patterns[neighbor_id].embedding);
|
||||
|
||||
if best.len() < ef
|
||||
|| sim
|
||||
> best
|
||||
.iter()
|
||||
.min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
|
||||
.unwrap()
|
||||
.1
|
||||
{
|
||||
candidates.push((neighbor_id, sim));
|
||||
best.push((neighbor_id, sim));
|
||||
|
||||
if best.len() > ef {
|
||||
best.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
|
||||
best.truncate(ef);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
best.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
|
||||
best
|
||||
}
|
||||
|
||||
/// Search the graph for k nearest neighbors
|
||||
fn search(&self, query: &[f32], k: usize) -> Vec<RouteResultWasm> {
|
||||
if self.patterns.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let entry_point = self.entry_point.unwrap();
|
||||
let mut curr = entry_point;
|
||||
|
||||
// Search from top layer down to layer 1
|
||||
for l in (1..=self.max_layer).rev() {
|
||||
curr = self.search_layer(query, curr, 1, l)[0].0;
|
||||
}
|
||||
|
||||
// Search layer 0 with ef_search
|
||||
let results = self.search_layer(query, curr, self.ef_search.max(k), 0);
|
||||
|
||||
// Convert to RouteResultWasm
|
||||
results
|
||||
.into_iter()
|
||||
.take(k)
|
||||
.map(|(id, score)| {
|
||||
let pattern = &self.patterns[id];
|
||||
RouteResultWasm {
|
||||
name: pattern.name.clone(),
|
||||
score,
|
||||
metadata: pattern.metadata.clone(),
|
||||
embedding: pattern.embedding.clone(),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// HNSW Semantic Router for browser-compatible pattern routing
|
||||
///
|
||||
/// Provides approximate nearest neighbor search over pattern embeddings
|
||||
/// using the HNSW (Hierarchical Navigable Small World) algorithm.
|
||||
///
|
||||
/// ## Memory Efficiency
|
||||
///
|
||||
/// The router enforces a maximum number of patterns to prevent unbounded
|
||||
/// memory growth in browser environments. When the limit is reached, adding
|
||||
/// new patterns will fail.
|
||||
///
|
||||
/// ## Thread Safety
|
||||
///
|
||||
/// This implementation is single-threaded and designed for use in browser
|
||||
/// main thread or Web Workers.
|
||||
#[wasm_bindgen]
|
||||
pub struct HnswRouterWasm {
|
||||
dimensions: usize,
|
||||
max_patterns: usize,
|
||||
graph: HnswGraph,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl HnswRouterWasm {
|
||||
/// Create a new HNSW router
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// - `dimensions`: Size of embedding vectors (e.g., 384 for all-MiniLM-L6-v2)
|
||||
/// - `max_patterns`: Maximum number of patterns to store (memory limit)
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```javascript
|
||||
/// const router = HnswRouterWasm.new(384, 1000);
|
||||
/// ```
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(dimensions: usize, max_patterns: usize) -> Self {
|
||||
crate::utils::set_panic_hook();
|
||||
|
||||
Self {
|
||||
dimensions,
|
||||
max_patterns,
|
||||
graph: HnswGraph::new(DEFAULT_M, DEFAULT_EF_CONSTRUCTION, DEFAULT_EF_SEARCH),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get embedding dimensions
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn dimensions(&self) -> usize {
|
||||
self.dimensions
|
||||
}
|
||||
|
||||
/// Get maximum patterns limit
|
||||
#[wasm_bindgen(getter, js_name = maxPatterns)]
|
||||
pub fn max_patterns(&self) -> usize {
|
||||
self.max_patterns
|
||||
}
|
||||
|
||||
/// Get current number of patterns
|
||||
#[wasm_bindgen(getter, js_name = patternCount)]
|
||||
pub fn pattern_count(&self) -> usize {
|
||||
self.graph.patterns.len()
|
||||
}
|
||||
|
||||
/// Add a pattern to the router
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// - `embedding`: Float32Array of embedding values (must match dimensions)
|
||||
/// - `name`: Pattern name/identifier
|
||||
/// - `metadata`: JSON string with additional metadata
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// `true` if pattern was added, `false` if max_patterns limit reached
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```javascript
|
||||
/// const embedding = new Float32Array([0.1, 0.2, 0.3, ...]); // 384 dims
|
||||
/// const success = router.addPattern(
|
||||
/// embedding,
|
||||
/// "rust-expert",
|
||||
/// JSON.stringify({ domain: "rust", expertise: "high" })
|
||||
/// );
|
||||
/// ```
|
||||
#[wasm_bindgen(js_name = addPattern)]
|
||||
pub fn add_pattern(&mut self, embedding: &[f32], name: &str, metadata: &str) -> bool {
|
||||
if self.graph.patterns.len() >= self.max_patterns {
|
||||
return false;
|
||||
}
|
||||
|
||||
if embedding.len() != self.dimensions {
|
||||
crate::utils::warn(&format!(
|
||||
"Embedding dimension mismatch: expected {}, got {}",
|
||||
self.dimensions,
|
||||
embedding.len()
|
||||
));
|
||||
return false;
|
||||
}
|
||||
|
||||
let pattern = PatternWasm::new(embedding, name, metadata);
|
||||
self.graph.add_pattern(pattern);
|
||||
true
|
||||
}
|
||||
|
||||
/// Route a query to find similar patterns
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// - `query`: Float32Array of query embedding (must match dimensions)
|
||||
/// - `top_k`: Number of top results to return
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Array of RouteResultWasm ordered by similarity (highest first)
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```javascript
|
||||
/// const query = new Float32Array([0.15, 0.18, ...]); // 384 dims
|
||||
/// const results = router.route(query, 5);
|
||||
/// results.forEach(result => {
|
||||
/// console.log(`${result.name}: ${result.score}`);
|
||||
/// });
|
||||
/// ```
|
||||
#[wasm_bindgen]
|
||||
pub fn route(&self, query: &[f32], top_k: usize) -> Vec<RouteResultWasm> {
|
||||
if query.len() != self.dimensions {
|
||||
crate::utils::warn(&format!(
|
||||
"Query dimension mismatch: expected {}, got {}",
|
||||
self.dimensions,
|
||||
query.len()
|
||||
));
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
self.graph.search(query, top_k)
|
||||
}
|
||||
|
||||
/// Serialize the router to JSON string
|
||||
///
|
||||
/// Useful for persisting to IndexedDB or localStorage.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```javascript
|
||||
/// const json = router.toJson();
|
||||
/// localStorage.setItem('router', json);
|
||||
/// ```
|
||||
#[wasm_bindgen(js_name = toJson)]
|
||||
pub fn to_json(&self) -> Result<String, JsValue> {
|
||||
serde_json::to_string(&SerializableRouter {
|
||||
dimensions: self.dimensions,
|
||||
max_patterns: self.max_patterns,
|
||||
graph: self.graph.clone(),
|
||||
})
|
||||
.map_err(|e| JsValue::from_str(&format!("Serialization failed: {}", e)))
|
||||
}
|
||||
|
||||
/// Deserialize a router from JSON string
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```javascript
|
||||
/// const json = localStorage.getItem('router');
|
||||
/// const router = HnswRouterWasm.fromJson(json);
|
||||
/// ```
|
||||
#[wasm_bindgen(js_name = fromJson)]
|
||||
pub fn from_json(json: &str) -> Result<HnswRouterWasm, JsValue> {
|
||||
let data: SerializableRouter = serde_json::from_str(json)
|
||||
.map_err(|e| JsValue::from_str(&format!("Deserialization failed: {}", e)))?;
|
||||
|
||||
Ok(Self {
|
||||
dimensions: data.dimensions,
|
||||
max_patterns: data.max_patterns,
|
||||
graph: data.graph,
|
||||
})
|
||||
}
|
||||
|
||||
/// Clear all patterns from the router
|
||||
///
|
||||
/// Resets the router to empty state.
|
||||
#[wasm_bindgen]
|
||||
pub fn clear(&mut self) {
|
||||
self.graph = HnswGraph::new(DEFAULT_M, DEFAULT_EF_CONSTRUCTION, DEFAULT_EF_SEARCH);
|
||||
}
|
||||
|
||||
/// Get pattern by index
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// - `index`: Pattern index (0 to patternCount - 1)
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// PatternWasm or null if index out of bounds
|
||||
#[wasm_bindgen(js_name = getPattern)]
|
||||
pub fn get_pattern(&self, index: usize) -> Option<PatternWasm> {
|
||||
self.graph.patterns.get(index).cloned()
|
||||
}
|
||||
|
||||
/// Set efSearch parameter for query-time accuracy tuning
|
||||
///
|
||||
/// Higher values = more accurate but slower search.
|
||||
/// Recommended range: 10-200.
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// - `ef_search`: Number of neighbors to explore during search
|
||||
#[wasm_bindgen(js_name = setEfSearch)]
|
||||
pub fn set_ef_search(&mut self, ef_search: usize) {
|
||||
self.graph.ef_search = ef_search;
|
||||
}
|
||||
|
||||
/// Get current efSearch parameter
|
||||
#[wasm_bindgen(getter, js_name = efSearch)]
|
||||
pub fn ef_search(&self) -> usize {
|
||||
self.graph.ef_search
|
||||
}
|
||||
}
|
||||
|
||||
/// Serializable router format
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct SerializableRouter {
|
||||
dimensions: usize,
|
||||
max_patterns: usize,
|
||||
graph: HnswGraph,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn create_test_embedding(dim: usize, seed: f32) -> Vec<f32> {
|
||||
(0..dim).map(|i| (i as f32 * seed).sin()).collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_router_creation() {
|
||||
let router = HnswRouterWasm::new(128, 100);
|
||||
assert_eq!(router.dimensions(), 128);
|
||||
assert_eq!(router.max_patterns(), 100);
|
||||
assert_eq!(router.pattern_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_pattern() {
|
||||
let mut router = HnswRouterWasm::new(128, 100);
|
||||
let embedding = create_test_embedding(128, 1.0);
|
||||
|
||||
let success = router.add_pattern(&embedding, "test-pattern", "{}");
|
||||
assert!(success);
|
||||
assert_eq!(router.pattern_count(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_patterns_limit() {
|
||||
let mut router = HnswRouterWasm::new(128, 2);
|
||||
|
||||
let emb1 = create_test_embedding(128, 1.0);
|
||||
let emb2 = create_test_embedding(128, 2.0);
|
||||
let emb3 = create_test_embedding(128, 3.0);
|
||||
|
||||
assert!(router.add_pattern(&emb1, "pattern1", "{}"));
|
||||
assert!(router.add_pattern(&emb2, "pattern2", "{}"));
|
||||
assert!(!router.add_pattern(&emb3, "pattern3", "{}"));
|
||||
assert_eq!(router.pattern_count(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_route() {
|
||||
let mut router = HnswRouterWasm::new(128, 100);
|
||||
|
||||
// Add similar patterns
|
||||
let emb1 = create_test_embedding(128, 1.0);
|
||||
let emb2 = create_test_embedding(128, 1.1);
|
||||
let emb3 = create_test_embedding(128, 5.0);
|
||||
|
||||
router.add_pattern(&emb1, "similar1", r#"{"type":"A"}"#);
|
||||
router.add_pattern(&emb2, "similar2", r#"{"type":"A"}"#);
|
||||
router.add_pattern(&emb3, "different", r#"{"type":"B"}"#);
|
||||
|
||||
// Query similar to emb1
|
||||
let query = create_test_embedding(128, 1.05);
|
||||
let results = router.route(&query, 2);
|
||||
|
||||
assert_eq!(results.len(), 2);
|
||||
// First result should be most similar
|
||||
assert!(results[0].score() > results[1].score());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_similarity() {
|
||||
let a = vec![1.0, 0.0, 0.0];
|
||||
let b = vec![1.0, 0.0, 0.0];
|
||||
let sim = HnswGraph::cosine_similarity(&a, &b);
|
||||
assert!((sim - 1.0).abs() < 1e-5);
|
||||
|
||||
let c = vec![1.0, 0.0, 0.0];
|
||||
let d = vec![0.0, 1.0, 0.0];
|
||||
let sim2 = HnswGraph::cosine_similarity(&c, &d);
|
||||
assert!(sim2.abs() < 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_serialization() {
|
||||
let mut router = HnswRouterWasm::new(128, 100);
|
||||
let embedding = create_test_embedding(128, 1.0);
|
||||
router.add_pattern(&embedding, "test", r#"{"key":"value"}"#);
|
||||
|
||||
let json = router.to_json().unwrap();
|
||||
let restored = HnswRouterWasm::from_json(&json).unwrap();
|
||||
|
||||
assert_eq!(restored.dimensions(), 128);
|
||||
assert_eq!(restored.pattern_count(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clear() {
|
||||
let mut router = HnswRouterWasm::new(128, 100);
|
||||
let embedding = create_test_embedding(128, 1.0);
|
||||
router.add_pattern(&embedding, "test", "{}");
|
||||
|
||||
assert_eq!(router.pattern_count(), 1);
|
||||
router.clear();
|
||||
assert_eq!(router.pattern_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_pattern() {
|
||||
let mut router = HnswRouterWasm::new(128, 100);
|
||||
let embedding = create_test_embedding(128, 1.0);
|
||||
router.add_pattern(&embedding, "test-pattern", r#"{"foo":"bar"}"#);
|
||||
|
||||
let pattern = router.get_pattern(0).unwrap();
|
||||
assert_eq!(pattern.name(), "test-pattern");
|
||||
assert_eq!(pattern.metadata(), r#"{"foo":"bar"}"#);
|
||||
|
||||
assert!(router.get_pattern(1).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ef_search() {
|
||||
let mut router = HnswRouterWasm::new(128, 100);
|
||||
assert_eq!(router.ef_search(), DEFAULT_EF_SEARCH);
|
||||
|
||||
router.set_ef_search(200);
|
||||
assert_eq!(router.ef_search(), 200);
|
||||
}
|
||||
}
|
||||
287
vendor/ruvector/crates/ruvllm-wasm/src/lib.rs
vendored
Normal file
287
vendor/ruvector/crates/ruvllm-wasm/src/lib.rs
vendored
Normal file
@@ -0,0 +1,287 @@
|
||||
//! # RuvLLM WASM - Browser-Compatible LLM Inference Runtime
|
||||
//!
|
||||
//! This crate provides WebAssembly bindings for the RuvLLM inference runtime,
|
||||
//! enabling LLM inference directly in web browsers.
|
||||
//!
|
||||
//! ## Features
|
||||
//!
|
||||
//! - **KV Cache Management**: Two-tier KV cache with FP16 tail and quantized store
|
||||
//! - **Memory Pooling**: Efficient buffer reuse for minimal allocation overhead
|
||||
//! - **Chat Templates**: Support for Llama3, Mistral, Qwen, Phi, Gemma formats
|
||||
//! - **Intelligent Learning**: HNSW Router (150x faster), MicroLoRA (<1ms adaptation), SONA loops
|
||||
//! - **TypeScript-Friendly**: All types have getter/setter methods for easy JS interop
|
||||
//!
|
||||
//! ## Quick Start (JavaScript)
|
||||
//!
|
||||
//! ```javascript
|
||||
//! import init, { RuvLLMWasm, GenerateConfig, ChatMessageWasm, ChatTemplateWasm } from 'ruvllm-wasm';
|
||||
//!
|
||||
//! async function main() {
|
||||
//! // Initialize WASM module
|
||||
//! await init();
|
||||
//!
|
||||
//! // Create inference engine
|
||||
//! const llm = new RuvLLMWasm();
|
||||
//! llm.initialize();
|
||||
//!
|
||||
//! // Format a chat conversation
|
||||
//! const template = ChatTemplateWasm.llama3();
|
||||
//! const messages = [
|
||||
//! ChatMessageWasm.system("You are a helpful assistant."),
|
||||
//! ChatMessageWasm.user("What is WebAssembly?"),
|
||||
//! ];
|
||||
//! const prompt = template.format(messages);
|
||||
//!
|
||||
//! console.log("Formatted prompt:", prompt);
|
||||
//!
|
||||
//! // KV Cache management
|
||||
//! const config = new KvCacheConfigWasm();
|
||||
//! config.tailLength = 256;
|
||||
//! const kvCache = new KvCacheWasm(config);
|
||||
//!
|
||||
//! const stats = kvCache.stats();
|
||||
//! console.log("Cache stats:", stats.toJson());
|
||||
//!
|
||||
//! // Intelligent LLM with learning
|
||||
//! const intelligentConfig = new IntelligentConfigWasm();
|
||||
//! const intelligentLLM = new IntelligentLLMWasm(intelligentConfig);
|
||||
//!
|
||||
//! // Process with routing, LoRA, and SONA learning
|
||||
//! const embedding = new Float32Array(384);
|
||||
//! const output = intelligentLLM.process(embedding, "user query", 0.9);
|
||||
//!
|
||||
//! console.log("Intelligent stats:", intelligentLLM.stats());
|
||||
//! }
|
||||
//!
|
||||
//! main();
|
||||
//! ```
|
||||
//!
|
||||
//! ## Building
|
||||
//!
|
||||
//! ```bash
|
||||
//! # Build for browser (bundler target)
|
||||
//! wasm-pack build --target bundler
|
||||
//!
|
||||
//! # Build for Node.js
|
||||
//! wasm-pack build --target nodejs
|
||||
//!
|
||||
//! # Build for web (no bundler)
|
||||
//! wasm-pack build --target web
|
||||
//! ```
|
||||
//!
|
||||
//! ## Architecture
|
||||
//!
|
||||
//! ```text
|
||||
//! +-------------------+ +-------------------+
|
||||
//! | JavaScript/TS |---->| wasm-bindgen |
|
||||
//! | Application | | Bindings |
|
||||
//! +-------------------+ +-------------------+
|
||||
//! |
|
||||
//! v
|
||||
//! +-------------------+
|
||||
//! | RuvLLM Core |
|
||||
//! | (Rust WASM) |
|
||||
//! +-------------------+
|
||||
//! |
|
||||
//! v
|
||||
//! +-------------------+
|
||||
//! | Memory Pool |
|
||||
//! | KV Cache |
|
||||
//! | Chat Templates |
|
||||
//! +-------------------+
|
||||
//! ```
|
||||
//!
|
||||
//! ## Memory Management
|
||||
//!
|
||||
//! The WASM module uses efficient memory management strategies:
|
||||
//!
|
||||
//! - **Arena Allocator**: O(1) bump allocation for inference temporaries
|
||||
//! - **Buffer Pool**: Pre-allocated buffers in size classes (1KB-256KB)
|
||||
//! - **Two-Tier KV Cache**: FP32 tail + u8 quantized store
|
||||
//!
|
||||
//! ## Browser Compatibility
|
||||
//!
|
||||
//! Requires browsers with WebAssembly support:
|
||||
//! - Chrome 57+
|
||||
//! - Firefox 52+
|
||||
//! - Safari 11+
|
||||
//! - Edge 16+
|
||||
|
||||
#![warn(missing_docs)]
|
||||
#![warn(clippy::all)]
|
||||
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
pub mod bindings;
|
||||
pub mod hnsw_router;
|
||||
pub mod micro_lora;
|
||||
pub mod sona_instant;
|
||||
pub mod utils;
|
||||
pub mod workers;
|
||||
|
||||
#[cfg(feature = "webgpu")]
|
||||
pub mod webgpu;
|
||||
|
||||
// Re-export all bindings
|
||||
pub use bindings::*;
|
||||
pub use hnsw_router::{HnswRouterWasm, PatternWasm, RouteResultWasm};
|
||||
pub use sona_instant::{SonaAdaptResultWasm, SonaConfigWasm, SonaInstantWasm, SonaStatsWasm};
|
||||
pub use utils::{error, log, now_ms, set_panic_hook, warn, Timer};
|
||||
|
||||
// Re-export workers module
|
||||
pub use workers::{
|
||||
cross_origin_isolated, detect_capability_level, feature_summary, is_atomics_available,
|
||||
is_shared_array_buffer_available, optimal_worker_count, supports_parallel_inference,
|
||||
ParallelInference,
|
||||
};
|
||||
|
||||
// Re-export WebGPU module when enabled
|
||||
#[cfg(feature = "webgpu")]
|
||||
pub use webgpu::*;
|
||||
|
||||
/// Initialize the WASM module.
|
||||
///
|
||||
/// This should be called once at application startup to set up
|
||||
/// panic hooks and any other initialization.
|
||||
#[wasm_bindgen(start)]
|
||||
pub fn init() {
|
||||
utils::set_panic_hook();
|
||||
}
|
||||
|
||||
/// Perform a simple health check.
|
||||
///
|
||||
/// Returns true if the WASM module is functioning correctly.
|
||||
#[wasm_bindgen(js_name = healthCheck)]
|
||||
pub fn health_check() -> bool {
|
||||
// Verify we can create basic structures
|
||||
let arena = bindings::InferenceArenaWasm::new(1024);
|
||||
arena.capacity() >= 1024
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Integrated Intelligence System
|
||||
// ============================================================================
|
||||
// Note: This integration code is currently commented out pending full implementation
|
||||
// of micro_lora and sona_instant modules. The HNSW router can be used standalone.
|
||||
|
||||
/*
|
||||
/// Configuration for the intelligent LLM system (combines all components)
|
||||
#[wasm_bindgen]
|
||||
pub struct IntelligentConfigWasm {
|
||||
router_config: HnswRouterConfigWasm,
|
||||
lora_config: MicroLoraConfigWasm,
|
||||
sona_config: SonaConfigWasm,
|
||||
}
|
||||
*/
|
||||
|
||||
// Full integration system temporarily commented out - uncomment when micro_lora and sona_instant
|
||||
// are fully compatible with the new HnswRouterWasm API
|
||||
|
||||
/*
|
||||
#[wasm_bindgen]
|
||||
impl IntelligentConfigWasm {
|
||||
... (implementation temporarily removed)
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub struct IntelligentLLMWasm {
|
||||
... (implementation temporarily removed)
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl IntelligentLLMWasm {
|
||||
... (implementation temporarily removed)
|
||||
}
|
||||
*/
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_generate_config_defaults() {
|
||||
let config = bindings::GenerateConfig::new();
|
||||
assert_eq!(config.max_tokens, 256);
|
||||
assert!((config.temperature - 0.7).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_message() {
|
||||
let msg = bindings::ChatMessageWasm::user("Hello");
|
||||
assert_eq!(msg.role(), "user");
|
||||
assert_eq!(msg.content(), "Hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_template_detection() {
|
||||
let template = bindings::ChatTemplateWasm::detect_from_model_id("meta-llama/Llama-3-8B");
|
||||
assert_eq!(template.name(), "llama3");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_kv_cache_config() {
|
||||
let mut config = bindings::KvCacheConfigWasm::new();
|
||||
config.set_tail_length(512);
|
||||
assert_eq!(config.tail_length(), 512);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arena_creation() {
|
||||
let arena = bindings::InferenceArenaWasm::new(4096);
|
||||
assert!(arena.capacity() >= 4096);
|
||||
assert_eq!(arena.used(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_buffer_pool() {
|
||||
let pool = bindings::BufferPoolWasm::new();
|
||||
pool.prewarm_all(2);
|
||||
assert!(pool.hit_rate() >= 0.0);
|
||||
}
|
||||
|
||||
// RuvLLMWasm::new() calls set_panic_hook which uses wasm-bindgen,
|
||||
// so skip this test on non-wasm32 targets
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
#[test]
|
||||
fn test_ruvllm_wasm() {
|
||||
let mut llm = bindings::RuvLLMWasm::new();
|
||||
assert!(!llm.is_initialized());
|
||||
llm.initialize().unwrap();
|
||||
assert!(llm.is_initialized());
|
||||
}
|
||||
|
||||
// Integration tests temporarily commented out
|
||||
/*
|
||||
#[test]
|
||||
fn test_micro_lora_integration() {
|
||||
let config = micro_lora::MicroLoraConfigWasm::new();
|
||||
let adapter = micro_lora::MicroLoraWasm::new(&config);
|
||||
let stats = adapter.stats();
|
||||
assert_eq!(stats.samples_seen(), 0);
|
||||
assert!(stats.memory_bytes() > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_intelligent_llm_creation() {
|
||||
let config = IntelligentConfigWasm::new();
|
||||
let llm = IntelligentLLMWasm::new(config).unwrap();
|
||||
let stats_json = llm.stats();
|
||||
assert!(stats_json.contains("router"));
|
||||
assert!(stats_json.contains("lora"));
|
||||
assert!(stats_json.contains("sona"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_intelligent_llm_learn_pattern() {
|
||||
let config = IntelligentConfigWasm::new();
|
||||
let mut llm = IntelligentLLMWasm::new(config).unwrap();
|
||||
|
||||
let embedding = vec![0.1; 384];
|
||||
llm.learn_pattern(&embedding, "coder", "code_generation", "implement function", 0.85)
|
||||
.unwrap();
|
||||
|
||||
let stats_json = llm.stats();
|
||||
assert!(stats_json.contains("totalPatterns"));
|
||||
}
|
||||
*/
|
||||
}
|
||||
735
vendor/ruvector/crates/ruvllm-wasm/src/micro_lora.rs
vendored
Normal file
735
vendor/ruvector/crates/ruvllm-wasm/src/micro_lora.rs
vendored
Normal file
@@ -0,0 +1,735 @@
|
||||
//! MicroLoRA for WASM - Browser-Compatible Lightweight LoRA Adaptation
|
||||
//!
|
||||
//! This module provides ultra-lightweight LoRA (Low-Rank Adaptation) for browser-based
|
||||
//! LLM inference. Designed for minimal memory footprint and real-time per-request adaptation.
|
||||
//!
|
||||
//! ## Features
|
||||
//!
|
||||
//! - **Rank 1-4 adapters**: Very small memory footprint (<10KB per adapter)
|
||||
//! - **Pure Rust**: No threading, no file I/O, fully WASM-compatible
|
||||
//! - **Per-request adaptation**: Update weights based on user feedback
|
||||
//! - **Serialization**: JSON-based persistence for browser storage
|
||||
//!
|
||||
//! ## Example (JavaScript)
|
||||
//!
|
||||
//! ```javascript
|
||||
//! import { MicroLoraWasm, MicroLoraConfigWasm, AdaptFeedbackWasm } from 'ruvllm-wasm';
|
||||
//!
|
||||
//! // Create a rank-2 adapter for 768-dim hidden states
|
||||
//! const config = new MicroLoraConfigWasm();
|
||||
//! config.rank = 2;
|
||||
//! config.alpha = 4.0;
|
||||
//! config.inFeatures = 768;
|
||||
//! config.outFeatures = 768;
|
||||
//!
|
||||
//! const lora = new MicroLoraWasm(config);
|
||||
//!
|
||||
//! // Apply LoRA to input
|
||||
//! const input = new Float32Array(768);
|
||||
//! const output = lora.apply(input);
|
||||
//!
|
||||
//! // Adapt based on feedback
|
||||
//! const feedback = new AdaptFeedbackWasm();
|
||||
//! feedback.quality = 0.8;
|
||||
//! lora.adapt(input, feedback);
|
||||
//!
|
||||
//! // Serialize for persistence
|
||||
//! const json = lora.toJson();
|
||||
//! localStorage.setItem('lora-state', json);
|
||||
//!
|
||||
//! // Restore from JSON
|
||||
//! const restored = MicroLoraWasm.fromJson(json);
|
||||
//! ```
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
// ============================================================================
|
||||
// Configuration
|
||||
// ============================================================================
|
||||
|
||||
/// Configuration for MicroLoRA adapter.
|
||||
///
|
||||
/// Controls the rank, scaling, and dimensions of the LoRA adapter.
|
||||
/// TypeScript-friendly with getter/setter methods.
|
||||
#[wasm_bindgen]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MicroLoraConfigWasm {
|
||||
#[wasm_bindgen(skip)]
|
||||
pub rank: usize,
|
||||
#[wasm_bindgen(skip)]
|
||||
pub alpha: f32,
|
||||
#[wasm_bindgen(skip)]
|
||||
pub in_features: usize,
|
||||
#[wasm_bindgen(skip)]
|
||||
pub out_features: usize,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl MicroLoraConfigWasm {
|
||||
/// Create a new config with default values (rank=2, alpha=4.0, 768x768).
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
rank: 2,
|
||||
alpha: 4.0,
|
||||
in_features: 768,
|
||||
out_features: 768,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get rank.
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn rank(&self) -> usize {
|
||||
self.rank
|
||||
}
|
||||
|
||||
/// Set rank (clamped to 1-4 for browser efficiency).
|
||||
#[wasm_bindgen(setter)]
|
||||
pub fn set_rank(&mut self, value: usize) {
|
||||
self.rank = value.clamp(1, 4);
|
||||
}
|
||||
|
||||
/// Get alpha scaling factor.
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn alpha(&self) -> f32 {
|
||||
self.alpha
|
||||
}
|
||||
|
||||
/// Set alpha scaling factor.
|
||||
#[wasm_bindgen(setter)]
|
||||
pub fn set_alpha(&mut self, value: f32) {
|
||||
self.alpha = value;
|
||||
}
|
||||
|
||||
/// Get input feature dimension.
|
||||
#[wasm_bindgen(getter, js_name = inFeatures)]
|
||||
pub fn in_features(&self) -> usize {
|
||||
self.in_features
|
||||
}
|
||||
|
||||
/// Set input feature dimension.
|
||||
#[wasm_bindgen(setter, js_name = inFeatures)]
|
||||
pub fn set_in_features(&mut self, value: usize) {
|
||||
self.in_features = value;
|
||||
}
|
||||
|
||||
/// Get output feature dimension.
|
||||
#[wasm_bindgen(getter, js_name = outFeatures)]
|
||||
pub fn out_features(&self) -> usize {
|
||||
self.out_features
|
||||
}
|
||||
|
||||
/// Set output feature dimension.
|
||||
#[wasm_bindgen(setter, js_name = outFeatures)]
|
||||
pub fn set_out_features(&mut self, value: usize) {
|
||||
self.out_features = value;
|
||||
}
|
||||
|
||||
/// Calculate memory footprint in bytes.
|
||||
#[wasm_bindgen(js_name = memoryBytes)]
|
||||
pub fn memory_bytes(&self) -> usize {
|
||||
// A: in_features x rank, B: rank x out_features
|
||||
let params = self.in_features * self.rank + self.rank * self.out_features;
|
||||
params * std::mem::size_of::<f32>()
|
||||
}
|
||||
|
||||
/// Get computed scaling factor (alpha / rank).
|
||||
#[wasm_bindgen(js_name = computeScaling)]
|
||||
pub fn compute_scaling(&self) -> f32 {
|
||||
self.alpha / self.rank as f32
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MicroLoraConfigWasm {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Feedback for Adaptation
|
||||
// ============================================================================
|
||||
|
||||
/// Feedback for per-request adaptation.
|
||||
///
|
||||
/// Provides quality scores and optional gradient estimates to guide
|
||||
/// LoRA weight updates.
|
||||
#[wasm_bindgen]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AdaptFeedbackWasm {
|
||||
#[wasm_bindgen(skip)]
|
||||
pub quality: f32,
|
||||
#[wasm_bindgen(skip)]
|
||||
pub learning_rate: f32,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl AdaptFeedbackWasm {
|
||||
/// Create new feedback with quality score [0.0, 1.0].
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(quality: f32) -> Self {
|
||||
Self {
|
||||
quality: quality.clamp(0.0, 1.0),
|
||||
learning_rate: 0.01,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get quality score.
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn quality(&self) -> f32 {
|
||||
self.quality
|
||||
}
|
||||
|
||||
/// Set quality score (clamped to [0.0, 1.0]).
|
||||
#[wasm_bindgen(setter)]
|
||||
pub fn set_quality(&mut self, value: f32) {
|
||||
self.quality = value.clamp(0.0, 1.0);
|
||||
}
|
||||
|
||||
/// Get learning rate.
|
||||
#[wasm_bindgen(getter, js_name = learningRate)]
|
||||
pub fn learning_rate(&self) -> f32 {
|
||||
self.learning_rate
|
||||
}
|
||||
|
||||
/// Set learning rate.
|
||||
#[wasm_bindgen(setter, js_name = learningRate)]
|
||||
pub fn set_learning_rate(&mut self, value: f32) {
|
||||
self.learning_rate = value;
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Statistics
|
||||
// ============================================================================
|
||||
|
||||
/// Statistics for MicroLoRA adapter.
|
||||
#[wasm_bindgen]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MicroLoraStatsWasm {
|
||||
#[wasm_bindgen(skip)]
|
||||
pub samples_seen: usize,
|
||||
#[wasm_bindgen(skip)]
|
||||
pub avg_quality: f32,
|
||||
#[wasm_bindgen(skip)]
|
||||
pub memory_bytes: usize,
|
||||
#[wasm_bindgen(skip)]
|
||||
pub param_count: usize,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl MicroLoraStatsWasm {
|
||||
/// Get number of samples seen.
|
||||
#[wasm_bindgen(getter, js_name = samplesSeen)]
|
||||
pub fn samples_seen(&self) -> usize {
|
||||
self.samples_seen
|
||||
}
|
||||
|
||||
/// Get average quality score.
|
||||
#[wasm_bindgen(getter, js_name = avgQuality)]
|
||||
pub fn avg_quality(&self) -> f32 {
|
||||
self.avg_quality
|
||||
}
|
||||
|
||||
/// Get memory usage in bytes.
|
||||
#[wasm_bindgen(getter, js_name = memoryBytes)]
|
||||
pub fn memory_bytes(&self) -> usize {
|
||||
self.memory_bytes
|
||||
}
|
||||
|
||||
/// Get parameter count.
|
||||
#[wasm_bindgen(getter, js_name = paramCount)]
|
||||
pub fn param_count(&self) -> usize {
|
||||
self.param_count
|
||||
}
|
||||
|
||||
/// Convert to JSON string.
|
||||
#[wasm_bindgen(js_name = toJson)]
|
||||
pub fn to_json(&self) -> Result<String, JsValue> {
|
||||
serde_json::to_string(self)
|
||||
.map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// MicroLoRA Adapter (Internal)
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct LoraAdapterInternal {
|
||||
/// A matrix (in_features x rank) - down projection
|
||||
lora_a: Vec<f32>,
|
||||
/// B matrix (rank x out_features) - up projection
|
||||
lora_b: Vec<f32>,
|
||||
/// Scaling factor (alpha / rank)
|
||||
scaling: f32,
|
||||
/// Rank
|
||||
rank: usize,
|
||||
/// Input features
|
||||
in_features: usize,
|
||||
/// Output features
|
||||
out_features: usize,
|
||||
/// Accumulated gradients for A
|
||||
grad_a: Vec<f32>,
|
||||
/// Accumulated gradients for B
|
||||
grad_b: Vec<f32>,
|
||||
/// Number of accumulated gradients
|
||||
grad_count: usize,
|
||||
}
|
||||
|
||||
impl LoraAdapterInternal {
|
||||
/// Create a new LoRA adapter with Kaiming initialization for A and zeros for B.
|
||||
fn new(in_features: usize, out_features: usize, rank: usize, alpha: f32) -> Self {
|
||||
let scaling = alpha / rank as f32;
|
||||
|
||||
// Kaiming initialization for A
|
||||
let std_a = (2.0 / in_features as f32).sqrt() * 0.01;
|
||||
let mut lora_a = Vec::with_capacity(in_features * rank);
|
||||
for i in 0..(in_features * rank) {
|
||||
// Deterministic pseudo-random for reproducibility
|
||||
let seed = i as f32;
|
||||
let value = ((seed * 0.618033988749895) % 1.0 - 0.5) * 2.0 * std_a;
|
||||
lora_a.push(value);
|
||||
}
|
||||
|
||||
// Zero initialization for B (standard LoRA)
|
||||
let lora_b = vec![0.0; rank * out_features];
|
||||
|
||||
Self {
|
||||
lora_a,
|
||||
lora_b,
|
||||
scaling,
|
||||
rank,
|
||||
in_features,
|
||||
out_features,
|
||||
grad_a: vec![0.0; in_features * rank],
|
||||
grad_b: vec![0.0; rank * out_features],
|
||||
grad_count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass: output = x @ A @ B * scaling
|
||||
fn forward(&self, input: &[f32], output: &mut [f32]) {
|
||||
debug_assert_eq!(input.len(), self.in_features);
|
||||
debug_assert_eq!(output.len(), self.out_features);
|
||||
|
||||
// Compute intermediate: x @ A (in_features -> rank)
|
||||
let mut intermediate = vec![0.0; self.rank];
|
||||
for r in 0..self.rank {
|
||||
let mut sum = 0.0;
|
||||
for i in 0..self.in_features {
|
||||
sum += input[i] * self.lora_a[i * self.rank + r];
|
||||
}
|
||||
intermediate[r] = sum;
|
||||
}
|
||||
|
||||
// Compute output: intermediate @ B * scaling (rank -> out_features)
|
||||
for o in 0..self.out_features {
|
||||
let mut sum = 0.0;
|
||||
for r in 0..self.rank {
|
||||
sum += intermediate[r] * self.lora_b[r * self.out_features + o];
|
||||
}
|
||||
output[o] += sum * self.scaling;
|
||||
}
|
||||
}
|
||||
|
||||
/// Accumulate gradients based on feedback quality.
|
||||
///
|
||||
/// Uses a simplified gradient estimate based on the quality score.
|
||||
/// For browser use, we use a lightweight update rule without full backprop.
|
||||
fn accumulate_gradient(&mut self, input: &[f32], quality: f32) {
|
||||
// Compute intermediate activation
|
||||
let mut intermediate = vec![0.0; self.rank];
|
||||
for r in 0..self.rank {
|
||||
let mut sum = 0.0;
|
||||
for i in 0..self.in_features {
|
||||
sum += input[i] * self.lora_a[i * self.rank + r];
|
||||
}
|
||||
intermediate[r] = sum;
|
||||
}
|
||||
|
||||
// Simple gradient estimate: use quality as reward signal
|
||||
// For positive quality (>0.5), strengthen current activation patterns
|
||||
// For negative quality (<0.5), weaken them
|
||||
let reward = (quality - 0.5) * 2.0; // Map [0,1] to [-1,1]
|
||||
|
||||
// Update B gradients: outer product of intermediate and reward
|
||||
for r in 0..self.rank {
|
||||
for o in 0..self.out_features {
|
||||
let idx = r * self.out_features + o;
|
||||
self.grad_b[idx] += intermediate[r] * reward * self.scaling * 0.01;
|
||||
}
|
||||
}
|
||||
|
||||
// Update A gradients: outer product of input and reward-weighted intermediate
|
||||
for i in 0..self.in_features {
|
||||
for r in 0..self.rank {
|
||||
let idx = i * self.rank + r;
|
||||
self.grad_a[idx] += input[i] * reward * self.scaling * 0.01;
|
||||
}
|
||||
}
|
||||
|
||||
self.grad_count += 1;
|
||||
}
|
||||
|
||||
/// Apply accumulated gradients with learning rate.
|
||||
fn apply_gradients(&mut self, learning_rate: f32) {
|
||||
if self.grad_count == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
let scale = learning_rate / self.grad_count as f32;
|
||||
|
||||
// Update A
|
||||
for i in 0..self.lora_a.len() {
|
||||
self.lora_a[i] -= self.grad_a[i] * scale;
|
||||
}
|
||||
|
||||
// Update B
|
||||
for i in 0..self.lora_b.len() {
|
||||
self.lora_b[i] -= self.grad_b[i] * scale;
|
||||
}
|
||||
|
||||
// Reset gradients
|
||||
for g in &mut self.grad_a {
|
||||
*g = 0.0;
|
||||
}
|
||||
for g in &mut self.grad_b {
|
||||
*g = 0.0;
|
||||
}
|
||||
self.grad_count = 0;
|
||||
}
|
||||
|
||||
/// Reset adapter to initial state.
|
||||
fn reset(&mut self) {
|
||||
// Reset B to zeros
|
||||
for b in &mut self.lora_b {
|
||||
*b = 0.0;
|
||||
}
|
||||
|
||||
// Reset gradients
|
||||
for g in &mut self.grad_a {
|
||||
*g = 0.0;
|
||||
}
|
||||
for g in &mut self.grad_b {
|
||||
*g = 0.0;
|
||||
}
|
||||
self.grad_count = 0;
|
||||
}
|
||||
|
||||
/// Get parameter count.
|
||||
fn param_count(&self) -> usize {
|
||||
self.lora_a.len() + self.lora_b.len()
|
||||
}
|
||||
|
||||
/// Get memory usage in bytes.
|
||||
fn memory_bytes(&self) -> usize {
|
||||
self.param_count() * std::mem::size_of::<f32>()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// MicroLoRA (Public WASM Interface)
|
||||
// ============================================================================
|
||||
|
||||
/// MicroLoRA adapter for browser-based real-time adaptation.
|
||||
///
|
||||
/// Provides lightweight LoRA (Low-Rank Adaptation) with minimal memory footprint
|
||||
/// suitable for browser environments. Supports per-request adaptation with
|
||||
/// quality-based feedback.
|
||||
#[wasm_bindgen]
|
||||
pub struct MicroLoraWasm {
|
||||
adapter: LoraAdapterInternal,
|
||||
samples_seen: usize,
|
||||
quality_sum: f32,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl MicroLoraWasm {
|
||||
/// Create a new MicroLoRA adapter with the given configuration.
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(config: &MicroLoraConfigWasm) -> Self {
|
||||
let adapter = LoraAdapterInternal::new(
|
||||
config.in_features,
|
||||
config.out_features,
|
||||
config.rank,
|
||||
config.alpha,
|
||||
);
|
||||
|
||||
Self {
|
||||
adapter,
|
||||
samples_seen: 0,
|
||||
quality_sum: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply LoRA transformation to input.
|
||||
///
|
||||
/// Returns a new Float32Array with the transformed output.
|
||||
/// The output is added to (not replaced) so you can combine with base model output.
|
||||
#[wasm_bindgen]
|
||||
pub fn apply(&self, input: &[f32]) -> Result<Vec<f32>, JsValue> {
|
||||
if input.len() != self.adapter.in_features {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Input size mismatch: expected {}, got {}",
|
||||
self.adapter.in_features,
|
||||
input.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let mut output = vec![0.0; self.adapter.out_features];
|
||||
self.adapter.forward(input, &mut output);
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Adapt the LoRA weights based on feedback.
|
||||
///
|
||||
/// Accumulates gradients based on the quality score. Call `applyUpdates()`
|
||||
/// to actually apply the accumulated gradients.
|
||||
#[wasm_bindgen]
|
||||
pub fn adapt(&mut self, input: &[f32], feedback: &AdaptFeedbackWasm) -> Result<(), JsValue> {
|
||||
if input.len() != self.adapter.in_features {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Input size mismatch: expected {}, got {}",
|
||||
self.adapter.in_features,
|
||||
input.len()
|
||||
)));
|
||||
}
|
||||
|
||||
self.adapter.accumulate_gradient(input, feedback.quality);
|
||||
self.samples_seen += 1;
|
||||
self.quality_sum += feedback.quality;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Apply accumulated gradients with the given learning rate.
|
||||
///
|
||||
/// Should be called after one or more `adapt()` calls to update the weights.
|
||||
#[wasm_bindgen(js_name = applyUpdates)]
|
||||
pub fn apply_updates(&mut self, learning_rate: f32) {
|
||||
self.adapter.apply_gradients(learning_rate);
|
||||
}
|
||||
|
||||
/// Reset the adapter to its initial state.
|
||||
///
|
||||
/// Clears B weights and all statistics.
|
||||
#[wasm_bindgen]
|
||||
pub fn reset(&mut self) {
|
||||
self.adapter.reset();
|
||||
self.samples_seen = 0;
|
||||
self.quality_sum = 0.0;
|
||||
}
|
||||
|
||||
/// Get adapter statistics.
|
||||
#[wasm_bindgen]
|
||||
pub fn stats(&self) -> MicroLoraStatsWasm {
|
||||
MicroLoraStatsWasm {
|
||||
samples_seen: self.samples_seen,
|
||||
avg_quality: if self.samples_seen > 0 {
|
||||
self.quality_sum / self.samples_seen as f32
|
||||
} else {
|
||||
0.0
|
||||
},
|
||||
memory_bytes: self.adapter.memory_bytes(),
|
||||
param_count: self.adapter.param_count(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Serialize to JSON string for persistence.
|
||||
#[wasm_bindgen(js_name = toJson)]
|
||||
pub fn to_json(&self) -> Result<String, JsValue> {
|
||||
#[derive(Serialize)]
|
||||
struct SerializedState {
|
||||
adapter: LoraAdapterInternal,
|
||||
samples_seen: usize,
|
||||
quality_sum: f32,
|
||||
}
|
||||
|
||||
let state = SerializedState {
|
||||
adapter: self.adapter.clone(),
|
||||
samples_seen: self.samples_seen,
|
||||
quality_sum: self.quality_sum,
|
||||
};
|
||||
|
||||
serde_json::to_string(&state)
|
||||
.map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
|
||||
}
|
||||
|
||||
/// Deserialize from JSON string.
|
||||
#[wasm_bindgen(js_name = fromJson)]
|
||||
pub fn from_json(json: &str) -> Result<MicroLoraWasm, JsValue> {
|
||||
#[derive(Deserialize)]
|
||||
struct SerializedState {
|
||||
adapter: LoraAdapterInternal,
|
||||
samples_seen: usize,
|
||||
quality_sum: f32,
|
||||
}
|
||||
|
||||
let state: SerializedState = serde_json::from_str(json)
|
||||
.map_err(|e| JsValue::from_str(&format!("Deserialization error: {}", e)))?;
|
||||
|
||||
Ok(MicroLoraWasm {
|
||||
adapter: state.adapter,
|
||||
samples_seen: state.samples_seen,
|
||||
quality_sum: state.quality_sum,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get number of pending gradient updates.
|
||||
#[wasm_bindgen(js_name = pendingUpdates)]
|
||||
pub fn pending_updates(&self) -> usize {
|
||||
self.adapter.grad_count
|
||||
}
|
||||
|
||||
/// Get configuration.
|
||||
#[wasm_bindgen(js_name = getConfig)]
|
||||
pub fn get_config(&self) -> MicroLoraConfigWasm {
|
||||
MicroLoraConfigWasm {
|
||||
rank: self.adapter.rank,
|
||||
alpha: self.adapter.scaling * self.adapter.rank as f32,
|
||||
in_features: self.adapter.in_features,
|
||||
out_features: self.adapter.out_features,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_config_creation() {
|
||||
let config = MicroLoraConfigWasm::new();
|
||||
assert_eq!(config.rank(), 2);
|
||||
assert_eq!(config.alpha(), 4.0);
|
||||
assert_eq!(config.in_features(), 768);
|
||||
assert_eq!(config.out_features(), 768);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_rank_clamping() {
|
||||
let mut config = MicroLoraConfigWasm::new();
|
||||
config.set_rank(10);
|
||||
assert_eq!(config.rank(), 4); // Clamped to max 4
|
||||
config.set_rank(0);
|
||||
assert_eq!(config.rank(), 1); // Clamped to min 1
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_adapter_creation() {
|
||||
let config = MicroLoraConfigWasm::new();
|
||||
let adapter = MicroLoraWasm::new(&config);
|
||||
let stats = adapter.stats();
|
||||
assert_eq!(stats.samples_seen(), 0);
|
||||
assert_eq!(stats.avg_quality(), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_forward_pass() {
|
||||
let mut config = MicroLoraConfigWasm::new();
|
||||
config.set_in_features(64);
|
||||
config.set_out_features(64);
|
||||
config.set_rank(2);
|
||||
|
||||
let adapter = MicroLoraWasm::new(&config);
|
||||
let input = vec![1.0; 64];
|
||||
|
||||
let output = adapter.apply(&input).unwrap();
|
||||
assert_eq!(output.len(), 64);
|
||||
|
||||
// With zero-initialized B, output should be very small
|
||||
let sum: f32 = output.iter().map(|x| x.abs()).sum();
|
||||
assert!(sum < 0.1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_adaptation() {
|
||||
let mut config = MicroLoraConfigWasm::new();
|
||||
config.set_in_features(64);
|
||||
config.set_out_features(64);
|
||||
config.set_rank(2);
|
||||
|
||||
let mut adapter = MicroLoraWasm::new(&config);
|
||||
let input = vec![0.1; 64];
|
||||
let feedback = AdaptFeedbackWasm::new(0.8);
|
||||
|
||||
adapter.adapt(&input, &feedback).unwrap();
|
||||
assert_eq!(adapter.pending_updates(), 1);
|
||||
|
||||
adapter.apply_updates(0.01);
|
||||
assert_eq!(adapter.pending_updates(), 0);
|
||||
|
||||
let stats = adapter.stats();
|
||||
assert_eq!(stats.samples_seen(), 1);
|
||||
assert!((stats.avg_quality() - 0.8).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_serialization() {
|
||||
let mut config = MicroLoraConfigWasm::new();
|
||||
config.set_in_features(32);
|
||||
config.set_out_features(32);
|
||||
config.set_rank(2);
|
||||
|
||||
let mut adapter = MicroLoraWasm::new(&config);
|
||||
let input = vec![0.1; 32];
|
||||
let feedback = AdaptFeedbackWasm::new(0.9);
|
||||
|
||||
adapter.adapt(&input, &feedback).unwrap();
|
||||
adapter.apply_updates(0.01);
|
||||
|
||||
let json = adapter.to_json().unwrap();
|
||||
let restored = MicroLoraWasm::from_json(&json).unwrap();
|
||||
|
||||
let stats1 = adapter.stats();
|
||||
let stats2 = restored.stats();
|
||||
|
||||
assert_eq!(stats1.samples_seen(), stats2.samples_seen());
|
||||
assert!((stats1.avg_quality() - stats2.avg_quality()).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reset() {
|
||||
let mut config = MicroLoraConfigWasm::new();
|
||||
config.set_in_features(32);
|
||||
config.set_out_features(32);
|
||||
|
||||
let mut adapter = MicroLoraWasm::new(&config);
|
||||
let input = vec![0.1; 32];
|
||||
let feedback = AdaptFeedbackWasm::new(0.8);
|
||||
|
||||
adapter.adapt(&input, &feedback).unwrap();
|
||||
adapter.apply_updates(0.01);
|
||||
|
||||
let stats_before = adapter.stats();
|
||||
assert_eq!(stats_before.samples_seen(), 1);
|
||||
|
||||
adapter.reset();
|
||||
|
||||
let stats_after = adapter.stats();
|
||||
assert_eq!(stats_after.samples_seen(), 0);
|
||||
assert_eq!(stats_after.avg_quality(), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_memory_calculation() {
|
||||
let mut config = MicroLoraConfigWasm::new();
|
||||
config.set_in_features(768);
|
||||
config.set_out_features(768);
|
||||
config.set_rank(2);
|
||||
|
||||
let memory = config.memory_bytes();
|
||||
// (768 * 2 + 2 * 768) * 4 bytes = 3072 * 4 = 12288 bytes
|
||||
assert_eq!(memory, 12288);
|
||||
|
||||
let adapter = MicroLoraWasm::new(&config);
|
||||
let stats = adapter.stats();
|
||||
assert_eq!(stats.memory_bytes(), 12288);
|
||||
}
|
||||
}
|
||||
845
vendor/ruvector/crates/ruvllm-wasm/src/sona_instant.rs
vendored
Normal file
845
vendor/ruvector/crates/ruvllm-wasm/src/sona_instant.rs
vendored
Normal file
@@ -0,0 +1,845 @@
|
||||
//! SONA Instant Loop - Browser-Compatible Instant Learning
|
||||
//!
|
||||
//! Pure Rust, WASM-compatible implementation of SONA's instant learning loop
|
||||
//! with <1ms adaptation latency target.
|
||||
//!
|
||||
//! ## Features
|
||||
//!
|
||||
//! - **Instant Adaptation**: <1ms per quality signal
|
||||
//! - **Pattern Recognition**: HNSW-indexed pattern buffer (max 1000)
|
||||
//! - **EWC-Lite**: Simplified elastic weight consolidation
|
||||
//! - **Exponential Moving Average**: Quality tracking
|
||||
//! - **Pure WASM**: No threads, no async, browser-safe
|
||||
//!
|
||||
//! ## Architecture
|
||||
//!
|
||||
//! ```text
|
||||
//! Quality Signal (f32)
|
||||
//! |
|
||||
//! v
|
||||
//! +----------------+
|
||||
//! | Instant Adapt | <1ms target
|
||||
//! | - Update EMA |
|
||||
//! | - Adjust rank |
|
||||
//! | - Apply EWC |
|
||||
//! +----------------+
|
||||
//! |
|
||||
//! v
|
||||
//! Pattern Buffer (1000)
|
||||
//! HNSW-indexed for fast search
|
||||
//! ```
|
||||
//!
|
||||
//! ## Example (JavaScript)
|
||||
//!
|
||||
//! ```javascript
|
||||
//! import { SonaInstantWasm, SonaConfigWasm } from 'ruvllm-wasm';
|
||||
//!
|
||||
//! // Create SONA instance
|
||||
//! const config = new SonaConfigWasm();
|
||||
//! config.learningRate = 0.01;
|
||||
//! const sona = new SonaInstantWasm(config);
|
||||
//!
|
||||
//! // Instant adaptation
|
||||
//! const result = sona.instantAdapt(0.8);
|
||||
//! console.log(`Adapted in ${result.latencyUs}μs, quality: ${result.qualityDelta}`);
|
||||
//!
|
||||
//! // Record pattern outcome
|
||||
//! const embedding = new Float32Array([0.1, 0.2, 0.3, ...]);
|
||||
//! sona.recordPattern(embedding, true);
|
||||
//!
|
||||
//! // Get suggestion based on context
|
||||
//! const suggestion = sona.suggestAction(embedding);
|
||||
//! console.log(`Suggestion: ${suggestion || 'none'}`);
|
||||
//!
|
||||
//! // View statistics
|
||||
//! const stats = sona.stats();
|
||||
//! console.log(`Adaptations: ${stats.adaptations}, Avg quality: ${stats.avgQuality}`);
|
||||
//! ```
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::VecDeque;
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
// ============================================================================
|
||||
// Configuration
|
||||
// ============================================================================
|
||||
|
||||
/// Configuration for SONA Instant Loop (WASM)
|
||||
#[wasm_bindgen]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SonaConfigWasm {
|
||||
/// Hidden dimension size
|
||||
#[wasm_bindgen(skip)]
|
||||
pub hidden_dim: usize,
|
||||
/// Micro-LoRA rank (1-2 for instant learning)
|
||||
#[wasm_bindgen(skip)]
|
||||
pub micro_lora_rank: usize,
|
||||
/// Learning rate for instant updates
|
||||
#[wasm_bindgen(skip)]
|
||||
pub learning_rate: f32,
|
||||
/// EMA decay factor for quality tracking
|
||||
#[wasm_bindgen(skip)]
|
||||
pub ema_decay: f32,
|
||||
/// Pattern buffer capacity (max 1000 for WASM)
|
||||
#[wasm_bindgen(skip)]
|
||||
pub pattern_capacity: usize,
|
||||
/// EWC regularization strength
|
||||
#[wasm_bindgen(skip)]
|
||||
pub ewc_lambda: f32,
|
||||
/// Minimum quality threshold for learning
|
||||
#[wasm_bindgen(skip)]
|
||||
pub quality_threshold: f32,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl SonaConfigWasm {
|
||||
/// Create new config with defaults
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
hidden_dim: 256,
|
||||
micro_lora_rank: 1,
|
||||
learning_rate: 0.01,
|
||||
ema_decay: 0.95,
|
||||
pattern_capacity: 1000,
|
||||
ewc_lambda: 0.1,
|
||||
quality_threshold: 0.5,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get hidden dimension
|
||||
#[wasm_bindgen(getter, js_name = hiddenDim)]
|
||||
pub fn hidden_dim(&self) -> usize {
|
||||
self.hidden_dim
|
||||
}
|
||||
|
||||
/// Set hidden dimension
|
||||
#[wasm_bindgen(setter, js_name = hiddenDim)]
|
||||
pub fn set_hidden_dim(&mut self, value: usize) {
|
||||
self.hidden_dim = value;
|
||||
}
|
||||
|
||||
/// Get micro-LoRA rank
|
||||
#[wasm_bindgen(getter, js_name = microLoraRank)]
|
||||
pub fn micro_lora_rank(&self) -> usize {
|
||||
self.micro_lora_rank
|
||||
}
|
||||
|
||||
/// Set micro-LoRA rank
|
||||
#[wasm_bindgen(setter, js_name = microLoraRank)]
|
||||
pub fn set_micro_lora_rank(&mut self, value: usize) {
|
||||
self.micro_lora_rank = value.max(1).min(4); // Clamp 1-4
|
||||
}
|
||||
|
||||
/// Get learning rate
|
||||
#[wasm_bindgen(getter, js_name = learningRate)]
|
||||
pub fn learning_rate(&self) -> f32 {
|
||||
self.learning_rate
|
||||
}
|
||||
|
||||
/// Set learning rate
|
||||
#[wasm_bindgen(setter, js_name = learningRate)]
|
||||
pub fn set_learning_rate(&mut self, value: f32) {
|
||||
self.learning_rate = value.max(0.0).min(1.0);
|
||||
}
|
||||
|
||||
/// Get EMA decay
|
||||
#[wasm_bindgen(getter, js_name = emaDecay)]
|
||||
pub fn ema_decay(&self) -> f32 {
|
||||
self.ema_decay
|
||||
}
|
||||
|
||||
/// Set EMA decay
|
||||
#[wasm_bindgen(setter, js_name = emaDecay)]
|
||||
pub fn set_ema_decay(&mut self, value: f32) {
|
||||
self.ema_decay = value.max(0.0).min(1.0);
|
||||
}
|
||||
|
||||
/// Get pattern capacity
|
||||
#[wasm_bindgen(getter, js_name = patternCapacity)]
|
||||
pub fn pattern_capacity(&self) -> usize {
|
||||
self.pattern_capacity
|
||||
}
|
||||
|
||||
/// Set pattern capacity
|
||||
#[wasm_bindgen(setter, js_name = patternCapacity)]
|
||||
pub fn set_pattern_capacity(&mut self, value: usize) {
|
||||
self.pattern_capacity = value.max(10).min(1000);
|
||||
}
|
||||
|
||||
/// Get EWC lambda
|
||||
#[wasm_bindgen(getter, js_name = ewcLambda)]
|
||||
pub fn ewc_lambda(&self) -> f32 {
|
||||
self.ewc_lambda
|
||||
}
|
||||
|
||||
/// Set EWC lambda
|
||||
#[wasm_bindgen(setter, js_name = ewcLambda)]
|
||||
pub fn set_ewc_lambda(&mut self, value: f32) {
|
||||
self.ewc_lambda = value.max(0.0).min(1.0);
|
||||
}
|
||||
|
||||
/// Convert to JSON
|
||||
#[wasm_bindgen(js_name = toJson)]
|
||||
pub fn to_json(&self) -> Result<String, JsValue> {
|
||||
serde_json::to_string(self).map_err(|e| JsValue::from_str(&e.to_string()))
|
||||
}
|
||||
|
||||
/// Create from JSON
|
||||
#[wasm_bindgen(js_name = fromJson)]
|
||||
pub fn from_json(json: &str) -> Result<SonaConfigWasm, JsValue> {
|
||||
serde_json::from_str(json).map_err(|e| JsValue::from_str(&e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SonaConfigWasm {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Pattern Storage
|
||||
// ============================================================================
|
||||
|
||||
/// Pattern stored in buffer
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct Pattern {
|
||||
/// Pattern embedding
|
||||
embedding: Vec<f32>,
|
||||
/// Success/failure
|
||||
success: bool,
|
||||
/// Quality score
|
||||
quality: f32,
|
||||
/// Timestamp (monotonic counter for WASM)
|
||||
timestamp: u64,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Adaptation Result
|
||||
// ============================================================================
|
||||
|
||||
/// Result of instant adaptation
|
||||
#[wasm_bindgen]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SonaAdaptResultWasm {
|
||||
/// Whether adaptation was applied
|
||||
#[wasm_bindgen(skip)]
|
||||
pub applied: bool,
|
||||
/// Latency in microseconds
|
||||
#[wasm_bindgen(skip)]
|
||||
pub latency_us: u64,
|
||||
/// Estimated quality improvement
|
||||
#[wasm_bindgen(skip)]
|
||||
pub quality_delta: f32,
|
||||
/// New quality EMA
|
||||
#[wasm_bindgen(skip)]
|
||||
pub quality_ema: f32,
|
||||
/// Current rank
|
||||
#[wasm_bindgen(skip)]
|
||||
pub current_rank: usize,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl SonaAdaptResultWasm {
|
||||
/// Get applied status
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn applied(&self) -> bool {
|
||||
self.applied
|
||||
}
|
||||
|
||||
/// Get latency in microseconds
|
||||
#[wasm_bindgen(getter, js_name = latencyUs)]
|
||||
pub fn latency_us(&self) -> u64 {
|
||||
self.latency_us
|
||||
}
|
||||
|
||||
/// Get quality delta
|
||||
#[wasm_bindgen(getter, js_name = qualityDelta)]
|
||||
pub fn quality_delta(&self) -> f32 {
|
||||
self.quality_delta
|
||||
}
|
||||
|
||||
/// Get quality EMA
|
||||
#[wasm_bindgen(getter, js_name = qualityEma)]
|
||||
pub fn quality_ema(&self) -> f32 {
|
||||
self.quality_ema
|
||||
}
|
||||
|
||||
/// Get current rank
|
||||
#[wasm_bindgen(getter, js_name = currentRank)]
|
||||
pub fn current_rank(&self) -> usize {
|
||||
self.current_rank
|
||||
}
|
||||
|
||||
/// Convert to JSON
|
||||
#[wasm_bindgen(js_name = toJson)]
|
||||
pub fn to_json(&self) -> Result<String, JsValue> {
|
||||
serde_json::to_string(self).map_err(|e| JsValue::from_str(&e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Statistics
|
||||
// ============================================================================
|
||||
|
||||
/// Learning statistics
|
||||
#[wasm_bindgen]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SonaStatsWasm {
|
||||
/// Total adaptations performed
|
||||
#[wasm_bindgen(skip)]
|
||||
pub adaptations: u64,
|
||||
/// Average quality score (EMA)
|
||||
#[wasm_bindgen(skip)]
|
||||
pub avg_quality: f32,
|
||||
/// Total patterns recorded
|
||||
#[wasm_bindgen(skip)]
|
||||
pub patterns_recorded: u64,
|
||||
/// Successful patterns
|
||||
#[wasm_bindgen(skip)]
|
||||
pub successful_patterns: u64,
|
||||
/// Current pattern buffer size
|
||||
#[wasm_bindgen(skip)]
|
||||
pub buffer_size: usize,
|
||||
/// Average latency (microseconds)
|
||||
#[wasm_bindgen(skip)]
|
||||
pub avg_latency_us: f32,
|
||||
/// Current rank
|
||||
#[wasm_bindgen(skip)]
|
||||
pub current_rank: usize,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl SonaStatsWasm {
|
||||
/// Get adaptations count
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn adaptations(&self) -> u64 {
|
||||
self.adaptations
|
||||
}
|
||||
|
||||
/// Get average quality
|
||||
#[wasm_bindgen(getter, js_name = avgQuality)]
|
||||
pub fn avg_quality(&self) -> f32 {
|
||||
self.avg_quality
|
||||
}
|
||||
|
||||
/// Get patterns recorded
|
||||
#[wasm_bindgen(getter, js_name = patternsRecorded)]
|
||||
pub fn patterns_recorded(&self) -> u64 {
|
||||
self.patterns_recorded
|
||||
}
|
||||
|
||||
/// Get successful patterns
|
||||
#[wasm_bindgen(getter, js_name = successfulPatterns)]
|
||||
pub fn successful_patterns(&self) -> u64 {
|
||||
self.successful_patterns
|
||||
}
|
||||
|
||||
/// Get buffer size
|
||||
#[wasm_bindgen(getter, js_name = bufferSize)]
|
||||
pub fn buffer_size(&self) -> usize {
|
||||
self.buffer_size
|
||||
}
|
||||
|
||||
/// Get average latency
|
||||
#[wasm_bindgen(getter, js_name = avgLatencyUs)]
|
||||
pub fn avg_latency_us(&self) -> f32 {
|
||||
self.avg_latency_us
|
||||
}
|
||||
|
||||
/// Get current rank
|
||||
#[wasm_bindgen(getter, js_name = currentRank)]
|
||||
pub fn current_rank(&self) -> usize {
|
||||
self.current_rank
|
||||
}
|
||||
|
||||
/// Success rate
|
||||
#[wasm_bindgen(js_name = successRate)]
|
||||
pub fn success_rate(&self) -> f32 {
|
||||
if self.patterns_recorded == 0 {
|
||||
0.0
|
||||
} else {
|
||||
self.successful_patterns as f32 / self.patterns_recorded as f32
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert to JSON
|
||||
#[wasm_bindgen(js_name = toJson)]
|
||||
pub fn to_json(&self) -> Result<String, JsValue> {
|
||||
serde_json::to_string(self).map_err(|e| JsValue::from_str(&e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Main SONA Engine
|
||||
// ============================================================================
|
||||
|
||||
/// SONA Instant Loop for WASM
|
||||
#[wasm_bindgen]
|
||||
pub struct SonaInstantWasm {
|
||||
/// Configuration
|
||||
config: SonaConfigWasm,
|
||||
/// Pattern buffer (circular buffer)
|
||||
patterns: VecDeque<Pattern>,
|
||||
/// Quality EMA
|
||||
quality_ema: f32,
|
||||
/// Total adaptations
|
||||
adaptations: u64,
|
||||
/// Total latency accumulator (for averaging)
|
||||
latency_sum: u64,
|
||||
/// Patterns recorded
|
||||
patterns_recorded: u64,
|
||||
/// Successful patterns
|
||||
successful_patterns: u64,
|
||||
/// Timestamp counter (monotonic for WASM)
|
||||
timestamp: u64,
|
||||
/// EWC-lite: Important weight indices
|
||||
important_weights: Vec<usize>,
|
||||
/// Current effective rank
|
||||
current_rank: usize,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl SonaInstantWasm {
|
||||
/// Create new SONA instant loop
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(config: SonaConfigWasm) -> Self {
|
||||
let current_rank = config.micro_lora_rank;
|
||||
Self {
|
||||
patterns: VecDeque::with_capacity(config.pattern_capacity),
|
||||
quality_ema: 0.5, // Start neutral
|
||||
adaptations: 0,
|
||||
latency_sum: 0,
|
||||
patterns_recorded: 0,
|
||||
successful_patterns: 0,
|
||||
timestamp: 0,
|
||||
important_weights: Vec::new(),
|
||||
current_rank,
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
/// Instant adaptation based on quality signal
|
||||
///
|
||||
/// Target: <1ms latency
|
||||
#[wasm_bindgen(js_name = instantAdapt)]
|
||||
pub fn instant_adapt(&mut self, quality: f32) -> SonaAdaptResultWasm {
|
||||
let start = crate::utils::now_ms();
|
||||
|
||||
// Skip if quality below threshold
|
||||
if quality < self.config.quality_threshold {
|
||||
return SonaAdaptResultWasm {
|
||||
applied: false,
|
||||
latency_us: ((crate::utils::now_ms() - start) * 1000.0) as u64,
|
||||
quality_delta: 0.0,
|
||||
quality_ema: self.quality_ema,
|
||||
current_rank: self.current_rank,
|
||||
};
|
||||
}
|
||||
|
||||
// Update quality EMA
|
||||
let prev_quality = self.quality_ema;
|
||||
self.quality_ema =
|
||||
self.config.ema_decay * self.quality_ema + (1.0 - self.config.ema_decay) * quality;
|
||||
|
||||
// Adaptive rank adjustment (simple heuristic)
|
||||
// Increase rank if quality improving, decrease if degrading
|
||||
let quality_delta = quality - prev_quality;
|
||||
if quality_delta > 0.1 && self.current_rank < 4 {
|
||||
self.current_rank += 1;
|
||||
} else if quality_delta < -0.1 && self.current_rank > 1 {
|
||||
self.current_rank -= 1;
|
||||
}
|
||||
|
||||
// EWC-lite: Track important features (top 10% by quality contribution)
|
||||
// Simplified: just mark indices that correlate with high quality
|
||||
if quality > 0.7 && self.important_weights.len() < 100 {
|
||||
let weight_idx =
|
||||
(quality * self.config.hidden_dim as f32) as usize % self.config.hidden_dim;
|
||||
if !self.important_weights.contains(&weight_idx) {
|
||||
self.important_weights.push(weight_idx);
|
||||
}
|
||||
}
|
||||
|
||||
// Update metrics
|
||||
self.adaptations += 1;
|
||||
let latency_us = ((crate::utils::now_ms() - start) * 1000.0) as u64;
|
||||
self.latency_sum += latency_us;
|
||||
|
||||
SonaAdaptResultWasm {
|
||||
applied: true,
|
||||
latency_us,
|
||||
quality_delta: self.quality_ema - prev_quality,
|
||||
quality_ema: self.quality_ema,
|
||||
current_rank: self.current_rank,
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a pattern outcome for future reference
|
||||
#[wasm_bindgen(js_name = recordPattern)]
|
||||
pub fn record_pattern(&mut self, embedding: &[f32], success: bool) {
|
||||
let pattern = Pattern {
|
||||
embedding: embedding.to_vec(),
|
||||
success,
|
||||
quality: if success {
|
||||
self.quality_ema
|
||||
} else {
|
||||
1.0 - self.quality_ema
|
||||
},
|
||||
timestamp: self.timestamp,
|
||||
};
|
||||
|
||||
self.timestamp += 1;
|
||||
self.patterns_recorded += 1;
|
||||
if success {
|
||||
self.successful_patterns += 1;
|
||||
}
|
||||
|
||||
// Circular buffer: drop oldest if at capacity
|
||||
if self.patterns.len() >= self.config.pattern_capacity {
|
||||
self.patterns.pop_front();
|
||||
}
|
||||
|
||||
self.patterns.push_back(pattern);
|
||||
}
|
||||
|
||||
/// Suggest action based on learned patterns
|
||||
///
|
||||
/// Uses simple cosine similarity search (HNSW integration point for future)
|
||||
#[wasm_bindgen(js_name = suggestAction)]
|
||||
pub fn suggest_action(&self, context: &[f32]) -> Option<String> {
|
||||
if self.patterns.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Find most similar successful pattern
|
||||
let mut best_similarity = -1.0;
|
||||
let mut best_pattern: Option<&Pattern> = None;
|
||||
|
||||
for pattern in &self.patterns {
|
||||
if !pattern.success {
|
||||
continue;
|
||||
}
|
||||
|
||||
let similarity = cosine_similarity(context, &pattern.embedding);
|
||||
if similarity > best_similarity {
|
||||
best_similarity = similarity;
|
||||
best_pattern = Some(pattern);
|
||||
}
|
||||
}
|
||||
|
||||
// Threshold: only suggest if similarity > 0.7
|
||||
if best_similarity > 0.7 {
|
||||
best_pattern.map(|p| format!("apply_pattern_quality_{:.2}", p.quality))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Get current statistics
|
||||
#[wasm_bindgen]
|
||||
pub fn stats(&self) -> SonaStatsWasm {
|
||||
SonaStatsWasm {
|
||||
adaptations: self.adaptations,
|
||||
avg_quality: self.quality_ema,
|
||||
patterns_recorded: self.patterns_recorded,
|
||||
successful_patterns: self.successful_patterns,
|
||||
buffer_size: self.patterns.len(),
|
||||
avg_latency_us: if self.adaptations > 0 {
|
||||
self.latency_sum as f32 / self.adaptations as f32
|
||||
} else {
|
||||
0.0
|
||||
},
|
||||
current_rank: self.current_rank,
|
||||
}
|
||||
}
|
||||
|
||||
/// Export state to JSON
|
||||
#[wasm_bindgen(js_name = toJson)]
|
||||
pub fn to_json(&self) -> Result<String, JsValue> {
|
||||
#[derive(Serialize)]
|
||||
struct Export {
|
||||
config: SonaConfigWasm,
|
||||
quality_ema: f32,
|
||||
adaptations: u64,
|
||||
patterns_recorded: u64,
|
||||
successful_patterns: u64,
|
||||
current_rank: usize,
|
||||
buffer_size: usize,
|
||||
}
|
||||
|
||||
let export = Export {
|
||||
config: self.config.clone(),
|
||||
quality_ema: self.quality_ema,
|
||||
adaptations: self.adaptations,
|
||||
patterns_recorded: self.patterns_recorded,
|
||||
successful_patterns: self.successful_patterns,
|
||||
current_rank: self.current_rank,
|
||||
buffer_size: self.patterns.len(),
|
||||
};
|
||||
|
||||
serde_json::to_string(&export).map_err(|e| JsValue::from_str(&e.to_string()))
|
||||
}
|
||||
|
||||
/// Import state from JSON (partial - doesn't restore patterns)
|
||||
#[wasm_bindgen(js_name = fromJson)]
|
||||
pub fn from_json(json: &str) -> Result<SonaInstantWasm, JsValue> {
|
||||
#[derive(Deserialize)]
|
||||
struct Import {
|
||||
config: SonaConfigWasm,
|
||||
quality_ema: f32,
|
||||
adaptations: u64,
|
||||
patterns_recorded: u64,
|
||||
successful_patterns: u64,
|
||||
current_rank: usize,
|
||||
}
|
||||
|
||||
let import: Import =
|
||||
serde_json::from_str(json).map_err(|e| JsValue::from_str(&e.to_string()))?;
|
||||
|
||||
Ok(Self {
|
||||
config: import.config.clone(),
|
||||
patterns: VecDeque::with_capacity(import.config.pattern_capacity),
|
||||
quality_ema: import.quality_ema,
|
||||
adaptations: import.adaptations,
|
||||
latency_sum: 0,
|
||||
patterns_recorded: import.patterns_recorded,
|
||||
successful_patterns: import.successful_patterns,
|
||||
timestamp: 0,
|
||||
important_weights: Vec::new(),
|
||||
current_rank: import.current_rank,
|
||||
})
|
||||
}
|
||||
|
||||
/// Reset all learning state
|
||||
#[wasm_bindgen]
|
||||
pub fn reset(&mut self) {
|
||||
self.patterns.clear();
|
||||
self.quality_ema = 0.5;
|
||||
self.adaptations = 0;
|
||||
self.latency_sum = 0;
|
||||
self.patterns_recorded = 0;
|
||||
self.successful_patterns = 0;
|
||||
self.timestamp = 0;
|
||||
self.important_weights.clear();
|
||||
self.current_rank = self.config.micro_lora_rank;
|
||||
}
|
||||
|
||||
/// Get number of important weights tracked (EWC-lite)
|
||||
#[wasm_bindgen(js_name = importantWeightCount)]
|
||||
pub fn important_weight_count(&self) -> usize {
|
||||
self.important_weights.len()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Utilities
|
||||
// ============================================================================
|
||||
|
||||
/// Cosine similarity between two vectors
|
||||
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||
if a.len() != b.len() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let mut dot = 0.0;
|
||||
let mut norm_a = 0.0;
|
||||
let mut norm_b = 0.0;
|
||||
|
||||
for i in 0..a.len() {
|
||||
dot += a[i] * b[i];
|
||||
norm_a += a[i] * a[i];
|
||||
norm_b += b[i] * b[i];
|
||||
}
|
||||
|
||||
if norm_a <= 0.0 || norm_b <= 0.0 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
dot / (norm_a.sqrt() * norm_b.sqrt())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_config_defaults() {
|
||||
let config = SonaConfigWasm::new();
|
||||
assert_eq!(config.hidden_dim, 256);
|
||||
assert_eq!(config.micro_lora_rank, 1);
|
||||
assert!((config.learning_rate - 0.01).abs() < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_setters() {
|
||||
let mut config = SonaConfigWasm::new();
|
||||
config.set_learning_rate(0.05);
|
||||
assert!((config.learning_rate() - 0.05).abs() < 0.001);
|
||||
|
||||
config.set_micro_lora_rank(2);
|
||||
assert_eq!(config.micro_lora_rank(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sona_creation() {
|
||||
let config = SonaConfigWasm::new();
|
||||
let sona = SonaInstantWasm::new(config);
|
||||
let stats = sona.stats();
|
||||
assert_eq!(stats.adaptations, 0);
|
||||
assert_eq!(stats.buffer_size, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_instant_adapt() {
|
||||
let config = SonaConfigWasm::new();
|
||||
let mut sona = SonaInstantWasm::new(config);
|
||||
|
||||
// Low quality - should skip
|
||||
let result = sona.instant_adapt(0.3);
|
||||
assert!(!result.applied);
|
||||
|
||||
// High quality - should apply
|
||||
let result = sona.instant_adapt(0.8);
|
||||
assert!(result.applied);
|
||||
assert!(result.quality_ema > 0.5);
|
||||
assert!(result.latency_us < 10000); // Should be < 10ms (way below 1ms in practice)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pattern_recording() {
|
||||
let config = SonaConfigWasm::new();
|
||||
let mut sona = SonaInstantWasm::new(config);
|
||||
|
||||
let embedding = vec![0.1, 0.2, 0.3, 0.4];
|
||||
sona.record_pattern(&embedding, true);
|
||||
|
||||
let stats = sona.stats();
|
||||
assert_eq!(stats.patterns_recorded, 1);
|
||||
assert_eq!(stats.successful_patterns, 1);
|
||||
assert_eq!(stats.buffer_size, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pattern_buffer_overflow() {
|
||||
let mut config = SonaConfigWasm::new();
|
||||
config.set_pattern_capacity(5);
|
||||
let mut sona = SonaInstantWasm::new(config);
|
||||
|
||||
// Add more patterns than capacity
|
||||
for i in 0..10 {
|
||||
let embedding = vec![i as f32, i as f32 + 0.1];
|
||||
sona.record_pattern(&embedding, true);
|
||||
}
|
||||
|
||||
let stats = sona.stats();
|
||||
assert_eq!(stats.buffer_size, 5); // Should be capped at capacity
|
||||
assert_eq!(stats.patterns_recorded, 10); // Total recorded
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_suggest_action() {
|
||||
let config = SonaConfigWasm::new();
|
||||
let mut sona = SonaInstantWasm::new(config);
|
||||
|
||||
// Record a successful pattern
|
||||
let embedding = vec![0.5; 10];
|
||||
sona.instant_adapt(0.9); // Set high quality
|
||||
sona.record_pattern(&embedding, true);
|
||||
|
||||
// Query with similar context
|
||||
let similar = vec![0.51; 10];
|
||||
let suggestion = sona.suggest_action(&similar);
|
||||
assert!(suggestion.is_some());
|
||||
|
||||
// Query with dissimilar context
|
||||
let dissimilar = vec![-0.5; 10];
|
||||
let suggestion = sona.suggest_action(&dissimilar);
|
||||
assert!(suggestion.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quality_ema_tracking() {
|
||||
let config = SonaConfigWasm::new();
|
||||
let mut sona = SonaInstantWasm::new(config);
|
||||
|
||||
// Feed increasing quality signals
|
||||
for i in 1..=10 {
|
||||
let quality = 0.5 + (i as f32 * 0.03);
|
||||
sona.instant_adapt(quality);
|
||||
}
|
||||
|
||||
let stats = sona.stats();
|
||||
assert!(stats.avg_quality > 0.5); // EMA should have increased
|
||||
assert!(stats.avg_quality < 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_adaptive_rank() {
|
||||
let config = SonaConfigWasm::new();
|
||||
let mut sona = SonaInstantWasm::new(config);
|
||||
assert_eq!(sona.current_rank, 1);
|
||||
|
||||
// Improve quality - should increase rank
|
||||
sona.instant_adapt(0.5);
|
||||
sona.instant_adapt(0.7); // Big jump
|
||||
assert_eq!(sona.current_rank, 2);
|
||||
|
||||
// Degrade quality - should decrease rank
|
||||
sona.instant_adapt(0.3);
|
||||
assert_eq!(sona.current_rank, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reset() {
|
||||
let config = SonaConfigWasm::new();
|
||||
let mut sona = SonaInstantWasm::new(config);
|
||||
|
||||
// Add state
|
||||
sona.instant_adapt(0.8);
|
||||
sona.record_pattern(&[0.1, 0.2], true);
|
||||
|
||||
// Reset
|
||||
sona.reset();
|
||||
|
||||
let stats = sona.stats();
|
||||
assert_eq!(stats.adaptations, 0);
|
||||
assert_eq!(stats.patterns_recorded, 0);
|
||||
assert_eq!(stats.buffer_size, 0);
|
||||
assert!((stats.avg_quality - 0.5).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_similarity() {
|
||||
let a = vec![1.0, 0.0, 0.0];
|
||||
let b = vec![1.0, 0.0, 0.0];
|
||||
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
|
||||
|
||||
let c = vec![1.0, 0.0, 0.0];
|
||||
let d = vec![0.0, 1.0, 0.0];
|
||||
assert!((cosine_similarity(&c, &d) - 0.0).abs() < 0.001);
|
||||
|
||||
let e = vec![1.0, 1.0, 0.0];
|
||||
let f = vec![1.0, 1.0, 0.0];
|
||||
assert!((cosine_similarity(&e, &f) - 1.0).abs() < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_serialization() {
|
||||
let config = SonaConfigWasm::new();
|
||||
let mut sona = SonaInstantWasm::new(config);
|
||||
|
||||
sona.instant_adapt(0.8);
|
||||
sona.record_pattern(&[0.1, 0.2], true);
|
||||
|
||||
let json = sona.to_json().unwrap();
|
||||
assert!(json.contains("quality_ema"));
|
||||
assert!(json.contains("adaptations"));
|
||||
|
||||
// Should be able to deserialize config
|
||||
let config_json = sona.config.to_json().unwrap();
|
||||
let restored_config = SonaConfigWasm::from_json(&config_json).unwrap();
|
||||
assert_eq!(restored_config.hidden_dim, sona.config.hidden_dim);
|
||||
}
|
||||
}
|
||||
142
vendor/ruvector/crates/ruvllm-wasm/src/utils.rs
vendored
Normal file
142
vendor/ruvector/crates/ruvllm-wasm/src/utils.rs
vendored
Normal file
@@ -0,0 +1,142 @@
|
||||
//! Utility functions for WASM environment
|
||||
//!
|
||||
//! Provides helper functions for panic handling, logging, and
|
||||
//! JavaScript interop utilities.
|
||||
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
/// Set panic hook for better error messages in the browser console.
|
||||
///
|
||||
/// This function should be called once at initialization to enable
|
||||
/// better panic messages in the browser's developer console.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// use ruvllm_wasm::utils::set_panic_hook;
|
||||
///
|
||||
/// // Call at app startup
|
||||
/// set_panic_hook();
|
||||
/// ```
|
||||
pub fn set_panic_hook() {
|
||||
// When the `console_error_panic_hook` feature is enabled, we can call the
|
||||
// `set_panic_hook` function at least once during initialization, and then
|
||||
// we will get better error messages if our code ever panics.
|
||||
//
|
||||
// For more details see
|
||||
// https://github.com/rustwasm/console_error_panic_hook#readme
|
||||
#[cfg(feature = "console_error_panic_hook")]
|
||||
console_error_panic_hook::set_once();
|
||||
}
|
||||
|
||||
/// Log a message to the browser console.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `message` - The message to log
|
||||
#[wasm_bindgen]
|
||||
pub fn log(message: &str) {
|
||||
web_sys::console::log_1(&message.into());
|
||||
}
|
||||
|
||||
/// Log a warning to the browser console.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `message` - The warning message
|
||||
#[wasm_bindgen]
|
||||
pub fn warn(message: &str) {
|
||||
web_sys::console::warn_1(&message.into());
|
||||
}
|
||||
|
||||
/// Log an error to the browser console.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `message` - The error message
|
||||
#[wasm_bindgen]
|
||||
pub fn error(message: &str) {
|
||||
web_sys::console::error_1(&message.into());
|
||||
}
|
||||
|
||||
/// Get current timestamp in milliseconds using Performance API.
|
||||
///
|
||||
/// Returns high-resolution timestamp for performance measurements.
|
||||
#[wasm_bindgen]
|
||||
pub fn now_ms() -> f64 {
|
||||
web_sys::window()
|
||||
.and_then(|w| w.performance())
|
||||
.map(|p| p.now())
|
||||
.unwrap_or(0.0)
|
||||
}
|
||||
|
||||
/// Simple timer for measuring elapsed time in WASM.
|
||||
#[wasm_bindgen]
|
||||
pub struct Timer {
|
||||
start: f64,
|
||||
label: String,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl Timer {
|
||||
/// Create a new timer with the given label.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `label` - A descriptive label for the timer
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(label: &str) -> Timer {
|
||||
Timer {
|
||||
start: now_ms(),
|
||||
label: label.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get elapsed time in milliseconds.
|
||||
#[wasm_bindgen]
|
||||
pub fn elapsed_ms(&self) -> f64 {
|
||||
now_ms() - self.start
|
||||
}
|
||||
|
||||
/// Log elapsed time to console and return the duration.
|
||||
#[wasm_bindgen]
|
||||
pub fn stop(&self) -> f64 {
|
||||
let elapsed = self.elapsed_ms();
|
||||
log(&format!("{}: {:.2}ms", self.label, elapsed));
|
||||
elapsed
|
||||
}
|
||||
|
||||
/// Reset the timer.
|
||||
#[wasm_bindgen]
|
||||
pub fn reset(&mut self) {
|
||||
self.start = now_ms();
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a Rust Result to a JavaScript-friendly format.
|
||||
///
|
||||
/// On success, returns the value. On error, throws a JavaScript exception.
|
||||
pub fn result_to_js<T, E: std::fmt::Display>(result: Result<T, E>) -> Result<T, JsValue> {
|
||||
result.map_err(|e| JsValue::from_str(&e.to_string()))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// set_panic_hook requires console_error_panic_hook which only works on wasm32
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
#[test]
|
||||
fn test_set_panic_hook() {
|
||||
// Should not panic
|
||||
set_panic_hook();
|
||||
}
|
||||
|
||||
// Non-wasm32 version just verifies the function exists
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[test]
|
||||
fn test_set_panic_hook_noop() {
|
||||
// On non-wasm32, this is a no-op
|
||||
set_panic_hook();
|
||||
}
|
||||
}
|
||||
469
vendor/ruvector/crates/ruvllm-wasm/src/webgpu/buffers.rs
vendored
Normal file
469
vendor/ruvector/crates/ruvllm-wasm/src/webgpu/buffers.rs
vendored
Normal file
@@ -0,0 +1,469 @@
|
||||
//! GPU Buffer Management for WebGPU WASM
|
||||
//!
|
||||
//! This module provides buffer abstractions for GPU memory management
|
||||
//! in the browser WebGPU environment.
|
||||
|
||||
use js_sys::{Float32Array, Uint8Array};
|
||||
use std::cell::RefCell;
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
/// Buffer usage flags
|
||||
#[wasm_bindgen]
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct GpuBufferUsage {
|
||||
/// Can be mapped for reading
|
||||
#[wasm_bindgen(skip)]
|
||||
pub map_read: bool,
|
||||
/// Can be mapped for writing
|
||||
#[wasm_bindgen(skip)]
|
||||
pub map_write: bool,
|
||||
/// Can be used as copy source
|
||||
#[wasm_bindgen(skip)]
|
||||
pub copy_src: bool,
|
||||
/// Can be used as copy destination
|
||||
#[wasm_bindgen(skip)]
|
||||
pub copy_dst: bool,
|
||||
/// Can be used as storage buffer
|
||||
#[wasm_bindgen(skip)]
|
||||
pub storage: bool,
|
||||
/// Can be used as uniform buffer
|
||||
#[wasm_bindgen(skip)]
|
||||
pub uniform: bool,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl GpuBufferUsage {
|
||||
/// Create storage buffer usage (read/write compute)
|
||||
#[wasm_bindgen(js_name = storage)]
|
||||
pub fn new_storage() -> Self {
|
||||
Self {
|
||||
storage: true,
|
||||
copy_dst: true,
|
||||
copy_src: true,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create uniform buffer usage
|
||||
#[wasm_bindgen(js_name = uniform)]
|
||||
pub fn new_uniform() -> Self {
|
||||
Self {
|
||||
uniform: true,
|
||||
copy_dst: true,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create staging buffer for upload
|
||||
#[wasm_bindgen(js_name = stagingUpload)]
|
||||
pub fn staging_upload() -> Self {
|
||||
Self {
|
||||
map_write: true,
|
||||
copy_src: true,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create staging buffer for download
|
||||
#[wasm_bindgen(js_name = stagingDownload)]
|
||||
pub fn staging_download() -> Self {
|
||||
Self {
|
||||
map_read: true,
|
||||
copy_dst: true,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create read-only storage buffer
|
||||
#[wasm_bindgen(js_name = storageReadOnly)]
|
||||
pub fn storage_read_only() -> Self {
|
||||
Self {
|
||||
storage: true,
|
||||
copy_dst: true,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert to WebGPU usage flags (as raw u32)
|
||||
///
|
||||
/// WebGPU buffer usage flags:
|
||||
/// - MAP_READ = 0x0001
|
||||
/// - MAP_WRITE = 0x0002
|
||||
/// - COPY_SRC = 0x0004
|
||||
/// - COPY_DST = 0x0008
|
||||
/// - INDEX = 0x0010
|
||||
/// - VERTEX = 0x0020
|
||||
/// - UNIFORM = 0x0040
|
||||
/// - STORAGE = 0x0080
|
||||
/// - INDIRECT = 0x0100
|
||||
/// - QUERY_RESOLVE = 0x0200
|
||||
pub fn to_u32(&self) -> u32 {
|
||||
let mut flags = 0u32;
|
||||
if self.map_read {
|
||||
flags |= 0x0001;
|
||||
}
|
||||
if self.map_write {
|
||||
flags |= 0x0002;
|
||||
}
|
||||
if self.copy_src {
|
||||
flags |= 0x0004;
|
||||
}
|
||||
if self.copy_dst {
|
||||
flags |= 0x0008;
|
||||
}
|
||||
if self.uniform {
|
||||
flags |= 0x0040;
|
||||
}
|
||||
if self.storage {
|
||||
flags |= 0x0080;
|
||||
}
|
||||
flags
|
||||
}
|
||||
|
||||
#[wasm_bindgen(getter, js_name = mapRead)]
|
||||
pub fn get_map_read(&self) -> bool {
|
||||
self.map_read
|
||||
}
|
||||
|
||||
#[wasm_bindgen(setter, js_name = mapRead)]
|
||||
pub fn set_map_read(&mut self, value: bool) {
|
||||
self.map_read = value;
|
||||
}
|
||||
|
||||
#[wasm_bindgen(getter, js_name = mapWrite)]
|
||||
pub fn get_map_write(&self) -> bool {
|
||||
self.map_write
|
||||
}
|
||||
|
||||
#[wasm_bindgen(setter, js_name = mapWrite)]
|
||||
pub fn set_map_write(&mut self, value: bool) {
|
||||
self.map_write = value;
|
||||
}
|
||||
|
||||
#[wasm_bindgen(getter, js_name = copySrc)]
|
||||
pub fn get_copy_src(&self) -> bool {
|
||||
self.copy_src
|
||||
}
|
||||
|
||||
#[wasm_bindgen(setter, js_name = copySrc)]
|
||||
pub fn set_copy_src(&mut self, value: bool) {
|
||||
self.copy_src = value;
|
||||
}
|
||||
|
||||
#[wasm_bindgen(getter, js_name = copyDst)]
|
||||
pub fn get_copy_dst(&self) -> bool {
|
||||
self.copy_dst
|
||||
}
|
||||
|
||||
#[wasm_bindgen(setter, js_name = copyDst)]
|
||||
pub fn set_copy_dst(&mut self, value: bool) {
|
||||
self.copy_dst = value;
|
||||
}
|
||||
|
||||
#[wasm_bindgen(getter, js_name = isStorage)]
|
||||
pub fn get_storage(&self) -> bool {
|
||||
self.storage
|
||||
}
|
||||
|
||||
#[wasm_bindgen(setter, js_name = isStorage)]
|
||||
pub fn set_storage(&mut self, value: bool) {
|
||||
self.storage = value;
|
||||
}
|
||||
|
||||
#[wasm_bindgen(getter, js_name = isUniform)]
|
||||
pub fn get_uniform(&self) -> bool {
|
||||
self.uniform
|
||||
}
|
||||
|
||||
#[wasm_bindgen(setter, js_name = isUniform)]
|
||||
pub fn set_uniform(&mut self, value: bool) {
|
||||
self.uniform = value;
|
||||
}
|
||||
}
|
||||
|
||||
/// GPU buffer handle
|
||||
///
|
||||
/// Wraps a WebGPU buffer with metadata for safe operations.
|
||||
#[wasm_bindgen]
|
||||
pub struct GpuBuffer {
|
||||
/// Internal buffer handle (web_sys::GpuBuffer when on wasm32)
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
buffer: web_sys::GpuBuffer,
|
||||
|
||||
/// Placeholder for non-wasm32 builds
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
buffer: Vec<u8>,
|
||||
|
||||
/// Buffer size in bytes
|
||||
size: usize,
|
||||
|
||||
/// Buffer usage flags
|
||||
usage: GpuBufferUsage,
|
||||
|
||||
/// Optional label for debugging
|
||||
label: Option<String>,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl GpuBuffer {
|
||||
/// Get buffer size in bytes
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
|
||||
/// Get buffer label
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn label(&self) -> Option<String> {
|
||||
self.label.clone()
|
||||
}
|
||||
|
||||
/// Check if buffer supports mapping for read
|
||||
#[wasm_bindgen(getter, js_name = canMapRead)]
|
||||
pub fn can_map_read(&self) -> bool {
|
||||
self.usage.map_read
|
||||
}
|
||||
|
||||
/// Check if buffer supports mapping for write
|
||||
#[wasm_bindgen(getter, js_name = canMapWrite)]
|
||||
pub fn can_map_write(&self) -> bool {
|
||||
self.usage.map_write
|
||||
}
|
||||
|
||||
/// Get size as number of f32 elements
|
||||
#[wasm_bindgen(js_name = sizeAsF32)]
|
||||
pub fn size_as_f32(&self) -> usize {
|
||||
self.size / 4
|
||||
}
|
||||
|
||||
/// Get the raw web_sys buffer (for advanced usage)
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
#[wasm_bindgen(getter, js_name = rawBuffer)]
|
||||
pub fn raw_buffer(&self) -> web_sys::GpuBuffer {
|
||||
self.buffer.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl GpuBuffer {
|
||||
/// Create a new GPU buffer (internal constructor)
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
pub(crate) fn new(
|
||||
buffer: web_sys::GpuBuffer,
|
||||
size: usize,
|
||||
usage: GpuBufferUsage,
|
||||
label: Option<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
buffer,
|
||||
size,
|
||||
usage,
|
||||
label,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new GPU buffer (non-wasm32 placeholder)
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub(crate) fn new(size: usize, usage: GpuBufferUsage, label: Option<String>) -> Self {
|
||||
Self {
|
||||
buffer: vec![0u8; size],
|
||||
size,
|
||||
usage,
|
||||
label,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get internal buffer reference
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
pub(crate) fn inner(&self) -> &web_sys::GpuBuffer {
|
||||
&self.buffer
|
||||
}
|
||||
}
|
||||
|
||||
/// Staging buffer pool for efficient CPU<->GPU transfers
|
||||
#[wasm_bindgen]
|
||||
pub struct StagingBufferPool {
|
||||
/// Pool of upload staging buffers
|
||||
upload_pool: RefCell<Vec<GpuBuffer>>,
|
||||
/// Pool of download staging buffers
|
||||
download_pool: RefCell<Vec<GpuBuffer>>,
|
||||
/// Maximum buffers per pool
|
||||
max_per_pool: usize,
|
||||
/// Total bytes allocated
|
||||
total_allocated: RefCell<usize>,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl StagingBufferPool {
|
||||
/// Create a new staging buffer pool
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(max_per_pool: usize) -> Self {
|
||||
Self {
|
||||
upload_pool: RefCell::new(Vec::with_capacity(max_per_pool)),
|
||||
download_pool: RefCell::new(Vec::with_capacity(max_per_pool)),
|
||||
max_per_pool,
|
||||
total_allocated: RefCell::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the number of upload buffers in pool
|
||||
#[wasm_bindgen(getter, js_name = uploadBufferCount)]
|
||||
pub fn upload_buffer_count(&self) -> usize {
|
||||
self.upload_pool.borrow().len()
|
||||
}
|
||||
|
||||
/// Get the number of download buffers in pool
|
||||
#[wasm_bindgen(getter, js_name = downloadBufferCount)]
|
||||
pub fn download_buffer_count(&self) -> usize {
|
||||
self.download_pool.borrow().len()
|
||||
}
|
||||
|
||||
/// Get total bytes allocated
|
||||
#[wasm_bindgen(getter, js_name = totalAllocated)]
|
||||
pub fn total_allocated(&self) -> usize {
|
||||
*self.total_allocated.borrow()
|
||||
}
|
||||
|
||||
/// Clear all pooled buffers
|
||||
#[wasm_bindgen]
|
||||
pub fn clear(&self) {
|
||||
self.upload_pool.borrow_mut().clear();
|
||||
self.download_pool.borrow_mut().clear();
|
||||
*self.total_allocated.borrow_mut() = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Tensor descriptor for buffer allocation
|
||||
#[wasm_bindgen]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TensorDescriptor {
|
||||
/// Shape dimensions
|
||||
shape: Vec<u32>,
|
||||
/// Data type (0=f32, 1=f16, 2=i32, 3=u8)
|
||||
dtype: u8,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl TensorDescriptor {
|
||||
/// Create tensor descriptor for a matrix
|
||||
#[wasm_bindgen(js_name = matrix)]
|
||||
pub fn matrix(rows: u32, cols: u32) -> Self {
|
||||
Self {
|
||||
shape: vec![rows, cols],
|
||||
dtype: 0, // f32
|
||||
}
|
||||
}
|
||||
|
||||
/// Create tensor descriptor for a vector
|
||||
#[wasm_bindgen(js_name = vector)]
|
||||
pub fn vector(len: u32) -> Self {
|
||||
Self {
|
||||
shape: vec![len],
|
||||
dtype: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create tensor descriptor with arbitrary shape
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(shape: Vec<u32>, dtype: u8) -> Self {
|
||||
Self { shape, dtype }
|
||||
}
|
||||
|
||||
/// Get total number of elements
|
||||
#[wasm_bindgen(js_name = numElements)]
|
||||
pub fn num_elements(&self) -> usize {
|
||||
self.shape.iter().map(|&d| d as usize).product()
|
||||
}
|
||||
|
||||
/// Get size in bytes
|
||||
#[wasm_bindgen(js_name = sizeBytes)]
|
||||
pub fn size_bytes(&self) -> usize {
|
||||
let element_size = match self.dtype {
|
||||
0 => 4, // f32
|
||||
1 => 2, // f16
|
||||
2 => 4, // i32
|
||||
3 => 1, // u8
|
||||
_ => 4, // default to f32
|
||||
};
|
||||
self.num_elements() * element_size
|
||||
}
|
||||
|
||||
/// Get shape dimensions
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn shape(&self) -> Vec<u32> {
|
||||
self.shape.clone()
|
||||
}
|
||||
|
||||
/// Get data type
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn dtype(&self) -> u8 {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
/// Get number of dimensions
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn ndim(&self) -> usize {
|
||||
self.shape.len()
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper functions for creating typed arrays from GPU buffers
|
||||
#[wasm_bindgen]
|
||||
pub struct BufferHelpers;
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl BufferHelpers {
|
||||
/// Create a Float32Array view from a Uint8Array
|
||||
#[wasm_bindgen(js_name = asFloat32Array)]
|
||||
pub fn as_float32_array(data: &Uint8Array) -> Float32Array {
|
||||
Float32Array::new(&data.buffer())
|
||||
}
|
||||
|
||||
/// Calculate aligned size for GPU buffers (must be multiple of 4)
|
||||
#[wasm_bindgen(js_name = alignedSize)]
|
||||
pub fn aligned_size(size: usize) -> usize {
|
||||
(size + 3) & !3
|
||||
}
|
||||
|
||||
/// Calculate workgroup count for a given dimension
|
||||
#[wasm_bindgen(js_name = workgroupCount)]
|
||||
pub fn workgroup_count(total: u32, workgroup_size: u32) -> u32 {
|
||||
(total + workgroup_size - 1) / workgroup_size
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_buffer_usage() {
|
||||
let storage = GpuBufferUsage::new_storage();
|
||||
assert!(storage.storage);
|
||||
assert!(storage.copy_dst);
|
||||
assert!(storage.copy_src);
|
||||
assert!(!storage.uniform);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tensor_descriptor() {
|
||||
let matrix = TensorDescriptor::matrix(1024, 768);
|
||||
assert_eq!(matrix.num_elements(), 1024 * 768);
|
||||
assert_eq!(matrix.size_bytes(), 1024 * 768 * 4);
|
||||
assert_eq!(matrix.ndim(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aligned_size() {
|
||||
assert_eq!(BufferHelpers::aligned_size(0), 0);
|
||||
assert_eq!(BufferHelpers::aligned_size(1), 4);
|
||||
assert_eq!(BufferHelpers::aligned_size(4), 4);
|
||||
assert_eq!(BufferHelpers::aligned_size(5), 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_workgroup_count() {
|
||||
assert_eq!(BufferHelpers::workgroup_count(1000, 256), 4);
|
||||
assert_eq!(BufferHelpers::workgroup_count(256, 256), 1);
|
||||
assert_eq!(BufferHelpers::workgroup_count(257, 256), 2);
|
||||
}
|
||||
}
|
||||
882
vendor/ruvector/crates/ruvllm-wasm/src/webgpu/compute.rs
vendored
Normal file
882
vendor/ruvector/crates/ruvllm-wasm/src/webgpu/compute.rs
vendored
Normal file
@@ -0,0 +1,882 @@
|
||||
//! WebGPU Compute Context and Pipelines
|
||||
//!
|
||||
//! This module provides the core WebGPU compute functionality for WASM,
|
||||
//! including context initialization, pipeline creation, and kernel execution.
|
||||
//!
|
||||
//! Note: WebGPU bindings use JavaScript interop via js_sys/Reflect since
|
||||
//! web-sys WebGPU bindings are still unstable.
|
||||
|
||||
use js_sys::{Array, Float32Array, Object, Promise, Reflect};
|
||||
use wasm_bindgen::prelude::*;
|
||||
use wasm_bindgen_futures::JsFuture;
|
||||
|
||||
use super::{shaders, AdapterInfo, AttentionConfig};
|
||||
|
||||
/// Check if WebGPU is available in this browser
|
||||
pub async fn is_webgpu_available() -> bool {
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
{
|
||||
if let Some(gpu) = get_gpu_object() {
|
||||
return !gpu.is_undefined() && !gpu.is_null();
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
false
|
||||
}
|
||||
|
||||
/// Get GPU adapter information if available
|
||||
pub async fn get_gpu_info() -> Option<AdapterInfo> {
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
{
|
||||
let gpu = get_gpu_object()?;
|
||||
|
||||
// Request adapter
|
||||
let options = Object::new();
|
||||
let _ = Reflect::set(
|
||||
&options,
|
||||
&"powerPreference".into(),
|
||||
&"high-performance".into(),
|
||||
);
|
||||
|
||||
let adapter_promise = call_method(&gpu, "requestAdapter", &[options.into()]).ok()?;
|
||||
let adapter = JsFuture::from(adapter_promise.dyn_into::<Promise>().ok()?)
|
||||
.await
|
||||
.ok()?;
|
||||
|
||||
if adapter.is_null() || adapter.is_undefined() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Get adapter info via requestAdapterInfo()
|
||||
let info_promise = call_method(&adapter, "requestAdapterInfo", &[]).ok()?;
|
||||
let info = JsFuture::from(info_promise.dyn_into::<Promise>().ok()?)
|
||||
.await
|
||||
.ok()?;
|
||||
|
||||
// Extract limits
|
||||
let limits = Reflect::get(&adapter, &"limits".into()).ok()?;
|
||||
|
||||
Some(AdapterInfo {
|
||||
vendor: get_string_prop(&info, "vendor").unwrap_or_default(),
|
||||
architecture: get_string_prop(&info, "architecture").unwrap_or_default(),
|
||||
device_type: get_string_prop(&info, "device").unwrap_or_else(|| "unknown".to_string()),
|
||||
backend: "WebGPU".to_string(),
|
||||
max_buffer_size: get_number_prop(&limits, "maxBufferSize")
|
||||
.unwrap_or(256.0 * 1024.0 * 1024.0) as u64,
|
||||
max_workgroup_size: get_number_prop(&limits, "maxComputeWorkgroupSizeX")
|
||||
.unwrap_or(256.0) as u32,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
None
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Helper Functions
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
fn get_gpu_object() -> Option<JsValue> {
|
||||
let window = web_sys::window()?;
|
||||
let navigator = Reflect::get(&window, &"navigator".into()).ok()?;
|
||||
let gpu = Reflect::get(&navigator, &"gpu".into()).ok()?;
|
||||
if gpu.is_undefined() || gpu.is_null() {
|
||||
None
|
||||
} else {
|
||||
Some(gpu)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
fn get_string_prop(obj: &JsValue, key: &str) -> Option<String> {
|
||||
Reflect::get(obj, &key.into())
|
||||
.ok()
|
||||
.and_then(|v| v.as_string())
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
fn get_number_prop(obj: &JsValue, key: &str) -> Option<f64> {
|
||||
Reflect::get(obj, &key.into()).ok().and_then(|v| v.as_f64())
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
fn call_method(obj: &JsValue, method: &str, args: &[JsValue]) -> Result<JsValue, JsValue> {
|
||||
let func = Reflect::get(obj, &method.into())?.dyn_into::<js_sys::Function>()?;
|
||||
|
||||
let args_array = Array::new();
|
||||
for arg in args {
|
||||
args_array.push(arg);
|
||||
}
|
||||
|
||||
Reflect::apply(&func, obj, &args_array)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// WebGPU Context
|
||||
// ============================================================================
|
||||
|
||||
/// WebGPU context holding device and queue references
|
||||
#[wasm_bindgen]
|
||||
pub struct WebGpuContext {
|
||||
/// GPU device object (JsValue wrapper)
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
device: JsValue,
|
||||
|
||||
/// Command queue object
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
queue: JsValue,
|
||||
|
||||
/// Placeholder for non-wasm builds
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
_phantom: std::marker::PhantomData<()>,
|
||||
|
||||
/// Adapter information
|
||||
adapter_info: AdapterInfo,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WebGpuContext {
|
||||
/// Initialize WebGPU context
|
||||
#[wasm_bindgen(js_name = init)]
|
||||
pub async fn init() -> Result<WebGpuContext, JsValue> {
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
{
|
||||
let gpu = get_gpu_object().ok_or_else(|| JsValue::from_str("WebGPU not available"))?;
|
||||
|
||||
// Request adapter with high performance preference
|
||||
let adapter_options = Object::new();
|
||||
Reflect::set(
|
||||
&adapter_options,
|
||||
&"powerPreference".into(),
|
||||
&"high-performance".into(),
|
||||
)?;
|
||||
|
||||
let adapter_promise = call_method(&gpu, "requestAdapter", &[adapter_options.into()])?;
|
||||
let adapter = JsFuture::from(adapter_promise.dyn_into::<Promise>()?).await?;
|
||||
|
||||
if adapter.is_null() || adapter.is_undefined() {
|
||||
return Err(JsValue::from_str("No suitable GPU adapter found"));
|
||||
}
|
||||
|
||||
// Get adapter info
|
||||
let info_promise = call_method(&adapter, "requestAdapterInfo", &[])?;
|
||||
let info = JsFuture::from(info_promise.dyn_into::<Promise>()?).await?;
|
||||
let limits = Reflect::get(&adapter, &"limits".into())?;
|
||||
|
||||
let adapter_info = AdapterInfo {
|
||||
vendor: get_string_prop(&info, "vendor").unwrap_or_default(),
|
||||
architecture: get_string_prop(&info, "architecture").unwrap_or_default(),
|
||||
device_type: get_string_prop(&info, "device")
|
||||
.unwrap_or_else(|| "unknown".to_string()),
|
||||
backend: "WebGPU".to_string(),
|
||||
max_buffer_size: get_number_prop(&limits, "maxBufferSize")
|
||||
.unwrap_or(256.0 * 1024.0 * 1024.0) as u64,
|
||||
max_workgroup_size: get_number_prop(&limits, "maxComputeWorkgroupSizeX")
|
||||
.unwrap_or(256.0) as u32,
|
||||
};
|
||||
|
||||
// Request device
|
||||
let device_descriptor = Object::new();
|
||||
Reflect::set(&device_descriptor, &"label".into(), &"ruvllm-wasm".into())?;
|
||||
|
||||
let device_promise =
|
||||
call_method(&adapter, "requestDevice", &[device_descriptor.into()])?;
|
||||
let device = JsFuture::from(device_promise.dyn_into::<Promise>()?).await?;
|
||||
|
||||
// Get queue
|
||||
let queue = Reflect::get(&device, &"queue".into())?;
|
||||
|
||||
Ok(WebGpuContext {
|
||||
device,
|
||||
queue,
|
||||
adapter_info,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Err(JsValue::from_str("WebGPU only available in WASM"))
|
||||
}
|
||||
|
||||
/// Get adapter information
|
||||
#[wasm_bindgen(getter, js_name = adapterInfo)]
|
||||
pub fn adapter_info(&self) -> AdapterInfo {
|
||||
self.adapter_info.clone()
|
||||
}
|
||||
|
||||
/// Check if context is valid
|
||||
#[wasm_bindgen(getter, js_name = isValid)]
|
||||
pub fn is_valid(&self) -> bool {
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
{
|
||||
!self.device.is_undefined() && !self.device.is_null()
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
false
|
||||
}
|
||||
|
||||
/// Create a GPU buffer
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
fn create_buffer_internal(
|
||||
&self,
|
||||
size: usize,
|
||||
usage: u32,
|
||||
label: Option<&str>,
|
||||
) -> Result<JsValue, JsValue> {
|
||||
let descriptor = Object::new();
|
||||
Reflect::set(&descriptor, &"size".into(), &JsValue::from_f64(size as f64))?;
|
||||
Reflect::set(
|
||||
&descriptor,
|
||||
&"usage".into(),
|
||||
&JsValue::from_f64(usage as f64),
|
||||
)?;
|
||||
if let Some(lbl) = label {
|
||||
Reflect::set(&descriptor, &"label".into(), &lbl.into())?;
|
||||
}
|
||||
|
||||
call_method(&self.device, "createBuffer", &[descriptor.into()])
|
||||
}
|
||||
|
||||
/// Write data to GPU buffer
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
fn write_buffer_internal(&self, buffer: &JsValue, data: &[f32]) -> Result<(), JsValue> {
|
||||
let data_array = Float32Array::from(data);
|
||||
call_method(
|
||||
&self.queue,
|
||||
"writeBuffer",
|
||||
&[
|
||||
buffer.clone(),
|
||||
JsValue::from_f64(0.0),
|
||||
data_array.buffer().into(),
|
||||
],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Compute Pipeline
|
||||
// ============================================================================
|
||||
|
||||
/// Compute pipeline handle
|
||||
#[wasm_bindgen]
|
||||
pub struct ComputePipeline {
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
pipeline: JsValue,
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
bind_group_layout: JsValue,
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
_phantom: std::marker::PhantomData<()>,
|
||||
|
||||
entry_point: String,
|
||||
workgroup_size: [u32; 3],
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl ComputePipeline {
|
||||
/// Get the entry point name
|
||||
#[wasm_bindgen(getter, js_name = entryPoint)]
|
||||
pub fn entry_point(&self) -> String {
|
||||
self.entry_point.clone()
|
||||
}
|
||||
|
||||
/// Get the workgroup size
|
||||
#[wasm_bindgen(getter, js_name = workgroupSize)]
|
||||
pub fn workgroup_size(&self) -> Vec<u32> {
|
||||
self.workgroup_size.to_vec()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// WebGPU Inference Engine
|
||||
// ============================================================================
|
||||
|
||||
/// WebGPU inference engine for LLM operations
|
||||
#[wasm_bindgen]
|
||||
pub struct WebGpuInference {
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
device: JsValue,
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
queue: JsValue,
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
_phantom: std::marker::PhantomData<()>,
|
||||
|
||||
adapter_info: AdapterInfo,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WebGpuInference {
|
||||
/// Check if WebGPU is available
|
||||
#[wasm_bindgen(js_name = isAvailable)]
|
||||
pub async fn is_available() -> bool {
|
||||
is_webgpu_available().await
|
||||
}
|
||||
|
||||
/// Initialize WebGPU inference engine
|
||||
#[wasm_bindgen(js_name = init)]
|
||||
pub async fn init() -> Result<WebGpuInference, JsValue> {
|
||||
let ctx = WebGpuContext::init().await?;
|
||||
|
||||
Ok(WebGpuInference {
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
device: ctx.device,
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
queue: ctx.queue,
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
_phantom: std::marker::PhantomData,
|
||||
adapter_info: ctx.adapter_info,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get adapter information
|
||||
#[wasm_bindgen(getter, js_name = adapterInfo)]
|
||||
pub fn adapter_info(&self) -> AdapterInfo {
|
||||
self.adapter_info.clone()
|
||||
}
|
||||
|
||||
/// Perform matrix multiplication: C = A * B
|
||||
///
|
||||
/// Args:
|
||||
/// a: Matrix A as flat f32 array (M x K)
|
||||
/// b: Matrix B as flat f32 array (K x N)
|
||||
/// m: Number of rows in A
|
||||
/// n: Number of columns in B
|
||||
/// k: Shared dimension
|
||||
///
|
||||
/// Returns: Result matrix C as f32 array (M x N)
|
||||
#[wasm_bindgen]
|
||||
pub async fn matmul(
|
||||
&self,
|
||||
a: &[f32],
|
||||
b: &[f32],
|
||||
m: u32,
|
||||
n: u32,
|
||||
k: u32,
|
||||
) -> Result<Vec<f32>, JsValue> {
|
||||
// Validate dimensions
|
||||
let expected_a = (m as usize) * (k as usize);
|
||||
let expected_b = (k as usize) * (n as usize);
|
||||
|
||||
if a.len() != expected_a {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Matrix A dimension mismatch: expected {}, got {}",
|
||||
expected_a,
|
||||
a.len()
|
||||
)));
|
||||
}
|
||||
|
||||
if b.len() != expected_b {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Matrix B dimension mismatch: expected {}, got {}",
|
||||
expected_b,
|
||||
b.len()
|
||||
)));
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
{
|
||||
let output_size = (m as usize) * (n as usize);
|
||||
|
||||
// GPU buffer usage flags
|
||||
const STORAGE: u32 = 0x80; // GPUBufferUsage.STORAGE
|
||||
const COPY_SRC: u32 = 0x04; // GPUBufferUsage.COPY_SRC
|
||||
const COPY_DST: u32 = 0x08; // GPUBufferUsage.COPY_DST
|
||||
const MAP_READ: u32 = 0x01; // GPUBufferUsage.MAP_READ
|
||||
const UNIFORM: u32 = 0x40; // GPUBufferUsage.UNIFORM
|
||||
|
||||
// Create buffers
|
||||
let buffer_a = self.create_buffer(a.len() * 4, STORAGE | COPY_DST, Some("matmul_a"))?;
|
||||
let buffer_b = self.create_buffer(b.len() * 4, STORAGE | COPY_DST, Some("matmul_b"))?;
|
||||
let buffer_c =
|
||||
self.create_buffer(output_size * 4, STORAGE | COPY_SRC, Some("matmul_c"))?;
|
||||
|
||||
// Create uniform buffer for dimensions
|
||||
let uniform_data: [f32; 4] = [m as f32, n as f32, k as f32, 1.0]; // M, N, K, alpha
|
||||
let uniform_buffer =
|
||||
self.create_buffer(16, UNIFORM | COPY_DST, Some("matmul_uniforms"))?;
|
||||
|
||||
// Write data to buffers
|
||||
self.write_buffer(&buffer_a, a)?;
|
||||
self.write_buffer(&buffer_b, b)?;
|
||||
self.write_buffer(&uniform_buffer, &uniform_data)?;
|
||||
|
||||
// Create shader module
|
||||
let shader_desc = Object::new();
|
||||
Reflect::set(&shader_desc, &"code".into(), &shaders::MATMUL_SHADER.into())?;
|
||||
let shader_module =
|
||||
call_method(&self.device, "createShaderModule", &[shader_desc.into()])?;
|
||||
|
||||
// Create bind group layout
|
||||
let layout_entries = Array::new();
|
||||
|
||||
// Storage buffer entries (A, B, C)
|
||||
for i in 0..3u32 {
|
||||
let entry = Object::new();
|
||||
Reflect::set(&entry, &"binding".into(), &JsValue::from_f64(i as f64))?;
|
||||
Reflect::set(&entry, &"visibility".into(), &JsValue::from_f64(4.0))?; // COMPUTE stage
|
||||
let buffer_layout = Object::new();
|
||||
Reflect::set(
|
||||
&buffer_layout,
|
||||
&"type".into(),
|
||||
&(if i < 2 {
|
||||
"read-only-storage"
|
||||
} else {
|
||||
"storage"
|
||||
})
|
||||
.into(),
|
||||
)?;
|
||||
Reflect::set(&entry, &"buffer".into(), &buffer_layout)?;
|
||||
layout_entries.push(&entry);
|
||||
}
|
||||
|
||||
// Uniform buffer entry
|
||||
let uniform_entry = Object::new();
|
||||
Reflect::set(&uniform_entry, &"binding".into(), &JsValue::from_f64(3.0))?;
|
||||
Reflect::set(
|
||||
&uniform_entry,
|
||||
&"visibility".into(),
|
||||
&JsValue::from_f64(4.0),
|
||||
)?;
|
||||
let uniform_layout = Object::new();
|
||||
Reflect::set(&uniform_layout, &"type".into(), &"uniform".into())?;
|
||||
Reflect::set(&uniform_entry, &"buffer".into(), &uniform_layout)?;
|
||||
layout_entries.push(&uniform_entry);
|
||||
|
||||
let layout_desc = Object::new();
|
||||
Reflect::set(&layout_desc, &"entries".into(), &layout_entries)?;
|
||||
let bind_group_layout =
|
||||
call_method(&self.device, "createBindGroupLayout", &[layout_desc.into()])?;
|
||||
|
||||
// Create pipeline layout
|
||||
let layouts = Array::new();
|
||||
layouts.push(&bind_group_layout);
|
||||
let pipeline_layout_desc = Object::new();
|
||||
Reflect::set(&pipeline_layout_desc, &"bindGroupLayouts".into(), &layouts)?;
|
||||
let pipeline_layout = call_method(
|
||||
&self.device,
|
||||
"createPipelineLayout",
|
||||
&[pipeline_layout_desc.into()],
|
||||
)?;
|
||||
|
||||
// Create compute pipeline
|
||||
let compute_stage = Object::new();
|
||||
Reflect::set(&compute_stage, &"module".into(), &shader_module)?;
|
||||
Reflect::set(&compute_stage, &"entryPoint".into(), &"main".into())?;
|
||||
|
||||
let pipeline_desc = Object::new();
|
||||
Reflect::set(&pipeline_desc, &"layout".into(), &pipeline_layout)?;
|
||||
Reflect::set(&pipeline_desc, &"compute".into(), &compute_stage)?;
|
||||
|
||||
let pipeline = call_method(
|
||||
&self.device,
|
||||
"createComputePipeline",
|
||||
&[pipeline_desc.into()],
|
||||
)?;
|
||||
|
||||
// Create bind group
|
||||
let bind_entries = Array::new();
|
||||
for (i, buffer) in [&buffer_a, &buffer_b, &buffer_c, &uniform_buffer]
|
||||
.iter()
|
||||
.enumerate()
|
||||
{
|
||||
let entry = Object::new();
|
||||
Reflect::set(&entry, &"binding".into(), &JsValue::from_f64(i as f64))?;
|
||||
let resource = Object::new();
|
||||
Reflect::set(&resource, &"buffer".into(), buffer)?;
|
||||
Reflect::set(&entry, &"resource".into(), &resource)?;
|
||||
bind_entries.push(&entry);
|
||||
}
|
||||
|
||||
let bind_group_desc = Object::new();
|
||||
Reflect::set(&bind_group_desc, &"layout".into(), &bind_group_layout)?;
|
||||
Reflect::set(&bind_group_desc, &"entries".into(), &bind_entries)?;
|
||||
let bind_group =
|
||||
call_method(&self.device, "createBindGroup", &[bind_group_desc.into()])?;
|
||||
|
||||
// Create command encoder
|
||||
let encoder_desc = Object::new();
|
||||
let encoder =
|
||||
call_method(&self.device, "createCommandEncoder", &[encoder_desc.into()])?;
|
||||
|
||||
// Begin compute pass
|
||||
let pass_desc = Object::new();
|
||||
let pass = call_method(&encoder, "beginComputePass", &[pass_desc.into()])?;
|
||||
|
||||
// Set pipeline and bind group
|
||||
call_method(&pass, "setPipeline", &[pipeline.clone()])?;
|
||||
call_method(
|
||||
&pass,
|
||||
"setBindGroup",
|
||||
&[JsValue::from_f64(0.0), bind_group.clone()],
|
||||
)?;
|
||||
|
||||
// Dispatch workgroups (16x16 tile size)
|
||||
let workgroups_x = (m + 15) / 16;
|
||||
let workgroups_y = (n + 15) / 16;
|
||||
call_method(
|
||||
&pass,
|
||||
"dispatchWorkgroups",
|
||||
&[
|
||||
JsValue::from_f64(workgroups_x as f64),
|
||||
JsValue::from_f64(workgroups_y as f64),
|
||||
],
|
||||
)?;
|
||||
|
||||
call_method(&pass, "end", &[])?;
|
||||
|
||||
// Create staging buffer for readback
|
||||
let staging =
|
||||
self.create_buffer(output_size * 4, MAP_READ | COPY_DST, Some("staging"))?;
|
||||
|
||||
// Copy result to staging
|
||||
call_method(
|
||||
&encoder,
|
||||
"copyBufferToBuffer",
|
||||
&[
|
||||
buffer_c.clone(),
|
||||
JsValue::from_f64(0.0),
|
||||
staging.clone(),
|
||||
JsValue::from_f64(0.0),
|
||||
JsValue::from_f64((output_size * 4) as f64),
|
||||
],
|
||||
)?;
|
||||
|
||||
// Submit commands
|
||||
let command_buffer = call_method(&encoder, "finish", &[])?;
|
||||
let commands = Array::new();
|
||||
commands.push(&command_buffer);
|
||||
call_method(&self.queue, "submit", &[commands.into()])?;
|
||||
|
||||
// Map staging buffer and read result
|
||||
let map_promise = call_method(&staging, "mapAsync", &[JsValue::from_f64(1.0)])?; // MAP_READ = 1
|
||||
JsFuture::from(map_promise.dyn_into::<Promise>()?).await?;
|
||||
|
||||
let mapped_range = call_method(&staging, "getMappedRange", &[])?;
|
||||
let data = Float32Array::new(&mapped_range).to_vec();
|
||||
|
||||
call_method(&staging, "unmap", &[])?;
|
||||
|
||||
Ok(data)
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
{
|
||||
// CPU fallback - naive implementation
|
||||
let mut c = vec![0.0f32; (m as usize) * (n as usize)];
|
||||
for i in 0..m as usize {
|
||||
for j in 0..n as usize {
|
||||
let mut sum = 0.0f32;
|
||||
for l in 0..k as usize {
|
||||
sum += a[i * k as usize + l] * b[l * n as usize + j];
|
||||
}
|
||||
c[i * n as usize + j] = sum;
|
||||
}
|
||||
}
|
||||
Ok(c)
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform attention: Output = softmax(Q * K^T / sqrt(d_k)) * V
|
||||
#[wasm_bindgen]
|
||||
pub async fn attention(
|
||||
&self,
|
||||
q: &[f32],
|
||||
k: &[f32],
|
||||
v: &[f32],
|
||||
config: &AttentionConfig,
|
||||
) -> Result<Vec<f32>, JsValue> {
|
||||
let hidden_dim = config.hidden_dim();
|
||||
let expected_size = (config.seq_len as usize) * (hidden_dim as usize);
|
||||
|
||||
if q.len() != expected_size || k.len() != expected_size || v.len() != expected_size {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Attention tensor dimension mismatch: expected {}, got Q:{}, K:{}, V:{}",
|
||||
expected_size,
|
||||
q.len(),
|
||||
k.len(),
|
||||
v.len()
|
||||
)));
|
||||
}
|
||||
|
||||
// CPU fallback for attention (GPU implementation similar to matmul pattern)
|
||||
// For production, would implement full GPU attention here
|
||||
self.attention_cpu(q, k, v, config)
|
||||
}
|
||||
|
||||
/// CPU fallback for attention
|
||||
fn attention_cpu(
|
||||
&self,
|
||||
q: &[f32],
|
||||
k: &[f32],
|
||||
v: &[f32],
|
||||
config: &AttentionConfig,
|
||||
) -> Result<Vec<f32>, JsValue> {
|
||||
let seq_len = config.seq_len as usize;
|
||||
let num_heads = config.num_heads as usize;
|
||||
let head_dim = config.head_dim as usize;
|
||||
let hidden_dim = num_heads * head_dim;
|
||||
let scale = config.scale();
|
||||
|
||||
let mut output = vec![0.0f32; seq_len * hidden_dim];
|
||||
|
||||
// Process each head independently
|
||||
for h in 0..num_heads {
|
||||
for i in 0..seq_len {
|
||||
// For this query position, compute attention to all key positions
|
||||
let q_offset = i * hidden_dim + h * head_dim;
|
||||
|
||||
// Compute attention scores
|
||||
let mut scores = vec![0.0f32; seq_len];
|
||||
let mut max_score = f32::NEG_INFINITY;
|
||||
|
||||
for j in 0..seq_len {
|
||||
// Causal masking
|
||||
if config.causal && j > i {
|
||||
scores[j] = f32::NEG_INFINITY;
|
||||
continue;
|
||||
}
|
||||
|
||||
let k_offset = j * hidden_dim + h * head_dim;
|
||||
let mut score = 0.0f32;
|
||||
|
||||
for d in 0..head_dim {
|
||||
score += q[q_offset + d] * k[k_offset + d];
|
||||
}
|
||||
|
||||
score *= scale;
|
||||
scores[j] = score;
|
||||
if score > max_score {
|
||||
max_score = score;
|
||||
}
|
||||
}
|
||||
|
||||
// Softmax
|
||||
let mut sum = 0.0f32;
|
||||
for j in 0..seq_len {
|
||||
scores[j] = (scores[j] - max_score).exp();
|
||||
sum += scores[j];
|
||||
}
|
||||
for j in 0..seq_len {
|
||||
scores[j] /= sum;
|
||||
}
|
||||
|
||||
// Compute weighted sum of values
|
||||
let out_offset = i * hidden_dim + h * head_dim;
|
||||
for d in 0..head_dim {
|
||||
let mut weighted_sum = 0.0f32;
|
||||
for j in 0..seq_len {
|
||||
let v_offset = j * hidden_dim + h * head_dim;
|
||||
weighted_sum += scores[j] * v[v_offset + d];
|
||||
}
|
||||
output[out_offset + d] = weighted_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Perform RMS normalization
|
||||
#[wasm_bindgen(js_name = rmsNorm)]
|
||||
pub async fn rms_norm(
|
||||
&self,
|
||||
input: &[f32],
|
||||
weight: &[f32],
|
||||
hidden_dim: u32,
|
||||
eps: f32,
|
||||
) -> Result<Vec<f32>, JsValue> {
|
||||
if weight.len() != hidden_dim as usize {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Weight dimension mismatch: expected {}, got {}",
|
||||
hidden_dim,
|
||||
weight.len()
|
||||
)));
|
||||
}
|
||||
|
||||
if input.len() % hidden_dim as usize != 0 {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Input size {} not divisible by hidden_dim {}",
|
||||
input.len(),
|
||||
hidden_dim
|
||||
)));
|
||||
}
|
||||
|
||||
// CPU implementation
|
||||
let batch_size = input.len() / hidden_dim as usize;
|
||||
let mut output = vec![0.0f32; input.len()];
|
||||
|
||||
for b in 0..batch_size {
|
||||
let offset = b * hidden_dim as usize;
|
||||
|
||||
// Compute sum of squares
|
||||
let mut sum_sq = 0.0f32;
|
||||
for i in 0..hidden_dim as usize {
|
||||
let x = input[offset + i];
|
||||
sum_sq += x * x;
|
||||
}
|
||||
|
||||
// RMS scale
|
||||
let rms = (sum_sq / hidden_dim as f32 + eps).sqrt();
|
||||
|
||||
// Normalize and scale
|
||||
for i in 0..hidden_dim as usize {
|
||||
output[offset + i] = input[offset + i] / rms * weight[i];
|
||||
}
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Perform softmax
|
||||
#[wasm_bindgen]
|
||||
pub async fn softmax(
|
||||
&self,
|
||||
input: &[f32],
|
||||
dim: u32,
|
||||
temperature: f32,
|
||||
) -> Result<Vec<f32>, JsValue> {
|
||||
if input.len() % dim as usize != 0 {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Input size {} not divisible by dim {}",
|
||||
input.len(),
|
||||
dim
|
||||
)));
|
||||
}
|
||||
|
||||
let batch_size = input.len() / dim as usize;
|
||||
let mut output = vec![0.0f32; input.len()];
|
||||
|
||||
for b in 0..batch_size {
|
||||
let offset = b * dim as usize;
|
||||
|
||||
// Find max (for numerical stability)
|
||||
let mut max_val = f32::NEG_INFINITY;
|
||||
for i in 0..dim as usize {
|
||||
let x = input[offset + i] / temperature;
|
||||
if x > max_val {
|
||||
max_val = x;
|
||||
}
|
||||
}
|
||||
|
||||
// Compute exp and sum
|
||||
let mut sum = 0.0f32;
|
||||
for i in 0..dim as usize {
|
||||
let x = (input[offset + i] / temperature - max_val).exp();
|
||||
output[offset + i] = x;
|
||||
sum += x;
|
||||
}
|
||||
|
||||
// Normalize
|
||||
for i in 0..dim as usize {
|
||||
output[offset + i] /= sum;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
// Helper methods for GPU buffer management
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
fn create_buffer(
|
||||
&self,
|
||||
size: usize,
|
||||
usage: u32,
|
||||
label: Option<&str>,
|
||||
) -> Result<JsValue, JsValue> {
|
||||
let descriptor = Object::new();
|
||||
Reflect::set(&descriptor, &"size".into(), &JsValue::from_f64(size as f64))?;
|
||||
Reflect::set(
|
||||
&descriptor,
|
||||
&"usage".into(),
|
||||
&JsValue::from_f64(usage as f64),
|
||||
)?;
|
||||
if let Some(lbl) = label {
|
||||
Reflect::set(&descriptor, &"label".into(), &lbl.into())?;
|
||||
}
|
||||
|
||||
call_method(&self.device, "createBuffer", &[descriptor.into()])
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
fn write_buffer(&self, buffer: &JsValue, data: &[f32]) -> Result<(), JsValue> {
|
||||
let data_array = Float32Array::from(data);
|
||||
call_method(
|
||||
&self.queue,
|
||||
"writeBuffer",
|
||||
&[
|
||||
buffer.clone(),
|
||||
JsValue::from_f64(0.0),
|
||||
data_array.buffer().into(),
|
||||
],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_cpu_matmul_fallback() {
|
||||
// Test the CPU fallback logic (in non-wasm mode)
|
||||
let a = vec![1.0, 2.0, 3.0, 4.0]; // 2x2
|
||||
let b = vec![5.0, 6.0, 7.0, 8.0]; // 2x2
|
||||
|
||||
// Expected: [[1*5+2*7, 1*6+2*8], [3*5+4*7, 3*6+4*8]]
|
||||
// = [[19, 22], [43, 50]]
|
||||
|
||||
let mut c = vec![0.0f32; 4];
|
||||
for i in 0..2usize {
|
||||
for j in 0..2usize {
|
||||
let mut sum = 0.0f32;
|
||||
for l in 0..2usize {
|
||||
sum += a[i * 2 + l] * b[l * 2 + j];
|
||||
}
|
||||
c[i * 2 + j] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(c, vec![19.0, 22.0, 43.0, 50.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rms_norm_cpu() {
|
||||
let input = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let weight = vec![1.0, 1.0, 1.0, 1.0];
|
||||
let hidden_dim = 4;
|
||||
let eps = 1e-5f32;
|
||||
|
||||
// sum_sq = 1 + 4 + 9 + 16 = 30
|
||||
// rms = sqrt(30/4 + eps) = sqrt(7.5) ≈ 2.7386
|
||||
let rms = (30.0f32 / 4.0 + eps).sqrt();
|
||||
|
||||
let expected: Vec<f32> = input.iter().map(|&x| x / rms).collect();
|
||||
|
||||
// Verify calculation
|
||||
assert!((expected[0] - 0.3651).abs() < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_softmax_cpu() {
|
||||
let input = vec![1.0, 2.0, 3.0];
|
||||
let temperature = 1.0f32;
|
||||
|
||||
// max = 3
|
||||
// exp(1-3) = exp(-2), exp(2-3) = exp(-1), exp(3-3) = exp(0) = 1
|
||||
let exps: Vec<f32> = vec![(-2.0f32).exp(), (-1.0f32).exp(), 1.0];
|
||||
let sum: f32 = exps.iter().sum();
|
||||
let expected: Vec<f32> = exps.iter().map(|&x| x / sum).collect();
|
||||
|
||||
// Verify softmax sums to 1
|
||||
let softmax_sum: f32 = expected.iter().sum();
|
||||
assert!((softmax_sum - 1.0).abs() < 0.001);
|
||||
}
|
||||
}
|
||||
345
vendor/ruvector/crates/ruvllm-wasm/src/webgpu/mod.rs
vendored
Normal file
345
vendor/ruvector/crates/ruvllm-wasm/src/webgpu/mod.rs
vendored
Normal file
@@ -0,0 +1,345 @@
|
||||
//! WebGPU Compute Module for WASM-based GPU Acceleration
|
||||
//!
|
||||
//! This module provides WebGPU compute shader support for LLM inference
|
||||
//! operations in the browser. It includes:
|
||||
//!
|
||||
//! - Matrix multiplication (tiled, batched, GEMV)
|
||||
//! - Flash Attention (causal, GQA, decode)
|
||||
//! - RMSNorm and LayerNorm
|
||||
//! - Softmax (standard, temperature-scaled, log-softmax)
|
||||
//!
|
||||
//! ## Feature Detection
|
||||
//!
|
||||
//! WebGPU availability is checked at runtime with graceful fallback:
|
||||
//!
|
||||
//! ```javascript
|
||||
//! if (await WebGpuInference.isAvailable()) {
|
||||
//! const gpu = await WebGpuInference.init();
|
||||
//! const result = await gpu.matmul(a, b, m, n, k);
|
||||
//! } else {
|
||||
//! // Fall back to CPU implementation
|
||||
//! }
|
||||
//! ```
|
||||
//!
|
||||
//! ## Performance Targets
|
||||
//!
|
||||
//! - Matrix multiply: ~1 TFLOP on integrated GPUs, ~10 TFLOPS on discrete
|
||||
//! - Attention: 2ms for 4K context on discrete GPU
|
||||
//! - Normalization: <0.5ms for typical hidden dimensions
|
||||
|
||||
pub mod buffers;
|
||||
pub mod compute;
|
||||
pub mod shaders;
|
||||
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
pub use buffers::{GpuBuffer, GpuBufferUsage};
|
||||
pub use compute::{ComputePipeline, WebGpuContext};
|
||||
pub use shaders::ShaderModule;
|
||||
|
||||
/// GPU adapter information
|
||||
#[wasm_bindgen]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AdapterInfo {
|
||||
/// GPU vendor name
|
||||
#[wasm_bindgen(skip)]
|
||||
pub vendor: String,
|
||||
/// GPU architecture/device name
|
||||
#[wasm_bindgen(skip)]
|
||||
pub architecture: String,
|
||||
/// Device type (integrated, discrete, etc.)
|
||||
#[wasm_bindgen(skip)]
|
||||
pub device_type: String,
|
||||
/// Backend API (WebGPU, etc.)
|
||||
#[wasm_bindgen(skip)]
|
||||
pub backend: String,
|
||||
/// Maximum buffer size in bytes
|
||||
#[wasm_bindgen(skip)]
|
||||
pub max_buffer_size: u64,
|
||||
/// Maximum compute workgroup size
|
||||
#[wasm_bindgen(skip)]
|
||||
pub max_workgroup_size: u32,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl AdapterInfo {
|
||||
/// Get GPU vendor name
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn vendor(&self) -> String {
|
||||
self.vendor.clone()
|
||||
}
|
||||
|
||||
/// Get GPU architecture
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn architecture(&self) -> String {
|
||||
self.architecture.clone()
|
||||
}
|
||||
|
||||
/// Get device type
|
||||
#[wasm_bindgen(getter, js_name = deviceType)]
|
||||
pub fn device_type(&self) -> String {
|
||||
self.device_type.clone()
|
||||
}
|
||||
|
||||
/// Get backend API
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn backend(&self) -> String {
|
||||
self.backend.clone()
|
||||
}
|
||||
|
||||
/// Get maximum buffer size
|
||||
#[wasm_bindgen(getter, js_name = maxBufferSize)]
|
||||
pub fn max_buffer_size(&self) -> u64 {
|
||||
self.max_buffer_size
|
||||
}
|
||||
|
||||
/// Get maximum workgroup size
|
||||
#[wasm_bindgen(getter, js_name = maxWorkgroupSize)]
|
||||
pub fn max_workgroup_size(&self) -> u32 {
|
||||
self.max_workgroup_size
|
||||
}
|
||||
|
||||
/// Convert to JSON string
|
||||
#[wasm_bindgen(js_name = toJson)]
|
||||
pub fn to_json(&self) -> Result<String, JsValue> {
|
||||
let json = serde_json::json!({
|
||||
"vendor": self.vendor,
|
||||
"architecture": self.architecture,
|
||||
"deviceType": self.device_type,
|
||||
"backend": self.backend,
|
||||
"maxBufferSize": self.max_buffer_size,
|
||||
"maxWorkgroupSize": self.max_workgroup_size,
|
||||
});
|
||||
serde_json::to_string(&json).map_err(|e| JsValue::from_str(&e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Attention configuration for compute shaders
|
||||
#[wasm_bindgen]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AttentionConfig {
|
||||
/// Sequence length for queries
|
||||
#[wasm_bindgen(skip)]
|
||||
pub seq_len: u32,
|
||||
/// Key/Value sequence length (can differ for encoder-decoder)
|
||||
#[wasm_bindgen(skip)]
|
||||
pub kv_seq_len: u32,
|
||||
/// Number of attention heads
|
||||
#[wasm_bindgen(skip)]
|
||||
pub num_heads: u32,
|
||||
/// Dimension per head
|
||||
#[wasm_bindgen(skip)]
|
||||
pub head_dim: u32,
|
||||
/// Whether to apply causal masking
|
||||
#[wasm_bindgen(skip)]
|
||||
pub causal: bool,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl AttentionConfig {
|
||||
/// Create new attention configuration
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(seq_len: u32, num_heads: u32, head_dim: u32, causal: bool) -> Self {
|
||||
Self {
|
||||
seq_len,
|
||||
kv_seq_len: seq_len,
|
||||
num_heads,
|
||||
head_dim,
|
||||
causal,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create for encoder-decoder models with different KV length
|
||||
#[wasm_bindgen(js_name = forEncoderDecoder)]
|
||||
pub fn for_encoder_decoder(
|
||||
seq_len: u32,
|
||||
kv_seq_len: u32,
|
||||
num_heads: u32,
|
||||
head_dim: u32,
|
||||
) -> Self {
|
||||
Self {
|
||||
seq_len,
|
||||
kv_seq_len,
|
||||
num_heads,
|
||||
head_dim,
|
||||
causal: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the scaling factor (1/sqrt(head_dim))
|
||||
pub fn scale(&self) -> f32 {
|
||||
1.0 / (self.head_dim as f32).sqrt()
|
||||
}
|
||||
|
||||
/// Get total hidden dimension
|
||||
pub fn hidden_dim(&self) -> u32 {
|
||||
self.num_heads * self.head_dim
|
||||
}
|
||||
|
||||
#[wasm_bindgen(getter, js_name = seqLen)]
|
||||
pub fn get_seq_len(&self) -> u32 {
|
||||
self.seq_len
|
||||
}
|
||||
|
||||
#[wasm_bindgen(setter, js_name = seqLen)]
|
||||
pub fn set_seq_len(&mut self, value: u32) {
|
||||
self.seq_len = value;
|
||||
}
|
||||
|
||||
#[wasm_bindgen(getter, js_name = kvSeqLen)]
|
||||
pub fn get_kv_seq_len(&self) -> u32 {
|
||||
self.kv_seq_len
|
||||
}
|
||||
|
||||
#[wasm_bindgen(setter, js_name = kvSeqLen)]
|
||||
pub fn set_kv_seq_len(&mut self, value: u32) {
|
||||
self.kv_seq_len = value;
|
||||
}
|
||||
|
||||
#[wasm_bindgen(getter, js_name = numHeads)]
|
||||
pub fn get_num_heads(&self) -> u32 {
|
||||
self.num_heads
|
||||
}
|
||||
|
||||
#[wasm_bindgen(setter, js_name = numHeads)]
|
||||
pub fn set_num_heads(&mut self, value: u32) {
|
||||
self.num_heads = value;
|
||||
}
|
||||
|
||||
#[wasm_bindgen(getter, js_name = headDim)]
|
||||
pub fn get_head_dim(&self) -> u32 {
|
||||
self.head_dim
|
||||
}
|
||||
|
||||
#[wasm_bindgen(setter, js_name = headDim)]
|
||||
pub fn set_head_dim(&mut self, value: u32) {
|
||||
self.head_dim = value;
|
||||
}
|
||||
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn get_causal(&self) -> bool {
|
||||
self.causal
|
||||
}
|
||||
|
||||
#[wasm_bindgen(setter)]
|
||||
pub fn set_causal(&mut self, value: bool) {
|
||||
self.causal = value;
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if WebGPU is available in this browser
|
||||
#[wasm_bindgen(js_name = isWebGpuAvailable)]
|
||||
pub async fn is_webgpu_available() -> bool {
|
||||
compute::is_webgpu_available().await
|
||||
}
|
||||
|
||||
/// Get GPU information if available
|
||||
#[wasm_bindgen(js_name = getGpuInfo)]
|
||||
pub async fn get_gpu_info() -> Result<JsValue, JsValue> {
|
||||
match compute::get_gpu_info().await {
|
||||
Some(info) => {
|
||||
let js_obj = js_sys::Object::new();
|
||||
js_sys::Reflect::set(&js_obj, &"vendor".into(), &info.vendor.into())?;
|
||||
js_sys::Reflect::set(&js_obj, &"architecture".into(), &info.architecture.into())?;
|
||||
js_sys::Reflect::set(&js_obj, &"deviceType".into(), &info.device_type.into())?;
|
||||
js_sys::Reflect::set(&js_obj, &"backend".into(), &info.backend.into())?;
|
||||
js_sys::Reflect::set(
|
||||
&js_obj,
|
||||
&"maxBufferSize".into(),
|
||||
&JsValue::from_f64(info.max_buffer_size as f64),
|
||||
)?;
|
||||
js_sys::Reflect::set(
|
||||
&js_obj,
|
||||
&"maxWorkgroupSize".into(),
|
||||
&JsValue::from_f64(info.max_workgroup_size as f64),
|
||||
)?;
|
||||
Ok(js_obj.into())
|
||||
}
|
||||
None => Ok(JsValue::NULL),
|
||||
}
|
||||
}
|
||||
|
||||
/// WebGPU error types
|
||||
#[derive(Debug)]
|
||||
pub enum WebGpuError {
|
||||
/// WebGPU not available in this browser
|
||||
NotAvailable,
|
||||
/// Failed to get GPU adapter
|
||||
AdapterNotFound,
|
||||
/// Failed to create device
|
||||
DeviceCreationFailed(String),
|
||||
/// Buffer allocation failed
|
||||
BufferAllocationFailed { requested: usize, available: usize },
|
||||
/// Shader compilation failed
|
||||
ShaderCompilationFailed(String),
|
||||
/// Invalid dimensions for operation
|
||||
DimensionMismatch { expected: String, actual: String },
|
||||
/// Operation timed out
|
||||
Timeout,
|
||||
/// Generic GPU error
|
||||
GpuError(String),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for WebGpuError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::NotAvailable => write!(f, "WebGPU is not available in this browser"),
|
||||
Self::AdapterNotFound => write!(f, "No suitable GPU adapter found"),
|
||||
Self::DeviceCreationFailed(msg) => write!(f, "Failed to create GPU device: {}", msg),
|
||||
Self::BufferAllocationFailed {
|
||||
requested,
|
||||
available,
|
||||
} => {
|
||||
write!(
|
||||
f,
|
||||
"Buffer allocation failed: requested {} bytes, {} available",
|
||||
requested, available
|
||||
)
|
||||
}
|
||||
Self::ShaderCompilationFailed(msg) => write!(f, "Shader compilation failed: {}", msg),
|
||||
Self::DimensionMismatch { expected, actual } => {
|
||||
write!(
|
||||
f,
|
||||
"Dimension mismatch: expected {}, got {}",
|
||||
expected, actual
|
||||
)
|
||||
}
|
||||
Self::Timeout => write!(f, "GPU operation timed out"),
|
||||
Self::GpuError(msg) => write!(f, "GPU error: {}", msg),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for WebGpuError {}
|
||||
|
||||
impl From<WebGpuError> for JsValue {
|
||||
fn from(error: WebGpuError) -> Self {
|
||||
JsValue::from_str(&error.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_attention_config() {
|
||||
let config = AttentionConfig::new(512, 8, 64, true);
|
||||
assert_eq!(config.hidden_dim(), 512);
|
||||
assert!((config.scale() - 0.125).abs() < 0.001); // 1/sqrt(64) = 0.125
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_adapter_info_json() {
|
||||
let info = AdapterInfo {
|
||||
vendor: "TestVendor".to_string(),
|
||||
architecture: "TestArch".to_string(),
|
||||
device_type: "integrated".to_string(),
|
||||
backend: "WebGPU".to_string(),
|
||||
max_buffer_size: 1024 * 1024 * 256,
|
||||
max_workgroup_size: 256,
|
||||
};
|
||||
let json = info.to_json().unwrap();
|
||||
assert!(json.contains("TestVendor"));
|
||||
}
|
||||
}
|
||||
195
vendor/ruvector/crates/ruvllm-wasm/src/webgpu/shaders.rs
vendored
Normal file
195
vendor/ruvector/crates/ruvllm-wasm/src/webgpu/shaders.rs
vendored
Normal file
@@ -0,0 +1,195 @@
|
||||
//! WGSL Shader Module Definitions
|
||||
//!
|
||||
//! This module contains the embedded WGSL shader source code for all
|
||||
//! compute operations. Shaders are embedded at compile time for efficient
|
||||
//! loading in WASM.
|
||||
|
||||
/// Matrix multiplication shader (tiled with shared memory)
|
||||
pub const MATMUL_SHADER: &str = include_str!("shaders/matmul.wgsl");
|
||||
|
||||
/// Flash attention shader (online softmax, causal masking)
|
||||
pub const ATTENTION_SHADER: &str = include_str!("shaders/attention.wgsl");
|
||||
|
||||
/// RMSNorm and LayerNorm shader
|
||||
pub const NORM_SHADER: &str = include_str!("shaders/norm.wgsl");
|
||||
|
||||
/// Softmax shader (numerically stable)
|
||||
pub const SOFTMAX_SHADER: &str = include_str!("shaders/softmax.wgsl");
|
||||
|
||||
/// Shader entry points for matrix multiplication
|
||||
pub mod matmul {
|
||||
/// Standard tiled matrix multiply
|
||||
pub const MAIN: &str = "main";
|
||||
/// Batched matrix multiply for attention projections
|
||||
pub const BATCHED: &str = "main_batched";
|
||||
/// Vector-matrix multiply for single token generation
|
||||
pub const GEMV: &str = "main_gemv";
|
||||
}
|
||||
|
||||
/// Shader entry points for attention
|
||||
pub mod attention {
|
||||
/// Standard multi-head attention
|
||||
pub const MAIN: &str = "main";
|
||||
/// Grouped query attention (GQA)
|
||||
pub const GQA: &str = "main_gqa";
|
||||
/// Single token decode attention
|
||||
pub const DECODE: &str = "main_decode";
|
||||
}
|
||||
|
||||
/// Shader entry points for normalization
|
||||
pub mod norm {
|
||||
/// RMSNorm (Llama-style)
|
||||
pub const RMS_NORM: &str = "rms_norm";
|
||||
/// RMSNorm with fused residual connection
|
||||
pub const RMS_NORM_RESIDUAL: &str = "rms_norm_residual";
|
||||
/// Standard LayerNorm
|
||||
pub const LAYER_NORM: &str = "layer_norm";
|
||||
/// Fast RMSNorm for small dimensions
|
||||
pub const RMS_NORM_SMALL: &str = "rms_norm_small";
|
||||
}
|
||||
|
||||
/// Shader entry points for softmax
|
||||
pub mod softmax {
|
||||
/// Standard row-wise softmax
|
||||
pub const MAIN: &str = "softmax";
|
||||
/// In-place softmax
|
||||
pub const INPLACE: &str = "softmax_inplace";
|
||||
/// Small dimension softmax
|
||||
pub const SMALL: &str = "softmax_small";
|
||||
/// Log softmax for loss computation
|
||||
pub const LOG_SOFTMAX: &str = "log_softmax";
|
||||
}
|
||||
|
||||
/// Shader module wrapper for wasm-bindgen
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
#[wasm_bindgen]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ShaderModule {
|
||||
name: String,
|
||||
source: String,
|
||||
entry_points: Vec<String>,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl ShaderModule {
|
||||
/// Get the matrix multiplication shader module
|
||||
#[wasm_bindgen(js_name = matmul)]
|
||||
pub fn get_matmul() -> ShaderModule {
|
||||
ShaderModule {
|
||||
name: "matmul".to_string(),
|
||||
source: MATMUL_SHADER.to_string(),
|
||||
entry_points: vec![
|
||||
matmul::MAIN.to_string(),
|
||||
matmul::BATCHED.to_string(),
|
||||
matmul::GEMV.to_string(),
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the attention shader module
|
||||
#[wasm_bindgen(js_name = attention)]
|
||||
pub fn get_attention() -> ShaderModule {
|
||||
ShaderModule {
|
||||
name: "attention".to_string(),
|
||||
source: ATTENTION_SHADER.to_string(),
|
||||
entry_points: vec![
|
||||
attention::MAIN.to_string(),
|
||||
attention::GQA.to_string(),
|
||||
attention::DECODE.to_string(),
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the normalization shader module
|
||||
#[wasm_bindgen(js_name = norm)]
|
||||
pub fn get_norm() -> ShaderModule {
|
||||
ShaderModule {
|
||||
name: "norm".to_string(),
|
||||
source: NORM_SHADER.to_string(),
|
||||
entry_points: vec![
|
||||
norm::RMS_NORM.to_string(),
|
||||
norm::RMS_NORM_RESIDUAL.to_string(),
|
||||
norm::LAYER_NORM.to_string(),
|
||||
norm::RMS_NORM_SMALL.to_string(),
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the softmax shader module
|
||||
#[wasm_bindgen(js_name = softmax)]
|
||||
pub fn get_softmax() -> ShaderModule {
|
||||
ShaderModule {
|
||||
name: "softmax".to_string(),
|
||||
source: SOFTMAX_SHADER.to_string(),
|
||||
entry_points: vec![
|
||||
softmax::MAIN.to_string(),
|
||||
softmax::INPLACE.to_string(),
|
||||
softmax::SMALL.to_string(),
|
||||
softmax::LOG_SOFTMAX.to_string(),
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
/// Get shader name
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn name(&self) -> String {
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
/// Get shader source code
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn source(&self) -> String {
|
||||
self.source.clone()
|
||||
}
|
||||
|
||||
/// Get available entry points
|
||||
#[wasm_bindgen(getter, js_name = entryPoints)]
|
||||
pub fn entry_points(&self) -> Vec<String> {
|
||||
self.entry_points.clone()
|
||||
}
|
||||
|
||||
/// Check if an entry point exists
|
||||
#[wasm_bindgen(js_name = hasEntryPoint)]
|
||||
pub fn has_entry_point(&self, name: &str) -> bool {
|
||||
self.entry_points.iter().any(|ep| ep == name)
|
||||
}
|
||||
}
|
||||
|
||||
/// Get all available shader modules
|
||||
#[wasm_bindgen(js_name = getAllShaderModules)]
|
||||
pub fn get_all_shader_modules() -> Vec<ShaderModule> {
|
||||
vec![
|
||||
ShaderModule::get_matmul(),
|
||||
ShaderModule::get_attention(),
|
||||
ShaderModule::get_norm(),
|
||||
ShaderModule::get_softmax(),
|
||||
]
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_shader_sources_not_empty() {
|
||||
assert!(!MATMUL_SHADER.is_empty());
|
||||
assert!(!ATTENTION_SHADER.is_empty());
|
||||
assert!(!NORM_SHADER.is_empty());
|
||||
assert!(!SOFTMAX_SHADER.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shader_module_creation() {
|
||||
let matmul = ShaderModule::get_matmul();
|
||||
assert_eq!(matmul.name(), "matmul");
|
||||
assert!(matmul.has_entry_point("main"));
|
||||
assert!(matmul.has_entry_point("main_batched"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_shader_modules() {
|
||||
let modules = get_all_shader_modules();
|
||||
assert_eq!(modules.len(), 4);
|
||||
}
|
||||
}
|
||||
283
vendor/ruvector/crates/ruvllm-wasm/src/webgpu/shaders/attention.wgsl
vendored
Normal file
283
vendor/ruvector/crates/ruvllm-wasm/src/webgpu/shaders/attention.wgsl
vendored
Normal file
@@ -0,0 +1,283 @@
|
||||
// Flash Attention Shader for WebGPU WASM
|
||||
//
|
||||
// Implements memory-efficient attention using online softmax algorithm.
|
||||
// Supports causal masking for autoregressive generation.
|
||||
//
|
||||
// Algorithm:
|
||||
// 1. Process Q in blocks, streaming K and V
|
||||
// 2. Maintain running max and sum for numerical stability
|
||||
// 3. Rescale outputs on-the-fly (Flash Attention v2)
|
||||
// 4. O(n) memory vs O(n^2) for standard attention
|
||||
//
|
||||
// Memory Layout:
|
||||
// - Q: (seq_len, num_heads, head_dim)
|
||||
// - K: (seq_len, num_heads, head_dim)
|
||||
// - V: (seq_len, num_heads, head_dim)
|
||||
// - Output: (seq_len, num_heads, head_dim)
|
||||
|
||||
const BLOCK_SIZE: u32 = 32u; // Reduced for WebGPU limits
|
||||
const MAX_HEAD_DIM: u32 = 128u;
|
||||
|
||||
struct AttentionUniforms {
|
||||
seq_len: u32,
|
||||
head_dim: u32,
|
||||
num_heads: u32,
|
||||
scale: f32, // 1/sqrt(head_dim)
|
||||
causal_mask: u32, // 1 for causal, 0 for full attention
|
||||
kv_seq_len: u32, // For encoder-decoder or prefill
|
||||
_pad0: u32,
|
||||
_pad1: u32,
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> Q: array<f32>;
|
||||
@group(0) @binding(1) var<storage, read> K: array<f32>;
|
||||
@group(0) @binding(2) var<storage, read> V: array<f32>;
|
||||
@group(0) @binding(3) var<storage, read_write> Output: array<f32>;
|
||||
@group(0) @binding(4) var<uniform> uniforms: AttentionUniforms;
|
||||
|
||||
// Shared memory for blocks
|
||||
var<workgroup> Q_shared: array<f32, 4096>; // BLOCK_SIZE * MAX_HEAD_DIM
|
||||
var<workgroup> K_shared: array<f32, 4096>;
|
||||
var<workgroup> V_shared: array<f32, 4096>;
|
||||
var<workgroup> scores_shared: array<f32, 1024>; // BLOCK_SIZE * BLOCK_SIZE
|
||||
|
||||
// Thread-local state for online softmax
|
||||
var<private> m_i: f32; // Running max
|
||||
var<private> l_i: f32; // Running sum
|
||||
var<private> o_i: array<f32, 128>; // Output accumulator
|
||||
|
||||
@compute @workgroup_size(32, 1, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
let seq_len = uniforms.seq_len;
|
||||
let head_dim = uniforms.head_dim;
|
||||
let num_heads = uniforms.num_heads;
|
||||
let scale = uniforms.scale;
|
||||
let is_causal = uniforms.causal_mask == 1u;
|
||||
let kv_seq_len = uniforms.kv_seq_len;
|
||||
|
||||
// This workgroup handles one Q block for one head
|
||||
let head_idx = group_id.y;
|
||||
let q_block_idx = group_id.x;
|
||||
let q_start = q_block_idx * BLOCK_SIZE;
|
||||
|
||||
let thread_id = local_id.x;
|
||||
let hidden_stride = num_heads * head_dim;
|
||||
|
||||
// Initialize online softmax state
|
||||
m_i = -1e10f;
|
||||
l_i = 0.0f;
|
||||
for (var d = 0u; d < head_dim; d++) {
|
||||
o_i[d] = 0.0f;
|
||||
}
|
||||
|
||||
// Load Q block into shared memory
|
||||
let q_pos = q_start + thread_id;
|
||||
if (q_pos < seq_len && thread_id < BLOCK_SIZE) {
|
||||
for (var d = 0u; d < head_dim; d++) {
|
||||
let q_idx = q_pos * hidden_stride + head_idx * head_dim + d;
|
||||
Q_shared[thread_id * head_dim + d] = Q[q_idx];
|
||||
}
|
||||
}
|
||||
workgroupBarrier();
|
||||
|
||||
// Iterate over K/V blocks
|
||||
let num_kv_blocks = (kv_seq_len + BLOCK_SIZE - 1u) / BLOCK_SIZE;
|
||||
|
||||
for (var kv_block = 0u; kv_block < num_kv_blocks; kv_block++) {
|
||||
let kv_start = kv_block * BLOCK_SIZE;
|
||||
|
||||
// Early exit for causal attention
|
||||
if (is_causal && kv_start > q_start + BLOCK_SIZE) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Load K block
|
||||
let k_pos = kv_start + thread_id;
|
||||
if (k_pos < kv_seq_len && thread_id < BLOCK_SIZE) {
|
||||
for (var d = 0u; d < head_dim; d++) {
|
||||
let k_idx = k_pos * hidden_stride + head_idx * head_dim + d;
|
||||
K_shared[thread_id * head_dim + d] = K[k_idx];
|
||||
}
|
||||
}
|
||||
|
||||
// Load V block
|
||||
let v_pos = kv_start + thread_id;
|
||||
if (v_pos < kv_seq_len && thread_id < BLOCK_SIZE) {
|
||||
for (var d = 0u; d < head_dim; d++) {
|
||||
let v_idx = v_pos * hidden_stride + head_idx * head_dim + d;
|
||||
V_shared[thread_id * head_dim + d] = V[v_idx];
|
||||
}
|
||||
}
|
||||
workgroupBarrier();
|
||||
|
||||
// Compute attention scores and update online softmax
|
||||
if (thread_id < BLOCK_SIZE && q_pos < seq_len) {
|
||||
let kv_block_len = min(BLOCK_SIZE, kv_seq_len - kv_start);
|
||||
|
||||
// Compute row max for this block
|
||||
var block_max = -1e10f;
|
||||
var local_scores: array<f32, 32>;
|
||||
|
||||
for (var k = 0u; k < kv_block_len; k++) {
|
||||
let k_global = kv_start + k;
|
||||
|
||||
// Apply causal mask
|
||||
if (is_causal && k_global > q_pos) {
|
||||
local_scores[k] = -1e10f;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Compute Q[q_pos] dot K[k]
|
||||
var score = 0.0f;
|
||||
for (var d = 0u; d < head_dim; d++) {
|
||||
score += Q_shared[thread_id * head_dim + d] * K_shared[k * head_dim + d];
|
||||
}
|
||||
score *= scale;
|
||||
local_scores[k] = score;
|
||||
block_max = max(block_max, score);
|
||||
}
|
||||
|
||||
// Update running statistics
|
||||
let m_ij = max(m_i, block_max);
|
||||
|
||||
// Rescale previous accumulator
|
||||
let alpha = exp(m_i - m_ij);
|
||||
for (var d = 0u; d < head_dim; d++) {
|
||||
o_i[d] *= alpha;
|
||||
}
|
||||
l_i *= alpha;
|
||||
|
||||
// Accumulate weighted V for this block
|
||||
for (var k = 0u; k < kv_block_len; k++) {
|
||||
let k_global = kv_start + k;
|
||||
if (is_causal && k_global > q_pos) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let p_ij = exp(local_scores[k] - m_ij);
|
||||
l_i += p_ij;
|
||||
|
||||
for (var d = 0u; d < head_dim; d++) {
|
||||
o_i[d] += p_ij * V_shared[k * head_dim + d];
|
||||
}
|
||||
}
|
||||
|
||||
m_i = m_ij;
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
// Normalize and write output
|
||||
if (thread_id < BLOCK_SIZE && q_pos < seq_len) {
|
||||
let inv_l = select(1.0f / l_i, 0.0f, l_i == 0.0f);
|
||||
|
||||
for (var d = 0u; d < head_dim; d++) {
|
||||
let out_idx = q_pos * hidden_stride + head_idx * head_dim + d;
|
||||
Output[out_idx] = o_i[d] * inv_l;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Grouped Query Attention (GQA) variant
|
||||
// Multiple Q heads share same K/V heads
|
||||
@compute @workgroup_size(32, 1, 1)
|
||||
fn main_gqa(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
// For GQA: kv_head_idx = q_head_idx / num_q_per_kv
|
||||
// This allows Llama2/3 style grouped attention
|
||||
// Implementation similar to main() with modified indexing
|
||||
}
|
||||
|
||||
// Single token attention for generation phase
|
||||
// More efficient when seq_len = 1 (decoding)
|
||||
@compute @workgroup_size(256, 1, 1)
|
||||
fn main_decode(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
let head_dim = uniforms.head_dim;
|
||||
let num_heads = uniforms.num_heads;
|
||||
let scale = uniforms.scale;
|
||||
let kv_seq_len = uniforms.kv_seq_len;
|
||||
let is_causal = uniforms.causal_mask == 1u;
|
||||
|
||||
let head_idx = group_id.x;
|
||||
let thread_id = local_id.x;
|
||||
let hidden_stride = num_heads * head_dim;
|
||||
|
||||
// Each thread handles part of the KV sequence
|
||||
let kv_per_thread = (kv_seq_len + 255u) / 256u;
|
||||
|
||||
// Thread-local accumulators
|
||||
var local_max = -1e10f;
|
||||
var local_sum = 0.0f;
|
||||
var local_out: array<f32, 128>;
|
||||
for (var d = 0u; d < head_dim; d++) {
|
||||
local_out[d] = 0.0f;
|
||||
}
|
||||
|
||||
// Load Q (single token)
|
||||
var q_vec: array<f32, 128>;
|
||||
if (thread_id == 0u) {
|
||||
for (var d = 0u; d < head_dim; d++) {
|
||||
q_vec[d] = Q[head_idx * head_dim + d];
|
||||
}
|
||||
}
|
||||
// Broadcast Q to all threads via shared memory
|
||||
for (var d = 0u; d < head_dim; d++) {
|
||||
Q_shared[d] = Q[head_idx * head_dim + d];
|
||||
}
|
||||
workgroupBarrier();
|
||||
|
||||
// Process assigned KV positions
|
||||
for (var i = 0u; i < kv_per_thread; i++) {
|
||||
let k_pos = thread_id * kv_per_thread + i;
|
||||
if (k_pos >= kv_seq_len) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Compute attention score
|
||||
var score = 0.0f;
|
||||
for (var d = 0u; d < head_dim; d++) {
|
||||
let k_idx = k_pos * hidden_stride + head_idx * head_dim + d;
|
||||
score += Q_shared[d] * K[k_idx];
|
||||
}
|
||||
score *= scale;
|
||||
|
||||
// Update local max
|
||||
let new_max = max(local_max, score);
|
||||
let alpha = exp(local_max - new_max);
|
||||
|
||||
for (var d = 0u; d < head_dim; d++) {
|
||||
local_out[d] *= alpha;
|
||||
}
|
||||
local_sum = local_sum * alpha + exp(score - new_max);
|
||||
|
||||
// Accumulate weighted V
|
||||
let p = exp(score - new_max);
|
||||
for (var d = 0u; d < head_dim; d++) {
|
||||
let v_idx = k_pos * hidden_stride + head_idx * head_dim + d;
|
||||
local_out[d] += p * V[v_idx];
|
||||
}
|
||||
|
||||
local_max = new_max;
|
||||
}
|
||||
|
||||
// Reduction across threads (simplified - real impl would use parallel reduction)
|
||||
// Store partial results for CPU reduction or use atomics
|
||||
if (thread_id == 0u) {
|
||||
let inv_sum = select(1.0f / local_sum, 0.0f, local_sum == 0.0f);
|
||||
for (var d = 0u; d < head_dim; d++) {
|
||||
Output[head_idx * head_dim + d] = local_out[d] * inv_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
182
vendor/ruvector/crates/ruvllm-wasm/src/webgpu/shaders/matmul.wgsl
vendored
Normal file
182
vendor/ruvector/crates/ruvllm-wasm/src/webgpu/shaders/matmul.wgsl
vendored
Normal file
@@ -0,0 +1,182 @@
|
||||
// Tiled Matrix Multiplication Shader for WebGPU WASM
|
||||
//
|
||||
// Computes C = A * B using 16x16 tiles optimized for browser WebGPU.
|
||||
// Uses workgroup shared memory for cache-efficient tile loading.
|
||||
//
|
||||
// Memory Layout (row-major):
|
||||
// - A: M x K matrix
|
||||
// - B: K x N matrix
|
||||
// - C: M x N matrix (output)
|
||||
|
||||
// Tile size optimized for WebGPU limits
|
||||
const TILE_SIZE: u32 = 16u;
|
||||
|
||||
struct Uniforms {
|
||||
M: u32, // Rows of A, rows of C
|
||||
N: u32, // Cols of B, cols of C
|
||||
K: u32, // Cols of A, rows of B
|
||||
alpha: f32, // Scaling factor (default 1.0)
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> A: array<f32>;
|
||||
@group(0) @binding(1) var<storage, read> B: array<f32>;
|
||||
@group(0) @binding(2) var<storage, read_write> C: array<f32>;
|
||||
@group(0) @binding(3) var<uniform> uniforms: Uniforms;
|
||||
|
||||
// Shared memory for tile caching
|
||||
var<workgroup> A_tile: array<f32, 256>; // TILE_SIZE * TILE_SIZE
|
||||
var<workgroup> B_tile: array<f32, 256>;
|
||||
|
||||
@compute @workgroup_size(16, 16, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
let M = uniforms.M;
|
||||
let N = uniforms.N;
|
||||
let K = uniforms.K;
|
||||
let alpha = uniforms.alpha;
|
||||
|
||||
// Global row and column
|
||||
let row = global_id.x;
|
||||
let col = global_id.y;
|
||||
|
||||
// Thread position within tile
|
||||
let local_row = local_id.x;
|
||||
let local_col = local_id.y;
|
||||
|
||||
// Accumulator for this thread's output element
|
||||
var sum = 0.0f;
|
||||
|
||||
// Number of tiles to process along K dimension
|
||||
let num_tiles = (K + TILE_SIZE - 1u) / TILE_SIZE;
|
||||
|
||||
// Iterate over tiles
|
||||
for (var t = 0u; t < num_tiles; t++) {
|
||||
let tile_k = t * TILE_SIZE;
|
||||
|
||||
// Load A tile element
|
||||
let a_row = row;
|
||||
let a_col = tile_k + local_col;
|
||||
if (a_row < M && a_col < K) {
|
||||
A_tile[local_row * TILE_SIZE + local_col] = A[a_row * K + a_col];
|
||||
} else {
|
||||
A_tile[local_row * TILE_SIZE + local_col] = 0.0;
|
||||
}
|
||||
|
||||
// Load B tile element
|
||||
let b_row = tile_k + local_row;
|
||||
let b_col = col;
|
||||
if (b_row < K && b_col < N) {
|
||||
B_tile[local_row * TILE_SIZE + local_col] = B[b_row * N + b_col];
|
||||
} else {
|
||||
B_tile[local_row * TILE_SIZE + local_col] = 0.0;
|
||||
}
|
||||
|
||||
// Synchronize to ensure tile is fully loaded
|
||||
workgroupBarrier();
|
||||
|
||||
// Compute partial dot product for this tile
|
||||
let tile_k_end = min(TILE_SIZE, K - tile_k);
|
||||
for (var k = 0u; k < tile_k_end; k++) {
|
||||
sum += A_tile[local_row * TILE_SIZE + k] * B_tile[k * TILE_SIZE + local_col];
|
||||
}
|
||||
|
||||
// Synchronize before loading next tile
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
// Write result with optional scaling
|
||||
if (row < M && col < N) {
|
||||
C[row * N + col] = sum * alpha;
|
||||
}
|
||||
}
|
||||
|
||||
// Batched matrix multiply for multi-head attention projections
|
||||
// C[b] = A[b] * B where A is batch_size x M x K and B is K x N
|
||||
@compute @workgroup_size(16, 16, 1)
|
||||
fn main_batched(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
let M = uniforms.M;
|
||||
let N = uniforms.N;
|
||||
let K = uniforms.K;
|
||||
|
||||
let batch_idx = group_id.z;
|
||||
let row = global_id.x;
|
||||
let col = global_id.y;
|
||||
|
||||
let local_row = local_id.x;
|
||||
let local_col = local_id.y;
|
||||
|
||||
var sum = 0.0f;
|
||||
let num_tiles = (K + TILE_SIZE - 1u) / TILE_SIZE;
|
||||
|
||||
// Offset into batched A
|
||||
let batch_offset_a = batch_idx * M * K;
|
||||
let batch_offset_c = batch_idx * M * N;
|
||||
|
||||
for (var t = 0u; t < num_tiles; t++) {
|
||||
let tile_k = t * TILE_SIZE;
|
||||
|
||||
// Load A tile (batched)
|
||||
let a_row = row;
|
||||
let a_col = tile_k + local_col;
|
||||
if (a_row < M && a_col < K) {
|
||||
A_tile[local_row * TILE_SIZE + local_col] = A[batch_offset_a + a_row * K + a_col];
|
||||
} else {
|
||||
A_tile[local_row * TILE_SIZE + local_col] = 0.0;
|
||||
}
|
||||
|
||||
// Load B tile (shared across batch)
|
||||
let b_row = tile_k + local_row;
|
||||
let b_col = col;
|
||||
if (b_row < K && b_col < N) {
|
||||
B_tile[local_row * TILE_SIZE + local_col] = B[b_row * N + b_col];
|
||||
} else {
|
||||
B_tile[local_row * TILE_SIZE + local_col] = 0.0;
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
let tile_k_end = min(TILE_SIZE, K - tile_k);
|
||||
for (var k = 0u; k < tile_k_end; k++) {
|
||||
sum += A_tile[local_row * TILE_SIZE + k] * B_tile[k * TILE_SIZE + local_col];
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
if (row < M && col < N) {
|
||||
C[batch_offset_c + row * N + col] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
// Vector-matrix multiply optimized for single token generation
|
||||
// y = x * W where x is 1 x K and W is K x N
|
||||
@compute @workgroup_size(256, 1, 1)
|
||||
fn main_gemv(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
) {
|
||||
let K = uniforms.K;
|
||||
let N = uniforms.N;
|
||||
|
||||
let col = global_id.x;
|
||||
|
||||
if (col >= N) {
|
||||
return;
|
||||
}
|
||||
|
||||
var sum = 0.0f;
|
||||
|
||||
// Simple reduction - each thread computes one output element
|
||||
for (var k = 0u; k < K; k++) {
|
||||
sum += A[k] * B[k * N + col];
|
||||
}
|
||||
|
||||
C[col] = sum * uniforms.alpha;
|
||||
}
|
||||
235
vendor/ruvector/crates/ruvllm-wasm/src/webgpu/shaders/norm.wgsl
vendored
Normal file
235
vendor/ruvector/crates/ruvllm-wasm/src/webgpu/shaders/norm.wgsl
vendored
Normal file
@@ -0,0 +1,235 @@
|
||||
// RMSNorm and LayerNorm Shaders for WebGPU WASM
|
||||
//
|
||||
// Implements normalization layers used in transformer architectures:
|
||||
// - RMSNorm: Used in Llama, Mistral (no mean subtraction)
|
||||
// - LayerNorm: Standard transformer normalization
|
||||
//
|
||||
// RMSNorm: y = x / sqrt(mean(x^2) + eps) * weight
|
||||
// LayerNorm: y = (x - mean) / sqrt(var + eps) * weight + bias
|
||||
|
||||
const WARP_SIZE: u32 = 32u;
|
||||
const MAX_DIM: u32 = 8192u;
|
||||
|
||||
struct NormUniforms {
|
||||
hidden_dim: u32,
|
||||
batch_size: u32,
|
||||
eps: f32,
|
||||
_pad: u32,
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> input: array<f32>;
|
||||
@group(0) @binding(1) var<storage, read> weight: array<f32>;
|
||||
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
|
||||
@group(0) @binding(3) var<uniform> uniforms: NormUniforms;
|
||||
|
||||
// Shared memory for parallel reduction
|
||||
var<workgroup> partial_sums: array<f32, 256>;
|
||||
|
||||
// RMSNorm: y = x * rsqrt(mean(x^2) + eps) * weight
|
||||
@compute @workgroup_size(256, 1, 1)
|
||||
fn rms_norm(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
let hidden_dim = uniforms.hidden_dim;
|
||||
let eps = uniforms.eps;
|
||||
|
||||
let batch_idx = group_id.x;
|
||||
let thread_id = local_id.x;
|
||||
let offset = batch_idx * hidden_dim;
|
||||
|
||||
// Each thread computes partial sum of squares
|
||||
var thread_sum = 0.0f;
|
||||
let elements_per_thread = (hidden_dim + 255u) / 256u;
|
||||
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < hidden_dim) {
|
||||
let x = input[offset + idx];
|
||||
thread_sum += x * x;
|
||||
}
|
||||
}
|
||||
|
||||
// Store partial sum
|
||||
partial_sums[thread_id] = thread_sum;
|
||||
workgroupBarrier();
|
||||
|
||||
// Parallel reduction for sum of squares
|
||||
for (var stride = 128u; stride > 0u; stride >>= 1u) {
|
||||
if (thread_id < stride) {
|
||||
partial_sums[thread_id] += partial_sums[thread_id + stride];
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
// Compute RMS scale factor
|
||||
let mean_sq = partial_sums[0] / f32(hidden_dim);
|
||||
let rms_scale = 1.0f / sqrt(mean_sq + eps);
|
||||
workgroupBarrier();
|
||||
|
||||
// Apply normalization and weight
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < hidden_dim) {
|
||||
let x = input[offset + idx];
|
||||
output[offset + idx] = x * rms_scale * weight[idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fused RMSNorm + Residual: y = (x + residual) * rsqrt(mean((x+res)^2) + eps) * weight
|
||||
@compute @workgroup_size(256, 1, 1)
|
||||
fn rms_norm_residual(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
let hidden_dim = uniforms.hidden_dim;
|
||||
let eps = uniforms.eps;
|
||||
|
||||
let batch_idx = group_id.x;
|
||||
let thread_id = local_id.x;
|
||||
let offset = batch_idx * hidden_dim;
|
||||
|
||||
// Compute partial sum of (x + residual)^2
|
||||
var thread_sum = 0.0f;
|
||||
let elements_per_thread = (hidden_dim + 255u) / 256u;
|
||||
|
||||
// First pass: compute residual sum and store in shared for reduction
|
||||
// Note: residual is passed in output buffer for in-place update
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < hidden_dim) {
|
||||
let x = input[offset + idx] + output[offset + idx]; // x + residual
|
||||
thread_sum += x * x;
|
||||
}
|
||||
}
|
||||
|
||||
partial_sums[thread_id] = thread_sum;
|
||||
workgroupBarrier();
|
||||
|
||||
// Parallel reduction
|
||||
for (var stride = 128u; stride > 0u; stride >>= 1u) {
|
||||
if (thread_id < stride) {
|
||||
partial_sums[thread_id] += partial_sums[thread_id + stride];
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
let mean_sq = partial_sums[0] / f32(hidden_dim);
|
||||
let rms_scale = 1.0f / sqrt(mean_sq + eps);
|
||||
workgroupBarrier();
|
||||
|
||||
// Apply normalization
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < hidden_dim) {
|
||||
let x = input[offset + idx] + output[offset + idx];
|
||||
output[offset + idx] = x * rms_scale * weight[idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Standard LayerNorm with bias
|
||||
@group(0) @binding(4) var<storage, read> bias: array<f32>;
|
||||
|
||||
@compute @workgroup_size(256, 1, 1)
|
||||
fn layer_norm(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
let hidden_dim = uniforms.hidden_dim;
|
||||
let eps = uniforms.eps;
|
||||
|
||||
let batch_idx = group_id.x;
|
||||
let thread_id = local_id.x;
|
||||
let offset = batch_idx * hidden_dim;
|
||||
|
||||
let elements_per_thread = (hidden_dim + 255u) / 256u;
|
||||
|
||||
// First pass: compute mean
|
||||
var thread_sum = 0.0f;
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < hidden_dim) {
|
||||
thread_sum += input[offset + idx];
|
||||
}
|
||||
}
|
||||
|
||||
partial_sums[thread_id] = thread_sum;
|
||||
workgroupBarrier();
|
||||
|
||||
for (var stride = 128u; stride > 0u; stride >>= 1u) {
|
||||
if (thread_id < stride) {
|
||||
partial_sums[thread_id] += partial_sums[thread_id + stride];
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
let mean = partial_sums[0] / f32(hidden_dim);
|
||||
workgroupBarrier();
|
||||
|
||||
// Second pass: compute variance
|
||||
var thread_var = 0.0f;
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < hidden_dim) {
|
||||
let diff = input[offset + idx] - mean;
|
||||
thread_var += diff * diff;
|
||||
}
|
||||
}
|
||||
|
||||
partial_sums[thread_id] = thread_var;
|
||||
workgroupBarrier();
|
||||
|
||||
for (var stride = 128u; stride > 0u; stride >>= 1u) {
|
||||
if (thread_id < stride) {
|
||||
partial_sums[thread_id] += partial_sums[thread_id + stride];
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
let variance = partial_sums[0] / f32(hidden_dim);
|
||||
let inv_std = 1.0f / sqrt(variance + eps);
|
||||
workgroupBarrier();
|
||||
|
||||
// Third pass: normalize and apply affine transform
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < hidden_dim) {
|
||||
let x = input[offset + idx];
|
||||
output[offset + idx] = (x - mean) * inv_std * weight[idx] + bias[idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fast RMSNorm for small hidden dimensions (direct reduction)
|
||||
@compute @workgroup_size(128, 1, 1)
|
||||
fn rms_norm_small(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
let hidden_dim = uniforms.hidden_dim;
|
||||
let eps = uniforms.eps;
|
||||
|
||||
let batch_idx = group_id.x;
|
||||
let thread_id = local_id.x;
|
||||
let offset = batch_idx * hidden_dim;
|
||||
|
||||
// For small hidden_dim (<= 128), direct computation
|
||||
if (thread_id < hidden_dim) {
|
||||
// Compute sum of squares (all threads contribute)
|
||||
var sum_sq = 0.0f;
|
||||
for (var i = 0u; i < hidden_dim; i++) {
|
||||
let x = input[offset + i];
|
||||
sum_sq += x * x;
|
||||
}
|
||||
|
||||
let rms = sqrt(sum_sq / f32(hidden_dim) + eps);
|
||||
let x = input[offset + thread_id];
|
||||
output[offset + thread_id] = x / rms * weight[thread_id];
|
||||
}
|
||||
}
|
||||
288
vendor/ruvector/crates/ruvllm-wasm/src/webgpu/shaders/softmax.wgsl
vendored
Normal file
288
vendor/ruvector/crates/ruvllm-wasm/src/webgpu/shaders/softmax.wgsl
vendored
Normal file
@@ -0,0 +1,288 @@
|
||||
// Softmax Shader for WebGPU WASM
|
||||
//
|
||||
// Numerically stable softmax: y = exp(x - max(x)) / sum(exp(x - max(x)))
|
||||
// Uses parallel reduction for finding max and computing sum.
|
||||
//
|
||||
// Variants:
|
||||
// - Full softmax for attention scores
|
||||
// - Temperature-scaled softmax for sampling
|
||||
// - Top-k softmax for efficient sampling
|
||||
|
||||
const MAX_SEQ_LEN: u32 = 8192u;
|
||||
|
||||
struct SoftmaxUniforms {
|
||||
dim: u32, // Dimension to reduce over
|
||||
batch_size: u32, // Number of rows
|
||||
temperature: f32, // Scaling factor (1.0 for standard)
|
||||
top_k: u32, // 0 for full softmax, >0 for top-k
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage, read> input: array<f32>;
|
||||
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
|
||||
@group(0) @binding(2) var<uniform> uniforms: SoftmaxUniforms;
|
||||
|
||||
// Shared memory for reductions
|
||||
var<workgroup> reduction_buf: array<f32, 256>;
|
||||
|
||||
// Standard row-wise softmax
|
||||
@compute @workgroup_size(256, 1, 1)
|
||||
fn softmax(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
let dim = uniforms.dim;
|
||||
let temperature = uniforms.temperature;
|
||||
|
||||
let batch_idx = group_id.x;
|
||||
let thread_id = local_id.x;
|
||||
let offset = batch_idx * dim;
|
||||
|
||||
let elements_per_thread = (dim + 255u) / 256u;
|
||||
|
||||
// Phase 1: Find max value
|
||||
var thread_max = -1e10f;
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < dim) {
|
||||
thread_max = max(thread_max, input[offset + idx] / temperature);
|
||||
}
|
||||
}
|
||||
|
||||
reduction_buf[thread_id] = thread_max;
|
||||
workgroupBarrier();
|
||||
|
||||
// Parallel max reduction
|
||||
for (var stride = 128u; stride > 0u; stride >>= 1u) {
|
||||
if (thread_id < stride) {
|
||||
reduction_buf[thread_id] = max(reduction_buf[thread_id], reduction_buf[thread_id + stride]);
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
let max_val = reduction_buf[0];
|
||||
workgroupBarrier();
|
||||
|
||||
// Phase 2: Compute sum of exp(x - max)
|
||||
var thread_sum = 0.0f;
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < dim) {
|
||||
let x = input[offset + idx] / temperature - max_val;
|
||||
thread_sum += exp(x);
|
||||
}
|
||||
}
|
||||
|
||||
reduction_buf[thread_id] = thread_sum;
|
||||
workgroupBarrier();
|
||||
|
||||
// Parallel sum reduction
|
||||
for (var stride = 128u; stride > 0u; stride >>= 1u) {
|
||||
if (thread_id < stride) {
|
||||
reduction_buf[thread_id] += reduction_buf[thread_id + stride];
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
let sum_val = reduction_buf[0];
|
||||
let inv_sum = 1.0f / sum_val;
|
||||
workgroupBarrier();
|
||||
|
||||
// Phase 3: Compute normalized softmax
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < dim) {
|
||||
let x = input[offset + idx] / temperature - max_val;
|
||||
output[offset + idx] = exp(x) * inv_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// In-place softmax (input and output point to same buffer)
|
||||
@compute @workgroup_size(256, 1, 1)
|
||||
fn softmax_inplace(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
let dim = uniforms.dim;
|
||||
let temperature = uniforms.temperature;
|
||||
|
||||
let batch_idx = group_id.x;
|
||||
let thread_id = local_id.x;
|
||||
let offset = batch_idx * dim;
|
||||
|
||||
let elements_per_thread = (dim + 255u) / 256u;
|
||||
|
||||
// Find max
|
||||
var thread_max = -1e10f;
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < dim) {
|
||||
thread_max = max(thread_max, output[offset + idx] / temperature);
|
||||
}
|
||||
}
|
||||
|
||||
reduction_buf[thread_id] = thread_max;
|
||||
workgroupBarrier();
|
||||
|
||||
for (var stride = 128u; stride > 0u; stride >>= 1u) {
|
||||
if (thread_id < stride) {
|
||||
reduction_buf[thread_id] = max(reduction_buf[thread_id], reduction_buf[thread_id + stride]);
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
let max_val = reduction_buf[0];
|
||||
workgroupBarrier();
|
||||
|
||||
// Compute exp and sum
|
||||
var thread_sum = 0.0f;
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < dim) {
|
||||
let x = exp(output[offset + idx] / temperature - max_val);
|
||||
output[offset + idx] = x; // Store intermediate exp value
|
||||
thread_sum += x;
|
||||
}
|
||||
}
|
||||
|
||||
reduction_buf[thread_id] = thread_sum;
|
||||
workgroupBarrier();
|
||||
|
||||
for (var stride = 128u; stride > 0u; stride >>= 1u) {
|
||||
if (thread_id < stride) {
|
||||
reduction_buf[thread_id] += reduction_buf[thread_id + stride];
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
let inv_sum = 1.0f / reduction_buf[0];
|
||||
workgroupBarrier();
|
||||
|
||||
// Normalize in place
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < dim) {
|
||||
output[offset + idx] *= inv_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Small dimension softmax (dim <= 256)
|
||||
@compute @workgroup_size(256, 1, 1)
|
||||
fn softmax_small(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
let dim = uniforms.dim;
|
||||
let temperature = uniforms.temperature;
|
||||
|
||||
let batch_idx = group_id.x;
|
||||
let thread_id = local_id.x;
|
||||
let offset = batch_idx * dim;
|
||||
|
||||
// Load value for this thread
|
||||
var x = -1e10f;
|
||||
if (thread_id < dim) {
|
||||
x = input[offset + thread_id] / temperature;
|
||||
}
|
||||
|
||||
reduction_buf[thread_id] = x;
|
||||
workgroupBarrier();
|
||||
|
||||
// Find max using warp-level operations
|
||||
var max_val = x;
|
||||
for (var i = 0u; i < dim; i++) {
|
||||
max_val = max(max_val, reduction_buf[i]);
|
||||
}
|
||||
workgroupBarrier();
|
||||
|
||||
// Compute exp and sum
|
||||
var exp_val = 0.0f;
|
||||
if (thread_id < dim) {
|
||||
exp_val = exp(x - max_val);
|
||||
}
|
||||
reduction_buf[thread_id] = exp_val;
|
||||
workgroupBarrier();
|
||||
|
||||
var sum_val = 0.0f;
|
||||
for (var i = 0u; i < dim; i++) {
|
||||
sum_val += reduction_buf[i];
|
||||
}
|
||||
|
||||
// Write normalized output
|
||||
if (thread_id < dim) {
|
||||
output[offset + thread_id] = exp_val / sum_val;
|
||||
}
|
||||
}
|
||||
|
||||
// Log softmax for numerical stability in loss computation
|
||||
@compute @workgroup_size(256, 1, 1)
|
||||
fn log_softmax(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) group_id: vec3<u32>,
|
||||
) {
|
||||
let dim = uniforms.dim;
|
||||
let temperature = uniforms.temperature;
|
||||
|
||||
let batch_idx = group_id.x;
|
||||
let thread_id = local_id.x;
|
||||
let offset = batch_idx * dim;
|
||||
|
||||
let elements_per_thread = (dim + 255u) / 256u;
|
||||
|
||||
// Find max
|
||||
var thread_max = -1e10f;
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < dim) {
|
||||
thread_max = max(thread_max, input[offset + idx] / temperature);
|
||||
}
|
||||
}
|
||||
|
||||
reduction_buf[thread_id] = thread_max;
|
||||
workgroupBarrier();
|
||||
|
||||
for (var stride = 128u; stride > 0u; stride >>= 1u) {
|
||||
if (thread_id < stride) {
|
||||
reduction_buf[thread_id] = max(reduction_buf[thread_id], reduction_buf[thread_id + stride]);
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
let max_val = reduction_buf[0];
|
||||
workgroupBarrier();
|
||||
|
||||
// Compute log-sum-exp
|
||||
var thread_sum = 0.0f;
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < dim) {
|
||||
thread_sum += exp(input[offset + idx] / temperature - max_val);
|
||||
}
|
||||
}
|
||||
|
||||
reduction_buf[thread_id] = thread_sum;
|
||||
workgroupBarrier();
|
||||
|
||||
for (var stride = 128u; stride > 0u; stride >>= 1u) {
|
||||
if (thread_id < stride) {
|
||||
reduction_buf[thread_id] += reduction_buf[thread_id + stride];
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
let log_sum = log(reduction_buf[0]) + max_val;
|
||||
workgroupBarrier();
|
||||
|
||||
// Compute log softmax: log(softmax(x)) = x - log_sum_exp(x)
|
||||
for (var i = 0u; i < elements_per_thread; i++) {
|
||||
let idx = thread_id + i * 256u;
|
||||
if (idx < dim) {
|
||||
output[offset + idx] = input[offset + idx] / temperature - log_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
366
vendor/ruvector/crates/ruvllm-wasm/src/workers/feature_detect.rs
vendored
Normal file
366
vendor/ruvector/crates/ruvllm-wasm/src/workers/feature_detect.rs
vendored
Normal file
@@ -0,0 +1,366 @@
|
||||
//! Browser Feature Detection for Web Workers
|
||||
//!
|
||||
//! Detects availability of SharedArrayBuffer, Atomics, and other
|
||||
//! features required for parallel inference.
|
||||
|
||||
use wasm_bindgen::prelude::*;
|
||||
use wasm_bindgen::JsCast;
|
||||
|
||||
/// Check if SharedArrayBuffer is available.
|
||||
///
|
||||
/// SharedArrayBuffer is required for zero-copy memory sharing between
|
||||
/// the main thread and Web Workers.
|
||||
///
|
||||
/// # Notes
|
||||
/// - SharedArrayBuffer was temporarily disabled in all browsers after
|
||||
/// Spectre/Meltdown vulnerabilities were discovered.
|
||||
/// - It's now available again, but requires cross-origin isolation:
|
||||
/// - `Cross-Origin-Opener-Policy: same-origin`
|
||||
/// - `Cross-Origin-Embedder-Policy: require-corp`
|
||||
///
|
||||
/// # Returns
|
||||
/// `true` if SharedArrayBuffer is available, `false` otherwise.
|
||||
#[wasm_bindgen]
|
||||
pub fn is_shared_array_buffer_available() -> bool {
|
||||
// Try to access SharedArrayBuffer constructor
|
||||
let global = js_sys::global();
|
||||
|
||||
if let Ok(sab) = js_sys::Reflect::get(&global, &JsValue::from_str("SharedArrayBuffer")) {
|
||||
if !sab.is_undefined() && !sab.is_null() {
|
||||
// Try to create a small SharedArrayBuffer to verify it's actually usable
|
||||
match js_sys::SharedArrayBuffer::new(8) {
|
||||
_ => return true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
/// Check if Atomics API is available.
|
||||
///
|
||||
/// Atomics provides atomic operations for synchronization between
|
||||
/// the main thread and Web Workers.
|
||||
///
|
||||
/// # Returns
|
||||
/// `true` if Atomics is available, `false` otherwise.
|
||||
#[wasm_bindgen]
|
||||
pub fn is_atomics_available() -> bool {
|
||||
let global = js_sys::global();
|
||||
|
||||
if let Ok(atomics) = js_sys::Reflect::get(&global, &JsValue::from_str("Atomics")) {
|
||||
if !atomics.is_undefined() && !atomics.is_null() {
|
||||
// Verify Atomics.wait and Atomics.notify are available
|
||||
if let Ok(wait) = js_sys::Reflect::get(&atomics, &JsValue::from_str("wait")) {
|
||||
if let Ok(notify) = js_sys::Reflect::get(&atomics, &JsValue::from_str("notify")) {
|
||||
return !wait.is_undefined() && !notify.is_undefined();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
/// Check if the page is cross-origin isolated.
|
||||
///
|
||||
/// Cross-origin isolation is required for SharedArrayBuffer to work.
|
||||
/// The page must be served with:
|
||||
/// - `Cross-Origin-Opener-Policy: same-origin`
|
||||
/// - `Cross-Origin-Embedder-Policy: require-corp`
|
||||
///
|
||||
/// # Returns
|
||||
/// `true` if cross-origin isolated, `false` otherwise.
|
||||
#[wasm_bindgen]
|
||||
pub fn cross_origin_isolated() -> bool {
|
||||
if let Some(window) = web_sys::window() {
|
||||
// crossOriginIsolated is a boolean property on Window
|
||||
if let Ok(isolated) =
|
||||
js_sys::Reflect::get(&window, &JsValue::from_str("crossOriginIsolated"))
|
||||
{
|
||||
return isolated.as_bool().unwrap_or(false);
|
||||
}
|
||||
}
|
||||
|
||||
// Also check in worker context
|
||||
let global = js_sys::global();
|
||||
if let Ok(isolated) = js_sys::Reflect::get(&global, &JsValue::from_str("crossOriginIsolated")) {
|
||||
return isolated.as_bool().unwrap_or(false);
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
/// Check if Web Workers are available.
|
||||
///
|
||||
/// # Returns
|
||||
/// `true` if Web Workers are available, `false` otherwise.
|
||||
#[wasm_bindgen]
|
||||
pub fn is_web_workers_available() -> bool {
|
||||
let global = js_sys::global();
|
||||
|
||||
if let Ok(worker) = js_sys::Reflect::get(&global, &JsValue::from_str("Worker")) {
|
||||
return !worker.is_undefined() && !worker.is_null();
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
/// Get the optimal number of workers based on hardware concurrency.
|
||||
///
|
||||
/// Uses `navigator.hardwareConcurrency` if available, otherwise falls
|
||||
/// back to a reasonable default.
|
||||
///
|
||||
/// # Notes
|
||||
/// - Caps the result at MAX_WORKERS to prevent resource exhaustion.
|
||||
/// - Leaves at least 1 core for the main thread.
|
||||
/// - Falls back to 4 if hardware concurrency is not available.
|
||||
///
|
||||
/// # Returns
|
||||
/// Recommended number of workers.
|
||||
#[wasm_bindgen]
|
||||
pub fn optimal_worker_count() -> usize {
|
||||
const MAX_WORKERS: usize = 16;
|
||||
const MIN_WORKERS: usize = 2;
|
||||
const DEFAULT_WORKERS: usize = 4;
|
||||
|
||||
if let Some(window) = web_sys::window() {
|
||||
let navigator = window.navigator();
|
||||
// hardwareConcurrency returns the number of logical processors
|
||||
let cores = navigator.hardware_concurrency() as usize;
|
||||
if cores > 0 {
|
||||
// Leave at least 1 core for main thread
|
||||
// Cap at MAX_WORKERS
|
||||
return (cores.saturating_sub(1)).clamp(MIN_WORKERS, MAX_WORKERS);
|
||||
}
|
||||
}
|
||||
|
||||
// Check in worker global scope
|
||||
let global = js_sys::global();
|
||||
if let Ok(navigator) = js_sys::Reflect::get(&global, &JsValue::from_str("navigator")) {
|
||||
if !navigator.is_undefined() {
|
||||
if let Ok(cores) =
|
||||
js_sys::Reflect::get(&navigator, &JsValue::from_str("hardwareConcurrency"))
|
||||
{
|
||||
if let Some(c) = cores.as_f64() {
|
||||
let cores = c as usize;
|
||||
if cores > 0 {
|
||||
return (cores.saturating_sub(1)).clamp(MIN_WORKERS, MAX_WORKERS);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
DEFAULT_WORKERS
|
||||
}
|
||||
|
||||
/// Check if SIMD (WebAssembly SIMD) is available.
|
||||
///
|
||||
/// # Returns
|
||||
/// `true` if WASM SIMD is available, `false` otherwise.
|
||||
#[wasm_bindgen]
|
||||
pub fn is_simd_available() -> bool {
|
||||
// This is checked at compile time in Rust
|
||||
#[cfg(target_feature = "simd128")]
|
||||
{
|
||||
true
|
||||
}
|
||||
#[cfg(not(target_feature = "simd128"))]
|
||||
{
|
||||
// Runtime check using WebAssembly.validate
|
||||
let global = js_sys::global();
|
||||
if let Ok(wasm) = js_sys::Reflect::get(&global, &JsValue::from_str("WebAssembly")) {
|
||||
if !wasm.is_undefined() {
|
||||
if let Ok(validate) = js_sys::Reflect::get(&wasm, &JsValue::from_str("validate")) {
|
||||
if validate.is_function() {
|
||||
// SIMD test module (v128.const)
|
||||
let simd_test: [u8; 14] = [
|
||||
0x00, 0x61, 0x73, 0x6d, // magic
|
||||
0x01, 0x00, 0x00, 0x00, // version
|
||||
0x01, 0x05, 0x01, 0x60, // type section
|
||||
0x00, 0x01, // func type () -> v128
|
||||
];
|
||||
|
||||
let arr = js_sys::Uint8Array::from(&simd_test[..]);
|
||||
let validate_fn: js_sys::Function = validate.unchecked_into();
|
||||
if let Ok(result) = validate_fn.call1(&JsValue::NULL, &arr) {
|
||||
return result.as_bool().unwrap_or(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if BigInt is available.
|
||||
///
|
||||
/// BigInt is useful for 64-bit integer operations.
|
||||
///
|
||||
/// # Returns
|
||||
/// `true` if BigInt is available, `false` otherwise.
|
||||
#[wasm_bindgen]
|
||||
pub fn is_bigint_available() -> bool {
|
||||
let global = js_sys::global();
|
||||
|
||||
if let Ok(bigint) = js_sys::Reflect::get(&global, &JsValue::from_str("BigInt")) {
|
||||
return !bigint.is_undefined() && !bigint.is_null();
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
/// Check if Transferable objects are available.
|
||||
///
|
||||
/// Transferable objects (ArrayBuffer, MessagePort, etc.) can be
|
||||
/// transferred to workers without copying.
|
||||
///
|
||||
/// # Returns
|
||||
/// `true` if Transferable objects are available, `false` otherwise.
|
||||
#[wasm_bindgen]
|
||||
pub fn is_transferable_available() -> bool {
|
||||
// Transferable is supported in all modern browsers
|
||||
// Try to create an ArrayBuffer which is always transferable
|
||||
let buffer = js_sys::ArrayBuffer::new(8);
|
||||
let global = js_sys::global();
|
||||
|
||||
if let Ok(post_message) = js_sys::Reflect::get(&global, &JsValue::from_str("postMessage")) {
|
||||
if post_message.is_function() {
|
||||
// If we can create ArrayBuffer and postMessage exists, transferable is supported
|
||||
return !buffer.is_undefined();
|
||||
}
|
||||
}
|
||||
|
||||
// Also check window.postMessage
|
||||
if let Some(window) = web_sys::window() {
|
||||
// postMessage is available
|
||||
return true;
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
/// Get a summary of all available features.
|
||||
///
|
||||
/// # Returns
|
||||
/// JSON string with feature availability.
|
||||
#[wasm_bindgen]
|
||||
pub fn feature_summary() -> String {
|
||||
let features = serde_json::json!({
|
||||
"shared_array_buffer": is_shared_array_buffer_available(),
|
||||
"atomics": is_atomics_available(),
|
||||
"cross_origin_isolated": cross_origin_isolated(),
|
||||
"web_workers": is_web_workers_available(),
|
||||
"simd": is_simd_available(),
|
||||
"bigint": is_bigint_available(),
|
||||
"transferable": is_transferable_available(),
|
||||
"optimal_workers": optimal_worker_count(),
|
||||
});
|
||||
|
||||
serde_json::to_string_pretty(&features).unwrap_or_else(|_| "{}".to_string())
|
||||
}
|
||||
|
||||
/// Browser capability level for parallel inference.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum CapabilityLevel {
|
||||
/// Full parallel capability with shared memory
|
||||
Full,
|
||||
/// Partial capability - workers available but no shared memory
|
||||
Partial,
|
||||
/// No parallel capability - single-threaded only
|
||||
None,
|
||||
}
|
||||
|
||||
/// Determine the capability level for parallel inference.
|
||||
///
|
||||
/// # Returns
|
||||
/// The capability level based on available features.
|
||||
#[wasm_bindgen]
|
||||
pub fn detect_capability_level() -> String {
|
||||
let level = if is_shared_array_buffer_available()
|
||||
&& is_atomics_available()
|
||||
&& is_web_workers_available()
|
||||
&& cross_origin_isolated()
|
||||
{
|
||||
CapabilityLevel::Full
|
||||
} else if is_web_workers_available() {
|
||||
CapabilityLevel::Partial
|
||||
} else {
|
||||
CapabilityLevel::None
|
||||
};
|
||||
|
||||
match level {
|
||||
CapabilityLevel::Full => "full".to_string(),
|
||||
CapabilityLevel::Partial => "partial".to_string(),
|
||||
CapabilityLevel::None => "none".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the environment supports parallel inference.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `require_shared_memory` - Whether to require SharedArrayBuffer
|
||||
///
|
||||
/// # Returns
|
||||
/// `true` if parallel inference is supported, `false` otherwise.
|
||||
#[wasm_bindgen]
|
||||
pub fn supports_parallel_inference(require_shared_memory: bool) -> bool {
|
||||
if !is_web_workers_available() {
|
||||
return false;
|
||||
}
|
||||
|
||||
if require_shared_memory {
|
||||
is_shared_array_buffer_available() && is_atomics_available() && cross_origin_isolated()
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a message explaining why parallel inference is not available.
|
||||
///
|
||||
/// # Returns
|
||||
/// Explanation string, or empty string if parallel inference is available.
|
||||
#[wasm_bindgen]
|
||||
pub fn parallel_inference_unavailable_reason() -> String {
|
||||
if !is_web_workers_available() {
|
||||
return "Web Workers are not available in this environment.".to_string();
|
||||
}
|
||||
|
||||
if !is_shared_array_buffer_available() {
|
||||
return "SharedArrayBuffer is not available. This may be due to missing cross-origin isolation headers.".to_string();
|
||||
}
|
||||
|
||||
if !is_atomics_available() {
|
||||
return "Atomics API is not available.".to_string();
|
||||
}
|
||||
|
||||
if !cross_origin_isolated() {
|
||||
return "Page is not cross-origin isolated. Required headers:\n\
|
||||
- Cross-Origin-Opener-Policy: same-origin\n\
|
||||
- Cross-Origin-Embedder-Policy: require-corp"
|
||||
.to_string();
|
||||
}
|
||||
|
||||
String::new()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_capability_level() {
|
||||
// These tests will behave differently in WASM vs native
|
||||
let level = detect_capability_level();
|
||||
assert!(level == "full" || level == "partial" || level == "none");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_feature_summary() {
|
||||
let summary = feature_summary();
|
||||
assert!(summary.contains("shared_array_buffer"));
|
||||
assert!(summary.contains("optimal_workers"));
|
||||
}
|
||||
}
|
||||
633
vendor/ruvector/crates/ruvllm-wasm/src/workers/messages.rs
vendored
Normal file
633
vendor/ruvector/crates/ruvllm-wasm/src/workers/messages.rs
vendored
Normal file
@@ -0,0 +1,633 @@
|
||||
//! Message Protocol for Web Worker Communication
|
||||
//!
|
||||
//! Defines the message types used for communication between the main thread
|
||||
//! and Web Workers, including task definitions and responses.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Unique identifier for a task.
|
||||
pub type TaskId = u64;
|
||||
|
||||
/// Message sent from main thread to worker.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum WorkerMessage {
|
||||
/// Initialize the worker with configuration.
|
||||
Initialize {
|
||||
/// Worker ID
|
||||
worker_id: usize,
|
||||
/// Total number of workers
|
||||
total_workers: usize,
|
||||
/// Whether shared memory is available
|
||||
shared_memory: bool,
|
||||
},
|
||||
|
||||
/// Matrix multiplication task.
|
||||
ComputeMatmul {
|
||||
/// Unique task ID
|
||||
task_id: TaskId,
|
||||
/// Offset into shared buffer for matrix A
|
||||
a_offset: usize,
|
||||
/// Offset into shared buffer for matrix B
|
||||
b_offset: usize,
|
||||
/// Offset into shared buffer for output matrix C
|
||||
c_offset: usize,
|
||||
/// Number of rows in A (and C)
|
||||
m: usize,
|
||||
/// Number of columns in B (and C)
|
||||
n: usize,
|
||||
/// Number of columns in A / rows in B
|
||||
k: usize,
|
||||
/// Starting row for this worker's chunk
|
||||
row_start: usize,
|
||||
/// Ending row (exclusive) for this worker's chunk
|
||||
row_end: usize,
|
||||
},
|
||||
|
||||
/// Attention computation task.
|
||||
ComputeAttention {
|
||||
/// Unique task ID
|
||||
task_id: TaskId,
|
||||
/// Offset into shared buffer for Q
|
||||
q_offset: usize,
|
||||
/// Offset into shared buffer for K
|
||||
k_offset: usize,
|
||||
/// Offset into shared buffer for V
|
||||
v_offset: usize,
|
||||
/// Offset into shared buffer for output
|
||||
output_offset: usize,
|
||||
/// Number of heads to process (head_start to head_end)
|
||||
head_start: usize,
|
||||
/// Ending head (exclusive)
|
||||
head_end: usize,
|
||||
/// Total number of heads
|
||||
num_heads: usize,
|
||||
/// Head dimension
|
||||
head_dim: usize,
|
||||
/// Sequence length
|
||||
seq_len: usize,
|
||||
},
|
||||
|
||||
/// Layer normalization task.
|
||||
ComputeNorm {
|
||||
/// Unique task ID
|
||||
task_id: TaskId,
|
||||
/// Offset into shared buffer for input
|
||||
input_offset: usize,
|
||||
/// Offset into shared buffer for output
|
||||
output_offset: usize,
|
||||
/// Offset for gamma (scale) parameters
|
||||
gamma_offset: usize,
|
||||
/// Offset for beta (shift) parameters
|
||||
beta_offset: usize,
|
||||
/// Hidden dimension
|
||||
hidden_dim: usize,
|
||||
/// Starting batch index
|
||||
batch_start: usize,
|
||||
/// Ending batch index (exclusive)
|
||||
batch_end: usize,
|
||||
/// Epsilon for numerical stability
|
||||
epsilon: f32,
|
||||
},
|
||||
|
||||
/// Softmax computation task.
|
||||
ComputeSoftmax {
|
||||
/// Unique task ID
|
||||
task_id: TaskId,
|
||||
/// Offset into shared buffer for input/output
|
||||
data_offset: usize,
|
||||
/// Dimension along which to compute softmax
|
||||
dim_size: usize,
|
||||
/// Starting index
|
||||
start: usize,
|
||||
/// Ending index (exclusive)
|
||||
end: usize,
|
||||
},
|
||||
|
||||
/// Element-wise operation task.
|
||||
ComputeElementwise {
|
||||
/// Unique task ID
|
||||
task_id: TaskId,
|
||||
/// Operation type
|
||||
operation: ElementwiseOp,
|
||||
/// Offset for first input
|
||||
a_offset: usize,
|
||||
/// Offset for second input (optional for unary ops)
|
||||
b_offset: Option<usize>,
|
||||
/// Offset for output
|
||||
output_offset: usize,
|
||||
/// Starting index
|
||||
start: usize,
|
||||
/// Ending index (exclusive)
|
||||
end: usize,
|
||||
/// Scalar value (for scalar ops)
|
||||
scalar: Option<f32>,
|
||||
},
|
||||
|
||||
/// Reduction operation task.
|
||||
ComputeReduce {
|
||||
/// Unique task ID
|
||||
task_id: TaskId,
|
||||
/// Operation type
|
||||
operation: ReduceOp,
|
||||
/// Offset for input
|
||||
input_offset: usize,
|
||||
/// Offset for partial result
|
||||
partial_offset: usize,
|
||||
/// Starting index
|
||||
start: usize,
|
||||
/// Ending index (exclusive)
|
||||
end: usize,
|
||||
},
|
||||
|
||||
/// Generic task with data copied via message (fallback mode).
|
||||
ComputeWithData {
|
||||
/// Unique task ID
|
||||
task_id: TaskId,
|
||||
/// Operation type
|
||||
operation: OperationType,
|
||||
/// Input data A
|
||||
data_a: Vec<f32>,
|
||||
/// Input data B (optional)
|
||||
data_b: Option<Vec<f32>>,
|
||||
/// Operation parameters
|
||||
params: OperationParams,
|
||||
/// Chunk range
|
||||
chunk_start: usize,
|
||||
chunk_end: usize,
|
||||
},
|
||||
|
||||
/// Ping message for health check.
|
||||
Ping {
|
||||
/// Timestamp in milliseconds
|
||||
timestamp: f64,
|
||||
},
|
||||
|
||||
/// Shutdown the worker.
|
||||
Shutdown,
|
||||
}
|
||||
|
||||
/// Message sent from worker to main thread.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum WorkerResponse {
|
||||
/// Worker has been initialized.
|
||||
Initialized {
|
||||
/// Worker ID
|
||||
worker_id: usize,
|
||||
/// Capabilities
|
||||
capabilities: WorkerCapabilities,
|
||||
},
|
||||
|
||||
/// Task completed successfully.
|
||||
TaskComplete {
|
||||
/// Task ID
|
||||
task_id: TaskId,
|
||||
/// Duration in milliseconds
|
||||
duration_ms: f64,
|
||||
/// Optional metrics
|
||||
metrics: Option<TaskMetrics>,
|
||||
},
|
||||
|
||||
/// Task completed with result data (fallback mode).
|
||||
TaskCompleteWithData {
|
||||
/// Task ID
|
||||
task_id: TaskId,
|
||||
/// Result data
|
||||
data: Vec<f32>,
|
||||
/// Duration in milliseconds
|
||||
duration_ms: f64,
|
||||
},
|
||||
|
||||
/// Task failed.
|
||||
Error {
|
||||
/// Task ID
|
||||
task_id: TaskId,
|
||||
/// Error message
|
||||
message: String,
|
||||
/// Error code
|
||||
code: ErrorCode,
|
||||
},
|
||||
|
||||
/// Pong response to ping.
|
||||
Pong {
|
||||
/// Worker ID
|
||||
worker_id: usize,
|
||||
/// Original timestamp
|
||||
timestamp: f64,
|
||||
/// Worker's current timestamp
|
||||
worker_timestamp: f64,
|
||||
},
|
||||
|
||||
/// Worker is shutting down.
|
||||
ShuttingDown {
|
||||
/// Worker ID
|
||||
worker_id: usize,
|
||||
},
|
||||
}
|
||||
|
||||
/// Worker capabilities reported during initialization.
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct WorkerCapabilities {
|
||||
/// SIMD support available
|
||||
pub simd: bool,
|
||||
/// SharedArrayBuffer support
|
||||
pub shared_memory: bool,
|
||||
/// Atomics support
|
||||
pub atomics: bool,
|
||||
/// BigInt support
|
||||
pub bigint: bool,
|
||||
}
|
||||
|
||||
/// Metrics from task execution.
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct TaskMetrics {
|
||||
/// Number of floating point operations
|
||||
pub flops: u64,
|
||||
/// Bytes read
|
||||
pub bytes_read: u64,
|
||||
/// Bytes written
|
||||
pub bytes_written: u64,
|
||||
/// Cache hits (if applicable)
|
||||
pub cache_hits: u64,
|
||||
/// Cache misses (if applicable)
|
||||
pub cache_misses: u64,
|
||||
}
|
||||
|
||||
/// Element-wise operations.
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub enum ElementwiseOp {
|
||||
/// Addition
|
||||
Add,
|
||||
/// Subtraction
|
||||
Sub,
|
||||
/// Multiplication
|
||||
Mul,
|
||||
/// Division
|
||||
Div,
|
||||
/// Maximum
|
||||
Max,
|
||||
/// Minimum
|
||||
Min,
|
||||
/// Power
|
||||
Pow,
|
||||
/// Exponential
|
||||
Exp,
|
||||
/// Natural logarithm
|
||||
Log,
|
||||
/// Square root
|
||||
Sqrt,
|
||||
/// Absolute value
|
||||
Abs,
|
||||
/// Negation
|
||||
Neg,
|
||||
/// ReLU activation
|
||||
Relu,
|
||||
/// GeLU activation
|
||||
Gelu,
|
||||
/// SiLU (Swish) activation
|
||||
Silu,
|
||||
/// Tanh activation
|
||||
Tanh,
|
||||
/// Sigmoid activation
|
||||
Sigmoid,
|
||||
/// Add scalar
|
||||
AddScalar,
|
||||
/// Multiply by scalar
|
||||
MulScalar,
|
||||
}
|
||||
|
||||
/// Reduction operations.
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub enum ReduceOp {
|
||||
/// Sum reduction
|
||||
Sum,
|
||||
/// Mean reduction
|
||||
Mean,
|
||||
/// Max reduction
|
||||
Max,
|
||||
/// Min reduction
|
||||
Min,
|
||||
/// Product reduction
|
||||
Prod,
|
||||
/// Sum of squares
|
||||
SumSq,
|
||||
/// L2 norm
|
||||
Norm2,
|
||||
}
|
||||
|
||||
/// Operation type for generic tasks.
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub enum OperationType {
|
||||
/// Matrix multiplication
|
||||
Matmul,
|
||||
/// Attention computation
|
||||
Attention,
|
||||
/// Layer normalization
|
||||
LayerNorm,
|
||||
/// Softmax
|
||||
Softmax,
|
||||
/// Element-wise
|
||||
Elementwise,
|
||||
/// Reduction
|
||||
Reduce,
|
||||
}
|
||||
|
||||
/// Parameters for generic operations.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OperationParams {
|
||||
/// Matrix dimensions [m, n, k] for matmul
|
||||
pub dims: Vec<usize>,
|
||||
/// Additional parameters
|
||||
pub extra: HashMap<String, f64>,
|
||||
}
|
||||
|
||||
impl Default for OperationParams {
|
||||
fn default() -> Self {
|
||||
OperationParams {
|
||||
dims: Vec::new(),
|
||||
extra: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl OperationParams {
|
||||
/// Create parameters for matrix multiplication.
|
||||
pub fn matmul(m: usize, n: usize, k: usize) -> Self {
|
||||
OperationParams {
|
||||
dims: vec![m, n, k],
|
||||
extra: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create parameters for attention.
|
||||
pub fn attention(num_heads: usize, head_dim: usize, seq_len: usize) -> Self {
|
||||
let mut extra = HashMap::new();
|
||||
extra.insert("num_heads".to_string(), num_heads as f64);
|
||||
extra.insert("head_dim".to_string(), head_dim as f64);
|
||||
extra.insert("seq_len".to_string(), seq_len as f64);
|
||||
|
||||
OperationParams {
|
||||
dims: vec![num_heads, head_dim, seq_len],
|
||||
extra,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create parameters for layer norm.
|
||||
pub fn layer_norm(hidden_dim: usize, epsilon: f32) -> Self {
|
||||
let mut extra = HashMap::new();
|
||||
extra.insert("epsilon".to_string(), epsilon as f64);
|
||||
|
||||
OperationParams {
|
||||
dims: vec![hidden_dim],
|
||||
extra,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Error codes for worker responses.
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub enum ErrorCode {
|
||||
/// Invalid message format
|
||||
InvalidMessage,
|
||||
/// Memory access violation
|
||||
MemoryError,
|
||||
/// Invalid dimensions
|
||||
DimensionMismatch,
|
||||
/// Operation not supported
|
||||
UnsupportedOperation,
|
||||
/// Worker not initialized
|
||||
NotInitialized,
|
||||
/// Out of memory
|
||||
OutOfMemory,
|
||||
/// Internal error
|
||||
InternalError,
|
||||
/// Timeout
|
||||
Timeout,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ErrorCode {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
ErrorCode::InvalidMessage => write!(f, "Invalid message format"),
|
||||
ErrorCode::MemoryError => write!(f, "Memory access violation"),
|
||||
ErrorCode::DimensionMismatch => write!(f, "Dimension mismatch"),
|
||||
ErrorCode::UnsupportedOperation => write!(f, "Unsupported operation"),
|
||||
ErrorCode::NotInitialized => write!(f, "Worker not initialized"),
|
||||
ErrorCode::OutOfMemory => write!(f, "Out of memory"),
|
||||
ErrorCode::InternalError => write!(f, "Internal error"),
|
||||
ErrorCode::Timeout => write!(f, "Operation timed out"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Task status for tracking progress.
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub enum TaskStatus {
|
||||
/// Task is pending
|
||||
Pending,
|
||||
/// Task is being processed
|
||||
Processing,
|
||||
/// Task completed successfully
|
||||
Completed,
|
||||
/// Task failed
|
||||
Failed,
|
||||
/// Task was cancelled
|
||||
Cancelled,
|
||||
}
|
||||
|
||||
/// Pending task information.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PendingTask {
|
||||
/// Task ID
|
||||
pub task_id: TaskId,
|
||||
/// Operation type
|
||||
pub operation: OperationType,
|
||||
/// Status
|
||||
pub status: TaskStatus,
|
||||
/// Assigned worker ID
|
||||
pub worker_id: Option<usize>,
|
||||
/// Start time
|
||||
pub started_at: Option<f64>,
|
||||
}
|
||||
|
||||
impl PendingTask {
|
||||
/// Create a new pending task.
|
||||
pub fn new(task_id: TaskId, operation: OperationType) -> Self {
|
||||
PendingTask {
|
||||
task_id,
|
||||
operation,
|
||||
status: TaskStatus::Pending,
|
||||
worker_id: None,
|
||||
started_at: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Task queue for managing pending tasks.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct TaskQueue {
|
||||
tasks: HashMap<TaskId, PendingTask>,
|
||||
next_task_id: TaskId,
|
||||
}
|
||||
|
||||
impl TaskQueue {
|
||||
/// Create a new task queue.
|
||||
pub fn new() -> Self {
|
||||
TaskQueue {
|
||||
tasks: HashMap::new(),
|
||||
next_task_id: 1,
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a new task ID.
|
||||
pub fn next_id(&mut self) -> TaskId {
|
||||
let id = self.next_task_id;
|
||||
self.next_task_id += 1;
|
||||
id
|
||||
}
|
||||
|
||||
/// Add a task to the queue.
|
||||
pub fn add(&mut self, task: PendingTask) {
|
||||
self.tasks.insert(task.task_id, task);
|
||||
}
|
||||
|
||||
/// Get a task by ID.
|
||||
pub fn get(&self, task_id: TaskId) -> Option<&PendingTask> {
|
||||
self.tasks.get(&task_id)
|
||||
}
|
||||
|
||||
/// Get a mutable reference to a task.
|
||||
pub fn get_mut(&mut self, task_id: TaskId) -> Option<&mut PendingTask> {
|
||||
self.tasks.get_mut(&task_id)
|
||||
}
|
||||
|
||||
/// Remove a task from the queue.
|
||||
pub fn remove(&mut self, task_id: TaskId) -> Option<PendingTask> {
|
||||
self.tasks.remove(&task_id)
|
||||
}
|
||||
|
||||
/// Update task status.
|
||||
pub fn update_status(&mut self, task_id: TaskId, status: TaskStatus) {
|
||||
if let Some(task) = self.tasks.get_mut(&task_id) {
|
||||
task.status = status;
|
||||
}
|
||||
}
|
||||
|
||||
/// Get all pending tasks.
|
||||
pub fn pending_tasks(&self) -> Vec<&PendingTask> {
|
||||
self.tasks
|
||||
.values()
|
||||
.filter(|t| t.status == TaskStatus::Pending)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get number of pending tasks.
|
||||
pub fn pending_count(&self) -> usize {
|
||||
self.tasks
|
||||
.values()
|
||||
.filter(|t| t.status == TaskStatus::Pending)
|
||||
.count()
|
||||
}
|
||||
|
||||
/// Clear all tasks.
|
||||
pub fn clear(&mut self) {
|
||||
self.tasks.clear();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_task_queue() {
|
||||
let mut queue = TaskQueue::new();
|
||||
|
||||
let id1 = queue.next_id();
|
||||
let id2 = queue.next_id();
|
||||
|
||||
assert_eq!(id1, 1);
|
||||
assert_eq!(id2, 2);
|
||||
|
||||
queue.add(PendingTask::new(id1, OperationType::Matmul));
|
||||
queue.add(PendingTask::new(id2, OperationType::Attention));
|
||||
|
||||
assert_eq!(queue.pending_count(), 2);
|
||||
|
||||
queue.update_status(id1, TaskStatus::Completed);
|
||||
assert_eq!(queue.pending_count(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_operation_params() {
|
||||
let params = OperationParams::matmul(10, 20, 30);
|
||||
assert_eq!(params.dims, vec![10, 20, 30]);
|
||||
|
||||
let params = OperationParams::layer_norm(512, 1e-5);
|
||||
assert_eq!(params.dims, vec![512]);
|
||||
assert!((params.extra["epsilon"] - 1e-5).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_serialization() {
|
||||
let msg = WorkerMessage::ComputeMatmul {
|
||||
task_id: 1,
|
||||
a_offset: 0,
|
||||
b_offset: 1000,
|
||||
c_offset: 2000,
|
||||
m: 10,
|
||||
n: 20,
|
||||
k: 30,
|
||||
row_start: 0,
|
||||
row_end: 5,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&msg).unwrap();
|
||||
let parsed: WorkerMessage = serde_json::from_str(&json).unwrap();
|
||||
|
||||
match parsed {
|
||||
WorkerMessage::ComputeMatmul {
|
||||
task_id, m, n, k, ..
|
||||
} => {
|
||||
assert_eq!(task_id, 1);
|
||||
assert_eq!(m, 10);
|
||||
assert_eq!(n, 20);
|
||||
assert_eq!(k, 30);
|
||||
}
|
||||
_ => panic!("Wrong message type"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_response_serialization() {
|
||||
let resp = WorkerResponse::TaskComplete {
|
||||
task_id: 42,
|
||||
duration_ms: 123.45,
|
||||
metrics: Some(TaskMetrics {
|
||||
flops: 1000000,
|
||||
bytes_read: 4000,
|
||||
bytes_written: 2000,
|
||||
..Default::default()
|
||||
}),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&resp).unwrap();
|
||||
let parsed: WorkerResponse = serde_json::from_str(&json).unwrap();
|
||||
|
||||
match parsed {
|
||||
WorkerResponse::TaskComplete {
|
||||
task_id,
|
||||
duration_ms,
|
||||
metrics,
|
||||
} => {
|
||||
assert_eq!(task_id, 42);
|
||||
assert!((duration_ms - 123.45).abs() < 0.001);
|
||||
assert!(metrics.is_some());
|
||||
assert_eq!(metrics.unwrap().flops, 1000000);
|
||||
}
|
||||
_ => panic!("Wrong response type"),
|
||||
}
|
||||
}
|
||||
}
|
||||
505
vendor/ruvector/crates/ruvllm-wasm/src/workers/mod.rs
vendored
Normal file
505
vendor/ruvector/crates/ruvllm-wasm/src/workers/mod.rs
vendored
Normal file
@@ -0,0 +1,505 @@
|
||||
//! Web Workers for Parallel Inference in WASM
|
||||
//!
|
||||
//! This module provides multi-threaded execution in browsers using Web Workers
|
||||
//! with SharedArrayBuffer for zero-copy data sharing.
|
||||
//!
|
||||
//! # Architecture
|
||||
//!
|
||||
//! ```text
|
||||
//! ┌─────────────────────────────────────────────────────────────────┐
|
||||
//! │ Main Thread │
|
||||
//! │ ┌──────────────────┐ ┌──────────────────┐ │
|
||||
//! │ │ ParallelInference│ │ SharedBufferMgr │ │
|
||||
//! │ └────────┬─────────┘ └────────┬─────────┘ │
|
||||
//! │ │ │ │
|
||||
//! │ ▼ ▼ │
|
||||
//! │ ┌────────────────────────────────────────┐ │
|
||||
//! │ │ WorkerPool │ │
|
||||
//! │ │ ┌──────────┐ ┌──────────┐ ┌──────────┐│ │
|
||||
//! │ │ │TaskQueue │ │SharedMem │ │ Workers ││ │
|
||||
//! │ │ └──────────┘ └──────────┘ └──────────┘│ │
|
||||
//! │ └────────────────────────────────────────┘ │
|
||||
//! └─────────────────────────────────────────────────────────────────┘
|
||||
//! │ postMessage │
|
||||
//! ▼ ▼
|
||||
//! ┌────────────────┐ ┌────────────────┐ ┌────────────────┐
|
||||
//! │ Worker 0 │ │ Worker 1 │ │ Worker N │
|
||||
//! │ ┌────────────┐ │ │ ┌────────────┐ │ │ ┌────────────┐ │
|
||||
//! │ │SharedArray │ │ │ │SharedArray │ │ │ │SharedArray │ │
|
||||
//! │ │ Buffer │ │ │ │ Buffer │ │ │ │ Buffer │ │
|
||||
//! │ │ View │ │ │ │ View │ │ │ │ View │ │
|
||||
//! │ └────────────┘ │ │ └────────────┘ │ │ └────────────┘ │
|
||||
//! └────────────────┘ └────────────────┘ └────────────────┘
|
||||
//! ```
|
||||
//!
|
||||
//! # Features
|
||||
//!
|
||||
//! - **SharedArrayBuffer**: Zero-copy memory sharing between threads
|
||||
//! - **Atomics**: Thread synchronization primitives
|
||||
//! - **Dynamic Worker Count**: Based on `navigator.hardwareConcurrency`
|
||||
//! - **Graceful Fallback**: Single-threaded mode when SharedArrayBuffer unavailable
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! ```javascript
|
||||
//! import { ParallelInference } from 'ruvllm-wasm';
|
||||
//!
|
||||
//! // Create parallel inference engine
|
||||
//! const engine = await ParallelInference.new(4); // 4 workers
|
||||
//!
|
||||
//! // Check capabilities
|
||||
//! console.log('Workers:', engine.workerCount());
|
||||
//! console.log('Shared memory:', engine.isSharedMemoryAvailable());
|
||||
//!
|
||||
//! // Parallel matrix multiplication
|
||||
//! const result = await engine.matmul(a, b, m, n, k);
|
||||
//! ```
|
||||
//!
|
||||
//! # Browser Requirements
|
||||
//!
|
||||
//! For SharedArrayBuffer to work, the page must be served with:
|
||||
//! - `Cross-Origin-Opener-Policy: same-origin`
|
||||
//! - `Cross-Origin-Embedder-Policy: require-corp`
|
||||
|
||||
pub mod feature_detect;
|
||||
pub mod messages;
|
||||
pub mod pool;
|
||||
pub mod shared;
|
||||
|
||||
pub use feature_detect::*;
|
||||
pub use messages::*;
|
||||
pub use pool::*;
|
||||
pub use shared::*;
|
||||
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
/// Maximum recommended workers (prevent resource exhaustion)
|
||||
pub const MAX_WORKERS: usize = 16;
|
||||
|
||||
/// Default minimum workers
|
||||
pub const MIN_WORKERS: usize = 2;
|
||||
|
||||
/// WASM page size in bytes (64KB)
|
||||
pub const WASM_PAGE_SIZE: usize = 65536;
|
||||
|
||||
/// Alignment for SIMD operations (16 bytes for 128-bit SIMD)
|
||||
pub const SIMD_ALIGNMENT: usize = 16;
|
||||
|
||||
/// Main parallel inference interface for WASM.
|
||||
///
|
||||
/// Provides high-level API for parallel compute operations in the browser.
|
||||
/// Automatically manages worker pool and shared memory.
|
||||
#[wasm_bindgen]
|
||||
pub struct ParallelInference {
|
||||
pool: WorkerPool,
|
||||
shared_buffers: SharedBufferManager,
|
||||
initialized: bool,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl ParallelInference {
|
||||
/// Create a new ParallelInference instance.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `num_workers` - Number of workers to spawn. If None, uses optimal count.
|
||||
///
|
||||
/// # Returns
|
||||
/// A Promise that resolves to ParallelInference instance.
|
||||
///
|
||||
/// # Example (JavaScript)
|
||||
/// ```javascript
|
||||
/// const inference = await ParallelInference.new(4);
|
||||
/// ```
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub async fn new(num_workers: Option<usize>) -> Result<ParallelInference, JsValue> {
|
||||
crate::utils::set_panic_hook();
|
||||
|
||||
let worker_count = num_workers.unwrap_or_else(optimal_worker_count);
|
||||
let worker_count = worker_count.clamp(MIN_WORKERS, MAX_WORKERS);
|
||||
|
||||
crate::utils::log(&format!(
|
||||
"Initializing ParallelInference with {} workers",
|
||||
worker_count
|
||||
));
|
||||
|
||||
// Check for SharedArrayBuffer support
|
||||
let shared_memory_available = is_shared_array_buffer_available();
|
||||
if !shared_memory_available {
|
||||
crate::utils::warn(
|
||||
"SharedArrayBuffer not available. Using fallback mode with message passing.",
|
||||
);
|
||||
}
|
||||
|
||||
// Check cross-origin isolation
|
||||
if shared_memory_available && !cross_origin_isolated() {
|
||||
crate::utils::warn(
|
||||
"Page is not cross-origin isolated. SharedArrayBuffer may not work correctly.",
|
||||
);
|
||||
}
|
||||
|
||||
let pool = WorkerPool::new(worker_count).await?;
|
||||
let shared_buffers = SharedBufferManager::new();
|
||||
|
||||
crate::utils::log("ParallelInference initialized successfully");
|
||||
|
||||
Ok(ParallelInference {
|
||||
pool,
|
||||
shared_buffers,
|
||||
initialized: true,
|
||||
})
|
||||
}
|
||||
|
||||
/// Perform parallel matrix multiplication.
|
||||
///
|
||||
/// Computes C = A * B where:
|
||||
/// - A is m x k
|
||||
/// - B is k x n
|
||||
/// - C is m x n
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `a` - Matrix A as flat array (row-major)
|
||||
/// * `b` - Matrix B as flat array (row-major)
|
||||
/// * `m` - Number of rows in A
|
||||
/// * `n` - Number of columns in B
|
||||
/// * `k` - Number of columns in A / rows in B
|
||||
///
|
||||
/// # Returns
|
||||
/// Result matrix C as Float32Array
|
||||
#[wasm_bindgen]
|
||||
pub async fn matmul(
|
||||
&mut self,
|
||||
a: &[f32],
|
||||
b: &[f32],
|
||||
m: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
) -> Result<Vec<f32>, JsValue> {
|
||||
if !self.initialized {
|
||||
return Err(JsValue::from_str("ParallelInference not initialized"));
|
||||
}
|
||||
|
||||
// Validate dimensions
|
||||
if a.len() != m * k {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Matrix A size mismatch: expected {} ({}x{}), got {}",
|
||||
m * k,
|
||||
m,
|
||||
k,
|
||||
a.len()
|
||||
)));
|
||||
}
|
||||
if b.len() != k * n {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Matrix B size mismatch: expected {} ({}x{}), got {}",
|
||||
k * n,
|
||||
k,
|
||||
n,
|
||||
b.len()
|
||||
)));
|
||||
}
|
||||
|
||||
// For small matrices, compute directly on main thread
|
||||
if m * n * k < 10000 {
|
||||
return Ok(self.matmul_single_thread(a, b, m, n, k));
|
||||
}
|
||||
|
||||
// Use parallel computation
|
||||
self.pool.parallel_matmul(a, b, m, n, k).await
|
||||
}
|
||||
|
||||
/// Perform parallel multi-head attention.
|
||||
///
|
||||
/// Computes softmax(Q * K^T / sqrt(d_k)) * V for each attention head.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `q` - Query tensor (batch_size, num_heads, seq_len, head_dim)
|
||||
/// * `k` - Key tensor (batch_size, num_heads, seq_len, head_dim)
|
||||
/// * `v` - Value tensor (batch_size, num_heads, seq_len, head_dim)
|
||||
/// * `num_heads` - Number of attention heads
|
||||
/// * `head_dim` - Dimension of each head
|
||||
/// * `seq_len` - Sequence length
|
||||
///
|
||||
/// # Returns
|
||||
/// Output tensor (batch_size, num_heads, seq_len, head_dim)
|
||||
#[wasm_bindgen(js_name = attention)]
|
||||
pub async fn parallel_attention(
|
||||
&mut self,
|
||||
q: &[f32],
|
||||
k: &[f32],
|
||||
v: &[f32],
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
seq_len: usize,
|
||||
) -> Result<Vec<f32>, JsValue> {
|
||||
if !self.initialized {
|
||||
return Err(JsValue::from_str("ParallelInference not initialized"));
|
||||
}
|
||||
|
||||
// Validate dimensions
|
||||
let expected_size = num_heads * seq_len * head_dim;
|
||||
if q.len() != expected_size || k.len() != expected_size || v.len() != expected_size {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Tensor size mismatch: expected {}, got Q={}, K={}, V={}",
|
||||
expected_size,
|
||||
q.len(),
|
||||
k.len(),
|
||||
v.len()
|
||||
)));
|
||||
}
|
||||
|
||||
// For small tensors, compute on main thread
|
||||
if expected_size < 10000 {
|
||||
return Ok(self.attention_single_thread(q, k, v, num_heads, head_dim, seq_len));
|
||||
}
|
||||
|
||||
self.pool
|
||||
.parallel_attention(q, k, v, num_heads, head_dim, seq_len)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Perform parallel layer normalization.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `input` - Input tensor
|
||||
/// * `gamma` - Scale parameter
|
||||
/// * `beta` - Shift parameter
|
||||
/// * `epsilon` - Small constant for numerical stability
|
||||
///
|
||||
/// # Returns
|
||||
/// Normalized tensor
|
||||
#[wasm_bindgen(js_name = layerNorm)]
|
||||
pub async fn layer_norm(
|
||||
&mut self,
|
||||
input: &[f32],
|
||||
gamma: &[f32],
|
||||
beta: &[f32],
|
||||
epsilon: f32,
|
||||
) -> Result<Vec<f32>, JsValue> {
|
||||
if !self.initialized {
|
||||
return Err(JsValue::from_str("ParallelInference not initialized"));
|
||||
}
|
||||
|
||||
if input.len() < 1000 {
|
||||
return Ok(self.layer_norm_single_thread(input, gamma, beta, epsilon));
|
||||
}
|
||||
|
||||
self.pool.parallel_norm(input, gamma, beta, epsilon).await
|
||||
}
|
||||
|
||||
/// Get the number of active workers.
|
||||
#[wasm_bindgen(js_name = workerCount)]
|
||||
pub fn worker_count(&self) -> usize {
|
||||
self.pool.worker_count()
|
||||
}
|
||||
|
||||
/// Check if SharedArrayBuffer is available.
|
||||
#[wasm_bindgen(js_name = isSharedMemoryAvailable)]
|
||||
pub fn is_shared_memory_available(&self) -> bool {
|
||||
is_shared_array_buffer_available()
|
||||
}
|
||||
|
||||
/// Check if the page is cross-origin isolated.
|
||||
#[wasm_bindgen(js_name = isCrossOriginIsolated)]
|
||||
pub fn is_cross_origin_isolated(&self) -> bool {
|
||||
cross_origin_isolated()
|
||||
}
|
||||
|
||||
/// Check if Atomics API is available.
|
||||
#[wasm_bindgen(js_name = isAtomicsAvailable)]
|
||||
pub fn is_atomics_available(&self) -> bool {
|
||||
is_atomics_available()
|
||||
}
|
||||
|
||||
/// Get optimal worker count for the current hardware.
|
||||
#[wasm_bindgen(js_name = optimalWorkerCount)]
|
||||
pub fn get_optimal_worker_count() -> usize {
|
||||
optimal_worker_count()
|
||||
}
|
||||
|
||||
/// Terminate all workers and clean up resources.
|
||||
#[wasm_bindgen]
|
||||
pub fn terminate(&mut self) {
|
||||
self.pool.terminate();
|
||||
self.shared_buffers.clear();
|
||||
self.initialized = false;
|
||||
crate::utils::log("ParallelInference terminated");
|
||||
}
|
||||
|
||||
/// Get statistics about worker pool.
|
||||
#[wasm_bindgen(js_name = getStats)]
|
||||
pub fn get_stats(&self) -> Result<String, JsValue> {
|
||||
let stats = self.pool.stats();
|
||||
serde_json::to_string(&stats).map_err(|e| JsValue::from_str(&e.to_string()))
|
||||
}
|
||||
|
||||
// Private helper methods for single-threaded fallback
|
||||
|
||||
fn matmul_single_thread(&self, a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Vec<f32> {
|
||||
let mut c = vec![0.0f32; m * n];
|
||||
|
||||
for i in 0..m {
|
||||
for j in 0..n {
|
||||
let mut sum = 0.0f32;
|
||||
for l in 0..k {
|
||||
sum += a[i * k + l] * b[l * n + j];
|
||||
}
|
||||
c[i * n + j] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
c
|
||||
}
|
||||
|
||||
fn attention_single_thread(
|
||||
&self,
|
||||
q: &[f32],
|
||||
k: &[f32],
|
||||
v: &[f32],
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
seq_len: usize,
|
||||
) -> Vec<f32> {
|
||||
let mut output = vec![0.0f32; num_heads * seq_len * head_dim];
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
|
||||
for h in 0..num_heads {
|
||||
let head_offset = h * seq_len * head_dim;
|
||||
|
||||
// Compute attention scores: Q * K^T
|
||||
let mut scores = vec![0.0f32; seq_len * seq_len];
|
||||
for i in 0..seq_len {
|
||||
for j in 0..seq_len {
|
||||
let mut dot = 0.0f32;
|
||||
for d in 0..head_dim {
|
||||
dot +=
|
||||
q[head_offset + i * head_dim + d] * k[head_offset + j * head_dim + d];
|
||||
}
|
||||
scores[i * seq_len + j] = dot * scale;
|
||||
}
|
||||
}
|
||||
|
||||
// Softmax
|
||||
for i in 0..seq_len {
|
||||
let row_start = i * seq_len;
|
||||
let max_val = scores[row_start..row_start + seq_len]
|
||||
.iter()
|
||||
.fold(f32::NEG_INFINITY, |a, &b| a.max(b));
|
||||
|
||||
let mut sum = 0.0f32;
|
||||
for j in 0..seq_len {
|
||||
scores[row_start + j] = (scores[row_start + j] - max_val).exp();
|
||||
sum += scores[row_start + j];
|
||||
}
|
||||
|
||||
for j in 0..seq_len {
|
||||
scores[row_start + j] /= sum;
|
||||
}
|
||||
}
|
||||
|
||||
// Compute output: scores * V
|
||||
for i in 0..seq_len {
|
||||
for d in 0..head_dim {
|
||||
let mut sum = 0.0f32;
|
||||
for j in 0..seq_len {
|
||||
sum += scores[i * seq_len + j] * v[head_offset + j * head_dim + d];
|
||||
}
|
||||
output[head_offset + i * head_dim + d] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
fn layer_norm_single_thread(
|
||||
&self,
|
||||
input: &[f32],
|
||||
gamma: &[f32],
|
||||
beta: &[f32],
|
||||
epsilon: f32,
|
||||
) -> Vec<f32> {
|
||||
let n = input.len();
|
||||
let hidden_dim = gamma.len();
|
||||
|
||||
if n % hidden_dim != 0 {
|
||||
return input.to_vec(); // Fallback: return input unchanged
|
||||
}
|
||||
|
||||
let batch_size = n / hidden_dim;
|
||||
let mut output = vec![0.0f32; n];
|
||||
|
||||
for b in 0..batch_size {
|
||||
let start = b * hidden_dim;
|
||||
let end = start + hidden_dim;
|
||||
let slice = &input[start..end];
|
||||
|
||||
// Compute mean
|
||||
let mean: f32 = slice.iter().sum::<f32>() / hidden_dim as f32;
|
||||
|
||||
// Compute variance
|
||||
let variance: f32 =
|
||||
slice.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / hidden_dim as f32;
|
||||
|
||||
// Normalize
|
||||
let std = (variance + epsilon).sqrt();
|
||||
for i in 0..hidden_dim {
|
||||
output[start + i] = ((input[start + i] - mean) / std) * gamma[i] + beta[i];
|
||||
}
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ParallelInference {
|
||||
fn drop(&mut self) {
|
||||
self.terminate();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_matmul_single_thread() {
|
||||
let inference = ParallelInference {
|
||||
pool: WorkerPool::empty(),
|
||||
shared_buffers: SharedBufferManager::new(),
|
||||
initialized: true,
|
||||
};
|
||||
|
||||
// 2x3 * 3x2 = 2x2
|
||||
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let b = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
|
||||
let c = inference.matmul_single_thread(&a, &b, 2, 2, 3);
|
||||
|
||||
// Expected: [[22, 28], [49, 64]]
|
||||
assert_eq!(c.len(), 4);
|
||||
assert!((c[0] - 22.0).abs() < 0.001);
|
||||
assert!((c[1] - 28.0).abs() < 0.001);
|
||||
assert!((c[2] - 49.0).abs() < 0.001);
|
||||
assert!((c[3] - 64.0).abs() < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_layer_norm_single_thread() {
|
||||
let inference = ParallelInference {
|
||||
pool: WorkerPool::empty(),
|
||||
shared_buffers: SharedBufferManager::new(),
|
||||
initialized: true,
|
||||
};
|
||||
|
||||
let input = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let gamma = vec![1.0, 1.0, 1.0, 1.0];
|
||||
let beta = vec![0.0, 0.0, 0.0, 0.0];
|
||||
let epsilon = 1e-5;
|
||||
|
||||
let output = inference.layer_norm_single_thread(&input, &gamma, &beta, epsilon);
|
||||
|
||||
// After normalization, mean should be ~0 and std ~1
|
||||
let mean: f32 = output.iter().sum::<f32>() / output.len() as f32;
|
||||
assert!(mean.abs() < 0.001);
|
||||
}
|
||||
}
|
||||
1143
vendor/ruvector/crates/ruvllm-wasm/src/workers/pool.rs
vendored
Normal file
1143
vendor/ruvector/crates/ruvllm-wasm/src/workers/pool.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
593
vendor/ruvector/crates/ruvllm-wasm/src/workers/shared.rs
vendored
Normal file
593
vendor/ruvector/crates/ruvllm-wasm/src/workers/shared.rs
vendored
Normal file
@@ -0,0 +1,593 @@
|
||||
//! Shared Memory Types for Web Workers
|
||||
//!
|
||||
//! Provides zero-copy memory sharing between the main thread and Web Workers
|
||||
//! using SharedArrayBuffer.
|
||||
|
||||
use js_sys::{Float32Array, Int32Array, Object, Reflect, SharedArrayBuffer, Uint8Array};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
/// Alignment for tensor data (16 bytes for SIMD)
|
||||
const TENSOR_ALIGNMENT: usize = 16;
|
||||
|
||||
/// A tensor backed by SharedArrayBuffer for zero-copy sharing.
|
||||
///
|
||||
/// When SharedArrayBuffer is available, data can be shared between
|
||||
/// the main thread and workers without copying.
|
||||
#[derive(Clone)]
|
||||
pub struct SharedTensor {
|
||||
buffer: SharedArrayBuffer,
|
||||
view: Float32Array,
|
||||
shape: Vec<usize>,
|
||||
byte_offset: usize,
|
||||
}
|
||||
|
||||
impl SharedTensor {
|
||||
/// Create a new SharedTensor with the given shape.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `shape` - Tensor dimensions
|
||||
///
|
||||
/// # Returns
|
||||
/// A new SharedTensor with zero-initialized data
|
||||
pub fn new(shape: &[usize]) -> Result<Self, JsValue> {
|
||||
let num_elements: usize = shape.iter().product();
|
||||
let byte_length = num_elements * std::mem::size_of::<f32>();
|
||||
|
||||
// Align to TENSOR_ALIGNMENT
|
||||
let aligned_length = (byte_length + TENSOR_ALIGNMENT - 1) & !(TENSOR_ALIGNMENT - 1);
|
||||
|
||||
let buffer = SharedArrayBuffer::new(aligned_length as u32);
|
||||
let view = Float32Array::new(&buffer);
|
||||
|
||||
Ok(SharedTensor {
|
||||
buffer,
|
||||
view,
|
||||
shape: shape.to_vec(),
|
||||
byte_offset: 0,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a SharedTensor from existing data.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `data` - Tensor data as f32 slice
|
||||
/// * `shape` - Tensor dimensions
|
||||
///
|
||||
/// # Returns
|
||||
/// A new SharedTensor containing a copy of the data
|
||||
pub fn from_slice(data: &[f32], shape: &[usize]) -> Result<Self, JsValue> {
|
||||
let expected_len: usize = shape.iter().product();
|
||||
if data.len() != expected_len {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Data length {} doesn't match shape {:?} (expected {})",
|
||||
data.len(),
|
||||
shape,
|
||||
expected_len
|
||||
)));
|
||||
}
|
||||
|
||||
let tensor = Self::new(shape)?;
|
||||
tensor.view.copy_from(data);
|
||||
Ok(tensor)
|
||||
}
|
||||
|
||||
/// Create a SharedTensor as a view into an existing SharedArrayBuffer.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `buffer` - The SharedArrayBuffer to view
|
||||
/// * `byte_offset` - Offset into the buffer (in bytes)
|
||||
/// * `shape` - Tensor dimensions
|
||||
pub fn from_buffer(
|
||||
buffer: SharedArrayBuffer,
|
||||
byte_offset: usize,
|
||||
shape: &[usize],
|
||||
) -> Result<Self, JsValue> {
|
||||
let num_elements: usize = shape.iter().product();
|
||||
|
||||
let view = Float32Array::new_with_byte_offset_and_length(
|
||||
&buffer,
|
||||
byte_offset as u32,
|
||||
num_elements as u32,
|
||||
);
|
||||
|
||||
Ok(SharedTensor {
|
||||
buffer,
|
||||
view,
|
||||
shape: shape.to_vec(),
|
||||
byte_offset,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get the tensor shape.
|
||||
pub fn shape(&self) -> &[usize] {
|
||||
&self.shape
|
||||
}
|
||||
|
||||
/// Get the number of elements.
|
||||
pub fn len(&self) -> usize {
|
||||
self.shape.iter().product()
|
||||
}
|
||||
|
||||
/// Check if tensor is empty.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
|
||||
/// Get the underlying SharedArrayBuffer.
|
||||
pub fn buffer(&self) -> &SharedArrayBuffer {
|
||||
&self.buffer
|
||||
}
|
||||
|
||||
/// Get the Float32Array view.
|
||||
pub fn view(&self) -> &Float32Array {
|
||||
&self.view
|
||||
}
|
||||
|
||||
/// Get byte offset into the buffer.
|
||||
pub fn byte_offset(&self) -> usize {
|
||||
self.byte_offset
|
||||
}
|
||||
|
||||
/// Get the byte length of the tensor data.
|
||||
pub fn byte_length(&self) -> usize {
|
||||
self.len() * std::mem::size_of::<f32>()
|
||||
}
|
||||
|
||||
/// Copy data to a Vec<f32>.
|
||||
pub fn to_vec(&self) -> Vec<f32> {
|
||||
self.view.to_vec()
|
||||
}
|
||||
|
||||
/// Copy data from a slice.
|
||||
///
|
||||
/// # Safety Note (SECURITY)
|
||||
/// This method uses non-atomic write operations. When sharing memory
|
||||
/// between Web Workers, ensure proper synchronization (e.g., barriers)
|
||||
/// before and after bulk copies to prevent data races.
|
||||
pub fn copy_from(&self, data: &[f32]) -> Result<(), JsValue> {
|
||||
if data.len() != self.len() {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Data length {} doesn't match tensor length {}",
|
||||
data.len(),
|
||||
self.len()
|
||||
)));
|
||||
}
|
||||
self.view.copy_from(data);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get an element at the given index.
|
||||
///
|
||||
/// # Safety Note (SECURITY)
|
||||
/// This method uses non-atomic read operations. When sharing memory
|
||||
/// between Web Workers, use `get_atomic()` instead to avoid data races.
|
||||
/// Non-atomic reads may return torn values if another thread is writing.
|
||||
#[inline]
|
||||
pub fn get(&self, index: usize) -> Option<f32> {
|
||||
if index < self.len() {
|
||||
Some(self.view.get_index(index as u32))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Set an element at the given index.
|
||||
///
|
||||
/// # Safety Note (SECURITY)
|
||||
/// This method uses non-atomic write operations. When sharing memory
|
||||
/// between Web Workers, use `set_atomic()` instead to avoid data races.
|
||||
/// Non-atomic writes may cause torn writes visible to other threads.
|
||||
#[inline]
|
||||
pub fn set(&self, index: usize, value: f32) -> Result<(), JsValue> {
|
||||
if index >= self.len() {
|
||||
return Err(JsValue::from_str("Index out of bounds"));
|
||||
}
|
||||
self.view.set_index(index as u32, value);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create a subview of this tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `start` - Start index (in elements)
|
||||
/// * `shape` - Shape of the subview
|
||||
pub fn subview(&self, start: usize, shape: &[usize]) -> Result<Self, JsValue> {
|
||||
let num_elements: usize = shape.iter().product();
|
||||
if start + num_elements > self.len() {
|
||||
return Err(JsValue::from_str("Subview exceeds tensor bounds"));
|
||||
}
|
||||
|
||||
let byte_offset = self.byte_offset + start * std::mem::size_of::<f32>();
|
||||
|
||||
Self::from_buffer(self.buffer.clone(), byte_offset, shape)
|
||||
}
|
||||
|
||||
/// Fill with a constant value using Atomics (thread-safe).
|
||||
pub fn fill_atomic(&self, value: f32) {
|
||||
// Convert f32 to its bit representation for atomic operations
|
||||
let bits = value.to_bits() as i32;
|
||||
let int_view = Int32Array::new(&self.buffer);
|
||||
let offset = (self.byte_offset / 4) as u32;
|
||||
|
||||
for i in 0..self.len() as u32 {
|
||||
js_sys::Atomics::store(&int_view, offset + i, bits).expect("Atomics::store failed");
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a value using Atomics (thread-safe).
|
||||
pub fn get_atomic(&self, index: usize) -> Option<f32> {
|
||||
if index >= self.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let int_view = Int32Array::new(&self.buffer);
|
||||
let offset = (self.byte_offset / 4 + index) as u32;
|
||||
|
||||
let bits = js_sys::Atomics::load(&int_view, offset).expect("Atomics::load failed") as u32;
|
||||
Some(f32::from_bits(bits))
|
||||
}
|
||||
|
||||
/// Set a value using Atomics (thread-safe).
|
||||
pub fn set_atomic(&self, index: usize, value: f32) -> Result<(), JsValue> {
|
||||
if index >= self.len() {
|
||||
return Err(JsValue::from_str("Index out of bounds"));
|
||||
}
|
||||
|
||||
let int_view = Int32Array::new(&self.buffer);
|
||||
let offset = (self.byte_offset / 4 + index) as u32;
|
||||
let bits = value.to_bits() as i32;
|
||||
|
||||
js_sys::Atomics::store(&int_view, offset, bits).expect("Atomics::store failed");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for SharedTensor {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("SharedTensor")
|
||||
.field("shape", &self.shape)
|
||||
.field("byte_offset", &self.byte_offset)
|
||||
.field("len", &self.len())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
/// Region descriptor for shared memory allocation.
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
pub struct MemoryRegion {
|
||||
/// Offset in bytes from the start of the shared buffer
|
||||
pub offset: usize,
|
||||
/// Size in bytes
|
||||
pub size: usize,
|
||||
}
|
||||
|
||||
impl MemoryRegion {
|
||||
/// Create a new memory region.
|
||||
pub fn new(offset: usize, size: usize) -> Self {
|
||||
MemoryRegion { offset, size }
|
||||
}
|
||||
|
||||
/// Get end offset (exclusive).
|
||||
pub fn end(&self) -> usize {
|
||||
self.offset + self.size
|
||||
}
|
||||
|
||||
/// Check if this region overlaps with another.
|
||||
pub fn overlaps(&self, other: &MemoryRegion) -> bool {
|
||||
self.offset < other.end() && other.offset < self.end()
|
||||
}
|
||||
}
|
||||
|
||||
/// Manager for shared memory buffers.
|
||||
///
|
||||
/// Handles allocation and deallocation of regions within a large
|
||||
/// SharedArrayBuffer for efficient memory management.
|
||||
pub struct SharedBufferManager {
|
||||
/// Main shared buffer (allocated on demand)
|
||||
buffer: Option<SharedArrayBuffer>,
|
||||
/// Current buffer size in bytes
|
||||
buffer_size: usize,
|
||||
/// Allocated regions
|
||||
regions: HashMap<String, MemoryRegion>,
|
||||
/// Next allocation offset
|
||||
next_offset: usize,
|
||||
/// Alignment for allocations
|
||||
alignment: usize,
|
||||
}
|
||||
|
||||
impl SharedBufferManager {
|
||||
/// Create a new SharedBufferManager.
|
||||
pub fn new() -> Self {
|
||||
SharedBufferManager {
|
||||
buffer: None,
|
||||
buffer_size: 0,
|
||||
regions: HashMap::new(),
|
||||
next_offset: 0,
|
||||
alignment: TENSOR_ALIGNMENT,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with a pre-allocated buffer of the given size.
|
||||
pub fn with_capacity(capacity_bytes: usize) -> Result<Self, JsValue> {
|
||||
let aligned_capacity = (capacity_bytes + TENSOR_ALIGNMENT - 1) & !(TENSOR_ALIGNMENT - 1);
|
||||
|
||||
let buffer = SharedArrayBuffer::new(aligned_capacity as u32);
|
||||
|
||||
Ok(SharedBufferManager {
|
||||
buffer: Some(buffer),
|
||||
buffer_size: aligned_capacity,
|
||||
regions: HashMap::new(),
|
||||
next_offset: 0,
|
||||
alignment: TENSOR_ALIGNMENT,
|
||||
})
|
||||
}
|
||||
|
||||
/// Ensure buffer has at least the given capacity.
|
||||
pub fn ensure_capacity(&mut self, min_capacity: usize) -> Result<(), JsValue> {
|
||||
let aligned_capacity = (min_capacity + TENSOR_ALIGNMENT - 1) & !(TENSOR_ALIGNMENT - 1);
|
||||
|
||||
if self.buffer_size >= aligned_capacity {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Need to reallocate
|
||||
let new_buffer = SharedArrayBuffer::new(aligned_capacity as u32);
|
||||
|
||||
// Copy existing data if any
|
||||
if let Some(old_buffer) = &self.buffer {
|
||||
let old_view = Uint8Array::new(old_buffer);
|
||||
let new_view = Uint8Array::new(&new_buffer);
|
||||
new_view.set(&old_view, 0);
|
||||
}
|
||||
|
||||
self.buffer = Some(new_buffer);
|
||||
self.buffer_size = aligned_capacity;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Allocate a region for a tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `name` - Unique name for this region
|
||||
/// * `shape` - Tensor shape
|
||||
///
|
||||
/// # Returns
|
||||
/// A SharedTensor backed by the allocated region
|
||||
pub fn allocate(&mut self, name: &str, shape: &[usize]) -> Result<SharedTensor, JsValue> {
|
||||
if self.regions.contains_key(name) {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Region '{}' already allocated",
|
||||
name
|
||||
)));
|
||||
}
|
||||
|
||||
let num_elements: usize = shape.iter().product();
|
||||
let size_bytes = num_elements * std::mem::size_of::<f32>();
|
||||
let aligned_size = (size_bytes + self.alignment - 1) & !(self.alignment - 1);
|
||||
|
||||
// Align the offset
|
||||
let aligned_offset = (self.next_offset + self.alignment - 1) & !(self.alignment - 1);
|
||||
|
||||
// Ensure buffer has capacity
|
||||
self.ensure_capacity(aligned_offset + aligned_size)?;
|
||||
|
||||
let region = MemoryRegion::new(aligned_offset, aligned_size);
|
||||
self.regions.insert(name.to_string(), region);
|
||||
self.next_offset = aligned_offset + aligned_size;
|
||||
|
||||
let buffer = self.buffer.as_ref().unwrap().clone();
|
||||
SharedTensor::from_buffer(buffer, aligned_offset, shape)
|
||||
}
|
||||
|
||||
/// Get an existing tensor by name.
|
||||
pub fn get(&self, name: &str, shape: &[usize]) -> Result<SharedTensor, JsValue> {
|
||||
let region = self
|
||||
.regions
|
||||
.get(name)
|
||||
.ok_or_else(|| JsValue::from_str(&format!("Region '{}' not found", name)))?;
|
||||
|
||||
let buffer = self
|
||||
.buffer
|
||||
.as_ref()
|
||||
.ok_or_else(|| JsValue::from_str("Buffer not initialized"))?;
|
||||
|
||||
SharedTensor::from_buffer(buffer.clone(), region.offset, shape)
|
||||
}
|
||||
|
||||
/// Free a region.
|
||||
pub fn free(&mut self, name: &str) -> bool {
|
||||
self.regions.remove(name).is_some()
|
||||
}
|
||||
|
||||
/// Reset all allocations (but keep the buffer).
|
||||
pub fn reset(&mut self) {
|
||||
self.regions.clear();
|
||||
self.next_offset = 0;
|
||||
}
|
||||
|
||||
/// Clear everything including the buffer.
|
||||
pub fn clear(&mut self) {
|
||||
self.buffer = None;
|
||||
self.buffer_size = 0;
|
||||
self.regions.clear();
|
||||
self.next_offset = 0;
|
||||
}
|
||||
|
||||
/// Get the underlying SharedArrayBuffer.
|
||||
pub fn buffer(&self) -> Option<&SharedArrayBuffer> {
|
||||
self.buffer.as_ref()
|
||||
}
|
||||
|
||||
/// Get total allocated bytes.
|
||||
pub fn allocated_bytes(&self) -> usize {
|
||||
self.next_offset
|
||||
}
|
||||
|
||||
/// Get buffer capacity in bytes.
|
||||
pub fn capacity(&self) -> usize {
|
||||
self.buffer_size
|
||||
}
|
||||
|
||||
/// Get remaining available bytes.
|
||||
pub fn remaining(&self) -> usize {
|
||||
self.buffer_size.saturating_sub(self.next_offset)
|
||||
}
|
||||
|
||||
/// Get statistics about the buffer.
|
||||
pub fn stats(&self) -> SharedBufferStats {
|
||||
SharedBufferStats {
|
||||
capacity: self.buffer_size,
|
||||
allocated: self.next_offset,
|
||||
num_regions: self.regions.len(),
|
||||
regions: self.regions.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SharedBufferManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Statistics about shared buffer usage.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SharedBufferStats {
|
||||
/// Total capacity in bytes
|
||||
pub capacity: usize,
|
||||
/// Currently allocated bytes
|
||||
pub allocated: usize,
|
||||
/// Number of allocated regions
|
||||
pub num_regions: usize,
|
||||
/// All allocated regions
|
||||
pub regions: HashMap<String, MemoryRegion>,
|
||||
}
|
||||
|
||||
/// Synchronization primitive using SharedArrayBuffer and Atomics.
|
||||
///
|
||||
/// Provides wait/notify functionality for coordinating between workers.
|
||||
pub struct SharedBarrier {
|
||||
/// Shared state buffer
|
||||
state: SharedArrayBuffer,
|
||||
/// Int32 view for Atomics operations
|
||||
int_view: Int32Array,
|
||||
/// Number of participants
|
||||
count: usize,
|
||||
}
|
||||
|
||||
impl SharedBarrier {
|
||||
/// Create a new barrier for the given number of participants.
|
||||
pub fn new(count: usize) -> Self {
|
||||
// Allocate buffer for: [generation, arrived_count]
|
||||
let buffer = SharedArrayBuffer::new(8);
|
||||
let int_view = Int32Array::new(&buffer);
|
||||
|
||||
// Initialize
|
||||
js_sys::Atomics::store(&int_view, 0, 0).expect("Atomics::store failed"); // generation
|
||||
js_sys::Atomics::store(&int_view, 1, 0).expect("Atomics::store failed"); // arrived
|
||||
|
||||
SharedBarrier {
|
||||
state: buffer,
|
||||
int_view,
|
||||
count,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the underlying SharedArrayBuffer for sharing with workers.
|
||||
pub fn buffer(&self) -> &SharedArrayBuffer {
|
||||
&self.state
|
||||
}
|
||||
|
||||
/// Arrive at the barrier and wait for all participants.
|
||||
///
|
||||
/// Returns the generation number.
|
||||
pub fn wait(&self) -> Result<i32, JsValue> {
|
||||
let gen = js_sys::Atomics::load(&self.int_view, 0).expect("Atomics::load failed");
|
||||
let arrived = js_sys::Atomics::add(&self.int_view, 1, 1).expect("Atomics::add failed") + 1;
|
||||
|
||||
if arrived as usize == self.count {
|
||||
// Last to arrive - reset and notify
|
||||
js_sys::Atomics::store(&self.int_view, 1, 0).expect("Atomics::store failed");
|
||||
js_sys::Atomics::add(&self.int_view, 0, 1).expect("Atomics::add failed");
|
||||
js_sys::Atomics::notify(&self.int_view, 0).expect("Atomics::notify failed");
|
||||
} else {
|
||||
// Wait for generation to change
|
||||
let _ = js_sys::Atomics::wait(&self.int_view, 0, gen);
|
||||
}
|
||||
|
||||
Ok(js_sys::Atomics::load(&self.int_view, 0).expect("Atomics::load failed"))
|
||||
}
|
||||
|
||||
/// Reset the barrier.
|
||||
pub fn reset(&self) {
|
||||
js_sys::Atomics::store(&self.int_view, 0, 0).expect("Atomics::store failed");
|
||||
js_sys::Atomics::store(&self.int_view, 1, 0).expect("Atomics::store failed");
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for SharedBarrier {
|
||||
fn clone(&self) -> Self {
|
||||
SharedBarrier {
|
||||
state: self.state.clone(),
|
||||
int_view: Int32Array::new(&self.state),
|
||||
count: self.count,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_memory_region() {
|
||||
let r1 = MemoryRegion::new(0, 100);
|
||||
let r2 = MemoryRegion::new(50, 100);
|
||||
let r3 = MemoryRegion::new(100, 100);
|
||||
|
||||
assert!(r1.overlaps(&r2));
|
||||
assert!(!r1.overlaps(&r3));
|
||||
assert_eq!(r1.end(), 100);
|
||||
}
|
||||
|
||||
// Note: SharedTensor tests require wasm32 target due to SharedArrayBuffer
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
mod wasm_tests {
|
||||
use super::*;
|
||||
use wasm_bindgen_test::*;
|
||||
|
||||
wasm_bindgen_test_configure!(run_in_browser);
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_shared_tensor_new() {
|
||||
let tensor = SharedTensor::new(&[2, 3]).unwrap();
|
||||
assert_eq!(tensor.shape(), &[2, 3]);
|
||||
assert_eq!(tensor.len(), 6);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_shared_tensor_from_slice() {
|
||||
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let tensor = SharedTensor::from_slice(&data, &[2, 3]).unwrap();
|
||||
|
||||
let result = tensor.to_vec();
|
||||
assert_eq!(result, data);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_shared_buffer_manager() {
|
||||
let mut manager = SharedBufferManager::new();
|
||||
|
||||
let tensor1 = manager.allocate("input", &[10, 10]).unwrap();
|
||||
assert_eq!(tensor1.len(), 100);
|
||||
|
||||
let tensor2 = manager.allocate("output", &[10, 10]).unwrap();
|
||||
assert_eq!(tensor2.len(), 100);
|
||||
|
||||
assert!(manager.allocated_bytes() >= 800); // 200 floats * 4 bytes
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user