Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,806 @@
//! DAG Attention Mechanisms (from ruvector-dag)
//!
//! Re-exports the 7 DAG-specific attention mechanisms:
//! - Topological Attention
//! - Causal Cone Attention
//! - Critical Path Attention
//! - MinCut-Gated Attention
//! - Hierarchical Lorentz Attention
//! - Parallel Branch Attention
//! - Temporal BTSP Attention
use ruvector_dag::{OperatorNode, QueryDag};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use wasm_bindgen::prelude::*;
// ============================================================================
// Minimal DAG for WASM
// ============================================================================
/// Minimal DAG structure for WASM attention computation
#[wasm_bindgen]
pub struct WasmQueryDag {
inner: QueryDag,
}
#[wasm_bindgen]
impl WasmQueryDag {
/// Create a new empty DAG
#[wasm_bindgen(constructor)]
pub fn new() -> WasmQueryDag {
WasmQueryDag {
inner: QueryDag::new(),
}
}
/// Add a node with operator type and cost
///
/// # Arguments
/// * `op_type` - Operator type: "scan", "filter", "join", "aggregate", "project", "sort"
/// * `cost` - Estimated execution cost
///
/// # Returns
/// Node ID
#[wasm_bindgen(js_name = addNode)]
pub fn add_node(&mut self, op_type: &str, cost: f32) -> u32 {
let table_id = self.inner.node_count() as usize;
let mut node = match op_type {
"scan" => OperatorNode::seq_scan(table_id, &format!("table_{}", table_id)),
"filter" => OperatorNode::filter(table_id, "condition"),
"join" => OperatorNode::hash_join(table_id, "join_key"),
"aggregate" => OperatorNode::aggregate(table_id, vec!["*".to_string()]),
"project" => OperatorNode::project(table_id, vec!["*".to_string()]),
"sort" => OperatorNode::sort(table_id, vec!["col".to_string()]),
_ => OperatorNode::seq_scan(table_id, "unknown"),
};
node.estimated_cost = cost as f64;
self.inner.add_node(node) as u32
}
/// Add an edge between nodes
///
/// # Arguments
/// * `from` - Source node ID
/// * `to` - Target node ID
///
/// # Returns
/// True if edge was added successfully
#[wasm_bindgen(js_name = addEdge)]
pub fn add_edge(&mut self, from: u32, to: u32) -> bool {
self.inner.add_edge(from as usize, to as usize).is_ok()
}
/// Get the number of nodes
#[wasm_bindgen(getter, js_name = nodeCount)]
pub fn node_count(&self) -> u32 {
self.inner.node_count() as u32
}
/// Get the number of edges
#[wasm_bindgen(getter, js_name = edgeCount)]
pub fn edge_count(&self) -> u32 {
self.inner.edge_count() as u32
}
/// Serialize to JSON
#[wasm_bindgen(js_name = toJson)]
pub fn to_json(&self) -> String {
serde_json::to_string(&DagSummary {
node_count: self.inner.node_count(),
edge_count: self.inner.edge_count(),
})
.unwrap_or_default()
}
}
impl WasmQueryDag {
/// Get internal reference
pub(crate) fn inner(&self) -> &QueryDag {
&self.inner
}
}
#[derive(Serialize, Deserialize)]
struct DagSummary {
node_count: usize,
edge_count: usize,
}
// ============================================================================
// Helper trait for converting HashMap scores to Vec
// ============================================================================
fn hashmap_to_vec(scores: &HashMap<usize, f32>, n: usize) -> Vec<f32> {
(0..n)
.map(|i| scores.get(&i).copied().unwrap_or(0.0))
.collect()
}
// ============================================================================
// Topological Attention
// ============================================================================
/// Topological attention based on DAG position
///
/// Assigns attention scores based on node position in topological order.
/// Earlier nodes (closer to sources) get higher attention.
#[wasm_bindgen]
pub struct WasmTopologicalAttention {
decay_factor: f32,
}
#[wasm_bindgen]
impl WasmTopologicalAttention {
/// Create a new topological attention instance
///
/// # Arguments
/// * `decay_factor` - Decay factor for position-based attention (0.0-1.0)
#[wasm_bindgen(constructor)]
pub fn new(decay_factor: f32) -> WasmTopologicalAttention {
WasmTopologicalAttention { decay_factor }
}
/// Compute attention scores for the DAG
///
/// # Returns
/// Attention scores for each node
pub fn forward(&self, dag: &WasmQueryDag) -> Result<Vec<f32>, JsError> {
let n = dag.inner.node_count();
if n == 0 {
return Err(JsError::new("Empty DAG"));
}
let depths = dag.inner.compute_depths();
let max_depth = depths.values().max().copied().unwrap_or(0);
let mut scores = HashMap::new();
let mut total = 0.0f32;
for (&node_id, &depth) in &depths {
let normalized_depth = depth as f32 / (max_depth.max(1) as f32);
let score = self.decay_factor.powf(1.0 - normalized_depth);
scores.insert(node_id, score);
total += score;
}
if total > 0.0 {
for score in scores.values_mut() {
*score /= total;
}
}
Ok(hashmap_to_vec(&scores, n))
}
}
// ============================================================================
// Causal Cone Attention
// ============================================================================
/// Causal cone attention based on dependency lightcones
///
/// Nodes can only attend to ancestors in the DAG (causal predecessors).
/// Attention strength decays with causal distance.
#[wasm_bindgen]
pub struct WasmCausalConeAttention {
future_discount: f32,
ancestor_weight: f32,
}
#[wasm_bindgen]
impl WasmCausalConeAttention {
/// Create a new causal cone attention instance
///
/// # Arguments
/// * `future_discount` - Discount for future nodes
/// * `ancestor_weight` - Weight for ancestor influence
#[wasm_bindgen(constructor)]
pub fn new(future_discount: f32, ancestor_weight: f32) -> WasmCausalConeAttention {
WasmCausalConeAttention {
future_discount,
ancestor_weight,
}
}
/// Compute attention scores for the DAG
pub fn forward(&self, dag: &WasmQueryDag) -> Result<Vec<f32>, JsError> {
let n = dag.inner.node_count();
if n == 0 {
return Err(JsError::new("Empty DAG"));
}
let mut scores = HashMap::new();
let mut total = 0.0f32;
let depths = dag.inner.compute_depths();
for node_id in 0..n {
if dag.inner.get_node(node_id).is_none() {
continue;
}
let ancestors = dag.inner.ancestors(node_id);
let ancestor_count = ancestors.len();
let mut score = 1.0 + (ancestor_count as f32 * self.ancestor_weight);
if let Some(&depth) = depths.get(&node_id) {
score *= self.future_discount.powi(depth as i32);
}
scores.insert(node_id, score);
total += score;
}
if total > 0.0 {
for score in scores.values_mut() {
*score /= total;
}
}
Ok(hashmap_to_vec(&scores, n))
}
}
// ============================================================================
// Critical Path Attention
// ============================================================================
/// Critical path attention weighted by path criticality
///
/// Nodes on or near the critical path (longest execution path)
/// receive higher attention scores.
#[wasm_bindgen]
pub struct WasmCriticalPathAttention {
path_weight: f32,
branch_penalty: f32,
}
#[wasm_bindgen]
impl WasmCriticalPathAttention {
/// Create a new critical path attention instance
///
/// # Arguments
/// * `path_weight` - Weight for critical path membership
/// * `branch_penalty` - Penalty for branching nodes
#[wasm_bindgen(constructor)]
pub fn new(path_weight: f32, branch_penalty: f32) -> WasmCriticalPathAttention {
WasmCriticalPathAttention {
path_weight,
branch_penalty,
}
}
/// Compute the critical path (longest path by cost)
fn compute_critical_path(&self, dag: &QueryDag) -> Vec<usize> {
let mut longest_path: HashMap<usize, (f64, Vec<usize>)> = HashMap::new();
for &leaf in &dag.leaves() {
if let Some(node) = dag.get_node(leaf) {
longest_path.insert(leaf, (node.estimated_cost, vec![leaf]));
}
}
if let Ok(topo_order) = dag.topological_sort() {
for &node_id in topo_order.iter().rev() {
let node = match dag.get_node(node_id) {
Some(n) => n,
None => continue,
};
let mut max_cost = node.estimated_cost;
let mut max_path = vec![node_id];
for &child in dag.children(node_id) {
if let Some(&(child_cost, ref child_path)) = longest_path.get(&child) {
let total_cost = node.estimated_cost + child_cost;
if total_cost > max_cost {
max_cost = total_cost;
max_path = vec![node_id];
max_path.extend(child_path);
}
}
}
longest_path.insert(node_id, (max_cost, max_path));
}
}
longest_path
.into_iter()
.max_by(|a, b| {
a.1 .0
.partial_cmp(&b.1 .0)
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(_, (_, path))| path)
.unwrap_or_default()
}
/// Compute attention scores for the DAG
pub fn forward(&self, dag: &WasmQueryDag) -> Result<Vec<f32>, JsError> {
let n = dag.inner.node_count();
if n == 0 {
return Err(JsError::new("Empty DAG"));
}
let critical = self.compute_critical_path(&dag.inner);
let mut scores = HashMap::new();
let mut total = 0.0f32;
for node_id in 0..n {
if dag.inner.get_node(node_id).is_none() {
continue;
}
let is_on_critical_path = critical.contains(&node_id);
let num_children = dag.inner.children(node_id).len();
let mut score = if is_on_critical_path {
self.path_weight
} else {
1.0
};
if num_children > 1 {
score *= 1.0 + (num_children as f32 - 1.0) * self.branch_penalty;
}
scores.insert(node_id, score);
total += score;
}
if total > 0.0 {
for score in scores.values_mut() {
*score /= total;
}
}
Ok(hashmap_to_vec(&scores, n))
}
}
// ============================================================================
// MinCut-Gated Attention
// ============================================================================
/// MinCut-gated attention using flow-based bottleneck detection
///
/// Uses minimum cut analysis to identify bottleneck nodes
/// and gates attention through these critical points.
#[wasm_bindgen]
pub struct WasmMinCutGatedAttention {
gate_threshold: f32,
}
#[wasm_bindgen]
impl WasmMinCutGatedAttention {
/// Create a new MinCut-gated attention instance
///
/// # Arguments
/// * `gate_threshold` - Threshold for gating (0.0-1.0)
#[wasm_bindgen(constructor)]
pub fn new(gate_threshold: f32) -> WasmMinCutGatedAttention {
WasmMinCutGatedAttention { gate_threshold }
}
/// Compute attention scores for the DAG
pub fn forward(&self, dag: &WasmQueryDag) -> Result<Vec<f32>, JsError> {
let n = dag.inner.node_count();
if n == 0 {
return Err(JsError::new("Empty DAG"));
}
// Simple bottleneck detection: nodes with high in-degree and out-degree
let mut scores = HashMap::new();
let mut total = 0.0f32;
for node_id in 0..n {
if dag.inner.get_node(node_id).is_none() {
continue;
}
let in_degree = dag.inner.parents(node_id).len();
let out_degree = dag.inner.children(node_id).len();
// Bottleneck score: higher for nodes with high connectivity
let connectivity = (in_degree + out_degree) as f32;
let is_bottleneck = connectivity >= self.gate_threshold * n as f32;
let score = if is_bottleneck {
2.0 + connectivity * 0.1
} else {
1.0
};
scores.insert(node_id, score);
total += score;
}
if total > 0.0 {
for score in scores.values_mut() {
*score /= total;
}
}
Ok(hashmap_to_vec(&scores, n))
}
}
// ============================================================================
// Hierarchical Lorentz Attention
// ============================================================================
/// Hierarchical Lorentz attention in hyperbolic space
///
/// Combines DAG hierarchy with Lorentz (hyperboloid) geometry
/// for multi-scale hierarchical attention.
#[wasm_bindgen]
pub struct WasmHierarchicalLorentzAttention {
curvature: f32,
temperature: f32,
}
#[wasm_bindgen]
impl WasmHierarchicalLorentzAttention {
/// Create a new hierarchical Lorentz attention instance
///
/// # Arguments
/// * `curvature` - Hyperbolic curvature parameter
/// * `temperature` - Temperature for softmax
#[wasm_bindgen(constructor)]
pub fn new(curvature: f32, temperature: f32) -> WasmHierarchicalLorentzAttention {
WasmHierarchicalLorentzAttention {
curvature,
temperature,
}
}
/// Compute attention scores for the DAG
pub fn forward(&self, dag: &WasmQueryDag) -> Result<Vec<f32>, JsError> {
let n = dag.inner.node_count();
if n == 0 {
return Err(JsError::new("Empty DAG"));
}
let depths = dag.inner.compute_depths();
let max_depth = depths.values().max().copied().unwrap_or(0);
// Compute hyperbolic distances from origin
let mut distances: Vec<f32> = Vec::with_capacity(n);
for node_id in 0..n {
let depth = depths.get(&node_id).copied().unwrap_or(0);
// In hyperbolic space, distance grows exponentially with depth
let radial = (depth as f32 * 0.5).tanh();
let distance = (1.0 + radial).acosh() * self.curvature.abs();
distances.push(distance);
}
// Convert to attention scores using softmax
let max_neg_dist = distances
.iter()
.map(|&d| -d / self.temperature)
.fold(f32::NEG_INFINITY, f32::max);
let exp_sum: f32 = distances
.iter()
.map(|&d| ((-d / self.temperature) - max_neg_dist).exp())
.sum();
let scores: Vec<f32> = distances
.iter()
.map(|&d| ((-d / self.temperature) - max_neg_dist).exp() / exp_sum.max(1e-10))
.collect();
Ok(scores)
}
}
// ============================================================================
// Parallel Branch Attention
// ============================================================================
/// Parallel branch attention for concurrent DAG branches
///
/// Identifies parallel branches in the DAG and applies
/// attention patterns that respect branch independence.
#[wasm_bindgen]
pub struct WasmParallelBranchAttention {
max_branches: usize,
sync_penalty: f32,
}
#[wasm_bindgen]
impl WasmParallelBranchAttention {
/// Create a new parallel branch attention instance
///
/// # Arguments
/// * `max_branches` - Maximum number of branches to consider
/// * `sync_penalty` - Penalty for synchronization between branches
#[wasm_bindgen(constructor)]
pub fn new(max_branches: usize, sync_penalty: f32) -> WasmParallelBranchAttention {
WasmParallelBranchAttention {
max_branches,
sync_penalty,
}
}
/// Compute attention scores for the DAG
pub fn forward(&self, dag: &WasmQueryDag) -> Result<Vec<f32>, JsError> {
let n = dag.inner.node_count();
if n == 0 {
return Err(JsError::new("Empty DAG"));
}
// Detect branch points (nodes with multiple children)
let mut branch_starts: Vec<usize> = Vec::new();
for node_id in 0..n {
if dag.inner.children(node_id).len() > 1 {
branch_starts.push(node_id);
}
}
let mut scores = HashMap::new();
let mut total = 0.0f32;
for node_id in 0..n {
if dag.inner.get_node(node_id).is_none() {
continue;
}
// Check if node is part of a parallel branch
let parents = dag.inner.parents(node_id);
let is_branch_child = parents.iter().any(|&p| branch_starts.contains(&p));
let children = dag.inner.children(node_id);
let is_sync_point = children.len() == 0 && parents.len() > 1;
let score = if is_branch_child {
1.5 // Boost parallel branch nodes
} else if is_sync_point {
1.0 * (1.0 - self.sync_penalty) // Penalize sync points
} else {
1.0
};
scores.insert(node_id, score);
total += score;
}
if total > 0.0 {
for score in scores.values_mut() {
*score /= total;
}
}
Ok(hashmap_to_vec(&scores, n))
}
}
// ============================================================================
// Temporal BTSP Attention
// ============================================================================
/// Temporal BTSP (Behavioral Time-Series Pattern) attention
///
/// Incorporates temporal patterns and behavioral sequences
/// for time-aware DAG attention.
#[wasm_bindgen]
pub struct WasmTemporalBTSPAttention {
eligibility_decay: f32,
baseline_attention: f32,
}
#[wasm_bindgen]
impl WasmTemporalBTSPAttention {
/// Create a new temporal BTSP attention instance
///
/// # Arguments
/// * `eligibility_decay` - Decay rate for eligibility traces (0.0-1.0)
/// * `baseline_attention` - Baseline attention for nodes without history
#[wasm_bindgen(constructor)]
pub fn new(eligibility_decay: f32, baseline_attention: f32) -> WasmTemporalBTSPAttention {
WasmTemporalBTSPAttention {
eligibility_decay,
baseline_attention,
}
}
/// Compute attention scores for the DAG
pub fn forward(&self, dag: &WasmQueryDag) -> Result<Vec<f32>, JsError> {
let n = dag.inner.node_count();
if n == 0 {
return Err(JsError::new("Empty DAG"));
}
let mut scores = Vec::with_capacity(n);
let mut total = 0.0f32;
for node_id in 0..n {
let node = match dag.inner.get_node(node_id) {
Some(n) => n,
None => {
scores.push(0.0);
continue;
}
};
// Base score from cost and rows
let cost_factor = (node.estimated_cost as f32 / 100.0).min(1.0);
let rows_factor = (node.estimated_rows as f32 / 1000.0).min(1.0);
let score = self.baseline_attention * (0.5 * cost_factor + 0.5 * rows_factor + 0.5);
scores.push(score);
total += score;
}
// Normalize
if total > 0.0 {
for score in scores.iter_mut() {
*score /= total;
}
}
Ok(scores)
}
}
// ============================================================================
// DAG Attention Factory
// ============================================================================
/// Factory for creating DAG attention mechanisms
#[wasm_bindgen]
pub struct DagAttentionFactory;
#[wasm_bindgen]
impl DagAttentionFactory {
/// Get available DAG attention types
#[wasm_bindgen(js_name = availableTypes)]
pub fn available_types() -> JsValue {
let types = vec![
"topological",
"causal_cone",
"critical_path",
"mincut_gated",
"hierarchical_lorentz",
"parallel_branch",
"temporal_btsp",
];
serde_wasm_bindgen::to_value(&types).unwrap()
}
/// Get description for a DAG attention type
#[wasm_bindgen(js_name = getDescription)]
pub fn get_description(attention_type: &str) -> String {
match attention_type {
"topological" => "Position-based attention following DAG topological order".to_string(),
"causal_cone" => "Lightcone-based attention respecting causal dependencies".to_string(),
"critical_path" => "Attention weighted by critical execution path distance".to_string(),
"mincut_gated" => "Flow-based gating through bottleneck nodes".to_string(),
"hierarchical_lorentz" => {
"Multi-scale hyperbolic attention for DAG hierarchies".to_string()
}
"parallel_branch" => "Branch-aware attention for parallel DAG structures".to_string(),
"temporal_btsp" => "Time-series pattern attention for temporal DAGs".to_string(),
_ => "Unknown attention type".to_string(),
}
}
}
// ============================================================================
// Tests
// ============================================================================
#[cfg(test)]
mod tests {
use super::*;
use wasm_bindgen_test::*;
wasm_bindgen_test_configure!(run_in_browser);
#[wasm_bindgen_test]
fn test_dag_creation() {
let mut dag = WasmQueryDag::new();
let n1 = dag.add_node("scan", 1.0);
let n2 = dag.add_node("filter", 0.5);
dag.add_edge(n1, n2);
assert_eq!(dag.node_count(), 2);
assert_eq!(dag.edge_count(), 1);
}
#[wasm_bindgen_test]
fn test_topological_attention() {
let mut dag = WasmQueryDag::new();
dag.add_node("scan", 1.0);
dag.add_node("filter", 0.5);
dag.add_node("project", 0.3);
dag.add_edge(0, 1);
dag.add_edge(1, 2);
let attention = WasmTopologicalAttention::new(0.9);
let scores = attention.forward(&dag);
assert!(scores.is_ok());
let s = scores.unwrap();
assert_eq!(s.len(), 3);
}
#[wasm_bindgen_test]
fn test_causal_cone_attention() {
let mut dag = WasmQueryDag::new();
dag.add_node("scan", 1.0);
dag.add_node("filter", 0.5);
dag.add_edge(0, 1);
let attention = WasmCausalConeAttention::new(0.8, 0.9);
let scores = attention.forward(&dag);
assert!(scores.is_ok());
}
#[wasm_bindgen_test]
fn test_critical_path_attention() {
let mut dag = WasmQueryDag::new();
dag.add_node("scan", 1.0);
dag.add_node("filter", 0.5);
dag.add_edge(0, 1);
let attention = WasmCriticalPathAttention::new(2.0, 0.5);
let scores = attention.forward(&dag);
assert!(scores.is_ok());
}
#[wasm_bindgen_test]
fn test_mincut_gated_attention() {
let mut dag = WasmQueryDag::new();
dag.add_node("scan", 1.0);
dag.add_node("filter", 0.5);
dag.add_edge(0, 1);
let attention = WasmMinCutGatedAttention::new(0.5);
let scores = attention.forward(&dag);
assert!(scores.is_ok());
}
#[wasm_bindgen_test]
fn test_hierarchical_lorentz_attention() {
let mut dag = WasmQueryDag::new();
dag.add_node("scan", 1.0);
dag.add_node("filter", 0.5);
dag.add_edge(0, 1);
let attention = WasmHierarchicalLorentzAttention::new(-1.0, 0.1);
let scores = attention.forward(&dag);
assert!(scores.is_ok());
}
#[wasm_bindgen_test]
fn test_parallel_branch_attention() {
let mut dag = WasmQueryDag::new();
dag.add_node("scan", 1.0);
dag.add_node("filter", 0.5);
dag.add_edge(0, 1);
let attention = WasmParallelBranchAttention::new(8, 0.2);
let scores = attention.forward(&dag);
assert!(scores.is_ok());
}
#[wasm_bindgen_test]
fn test_temporal_btsp_attention() {
let mut dag = WasmQueryDag::new();
dag.add_node("scan", 1.0);
dag.add_node("filter", 0.5);
dag.add_edge(0, 1);
let attention = WasmTemporalBTSPAttention::new(0.95, 0.5);
let scores = attention.forward(&dag);
assert!(scores.is_ok());
}
#[wasm_bindgen_test]
fn test_factory_types() {
let types_js = DagAttentionFactory::available_types();
assert!(!types_js.is_null());
}
}

View File

@@ -0,0 +1,417 @@
//! Graph Attention Mechanisms (from ruvector-gnn)
//!
//! Re-exports graph neural network attention mechanisms:
//! - GAT (Graph Attention Networks)
//! - GCN (Graph Convolutional Networks)
//! - GraphSAGE (Sample and Aggregate)
use ruvector_gnn::{
differentiable_search as core_differentiable_search,
hierarchical_forward as core_hierarchical_forward, CompressedTensor, CompressionLevel,
RuvectorLayer, TensorCompress,
};
use serde::{Deserialize, Serialize};
use wasm_bindgen::prelude::*;
// ============================================================================
// GNN Layer (GAT-based)
// ============================================================================
/// Graph Neural Network layer with attention mechanism
///
/// Implements Graph Attention Networks (GAT) for HNSW topology.
/// Each node aggregates information from neighbors using learned attention weights.
#[wasm_bindgen]
pub struct WasmGNNLayer {
inner: RuvectorLayer,
hidden_dim: usize,
}
#[wasm_bindgen]
impl WasmGNNLayer {
/// Create a new GNN layer with attention
///
/// # Arguments
/// * `input_dim` - Dimension of input node embeddings
/// * `hidden_dim` - Dimension of hidden representations
/// * `heads` - Number of attention heads
/// * `dropout` - Dropout rate (0.0 to 1.0)
#[wasm_bindgen(constructor)]
pub fn new(
input_dim: usize,
hidden_dim: usize,
heads: usize,
dropout: f32,
) -> Result<WasmGNNLayer, JsError> {
let inner = RuvectorLayer::new(input_dim, hidden_dim, heads, dropout)
.map_err(|e| JsError::new(&e.to_string()))?;
Ok(WasmGNNLayer { inner, hidden_dim })
}
/// Forward pass through the GNN layer
///
/// # Arguments
/// * `node_embedding` - Current node's embedding (Float32Array)
/// * `neighbor_embeddings` - Embeddings of neighbor nodes (array of Float32Arrays)
/// * `edge_weights` - Weights of edges to neighbors (Float32Array)
///
/// # Returns
/// Updated node embedding (Float32Array)
pub fn forward(
&self,
node_embedding: Vec<f32>,
neighbor_embeddings: JsValue,
edge_weights: Vec<f32>,
) -> Result<Vec<f32>, JsError> {
let neighbors: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(neighbor_embeddings)
.map_err(|e| JsError::new(&format!("Failed to parse neighbor embeddings: {}", e)))?;
if neighbors.len() != edge_weights.len() {
return Err(JsError::new(&format!(
"Number of neighbors ({}) must match number of edge weights ({})",
neighbors.len(),
edge_weights.len()
)));
}
let result = self
.inner
.forward(&node_embedding, &neighbors, &edge_weights);
Ok(result)
}
/// Get the output dimension
#[wasm_bindgen(getter, js_name = outputDim)]
pub fn output_dim(&self) -> usize {
self.hidden_dim
}
}
// ============================================================================
// Tensor Compression (for efficient GNN)
// ============================================================================
/// Tensor compressor with adaptive level selection
///
/// Compresses embeddings based on access frequency for memory-efficient GNN
#[wasm_bindgen]
pub struct WasmTensorCompress {
inner: TensorCompress,
}
#[wasm_bindgen]
impl WasmTensorCompress {
/// Create a new tensor compressor
#[wasm_bindgen(constructor)]
pub fn new() -> Self {
Self {
inner: TensorCompress::new(),
}
}
/// Compress an embedding based on access frequency
///
/// # Arguments
/// * `embedding` - The input embedding vector
/// * `access_freq` - Access frequency in range [0.0, 1.0]
/// - f > 0.8: Full precision (hot data)
/// - f > 0.4: Half precision (warm data)
/// - f > 0.1: 8-bit PQ (cool data)
/// - f > 0.01: 4-bit PQ (cold data)
/// - f <= 0.01: Binary (archive)
pub fn compress(&self, embedding: Vec<f32>, access_freq: f32) -> Result<JsValue, JsError> {
let compressed = self
.inner
.compress(&embedding, access_freq)
.map_err(|e| JsError::new(&format!("Compression failed: {}", e)))?;
serde_wasm_bindgen::to_value(&compressed)
.map_err(|e| JsError::new(&format!("Serialization failed: {}", e)))
}
/// Compress with explicit compression level
///
/// # Arguments
/// * `embedding` - The input embedding vector
/// * `level` - Compression level: "none", "half", "pq8", "pq4", "binary"
#[wasm_bindgen(js_name = compressWithLevel)]
pub fn compress_with_level(
&self,
embedding: Vec<f32>,
level: &str,
) -> Result<JsValue, JsError> {
let compression_level = match level {
"none" => CompressionLevel::None,
"half" => CompressionLevel::Half { scale: 1.0 },
"pq8" => CompressionLevel::PQ8 {
subvectors: 8,
centroids: 16,
},
"pq4" => CompressionLevel::PQ4 {
subvectors: 8,
outlier_threshold: 3.0,
},
"binary" => CompressionLevel::Binary { threshold: 0.0 },
_ => {
return Err(JsError::new(&format!(
"Unknown compression level: {}",
level
)))
}
};
let compressed = self
.inner
.compress_with_level(&embedding, &compression_level)
.map_err(|e| JsError::new(&format!("Compression failed: {}", e)))?;
serde_wasm_bindgen::to_value(&compressed)
.map_err(|e| JsError::new(&format!("Serialization failed: {}", e)))
}
/// Decompress a compressed tensor
pub fn decompress(&self, compressed: JsValue) -> Result<Vec<f32>, JsError> {
let compressed_tensor: CompressedTensor = serde_wasm_bindgen::from_value(compressed)
.map_err(|e| JsError::new(&format!("Deserialization failed: {}", e)))?;
self.inner
.decompress(&compressed_tensor)
.map_err(|e| JsError::new(&format!("Decompression failed: {}", e)))
}
/// Get compression ratio estimate for a given access frequency
#[wasm_bindgen(js_name = getCompressionRatio)]
pub fn get_compression_ratio(&self, access_freq: f32) -> f32 {
if access_freq > 0.8 {
1.0
} else if access_freq > 0.4 {
2.0
} else if access_freq > 0.1 {
4.0
} else if access_freq > 0.01 {
8.0
} else {
32.0
}
}
}
// ============================================================================
// Search Configuration
// ============================================================================
/// Search configuration for differentiable search
#[wasm_bindgen]
pub struct WasmSearchConfig {
/// Number of top results to return
pub k: usize,
/// Temperature for softmax
pub temperature: f32,
}
#[wasm_bindgen]
impl WasmSearchConfig {
/// Create a new search configuration
#[wasm_bindgen(constructor)]
pub fn new(k: usize, temperature: f32) -> Self {
Self { k, temperature }
}
}
// ============================================================================
// Differentiable Search
// ============================================================================
/// Differentiable search using soft attention mechanism
///
/// # Arguments
/// * `query` - The query vector
/// * `candidate_embeddings` - List of candidate embedding vectors
/// * `config` - Search configuration
///
/// # Returns
/// Object with indices and weights for top-k candidates
#[wasm_bindgen(js_name = graphDifferentiableSearch)]
pub fn differentiable_search(
query: Vec<f32>,
candidate_embeddings: JsValue,
config: &WasmSearchConfig,
) -> Result<JsValue, JsError> {
let candidates: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(candidate_embeddings)
.map_err(|e| JsError::new(&format!("Failed to parse candidate embeddings: {}", e)))?;
let (indices, weights) =
core_differentiable_search(&query, &candidates, config.k, config.temperature);
let result = SearchResult { indices, weights };
serde_wasm_bindgen::to_value(&result)
.map_err(|e| JsError::new(&format!("Failed to serialize result: {}", e)))
}
#[derive(Serialize, Deserialize)]
struct SearchResult {
indices: Vec<usize>,
weights: Vec<f32>,
}
// ============================================================================
// Hierarchical Forward
// ============================================================================
/// Hierarchical forward pass through multiple GNN layers
///
/// # Arguments
/// * `query` - The query vector
/// * `layer_embeddings` - Embeddings organized by layer
/// * `gnn_layers` - Array of GNN layers
///
/// # Returns
/// Final embedding after hierarchical processing
#[wasm_bindgen(js_name = graphHierarchicalForward)]
pub fn hierarchical_forward(
query: Vec<f32>,
layer_embeddings: JsValue,
gnn_layers: Vec<WasmGNNLayer>,
) -> Result<Vec<f32>, JsError> {
let embeddings: Vec<Vec<Vec<f32>>> = serde_wasm_bindgen::from_value(layer_embeddings)
.map_err(|e| JsError::new(&format!("Failed to parse layer embeddings: {}", e)))?;
let core_layers: Vec<RuvectorLayer> = gnn_layers.iter().map(|l| l.inner.clone()).collect();
let result = core_hierarchical_forward(&query, &embeddings, &core_layers);
Ok(result)
}
// ============================================================================
// Graph Attention Types
// ============================================================================
/// Graph attention mechanism types
#[wasm_bindgen]
pub enum GraphAttentionType {
/// Graph Attention Networks (Velickovic et al., 2018)
GAT,
/// Graph Convolutional Networks (Kipf & Welling, 2017)
GCN,
/// GraphSAGE (Hamilton et al., 2017)
GraphSAGE,
}
/// Factory for graph attention information
#[wasm_bindgen]
pub struct GraphAttentionFactory;
#[wasm_bindgen]
impl GraphAttentionFactory {
/// Get available graph attention types
#[wasm_bindgen(js_name = availableTypes)]
pub fn available_types() -> JsValue {
let types = vec!["gat", "gcn", "graphsage"];
serde_wasm_bindgen::to_value(&types).unwrap()
}
/// Get description for a graph attention type
#[wasm_bindgen(js_name = getDescription)]
pub fn get_description(attention_type: &str) -> String {
match attention_type {
"gat" => {
"Graph Attention Networks - learns attention weights over neighbors".to_string()
}
"gcn" => "Graph Convolutional Networks - spectral convolution on graphs".to_string(),
"graphsage" => "GraphSAGE - sample and aggregate neighbor features".to_string(),
_ => "Unknown graph attention type".to_string(),
}
}
/// Get recommended use cases for a graph attention type
#[wasm_bindgen(js_name = getUseCases)]
pub fn get_use_cases(attention_type: &str) -> JsValue {
let cases = match attention_type {
"gat" => vec![
"Node classification with varying neighbor importance",
"Link prediction in heterogeneous graphs",
"Knowledge graph reasoning",
],
"gcn" => vec![
"Semi-supervised node classification",
"Graph-level classification",
"Spectral clustering",
],
"graphsage" => vec![
"Inductive learning on new nodes",
"Large-scale graph processing",
"Dynamic graphs with new vertices",
],
_ => vec!["Unknown type"],
};
serde_wasm_bindgen::to_value(&cases).unwrap()
}
}
// ============================================================================
// Tests
// ============================================================================
#[cfg(test)]
mod tests {
use super::*;
use wasm_bindgen_test::*;
wasm_bindgen_test_configure!(run_in_browser);
#[wasm_bindgen_test]
fn test_gnn_layer_creation() {
let layer = WasmGNNLayer::new(4, 8, 2, 0.1);
assert!(layer.is_ok());
let l = layer.unwrap();
assert_eq!(l.output_dim(), 8);
}
#[wasm_bindgen_test]
fn test_gnn_layer_invalid_dropout() {
let layer = WasmGNNLayer::new(4, 8, 2, 1.5);
assert!(layer.is_err());
}
#[wasm_bindgen_test]
fn test_gnn_layer_invalid_heads() {
let layer = WasmGNNLayer::new(4, 7, 3, 0.1);
assert!(layer.is_err());
}
#[wasm_bindgen_test]
fn test_tensor_compress_creation() {
let compressor = WasmTensorCompress::new();
assert_eq!(compressor.get_compression_ratio(1.0), 1.0);
assert_eq!(compressor.get_compression_ratio(0.5), 2.0);
assert_eq!(compressor.get_compression_ratio(0.2), 4.0);
assert_eq!(compressor.get_compression_ratio(0.05), 8.0);
assert_eq!(compressor.get_compression_ratio(0.005), 32.0);
}
#[wasm_bindgen_test]
fn test_search_config() {
let config = WasmSearchConfig::new(5, 1.0);
assert_eq!(config.k, 5);
assert_eq!(config.temperature, 1.0);
}
#[wasm_bindgen_test]
fn test_factory_types() {
let types_js = GraphAttentionFactory::available_types();
assert!(!types_js.is_null());
}
#[wasm_bindgen_test]
fn test_factory_descriptions() {
let desc = GraphAttentionFactory::get_description("gat");
assert!(desc.contains("Graph Attention"));
let desc = GraphAttentionFactory::get_description("gcn");
assert!(desc.contains("Graph Convolutional"));
let desc = GraphAttentionFactory::get_description("graphsage");
assert!(desc.contains("GraphSAGE"));
}
}

View File

@@ -0,0 +1,382 @@
//! Unified WebAssembly Attention Library
//!
//! This crate provides a unified WASM interface for 18+ attention mechanisms:
//!
//! ## Neural Attention (from ruvector-attention)
//! - **Scaled Dot-Product**: Standard transformer attention
//! - **Multi-Head**: Parallel attention heads
//! - **Hyperbolic**: Attention in hyperbolic space for hierarchical data
//! - **Linear**: O(n) Performer-style attention
//! - **Flash**: Memory-efficient blocked attention
//! - **Local-Global**: Sparse attention with global tokens
//! - **MoE**: Mixture of Experts attention
//!
//! ## DAG Attention (from ruvector-dag)
//! - **Topological**: Position-aware attention in DAG order
//! - **Causal Cone**: Lightcone-based causal attention
//! - **Critical Path**: Attention weighted by critical path distance
//! - **MinCut-Gated**: Flow-based gating attention
//! - **Hierarchical Lorentz**: Multi-scale hyperbolic DAG attention
//! - **Parallel Branch**: Attention for parallel DAG branches
//! - **Temporal BTSP**: Behavioral Time-Series Pattern attention
//!
//! ## Graph Attention (from ruvector-gnn)
//! - **GAT**: Graph Attention Networks
//! - **GCN**: Graph Convolutional Networks
//! - **GraphSAGE**: Sampling and Aggregating graph embeddings
//!
//! ## State Space Models
//! - **Mamba SSM**: Selective State Space Model attention
use wasm_bindgen::prelude::*;
// Use wee_alloc for smaller WASM binary (~10KB reduction)
#[cfg(feature = "wee_alloc")]
#[global_allocator]
static ALLOC: wee_alloc::WeeAlloc = wee_alloc::WeeAlloc::INIT;
// ============================================================================
// Module declarations
// ============================================================================
pub mod mamba;
mod dag;
mod graph;
mod neural;
// ============================================================================
// Re-exports for convenient access
// ============================================================================
pub use dag::*;
pub use graph::*;
pub use mamba::*;
pub use neural::*;
// ============================================================================
// Initialization
// ============================================================================
/// Initialize the WASM module with panic hook for better error messages
#[wasm_bindgen(start)]
pub fn init() {
#[cfg(feature = "console_error_panic_hook")]
console_error_panic_hook::set_once();
}
// ============================================================================
// Version and Info
// ============================================================================
/// Get the version of the unified attention WASM crate
#[wasm_bindgen]
pub fn version() -> String {
env!("CARGO_PKG_VERSION").to_string()
}
/// Get information about all available attention mechanisms
#[wasm_bindgen(js_name = availableMechanisms)]
pub fn available_mechanisms() -> JsValue {
let mechanisms = AttentionMechanisms {
neural: vec![
"scaled_dot_product".into(),
"multi_head".into(),
"hyperbolic".into(),
"linear".into(),
"flash".into(),
"local_global".into(),
"moe".into(),
],
dag: vec![
"topological".into(),
"causal_cone".into(),
"critical_path".into(),
"mincut_gated".into(),
"hierarchical_lorentz".into(),
"parallel_branch".into(),
"temporal_btsp".into(),
],
graph: vec!["gat".into(), "gcn".into(), "graphsage".into()],
ssm: vec!["mamba".into()],
};
serde_wasm_bindgen::to_value(&mechanisms).unwrap()
}
/// Get summary statistics about the unified attention library
#[wasm_bindgen(js_name = getStats)]
pub fn get_stats() -> JsValue {
let stats = UnifiedStats {
total_mechanisms: 18,
neural_count: 7,
dag_count: 7,
graph_count: 3,
ssm_count: 1,
version: env!("CARGO_PKG_VERSION").to_string(),
};
serde_wasm_bindgen::to_value(&stats).unwrap()
}
// ============================================================================
// Internal Types
// ============================================================================
#[derive(serde::Serialize)]
struct AttentionMechanisms {
neural: Vec<String>,
dag: Vec<String>,
graph: Vec<String>,
ssm: Vec<String>,
}
#[derive(serde::Serialize)]
struct UnifiedStats {
total_mechanisms: usize,
neural_count: usize,
dag_count: usize,
graph_count: usize,
ssm_count: usize,
version: String,
}
// ============================================================================
// Unified Attention Selector
// ============================================================================
/// Unified attention mechanism selector
/// Automatically routes to the appropriate attention implementation
#[wasm_bindgen]
pub struct UnifiedAttention {
mechanism_type: String,
}
#[wasm_bindgen]
impl UnifiedAttention {
/// Create a new unified attention selector
#[wasm_bindgen(constructor)]
pub fn new(mechanism: &str) -> Result<UnifiedAttention, JsError> {
let valid_mechanisms = [
// Neural
"scaled_dot_product",
"multi_head",
"hyperbolic",
"linear",
"flash",
"local_global",
"moe",
// DAG
"topological",
"causal_cone",
"critical_path",
"mincut_gated",
"hierarchical_lorentz",
"parallel_branch",
"temporal_btsp",
// Graph
"gat",
"gcn",
"graphsage",
// SSM
"mamba",
];
if !valid_mechanisms.contains(&mechanism) {
return Err(JsError::new(&format!(
"Unknown mechanism: {}. Valid options: {:?}",
mechanism, valid_mechanisms
)));
}
Ok(Self {
mechanism_type: mechanism.to_string(),
})
}
/// Get the currently selected mechanism type
#[wasm_bindgen(getter)]
pub fn mechanism(&self) -> String {
self.mechanism_type.clone()
}
/// Get the category of the selected mechanism
#[wasm_bindgen(getter)]
pub fn category(&self) -> String {
match self.mechanism_type.as_str() {
"scaled_dot_product" | "multi_head" | "hyperbolic" | "linear" | "flash"
| "local_global" | "moe" => "neural".to_string(),
"topological"
| "causal_cone"
| "critical_path"
| "mincut_gated"
| "hierarchical_lorentz"
| "parallel_branch"
| "temporal_btsp" => "dag".to_string(),
"gat" | "gcn" | "graphsage" => "graph".to_string(),
"mamba" => "ssm".to_string(),
_ => "unknown".to_string(),
}
}
/// Check if this mechanism supports sequence processing
#[wasm_bindgen(js_name = supportsSequences)]
pub fn supports_sequences(&self) -> bool {
matches!(
self.mechanism_type.as_str(),
"scaled_dot_product" | "multi_head" | "linear" | "flash" | "local_global" | "mamba"
)
}
/// Check if this mechanism supports graph/DAG structures
#[wasm_bindgen(js_name = supportsGraphs)]
pub fn supports_graphs(&self) -> bool {
matches!(
self.mechanism_type.as_str(),
"topological"
| "causal_cone"
| "critical_path"
| "mincut_gated"
| "hierarchical_lorentz"
| "parallel_branch"
| "temporal_btsp"
| "gat"
| "gcn"
| "graphsage"
)
}
/// Check if this mechanism supports hyperbolic geometry
#[wasm_bindgen(js_name = supportsHyperbolic)]
pub fn supports_hyperbolic(&self) -> bool {
matches!(
self.mechanism_type.as_str(),
"hyperbolic" | "hierarchical_lorentz"
)
}
}
// ============================================================================
// Utility Functions
// ============================================================================
/// Compute cosine similarity between two vectors
#[wasm_bindgen(js_name = cosineSimilarity)]
pub fn cosine_similarity(a: Vec<f32>, b: Vec<f32>) -> Result<f32, JsError> {
if a.len() != b.len() {
return Err(JsError::new(&format!(
"Vector dimensions must match: {} vs {}",
a.len(),
b.len()
)));
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
Ok(0.0)
} else {
Ok(dot / (norm_a * norm_b))
}
}
/// Softmax normalization
#[wasm_bindgen]
pub fn softmax(values: Vec<f32>) -> Vec<f32> {
let max_val = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let exp_values: Vec<f32> = values.iter().map(|&x| (x - max_val).exp()).collect();
let sum: f32 = exp_values.iter().sum();
exp_values.iter().map(|&x| x / sum).collect()
}
/// Temperature-scaled softmax
#[wasm_bindgen(js_name = temperatureSoftmax)]
pub fn temperature_softmax(values: Vec<f32>, temperature: f32) -> Vec<f32> {
if temperature <= 0.0 {
// Return one-hot for the maximum
let max_idx = values
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(i, _)| i)
.unwrap_or(0);
let mut result = vec![0.0; values.len()];
result[max_idx] = 1.0;
return result;
}
let scaled: Vec<f32> = values.iter().map(|&x| x / temperature).collect();
softmax(scaled)
}
// ============================================================================
// Tests
// ============================================================================
#[cfg(test)]
mod tests {
use super::*;
use wasm_bindgen_test::*;
wasm_bindgen_test_configure!(run_in_browser);
#[wasm_bindgen_test]
fn test_version() {
assert!(!version().is_empty());
}
#[wasm_bindgen_test]
fn test_unified_attention_creation() {
let attention = UnifiedAttention::new("multi_head");
assert!(attention.is_ok());
let invalid = UnifiedAttention::new("invalid_mechanism");
assert!(invalid.is_err());
}
#[wasm_bindgen_test]
fn test_mechanism_categories() {
let neural = UnifiedAttention::new("multi_head").unwrap();
assert_eq!(neural.category(), "neural");
let dag = UnifiedAttention::new("topological").unwrap();
assert_eq!(dag.category(), "dag");
let graph = UnifiedAttention::new("gat").unwrap();
assert_eq!(graph.category(), "graph");
let ssm = UnifiedAttention::new("mamba").unwrap();
assert_eq!(ssm.category(), "ssm");
}
#[wasm_bindgen_test]
fn test_softmax() {
let input = vec![1.0, 2.0, 3.0];
let output = softmax(input);
// Sum should be 1.0
let sum: f32 = output.iter().sum();
assert!((sum - 1.0).abs() < 1e-6);
// Should be monotonically increasing
assert!(output[0] < output[1]);
assert!(output[1] < output[2]);
}
#[wasm_bindgen_test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
let sim = cosine_similarity(a, b).unwrap();
assert!((sim - 1.0).abs() < 1e-6);
let c = vec![1.0, 0.0, 0.0];
let d = vec![0.0, 1.0, 0.0];
let sim2 = cosine_similarity(c, d).unwrap();
assert!(sim2.abs() < 1e-6);
}
}

