Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
439
vendor/ruvector/crates/ruvector-attention-unified-wasm/src/neural.rs
vendored
Normal file
439
vendor/ruvector/crates/ruvector-attention-unified-wasm/src/neural.rs
vendored
Normal file
@@ -0,0 +1,439 @@
|
||||
//! Neural Attention Mechanisms (from ruvector-attention)
|
||||
//!
|
||||
//! Re-exports the 7 core neural attention mechanisms:
|
||||
//! - Scaled Dot-Product Attention
|
||||
//! - Multi-Head Attention
|
||||
//! - Hyperbolic Attention
|
||||
//! - Linear Attention (Performer)
|
||||
//! - Flash Attention
|
||||
//! - Local-Global Attention
|
||||
//! - Mixture of Experts (MoE) Attention
|
||||
|
||||
use ruvector_attention::{
|
||||
attention::{MultiHeadAttention, ScaledDotProductAttention},
|
||||
hyperbolic::{HyperbolicAttention, HyperbolicAttentionConfig},
|
||||
moe::{MoEAttention, MoEConfig},
|
||||
sparse::{FlashAttention, LinearAttention, LocalGlobalAttention},
|
||||
traits::Attention,
|
||||
};
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
// ============================================================================
|
||||
// Scaled Dot-Product Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Compute scaled dot-product attention
|
||||
///
|
||||
/// Standard transformer attention: softmax(QK^T / sqrt(d)) * V
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query` - Query vector (Float32Array)
|
||||
/// * `keys` - Array of key vectors (JsValue - array of Float32Arrays)
|
||||
/// * `values` - Array of value vectors (JsValue - array of Float32Arrays)
|
||||
/// * `scale` - Optional scaling factor (defaults to 1/sqrt(dim))
|
||||
///
|
||||
/// # Returns
|
||||
/// Attention-weighted output vector
|
||||
#[wasm_bindgen(js_name = scaledDotAttention)]
|
||||
pub fn scaled_dot_attention(
|
||||
query: &[f32],
|
||||
keys: JsValue,
|
||||
values: JsValue,
|
||||
scale: Option<f32>,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)
|
||||
.map_err(|e| JsError::new(&format!("Failed to parse keys: {}", e)))?;
|
||||
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)
|
||||
.map_err(|e| JsError::new(&format!("Failed to parse values: {}", e)))?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let attention = ScaledDotProductAttention::new(query.len());
|
||||
attention
|
||||
.compute(query, &keys_refs, &values_refs)
|
||||
.map_err(|e| JsError::new(&e.to_string()))
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Multi-Head Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Multi-head attention mechanism
|
||||
///
|
||||
/// Splits input into multiple heads, applies attention, and concatenates results
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmMultiHeadAttention {
|
||||
inner: MultiHeadAttention,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmMultiHeadAttention {
|
||||
/// Create a new multi-head attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension (must be divisible by num_heads)
|
||||
/// * `num_heads` - Number of parallel attention heads
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(dim: usize, num_heads: usize) -> Result<WasmMultiHeadAttention, JsError> {
|
||||
if dim % num_heads != 0 {
|
||||
return Err(JsError::new(&format!(
|
||||
"Dimension {} must be divisible by number of heads {}",
|
||||
dim, num_heads
|
||||
)));
|
||||
}
|
||||
Ok(Self {
|
||||
inner: MultiHeadAttention::new(dim, num_heads),
|
||||
})
|
||||
}
|
||||
|
||||
/// Compute multi-head attention
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query` - Query vector
|
||||
/// * `keys` - Array of key vectors
|
||||
/// * `values` - Array of value vectors
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: JsValue,
|
||||
values: JsValue,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
|
||||
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
self.inner
|
||||
.compute(query, &keys_refs, &values_refs)
|
||||
.map_err(|e| JsError::new(&e.to_string()))
|
||||
}
|
||||
|
||||
/// Get the number of attention heads
|
||||
#[wasm_bindgen(getter, js_name = numHeads)]
|
||||
pub fn num_heads(&self) -> usize {
|
||||
self.inner.num_heads()
|
||||
}
|
||||
|
||||
/// Get the embedding dimension
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn dim(&self) -> usize {
|
||||
self.inner.dim()
|
||||
}
|
||||
|
||||
/// Get the dimension per head
|
||||
#[wasm_bindgen(getter, js_name = headDim)]
|
||||
pub fn head_dim(&self) -> usize {
|
||||
self.inner.dim() / self.inner.num_heads()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Hyperbolic Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Hyperbolic attention mechanism for hierarchical data
|
||||
///
|
||||
/// Operates in hyperbolic space (Poincare ball model) which naturally
|
||||
/// represents tree-like hierarchical structures with exponential capacity
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmHyperbolicAttention {
|
||||
inner: HyperbolicAttention,
|
||||
curvature_value: f32,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmHyperbolicAttention {
|
||||
/// Create a new hyperbolic attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension
|
||||
/// * `curvature` - Hyperbolic curvature parameter (negative for hyperbolic space)
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(dim: usize, curvature: f32) -> WasmHyperbolicAttention {
|
||||
let config = HyperbolicAttentionConfig {
|
||||
dim,
|
||||
curvature,
|
||||
..Default::default()
|
||||
};
|
||||
Self {
|
||||
inner: HyperbolicAttention::new(config),
|
||||
curvature_value: curvature,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute hyperbolic attention
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: JsValue,
|
||||
values: JsValue,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
|
||||
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
self.inner
|
||||
.compute(query, &keys_refs, &values_refs)
|
||||
.map_err(|e| JsError::new(&e.to_string()))
|
||||
}
|
||||
|
||||
/// Get the curvature parameter
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn curvature(&self) -> f32 {
|
||||
self.curvature_value
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Linear Attention (Performer)
|
||||
// ============================================================================
|
||||
|
||||
/// Linear attention using random feature approximation
|
||||
///
|
||||
/// Achieves O(n) complexity instead of O(n^2) by approximating
|
||||
/// the softmax kernel with random Fourier features
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmLinearAttention {
|
||||
inner: LinearAttention,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmLinearAttention {
|
||||
/// Create a new linear attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension
|
||||
/// * `num_features` - Number of random features for kernel approximation
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(dim: usize, num_features: usize) -> WasmLinearAttention {
|
||||
Self {
|
||||
inner: LinearAttention::new(dim, num_features),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute linear attention
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: JsValue,
|
||||
values: JsValue,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
|
||||
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
self.inner
|
||||
.compute(query, &keys_refs, &values_refs)
|
||||
.map_err(|e| JsError::new(&e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Flash Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Flash attention with memory-efficient tiling
|
||||
///
|
||||
/// Reduces memory usage from O(n^2) to O(n) by computing attention
|
||||
/// in blocks and fusing operations
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmFlashAttention {
|
||||
inner: FlashAttention,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmFlashAttention {
|
||||
/// Create a new flash attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension
|
||||
/// * `block_size` - Block size for tiled computation
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(dim: usize, block_size: usize) -> WasmFlashAttention {
|
||||
Self {
|
||||
inner: FlashAttention::new(dim, block_size),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute flash attention
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: JsValue,
|
||||
values: JsValue,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
|
||||
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
self.inner
|
||||
.compute(query, &keys_refs, &values_refs)
|
||||
.map_err(|e| JsError::new(&e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Local-Global Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Local-global sparse attention (Longformer-style)
|
||||
///
|
||||
/// Combines local sliding window attention with global tokens
|
||||
/// for efficient long-range dependencies
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmLocalGlobalAttention {
|
||||
inner: LocalGlobalAttention,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmLocalGlobalAttention {
|
||||
/// Create a new local-global attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension
|
||||
/// * `local_window` - Size of local attention window
|
||||
/// * `global_tokens` - Number of global attention tokens
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(dim: usize, local_window: usize, global_tokens: usize) -> WasmLocalGlobalAttention {
|
||||
Self {
|
||||
inner: LocalGlobalAttention::new(dim, local_window, global_tokens),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute local-global attention
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: JsValue,
|
||||
values: JsValue,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
|
||||
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
self.inner
|
||||
.compute(query, &keys_refs, &values_refs)
|
||||
.map_err(|e| JsError::new(&e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Mixture of Experts (MoE) Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Mixture of Experts attention mechanism
|
||||
///
|
||||
/// Routes queries to specialized expert attention heads based on
|
||||
/// learned gating functions for capacity-efficient computation
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmMoEAttention {
|
||||
inner: MoEAttention,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmMoEAttention {
|
||||
/// Create a new MoE attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension
|
||||
/// * `num_experts` - Number of expert attention mechanisms
|
||||
/// * `top_k` - Number of experts to activate per query
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(dim: usize, num_experts: usize, top_k: usize) -> WasmMoEAttention {
|
||||
let config = MoEConfig::builder()
|
||||
.dim(dim)
|
||||
.num_experts(num_experts)
|
||||
.top_k(top_k)
|
||||
.build();
|
||||
Self {
|
||||
inner: MoEAttention::new(config),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute MoE attention
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: JsValue,
|
||||
values: JsValue,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
|
||||
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
self.inner
|
||||
.compute(query, &keys_refs, &values_refs)
|
||||
.map_err(|e| JsError::new(&e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use wasm_bindgen_test::*;
|
||||
|
||||
wasm_bindgen_test_configure!(run_in_browser);
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_multi_head_creation() {
|
||||
let mha = WasmMultiHeadAttention::new(64, 8);
|
||||
assert!(mha.is_ok());
|
||||
let mha = mha.unwrap();
|
||||
assert_eq!(mha.dim(), 64);
|
||||
assert_eq!(mha.num_heads(), 8);
|
||||
assert_eq!(mha.head_dim(), 8);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_multi_head_invalid_dims() {
|
||||
let mha = WasmMultiHeadAttention::new(65, 8);
|
||||
assert!(mha.is_err());
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_hyperbolic_attention() {
|
||||
let hyp = WasmHyperbolicAttention::new(32, -1.0);
|
||||
assert_eq!(hyp.curvature(), -1.0);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_linear_attention_creation() {
|
||||
let linear = WasmLinearAttention::new(64, 128);
|
||||
// Just verify it can be created
|
||||
assert!(true);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_flash_attention_creation() {
|
||||
let flash = WasmFlashAttention::new(64, 16);
|
||||
assert!(true);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_local_global_creation() {
|
||||
let lg = WasmLocalGlobalAttention::new(64, 128, 4);
|
||||
assert!(true);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_moe_attention_creation() {
|
||||
let moe = WasmMoEAttention::new(64, 8, 2);
|
||||
assert!(true);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user