36 KiB
DAG Attention Mechanisms
Overview
This document specifies seven specialized attention mechanisms designed for Directed Acyclic Graph (DAG) structures. Each mechanism leverages unique DAG properties for specific optimization scenarios.
Attention Trait Definition
/// Core trait for all DAG attention mechanisms
pub trait DagAttention: Send + Sync {
/// Compute attention weights for a query node over its context
fn forward(
&self,
query: &DagNode,
context: &DagContext,
config: &AttentionConfig,
) -> Result<AttentionOutput, AttentionError>;
/// Update internal state based on execution feedback
fn update(&mut self, feedback: &AttentionFeedback) -> Result<(), AttentionError>;
/// Get attention type identifier
fn attention_type(&self) -> DagAttentionType;
/// Estimated computation complexity
fn complexity(&self, context_size: usize) -> usize;
}
/// Output from attention computation
pub struct AttentionOutput {
/// Attention weights (sum to 1.0)
pub weights: Vec<f32>,
/// Weighted aggregation of context values
pub aggregated: Vec<f32>,
/// Auxiliary information for learning
pub metadata: AttentionMetadata,
}
/// Context for DAG attention
pub struct DagContext {
/// All nodes in the DAG
pub nodes: Vec<DagNode>,
/// Adjacency list (node_id -> children)
pub edges: HashMap<NodeId, Vec<NodeId>>,
/// Reverse adjacency (node_id -> parents)
pub reverse_edges: HashMap<NodeId, Vec<NodeId>>,
/// Node depths (topological distance from roots)
pub depths: HashMap<NodeId, usize>,
/// Optional: timestamps for temporal attention
pub timestamps: Option<HashMap<NodeId, f64>>,
/// Optional: min-cut criticality scores
pub criticalities: Option<HashMap<NodeId, f32>>,
}
1. Topological Attention
Purpose
Respects DAG ordering by allowing nodes to only attend to their ancestors. This maintains causal consistency in query plans.
Algorithm
pub struct TopologicalAttention {
/// Hidden dimension for projections
hidden_dim: usize,
/// Number of attention heads
num_heads: usize,
/// Query/Key/Value projection weights
w_query: Array2<f32>, // [hidden_dim, hidden_dim]
w_key: Array2<f32>,
w_value: Array2<f32>,
/// Precomputed ancestor masks (lazily computed)
ancestor_cache: DashMap<NodeId, BitVec>,
}
impl TopologicalAttention {
pub fn forward(&self, query_node: &DagNode, ctx: &DagContext) -> AttentionOutput {
// 1. Get or compute ancestor mask
let ancestors = self.get_ancestors(query_node.id, ctx);
// 2. Project query
let q = self.project_query(&query_node.embedding);
// 3. Compute attention scores (only for ancestors)
let mut scores = Vec::with_capacity(ctx.nodes.len());
let scale = (self.hidden_dim as f32).sqrt();
for (i, node) in ctx.nodes.iter().enumerate() {
if ancestors.get(i).unwrap_or(false) {
let k = self.project_key(&node.embedding);
let score = dot(&q, &k) / scale;
scores.push(score);
} else {
scores.push(f32::NEG_INFINITY); // Mask non-ancestors
}
}
// 4. Softmax
let weights = softmax(&scores);
// 5. Weighted aggregation of values
let mut aggregated = vec![0.0; self.hidden_dim];
for (i, node) in ctx.nodes.iter().enumerate() {
if weights[i] > 1e-8 {
let v = self.project_value(&node.embedding);
for (j, val) in v.iter().enumerate() {
aggregated[j] += weights[i] * val;
}
}
}
AttentionOutput {
weights,
aggregated,
metadata: AttentionMetadata::new(DagAttentionType::Topological),
}
}
fn get_ancestors(&self, node_id: NodeId, ctx: &DagContext) -> BitVec {
if let Some(cached) = self.ancestor_cache.get(&node_id) {
return cached.clone();
}
// BFS to find all ancestors
let mut ancestors = BitVec::repeat(false, ctx.nodes.len());
let mut queue = VecDeque::new();
// Start with direct parents
if let Some(parents) = ctx.reverse_edges.get(&node_id) {
for &parent in parents {
queue.push_back(parent);
}
}
while let Some(current) = queue.pop_front() {
let idx = ctx.node_index(current);
if !ancestors.get(idx).unwrap_or(false) {
ancestors.set(idx, true);
if let Some(parents) = ctx.reverse_edges.get(¤t) {
for &parent in parents {
queue.push_back(parent);
}
}
}
}
self.ancestor_cache.insert(node_id, ancestors.clone());
ancestors
}
}
Complexity
- Time: O(n·k) where n = nodes, k = average ancestors
- Space: O(n) for ancestor mask
SQL Interface
-- Compute topological attention
SELECT ruvector_attention_topological(
query_embedding, -- VECTOR: query node embedding
ancestor_embeddings, -- VECTOR[]: ancestor embeddings
'{"num_heads": 8, "hidden_dim": 256}'::jsonb -- config
) AS attention_weights;
2. Causal Cone Attention
Purpose
Extends topological attention with distance-weighted decay. Closer ancestors get more attention.
Algorithm
pub struct CausalConeAttention {
hidden_dim: usize,
num_heads: usize,
/// Decay rate per DAG hop
decay_rate: f32,
/// Maximum lookback depth
max_depth: usize,
/// Projection weights
w_query: Array2<f32>,
w_key: Array2<f32>,
w_value: Array2<f32>,
}
impl CausalConeAttention {
pub fn forward(&self, query_node: &DagNode, ctx: &DagContext) -> AttentionOutput {
let query_depth = ctx.depths[&query_node.id];
// 1. Compute distances from query to all ancestors
let distances = self.compute_distances(query_node.id, ctx);
// 2. Project query
let q = self.project_query(&query_node.embedding);
let scale = (self.hidden_dim as f32).sqrt();
// 3. Compute distance-weighted attention
let mut scores = Vec::with_capacity(ctx.nodes.len());
for (i, node) in ctx.nodes.iter().enumerate() {
if let Some(&dist) = distances.get(&node.id) {
if dist > 0 && dist <= self.max_depth {
let k = self.project_key(&node.embedding);
let base_score = dot(&q, &k) / scale;
// Exponential decay with distance
let decay = (-self.decay_rate * dist as f32).exp();
scores.push(base_score * decay);
} else {
scores.push(f32::NEG_INFINITY);
}
} else {
scores.push(f32::NEG_INFINITY);
}
}
// 4. Softmax and aggregate
let weights = softmax(&scores);
let aggregated = self.aggregate_values(&weights, ctx);
AttentionOutput {
weights,
aggregated,
metadata: AttentionMetadata::new(DagAttentionType::CausalCone),
}
}
fn compute_distances(&self, from: NodeId, ctx: &DagContext) -> HashMap<NodeId, usize> {
let mut distances = HashMap::new();
let mut queue = VecDeque::new();
// BFS from query node going backward
if let Some(parents) = ctx.reverse_edges.get(&from) {
for &parent in parents {
queue.push_back((parent, 1));
}
}
while let Some((current, dist)) = queue.pop_front() {
if dist > self.max_depth {
continue;
}
if !distances.contains_key(¤t) {
distances.insert(current, dist);
if let Some(parents) = ctx.reverse_edges.get(¤t) {
for &parent in parents {
queue.push_back((parent, dist + 1));
}
}
}
}
distances
}
}
Configuration
| Parameter | Default | Description |
|---|---|---|
decay_rate |
0.1 | Exponential decay per hop |
max_depth |
10 | Maximum ancestor distance |
SQL Interface
SELECT ruvector_attention_causal_cone(
query_embedding,
ancestor_embeddings,
ancestor_distances, -- INT[]: hop distances
0.1, -- decay_rate
10 -- max_depth
) AS attention_weights;
3. Critical Path Attention
Purpose
Focuses attention on DAG critical path (longest path). Useful for identifying and optimizing bottleneck operators.
Algorithm
pub struct CriticalPathAttention {
hidden_dim: usize,
num_heads: usize,
/// Boost factor for critical path nodes
critical_path_boost: f32,
/// Projection weights
w_query: Array2<f32>,
w_key: Array2<f32>,
w_value: Array2<f32>,
/// Cached critical paths
critical_path_cache: DashMap<NodeId, HashSet<NodeId>>,
}
impl CriticalPathAttention {
pub fn forward(&self, query_node: &DagNode, ctx: &DagContext) -> AttentionOutput {
// 1. Compute critical path through query node
let critical_path = self.find_critical_path(query_node.id, ctx);
// 2. Project query
let q = self.project_query(&query_node.embedding);
let scale = (self.hidden_dim as f32).sqrt();
// 3. Compute attention with critical path boost
let ancestors = self.get_ancestors(query_node.id, ctx);
let mut scores = Vec::with_capacity(ctx.nodes.len());
for (i, node) in ctx.nodes.iter().enumerate() {
if ancestors.contains(&node.id) {
let k = self.project_key(&node.embedding);
let mut score = dot(&q, &k) / scale;
// Boost critical path nodes
if critical_path.contains(&node.id) {
score += self.critical_path_boost;
}
scores.push(score);
} else {
scores.push(f32::NEG_INFINITY);
}
}
let weights = softmax(&scores);
let aggregated = self.aggregate_values(&weights, ctx);
AttentionOutput {
weights,
aggregated,
metadata: AttentionMetadata::with_critical_path(critical_path),
}
}
fn find_critical_path(&self, through: NodeId, ctx: &DagContext) -> HashSet<NodeId> {
if let Some(cached) = self.critical_path_cache.get(&through) {
return cached.clone();
}
// Find longest path through this node using DP
let mut longest_to = HashMap::new(); // longest path ending at node
let mut longest_from = HashMap::new(); // longest path starting from node
// Topological order
let topo_order = self.topological_sort(ctx);
// Forward pass: longest path TO each node
for &node in &topo_order {
let mut max_len = 0;
if let Some(parents) = ctx.reverse_edges.get(&node) {
for &parent in parents {
max_len = max_len.max(longest_to.get(&parent).unwrap_or(&0) + 1);
}
}
longest_to.insert(node, max_len);
}
// Backward pass: longest path FROM each node
for &node in topo_order.iter().rev() {
let mut max_len = 0;
if let Some(children) = ctx.edges.get(&node) {
for &child in children {
max_len = max_len.max(longest_from.get(&child).unwrap_or(&0) + 1);
}
}
longest_from.insert(node, max_len);
}
// Find nodes on critical path through 'through'
let total_through = longest_to[&through] + longest_from[&through];
let global_longest = topo_order.iter()
.map(|n| longest_to[n] + longest_from[n])
.max()
.unwrap_or(0);
let mut critical = HashSet::new();
if total_through == global_longest {
// 'through' is on a global critical path
for &node in &topo_order {
if longest_to[&node] + longest_from[&node] == global_longest {
critical.insert(node);
}
}
}
self.critical_path_cache.insert(through, critical.clone());
critical
}
}
Configuration
| Parameter | Default | Description |
|---|---|---|
critical_path_boost |
1.5 | Additive boost for critical nodes |
SQL Interface
SELECT ruvector_attention_critical_path(
query_embedding,
ancestor_embeddings,
is_critical, -- BOOLEAN[]: critical path membership
1.5 -- boost factor
) AS attention_weights;
4. MinCut Gated Attention
Purpose
Uses min-cut analysis to gate information flow. Focuses learning on bottleneck operators that dominate execution time.
Algorithm
pub struct MinCutGatedAttention {
hidden_dim: usize,
num_heads: usize,
/// Threshold for full attention (criticality > threshold)
gate_threshold: f32,
/// Min-cut engine for criticality computation
mincut_engine: Arc<SubpolynomialMinCut>,
/// Projection weights
w_query: Array2<f32>,
w_key: Array2<f32>,
w_value: Array2<f32>,
/// Gate projection (learns to predict criticality)
w_gate: Array2<f32>,
}
impl MinCutGatedAttention {
pub fn forward(&self, query_node: &DagNode, ctx: &DagContext) -> AttentionOutput {
// 1. Get or compute criticalities
let criticalities = self.get_criticalities(ctx);
// 2. Project query
let q = self.project_query(&query_node.embedding);
let scale = (self.hidden_dim as f32).sqrt();
// 3. Compute gated attention
let ancestors = self.get_ancestors(query_node.id, ctx);
let mut scores = Vec::with_capacity(ctx.nodes.len());
let mut gates = Vec::with_capacity(ctx.nodes.len());
for (i, node) in ctx.nodes.iter().enumerate() {
if ancestors.contains(&node.id) {
let k = self.project_key(&node.embedding);
let base_score = dot(&q, &k) / scale;
// Compute gate based on criticality
let criticality = criticalities.get(&node.id).unwrap_or(&0.0);
let gate = if *criticality > self.gate_threshold {
1.0 // Full attention for critical nodes
} else {
criticality / self.gate_threshold // Scaled attention
};
scores.push(base_score);
gates.push(gate);
} else {
scores.push(f32::NEG_INFINITY);
gates.push(0.0);
}
}
// 4. Apply gates before softmax
let gated_scores: Vec<f32> = scores.iter()
.zip(gates.iter())
.map(|(s, g)| if *s > f32::NEG_INFINITY { s * g } else { *s })
.collect();
let weights = softmax(&gated_scores);
let aggregated = self.aggregate_values(&weights, ctx);
AttentionOutput {
weights,
aggregated,
metadata: AttentionMetadata::with_criticalities(criticalities),
}
}
fn get_criticalities(&self, ctx: &DagContext) -> HashMap<NodeId, f32> {
// Use precomputed if available
if let Some(ref crit) = ctx.criticalities {
return crit.clone();
}
// Compute via min-cut engine
let mut criticalities = HashMap::new();
let global_cut = self.mincut_engine.query();
for node in &ctx.nodes {
// LocalKCut query around this node
let local_query = LocalKCutQuery {
seed_vertices: vec![node.id],
budget_k: global_cut,
radius: 3,
};
match self.mincut_engine.local_query(local_query) {
LocalKCutResult::Found { cut_value, .. } => {
let criticality = (global_cut - cut_value) as f32 / global_cut as f32;
criticalities.insert(node.id, criticality);
}
LocalKCutResult::NoneInLocality => {
criticalities.insert(node.id, 0.0);
}
}
}
criticalities
}
}
Configuration
| Parameter | Default | Description |
|---|---|---|
gate_threshold |
0.5 | Criticality threshold for full attention |
Complexity
- Time: O(n^0.12 + n·k) - subpolynomial min-cut + attention
- Space: O(n) for criticality map
SQL Interface
SELECT ruvector_attention_mincut_gated(
query_embedding,
ancestor_embeddings,
ruvector_mincut_criticality(dag_id), -- precomputed criticalities
0.5 -- threshold
) AS attention_weights;
5. Hierarchical Lorentz Attention
Purpose
Uses hyperbolic geometry (Lorentz hyperboloid) to naturally represent DAG hierarchy. Deeper nodes embed further from origin.
Algorithm
pub struct HierarchicalLorentzAttention {
hidden_dim: usize,
num_heads: usize,
/// Curvature of hyperbolic space (negative)
curvature: f32,
/// Temperature for attention softmax
temperature: f32,
/// Lorentz model for hyperbolic operations
lorentz: LorentzModel,
/// Projection to hyperbolic space
w_to_hyperbolic: Array2<f32>,
}
impl HierarchicalLorentzAttention {
pub fn forward(&self, query_node: &DagNode, ctx: &DagContext) -> AttentionOutput {
let query_depth = ctx.depths[&query_node.id];
// 1. Map query to hyperbolic space at its depth
let query_hyper = self.to_lorentz(
&query_node.embedding,
query_depth,
);
// 2. Compute Lorentz attention scores
let ancestors = self.get_ancestors(query_node.id, ctx);
let mut scores = Vec::with_capacity(ctx.nodes.len());
for (i, node) in ctx.nodes.iter().enumerate() {
if ancestors.contains(&node.id) {
let node_depth = ctx.depths[&node.id];
let node_hyper = self.to_lorentz(&node.embedding, node_depth);
// Busemann function for O(d) hierarchy scoring
let score = self.lorentz.busemann_score(&query_hyper, &node_hyper);
scores.push(score / self.temperature);
} else {
scores.push(f32::NEG_INFINITY);
}
}
// 3. Hyperbolic softmax
let weights = self.lorentz.hyperbolic_softmax(&scores);
// 4. Einstein midpoint aggregation (closed-form, no iteration)
let aggregated = self.einstein_midpoint_aggregation(&weights, ctx);
AttentionOutput {
weights,
aggregated,
metadata: AttentionMetadata::new(DagAttentionType::HierarchicalLorentz),
}
}
fn to_lorentz(&self, embedding: &[f32], depth: usize) -> LorentzPoint {
// Project to hyperbolic space
let projected = self.w_to_hyperbolic.dot(&Array1::from_vec(embedding.to_vec()));
// Map depth to hyperbolic radius
let radius = (depth as f32 + 1.0).ln() / (-self.curvature).sqrt();
self.lorentz.exp_map_at_origin(&projected.to_vec(), radius)
}
fn einstein_midpoint_aggregation(
&self,
weights: &[f32],
ctx: &DagContext,
) -> Vec<f32> {
// Closed-form weighted centroid in Lorentz model
// Much faster than iterative Fréchet mean
let mut numerator = vec![0.0; self.hidden_dim + 1];
let mut denominator = 0.0;
for (i, node) in ctx.nodes.iter().enumerate() {
if weights[i] > 1e-8 {
let depth = ctx.depths[&node.id];
let hyper = self.to_lorentz(&node.embedding, depth);
// Lorentz factor
let gamma = hyper.lorentz_factor();
let weight = weights[i] * gamma;
for (j, &coord) in hyper.coordinates().iter().enumerate() {
numerator[j] += weight * coord;
}
denominator += weight;
}
}
// Normalize and project back to Euclidean
let midpoint: Vec<f32> = numerator.iter()
.map(|x| x / denominator)
.collect();
self.lorentz.log_map_to_euclidean(&midpoint)
}
}
Configuration
| Parameter | Default | Description |
|---|---|---|
curvature |
-1.0 | Hyperbolic space curvature |
temperature |
1.0 | Softmax temperature |
Benefits
- 5-10x faster than Poincaré attention
- Numerically stable (no tanh clipping)
- Natural hierarchy representation
SQL Interface
SELECT ruvector_attention_hierarchical_lorentz(
query_embedding,
ancestor_embeddings,
ancestor_depths, -- INT[]: DAG depths
-1.0, -- curvature
1.0 -- temperature
) AS attention_weights;
6. Parallel Branch Attention
Purpose
Coordinates across parallel DAG branches. Essential for parallel query execution where branches need to share information.
Algorithm
pub struct ParallelBranchAttention {
hidden_dim: usize,
num_heads: usize,
/// Weight for cross-branch attention
cross_branch_weight: f32,
/// Boost for common ancestors (synchronization points)
common_ancestor_boost: f32,
/// Projection weights
w_query: Array2<f32>,
w_key: Array2<f32>,
w_value: Array2<f32>,
/// Cross-attention projection
w_cross: Array2<f32>,
}
impl ParallelBranchAttention {
pub fn forward(&self, query_node: &DagNode, ctx: &DagContext) -> AttentionOutput {
let query_depth = ctx.depths[&query_node.id];
// 1. Find parallel branches (same depth, different lineage)
let parallel_nodes = self.find_parallel_nodes(query_node.id, query_depth, ctx);
// 2. Find common ancestors (synchronization points)
let common_ancestors = self.find_common_ancestors(
query_node.id,
¶llel_nodes,
ctx,
);
// 3. Project query
let q = self.project_query(&query_node.embedding);
let scale = (self.hidden_dim as f32).sqrt();
// 4. Compute multi-source attention
let ancestors = self.get_ancestors(query_node.id, ctx);
let mut scores = Vec::new();
let mut node_types = Vec::new(); // Track source type
// Own ancestors
for node in &ctx.nodes {
if ancestors.contains(&node.id) {
let k = self.project_key(&node.embedding);
let mut score = dot(&q, &k) / scale;
// Boost common ancestors
if common_ancestors.contains(&node.id) {
score += self.common_ancestor_boost;
}
scores.push(score);
node_types.push(NodeType::Ancestor);
}
}
// Parallel branch nodes (cross-attention)
for ¶llel_id in ¶llel_nodes {
let parallel_node = ctx.get_node(parallel_id);
let k = self.project_cross(¶llel_node.embedding);
let score = dot(&q, &k) / scale * self.cross_branch_weight;
scores.push(score);
node_types.push(NodeType::ParallelBranch);
}
// 5. Softmax over all sources
let weights = softmax(&scores);
// 6. Aggregation (separate paths for different types)
let aggregated = self.multi_source_aggregate(&weights, &node_types, ctx);
AttentionOutput {
weights,
aggregated,
metadata: AttentionMetadata::with_parallel_info(parallel_nodes, common_ancestors),
}
}
fn find_parallel_nodes(
&self,
query_id: NodeId,
query_depth: usize,
ctx: &DagContext,
) -> Vec<NodeId> {
let query_ancestors = self.get_ancestors(query_id, ctx);
ctx.nodes.iter()
.filter(|n| {
n.id != query_id &&
ctx.depths[&n.id] == query_depth &&
!query_ancestors.contains(&n.id) // Not an ancestor
})
.map(|n| n.id)
.collect()
}
fn find_common_ancestors(
&self,
query_id: NodeId,
parallel_ids: &[NodeId],
ctx: &DagContext,
) -> HashSet<NodeId> {
if parallel_ids.is_empty() {
return HashSet::new();
}
// Start with query's ancestors
let mut common = self.get_ancestors(query_id, ctx);
// Intersect with each parallel node's ancestors
for ¶llel_id in parallel_ids {
let parallel_ancestors = self.get_ancestors(parallel_id, ctx);
common = common.intersection(¶llel_ancestors).cloned().collect();
}
common
}
}
Configuration
| Parameter | Default | Description |
|---|---|---|
cross_branch_weight |
0.5 | Weight for cross-branch attention |
common_ancestor_boost |
2.0 | Boost for synchronization points |
SQL Interface
SELECT ruvector_attention_parallel_branch(
query_embedding,
ancestor_embeddings,
parallel_embeddings, -- VECTOR[]: parallel branch embeddings
common_ancestor_mask, -- BOOLEAN[]: common ancestor indicators
0.5, -- cross_weight
2.0 -- common_ancestor_boost
) AS attention_weights;
7. Temporal BTSP Attention
Purpose
Combines DAG structure with temporal spike patterns. Uses BTSP (Behavioral Timescale Synaptic Plasticity) for one-shot learning of time-correlated query patterns.
Algorithm
pub struct TemporalBTSPAttention {
hidden_dim: usize,
num_heads: usize,
/// Coincidence window for temporal grouping (ms)
coincidence_window_ms: f32,
/// Boost for temporally coincident nodes
coincidence_boost: f32,
/// BTSP memory for one-shot pattern recall
btsp_memory: BTSPLayer,
/// Projection weights
w_query: Array2<f32>,
w_key: Array2<f32>,
w_value: Array2<f32>,
}
impl TemporalBTSPAttention {
pub fn forward(
&self,
query_node: &DagNode,
ctx: &DagContext,
) -> AttentionOutput {
let timestamps = ctx.timestamps.as_ref()
.expect("Temporal attention requires timestamps");
let query_time = timestamps[&query_node.id];
// 1. Try BTSP recall first (one-shot memory)
if let Some(recalled) = self.btsp_memory.recall(&query_node.embedding) {
return AttentionOutput {
weights: recalled.weights,
aggregated: recalled.aggregated,
metadata: AttentionMetadata::btsp_recalled(),
};
}
// 2. Find temporally coincident nodes
let coincident: HashSet<NodeId> = ctx.nodes.iter()
.filter(|n| {
let t = timestamps[&n.id];
(query_time - t).abs() < self.coincidence_window_ms as f64
})
.map(|n| n.id)
.collect();
// 3. Compute attention with temporal boost
let ancestors = self.get_ancestors(query_node.id, ctx);
let q = self.project_query(&query_node.embedding);
let scale = (self.hidden_dim as f32).sqrt();
let mut scores = Vec::with_capacity(ctx.nodes.len());
for node in &ctx.nodes {
if ancestors.contains(&node.id) {
let k = self.project_key(&node.embedding);
let mut score = dot(&q, &k) / scale;
// Boost temporally coincident nodes
if coincident.contains(&node.id) {
score *= self.coincidence_boost;
}
scores.push(score);
} else {
scores.push(f32::NEG_INFINITY);
}
}
let weights = softmax(&scores);
let aggregated = self.aggregate_values(&weights, ctx);
// 4. One-shot learning via BTSP plateau
if self.should_learn(&weights) {
self.btsp_memory.associate(
&query_node.embedding,
&weights,
&aggregated,
);
}
AttentionOutput {
weights,
aggregated,
metadata: AttentionMetadata::with_temporal_info(coincident),
}
}
fn should_learn(&self, weights: &[f32]) -> bool {
// Learn if attention is confident (low entropy)
let entropy: f32 = weights.iter()
.filter(|&&w| w > 1e-8)
.map(|&w| -w * w.ln())
.sum();
let max_entropy = (weights.len() as f32).ln();
let normalized_entropy = entropy / max_entropy;
normalized_entropy < 0.5 // Confident attention pattern
}
}
Configuration
| Parameter | Default | Description |
|---|---|---|
coincidence_window_ms |
50.0 | Temporal coincidence window |
coincidence_boost |
1.5 | Boost for coincident nodes |
BTSP Memory Structure
pub struct BTSPLayer {
/// Synaptic weights: pattern → attention
synapses: Vec<BTSPSynapse>,
/// Number of stored patterns
pattern_count: usize,
/// Maximum patterns before consolidation
max_patterns: usize,
/// Similarity threshold for recall
recall_threshold: f32,
}
pub struct BTSPSynapse {
/// Input pattern embedding
pattern: Vec<f32>,
/// Learned attention weights
attention_weights: Vec<f32>,
/// Learned aggregated output
aggregated: Vec<f32>,
/// Confidence score
confidence: f32,
/// Usage count
usage_count: usize,
}
SQL Interface
SELECT ruvector_attention_temporal_btsp(
query_embedding,
ancestor_embeddings,
ancestor_timestamps, -- FLOAT[]: timestamps in ms
50.0, -- coincidence_window_ms
1.5 -- coincidence_boost
) AS attention_weights;
Ensemble Attention
Purpose
Combines multiple attention types for robust performance.
Algorithm
pub struct EnsembleAttention {
/// Component attention mechanisms
components: Vec<Box<dyn DagAttention>>,
/// Combination weights (learned or fixed)
weights: Vec<f32>,
/// Weight learning mode
weight_mode: WeightMode,
}
pub enum WeightMode {
/// Fixed weights
Fixed,
/// Learn weights via gradient descent
Learned,
/// Adaptive based on query pattern
Adaptive(AdaptiveWeightSelector),
}
impl EnsembleAttention {
pub fn forward(&self, query_node: &DagNode, ctx: &DagContext) -> AttentionOutput {
// Get weights for this query
let weights = match &self.weight_mode {
WeightMode::Fixed => self.weights.clone(),
WeightMode::Learned => self.weights.clone(),
WeightMode::Adaptive(selector) => {
selector.select_weights(&query_node.embedding)
}
};
// Compute attention from each component
let outputs: Vec<AttentionOutput> = self.components.iter()
.map(|c| c.forward(query_node, ctx))
.collect();
// Weighted combination
let n = outputs[0].weights.len();
let mut combined_weights = vec![0.0; n];
let mut combined_aggregated = vec![0.0; outputs[0].aggregated.len()];
for (i, output) in outputs.iter().enumerate() {
let w = weights[i];
for (j, &ow) in output.weights.iter().enumerate() {
combined_weights[j] += w * ow;
}
for (j, &oa) in output.aggregated.iter().enumerate() {
combined_aggregated[j] += w * oa;
}
}
// Renormalize weights
let sum: f32 = combined_weights.iter().sum();
if sum > 0.0 {
for w in &mut combined_weights {
*w /= sum;
}
}
AttentionOutput {
weights: combined_weights,
aggregated: combined_aggregated,
metadata: AttentionMetadata::ensemble(
self.components.iter().map(|c| c.attention_type()).collect()
),
}
}
}
SQL Interface
SELECT ruvector_attention_ensemble(
query_embedding,
ancestor_embeddings,
ARRAY['topological', 'critical_path', 'mincut_gated'], -- types
ARRAY[0.4, 0.3, 0.3]::FLOAT[] -- weights (optional)
) AS attention_weights;
Attention Selector (UCB Bandit)
Purpose
Automatically selects the best attention type for each query pattern.
Algorithm
pub struct AttentionSelector {
/// Performance history per (pattern_type, attention_type)
history: DashMap<(PatternTypeId, DagAttentionType), PerformanceStats>,
/// UCB exploration coefficient
ucb_c: f32,
/// Exploration probability (epsilon-greedy)
epsilon: f32,
/// Pattern type classifier
pattern_classifier: PatternClassifier,
}
impl AttentionSelector {
pub fn select(&self, query_embedding: &[f32]) -> DagAttentionType {
// Classify query pattern
let pattern_type = self.pattern_classifier.classify(query_embedding);
// Epsilon-greedy exploration
if rand::random::<f32>() < self.epsilon {
return DagAttentionType::random();
}
// UCB selection
let total_trials: usize = self.history.iter()
.filter(|e| e.key().0 == pattern_type)
.map(|e| e.value().trials)
.sum();
let mut best_ucb = f32::NEG_INFINITY;
let mut best_type = DagAttentionType::Topological;
for attention_type in DagAttentionType::all() {
let key = (pattern_type, attention_type.clone());
let (mean_reward, trials) = self.history.get(&key)
.map(|s| (s.mean_reward(), s.trials))
.unwrap_or((0.5, 1)); // Optimistic initialization
// UCB formula
let exploration = self.ucb_c *
((total_trials as f32).ln() / trials as f32).sqrt();
let ucb = mean_reward + exploration;
if ucb > best_ucb {
best_ucb = ucb;
best_type = attention_type;
}
}
best_type
}
pub fn update(
&self,
query_embedding: &[f32],
attention_type: DagAttentionType,
reward: f32,
) {
let pattern_type = self.pattern_classifier.classify(query_embedding);
let key = (pattern_type, attention_type);
self.history.entry(key)
.or_insert_with(PerformanceStats::new)
.record(reward);
}
}
Configuration
| Parameter | Default | Description |
|---|---|---|
ucb_c |
1.414 | UCB exploration coefficient (√2) |
epsilon |
0.1 | Random exploration probability |
Performance Summary
| Attention Type | Time Complexity | Space Complexity | Best For |
|---|---|---|---|
| Topological | O(n·k) | O(n) | Causal consistency |
| Causal Cone | O(n·d) | O(n) | Distance-weighted |
| Critical Path | O(n + crit_len) | O(n) | Bottleneck focus |
| MinCut Gated | O(n^0.12 + n·k) | O(n) | Bottleneck detection |
| Hierarchical Lorentz | O(n·d) | O(n·d) | Deep hierarchies |
| Parallel Branch | O(n·b) | O(n) | Parallel execution |
| Temporal BTSP | O(n·w) | O(patterns) | Repeated queries |
| Ensemble | O(e·n·k) | O(e·n) | Robust performance |
Where:
- n = number of nodes
- k = average ancestors
- d = DAG depth
- b = parallel branches
- w = coincidence window
- e = ensemble components