//! Neural Attention Mechanisms (from ruvector-attention) //! //! Re-exports the 7 core neural attention mechanisms: //! - Scaled Dot-Product Attention //! - Multi-Head Attention //! - Hyperbolic Attention //! - Linear Attention (Performer) //! - Flash Attention //! - Local-Global Attention //! - Mixture of Experts (MoE) Attention use ruvector_attention::{ attention::{MultiHeadAttention, ScaledDotProductAttention}, hyperbolic::{HyperbolicAttention, HyperbolicAttentionConfig}, moe::{MoEAttention, MoEConfig}, sparse::{FlashAttention, LinearAttention, LocalGlobalAttention}, traits::Attention, }; use wasm_bindgen::prelude::*; // ============================================================================ // Scaled Dot-Product Attention // ============================================================================ /// Compute scaled dot-product attention /// /// Standard transformer attention: softmax(QK^T / sqrt(d)) * V /// /// # Arguments /// * `query` - Query vector (Float32Array) /// * `keys` - Array of key vectors (JsValue - array of Float32Arrays) /// * `values` - Array of value vectors (JsValue - array of Float32Arrays) /// * `scale` - Optional scaling factor (defaults to 1/sqrt(dim)) /// /// # Returns /// Attention-weighted output vector #[wasm_bindgen(js_name = scaledDotAttention)] pub fn scaled_dot_attention( query: &[f32], keys: JsValue, values: JsValue, scale: Option, ) -> Result, JsError> { let keys_vec: Vec> = serde_wasm_bindgen::from_value(keys) .map_err(|e| JsError::new(&format!("Failed to parse keys: {}", e)))?; let values_vec: Vec> = serde_wasm_bindgen::from_value(values) .map_err(|e| JsError::new(&format!("Failed to parse values: {}", e)))?; let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect(); let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect(); let attention = ScaledDotProductAttention::new(query.len()); attention .compute(query, &keys_refs, &values_refs) .map_err(|e| JsError::new(&e.to_string())) } // ============================================================================ // Multi-Head Attention // ============================================================================ /// Multi-head attention mechanism /// /// Splits input into multiple heads, applies attention, and concatenates results #[wasm_bindgen] pub struct WasmMultiHeadAttention { inner: MultiHeadAttention, } #[wasm_bindgen] impl WasmMultiHeadAttention { /// Create a new multi-head attention instance /// /// # Arguments /// * `dim` - Embedding dimension (must be divisible by num_heads) /// * `num_heads` - Number of parallel attention heads #[wasm_bindgen(constructor)] pub fn new(dim: usize, num_heads: usize) -> Result { if dim % num_heads != 0 { return Err(JsError::new(&format!( "Dimension {} must be divisible by number of heads {}", dim, num_heads ))); } Ok(Self { inner: MultiHeadAttention::new(dim, num_heads), }) } /// Compute multi-head attention /// /// # Arguments /// * `query` - Query vector /// * `keys` - Array of key vectors /// * `values` - Array of value vectors pub fn compute( &self, query: &[f32], keys: JsValue, values: JsValue, ) -> Result, JsError> { let keys_vec: Vec> = serde_wasm_bindgen::from_value(keys)?; let values_vec: Vec> = serde_wasm_bindgen::from_value(values)?; let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect(); let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect(); self.inner .compute(query, &keys_refs, &values_refs) .map_err(|e| JsError::new(&e.to_string())) } /// Get the number of attention heads #[wasm_bindgen(getter, js_name = numHeads)] pub fn num_heads(&self) -> usize { self.inner.num_heads() } /// Get the embedding dimension #[wasm_bindgen(getter)] pub fn dim(&self) -> usize { self.inner.dim() } /// Get the dimension per head #[wasm_bindgen(getter, js_name = headDim)] pub fn head_dim(&self) -> usize { self.inner.dim() / self.inner.num_heads() } } // ============================================================================ // Hyperbolic Attention // ============================================================================ /// Hyperbolic attention mechanism for hierarchical data /// /// Operates in hyperbolic space (Poincare ball model) which naturally /// represents tree-like hierarchical structures with exponential capacity #[wasm_bindgen] pub struct WasmHyperbolicAttention { inner: HyperbolicAttention, curvature_value: f32, } #[wasm_bindgen] impl WasmHyperbolicAttention { /// Create a new hyperbolic attention instance /// /// # Arguments /// * `dim` - Embedding dimension /// * `curvature` - Hyperbolic curvature parameter (negative for hyperbolic space) #[wasm_bindgen(constructor)] pub fn new(dim: usize, curvature: f32) -> WasmHyperbolicAttention { let config = HyperbolicAttentionConfig { dim, curvature, ..Default::default() }; Self { inner: HyperbolicAttention::new(config), curvature_value: curvature, } } /// Compute hyperbolic attention pub fn compute( &self, query: &[f32], keys: JsValue, values: JsValue, ) -> Result, JsError> { let keys_vec: Vec> = serde_wasm_bindgen::from_value(keys)?; let values_vec: Vec> = serde_wasm_bindgen::from_value(values)?; let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect(); let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect(); self.inner .compute(query, &keys_refs, &values_refs) .map_err(|e| JsError::new(&e.to_string())) } /// Get the curvature parameter #[wasm_bindgen(getter)] pub fn curvature(&self) -> f32 { self.curvature_value } } // ============================================================================ // Linear Attention (Performer) // ============================================================================ /// Linear attention using random feature approximation /// /// Achieves O(n) complexity instead of O(n^2) by approximating /// the softmax kernel with random Fourier features #[wasm_bindgen] pub struct WasmLinearAttention { inner: LinearAttention, } #[wasm_bindgen] impl WasmLinearAttention { /// Create a new linear attention instance /// /// # Arguments /// * `dim` - Embedding dimension /// * `num_features` - Number of random features for kernel approximation #[wasm_bindgen(constructor)] pub fn new(dim: usize, num_features: usize) -> WasmLinearAttention { Self { inner: LinearAttention::new(dim, num_features), } } /// Compute linear attention pub fn compute( &self, query: &[f32], keys: JsValue, values: JsValue, ) -> Result, JsError> { let keys_vec: Vec> = serde_wasm_bindgen::from_value(keys)?; let values_vec: Vec> = serde_wasm_bindgen::from_value(values)?; let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect(); let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect(); self.inner .compute(query, &keys_refs, &values_refs) .map_err(|e| JsError::new(&e.to_string())) } } // ============================================================================ // Flash Attention // ============================================================================ /// Flash attention with memory-efficient tiling /// /// Reduces memory usage from O(n^2) to O(n) by computing attention /// in blocks and fusing operations #[wasm_bindgen] pub struct WasmFlashAttention { inner: FlashAttention, } #[wasm_bindgen] impl WasmFlashAttention { /// Create a new flash attention instance /// /// # Arguments /// * `dim` - Embedding dimension /// * `block_size` - Block size for tiled computation #[wasm_bindgen(constructor)] pub fn new(dim: usize, block_size: usize) -> WasmFlashAttention { Self { inner: FlashAttention::new(dim, block_size), } } /// Compute flash attention pub fn compute( &self, query: &[f32], keys: JsValue, values: JsValue, ) -> Result, JsError> { let keys_vec: Vec> = serde_wasm_bindgen::from_value(keys)?; let values_vec: Vec> = serde_wasm_bindgen::from_value(values)?; let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect(); let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect(); self.inner .compute(query, &keys_refs, &values_refs) .map_err(|e| JsError::new(&e.to_string())) } } // ============================================================================ // Local-Global Attention // ============================================================================ /// Local-global sparse attention (Longformer-style) /// /// Combines local sliding window attention with global tokens /// for efficient long-range dependencies #[wasm_bindgen] pub struct WasmLocalGlobalAttention { inner: LocalGlobalAttention, } #[wasm_bindgen] impl WasmLocalGlobalAttention { /// Create a new local-global attention instance /// /// # Arguments /// * `dim` - Embedding dimension /// * `local_window` - Size of local attention window /// * `global_tokens` - Number of global attention tokens #[wasm_bindgen(constructor)] pub fn new(dim: usize, local_window: usize, global_tokens: usize) -> WasmLocalGlobalAttention { Self { inner: LocalGlobalAttention::new(dim, local_window, global_tokens), } } /// Compute local-global attention pub fn compute( &self, query: &[f32], keys: JsValue, values: JsValue, ) -> Result, JsError> { let keys_vec: Vec> = serde_wasm_bindgen::from_value(keys)?; let values_vec: Vec> = serde_wasm_bindgen::from_value(values)?; let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect(); let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect(); self.inner .compute(query, &keys_refs, &values_refs) .map_err(|e| JsError::new(&e.to_string())) } } // ============================================================================ // Mixture of Experts (MoE) Attention // ============================================================================ /// Mixture of Experts attention mechanism /// /// Routes queries to specialized expert attention heads based on /// learned gating functions for capacity-efficient computation #[wasm_bindgen] pub struct WasmMoEAttention { inner: MoEAttention, } #[wasm_bindgen] impl WasmMoEAttention { /// Create a new MoE attention instance /// /// # Arguments /// * `dim` - Embedding dimension /// * `num_experts` - Number of expert attention mechanisms /// * `top_k` - Number of experts to activate per query #[wasm_bindgen(constructor)] pub fn new(dim: usize, num_experts: usize, top_k: usize) -> WasmMoEAttention { let config = MoEConfig::builder() .dim(dim) .num_experts(num_experts) .top_k(top_k) .build(); Self { inner: MoEAttention::new(config), } } /// Compute MoE attention pub fn compute( &self, query: &[f32], keys: JsValue, values: JsValue, ) -> Result, JsError> { let keys_vec: Vec> = serde_wasm_bindgen::from_value(keys)?; let values_vec: Vec> = serde_wasm_bindgen::from_value(values)?; let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect(); let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect(); self.inner .compute(query, &keys_refs, &values_refs) .map_err(|e| JsError::new(&e.to_string())) } } // ============================================================================ // Tests // ============================================================================ #[cfg(test)] mod tests { use super::*; use wasm_bindgen_test::*; wasm_bindgen_test_configure!(run_in_browser); #[wasm_bindgen_test] fn test_multi_head_creation() { let mha = WasmMultiHeadAttention::new(64, 8); assert!(mha.is_ok()); let mha = mha.unwrap(); assert_eq!(mha.dim(), 64); assert_eq!(mha.num_heads(), 8); assert_eq!(mha.head_dim(), 8); } #[wasm_bindgen_test] fn test_multi_head_invalid_dims() { let mha = WasmMultiHeadAttention::new(65, 8); assert!(mha.is_err()); } #[wasm_bindgen_test] fn test_hyperbolic_attention() { let hyp = WasmHyperbolicAttention::new(32, -1.0); assert_eq!(hyp.curvature(), -1.0); } #[wasm_bindgen_test] fn test_linear_attention_creation() { let linear = WasmLinearAttention::new(64, 128); // Just verify it can be created assert!(true); } #[wasm_bindgen_test] fn test_flash_attention_creation() { let flash = WasmFlashAttention::new(64, 16); assert!(true); } #[wasm_bindgen_test] fn test_local_global_creation() { let lg = WasmLocalGlobalAttention::new(64, 128, 4); assert!(true); } #[wasm_bindgen_test] fn test_moe_attention_creation() { let moe = WasmMoEAttention::new(64, 8, 2); assert!(true); } }