View File

@@ -0,0 +1,554 @@
//! Mamba SSM (Selective State Space Model) Attention Mechanism
//!
//! Implements the Mamba architecture's selective scan mechanism for efficient
//! sequence modeling with linear time complexity O(n).
//!
//! Key Features:
//! - **Selective Scan**: Input-dependent state transitions
//! - **Linear Complexity**: O(n) vs O(n^2) for standard attention
//! - **Hardware Efficient**: Optimized for parallel scan operations
//! - **Long Context**: Handles very long sequences efficiently
//!
//! ## Architecture
//!
//! Mamba uses a selective state space model:
//! ```text
//! h_t = A_t * h_{t-1} + B_t * x_t
//! y_t = C_t * h_t
//! ```
//!
//! Where A_t, B_t, C_t are input-dependent (selective), computed from x_t.
//!
//! ## References
//!
//! - Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Gu & Dao, 2023)
use serde::{Deserialize, Serialize};
use wasm_bindgen::prelude::*;
// ============================================================================
// Configuration
// ============================================================================
/// Configuration for Mamba SSM attention
#[derive(Debug, Clone, Serialize, Deserialize)]
#[wasm_bindgen]
pub struct MambaConfig {
/// Model dimension (d_model)
pub dim: usize,
/// State space dimension (n)
pub state_dim: usize,
/// Expansion factor for inner dimension
pub expand_factor: usize,
/// Convolution kernel size
pub conv_kernel_size: usize,
/// Delta (discretization step) range minimum
pub dt_min: f32,
/// Delta range maximum
pub dt_max: f32,
/// Whether to use learnable D skip connection
pub use_d_skip: bool,
}
#[wasm_bindgen]
impl MambaConfig {
/// Create a new Mamba configuration
#[wasm_bindgen(constructor)]
pub fn new(dim: usize) -> MambaConfig {
MambaConfig {
dim,
state_dim: 16,
expand_factor: 2,
conv_kernel_size: 4,
dt_min: 0.001,
dt_max: 0.1,
use_d_skip: true,
}
}
/// Set state space dimension
#[wasm_bindgen(js_name = withStateDim)]
pub fn with_state_dim(mut self, state_dim: usize) -> MambaConfig {
self.state_dim = state_dim;
self
}
/// Set expansion factor
#[wasm_bindgen(js_name = withExpandFactor)]
pub fn with_expand_factor(mut self, factor: usize) -> MambaConfig {
self.expand_factor = factor;
self
}
/// Set convolution kernel size
#[wasm_bindgen(js_name = withConvKernelSize)]
pub fn with_conv_kernel_size(mut self, size: usize) -> MambaConfig {
self.conv_kernel_size = size;
self
}
}
impl Default for MambaConfig {
fn default() -> Self {
MambaConfig::new(256)
}
}
// ============================================================================
// State Space Parameters
// ============================================================================
/// Selective state space parameters (input-dependent)
#[derive(Debug, Clone, Serialize, Deserialize)]
struct SelectiveSSMParams {
/// Discretized A matrix diagonal (batch, seq_len, state_dim)
a_bar: Vec<Vec<Vec<f32>>>,
/// Discretized B matrix (batch, seq_len, state_dim)
b_bar: Vec<Vec<Vec<f32>>>,
/// Output projection C (batch, seq_len, state_dim)
c: Vec<Vec<Vec<f32>>>,
/// Discretization step delta (batch, seq_len, inner_dim)
delta: Vec<Vec<Vec<f32>>>,
}
// ============================================================================
// Mamba SSM Attention
// ============================================================================
/// Mamba Selective State Space Model for sequence attention
///
/// Provides O(n) attention-like mechanism using selective state spaces
#[wasm_bindgen]
pub struct MambaSSMAttention {
config: MambaConfig,
/// Inner dimension after expansion
inner_dim: usize,
/// A parameter (state_dim,) - diagonal of continuous A
a_log: Vec<f32>,
/// D skip connection (inner_dim,)
d_skip: Vec<f32>,
/// Projection weights (simplified for WASM)
in_proj: Vec<Vec<f32>>,
out_proj: Vec<Vec<f32>>,
}
#[wasm_bindgen]
impl MambaSSMAttention {
/// Create a new Mamba SSM attention layer
#[wasm_bindgen(constructor)]
pub fn new(config: MambaConfig) -> MambaSSMAttention {
let inner_dim = config.dim * config.expand_factor;
// Initialize A as negative values (for stability) - log of eigenvalues
let a_log: Vec<f32> = (0..config.state_dim)
.map(|i| -((i + 1) as f32).ln())
.collect();
// D skip connection
let d_skip = vec![1.0; inner_dim];
// Simplified projection matrices (identity-like for stub)
let in_proj: Vec<Vec<f32>> = (0..inner_dim)
.map(|i| {
let mut row = vec![0.0; config.dim];
if i < config.dim {
row[i] = 1.0;
}
row
})
.collect();
let out_proj: Vec<Vec<f32>> = (0..config.dim)
.map(|i| {
let mut row = vec![0.0; inner_dim];
if i < inner_dim {
row[i] = 1.0;
}
row
})
.collect();
MambaSSMAttention {
config,
inner_dim,
a_log,
d_skip,
in_proj,
out_proj,
}
}
/// Create with default configuration
#[wasm_bindgen(js_name = withDefaults)]
pub fn with_defaults(dim: usize) -> MambaSSMAttention {
MambaSSMAttention::new(MambaConfig::new(dim))
}
/// Forward pass through Mamba SSM
///
/// # Arguments
/// * `input` - Input sequence (seq_len, dim) flattened to 1D
/// * `seq_len` - Sequence length
///
/// # Returns
/// Output sequence (seq_len, dim) flattened to 1D
#[wasm_bindgen]
pub fn forward(&self, input: Vec<f32>, seq_len: usize) -> Result<Vec<f32>, JsError> {
let dim = self.config.dim;
if input.len() != seq_len * dim {
return Err(JsError::new(&format!(
"Input size mismatch: expected {} ({}x{}), got {}",
seq_len * dim,
seq_len,
dim,
input.len()
)));
}
// Reshape input to 2D
let input_2d: Vec<Vec<f32>> = (0..seq_len)
.map(|t| input[t * dim..(t + 1) * dim].to_vec())
.collect();
// Step 1: Input projection to inner_dim
let projected = self.project_in(&input_2d);
// Step 2: Compute selective SSM parameters from input
let ssm_params = self.compute_selective_params(&projected);
// Step 3: Run selective scan
let ssm_output = self.selective_scan(&projected, &ssm_params);
// Step 4: Apply D skip connection
let with_skip: Vec<Vec<f32>> = ssm_output
.iter()
.zip(projected.iter())
.map(|(y, x)| {
y.iter()
.zip(x.iter())
.zip(self.d_skip.iter())
.map(|((yi, xi), di)| yi + di * xi)
.collect()
})
.collect();
// Step 5: Output projection
let output = self.project_out(&with_skip);
// Flatten output
Ok(output.into_iter().flatten().collect())
}
/// Get the configuration
#[wasm_bindgen(getter)]
pub fn config(&self) -> MambaConfig {
self.config.clone()
}
/// Get the inner dimension
#[wasm_bindgen(getter, js_name = innerDim)]
pub fn inner_dim(&self) -> usize {
self.inner_dim
}
/// Compute attention-like scores (for visualization/analysis)
///
/// Returns pseudo-attention scores showing which positions influence output
#[wasm_bindgen(js_name = getAttentionScores)]
pub fn get_attention_scores(
&self,
input: Vec<f32>,
seq_len: usize,
) -> Result<Vec<f32>, JsError> {
let dim = self.config.dim;
if input.len() != seq_len * dim {
return Err(JsError::new(&format!(
"Input size mismatch: expected {}, got {}",
seq_len * dim,
input.len()
)));
}
// Compute approximate attention scores based on state decay
// This shows how much each position can "attend to" previous positions
let mut scores = vec![0.0f32; seq_len * seq_len];
for t in 0..seq_len {
for s in 0..=t {
// Exponential decay based on distance and A parameters
let distance = (t - s) as f32;
let decay: f32 = self
.a_log
.iter()
.map(|&a| (a * distance).exp())
.sum::<f32>()
/ self.config.state_dim as f32;
scores[t * seq_len + s] = decay;
}
}
Ok(scores)
}
}
// Internal implementation methods
impl MambaSSMAttention {
/// Project input from dim to inner_dim
fn project_in(&self, input: &[Vec<f32>]) -> Vec<Vec<f32>> {
input
.iter()
.map(|x| {
self.in_proj
.iter()
.map(|row| row.iter().zip(x.iter()).map(|(w, xi)| w * xi).sum())
.collect()
})
.collect()
}
/// Project from inner_dim back to dim
fn project_out(&self, input: &[Vec<f32>]) -> Vec<Vec<f32>> {
input
.iter()
.map(|x| {
self.out_proj
.iter()
.map(|row| row.iter().zip(x.iter()).map(|(w, xi)| w * xi).sum())
.collect()
})
.collect()
}
/// Compute selective SSM parameters from input
fn compute_selective_params(&self, input: &[Vec<f32>]) -> SelectiveSSMParams {
let seq_len = input.len();
let state_dim = self.config.state_dim;
// Compute input-dependent delta, B, C
// Simplified: use sigmoid/tanh of input projections
let mut a_bar = vec![vec![vec![0.0; state_dim]; self.inner_dim]; seq_len];
let mut b_bar = vec![vec![vec![0.0; state_dim]; self.inner_dim]; seq_len];
let mut c = vec![vec![vec![0.0; state_dim]; self.inner_dim]; seq_len];
let mut delta = vec![vec![vec![0.0; self.inner_dim]; 1]; seq_len];
for (t, x) in input.iter().enumerate() {
// Compute delta from input (softplus of projection)
let dt: Vec<f32> = x
.iter()
.map(|&xi| {
let raw = xi * 0.1; // Simple scaling
let dt_val = (1.0 + raw.exp()).ln(); // Softplus
dt_val.clamp(self.config.dt_min, self.config.dt_max)
})
.collect();
delta[t][0] = dt.clone();
for d in 0..self.inner_dim.min(x.len()) {
let dt_d = dt[d.min(dt.len() - 1)];
for n in 0..state_dim {
// Discretize A: A_bar = exp(delta * A)
let a_continuous = self.a_log[n].exp(); // Negative
a_bar[t][d][n] = (dt_d * a_continuous).exp();
// Discretize B: B_bar = delta * B (simplified)
// B is input-dependent
let b_input = if d < x.len() { x[d] } else { 0.0 };
b_bar[t][d][n] = dt_d * Self::sigmoid(b_input * 0.1);
// C is input-dependent
c[t][d][n] = Self::tanh(b_input * 0.1);
}
}
}
SelectiveSSMParams {
a_bar,
b_bar,
c,
delta,
}
}
/// Run selective scan (parallel associative scan in practice)
fn selective_scan(&self, input: &[Vec<f32>], params: &SelectiveSSMParams) -> Vec<Vec<f32>> {
let seq_len = input.len();
let state_dim = self.config.state_dim;
// Initialize hidden state
let mut hidden = vec![vec![0.0f32; state_dim]; self.inner_dim];
let mut output = vec![vec![0.0f32; self.inner_dim]; seq_len];
for t in 0..seq_len {
for d in 0..self.inner_dim {
let x_d = if d < input[t].len() { input[t][d] } else { 0.0 };
// Update hidden state: h_t = A_bar * h_{t-1} + B_bar * x_t
for n in 0..state_dim {
hidden[d][n] =
params.a_bar[t][d][n] * hidden[d][n] + params.b_bar[t][d][n] * x_d;
}
// Compute output: y_t = C * h_t
output[t][d] = hidden[d]
.iter()
.zip(params.c[t][d].iter())
.map(|(h, c)| h * c)
.sum();
}
}
output
}
#[inline]
fn sigmoid(x: f32) -> f32 {
1.0 / (1.0 + (-x).exp())
}
#[inline]
fn tanh(x: f32) -> f32 {
x.tanh()
}
}
// ============================================================================
// Hybrid Mamba-Attention
// ============================================================================
/// Hybrid layer combining Mamba SSM with standard attention
///
/// Uses Mamba for long-range dependencies and attention for local patterns
#[wasm_bindgen]
pub struct HybridMambaAttention {
mamba: MambaSSMAttention,
local_window: usize,
use_attention_for_local: bool,
}
#[wasm_bindgen]
impl HybridMambaAttention {
/// Create a new hybrid Mamba-Attention layer
#[wasm_bindgen(constructor)]
pub fn new(config: MambaConfig, local_window: usize) -> HybridMambaAttention {
HybridMambaAttention {
mamba: MambaSSMAttention::new(config),
local_window,
use_attention_for_local: true,
}
}
/// Forward pass
#[wasm_bindgen]
pub fn forward(&self, input: Vec<f32>, seq_len: usize) -> Result<Vec<f32>, JsError> {
let dim = self.mamba.config.dim;
// Run Mamba for global context
let mamba_output = self.mamba.forward(input.clone(), seq_len)?;
// Apply local attention mixing (simplified)
let mut output = mamba_output.clone();
if self.use_attention_for_local {
for t in 0..seq_len {
let start = t.saturating_sub(self.local_window / 2);
let end = (t + self.local_window / 2 + 1).min(seq_len);
// Simple local averaging
for d in 0..dim {
let mut local_sum = 0.0;
let mut count = 0;
for s in start..end {
local_sum += input[s * dim + d];
count += 1;
}
// Mix global (Mamba) and local
let local_avg = local_sum / count as f32;
output[t * dim + d] = 0.7 * output[t * dim + d] + 0.3 * local_avg;
}
}
}
Ok(output)
}
/// Get local window size
#[wasm_bindgen(getter, js_name = localWindow)]
pub fn local_window(&self) -> usize {
self.local_window
}
}
// ============================================================================
// Tests
// ============================================================================
#[cfg(test)]
mod tests {
use super::*;
use wasm_bindgen_test::*;
wasm_bindgen_test_configure!(run_in_browser);
#[wasm_bindgen_test]
fn test_mamba_config() {
let config = MambaConfig::new(256);
assert_eq!(config.dim, 256);
assert_eq!(config.state_dim, 16);
assert_eq!(config.expand_factor, 2);
}
#[wasm_bindgen_test]
fn test_mamba_creation() {
let config = MambaConfig::new(64);
let mamba = MambaSSMAttention::new(config);
assert_eq!(mamba.inner_dim(), 128); // 64 * 2
}
#[wasm_bindgen_test]
fn test_mamba_forward() {
let config = MambaConfig::new(8);
let mamba = MambaSSMAttention::new(config);
// Input: 4 tokens of dimension 8
let input = vec![0.1f32; 32];
let output = mamba.forward(input, 4);
assert!(output.is_ok());
let out = output.unwrap();
assert_eq!(out.len(), 32); // Same shape as input
}
#[wasm_bindgen_test]
fn test_attention_scores() {
let config = MambaConfig::new(8);
let mamba = MambaSSMAttention::new(config);
let input = vec![0.1f32; 24]; // 3 tokens
let scores = mamba.get_attention_scores(input, 3);
assert!(scores.is_ok());
let s = scores.unwrap();
assert_eq!(s.len(), 9); // 3x3 attention matrix
// Causal: upper triangle should be 0
assert_eq!(s[0 * 3 + 1], 0.0); // t=0 cannot attend to t=1
assert_eq!(s[0 * 3 + 2], 0.0); // t=0 cannot attend to t=2
}
#[wasm_bindgen_test]
fn test_hybrid_mamba() {
let config = MambaConfig::new(8);
let hybrid = HybridMambaAttention::new(config, 4);
let input = vec![0.5f32; 40]; // 5 tokens
let output = hybrid.forward(input, 5);
assert!(output.is_ok());
assert_eq!(output.unwrap().len(), 40);
}
}

