Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
806
vendor/ruvector/crates/ruvector-attention-unified-wasm/src/dag.rs
vendored
Normal file
806
vendor/ruvector/crates/ruvector-attention-unified-wasm/src/dag.rs
vendored
Normal file
@@ -0,0 +1,806 @@
|
||||
//! DAG Attention Mechanisms (from ruvector-dag)
|
||||
//!
|
||||
//! Re-exports the 7 DAG-specific attention mechanisms:
|
||||
//! - Topological Attention
|
||||
//! - Causal Cone Attention
|
||||
//! - Critical Path Attention
|
||||
//! - MinCut-Gated Attention
|
||||
//! - Hierarchical Lorentz Attention
|
||||
//! - Parallel Branch Attention
|
||||
//! - Temporal BTSP Attention
|
||||
|
||||
use ruvector_dag::{OperatorNode, QueryDag};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
// ============================================================================
|
||||
// Minimal DAG for WASM
|
||||
// ============================================================================
|
||||
|
||||
/// Minimal DAG structure for WASM attention computation
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmQueryDag {
|
||||
inner: QueryDag,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmQueryDag {
|
||||
/// Create a new empty DAG
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new() -> WasmQueryDag {
|
||||
WasmQueryDag {
|
||||
inner: QueryDag::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a node with operator type and cost
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `op_type` - Operator type: "scan", "filter", "join", "aggregate", "project", "sort"
|
||||
/// * `cost` - Estimated execution cost
|
||||
///
|
||||
/// # Returns
|
||||
/// Node ID
|
||||
#[wasm_bindgen(js_name = addNode)]
|
||||
pub fn add_node(&mut self, op_type: &str, cost: f32) -> u32 {
|
||||
let table_id = self.inner.node_count() as usize;
|
||||
let mut node = match op_type {
|
||||
"scan" => OperatorNode::seq_scan(table_id, &format!("table_{}", table_id)),
|
||||
"filter" => OperatorNode::filter(table_id, "condition"),
|
||||
"join" => OperatorNode::hash_join(table_id, "join_key"),
|
||||
"aggregate" => OperatorNode::aggregate(table_id, vec!["*".to_string()]),
|
||||
"project" => OperatorNode::project(table_id, vec!["*".to_string()]),
|
||||
"sort" => OperatorNode::sort(table_id, vec!["col".to_string()]),
|
||||
_ => OperatorNode::seq_scan(table_id, "unknown"),
|
||||
};
|
||||
node.estimated_cost = cost as f64;
|
||||
self.inner.add_node(node) as u32
|
||||
}
|
||||
|
||||
/// Add an edge between nodes
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `from` - Source node ID
|
||||
/// * `to` - Target node ID
|
||||
///
|
||||
/// # Returns
|
||||
/// True if edge was added successfully
|
||||
#[wasm_bindgen(js_name = addEdge)]
|
||||
pub fn add_edge(&mut self, from: u32, to: u32) -> bool {
|
||||
self.inner.add_edge(from as usize, to as usize).is_ok()
|
||||
}
|
||||
|
||||
/// Get the number of nodes
|
||||
#[wasm_bindgen(getter, js_name = nodeCount)]
|
||||
pub fn node_count(&self) -> u32 {
|
||||
self.inner.node_count() as u32
|
||||
}
|
||||
|
||||
/// Get the number of edges
|
||||
#[wasm_bindgen(getter, js_name = edgeCount)]
|
||||
pub fn edge_count(&self) -> u32 {
|
||||
self.inner.edge_count() as u32
|
||||
}
|
||||
|
||||
/// Serialize to JSON
|
||||
#[wasm_bindgen(js_name = toJson)]
|
||||
pub fn to_json(&self) -> String {
|
||||
serde_json::to_string(&DagSummary {
|
||||
node_count: self.inner.node_count(),
|
||||
edge_count: self.inner.edge_count(),
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
impl WasmQueryDag {
|
||||
/// Get internal reference
|
||||
pub(crate) fn inner(&self) -> &QueryDag {
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct DagSummary {
|
||||
node_count: usize,
|
||||
edge_count: usize,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Helper trait for converting HashMap scores to Vec
|
||||
// ============================================================================
|
||||
|
||||
fn hashmap_to_vec(scores: &HashMap<usize, f32>, n: usize) -> Vec<f32> {
|
||||
(0..n)
|
||||
.map(|i| scores.get(&i).copied().unwrap_or(0.0))
|
||||
.collect()
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Topological Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Topological attention based on DAG position
|
||||
///
|
||||
/// Assigns attention scores based on node position in topological order.
|
||||
/// Earlier nodes (closer to sources) get higher attention.
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmTopologicalAttention {
|
||||
decay_factor: f32,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmTopologicalAttention {
|
||||
/// Create a new topological attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `decay_factor` - Decay factor for position-based attention (0.0-1.0)
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(decay_factor: f32) -> WasmTopologicalAttention {
|
||||
WasmTopologicalAttention { decay_factor }
|
||||
}
|
||||
|
||||
/// Compute attention scores for the DAG
|
||||
///
|
||||
/// # Returns
|
||||
/// Attention scores for each node
|
||||
pub fn forward(&self, dag: &WasmQueryDag) -> Result<Vec<f32>, JsError> {
|
||||
let n = dag.inner.node_count();
|
||||
if n == 0 {
|
||||
return Err(JsError::new("Empty DAG"));
|
||||
}
|
||||
|
||||
let depths = dag.inner.compute_depths();
|
||||
let max_depth = depths.values().max().copied().unwrap_or(0);
|
||||
|
||||
let mut scores = HashMap::new();
|
||||
let mut total = 0.0f32;
|
||||
|
||||
for (&node_id, &depth) in &depths {
|
||||
let normalized_depth = depth as f32 / (max_depth.max(1) as f32);
|
||||
let score = self.decay_factor.powf(1.0 - normalized_depth);
|
||||
scores.insert(node_id, score);
|
||||
total += score;
|
||||
}
|
||||
|
||||
if total > 0.0 {
|
||||
for score in scores.values_mut() {
|
||||
*score /= total;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(hashmap_to_vec(&scores, n))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Causal Cone Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Causal cone attention based on dependency lightcones
|
||||
///
|
||||
/// Nodes can only attend to ancestors in the DAG (causal predecessors).
|
||||
/// Attention strength decays with causal distance.
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmCausalConeAttention {
|
||||
future_discount: f32,
|
||||
ancestor_weight: f32,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmCausalConeAttention {
|
||||
/// Create a new causal cone attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `future_discount` - Discount for future nodes
|
||||
/// * `ancestor_weight` - Weight for ancestor influence
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(future_discount: f32, ancestor_weight: f32) -> WasmCausalConeAttention {
|
||||
WasmCausalConeAttention {
|
||||
future_discount,
|
||||
ancestor_weight,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute attention scores for the DAG
|
||||
pub fn forward(&self, dag: &WasmQueryDag) -> Result<Vec<f32>, JsError> {
|
||||
let n = dag.inner.node_count();
|
||||
if n == 0 {
|
||||
return Err(JsError::new("Empty DAG"));
|
||||
}
|
||||
|
||||
let mut scores = HashMap::new();
|
||||
let mut total = 0.0f32;
|
||||
|
||||
let depths = dag.inner.compute_depths();
|
||||
|
||||
for node_id in 0..n {
|
||||
if dag.inner.get_node(node_id).is_none() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let ancestors = dag.inner.ancestors(node_id);
|
||||
let ancestor_count = ancestors.len();
|
||||
|
||||
let mut score = 1.0 + (ancestor_count as f32 * self.ancestor_weight);
|
||||
|
||||
if let Some(&depth) = depths.get(&node_id) {
|
||||
score *= self.future_discount.powi(depth as i32);
|
||||
}
|
||||
|
||||
scores.insert(node_id, score);
|
||||
total += score;
|
||||
}
|
||||
|
||||
if total > 0.0 {
|
||||
for score in scores.values_mut() {
|
||||
*score /= total;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(hashmap_to_vec(&scores, n))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Critical Path Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Critical path attention weighted by path criticality
|
||||
///
|
||||
/// Nodes on or near the critical path (longest execution path)
|
||||
/// receive higher attention scores.
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmCriticalPathAttention {
|
||||
path_weight: f32,
|
||||
branch_penalty: f32,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmCriticalPathAttention {
|
||||
/// Create a new critical path attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `path_weight` - Weight for critical path membership
|
||||
/// * `branch_penalty` - Penalty for branching nodes
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(path_weight: f32, branch_penalty: f32) -> WasmCriticalPathAttention {
|
||||
WasmCriticalPathAttention {
|
||||
path_weight,
|
||||
branch_penalty,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the critical path (longest path by cost)
|
||||
fn compute_critical_path(&self, dag: &QueryDag) -> Vec<usize> {
|
||||
let mut longest_path: HashMap<usize, (f64, Vec<usize>)> = HashMap::new();
|
||||
|
||||
for &leaf in &dag.leaves() {
|
||||
if let Some(node) = dag.get_node(leaf) {
|
||||
longest_path.insert(leaf, (node.estimated_cost, vec![leaf]));
|
||||
}
|
||||
}
|
||||
|
||||
if let Ok(topo_order) = dag.topological_sort() {
|
||||
for &node_id in topo_order.iter().rev() {
|
||||
let node = match dag.get_node(node_id) {
|
||||
Some(n) => n,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let mut max_cost = node.estimated_cost;
|
||||
let mut max_path = vec![node_id];
|
||||
|
||||
for &child in dag.children(node_id) {
|
||||
if let Some(&(child_cost, ref child_path)) = longest_path.get(&child) {
|
||||
let total_cost = node.estimated_cost + child_cost;
|
||||
if total_cost > max_cost {
|
||||
max_cost = total_cost;
|
||||
max_path = vec![node_id];
|
||||
max_path.extend(child_path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
longest_path.insert(node_id, (max_cost, max_path));
|
||||
}
|
||||
}
|
||||
|
||||
longest_path
|
||||
.into_iter()
|
||||
.max_by(|a, b| {
|
||||
a.1 .0
|
||||
.partial_cmp(&b.1 .0)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
})
|
||||
.map(|(_, (_, path))| path)
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Compute attention scores for the DAG
|
||||
pub fn forward(&self, dag: &WasmQueryDag) -> Result<Vec<f32>, JsError> {
|
||||
let n = dag.inner.node_count();
|
||||
if n == 0 {
|
||||
return Err(JsError::new("Empty DAG"));
|
||||
}
|
||||
|
||||
let critical = self.compute_critical_path(&dag.inner);
|
||||
let mut scores = HashMap::new();
|
||||
let mut total = 0.0f32;
|
||||
|
||||
for node_id in 0..n {
|
||||
if dag.inner.get_node(node_id).is_none() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let is_on_critical_path = critical.contains(&node_id);
|
||||
let num_children = dag.inner.children(node_id).len();
|
||||
|
||||
let mut score = if is_on_critical_path {
|
||||
self.path_weight
|
||||
} else {
|
||||
1.0
|
||||
};
|
||||
|
||||
if num_children > 1 {
|
||||
score *= 1.0 + (num_children as f32 - 1.0) * self.branch_penalty;
|
||||
}
|
||||
|
||||
scores.insert(node_id, score);
|
||||
total += score;
|
||||
}
|
||||
|
||||
if total > 0.0 {
|
||||
for score in scores.values_mut() {
|
||||
*score /= total;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(hashmap_to_vec(&scores, n))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// MinCut-Gated Attention
|
||||
// ============================================================================
|
||||
|
||||
/// MinCut-gated attention using flow-based bottleneck detection
|
||||
///
|
||||
/// Uses minimum cut analysis to identify bottleneck nodes
|
||||
/// and gates attention through these critical points.
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmMinCutGatedAttention {
|
||||
gate_threshold: f32,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmMinCutGatedAttention {
|
||||
/// Create a new MinCut-gated attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `gate_threshold` - Threshold for gating (0.0-1.0)
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(gate_threshold: f32) -> WasmMinCutGatedAttention {
|
||||
WasmMinCutGatedAttention { gate_threshold }
|
||||
}
|
||||
|
||||
/// Compute attention scores for the DAG
|
||||
pub fn forward(&self, dag: &WasmQueryDag) -> Result<Vec<f32>, JsError> {
|
||||
let n = dag.inner.node_count();
|
||||
if n == 0 {
|
||||
return Err(JsError::new("Empty DAG"));
|
||||
}
|
||||
|
||||
// Simple bottleneck detection: nodes with high in-degree and out-degree
|
||||
let mut scores = HashMap::new();
|
||||
let mut total = 0.0f32;
|
||||
|
||||
for node_id in 0..n {
|
||||
if dag.inner.get_node(node_id).is_none() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let in_degree = dag.inner.parents(node_id).len();
|
||||
let out_degree = dag.inner.children(node_id).len();
|
||||
|
||||
// Bottleneck score: higher for nodes with high connectivity
|
||||
let connectivity = (in_degree + out_degree) as f32;
|
||||
let is_bottleneck = connectivity >= self.gate_threshold * n as f32;
|
||||
|
||||
let score = if is_bottleneck {
|
||||
2.0 + connectivity * 0.1
|
||||
} else {
|
||||
1.0
|
||||
};
|
||||
|
||||
scores.insert(node_id, score);
|
||||
total += score;
|
||||
}
|
||||
|
||||
if total > 0.0 {
|
||||
for score in scores.values_mut() {
|
||||
*score /= total;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(hashmap_to_vec(&scores, n))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Hierarchical Lorentz Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Hierarchical Lorentz attention in hyperbolic space
|
||||
///
|
||||
/// Combines DAG hierarchy with Lorentz (hyperboloid) geometry
|
||||
/// for multi-scale hierarchical attention.
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmHierarchicalLorentzAttention {
|
||||
curvature: f32,
|
||||
temperature: f32,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmHierarchicalLorentzAttention {
|
||||
/// Create a new hierarchical Lorentz attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `curvature` - Hyperbolic curvature parameter
|
||||
/// * `temperature` - Temperature for softmax
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(curvature: f32, temperature: f32) -> WasmHierarchicalLorentzAttention {
|
||||
WasmHierarchicalLorentzAttention {
|
||||
curvature,
|
||||
temperature,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute attention scores for the DAG
|
||||
pub fn forward(&self, dag: &WasmQueryDag) -> Result<Vec<f32>, JsError> {
|
||||
let n = dag.inner.node_count();
|
||||
if n == 0 {
|
||||
return Err(JsError::new("Empty DAG"));
|
||||
}
|
||||
|
||||
let depths = dag.inner.compute_depths();
|
||||
let max_depth = depths.values().max().copied().unwrap_or(0);
|
||||
|
||||
// Compute hyperbolic distances from origin
|
||||
let mut distances: Vec<f32> = Vec::with_capacity(n);
|
||||
for node_id in 0..n {
|
||||
let depth = depths.get(&node_id).copied().unwrap_or(0);
|
||||
// In hyperbolic space, distance grows exponentially with depth
|
||||
let radial = (depth as f32 * 0.5).tanh();
|
||||
let distance = (1.0 + radial).acosh() * self.curvature.abs();
|
||||
distances.push(distance);
|
||||
}
|
||||
|
||||
// Convert to attention scores using softmax
|
||||
let max_neg_dist = distances
|
||||
.iter()
|
||||
.map(|&d| -d / self.temperature)
|
||||
.fold(f32::NEG_INFINITY, f32::max);
|
||||
let exp_sum: f32 = distances
|
||||
.iter()
|
||||
.map(|&d| ((-d / self.temperature) - max_neg_dist).exp())
|
||||
.sum();
|
||||
|
||||
let scores: Vec<f32> = distances
|
||||
.iter()
|
||||
.map(|&d| ((-d / self.temperature) - max_neg_dist).exp() / exp_sum.max(1e-10))
|
||||
.collect();
|
||||
|
||||
Ok(scores)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Parallel Branch Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Parallel branch attention for concurrent DAG branches
|
||||
///
|
||||
/// Identifies parallel branches in the DAG and applies
|
||||
/// attention patterns that respect branch independence.
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmParallelBranchAttention {
|
||||
max_branches: usize,
|
||||
sync_penalty: f32,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmParallelBranchAttention {
|
||||
/// Create a new parallel branch attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `max_branches` - Maximum number of branches to consider
|
||||
/// * `sync_penalty` - Penalty for synchronization between branches
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(max_branches: usize, sync_penalty: f32) -> WasmParallelBranchAttention {
|
||||
WasmParallelBranchAttention {
|
||||
max_branches,
|
||||
sync_penalty,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute attention scores for the DAG
|
||||
pub fn forward(&self, dag: &WasmQueryDag) -> Result<Vec<f32>, JsError> {
|
||||
let n = dag.inner.node_count();
|
||||
if n == 0 {
|
||||
return Err(JsError::new("Empty DAG"));
|
||||
}
|
||||
|
||||
// Detect branch points (nodes with multiple children)
|
||||
let mut branch_starts: Vec<usize> = Vec::new();
|
||||
for node_id in 0..n {
|
||||
if dag.inner.children(node_id).len() > 1 {
|
||||
branch_starts.push(node_id);
|
||||
}
|
||||
}
|
||||
|
||||
let mut scores = HashMap::new();
|
||||
let mut total = 0.0f32;
|
||||
|
||||
for node_id in 0..n {
|
||||
if dag.inner.get_node(node_id).is_none() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if node is part of a parallel branch
|
||||
let parents = dag.inner.parents(node_id);
|
||||
let is_branch_child = parents.iter().any(|&p| branch_starts.contains(&p));
|
||||
|
||||
let children = dag.inner.children(node_id);
|
||||
let is_sync_point = children.len() == 0 && parents.len() > 1;
|
||||
|
||||
let score = if is_branch_child {
|
||||
1.5 // Boost parallel branch nodes
|
||||
} else if is_sync_point {
|
||||
1.0 * (1.0 - self.sync_penalty) // Penalize sync points
|
||||
} else {
|
||||
1.0
|
||||
};
|
||||
|
||||
scores.insert(node_id, score);
|
||||
total += score;
|
||||
}
|
||||
|
||||
if total > 0.0 {
|
||||
for score in scores.values_mut() {
|
||||
*score /= total;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(hashmap_to_vec(&scores, n))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Temporal BTSP Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Temporal BTSP (Behavioral Time-Series Pattern) attention
|
||||
///
|
||||
/// Incorporates temporal patterns and behavioral sequences
|
||||
/// for time-aware DAG attention.
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmTemporalBTSPAttention {
|
||||
eligibility_decay: f32,
|
||||
baseline_attention: f32,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmTemporalBTSPAttention {
|
||||
/// Create a new temporal BTSP attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `eligibility_decay` - Decay rate for eligibility traces (0.0-1.0)
|
||||
/// * `baseline_attention` - Baseline attention for nodes without history
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(eligibility_decay: f32, baseline_attention: f32) -> WasmTemporalBTSPAttention {
|
||||
WasmTemporalBTSPAttention {
|
||||
eligibility_decay,
|
||||
baseline_attention,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute attention scores for the DAG
|
||||
pub fn forward(&self, dag: &WasmQueryDag) -> Result<Vec<f32>, JsError> {
|
||||
let n = dag.inner.node_count();
|
||||
if n == 0 {
|
||||
return Err(JsError::new("Empty DAG"));
|
||||
}
|
||||
|
||||
let mut scores = Vec::with_capacity(n);
|
||||
let mut total = 0.0f32;
|
||||
|
||||
for node_id in 0..n {
|
||||
let node = match dag.inner.get_node(node_id) {
|
||||
Some(n) => n,
|
||||
None => {
|
||||
scores.push(0.0);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Base score from cost and rows
|
||||
let cost_factor = (node.estimated_cost as f32 / 100.0).min(1.0);
|
||||
let rows_factor = (node.estimated_rows as f32 / 1000.0).min(1.0);
|
||||
let score = self.baseline_attention * (0.5 * cost_factor + 0.5 * rows_factor + 0.5);
|
||||
|
||||
scores.push(score);
|
||||
total += score;
|
||||
}
|
||||
|
||||
// Normalize
|
||||
if total > 0.0 {
|
||||
for score in scores.iter_mut() {
|
||||
*score /= total;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(scores)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// DAG Attention Factory
|
||||
// ============================================================================
|
||||
|
||||
/// Factory for creating DAG attention mechanisms
|
||||
#[wasm_bindgen]
|
||||
pub struct DagAttentionFactory;
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl DagAttentionFactory {
|
||||
/// Get available DAG attention types
|
||||
#[wasm_bindgen(js_name = availableTypes)]
|
||||
pub fn available_types() -> JsValue {
|
||||
let types = vec![
|
||||
"topological",
|
||||
"causal_cone",
|
||||
"critical_path",
|
||||
"mincut_gated",
|
||||
"hierarchical_lorentz",
|
||||
"parallel_branch",
|
||||
"temporal_btsp",
|
||||
];
|
||||
serde_wasm_bindgen::to_value(&types).unwrap()
|
||||
}
|
||||
|
||||
/// Get description for a DAG attention type
|
||||
#[wasm_bindgen(js_name = getDescription)]
|
||||
pub fn get_description(attention_type: &str) -> String {
|
||||
match attention_type {
|
||||
"topological" => "Position-based attention following DAG topological order".to_string(),
|
||||
"causal_cone" => "Lightcone-based attention respecting causal dependencies".to_string(),
|
||||
"critical_path" => "Attention weighted by critical execution path distance".to_string(),
|
||||
"mincut_gated" => "Flow-based gating through bottleneck nodes".to_string(),
|
||||
"hierarchical_lorentz" => {
|
||||
"Multi-scale hyperbolic attention for DAG hierarchies".to_string()
|
||||
}
|
||||
"parallel_branch" => "Branch-aware attention for parallel DAG structures".to_string(),
|
||||
"temporal_btsp" => "Time-series pattern attention for temporal DAGs".to_string(),
|
||||
_ => "Unknown attention type".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use wasm_bindgen_test::*;
|
||||
|
||||
wasm_bindgen_test_configure!(run_in_browser);
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_dag_creation() {
|
||||
let mut dag = WasmQueryDag::new();
|
||||
let n1 = dag.add_node("scan", 1.0);
|
||||
let n2 = dag.add_node("filter", 0.5);
|
||||
dag.add_edge(n1, n2);
|
||||
|
||||
assert_eq!(dag.node_count(), 2);
|
||||
assert_eq!(dag.edge_count(), 1);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_topological_attention() {
|
||||
let mut dag = WasmQueryDag::new();
|
||||
dag.add_node("scan", 1.0);
|
||||
dag.add_node("filter", 0.5);
|
||||
dag.add_node("project", 0.3);
|
||||
dag.add_edge(0, 1);
|
||||
dag.add_edge(1, 2);
|
||||
|
||||
let attention = WasmTopologicalAttention::new(0.9);
|
||||
let scores = attention.forward(&dag);
|
||||
assert!(scores.is_ok());
|
||||
let s = scores.unwrap();
|
||||
assert_eq!(s.len(), 3);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_causal_cone_attention() {
|
||||
let mut dag = WasmQueryDag::new();
|
||||
dag.add_node("scan", 1.0);
|
||||
dag.add_node("filter", 0.5);
|
||||
dag.add_edge(0, 1);
|
||||
|
||||
let attention = WasmCausalConeAttention::new(0.8, 0.9);
|
||||
let scores = attention.forward(&dag);
|
||||
assert!(scores.is_ok());
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_critical_path_attention() {
|
||||
let mut dag = WasmQueryDag::new();
|
||||
dag.add_node("scan", 1.0);
|
||||
dag.add_node("filter", 0.5);
|
||||
dag.add_edge(0, 1);
|
||||
|
||||
let attention = WasmCriticalPathAttention::new(2.0, 0.5);
|
||||
let scores = attention.forward(&dag);
|
||||
assert!(scores.is_ok());
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_mincut_gated_attention() {
|
||||
let mut dag = WasmQueryDag::new();
|
||||
dag.add_node("scan", 1.0);
|
||||
dag.add_node("filter", 0.5);
|
||||
dag.add_edge(0, 1);
|
||||
|
||||
let attention = WasmMinCutGatedAttention::new(0.5);
|
||||
let scores = attention.forward(&dag);
|
||||
assert!(scores.is_ok());
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_hierarchical_lorentz_attention() {
|
||||
let mut dag = WasmQueryDag::new();
|
||||
dag.add_node("scan", 1.0);
|
||||
dag.add_node("filter", 0.5);
|
||||
dag.add_edge(0, 1);
|
||||
|
||||
let attention = WasmHierarchicalLorentzAttention::new(-1.0, 0.1);
|
||||
let scores = attention.forward(&dag);
|
||||
assert!(scores.is_ok());
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_parallel_branch_attention() {
|
||||
let mut dag = WasmQueryDag::new();
|
||||
dag.add_node("scan", 1.0);
|
||||
dag.add_node("filter", 0.5);
|
||||
dag.add_edge(0, 1);
|
||||
|
||||
let attention = WasmParallelBranchAttention::new(8, 0.2);
|
||||
let scores = attention.forward(&dag);
|
||||
assert!(scores.is_ok());
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_temporal_btsp_attention() {
|
||||
let mut dag = WasmQueryDag::new();
|
||||
dag.add_node("scan", 1.0);
|
||||
dag.add_node("filter", 0.5);
|
||||
dag.add_edge(0, 1);
|
||||
|
||||
let attention = WasmTemporalBTSPAttention::new(0.95, 0.5);
|
||||
let scores = attention.forward(&dag);
|
||||
assert!(scores.is_ok());
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_factory_types() {
|
||||
let types_js = DagAttentionFactory::available_types();
|
||||
assert!(!types_js.is_null());
|
||||
}
|
||||
}
|
||||
417
vendor/ruvector/crates/ruvector-attention-unified-wasm/src/graph.rs
vendored
Normal file
417
vendor/ruvector/crates/ruvector-attention-unified-wasm/src/graph.rs
vendored
Normal file
@@ -0,0 +1,417 @@
|
||||
//! Graph Attention Mechanisms (from ruvector-gnn)
|
||||
//!
|
||||
//! Re-exports graph neural network attention mechanisms:
|
||||
//! - GAT (Graph Attention Networks)
|
||||
//! - GCN (Graph Convolutional Networks)
|
||||
//! - GraphSAGE (Sample and Aggregate)
|
||||
|
||||
use ruvector_gnn::{
|
||||
differentiable_search as core_differentiable_search,
|
||||
hierarchical_forward as core_hierarchical_forward, CompressedTensor, CompressionLevel,
|
||||
RuvectorLayer, TensorCompress,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
// ============================================================================
|
||||
// GNN Layer (GAT-based)
|
||||
// ============================================================================
|
||||
|
||||
/// Graph Neural Network layer with attention mechanism
|
||||
///
|
||||
/// Implements Graph Attention Networks (GAT) for HNSW topology.
|
||||
/// Each node aggregates information from neighbors using learned attention weights.
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmGNNLayer {
|
||||
inner: RuvectorLayer,
|
||||
hidden_dim: usize,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmGNNLayer {
|
||||
/// Create a new GNN layer with attention
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `input_dim` - Dimension of input node embeddings
|
||||
/// * `hidden_dim` - Dimension of hidden representations
|
||||
/// * `heads` - Number of attention heads
|
||||
/// * `dropout` - Dropout rate (0.0 to 1.0)
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(
|
||||
input_dim: usize,
|
||||
hidden_dim: usize,
|
||||
heads: usize,
|
||||
dropout: f32,
|
||||
) -> Result<WasmGNNLayer, JsError> {
|
||||
let inner = RuvectorLayer::new(input_dim, hidden_dim, heads, dropout)
|
||||
.map_err(|e| JsError::new(&e.to_string()))?;
|
||||
|
||||
Ok(WasmGNNLayer { inner, hidden_dim })
|
||||
}
|
||||
|
||||
/// Forward pass through the GNN layer
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `node_embedding` - Current node's embedding (Float32Array)
|
||||
/// * `neighbor_embeddings` - Embeddings of neighbor nodes (array of Float32Arrays)
|
||||
/// * `edge_weights` - Weights of edges to neighbors (Float32Array)
|
||||
///
|
||||
/// # Returns
|
||||
/// Updated node embedding (Float32Array)
|
||||
pub fn forward(
|
||||
&self,
|
||||
node_embedding: Vec<f32>,
|
||||
neighbor_embeddings: JsValue,
|
||||
edge_weights: Vec<f32>,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let neighbors: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(neighbor_embeddings)
|
||||
.map_err(|e| JsError::new(&format!("Failed to parse neighbor embeddings: {}", e)))?;
|
||||
|
||||
if neighbors.len() != edge_weights.len() {
|
||||
return Err(JsError::new(&format!(
|
||||
"Number of neighbors ({}) must match number of edge weights ({})",
|
||||
neighbors.len(),
|
||||
edge_weights.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let result = self
|
||||
.inner
|
||||
.forward(&node_embedding, &neighbors, &edge_weights);
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Get the output dimension
|
||||
#[wasm_bindgen(getter, js_name = outputDim)]
|
||||
pub fn output_dim(&self) -> usize {
|
||||
self.hidden_dim
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tensor Compression (for efficient GNN)
|
||||
// ============================================================================
|
||||
|
||||
/// Tensor compressor with adaptive level selection
|
||||
///
|
||||
/// Compresses embeddings based on access frequency for memory-efficient GNN
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmTensorCompress {
|
||||
inner: TensorCompress,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmTensorCompress {
|
||||
/// Create a new tensor compressor
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
inner: TensorCompress::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compress an embedding based on access frequency
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `embedding` - The input embedding vector
|
||||
/// * `access_freq` - Access frequency in range [0.0, 1.0]
|
||||
/// - f > 0.8: Full precision (hot data)
|
||||
/// - f > 0.4: Half precision (warm data)
|
||||
/// - f > 0.1: 8-bit PQ (cool data)
|
||||
/// - f > 0.01: 4-bit PQ (cold data)
|
||||
/// - f <= 0.01: Binary (archive)
|
||||
pub fn compress(&self, embedding: Vec<f32>, access_freq: f32) -> Result<JsValue, JsError> {
|
||||
let compressed = self
|
||||
.inner
|
||||
.compress(&embedding, access_freq)
|
||||
.map_err(|e| JsError::new(&format!("Compression failed: {}", e)))?;
|
||||
|
||||
serde_wasm_bindgen::to_value(&compressed)
|
||||
.map_err(|e| JsError::new(&format!("Serialization failed: {}", e)))
|
||||
}
|
||||
|
||||
/// Compress with explicit compression level
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `embedding` - The input embedding vector
|
||||
/// * `level` - Compression level: "none", "half", "pq8", "pq4", "binary"
|
||||
#[wasm_bindgen(js_name = compressWithLevel)]
|
||||
pub fn compress_with_level(
|
||||
&self,
|
||||
embedding: Vec<f32>,
|
||||
level: &str,
|
||||
) -> Result<JsValue, JsError> {
|
||||
let compression_level = match level {
|
||||
"none" => CompressionLevel::None,
|
||||
"half" => CompressionLevel::Half { scale: 1.0 },
|
||||
"pq8" => CompressionLevel::PQ8 {
|
||||
subvectors: 8,
|
||||
centroids: 16,
|
||||
},
|
||||
"pq4" => CompressionLevel::PQ4 {
|
||||
subvectors: 8,
|
||||
outlier_threshold: 3.0,
|
||||
},
|
||||
"binary" => CompressionLevel::Binary { threshold: 0.0 },
|
||||
_ => {
|
||||
return Err(JsError::new(&format!(
|
||||
"Unknown compression level: {}",
|
||||
level
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
let compressed = self
|
||||
.inner
|
||||
.compress_with_level(&embedding, &compression_level)
|
||||
.map_err(|e| JsError::new(&format!("Compression failed: {}", e)))?;
|
||||
|
||||
serde_wasm_bindgen::to_value(&compressed)
|
||||
.map_err(|e| JsError::new(&format!("Serialization failed: {}", e)))
|
||||
}
|
||||
|
||||
/// Decompress a compressed tensor
|
||||
pub fn decompress(&self, compressed: JsValue) -> Result<Vec<f32>, JsError> {
|
||||
let compressed_tensor: CompressedTensor = serde_wasm_bindgen::from_value(compressed)
|
||||
.map_err(|e| JsError::new(&format!("Deserialization failed: {}", e)))?;
|
||||
|
||||
self.inner
|
||||
.decompress(&compressed_tensor)
|
||||
.map_err(|e| JsError::new(&format!("Decompression failed: {}", e)))
|
||||
}
|
||||
|
||||
/// Get compression ratio estimate for a given access frequency
|
||||
#[wasm_bindgen(js_name = getCompressionRatio)]
|
||||
pub fn get_compression_ratio(&self, access_freq: f32) -> f32 {
|
||||
if access_freq > 0.8 {
|
||||
1.0
|
||||
} else if access_freq > 0.4 {
|
||||
2.0
|
||||
} else if access_freq > 0.1 {
|
||||
4.0
|
||||
} else if access_freq > 0.01 {
|
||||
8.0
|
||||
} else {
|
||||
32.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Search Configuration
|
||||
// ============================================================================
|
||||
|
||||
/// Search configuration for differentiable search
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmSearchConfig {
|
||||
/// Number of top results to return
|
||||
pub k: usize,
|
||||
/// Temperature for softmax
|
||||
pub temperature: f32,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmSearchConfig {
|
||||
/// Create a new search configuration
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(k: usize, temperature: f32) -> Self {
|
||||
Self { k, temperature }
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Differentiable Search
|
||||
// ============================================================================
|
||||
|
||||
/// Differentiable search using soft attention mechanism
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query` - The query vector
|
||||
/// * `candidate_embeddings` - List of candidate embedding vectors
|
||||
/// * `config` - Search configuration
|
||||
///
|
||||
/// # Returns
|
||||
/// Object with indices and weights for top-k candidates
|
||||
#[wasm_bindgen(js_name = graphDifferentiableSearch)]
|
||||
pub fn differentiable_search(
|
||||
query: Vec<f32>,
|
||||
candidate_embeddings: JsValue,
|
||||
config: &WasmSearchConfig,
|
||||
) -> Result<JsValue, JsError> {
|
||||
let candidates: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(candidate_embeddings)
|
||||
.map_err(|e| JsError::new(&format!("Failed to parse candidate embeddings: {}", e)))?;
|
||||
|
||||
let (indices, weights) =
|
||||
core_differentiable_search(&query, &candidates, config.k, config.temperature);
|
||||
|
||||
let result = SearchResult { indices, weights };
|
||||
serde_wasm_bindgen::to_value(&result)
|
||||
.map_err(|e| JsError::new(&format!("Failed to serialize result: {}", e)))
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct SearchResult {
|
||||
indices: Vec<usize>,
|
||||
weights: Vec<f32>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Hierarchical Forward
|
||||
// ============================================================================
|
||||
|
||||
/// Hierarchical forward pass through multiple GNN layers
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query` - The query vector
|
||||
/// * `layer_embeddings` - Embeddings organized by layer
|
||||
/// * `gnn_layers` - Array of GNN layers
|
||||
///
|
||||
/// # Returns
|
||||
/// Final embedding after hierarchical processing
|
||||
#[wasm_bindgen(js_name = graphHierarchicalForward)]
|
||||
pub fn hierarchical_forward(
|
||||
query: Vec<f32>,
|
||||
layer_embeddings: JsValue,
|
||||
gnn_layers: Vec<WasmGNNLayer>,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let embeddings: Vec<Vec<Vec<f32>>> = serde_wasm_bindgen::from_value(layer_embeddings)
|
||||
.map_err(|e| JsError::new(&format!("Failed to parse layer embeddings: {}", e)))?;
|
||||
|
||||
let core_layers: Vec<RuvectorLayer> = gnn_layers.iter().map(|l| l.inner.clone()).collect();
|
||||
|
||||
let result = core_hierarchical_forward(&query, &embeddings, &core_layers);
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Graph Attention Types
|
||||
// ============================================================================
|
||||
|
||||
/// Graph attention mechanism types
|
||||
#[wasm_bindgen]
|
||||
pub enum GraphAttentionType {
|
||||
/// Graph Attention Networks (Velickovic et al., 2018)
|
||||
GAT,
|
||||
/// Graph Convolutional Networks (Kipf & Welling, 2017)
|
||||
GCN,
|
||||
/// GraphSAGE (Hamilton et al., 2017)
|
||||
GraphSAGE,
|
||||
}
|
||||
|
||||
/// Factory for graph attention information
|
||||
#[wasm_bindgen]
|
||||
pub struct GraphAttentionFactory;
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl GraphAttentionFactory {
|
||||
/// Get available graph attention types
|
||||
#[wasm_bindgen(js_name = availableTypes)]
|
||||
pub fn available_types() -> JsValue {
|
||||
let types = vec!["gat", "gcn", "graphsage"];
|
||||
serde_wasm_bindgen::to_value(&types).unwrap()
|
||||
}
|
||||
|
||||
/// Get description for a graph attention type
|
||||
#[wasm_bindgen(js_name = getDescription)]
|
||||
pub fn get_description(attention_type: &str) -> String {
|
||||
match attention_type {
|
||||
"gat" => {
|
||||
"Graph Attention Networks - learns attention weights over neighbors".to_string()
|
||||
}
|
||||
"gcn" => "Graph Convolutional Networks - spectral convolution on graphs".to_string(),
|
||||
"graphsage" => "GraphSAGE - sample and aggregate neighbor features".to_string(),
|
||||
_ => "Unknown graph attention type".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get recommended use cases for a graph attention type
|
||||
#[wasm_bindgen(js_name = getUseCases)]
|
||||
pub fn get_use_cases(attention_type: &str) -> JsValue {
|
||||
let cases = match attention_type {
|
||||
"gat" => vec![
|
||||
"Node classification with varying neighbor importance",
|
||||
"Link prediction in heterogeneous graphs",
|
||||
"Knowledge graph reasoning",
|
||||
],
|
||||
"gcn" => vec![
|
||||
"Semi-supervised node classification",
|
||||
"Graph-level classification",
|
||||
"Spectral clustering",
|
||||
],
|
||||
"graphsage" => vec![
|
||||
"Inductive learning on new nodes",
|
||||
"Large-scale graph processing",
|
||||
"Dynamic graphs with new vertices",
|
||||
],
|
||||
_ => vec!["Unknown type"],
|
||||
};
|
||||
serde_wasm_bindgen::to_value(&cases).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use wasm_bindgen_test::*;
|
||||
|
||||
wasm_bindgen_test_configure!(run_in_browser);
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_gnn_layer_creation() {
|
||||
let layer = WasmGNNLayer::new(4, 8, 2, 0.1);
|
||||
assert!(layer.is_ok());
|
||||
let l = layer.unwrap();
|
||||
assert_eq!(l.output_dim(), 8);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_gnn_layer_invalid_dropout() {
|
||||
let layer = WasmGNNLayer::new(4, 8, 2, 1.5);
|
||||
assert!(layer.is_err());
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_gnn_layer_invalid_heads() {
|
||||
let layer = WasmGNNLayer::new(4, 7, 3, 0.1);
|
||||
assert!(layer.is_err());
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_tensor_compress_creation() {
|
||||
let compressor = WasmTensorCompress::new();
|
||||
assert_eq!(compressor.get_compression_ratio(1.0), 1.0);
|
||||
assert_eq!(compressor.get_compression_ratio(0.5), 2.0);
|
||||
assert_eq!(compressor.get_compression_ratio(0.2), 4.0);
|
||||
assert_eq!(compressor.get_compression_ratio(0.05), 8.0);
|
||||
assert_eq!(compressor.get_compression_ratio(0.005), 32.0);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_search_config() {
|
||||
let config = WasmSearchConfig::new(5, 1.0);
|
||||
assert_eq!(config.k, 5);
|
||||
assert_eq!(config.temperature, 1.0);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_factory_types() {
|
||||
let types_js = GraphAttentionFactory::available_types();
|
||||
assert!(!types_js.is_null());
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_factory_descriptions() {
|
||||
let desc = GraphAttentionFactory::get_description("gat");
|
||||
assert!(desc.contains("Graph Attention"));
|
||||
|
||||
let desc = GraphAttentionFactory::get_description("gcn");
|
||||
assert!(desc.contains("Graph Convolutional"));
|
||||
|
||||
let desc = GraphAttentionFactory::get_description("graphsage");
|
||||
assert!(desc.contains("GraphSAGE"));
|
||||
}
|
||||
}
|
||||
382
vendor/ruvector/crates/ruvector-attention-unified-wasm/src/lib.rs
vendored
Normal file
382
vendor/ruvector/crates/ruvector-attention-unified-wasm/src/lib.rs
vendored
Normal file
@@ -0,0 +1,382 @@
|
||||
//! Unified WebAssembly Attention Library
|
||||
//!
|
||||
//! This crate provides a unified WASM interface for 18+ attention mechanisms:
|
||||
//!
|
||||
//! ## Neural Attention (from ruvector-attention)
|
||||
//! - **Scaled Dot-Product**: Standard transformer attention
|
||||
//! - **Multi-Head**: Parallel attention heads
|
||||
//! - **Hyperbolic**: Attention in hyperbolic space for hierarchical data
|
||||
//! - **Linear**: O(n) Performer-style attention
|
||||
//! - **Flash**: Memory-efficient blocked attention
|
||||
//! - **Local-Global**: Sparse attention with global tokens
|
||||
//! - **MoE**: Mixture of Experts attention
|
||||
//!
|
||||
//! ## DAG Attention (from ruvector-dag)
|
||||
//! - **Topological**: Position-aware attention in DAG order
|
||||
//! - **Causal Cone**: Lightcone-based causal attention
|
||||
//! - **Critical Path**: Attention weighted by critical path distance
|
||||
//! - **MinCut-Gated**: Flow-based gating attention
|
||||
//! - **Hierarchical Lorentz**: Multi-scale hyperbolic DAG attention
|
||||
//! - **Parallel Branch**: Attention for parallel DAG branches
|
||||
//! - **Temporal BTSP**: Behavioral Time-Series Pattern attention
|
||||
//!
|
||||
//! ## Graph Attention (from ruvector-gnn)
|
||||
//! - **GAT**: Graph Attention Networks
|
||||
//! - **GCN**: Graph Convolutional Networks
|
||||
//! - **GraphSAGE**: Sampling and Aggregating graph embeddings
|
||||
//!
|
||||
//! ## State Space Models
|
||||
//! - **Mamba SSM**: Selective State Space Model attention
|
||||
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
// Use wee_alloc for smaller WASM binary (~10KB reduction)
|
||||
#[cfg(feature = "wee_alloc")]
|
||||
#[global_allocator]
|
||||
static ALLOC: wee_alloc::WeeAlloc = wee_alloc::WeeAlloc::INIT;
|
||||
|
||||
// ============================================================================
|
||||
// Module declarations
|
||||
// ============================================================================
|
||||
|
||||
pub mod mamba;
|
||||
|
||||
mod dag;
|
||||
mod graph;
|
||||
mod neural;
|
||||
|
||||
// ============================================================================
|
||||
// Re-exports for convenient access
|
||||
// ============================================================================
|
||||
|
||||
pub use dag::*;
|
||||
pub use graph::*;
|
||||
pub use mamba::*;
|
||||
pub use neural::*;
|
||||
|
||||
// ============================================================================
|
||||
// Initialization
|
||||
// ============================================================================
|
||||
|
||||
/// Initialize the WASM module with panic hook for better error messages
|
||||
#[wasm_bindgen(start)]
|
||||
pub fn init() {
|
||||
#[cfg(feature = "console_error_panic_hook")]
|
||||
console_error_panic_hook::set_once();
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Version and Info
|
||||
// ============================================================================
|
||||
|
||||
/// Get the version of the unified attention WASM crate
|
||||
#[wasm_bindgen]
|
||||
pub fn version() -> String {
|
||||
env!("CARGO_PKG_VERSION").to_string()
|
||||
}
|
||||
|
||||
/// Get information about all available attention mechanisms
|
||||
#[wasm_bindgen(js_name = availableMechanisms)]
|
||||
pub fn available_mechanisms() -> JsValue {
|
||||
let mechanisms = AttentionMechanisms {
|
||||
neural: vec![
|
||||
"scaled_dot_product".into(),
|
||||
"multi_head".into(),
|
||||
"hyperbolic".into(),
|
||||
"linear".into(),
|
||||
"flash".into(),
|
||||
"local_global".into(),
|
||||
"moe".into(),
|
||||
],
|
||||
dag: vec![
|
||||
"topological".into(),
|
||||
"causal_cone".into(),
|
||||
"critical_path".into(),
|
||||
"mincut_gated".into(),
|
||||
"hierarchical_lorentz".into(),
|
||||
"parallel_branch".into(),
|
||||
"temporal_btsp".into(),
|
||||
],
|
||||
graph: vec!["gat".into(), "gcn".into(), "graphsage".into()],
|
||||
ssm: vec!["mamba".into()],
|
||||
};
|
||||
serde_wasm_bindgen::to_value(&mechanisms).unwrap()
|
||||
}
|
||||
|
||||
/// Get summary statistics about the unified attention library
|
||||
#[wasm_bindgen(js_name = getStats)]
|
||||
pub fn get_stats() -> JsValue {
|
||||
let stats = UnifiedStats {
|
||||
total_mechanisms: 18,
|
||||
neural_count: 7,
|
||||
dag_count: 7,
|
||||
graph_count: 3,
|
||||
ssm_count: 1,
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
};
|
||||
serde_wasm_bindgen::to_value(&stats).unwrap()
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Internal Types
|
||||
// ============================================================================
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct AttentionMechanisms {
|
||||
neural: Vec<String>,
|
||||
dag: Vec<String>,
|
||||
graph: Vec<String>,
|
||||
ssm: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct UnifiedStats {
|
||||
total_mechanisms: usize,
|
||||
neural_count: usize,
|
||||
dag_count: usize,
|
||||
graph_count: usize,
|
||||
ssm_count: usize,
|
||||
version: String,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Unified Attention Selector
|
||||
// ============================================================================
|
||||
|
||||
/// Unified attention mechanism selector
|
||||
/// Automatically routes to the appropriate attention implementation
|
||||
#[wasm_bindgen]
|
||||
pub struct UnifiedAttention {
|
||||
mechanism_type: String,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl UnifiedAttention {
|
||||
/// Create a new unified attention selector
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(mechanism: &str) -> Result<UnifiedAttention, JsError> {
|
||||
let valid_mechanisms = [
|
||||
// Neural
|
||||
"scaled_dot_product",
|
||||
"multi_head",
|
||||
"hyperbolic",
|
||||
"linear",
|
||||
"flash",
|
||||
"local_global",
|
||||
"moe",
|
||||
// DAG
|
||||
"topological",
|
||||
"causal_cone",
|
||||
"critical_path",
|
||||
"mincut_gated",
|
||||
"hierarchical_lorentz",
|
||||
"parallel_branch",
|
||||
"temporal_btsp",
|
||||
// Graph
|
||||
"gat",
|
||||
"gcn",
|
||||
"graphsage",
|
||||
// SSM
|
||||
"mamba",
|
||||
];
|
||||
|
||||
if !valid_mechanisms.contains(&mechanism) {
|
||||
return Err(JsError::new(&format!(
|
||||
"Unknown mechanism: {}. Valid options: {:?}",
|
||||
mechanism, valid_mechanisms
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
mechanism_type: mechanism.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Get the currently selected mechanism type
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn mechanism(&self) -> String {
|
||||
self.mechanism_type.clone()
|
||||
}
|
||||
|
||||
/// Get the category of the selected mechanism
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn category(&self) -> String {
|
||||
match self.mechanism_type.as_str() {
|
||||
"scaled_dot_product" | "multi_head" | "hyperbolic" | "linear" | "flash"
|
||||
| "local_global" | "moe" => "neural".to_string(),
|
||||
|
||||
"topological"
|
||||
| "causal_cone"
|
||||
| "critical_path"
|
||||
| "mincut_gated"
|
||||
| "hierarchical_lorentz"
|
||||
| "parallel_branch"
|
||||
| "temporal_btsp" => "dag".to_string(),
|
||||
|
||||
"gat" | "gcn" | "graphsage" => "graph".to_string(),
|
||||
|
||||
"mamba" => "ssm".to_string(),
|
||||
|
||||
_ => "unknown".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if this mechanism supports sequence processing
|
||||
#[wasm_bindgen(js_name = supportsSequences)]
|
||||
pub fn supports_sequences(&self) -> bool {
|
||||
matches!(
|
||||
self.mechanism_type.as_str(),
|
||||
"scaled_dot_product" | "multi_head" | "linear" | "flash" | "local_global" | "mamba"
|
||||
)
|
||||
}
|
||||
|
||||
/// Check if this mechanism supports graph/DAG structures
|
||||
#[wasm_bindgen(js_name = supportsGraphs)]
|
||||
pub fn supports_graphs(&self) -> bool {
|
||||
matches!(
|
||||
self.mechanism_type.as_str(),
|
||||
"topological"
|
||||
| "causal_cone"
|
||||
| "critical_path"
|
||||
| "mincut_gated"
|
||||
| "hierarchical_lorentz"
|
||||
| "parallel_branch"
|
||||
| "temporal_btsp"
|
||||
| "gat"
|
||||
| "gcn"
|
||||
| "graphsage"
|
||||
)
|
||||
}
|
||||
|
||||
/// Check if this mechanism supports hyperbolic geometry
|
||||
#[wasm_bindgen(js_name = supportsHyperbolic)]
|
||||
pub fn supports_hyperbolic(&self) -> bool {
|
||||
matches!(
|
||||
self.mechanism_type.as_str(),
|
||||
"hyperbolic" | "hierarchical_lorentz"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Utility Functions
|
||||
// ============================================================================
|
||||
|
||||
/// Compute cosine similarity between two vectors
|
||||
#[wasm_bindgen(js_name = cosineSimilarity)]
|
||||
pub fn cosine_similarity(a: Vec<f32>, b: Vec<f32>) -> Result<f32, JsError> {
|
||||
if a.len() != b.len() {
|
||||
return Err(JsError::new(&format!(
|
||||
"Vector dimensions must match: {} vs {}",
|
||||
a.len(),
|
||||
b.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
|
||||
if norm_a == 0.0 || norm_b == 0.0 {
|
||||
Ok(0.0)
|
||||
} else {
|
||||
Ok(dot / (norm_a * norm_b))
|
||||
}
|
||||
}
|
||||
|
||||
/// Softmax normalization
|
||||
#[wasm_bindgen]
|
||||
pub fn softmax(values: Vec<f32>) -> Vec<f32> {
|
||||
let max_val = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
|
||||
let exp_values: Vec<f32> = values.iter().map(|&x| (x - max_val).exp()).collect();
|
||||
let sum: f32 = exp_values.iter().sum();
|
||||
exp_values.iter().map(|&x| x / sum).collect()
|
||||
}
|
||||
|
||||
/// Temperature-scaled softmax
|
||||
#[wasm_bindgen(js_name = temperatureSoftmax)]
|
||||
pub fn temperature_softmax(values: Vec<f32>, temperature: f32) -> Vec<f32> {
|
||||
if temperature <= 0.0 {
|
||||
// Return one-hot for the maximum
|
||||
let max_idx = values
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
|
||||
.map(|(i, _)| i)
|
||||
.unwrap_or(0);
|
||||
let mut result = vec![0.0; values.len()];
|
||||
result[max_idx] = 1.0;
|
||||
return result;
|
||||
}
|
||||
|
||||
let scaled: Vec<f32> = values.iter().map(|&x| x / temperature).collect();
|
||||
softmax(scaled)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use wasm_bindgen_test::*;
|
||||
|
||||
wasm_bindgen_test_configure!(run_in_browser);
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_version() {
|
||||
assert!(!version().is_empty());
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_unified_attention_creation() {
|
||||
let attention = UnifiedAttention::new("multi_head");
|
||||
assert!(attention.is_ok());
|
||||
|
||||
let invalid = UnifiedAttention::new("invalid_mechanism");
|
||||
assert!(invalid.is_err());
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_mechanism_categories() {
|
||||
let neural = UnifiedAttention::new("multi_head").unwrap();
|
||||
assert_eq!(neural.category(), "neural");
|
||||
|
||||
let dag = UnifiedAttention::new("topological").unwrap();
|
||||
assert_eq!(dag.category(), "dag");
|
||||
|
||||
let graph = UnifiedAttention::new("gat").unwrap();
|
||||
assert_eq!(graph.category(), "graph");
|
||||
|
||||
let ssm = UnifiedAttention::new("mamba").unwrap();
|
||||
assert_eq!(ssm.category(), "ssm");
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_softmax() {
|
||||
let input = vec![1.0, 2.0, 3.0];
|
||||
let output = softmax(input);
|
||||
|
||||
// Sum should be 1.0
|
||||
let sum: f32 = output.iter().sum();
|
||||
assert!((sum - 1.0).abs() < 1e-6);
|
||||
|
||||
// Should be monotonically increasing
|
||||
assert!(output[0] < output[1]);
|
||||
assert!(output[1] < output[2]);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_cosine_similarity() {
|
||||
let a = vec![1.0, 0.0, 0.0];
|
||||
let b = vec![1.0, 0.0, 0.0];
|
||||
let sim = cosine_similarity(a, b).unwrap();
|
||||
assert!((sim - 1.0).abs() < 1e-6);
|
||||
|
||||
let c = vec![1.0, 0.0, 0.0];
|
||||
let d = vec![0.0, 1.0, 0.0];
|
||||
let sim2 = cosine_similarity(c, d).unwrap();
|
||||
assert!(sim2.abs() < 1e-6);
|
||||
}
|
||||
}
|
||||
554
vendor/ruvector/crates/ruvector-attention-unified-wasm/src/mamba.rs
vendored
Normal file
554
vendor/ruvector/crates/ruvector-attention-unified-wasm/src/mamba.rs
vendored
Normal file
@@ -0,0 +1,554 @@
|
||||
//! Mamba SSM (Selective State Space Model) Attention Mechanism
|
||||
//!
|
||||
//! Implements the Mamba architecture's selective scan mechanism for efficient
|
||||
//! sequence modeling with linear time complexity O(n).
|
||||
//!
|
||||
//! Key Features:
|
||||
//! - **Selective Scan**: Input-dependent state transitions
|
||||
//! - **Linear Complexity**: O(n) vs O(n^2) for standard attention
|
||||
//! - **Hardware Efficient**: Optimized for parallel scan operations
|
||||
//! - **Long Context**: Handles very long sequences efficiently
|
||||
//!
|
||||
//! ## Architecture
|
||||
//!
|
||||
//! Mamba uses a selective state space model:
|
||||
//! ```text
|
||||
//! h_t = A_t * h_{t-1} + B_t * x_t
|
||||
//! y_t = C_t * h_t
|
||||
//! ```
|
||||
//!
|
||||
//! Where A_t, B_t, C_t are input-dependent (selective), computed from x_t.
|
||||
//!
|
||||
//! ## References
|
||||
//!
|
||||
//! - Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Gu & Dao, 2023)
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
// ============================================================================
|
||||
// Configuration
|
||||
// ============================================================================
|
||||
|
||||
/// Configuration for Mamba SSM attention
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[wasm_bindgen]
|
||||
pub struct MambaConfig {
|
||||
/// Model dimension (d_model)
|
||||
pub dim: usize,
|
||||
/// State space dimension (n)
|
||||
pub state_dim: usize,
|
||||
/// Expansion factor for inner dimension
|
||||
pub expand_factor: usize,
|
||||
/// Convolution kernel size
|
||||
pub conv_kernel_size: usize,
|
||||
/// Delta (discretization step) range minimum
|
||||
pub dt_min: f32,
|
||||
/// Delta range maximum
|
||||
pub dt_max: f32,
|
||||
/// Whether to use learnable D skip connection
|
||||
pub use_d_skip: bool,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl MambaConfig {
|
||||
/// Create a new Mamba configuration
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(dim: usize) -> MambaConfig {
|
||||
MambaConfig {
|
||||
dim,
|
||||
state_dim: 16,
|
||||
expand_factor: 2,
|
||||
conv_kernel_size: 4,
|
||||
dt_min: 0.001,
|
||||
dt_max: 0.1,
|
||||
use_d_skip: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set state space dimension
|
||||
#[wasm_bindgen(js_name = withStateDim)]
|
||||
pub fn with_state_dim(mut self, state_dim: usize) -> MambaConfig {
|
||||
self.state_dim = state_dim;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set expansion factor
|
||||
#[wasm_bindgen(js_name = withExpandFactor)]
|
||||
pub fn with_expand_factor(mut self, factor: usize) -> MambaConfig {
|
||||
self.expand_factor = factor;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set convolution kernel size
|
||||
#[wasm_bindgen(js_name = withConvKernelSize)]
|
||||
pub fn with_conv_kernel_size(mut self, size: usize) -> MambaConfig {
|
||||
self.conv_kernel_size = size;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MambaConfig {
|
||||
fn default() -> Self {
|
||||
MambaConfig::new(256)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// State Space Parameters
|
||||
// ============================================================================
|
||||
|
||||
/// Selective state space parameters (input-dependent)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct SelectiveSSMParams {
|
||||
/// Discretized A matrix diagonal (batch, seq_len, state_dim)
|
||||
a_bar: Vec<Vec<Vec<f32>>>,
|
||||
/// Discretized B matrix (batch, seq_len, state_dim)
|
||||
b_bar: Vec<Vec<Vec<f32>>>,
|
||||
/// Output projection C (batch, seq_len, state_dim)
|
||||
c: Vec<Vec<Vec<f32>>>,
|
||||
/// Discretization step delta (batch, seq_len, inner_dim)
|
||||
delta: Vec<Vec<Vec<f32>>>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Mamba SSM Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Mamba Selective State Space Model for sequence attention
|
||||
///
|
||||
/// Provides O(n) attention-like mechanism using selective state spaces
|
||||
#[wasm_bindgen]
|
||||
pub struct MambaSSMAttention {
|
||||
config: MambaConfig,
|
||||
/// Inner dimension after expansion
|
||||
inner_dim: usize,
|
||||
/// A parameter (state_dim,) - diagonal of continuous A
|
||||
a_log: Vec<f32>,
|
||||
/// D skip connection (inner_dim,)
|
||||
d_skip: Vec<f32>,
|
||||
/// Projection weights (simplified for WASM)
|
||||
in_proj: Vec<Vec<f32>>,
|
||||
out_proj: Vec<Vec<f32>>,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl MambaSSMAttention {
|
||||
/// Create a new Mamba SSM attention layer
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(config: MambaConfig) -> MambaSSMAttention {
|
||||
let inner_dim = config.dim * config.expand_factor;
|
||||
|
||||
// Initialize A as negative values (for stability) - log of eigenvalues
|
||||
let a_log: Vec<f32> = (0..config.state_dim)
|
||||
.map(|i| -((i + 1) as f32).ln())
|
||||
.collect();
|
||||
|
||||
// D skip connection
|
||||
let d_skip = vec![1.0; inner_dim];
|
||||
|
||||
// Simplified projection matrices (identity-like for stub)
|
||||
let in_proj: Vec<Vec<f32>> = (0..inner_dim)
|
||||
.map(|i| {
|
||||
let mut row = vec![0.0; config.dim];
|
||||
if i < config.dim {
|
||||
row[i] = 1.0;
|
||||
}
|
||||
row
|
||||
})
|
||||
.collect();
|
||||
|
||||
let out_proj: Vec<Vec<f32>> = (0..config.dim)
|
||||
.map(|i| {
|
||||
let mut row = vec![0.0; inner_dim];
|
||||
if i < inner_dim {
|
||||
row[i] = 1.0;
|
||||
}
|
||||
row
|
||||
})
|
||||
.collect();
|
||||
|
||||
MambaSSMAttention {
|
||||
config,
|
||||
inner_dim,
|
||||
a_log,
|
||||
d_skip,
|
||||
in_proj,
|
||||
out_proj,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with default configuration
|
||||
#[wasm_bindgen(js_name = withDefaults)]
|
||||
pub fn with_defaults(dim: usize) -> MambaSSMAttention {
|
||||
MambaSSMAttention::new(MambaConfig::new(dim))
|
||||
}
|
||||
|
||||
/// Forward pass through Mamba SSM
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `input` - Input sequence (seq_len, dim) flattened to 1D
|
||||
/// * `seq_len` - Sequence length
|
||||
///
|
||||
/// # Returns
|
||||
/// Output sequence (seq_len, dim) flattened to 1D
|
||||
#[wasm_bindgen]
|
||||
pub fn forward(&self, input: Vec<f32>, seq_len: usize) -> Result<Vec<f32>, JsError> {
|
||||
let dim = self.config.dim;
|
||||
|
||||
if input.len() != seq_len * dim {
|
||||
return Err(JsError::new(&format!(
|
||||
"Input size mismatch: expected {} ({}x{}), got {}",
|
||||
seq_len * dim,
|
||||
seq_len,
|
||||
dim,
|
||||
input.len()
|
||||
)));
|
||||
}
|
||||
|
||||
// Reshape input to 2D
|
||||
let input_2d: Vec<Vec<f32>> = (0..seq_len)
|
||||
.map(|t| input[t * dim..(t + 1) * dim].to_vec())
|
||||
.collect();
|
||||
|
||||
// Step 1: Input projection to inner_dim
|
||||
let projected = self.project_in(&input_2d);
|
||||
|
||||
// Step 2: Compute selective SSM parameters from input
|
||||
let ssm_params = self.compute_selective_params(&projected);
|
||||
|
||||
// Step 3: Run selective scan
|
||||
let ssm_output = self.selective_scan(&projected, &ssm_params);
|
||||
|
||||
// Step 4: Apply D skip connection
|
||||
let with_skip: Vec<Vec<f32>> = ssm_output
|
||||
.iter()
|
||||
.zip(projected.iter())
|
||||
.map(|(y, x)| {
|
||||
y.iter()
|
||||
.zip(x.iter())
|
||||
.zip(self.d_skip.iter())
|
||||
.map(|((yi, xi), di)| yi + di * xi)
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Step 5: Output projection
|
||||
let output = self.project_out(&with_skip);
|
||||
|
||||
// Flatten output
|
||||
Ok(output.into_iter().flatten().collect())
|
||||
}
|
||||
|
||||
/// Get the configuration
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn config(&self) -> MambaConfig {
|
||||
self.config.clone()
|
||||
}
|
||||
|
||||
/// Get the inner dimension
|
||||
#[wasm_bindgen(getter, js_name = innerDim)]
|
||||
pub fn inner_dim(&self) -> usize {
|
||||
self.inner_dim
|
||||
}
|
||||
|
||||
/// Compute attention-like scores (for visualization/analysis)
|
||||
///
|
||||
/// Returns pseudo-attention scores showing which positions influence output
|
||||
#[wasm_bindgen(js_name = getAttentionScores)]
|
||||
pub fn get_attention_scores(
|
||||
&self,
|
||||
input: Vec<f32>,
|
||||
seq_len: usize,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let dim = self.config.dim;
|
||||
|
||||
if input.len() != seq_len * dim {
|
||||
return Err(JsError::new(&format!(
|
||||
"Input size mismatch: expected {}, got {}",
|
||||
seq_len * dim,
|
||||
input.len()
|
||||
)));
|
||||
}
|
||||
|
||||
// Compute approximate attention scores based on state decay
|
||||
// This shows how much each position can "attend to" previous positions
|
||||
let mut scores = vec![0.0f32; seq_len * seq_len];
|
||||
|
||||
for t in 0..seq_len {
|
||||
for s in 0..=t {
|
||||
// Exponential decay based on distance and A parameters
|
||||
let distance = (t - s) as f32;
|
||||
let decay: f32 = self
|
||||
.a_log
|
||||
.iter()
|
||||
.map(|&a| (a * distance).exp())
|
||||
.sum::<f32>()
|
||||
/ self.config.state_dim as f32;
|
||||
|
||||
scores[t * seq_len + s] = decay;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(scores)
|
||||
}
|
||||
}
|
||||
|
||||
// Internal implementation methods
|
||||
impl MambaSSMAttention {
|
||||
/// Project input from dim to inner_dim
|
||||
fn project_in(&self, input: &[Vec<f32>]) -> Vec<Vec<f32>> {
|
||||
input
|
||||
.iter()
|
||||
.map(|x| {
|
||||
self.in_proj
|
||||
.iter()
|
||||
.map(|row| row.iter().zip(x.iter()).map(|(w, xi)| w * xi).sum())
|
||||
.collect()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Project from inner_dim back to dim
|
||||
fn project_out(&self, input: &[Vec<f32>]) -> Vec<Vec<f32>> {
|
||||
input
|
||||
.iter()
|
||||
.map(|x| {
|
||||
self.out_proj
|
||||
.iter()
|
||||
.map(|row| row.iter().zip(x.iter()).map(|(w, xi)| w * xi).sum())
|
||||
.collect()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Compute selective SSM parameters from input
|
||||
fn compute_selective_params(&self, input: &[Vec<f32>]) -> SelectiveSSMParams {
|
||||
let seq_len = input.len();
|
||||
let state_dim = self.config.state_dim;
|
||||
|
||||
// Compute input-dependent delta, B, C
|
||||
// Simplified: use sigmoid/tanh of input projections
|
||||
|
||||
let mut a_bar = vec![vec![vec![0.0; state_dim]; self.inner_dim]; seq_len];
|
||||
let mut b_bar = vec![vec![vec![0.0; state_dim]; self.inner_dim]; seq_len];
|
||||
let mut c = vec![vec![vec![0.0; state_dim]; self.inner_dim]; seq_len];
|
||||
let mut delta = vec![vec![vec![0.0; self.inner_dim]; 1]; seq_len];
|
||||
|
||||
for (t, x) in input.iter().enumerate() {
|
||||
// Compute delta from input (softplus of projection)
|
||||
let dt: Vec<f32> = x
|
||||
.iter()
|
||||
.map(|&xi| {
|
||||
let raw = xi * 0.1; // Simple scaling
|
||||
let dt_val = (1.0 + raw.exp()).ln(); // Softplus
|
||||
dt_val.clamp(self.config.dt_min, self.config.dt_max)
|
||||
})
|
||||
.collect();
|
||||
delta[t][0] = dt.clone();
|
||||
|
||||
for d in 0..self.inner_dim.min(x.len()) {
|
||||
let dt_d = dt[d.min(dt.len() - 1)];
|
||||
|
||||
for n in 0..state_dim {
|
||||
// Discretize A: A_bar = exp(delta * A)
|
||||
let a_continuous = self.a_log[n].exp(); // Negative
|
||||
a_bar[t][d][n] = (dt_d * a_continuous).exp();
|
||||
|
||||
// Discretize B: B_bar = delta * B (simplified)
|
||||
// B is input-dependent
|
||||
let b_input = if d < x.len() { x[d] } else { 0.0 };
|
||||
b_bar[t][d][n] = dt_d * Self::sigmoid(b_input * 0.1);
|
||||
|
||||
// C is input-dependent
|
||||
c[t][d][n] = Self::tanh(b_input * 0.1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SelectiveSSMParams {
|
||||
a_bar,
|
||||
b_bar,
|
||||
c,
|
||||
delta,
|
||||
}
|
||||
}
|
||||
|
||||
/// Run selective scan (parallel associative scan in practice)
|
||||
fn selective_scan(&self, input: &[Vec<f32>], params: &SelectiveSSMParams) -> Vec<Vec<f32>> {
|
||||
let seq_len = input.len();
|
||||
let state_dim = self.config.state_dim;
|
||||
|
||||
// Initialize hidden state
|
||||
let mut hidden = vec![vec![0.0f32; state_dim]; self.inner_dim];
|
||||
let mut output = vec![vec![0.0f32; self.inner_dim]; seq_len];
|
||||
|
||||
for t in 0..seq_len {
|
||||
for d in 0..self.inner_dim {
|
||||
let x_d = if d < input[t].len() { input[t][d] } else { 0.0 };
|
||||
|
||||
// Update hidden state: h_t = A_bar * h_{t-1} + B_bar * x_t
|
||||
for n in 0..state_dim {
|
||||
hidden[d][n] =
|
||||
params.a_bar[t][d][n] * hidden[d][n] + params.b_bar[t][d][n] * x_d;
|
||||
}
|
||||
|
||||
// Compute output: y_t = C * h_t
|
||||
output[t][d] = hidden[d]
|
||||
.iter()
|
||||
.zip(params.c[t][d].iter())
|
||||
.map(|(h, c)| h * c)
|
||||
.sum();
|
||||
}
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn sigmoid(x: f32) -> f32 {
|
||||
1.0 / (1.0 + (-x).exp())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn tanh(x: f32) -> f32 {
|
||||
x.tanh()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Hybrid Mamba-Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Hybrid layer combining Mamba SSM with standard attention
|
||||
///
|
||||
/// Uses Mamba for long-range dependencies and attention for local patterns
|
||||
#[wasm_bindgen]
|
||||
pub struct HybridMambaAttention {
|
||||
mamba: MambaSSMAttention,
|
||||
local_window: usize,
|
||||
use_attention_for_local: bool,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl HybridMambaAttention {
|
||||
/// Create a new hybrid Mamba-Attention layer
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(config: MambaConfig, local_window: usize) -> HybridMambaAttention {
|
||||
HybridMambaAttention {
|
||||
mamba: MambaSSMAttention::new(config),
|
||||
local_window,
|
||||
use_attention_for_local: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass
|
||||
#[wasm_bindgen]
|
||||
pub fn forward(&self, input: Vec<f32>, seq_len: usize) -> Result<Vec<f32>, JsError> {
|
||||
let dim = self.mamba.config.dim;
|
||||
|
||||
// Run Mamba for global context
|
||||
let mamba_output = self.mamba.forward(input.clone(), seq_len)?;
|
||||
|
||||
// Apply local attention mixing (simplified)
|
||||
let mut output = mamba_output.clone();
|
||||
|
||||
if self.use_attention_for_local {
|
||||
for t in 0..seq_len {
|
||||
let start = t.saturating_sub(self.local_window / 2);
|
||||
let end = (t + self.local_window / 2 + 1).min(seq_len);
|
||||
|
||||
// Simple local averaging
|
||||
for d in 0..dim {
|
||||
let mut local_sum = 0.0;
|
||||
let mut count = 0;
|
||||
for s in start..end {
|
||||
local_sum += input[s * dim + d];
|
||||
count += 1;
|
||||
}
|
||||
// Mix global (Mamba) and local
|
||||
let local_avg = local_sum / count as f32;
|
||||
output[t * dim + d] = 0.7 * output[t * dim + d] + 0.3 * local_avg;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Get local window size
|
||||
#[wasm_bindgen(getter, js_name = localWindow)]
|
||||
pub fn local_window(&self) -> usize {
|
||||
self.local_window
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use wasm_bindgen_test::*;
|
||||
|
||||
wasm_bindgen_test_configure!(run_in_browser);
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_mamba_config() {
|
||||
let config = MambaConfig::new(256);
|
||||
assert_eq!(config.dim, 256);
|
||||
assert_eq!(config.state_dim, 16);
|
||||
assert_eq!(config.expand_factor, 2);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_mamba_creation() {
|
||||
let config = MambaConfig::new(64);
|
||||
let mamba = MambaSSMAttention::new(config);
|
||||
assert_eq!(mamba.inner_dim(), 128); // 64 * 2
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_mamba_forward() {
|
||||
let config = MambaConfig::new(8);
|
||||
let mamba = MambaSSMAttention::new(config);
|
||||
|
||||
// Input: 4 tokens of dimension 8
|
||||
let input = vec![0.1f32; 32];
|
||||
let output = mamba.forward(input, 4);
|
||||
|
||||
assert!(output.is_ok());
|
||||
let out = output.unwrap();
|
||||
assert_eq!(out.len(), 32); // Same shape as input
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_attention_scores() {
|
||||
let config = MambaConfig::new(8);
|
||||
let mamba = MambaSSMAttention::new(config);
|
||||
|
||||
let input = vec![0.1f32; 24]; // 3 tokens
|
||||
let scores = mamba.get_attention_scores(input, 3);
|
||||
|
||||
assert!(scores.is_ok());
|
||||
let s = scores.unwrap();
|
||||
assert_eq!(s.len(), 9); // 3x3 attention matrix
|
||||
|
||||
// Causal: upper triangle should be 0
|
||||
assert_eq!(s[0 * 3 + 1], 0.0); // t=0 cannot attend to t=1
|
||||
assert_eq!(s[0 * 3 + 2], 0.0); // t=0 cannot attend to t=2
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_hybrid_mamba() {
|
||||
let config = MambaConfig::new(8);
|
||||
let hybrid = HybridMambaAttention::new(config, 4);
|
||||
|
||||
let input = vec![0.5f32; 40]; // 5 tokens
|
||||
let output = hybrid.forward(input, 5);
|
||||
|
||||
assert!(output.is_ok());
|
||||
assert_eq!(output.unwrap().len(), 40);
|
||||
}
|
||||
}
|
||||
439
vendor/ruvector/crates/ruvector-attention-unified-wasm/src/neural.rs
vendored
Normal file
439
vendor/ruvector/crates/ruvector-attention-unified-wasm/src/neural.rs
vendored
Normal file
@@ -0,0 +1,439 @@
|
||||
//! Neural Attention Mechanisms (from ruvector-attention)
|
||||
//!
|
||||
//! Re-exports the 7 core neural attention mechanisms:
|
||||
//! - Scaled Dot-Product Attention
|
||||
//! - Multi-Head Attention
|
||||
//! - Hyperbolic Attention
|
||||
//! - Linear Attention (Performer)
|
||||
//! - Flash Attention
|
||||
//! - Local-Global Attention
|
||||
//! - Mixture of Experts (MoE) Attention
|
||||
|
||||
use ruvector_attention::{
|
||||
attention::{MultiHeadAttention, ScaledDotProductAttention},
|
||||
hyperbolic::{HyperbolicAttention, HyperbolicAttentionConfig},
|
||||
moe::{MoEAttention, MoEConfig},
|
||||
sparse::{FlashAttention, LinearAttention, LocalGlobalAttention},
|
||||
traits::Attention,
|
||||
};
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
// ============================================================================
|
||||
// Scaled Dot-Product Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Compute scaled dot-product attention
|
||||
///
|
||||
/// Standard transformer attention: softmax(QK^T / sqrt(d)) * V
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query` - Query vector (Float32Array)
|
||||
/// * `keys` - Array of key vectors (JsValue - array of Float32Arrays)
|
||||
/// * `values` - Array of value vectors (JsValue - array of Float32Arrays)
|
||||
/// * `scale` - Optional scaling factor (defaults to 1/sqrt(dim))
|
||||
///
|
||||
/// # Returns
|
||||
/// Attention-weighted output vector
|
||||
#[wasm_bindgen(js_name = scaledDotAttention)]
|
||||
pub fn scaled_dot_attention(
|
||||
query: &[f32],
|
||||
keys: JsValue,
|
||||
values: JsValue,
|
||||
scale: Option<f32>,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)
|
||||
.map_err(|e| JsError::new(&format!("Failed to parse keys: {}", e)))?;
|
||||
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)
|
||||
.map_err(|e| JsError::new(&format!("Failed to parse values: {}", e)))?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let attention = ScaledDotProductAttention::new(query.len());
|
||||
attention
|
||||
.compute(query, &keys_refs, &values_refs)
|
||||
.map_err(|e| JsError::new(&e.to_string()))
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Multi-Head Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Multi-head attention mechanism
|
||||
///
|
||||
/// Splits input into multiple heads, applies attention, and concatenates results
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmMultiHeadAttention {
|
||||
inner: MultiHeadAttention,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmMultiHeadAttention {
|
||||
/// Create a new multi-head attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension (must be divisible by num_heads)
|
||||
/// * `num_heads` - Number of parallel attention heads
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(dim: usize, num_heads: usize) -> Result<WasmMultiHeadAttention, JsError> {
|
||||
if dim % num_heads != 0 {
|
||||
return Err(JsError::new(&format!(
|
||||
"Dimension {} must be divisible by number of heads {}",
|
||||
dim, num_heads
|
||||
)));
|
||||
}
|
||||
Ok(Self {
|
||||
inner: MultiHeadAttention::new(dim, num_heads),
|
||||
})
|
||||
}
|
||||
|
||||
/// Compute multi-head attention
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query` - Query vector
|
||||
/// * `keys` - Array of key vectors
|
||||
/// * `values` - Array of value vectors
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: JsValue,
|
||||
values: JsValue,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
|
||||
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
self.inner
|
||||
.compute(query, &keys_refs, &values_refs)
|
||||
.map_err(|e| JsError::new(&e.to_string()))
|
||||
}
|
||||
|
||||
/// Get the number of attention heads
|
||||
#[wasm_bindgen(getter, js_name = numHeads)]
|
||||
pub fn num_heads(&self) -> usize {
|
||||
self.inner.num_heads()
|
||||
}
|
||||
|
||||
/// Get the embedding dimension
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn dim(&self) -> usize {
|
||||
self.inner.dim()
|
||||
}
|
||||
|
||||
/// Get the dimension per head
|
||||
#[wasm_bindgen(getter, js_name = headDim)]
|
||||
pub fn head_dim(&self) -> usize {
|
||||
self.inner.dim() / self.inner.num_heads()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Hyperbolic Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Hyperbolic attention mechanism for hierarchical data
|
||||
///
|
||||
/// Operates in hyperbolic space (Poincare ball model) which naturally
|
||||
/// represents tree-like hierarchical structures with exponential capacity
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmHyperbolicAttention {
|
||||
inner: HyperbolicAttention,
|
||||
curvature_value: f32,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmHyperbolicAttention {
|
||||
/// Create a new hyperbolic attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension
|
||||
/// * `curvature` - Hyperbolic curvature parameter (negative for hyperbolic space)
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(dim: usize, curvature: f32) -> WasmHyperbolicAttention {
|
||||
let config = HyperbolicAttentionConfig {
|
||||
dim,
|
||||
curvature,
|
||||
..Default::default()
|
||||
};
|
||||
Self {
|
||||
inner: HyperbolicAttention::new(config),
|
||||
curvature_value: curvature,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute hyperbolic attention
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: JsValue,
|
||||
values: JsValue,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
|
||||
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
self.inner
|
||||
.compute(query, &keys_refs, &values_refs)
|
||||
.map_err(|e| JsError::new(&e.to_string()))
|
||||
}
|
||||
|
||||
/// Get the curvature parameter
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn curvature(&self) -> f32 {
|
||||
self.curvature_value
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Linear Attention (Performer)
|
||||
// ============================================================================
|
||||
|
||||
/// Linear attention using random feature approximation
|
||||
///
|
||||
/// Achieves O(n) complexity instead of O(n^2) by approximating
|
||||
/// the softmax kernel with random Fourier features
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmLinearAttention {
|
||||
inner: LinearAttention,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmLinearAttention {
|
||||
/// Create a new linear attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension
|
||||
/// * `num_features` - Number of random features for kernel approximation
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(dim: usize, num_features: usize) -> WasmLinearAttention {
|
||||
Self {
|
||||
inner: LinearAttention::new(dim, num_features),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute linear attention
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: JsValue,
|
||||
values: JsValue,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
|
||||
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
self.inner
|
||||
.compute(query, &keys_refs, &values_refs)
|
||||
.map_err(|e| JsError::new(&e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Flash Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Flash attention with memory-efficient tiling
|
||||
///
|
||||
/// Reduces memory usage from O(n^2) to O(n) by computing attention
|
||||
/// in blocks and fusing operations
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmFlashAttention {
|
||||
inner: FlashAttention,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmFlashAttention {
|
||||
/// Create a new flash attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension
|
||||
/// * `block_size` - Block size for tiled computation
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(dim: usize, block_size: usize) -> WasmFlashAttention {
|
||||
Self {
|
||||
inner: FlashAttention::new(dim, block_size),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute flash attention
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: JsValue,
|
||||
values: JsValue,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
|
||||
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
self.inner
|
||||
.compute(query, &keys_refs, &values_refs)
|
||||
.map_err(|e| JsError::new(&e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Local-Global Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Local-global sparse attention (Longformer-style)
|
||||
///
|
||||
/// Combines local sliding window attention with global tokens
|
||||
/// for efficient long-range dependencies
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmLocalGlobalAttention {
|
||||
inner: LocalGlobalAttention,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmLocalGlobalAttention {
|
||||
/// Create a new local-global attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension
|
||||
/// * `local_window` - Size of local attention window
|
||||
/// * `global_tokens` - Number of global attention tokens
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(dim: usize, local_window: usize, global_tokens: usize) -> WasmLocalGlobalAttention {
|
||||
Self {
|
||||
inner: LocalGlobalAttention::new(dim, local_window, global_tokens),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute local-global attention
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: JsValue,
|
||||
values: JsValue,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
|
||||
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
self.inner
|
||||
.compute(query, &keys_refs, &values_refs)
|
||||
.map_err(|e| JsError::new(&e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Mixture of Experts (MoE) Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Mixture of Experts attention mechanism
|
||||
///
|
||||
/// Routes queries to specialized expert attention heads based on
|
||||
/// learned gating functions for capacity-efficient computation
|
||||
#[wasm_bindgen]
|
||||
pub struct WasmMoEAttention {
|
||||
inner: MoEAttention,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WasmMoEAttention {
|
||||
/// Create a new MoE attention instance
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `dim` - Embedding dimension
|
||||
/// * `num_experts` - Number of expert attention mechanisms
|
||||
/// * `top_k` - Number of experts to activate per query
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(dim: usize, num_experts: usize, top_k: usize) -> WasmMoEAttention {
|
||||
let config = MoEConfig::builder()
|
||||
.dim(dim)
|
||||
.num_experts(num_experts)
|
||||
.top_k(top_k)
|
||||
.build();
|
||||
Self {
|
||||
inner: MoEAttention::new(config),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute MoE attention
|
||||
pub fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: JsValue,
|
||||
values: JsValue,
|
||||
) -> Result<Vec<f32>, JsError> {
|
||||
let keys_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(keys)?;
|
||||
let values_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(values)?;
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
self.inner
|
||||
.compute(query, &keys_refs, &values_refs)
|
||||
.map_err(|e| JsError::new(&e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use wasm_bindgen_test::*;
|
||||
|
||||
wasm_bindgen_test_configure!(run_in_browser);
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_multi_head_creation() {
|
||||
let mha = WasmMultiHeadAttention::new(64, 8);
|
||||
assert!(mha.is_ok());
|
||||
let mha = mha.unwrap();
|
||||
assert_eq!(mha.dim(), 64);
|
||||
assert_eq!(mha.num_heads(), 8);
|
||||
assert_eq!(mha.head_dim(), 8);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_multi_head_invalid_dims() {
|
||||
let mha = WasmMultiHeadAttention::new(65, 8);
|
||||
assert!(mha.is_err());
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_hyperbolic_attention() {
|
||||
let hyp = WasmHyperbolicAttention::new(32, -1.0);
|
||||
assert_eq!(hyp.curvature(), -1.0);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_linear_attention_creation() {
|
||||
let linear = WasmLinearAttention::new(64, 128);
|
||||
// Just verify it can be created
|
||||
assert!(true);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_flash_attention_creation() {
|
||||
let flash = WasmFlashAttention::new(64, 16);
|
||||
assert!(true);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_local_global_creation() {
|
||||
let lg = WasmLocalGlobalAttention::new(64, 128, 4);
|
||||
assert!(true);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn test_moe_attention_creation() {
|
||||
let moe = WasmMoEAttention::new(64, 8, 2);
|
||||
assert!(true);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user