807 lines
25 KiB
Rust
807 lines
25 KiB
Rust
//! DAG Attention Mechanisms (from ruvector-dag)
|
|
//!
|
|
//! Re-exports the 7 DAG-specific attention mechanisms:
|
|
//! - Topological Attention
|
|
//! - Causal Cone Attention
|
|
//! - Critical Path Attention
|
|
//! - MinCut-Gated Attention
|
|
//! - Hierarchical Lorentz Attention
|
|
//! - Parallel Branch Attention
|
|
//! - Temporal BTSP Attention
|
|
|
|
use ruvector_dag::{OperatorNode, QueryDag};
|
|
use serde::{Deserialize, Serialize};
|
|
use std::collections::HashMap;
|
|
use wasm_bindgen::prelude::*;
|
|
|
|
// ============================================================================
|
|
// Minimal DAG for WASM
|
|
// ============================================================================
|
|
|
|
/// Minimal DAG structure for WASM attention computation
|
|
#[wasm_bindgen]
|
|
pub struct WasmQueryDag {
|
|
inner: QueryDag,
|
|
}
|
|
|
|
#[wasm_bindgen]
|
|
impl WasmQueryDag {
|
|
/// Create a new empty DAG
|
|
#[wasm_bindgen(constructor)]
|
|
pub fn new() -> WasmQueryDag {
|
|
WasmQueryDag {
|
|
inner: QueryDag::new(),
|
|
}
|
|
}
|
|
|
|
/// Add a node with operator type and cost
|
|
///
|
|
/// # Arguments
|
|
/// * `op_type` - Operator type: "scan", "filter", "join", "aggregate", "project", "sort"
|
|
/// * `cost` - Estimated execution cost
|
|
///
|
|
/// # Returns
|
|
/// Node ID
|
|
#[wasm_bindgen(js_name = addNode)]
|
|
pub fn add_node(&mut self, op_type: &str, cost: f32) -> u32 {
|
|
let table_id = self.inner.node_count() as usize;
|
|
let mut node = match op_type {
|
|
"scan" => OperatorNode::seq_scan(table_id, &format!("table_{}", table_id)),
|
|
"filter" => OperatorNode::filter(table_id, "condition"),
|
|
"join" => OperatorNode::hash_join(table_id, "join_key"),
|
|
"aggregate" => OperatorNode::aggregate(table_id, vec!["*".to_string()]),
|
|
"project" => OperatorNode::project(table_id, vec!["*".to_string()]),
|
|
"sort" => OperatorNode::sort(table_id, vec!["col".to_string()]),
|
|
_ => OperatorNode::seq_scan(table_id, "unknown"),
|
|
};
|
|
node.estimated_cost = cost as f64;
|
|
self.inner.add_node(node) as u32
|
|
}
|
|
|
|
/// Add an edge between nodes
|
|
///
|
|
/// # Arguments
|
|
/// * `from` - Source node ID
|
|
/// * `to` - Target node ID
|
|
///
|
|
/// # Returns
|
|
/// True if edge was added successfully
|
|
#[wasm_bindgen(js_name = addEdge)]
|
|
pub fn add_edge(&mut self, from: u32, to: u32) -> bool {
|
|
self.inner.add_edge(from as usize, to as usize).is_ok()
|
|
}
|
|
|
|
/// Get the number of nodes
|
|
#[wasm_bindgen(getter, js_name = nodeCount)]
|
|
pub fn node_count(&self) -> u32 {
|
|
self.inner.node_count() as u32
|
|
}
|
|
|
|
/// Get the number of edges
|
|
#[wasm_bindgen(getter, js_name = edgeCount)]
|
|
pub fn edge_count(&self) -> u32 {
|
|
self.inner.edge_count() as u32
|
|
}
|
|
|
|
/// Serialize to JSON
|
|
#[wasm_bindgen(js_name = toJson)]
|
|
pub fn to_json(&self) -> String {
|
|
serde_json::to_string(&DagSummary {
|
|
node_count: self.inner.node_count(),
|
|
edge_count: self.inner.edge_count(),
|
|
})
|
|
.unwrap_or_default()
|
|
}
|
|
}
|
|
|
|
impl WasmQueryDag {
|
|
/// Get internal reference
|
|
pub(crate) fn inner(&self) -> &QueryDag {
|
|
&self.inner
|
|
}
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize)]
|
|
struct DagSummary {
|
|
node_count: usize,
|
|
edge_count: usize,
|
|
}
|
|
|
|
// ============================================================================
|
|
// Helper trait for converting HashMap scores to Vec
|
|
// ============================================================================
|
|
|
|
fn hashmap_to_vec(scores: &HashMap<usize, f32>, n: usize) -> Vec<f32> {
|
|
(0..n)
|
|
.map(|i| scores.get(&i).copied().unwrap_or(0.0))
|
|
.collect()
|
|
}
|
|
|
|
// ============================================================================
|
|
// Topological Attention
|
|
// ============================================================================
|
|
|
|
/// Topological attention based on DAG position
|
|
///
|
|
/// Assigns attention scores based on node position in topological order.
|
|
/// Earlier nodes (closer to sources) get higher attention.
|
|
#[wasm_bindgen]
|
|
pub struct WasmTopologicalAttention {
|
|
decay_factor: f32,
|
|
}
|
|
|
|
#[wasm_bindgen]
|
|
impl WasmTopologicalAttention {
|
|
/// Create a new topological attention instance
|
|
///
|
|
/// # Arguments
|
|
/// * `decay_factor` - Decay factor for position-based attention (0.0-1.0)
|
|
#[wasm_bindgen(constructor)]
|
|
pub fn new(decay_factor: f32) -> WasmTopologicalAttention {
|
|
WasmTopologicalAttention { decay_factor }
|
|
}
|
|
|
|
/// Compute attention scores for the DAG
|
|
///
|
|
/// # Returns
|
|
/// Attention scores for each node
|
|
pub fn forward(&self, dag: &WasmQueryDag) -> Result<Vec<f32>, JsError> {
|
|
let n = dag.inner.node_count();
|
|
if n == 0 {
|
|
return Err(JsError::new("Empty DAG"));
|
|
}
|
|
|
|
let depths = dag.inner.compute_depths();
|
|
let max_depth = depths.values().max().copied().unwrap_or(0);
|
|
|
|
let mut scores = HashMap::new();
|
|
let mut total = 0.0f32;
|
|
|
|
for (&node_id, &depth) in &depths {
|
|
let normalized_depth = depth as f32 / (max_depth.max(1) as f32);
|
|
let score = self.decay_factor.powf(1.0 - normalized_depth);
|
|
scores.insert(node_id, score);
|
|
total += score;
|
|
}
|
|
|
|
if total > 0.0 {
|
|
for score in scores.values_mut() {
|
|
*score /= total;
|
|
}
|
|
}
|
|
|
|
Ok(hashmap_to_vec(&scores, n))
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// Causal Cone Attention
|
|
// ============================================================================
|
|
|
|
/// Causal cone attention based on dependency lightcones
|
|
///
|
|
/// Nodes can only attend to ancestors in the DAG (causal predecessors).
|
|
/// Attention strength decays with causal distance.
|
|
#[wasm_bindgen]
|
|
pub struct WasmCausalConeAttention {
|
|
future_discount: f32,
|
|
ancestor_weight: f32,
|
|
}
|
|
|
|
#[wasm_bindgen]
|
|
impl WasmCausalConeAttention {
|
|
/// Create a new causal cone attention instance
|
|
///
|
|
/// # Arguments
|
|
/// * `future_discount` - Discount for future nodes
|
|
/// * `ancestor_weight` - Weight for ancestor influence
|
|
#[wasm_bindgen(constructor)]
|
|
pub fn new(future_discount: f32, ancestor_weight: f32) -> WasmCausalConeAttention {
|
|
WasmCausalConeAttention {
|
|
future_discount,
|
|
ancestor_weight,
|
|
}
|
|
}
|
|
|
|
/// Compute attention scores for the DAG
|
|
pub fn forward(&self, dag: &WasmQueryDag) -> Result<Vec<f32>, JsError> {
|
|
let n = dag.inner.node_count();
|
|
if n == 0 {
|
|
return Err(JsError::new("Empty DAG"));
|
|
}
|
|
|
|
let mut scores = HashMap::new();
|
|
let mut total = 0.0f32;
|
|
|
|
let depths = dag.inner.compute_depths();
|
|
|
|
for node_id in 0..n {
|
|
if dag.inner.get_node(node_id).is_none() {
|
|
continue;
|
|
}
|
|
|
|
let ancestors = dag.inner.ancestors(node_id);
|
|
let ancestor_count = ancestors.len();
|
|
|
|
let mut score = 1.0 + (ancestor_count as f32 * self.ancestor_weight);
|
|
|
|
if let Some(&depth) = depths.get(&node_id) {
|
|
score *= self.future_discount.powi(depth as i32);
|
|
}
|
|
|
|
scores.insert(node_id, score);
|
|
total += score;
|
|
}
|
|
|
|
if total > 0.0 {
|
|
for score in scores.values_mut() {
|
|
*score /= total;
|
|
}
|
|
}
|
|
|
|
Ok(hashmap_to_vec(&scores, n))
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// Critical Path Attention
|
|
// ============================================================================
|
|
|
|
/// Critical path attention weighted by path criticality
|
|
///
|
|
/// Nodes on or near the critical path (longest execution path)
|
|
/// receive higher attention scores.
|
|
#[wasm_bindgen]
|
|
pub struct WasmCriticalPathAttention {
|
|
path_weight: f32,
|
|
branch_penalty: f32,
|
|
}
|
|
|
|
#[wasm_bindgen]
|
|
impl WasmCriticalPathAttention {
|
|
/// Create a new critical path attention instance
|
|
///
|
|
/// # Arguments
|
|
/// * `path_weight` - Weight for critical path membership
|
|
/// * `branch_penalty` - Penalty for branching nodes
|
|
#[wasm_bindgen(constructor)]
|
|
pub fn new(path_weight: f32, branch_penalty: f32) -> WasmCriticalPathAttention {
|
|
WasmCriticalPathAttention {
|
|
path_weight,
|
|
branch_penalty,
|
|
}
|
|
}
|
|
|
|
/// Compute the critical path (longest path by cost)
|
|
fn compute_critical_path(&self, dag: &QueryDag) -> Vec<usize> {
|
|
let mut longest_path: HashMap<usize, (f64, Vec<usize>)> = HashMap::new();
|
|
|
|
for &leaf in &dag.leaves() {
|
|
if let Some(node) = dag.get_node(leaf) {
|
|
longest_path.insert(leaf, (node.estimated_cost, vec![leaf]));
|
|
}
|
|
}
|
|
|
|
if let Ok(topo_order) = dag.topological_sort() {
|
|
for &node_id in topo_order.iter().rev() {
|
|
let node = match dag.get_node(node_id) {
|
|
Some(n) => n,
|
|
None => continue,
|
|
};
|
|
|
|
let mut max_cost = node.estimated_cost;
|
|
let mut max_path = vec![node_id];
|
|
|
|
for &child in dag.children(node_id) {
|
|
if let Some(&(child_cost, ref child_path)) = longest_path.get(&child) {
|
|
let total_cost = node.estimated_cost + child_cost;
|
|
if total_cost > max_cost {
|
|
max_cost = total_cost;
|
|
max_path = vec![node_id];
|
|
max_path.extend(child_path);
|
|
}
|
|
}
|
|
}
|
|
|
|
longest_path.insert(node_id, (max_cost, max_path));
|
|
}
|
|
}
|
|
|
|
longest_path
|
|
.into_iter()
|
|
.max_by(|a, b| {
|
|
a.1 .0
|
|
.partial_cmp(&b.1 .0)
|
|
.unwrap_or(std::cmp::Ordering::Equal)
|
|
})
|
|
.map(|(_, (_, path))| path)
|
|
.unwrap_or_default()
|
|
}
|
|
|
|
/// Compute attention scores for the DAG
|
|
pub fn forward(&self, dag: &WasmQueryDag) -> Result<Vec<f32>, JsError> {
|
|
let n = dag.inner.node_count();
|
|
if n == 0 {
|
|
return Err(JsError::new("Empty DAG"));
|
|
}
|
|
|
|
let critical = self.compute_critical_path(&dag.inner);
|
|
let mut scores = HashMap::new();
|
|
let mut total = 0.0f32;
|
|
|
|
for node_id in 0..n {
|
|
if dag.inner.get_node(node_id).is_none() {
|
|
continue;
|
|
}
|
|
|
|
let is_on_critical_path = critical.contains(&node_id);
|
|
let num_children = dag.inner.children(node_id).len();
|
|
|
|
let mut score = if is_on_critical_path {
|
|
self.path_weight
|
|
} else {
|
|
1.0
|
|
};
|
|
|
|
if num_children > 1 {
|
|
score *= 1.0 + (num_children as f32 - 1.0) * self.branch_penalty;
|
|
}
|
|
|
|
scores.insert(node_id, score);
|
|
total += score;
|
|
}
|
|
|
|
if total > 0.0 {
|
|
for score in scores.values_mut() {
|
|
*score /= total;
|
|
}
|
|
}
|
|
|
|
Ok(hashmap_to_vec(&scores, n))
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// MinCut-Gated Attention
|
|
// ============================================================================
|
|
|
|
/// MinCut-gated attention using flow-based bottleneck detection
|
|
///
|
|
/// Uses minimum cut analysis to identify bottleneck nodes
|
|
/// and gates attention through these critical points.
|
|
#[wasm_bindgen]
|
|
pub struct WasmMinCutGatedAttention {
|
|
gate_threshold: f32,
|
|
}
|
|
|
|
#[wasm_bindgen]
|
|
impl WasmMinCutGatedAttention {
|
|
/// Create a new MinCut-gated attention instance
|
|
///
|
|
/// # Arguments
|
|
/// * `gate_threshold` - Threshold for gating (0.0-1.0)
|
|
#[wasm_bindgen(constructor)]
|
|
pub fn new(gate_threshold: f32) -> WasmMinCutGatedAttention {
|
|
WasmMinCutGatedAttention { gate_threshold }
|
|
}
|
|
|
|
/// Compute attention scores for the DAG
|
|
pub fn forward(&self, dag: &WasmQueryDag) -> Result<Vec<f32>, JsError> {
|
|
let n = dag.inner.node_count();
|
|
if n == 0 {
|
|
return Err(JsError::new("Empty DAG"));
|
|
}
|
|
|
|
// Simple bottleneck detection: nodes with high in-degree and out-degree
|
|
let mut scores = HashMap::new();
|
|
let mut total = 0.0f32;
|
|
|
|
for node_id in 0..n {
|
|
if dag.inner.get_node(node_id).is_none() {
|
|
continue;
|
|
}
|
|
|
|
let in_degree = dag.inner.parents(node_id).len();
|
|
let out_degree = dag.inner.children(node_id).len();
|
|
|
|
// Bottleneck score: higher for nodes with high connectivity
|
|
let connectivity = (in_degree + out_degree) as f32;
|
|
let is_bottleneck = connectivity >= self.gate_threshold * n as f32;
|
|
|
|
let score = if is_bottleneck {
|
|
2.0 + connectivity * 0.1
|
|
} else {
|
|
1.0
|
|
};
|
|
|
|
scores.insert(node_id, score);
|
|
total += score;
|
|
}
|
|
|
|
if total > 0.0 {
|
|
for score in scores.values_mut() {
|
|
*score /= total;
|
|
}
|
|
}
|
|
|
|
Ok(hashmap_to_vec(&scores, n))
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// Hierarchical Lorentz Attention
|
|
// ============================================================================
|
|
|
|
/// Hierarchical Lorentz attention in hyperbolic space
|
|
///
|
|
/// Combines DAG hierarchy with Lorentz (hyperboloid) geometry
|
|
/// for multi-scale hierarchical attention.
|
|
#[wasm_bindgen]
|
|
pub struct WasmHierarchicalLorentzAttention {
|
|
curvature: f32,
|
|
temperature: f32,
|
|
}
|
|
|
|
#[wasm_bindgen]
|
|
impl WasmHierarchicalLorentzAttention {
|
|
/// Create a new hierarchical Lorentz attention instance
|
|
///
|
|
/// # Arguments
|
|
/// * `curvature` - Hyperbolic curvature parameter
|
|
/// * `temperature` - Temperature for softmax
|
|
#[wasm_bindgen(constructor)]
|
|
pub fn new(curvature: f32, temperature: f32) -> WasmHierarchicalLorentzAttention {
|
|
WasmHierarchicalLorentzAttention {
|
|
curvature,
|
|
temperature,
|
|
}
|
|
}
|
|
|
|
/// Compute attention scores for the DAG
|
|
pub fn forward(&self, dag: &WasmQueryDag) -> Result<Vec<f32>, JsError> {
|
|
let n = dag.inner.node_count();
|
|
if n == 0 {
|
|
return Err(JsError::new("Empty DAG"));
|
|
}
|
|
|
|
let depths = dag.inner.compute_depths();
|
|
let max_depth = depths.values().max().copied().unwrap_or(0);
|
|
|
|
// Compute hyperbolic distances from origin
|
|
let mut distances: Vec<f32> = Vec::with_capacity(n);
|
|
for node_id in 0..n {
|
|
let depth = depths.get(&node_id).copied().unwrap_or(0);
|
|
// In hyperbolic space, distance grows exponentially with depth
|
|
let radial = (depth as f32 * 0.5).tanh();
|
|
let distance = (1.0 + radial).acosh() * self.curvature.abs();
|
|
distances.push(distance);
|
|
}
|
|
|
|
// Convert to attention scores using softmax
|
|
let max_neg_dist = distances
|
|
.iter()
|
|
.map(|&d| -d / self.temperature)
|
|
.fold(f32::NEG_INFINITY, f32::max);
|
|
let exp_sum: f32 = distances
|
|
.iter()
|
|
.map(|&d| ((-d / self.temperature) - max_neg_dist).exp())
|
|
.sum();
|
|
|
|
let scores: Vec<f32> = distances
|
|
.iter()
|
|
.map(|&d| ((-d / self.temperature) - max_neg_dist).exp() / exp_sum.max(1e-10))
|
|
.collect();
|
|
|
|
Ok(scores)
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// Parallel Branch Attention
|
|
// ============================================================================
|
|
|
|
/// Parallel branch attention for concurrent DAG branches
|
|
///
|
|
/// Identifies parallel branches in the DAG and applies
|
|
/// attention patterns that respect branch independence.
|
|
#[wasm_bindgen]
|
|
pub struct WasmParallelBranchAttention {
|
|
max_branches: usize,
|
|
sync_penalty: f32,
|
|
}
|
|
|
|
#[wasm_bindgen]
|
|
impl WasmParallelBranchAttention {
|
|
/// Create a new parallel branch attention instance
|
|
///
|
|
/// # Arguments
|
|
/// * `max_branches` - Maximum number of branches to consider
|
|
/// * `sync_penalty` - Penalty for synchronization between branches
|
|
#[wasm_bindgen(constructor)]
|
|
pub fn new(max_branches: usize, sync_penalty: f32) -> WasmParallelBranchAttention {
|
|
WasmParallelBranchAttention {
|
|
max_branches,
|
|
sync_penalty,
|
|
}
|
|
}
|
|
|
|
/// Compute attention scores for the DAG
|
|
pub fn forward(&self, dag: &WasmQueryDag) -> Result<Vec<f32>, JsError> {
|
|
let n = dag.inner.node_count();
|
|
if n == 0 {
|
|
return Err(JsError::new("Empty DAG"));
|
|
}
|
|
|
|
// Detect branch points (nodes with multiple children)
|
|
let mut branch_starts: Vec<usize> = Vec::new();
|
|
for node_id in 0..n {
|
|
if dag.inner.children(node_id).len() > 1 {
|
|
branch_starts.push(node_id);
|
|
}
|
|
}
|
|
|
|
let mut scores = HashMap::new();
|
|
let mut total = 0.0f32;
|
|
|
|
for node_id in 0..n {
|
|
if dag.inner.get_node(node_id).is_none() {
|
|
continue;
|
|
}
|
|
|
|
// Check if node is part of a parallel branch
|
|
let parents = dag.inner.parents(node_id);
|
|
let is_branch_child = parents.iter().any(|&p| branch_starts.contains(&p));
|
|
|
|
let children = dag.inner.children(node_id);
|
|
let is_sync_point = children.len() == 0 && parents.len() > 1;
|
|
|
|
let score = if is_branch_child {
|
|
1.5 // Boost parallel branch nodes
|
|
} else if is_sync_point {
|
|
1.0 * (1.0 - self.sync_penalty) // Penalize sync points
|
|
} else {
|
|
1.0
|
|
};
|
|
|
|
scores.insert(node_id, score);
|
|
total += score;
|
|
}
|
|
|
|
if total > 0.0 {
|
|
for score in scores.values_mut() {
|
|
*score /= total;
|
|
}
|
|
}
|
|
|
|
Ok(hashmap_to_vec(&scores, n))
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// Temporal BTSP Attention
|
|
// ============================================================================
|
|
|
|
/// Temporal BTSP (Behavioral Time-Series Pattern) attention
|
|
///
|
|
/// Incorporates temporal patterns and behavioral sequences
|
|
/// for time-aware DAG attention.
|
|
#[wasm_bindgen]
|
|
pub struct WasmTemporalBTSPAttention {
|
|
eligibility_decay: f32,
|
|
baseline_attention: f32,
|
|
}
|
|
|
|
#[wasm_bindgen]
|
|
impl WasmTemporalBTSPAttention {
|
|
/// Create a new temporal BTSP attention instance
|
|
///
|
|
/// # Arguments
|
|
/// * `eligibility_decay` - Decay rate for eligibility traces (0.0-1.0)
|
|
/// * `baseline_attention` - Baseline attention for nodes without history
|
|
#[wasm_bindgen(constructor)]
|
|
pub fn new(eligibility_decay: f32, baseline_attention: f32) -> WasmTemporalBTSPAttention {
|
|
WasmTemporalBTSPAttention {
|
|
eligibility_decay,
|
|
baseline_attention,
|
|
}
|
|
}
|
|
|
|
/// Compute attention scores for the DAG
|
|
pub fn forward(&self, dag: &WasmQueryDag) -> Result<Vec<f32>, JsError> {
|
|
let n = dag.inner.node_count();
|
|
if n == 0 {
|
|
return Err(JsError::new("Empty DAG"));
|
|
}
|
|
|
|
let mut scores = Vec::with_capacity(n);
|
|
let mut total = 0.0f32;
|
|
|
|
for node_id in 0..n {
|
|
let node = match dag.inner.get_node(node_id) {
|
|
Some(n) => n,
|
|
None => {
|
|
scores.push(0.0);
|
|
continue;
|
|
}
|
|
};
|
|
|
|
// Base score from cost and rows
|
|
let cost_factor = (node.estimated_cost as f32 / 100.0).min(1.0);
|
|
let rows_factor = (node.estimated_rows as f32 / 1000.0).min(1.0);
|
|
let score = self.baseline_attention * (0.5 * cost_factor + 0.5 * rows_factor + 0.5);
|
|
|
|
scores.push(score);
|
|
total += score;
|
|
}
|
|
|
|
// Normalize
|
|
if total > 0.0 {
|
|
for score in scores.iter_mut() {
|
|
*score /= total;
|
|
}
|
|
}
|
|
|
|
Ok(scores)
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// DAG Attention Factory
|
|
// ============================================================================
|
|
|
|
/// Factory for creating DAG attention mechanisms
|
|
#[wasm_bindgen]
|
|
pub struct DagAttentionFactory;
|
|
|
|
#[wasm_bindgen]
|
|
impl DagAttentionFactory {
|
|
/// Get available DAG attention types
|
|
#[wasm_bindgen(js_name = availableTypes)]
|
|
pub fn available_types() -> JsValue {
|
|
let types = vec![
|
|
"topological",
|
|
"causal_cone",
|
|
"critical_path",
|
|
"mincut_gated",
|
|
"hierarchical_lorentz",
|
|
"parallel_branch",
|
|
"temporal_btsp",
|
|
];
|
|
serde_wasm_bindgen::to_value(&types).unwrap()
|
|
}
|
|
|
|
/// Get description for a DAG attention type
|
|
#[wasm_bindgen(js_name = getDescription)]
|
|
pub fn get_description(attention_type: &str) -> String {
|
|
match attention_type {
|
|
"topological" => "Position-based attention following DAG topological order".to_string(),
|
|
"causal_cone" => "Lightcone-based attention respecting causal dependencies".to_string(),
|
|
"critical_path" => "Attention weighted by critical execution path distance".to_string(),
|
|
"mincut_gated" => "Flow-based gating through bottleneck nodes".to_string(),
|
|
"hierarchical_lorentz" => {
|
|
"Multi-scale hyperbolic attention for DAG hierarchies".to_string()
|
|
}
|
|
"parallel_branch" => "Branch-aware attention for parallel DAG structures".to_string(),
|
|
"temporal_btsp" => "Time-series pattern attention for temporal DAGs".to_string(),
|
|
_ => "Unknown attention type".to_string(),
|
|
}
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// Tests
|
|
// ============================================================================
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use wasm_bindgen_test::*;
|
|
|
|
wasm_bindgen_test_configure!(run_in_browser);
|
|
|
|
#[wasm_bindgen_test]
|
|
fn test_dag_creation() {
|
|
let mut dag = WasmQueryDag::new();
|
|
let n1 = dag.add_node("scan", 1.0);
|
|
let n2 = dag.add_node("filter", 0.5);
|
|
dag.add_edge(n1, n2);
|
|
|
|
assert_eq!(dag.node_count(), 2);
|
|
assert_eq!(dag.edge_count(), 1);
|
|
}
|
|
|
|
#[wasm_bindgen_test]
|
|
fn test_topological_attention() {
|
|
let mut dag = WasmQueryDag::new();
|
|
dag.add_node("scan", 1.0);
|
|
dag.add_node("filter", 0.5);
|
|
dag.add_node("project", 0.3);
|
|
dag.add_edge(0, 1);
|
|
dag.add_edge(1, 2);
|
|
|
|
let attention = WasmTopologicalAttention::new(0.9);
|
|
let scores = attention.forward(&dag);
|
|
assert!(scores.is_ok());
|
|
let s = scores.unwrap();
|
|
assert_eq!(s.len(), 3);
|
|
}
|
|
|
|
#[wasm_bindgen_test]
|
|
fn test_causal_cone_attention() {
|
|
let mut dag = WasmQueryDag::new();
|
|
dag.add_node("scan", 1.0);
|
|
dag.add_node("filter", 0.5);
|
|
dag.add_edge(0, 1);
|
|
|
|
let attention = WasmCausalConeAttention::new(0.8, 0.9);
|
|
let scores = attention.forward(&dag);
|
|
assert!(scores.is_ok());
|
|
}
|
|
|
|
#[wasm_bindgen_test]
|
|
fn test_critical_path_attention() {
|
|
let mut dag = WasmQueryDag::new();
|
|
dag.add_node("scan", 1.0);
|
|
dag.add_node("filter", 0.5);
|
|
dag.add_edge(0, 1);
|
|
|
|
let attention = WasmCriticalPathAttention::new(2.0, 0.5);
|
|
let scores = attention.forward(&dag);
|
|
assert!(scores.is_ok());
|
|
}
|
|
|
|
#[wasm_bindgen_test]
|
|
fn test_mincut_gated_attention() {
|
|
let mut dag = WasmQueryDag::new();
|
|
dag.add_node("scan", 1.0);
|
|
dag.add_node("filter", 0.5);
|
|
dag.add_edge(0, 1);
|
|
|
|
let attention = WasmMinCutGatedAttention::new(0.5);
|
|
let scores = attention.forward(&dag);
|
|
assert!(scores.is_ok());
|
|
}
|
|
|
|
#[wasm_bindgen_test]
|
|
fn test_hierarchical_lorentz_attention() {
|
|
let mut dag = WasmQueryDag::new();
|
|
dag.add_node("scan", 1.0);
|
|
dag.add_node("filter", 0.5);
|
|
dag.add_edge(0, 1);
|
|
|
|
let attention = WasmHierarchicalLorentzAttention::new(-1.0, 0.1);
|
|
let scores = attention.forward(&dag);
|
|
assert!(scores.is_ok());
|
|
}
|
|
|
|
#[wasm_bindgen_test]
|
|
fn test_parallel_branch_attention() {
|
|
let mut dag = WasmQueryDag::new();
|
|
dag.add_node("scan", 1.0);
|
|
dag.add_node("filter", 0.5);
|
|
dag.add_edge(0, 1);
|
|
|
|
let attention = WasmParallelBranchAttention::new(8, 0.2);
|
|
let scores = attention.forward(&dag);
|
|
assert!(scores.is_ok());
|
|
}
|
|
|
|
#[wasm_bindgen_test]
|
|
fn test_temporal_btsp_attention() {
|
|
let mut dag = WasmQueryDag::new();
|
|
dag.add_node("scan", 1.0);
|
|
dag.add_node("filter", 0.5);
|
|
dag.add_edge(0, 1);
|
|
|
|
let attention = WasmTemporalBTSPAttention::new(0.95, 0.5);
|
|
let scores = attention.forward(&dag);
|
|
assert!(scores.is_ok());
|
|
}
|
|
|
|
#[wasm_bindgen_test]
|
|
fn test_factory_types() {
|
|
let types_js = DagAttentionFactory::available_types();
|
|
assert!(!types_js.is_null());
|
|
}
|
|
}
|