View File

@@ -0,0 +1,439 @@
//! Neural Attention Mechanisms (from ruvector-attention)
//!
//! Re-exports the 7 core neural attention mechanisms:
//! - Scaled Dot-Product Attention
//! - Multi-Head Attention
//! - Hyperbolic Attention
//! - Linear Attention (Performer)
//! - Flash Attention
//! - Local-Global Attention
//! - Mixture of Experts (MoE) Attention
use ruvector_attention::{
attention::{MultiHeadAttention, ScaledDotProductAttention},
hyperbolic::{HyperbolicAttention, HyperbolicAttentionConfig},
moe::{MoEAttention, MoEConfig},
sparse::{FlashAttention, LinearAttention, LocalGlobalAttention},
traits::Attention,
};
use wasm_bindgen::prelude::*;
// ============================================================================
// Scaled Dot-Product Attention
// ============================================================================
/// Compute scaled dot-product attention
///
/// Standard transformer attention: softmax(QK^T / sqrt(d)) * V
///
/// # Arguments
/// * `query` - Query vector (Float32Array)
/// * `keys` - Array of key vectors (JsValue - array of Float32Arrays)
/// * `values` - Array of value vectors (JsValue - array of Float32Arrays)
/// * `scale` - Optional scaling factor (defaults to 1/sqrt(dim))
///
/// # Returns
/// Attention-weighted output vector
#[wasm_bindgen(js_name = scaledDotAttention)]
pub fn scaled_dot_attention(
query: &[f32],
keys: JsValue,
values: JsValue,
scale: Option<f32>,
) -> Result<Vec<f32>, JsError> {
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)
.map_err(|e| JsError::new(&format!("Failed to parse keys: {}", e)))?;
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)
.map_err(|e| JsError::new(&format!("Failed to parse values: {}", e)))?;
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
let attention = ScaledDotProductAttention::new(query.len());
attention
.compute(query, &keys_refs, &values_refs)
.map_err(|e| JsError::new(&e.to_string()))
}
// ============================================================================
// Multi-Head Attention
// ============================================================================
/// Multi-head attention mechanism
///
/// Splits input into multiple heads, applies attention, and concatenates results
#[wasm_bindgen]
pub struct WasmMultiHeadAttention {
inner: MultiHeadAttention,
}
#[wasm_bindgen]
impl WasmMultiHeadAttention {
/// Create a new multi-head attention instance
///
/// # Arguments
/// * `dim` - Embedding dimension (must be divisible by num_heads)
/// * `num_heads` - Number of parallel attention heads
#[wasm_bindgen(constructor)]
pub fn new(dim: usize, num_heads: usize) -> Result<WasmMultiHeadAttention, JsError> {
if dim % num_heads != 0 {
return Err(JsError::new(&format!(
"Dimension {} must be divisible by number of heads {}",
dim, num_heads
)));
}
Ok(Self {
inner: MultiHeadAttention::new(dim, num_heads),
})
}
/// Compute multi-head attention
///
/// # Arguments
/// * `query` - Query vector
/// * `keys` - Array of key vectors
/// * `values` - Array of value vectors
pub fn compute(
&self,
query: &[f32],
keys: JsValue,
values: JsValue,
) -> Result<Vec<f32>, JsError> {
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
self.inner
.compute(query, &keys_refs, &values_refs)
.map_err(|e| JsError::new(&e.to_string()))
}
/// Get the number of attention heads
#[wasm_bindgen(getter, js_name = numHeads)]
pub fn num_heads(&self) -> usize {
self.inner.num_heads()
}
/// Get the embedding dimension
#[wasm_bindgen(getter)]
pub fn dim(&self) -> usize {
self.inner.dim()
}
/// Get the dimension per head
#[wasm_bindgen(getter, js_name = headDim)]
pub fn head_dim(&self) -> usize {
self.inner.dim() / self.inner.num_heads()
}
}
// ============================================================================
// Hyperbolic Attention
// ============================================================================
/// Hyperbolic attention mechanism for hierarchical data
///
/// Operates in hyperbolic space (Poincare ball model) which naturally
/// represents tree-like hierarchical structures with exponential capacity
#[wasm_bindgen]
pub struct WasmHyperbolicAttention {
inner: HyperbolicAttention,
curvature_value: f32,
}
#[wasm_bindgen]
impl WasmHyperbolicAttention {
/// Create a new hyperbolic attention instance
///
/// # Arguments
/// * `dim` - Embedding dimension
/// * `curvature` - Hyperbolic curvature parameter (negative for hyperbolic space)
#[wasm_bindgen(constructor)]
pub fn new(dim: usize, curvature: f32) -> WasmHyperbolicAttention {
let config = HyperbolicAttentionConfig {
dim,
curvature,
..Default::default()
};
Self {
inner: HyperbolicAttention::new(config),
curvature_value: curvature,
}
}
/// Compute hyperbolic attention
pub fn compute(
&self,
query: &[f32],
keys: JsValue,
values: JsValue,
) -> Result<Vec<f32>, JsError> {
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
self.inner
.compute(query, &keys_refs, &values_refs)
.map_err(|e| JsError::new(&e.to_string()))
}
/// Get the curvature parameter
#[wasm_bindgen(getter)]
pub fn curvature(&self) -> f32 {
self.curvature_value
}
}
// ============================================================================
// Linear Attention (Performer)
// ============================================================================
/// Linear attention using random feature approximation
///
/// Achieves O(n) complexity instead of O(n^2) by approximating
/// the softmax kernel with random Fourier features
#[wasm_bindgen]
pub struct WasmLinearAttention {
inner: LinearAttention,
}
#[wasm_bindgen]
impl WasmLinearAttention {
/// Create a new linear attention instance
///
/// # Arguments
/// * `dim` - Embedding dimension
/// * `num_features` - Number of random features for kernel approximation
#[wasm_bindgen(constructor)]
pub fn new(dim: usize, num_features: usize) -> WasmLinearAttention {
Self {
inner: LinearAttention::new(dim, num_features),
}
}
/// Compute linear attention
pub fn compute(
&self,
query: &[f32],
keys: JsValue,
values: JsValue,
) -> Result<Vec<f32>, JsError> {
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
self.inner
.compute(query, &keys_refs, &values_refs)
.map_err(|e| JsError::new(&e.to_string()))
}
}
// ============================================================================
// Flash Attention
// ============================================================================
/// Flash attention with memory-efficient tiling
///
/// Reduces memory usage from O(n^2) to O(n) by computing attention
/// in blocks and fusing operations
#[wasm_bindgen]
pub struct WasmFlashAttention {
inner: FlashAttention,
}
#[wasm_bindgen]
impl WasmFlashAttention {
/// Create a new flash attention instance
///
/// # Arguments
/// * `dim` - Embedding dimension
/// * `block_size` - Block size for tiled computation
#[wasm_bindgen(constructor)]
pub fn new(dim: usize, block_size: usize) -> WasmFlashAttention {
Self {
inner: FlashAttention::new(dim, block_size),
}
}
/// Compute flash attention
pub fn compute(
&self,
query: &[f32],
keys: JsValue,
values: JsValue,
) -> Result<Vec<f32>, JsError> {
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
self.inner
.compute(query, &keys_refs, &values_refs)
.map_err(|e| JsError::new(&e.to_string()))
}
}
// ============================================================================
// Local-Global Attention
// ============================================================================
/// Local-global sparse attention (Longformer-style)
///
/// Combines local sliding window attention with global tokens
/// for efficient long-range dependencies
#[wasm_bindgen]
pub struct WasmLocalGlobalAttention {
inner: LocalGlobalAttention,
}
#[wasm_bindgen]
impl WasmLocalGlobalAttention {
/// Create a new local-global attention instance
///
/// # Arguments
/// * `dim` - Embedding dimension
/// * `local_window` - Size of local attention window
/// * `global_tokens` - Number of global attention tokens
#[wasm_bindgen(constructor)]
pub fn new(dim: usize, local_window: usize, global_tokens: usize) -> WasmLocalGlobalAttention {
Self {
inner: LocalGlobalAttention::new(dim, local_window, global_tokens),
}
}
/// Compute local-global attention
pub fn compute(
&self,
query: &[f32],
keys: JsValue,
values: JsValue,
) -> Result<Vec<f32>, JsError> {
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
self.inner
.compute(query, &keys_refs, &values_refs)
.map_err(|e| JsError::new(&e.to_string()))
}
}
// ============================================================================
// Mixture of Experts (MoE) Attention
// ============================================================================
/// Mixture of Experts attention mechanism
///
/// Routes queries to specialized expert attention heads based on
/// learned gating functions for capacity-efficient computation
#[wasm_bindgen]
pub struct WasmMoEAttention {
inner: MoEAttention,
}
#[wasm_bindgen]
impl WasmMoEAttention {
/// Create a new MoE attention instance
///
/// # Arguments
/// * `dim` - Embedding dimension
/// * `num_experts` - Number of expert attention mechanisms
/// * `top_k` - Number of experts to activate per query
#[wasm_bindgen(constructor)]
pub fn new(dim: usize, num_experts: usize, top_k: usize) -> WasmMoEAttention {
let config = MoEConfig::builder()
.dim(dim)
.num_experts(num_experts)
.top_k(top_k)
.build();
Self {
inner: MoEAttention::new(config),
}
}
/// Compute MoE attention
pub fn compute(
&self,
query: &[f32],
keys: JsValue,
values: JsValue,
) -> Result<Vec<f32>, JsError> {
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
self.inner
.compute(query, &keys_refs, &values_refs)
.map_err(|e| JsError::new(&e.to_string()))
}
}
// ============================================================================
// Tests
// ============================================================================
#[cfg(test)]
mod tests {
use super::*;
use wasm_bindgen_test::*;
wasm_bindgen_test_configure!(run_in_browser);
#[wasm_bindgen_test]
fn test_multi_head_creation() {
let mha = WasmMultiHeadAttention::new(64, 8);
assert!(mha.is_ok());
let mha = mha.unwrap();
assert_eq!(mha.dim(), 64);
assert_eq!(mha.num_heads(), 8);
assert_eq!(mha.head_dim(), 8);
}
#[wasm_bindgen_test]
fn test_multi_head_invalid_dims() {
let mha = WasmMultiHeadAttention::new(65, 8);
assert!(mha.is_err());
}
#[wasm_bindgen_test]
fn test_hyperbolic_attention() {
let hyp = WasmHyperbolicAttention::new(32, -1.0);
assert_eq!(hyp.curvature(), -1.0);
}
#[wasm_bindgen_test]
fn test_linear_attention_creation() {
let linear = WasmLinearAttention::new(64, 128);
// Just verify it can be created
assert!(true);
}
#[wasm_bindgen_test]
fn test_flash_attention_creation() {
let flash = WasmFlashAttention::new(64, 16);
assert!(true);
}
#[wasm_bindgen_test]
fn test_local_global_creation() {
let lg = WasmLocalGlobalAttention::new(64, 128, 4);
assert!(true);
}
#[wasm_bindgen_test]
fn test_moe_attention_creation() {
let moe = WasmMoEAttention::new(64, 8, 2);
assert!(true);
}
}