//! WebAssembly bindings for RuVector GNN //! //! This module provides high-performance browser bindings for Graph Neural Network //! operations on HNSW topology, including: //! - GNN layer forward passes //! - Tensor compression with adaptive level selection //! - Differentiable search with soft attention //! - Hierarchical forward propagation use ruvector_gnn::{ differentiable_search as core_differentiable_search, hierarchical_forward as core_hierarchical_forward, CompressedTensor, CompressionLevel, RuvectorLayer, TensorCompress, }; use serde::{Deserialize, Serialize}; use wasm_bindgen::prelude::*; /// Initialize panic hook for better error messages #[wasm_bindgen(start)] pub fn init() { #[cfg(feature = "console_error_panic_hook")] console_error_panic_hook::set_once(); } // ============================================================================ // Type Definitions for WASM // ============================================================================ /// Query configuration for differentiable search #[derive(Debug, Clone, Serialize, Deserialize)] #[wasm_bindgen] pub struct SearchConfig { /// Number of top results to return pub k: usize, /// Temperature for softmax (lower = sharper, higher = smoother) pub temperature: f32, } #[wasm_bindgen] impl SearchConfig { /// Create a new search configuration #[wasm_bindgen(constructor)] pub fn new(k: usize, temperature: f32) -> Self { Self { k, temperature } } } /// Search results with indices and weights (internal) #[derive(Debug, Clone, Serialize, Deserialize)] struct SearchResultInternal { /// Indices of top-k candidates indices: Vec, /// Soft weights for each result weights: Vec, } // ============================================================================ // JsRuvectorLayer - GNN Layer Wrapper // ============================================================================ /// Graph Neural Network layer for HNSW topology #[wasm_bindgen] pub struct JsRuvectorLayer { inner: RuvectorLayer, hidden_dim: usize, } #[wasm_bindgen] impl JsRuvectorLayer { /// Create a new GNN layer /// /// # Arguments /// * `input_dim` - Dimension of input node embeddings /// * `hidden_dim` - Dimension of hidden representations /// * `heads` - Number of attention heads /// * `dropout` - Dropout rate (0.0 to 1.0) #[wasm_bindgen(constructor)] pub fn new( input_dim: usize, hidden_dim: usize, heads: usize, dropout: f32, ) -> Result { let inner = RuvectorLayer::new(input_dim, hidden_dim, heads, dropout) .map_err(|e| JsValue::from_str(&e.to_string()))?; Ok(JsRuvectorLayer { inner, hidden_dim }) } /// Forward pass through the GNN layer /// /// # Arguments /// * `node_embedding` - Current node's embedding (Float32Array) /// * `neighbor_embeddings` - Embeddings of neighbor nodes (array of Float32Arrays) /// * `edge_weights` - Weights of edges to neighbors (Float32Array) /// /// # Returns /// Updated node embedding (Float32Array) #[wasm_bindgen] pub fn forward( &self, node_embedding: Vec, neighbor_embeddings: JsValue, edge_weights: Vec, ) -> Result, JsValue> { // Convert neighbor embeddings from JS value let neighbors: Vec> = serde_wasm_bindgen::from_value(neighbor_embeddings) .map_err(|e| { JsValue::from_str(&format!("Failed to parse neighbor embeddings: {}", e)) })?; // Validate inputs if neighbors.len() != edge_weights.len() { return Err(JsValue::from_str(&format!( "Number of neighbors ({}) must match number of edge weights ({})", neighbors.len(), edge_weights.len() ))); } // Call core forward let result = self .inner .forward(&node_embedding, &neighbors, &edge_weights); Ok(result) } /// Get the output dimension of this layer #[wasm_bindgen(getter, js_name = outputDim)] pub fn output_dim(&self) -> usize { self.hidden_dim } } // ============================================================================ // JsTensorCompress - Tensor Compression Wrapper // ============================================================================ /// Tensor compressor with adaptive level selection #[wasm_bindgen] pub struct JsTensorCompress { inner: TensorCompress, } #[wasm_bindgen] impl JsTensorCompress { /// Create a new tensor compressor #[wasm_bindgen(constructor)] pub fn new() -> Self { Self { inner: TensorCompress::new(), } } /// Compress an embedding based on access frequency /// /// # Arguments /// * `embedding` - The input embedding vector (Float32Array) /// * `access_freq` - Access frequency in range [0.0, 1.0] /// - f > 0.8: Full precision (hot data) /// - f > 0.4: Half precision (warm data) /// - f > 0.1: 8-bit PQ (cool data) /// - f > 0.01: 4-bit PQ (cold data) /// - f <= 0.01: Binary (archive) /// /// # Returns /// Compressed tensor as JsValue #[wasm_bindgen] pub fn compress(&self, embedding: Vec, access_freq: f32) -> Result { let compressed = self .inner .compress(&embedding, access_freq) .map_err(|e| JsValue::from_str(&format!("Compression failed: {}", e)))?; // Serialize using serde_wasm_bindgen serde_wasm_bindgen::to_value(&compressed) .map_err(|e| JsValue::from_str(&format!("Serialization failed: {}", e))) } /// Compress with explicit compression level /// /// # Arguments /// * `embedding` - The input embedding vector /// * `level` - Compression level ("none", "half", "pq8", "pq4", "binary") /// /// # Returns /// Compressed tensor as JsValue #[wasm_bindgen(js_name = compressWithLevel)] pub fn compress_with_level( &self, embedding: Vec, level: &str, ) -> Result { let compression_level = match level { "none" => CompressionLevel::None, "half" => CompressionLevel::Half { scale: 1.0 }, "pq8" => CompressionLevel::PQ8 { subvectors: 8, centroids: 16, }, "pq4" => CompressionLevel::PQ4 { subvectors: 8, outlier_threshold: 3.0, }, "binary" => CompressionLevel::Binary { threshold: 0.0 }, _ => { return Err(JsValue::from_str(&format!( "Unknown compression level: {}", level ))) } }; let compressed = self .inner .compress_with_level(&embedding, &compression_level) .map_err(|e| JsValue::from_str(&format!("Compression failed: {}", e)))?; // Serialize using serde_wasm_bindgen serde_wasm_bindgen::to_value(&compressed) .map_err(|e| JsValue::from_str(&format!("Serialization failed: {}", e))) } /// Decompress a compressed tensor /// /// # Arguments /// * `compressed` - Serialized compressed tensor (JsValue) /// /// # Returns /// Decompressed embedding vector (Float32Array) #[wasm_bindgen] pub fn decompress(&self, compressed: JsValue) -> Result, JsValue> { let compressed_tensor: CompressedTensor = serde_wasm_bindgen::from_value(compressed) .map_err(|e| JsValue::from_str(&format!("Deserialization failed: {}", e)))?; let decompressed = self .inner .decompress(&compressed_tensor) .map_err(|e| JsValue::from_str(&format!("Decompression failed: {}", e)))?; Ok(decompressed) } /// Get compression ratio estimate for a given access frequency /// /// # Arguments /// * `access_freq` - Access frequency in range [0.0, 1.0] /// /// # Returns /// Estimated compression ratio (original_size / compressed_size) #[wasm_bindgen(js_name = getCompressionRatio)] pub fn get_compression_ratio(&self, access_freq: f32) -> f32 { if access_freq > 0.8 { 1.0 // No compression } else if access_freq > 0.4 { 2.0 // Half precision } else if access_freq > 0.1 { 4.0 // 8-bit PQ } else if access_freq > 0.01 { 8.0 // 4-bit PQ } else { 32.0 // Binary } } } // ============================================================================ // Standalone Functions // ============================================================================ /// Differentiable search using soft attention mechanism /// /// # Arguments /// * `query` - The query vector (Float32Array) /// * `candidate_embeddings` - List of candidate embedding vectors (array of Float32Arrays) /// * `config` - Search configuration (k and temperature) /// /// # Returns /// Object with indices and weights for top-k candidates #[wasm_bindgen(js_name = differentiableSearch)] pub fn differentiable_search( query: Vec, candidate_embeddings: JsValue, config: &SearchConfig, ) -> Result { // Convert candidate embeddings from JS value let candidates: Vec> = serde_wasm_bindgen::from_value(candidate_embeddings) .map_err(|e| JsValue::from_str(&format!("Failed to parse candidate embeddings: {}", e)))?; // Call core search function let (indices, weights) = core_differentiable_search(&query, &candidates, config.k, config.temperature); let result = SearchResultInternal { indices, weights }; serde_wasm_bindgen::to_value(&result) .map_err(|e| JsValue::from_str(&format!("Failed to serialize result: {}", e))) } /// Hierarchical forward pass through multiple GNN layers /// /// # Arguments /// * `query` - The query vector (Float32Array) /// * `layer_embeddings` - Embeddings organized by layer (array of arrays of Float32Arrays) /// * `gnn_layers` - Array of GNN layers to process through /// /// # Returns /// Final embedding after hierarchical processing (Float32Array) #[wasm_bindgen(js_name = hierarchicalForward)] pub fn hierarchical_forward( query: Vec, layer_embeddings: JsValue, gnn_layers: Vec, ) -> Result, JsValue> { // Convert layer embeddings from JS value let embeddings: Vec>> = serde_wasm_bindgen::from_value(layer_embeddings) .map_err(|e| JsValue::from_str(&format!("Failed to parse layer embeddings: {}", e)))?; // Extract inner layers let core_layers: Vec = gnn_layers.iter().map(|l| l.inner.clone()).collect(); // Call core function let result = core_hierarchical_forward(&query, &embeddings, &core_layers); Ok(result) } // ============================================================================ // Utility Functions // ============================================================================ /// Get version information #[wasm_bindgen] pub fn version() -> String { env!("CARGO_PKG_VERSION").to_string() } /// Compute cosine similarity between two vectors /// /// # Arguments /// * `a` - First vector (Float32Array) /// * `b` - Second vector (Float32Array) /// /// # Returns /// Cosine similarity score [-1.0, 1.0] #[wasm_bindgen(js_name = cosineSimilarity)] pub fn cosine_similarity(a: Vec, b: Vec) -> Result { if a.len() != b.len() { return Err(JsValue::from_str(&format!( "Vector dimensions must match: {} vs {}", a.len(), b.len() ))); } let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(); let norm_a: f32 = a.iter().map(|x| x * x).sum::().sqrt(); let norm_b: f32 = b.iter().map(|x| x * x).sum::().sqrt(); if norm_a == 0.0 || norm_b == 0.0 { Ok(0.0) } else { Ok(dot_product / (norm_a * norm_b)) } } // ============================================================================ // Tests // ============================================================================ #[cfg(test)] mod tests { use super::*; use wasm_bindgen_test::*; wasm_bindgen_test_configure!(run_in_browser); #[wasm_bindgen_test] fn test_version() { assert!(!version().is_empty()); } #[wasm_bindgen_test] fn test_ruvector_layer_creation() { let layer = JsRuvectorLayer::new(4, 8, 2, 0.1); assert!(layer.is_ok()); } #[wasm_bindgen_test] fn test_tensor_compress_creation() { let compressor = JsTensorCompress::new(); assert_eq!(compressor.get_compression_ratio(1.0), 1.0); assert_eq!(compressor.get_compression_ratio(0.5), 2.0); } #[wasm_bindgen_test] fn test_cosine_similarity() { let a = vec![1.0, 0.0, 0.0]; let b = vec![1.0, 0.0, 0.0]; let sim = cosine_similarity(a, b).unwrap(); assert!((sim - 1.0).abs() < 1e-6); } #[wasm_bindgen_test] fn test_search_config() { let config = SearchConfig::new(5, 1.0); assert_eq!(config.k, 5); assert_eq!(config.temperature, 1.0); } }