# DAG Attention Mechanisms ## Overview This document specifies seven specialized attention mechanisms designed for Directed Acyclic Graph (DAG) structures. Each mechanism leverages unique DAG properties for specific optimization scenarios. ## Attention Trait Definition ```rust /// Core trait for all DAG attention mechanisms pub trait DagAttention: Send + Sync { /// Compute attention weights for a query node over its context fn forward( &self, query: &DagNode, context: &DagContext, config: &AttentionConfig, ) -> Result; /// Update internal state based on execution feedback fn update(&mut self, feedback: &AttentionFeedback) -> Result<(), AttentionError>; /// Get attention type identifier fn attention_type(&self) -> DagAttentionType; /// Estimated computation complexity fn complexity(&self, context_size: usize) -> usize; } /// Output from attention computation pub struct AttentionOutput { /// Attention weights (sum to 1.0) pub weights: Vec, /// Weighted aggregation of context values pub aggregated: Vec, /// Auxiliary information for learning pub metadata: AttentionMetadata, } /// Context for DAG attention pub struct DagContext { /// All nodes in the DAG pub nodes: Vec, /// Adjacency list (node_id -> children) pub edges: HashMap>, /// Reverse adjacency (node_id -> parents) pub reverse_edges: HashMap>, /// Node depths (topological distance from roots) pub depths: HashMap, /// Optional: timestamps for temporal attention pub timestamps: Option>, /// Optional: min-cut criticality scores pub criticalities: Option>, } ``` --- ## 1. Topological Attention ### Purpose Respects DAG ordering by allowing nodes to only attend to their ancestors. This maintains causal consistency in query plans. ### Algorithm ```rust pub struct TopologicalAttention { /// Hidden dimension for projections hidden_dim: usize, /// Number of attention heads num_heads: usize, /// Query/Key/Value projection weights w_query: Array2, // [hidden_dim, hidden_dim] w_key: Array2, w_value: Array2, /// Precomputed ancestor masks (lazily computed) ancestor_cache: DashMap, } impl TopologicalAttention { pub fn forward(&self, query_node: &DagNode, ctx: &DagContext) -> AttentionOutput { // 1. Get or compute ancestor mask let ancestors = self.get_ancestors(query_node.id, ctx); // 2. Project query let q = self.project_query(&query_node.embedding); // 3. Compute attention scores (only for ancestors) let mut scores = Vec::with_capacity(ctx.nodes.len()); let scale = (self.hidden_dim as f32).sqrt(); for (i, node) in ctx.nodes.iter().enumerate() { if ancestors.get(i).unwrap_or(false) { let k = self.project_key(&node.embedding); let score = dot(&q, &k) / scale; scores.push(score); } else { scores.push(f32::NEG_INFINITY); // Mask non-ancestors } } // 4. Softmax let weights = softmax(&scores); // 5. Weighted aggregation of values let mut aggregated = vec![0.0; self.hidden_dim]; for (i, node) in ctx.nodes.iter().enumerate() { if weights[i] > 1e-8 { let v = self.project_value(&node.embedding); for (j, val) in v.iter().enumerate() { aggregated[j] += weights[i] * val; } } } AttentionOutput { weights, aggregated, metadata: AttentionMetadata::new(DagAttentionType::Topological), } } fn get_ancestors(&self, node_id: NodeId, ctx: &DagContext) -> BitVec { if let Some(cached) = self.ancestor_cache.get(&node_id) { return cached.clone(); } // BFS to find all ancestors let mut ancestors = BitVec::repeat(false, ctx.nodes.len()); let mut queue = VecDeque::new(); // Start with direct parents if let Some(parents) = ctx.reverse_edges.get(&node_id) { for &parent in parents { queue.push_back(parent); } } while let Some(current) = queue.pop_front() { let idx = ctx.node_index(current); if !ancestors.get(idx).unwrap_or(false) { ancestors.set(idx, true); if let Some(parents) = ctx.reverse_edges.get(¤t) { for &parent in parents { queue.push_back(parent); } } } } self.ancestor_cache.insert(node_id, ancestors.clone()); ancestors } } ``` ### Complexity - Time: O(n·k) where n = nodes, k = average ancestors - Space: O(n) for ancestor mask ### SQL Interface ```sql -- Compute topological attention SELECT ruvector_attention_topological( query_embedding, -- VECTOR: query node embedding ancestor_embeddings, -- VECTOR[]: ancestor embeddings '{"num_heads": 8, "hidden_dim": 256}'::jsonb -- config ) AS attention_weights; ``` --- ## 2. Causal Cone Attention ### Purpose Extends topological attention with distance-weighted decay. Closer ancestors get more attention. ### Algorithm ```rust pub struct CausalConeAttention { hidden_dim: usize, num_heads: usize, /// Decay rate per DAG hop decay_rate: f32, /// Maximum lookback depth max_depth: usize, /// Projection weights w_query: Array2, w_key: Array2, w_value: Array2, } impl CausalConeAttention { pub fn forward(&self, query_node: &DagNode, ctx: &DagContext) -> AttentionOutput { let query_depth = ctx.depths[&query_node.id]; // 1. Compute distances from query to all ancestors let distances = self.compute_distances(query_node.id, ctx); // 2. Project query let q = self.project_query(&query_node.embedding); let scale = (self.hidden_dim as f32).sqrt(); // 3. Compute distance-weighted attention let mut scores = Vec::with_capacity(ctx.nodes.len()); for (i, node) in ctx.nodes.iter().enumerate() { if let Some(&dist) = distances.get(&node.id) { if dist > 0 && dist <= self.max_depth { let k = self.project_key(&node.embedding); let base_score = dot(&q, &k) / scale; // Exponential decay with distance let decay = (-self.decay_rate * dist as f32).exp(); scores.push(base_score * decay); } else { scores.push(f32::NEG_INFINITY); } } else { scores.push(f32::NEG_INFINITY); } } // 4. Softmax and aggregate let weights = softmax(&scores); let aggregated = self.aggregate_values(&weights, ctx); AttentionOutput { weights, aggregated, metadata: AttentionMetadata::new(DagAttentionType::CausalCone), } } fn compute_distances(&self, from: NodeId, ctx: &DagContext) -> HashMap { let mut distances = HashMap::new(); let mut queue = VecDeque::new(); // BFS from query node going backward if let Some(parents) = ctx.reverse_edges.get(&from) { for &parent in parents { queue.push_back((parent, 1)); } } while let Some((current, dist)) = queue.pop_front() { if dist > self.max_depth { continue; } if !distances.contains_key(¤t) { distances.insert(current, dist); if let Some(parents) = ctx.reverse_edges.get(¤t) { for &parent in parents { queue.push_back((parent, dist + 1)); } } } } distances } } ``` ### Configuration | Parameter | Default | Description | |-----------|---------|-------------| | `decay_rate` | 0.1 | Exponential decay per hop | | `max_depth` | 10 | Maximum ancestor distance | ### SQL Interface ```sql SELECT ruvector_attention_causal_cone( query_embedding, ancestor_embeddings, ancestor_distances, -- INT[]: hop distances 0.1, -- decay_rate 10 -- max_depth ) AS attention_weights; ``` --- ## 3. Critical Path Attention ### Purpose Focuses attention on DAG critical path (longest path). Useful for identifying and optimizing bottleneck operators. ### Algorithm ```rust pub struct CriticalPathAttention { hidden_dim: usize, num_heads: usize, /// Boost factor for critical path nodes critical_path_boost: f32, /// Projection weights w_query: Array2, w_key: Array2, w_value: Array2, /// Cached critical paths critical_path_cache: DashMap>, } impl CriticalPathAttention { pub fn forward(&self, query_node: &DagNode, ctx: &DagContext) -> AttentionOutput { // 1. Compute critical path through query node let critical_path = self.find_critical_path(query_node.id, ctx); // 2. Project query let q = self.project_query(&query_node.embedding); let scale = (self.hidden_dim as f32).sqrt(); // 3. Compute attention with critical path boost let ancestors = self.get_ancestors(query_node.id, ctx); let mut scores = Vec::with_capacity(ctx.nodes.len()); for (i, node) in ctx.nodes.iter().enumerate() { if ancestors.contains(&node.id) { let k = self.project_key(&node.embedding); let mut score = dot(&q, &k) / scale; // Boost critical path nodes if critical_path.contains(&node.id) { score += self.critical_path_boost; } scores.push(score); } else { scores.push(f32::NEG_INFINITY); } } let weights = softmax(&scores); let aggregated = self.aggregate_values(&weights, ctx); AttentionOutput { weights, aggregated, metadata: AttentionMetadata::with_critical_path(critical_path), } } fn find_critical_path(&self, through: NodeId, ctx: &DagContext) -> HashSet { if let Some(cached) = self.critical_path_cache.get(&through) { return cached.clone(); } // Find longest path through this node using DP let mut longest_to = HashMap::new(); // longest path ending at node let mut longest_from = HashMap::new(); // longest path starting from node // Topological order let topo_order = self.topological_sort(ctx); // Forward pass: longest path TO each node for &node in &topo_order { let mut max_len = 0; if let Some(parents) = ctx.reverse_edges.get(&node) { for &parent in parents { max_len = max_len.max(longest_to.get(&parent).unwrap_or(&0) + 1); } } longest_to.insert(node, max_len); } // Backward pass: longest path FROM each node for &node in topo_order.iter().rev() { let mut max_len = 0; if let Some(children) = ctx.edges.get(&node) { for &child in children { max_len = max_len.max(longest_from.get(&child).unwrap_or(&0) + 1); } } longest_from.insert(node, max_len); } // Find nodes on critical path through 'through' let total_through = longest_to[&through] + longest_from[&through]; let global_longest = topo_order.iter() .map(|n| longest_to[n] + longest_from[n]) .max() .unwrap_or(0); let mut critical = HashSet::new(); if total_through == global_longest { // 'through' is on a global critical path for &node in &topo_order { if longest_to[&node] + longest_from[&node] == global_longest { critical.insert(node); } } } self.critical_path_cache.insert(through, critical.clone()); critical } } ``` ### Configuration | Parameter | Default | Description | |-----------|---------|-------------| | `critical_path_boost` | 1.5 | Additive boost for critical nodes | ### SQL Interface ```sql SELECT ruvector_attention_critical_path( query_embedding, ancestor_embeddings, is_critical, -- BOOLEAN[]: critical path membership 1.5 -- boost factor ) AS attention_weights; ``` --- ## 4. MinCut Gated Attention ### Purpose Uses min-cut analysis to gate information flow. Focuses learning on bottleneck operators that dominate execution time. ### Algorithm ```rust pub struct MinCutGatedAttention { hidden_dim: usize, num_heads: usize, /// Threshold for full attention (criticality > threshold) gate_threshold: f32, /// Min-cut engine for criticality computation mincut_engine: Arc, /// Projection weights w_query: Array2, w_key: Array2, w_value: Array2, /// Gate projection (learns to predict criticality) w_gate: Array2, } impl MinCutGatedAttention { pub fn forward(&self, query_node: &DagNode, ctx: &DagContext) -> AttentionOutput { // 1. Get or compute criticalities let criticalities = self.get_criticalities(ctx); // 2. Project query let q = self.project_query(&query_node.embedding); let scale = (self.hidden_dim as f32).sqrt(); // 3. Compute gated attention let ancestors = self.get_ancestors(query_node.id, ctx); let mut scores = Vec::with_capacity(ctx.nodes.len()); let mut gates = Vec::with_capacity(ctx.nodes.len()); for (i, node) in ctx.nodes.iter().enumerate() { if ancestors.contains(&node.id) { let k = self.project_key(&node.embedding); let base_score = dot(&q, &k) / scale; // Compute gate based on criticality let criticality = criticalities.get(&node.id).unwrap_or(&0.0); let gate = if *criticality > self.gate_threshold { 1.0 // Full attention for critical nodes } else { criticality / self.gate_threshold // Scaled attention }; scores.push(base_score); gates.push(gate); } else { scores.push(f32::NEG_INFINITY); gates.push(0.0); } } // 4. Apply gates before softmax let gated_scores: Vec = scores.iter() .zip(gates.iter()) .map(|(s, g)| if *s > f32::NEG_INFINITY { s * g } else { *s }) .collect(); let weights = softmax(&gated_scores); let aggregated = self.aggregate_values(&weights, ctx); AttentionOutput { weights, aggregated, metadata: AttentionMetadata::with_criticalities(criticalities), } } fn get_criticalities(&self, ctx: &DagContext) -> HashMap { // Use precomputed if available if let Some(ref crit) = ctx.criticalities { return crit.clone(); } // Compute via min-cut engine let mut criticalities = HashMap::new(); let global_cut = self.mincut_engine.query(); for node in &ctx.nodes { // LocalKCut query around this node let local_query = LocalKCutQuery { seed_vertices: vec![node.id], budget_k: global_cut, radius: 3, }; match self.mincut_engine.local_query(local_query) { LocalKCutResult::Found { cut_value, .. } => { let criticality = (global_cut - cut_value) as f32 / global_cut as f32; criticalities.insert(node.id, criticality); } LocalKCutResult::NoneInLocality => { criticalities.insert(node.id, 0.0); } } } criticalities } } ``` ### Configuration | Parameter | Default | Description | |-----------|---------|-------------| | `gate_threshold` | 0.5 | Criticality threshold for full attention | ### Complexity - Time: O(n^0.12 + n·k) - subpolynomial min-cut + attention - Space: O(n) for criticality map ### SQL Interface ```sql SELECT ruvector_attention_mincut_gated( query_embedding, ancestor_embeddings, ruvector_mincut_criticality(dag_id), -- precomputed criticalities 0.5 -- threshold ) AS attention_weights; ``` --- ## 5. Hierarchical Lorentz Attention ### Purpose Uses hyperbolic geometry (Lorentz hyperboloid) to naturally represent DAG hierarchy. Deeper nodes embed further from origin. ### Algorithm ```rust pub struct HierarchicalLorentzAttention { hidden_dim: usize, num_heads: usize, /// Curvature of hyperbolic space (negative) curvature: f32, /// Temperature for attention softmax temperature: f32, /// Lorentz model for hyperbolic operations lorentz: LorentzModel, /// Projection to hyperbolic space w_to_hyperbolic: Array2, } impl HierarchicalLorentzAttention { pub fn forward(&self, query_node: &DagNode, ctx: &DagContext) -> AttentionOutput { let query_depth = ctx.depths[&query_node.id]; // 1. Map query to hyperbolic space at its depth let query_hyper = self.to_lorentz( &query_node.embedding, query_depth, ); // 2. Compute Lorentz attention scores let ancestors = self.get_ancestors(query_node.id, ctx); let mut scores = Vec::with_capacity(ctx.nodes.len()); for (i, node) in ctx.nodes.iter().enumerate() { if ancestors.contains(&node.id) { let node_depth = ctx.depths[&node.id]; let node_hyper = self.to_lorentz(&node.embedding, node_depth); // Busemann function for O(d) hierarchy scoring let score = self.lorentz.busemann_score(&query_hyper, &node_hyper); scores.push(score / self.temperature); } else { scores.push(f32::NEG_INFINITY); } } // 3. Hyperbolic softmax let weights = self.lorentz.hyperbolic_softmax(&scores); // 4. Einstein midpoint aggregation (closed-form, no iteration) let aggregated = self.einstein_midpoint_aggregation(&weights, ctx); AttentionOutput { weights, aggregated, metadata: AttentionMetadata::new(DagAttentionType::HierarchicalLorentz), } } fn to_lorentz(&self, embedding: &[f32], depth: usize) -> LorentzPoint { // Project to hyperbolic space let projected = self.w_to_hyperbolic.dot(&Array1::from_vec(embedding.to_vec())); // Map depth to hyperbolic radius let radius = (depth as f32 + 1.0).ln() / (-self.curvature).sqrt(); self.lorentz.exp_map_at_origin(&projected.to_vec(), radius) } fn einstein_midpoint_aggregation( &self, weights: &[f32], ctx: &DagContext, ) -> Vec { // Closed-form weighted centroid in Lorentz model // Much faster than iterative Fréchet mean let mut numerator = vec![0.0; self.hidden_dim + 1]; let mut denominator = 0.0; for (i, node) in ctx.nodes.iter().enumerate() { if weights[i] > 1e-8 { let depth = ctx.depths[&node.id]; let hyper = self.to_lorentz(&node.embedding, depth); // Lorentz factor let gamma = hyper.lorentz_factor(); let weight = weights[i] * gamma; for (j, &coord) in hyper.coordinates().iter().enumerate() { numerator[j] += weight * coord; } denominator += weight; } } // Normalize and project back to Euclidean let midpoint: Vec = numerator.iter() .map(|x| x / denominator) .collect(); self.lorentz.log_map_to_euclidean(&midpoint) } } ``` ### Configuration | Parameter | Default | Description | |-----------|---------|-------------| | `curvature` | -1.0 | Hyperbolic space curvature | | `temperature` | 1.0 | Softmax temperature | ### Benefits - 5-10x faster than Poincaré attention - Numerically stable (no tanh clipping) - Natural hierarchy representation ### SQL Interface ```sql SELECT ruvector_attention_hierarchical_lorentz( query_embedding, ancestor_embeddings, ancestor_depths, -- INT[]: DAG depths -1.0, -- curvature 1.0 -- temperature ) AS attention_weights; ``` --- ## 6. Parallel Branch Attention ### Purpose Coordinates across parallel DAG branches. Essential for parallel query execution where branches need to share information. ### Algorithm ```rust pub struct ParallelBranchAttention { hidden_dim: usize, num_heads: usize, /// Weight for cross-branch attention cross_branch_weight: f32, /// Boost for common ancestors (synchronization points) common_ancestor_boost: f32, /// Projection weights w_query: Array2, w_key: Array2, w_value: Array2, /// Cross-attention projection w_cross: Array2, } impl ParallelBranchAttention { pub fn forward(&self, query_node: &DagNode, ctx: &DagContext) -> AttentionOutput { let query_depth = ctx.depths[&query_node.id]; // 1. Find parallel branches (same depth, different lineage) let parallel_nodes = self.find_parallel_nodes(query_node.id, query_depth, ctx); // 2. Find common ancestors (synchronization points) let common_ancestors = self.find_common_ancestors( query_node.id, ¶llel_nodes, ctx, ); // 3. Project query let q = self.project_query(&query_node.embedding); let scale = (self.hidden_dim as f32).sqrt(); // 4. Compute multi-source attention let ancestors = self.get_ancestors(query_node.id, ctx); let mut scores = Vec::new(); let mut node_types = Vec::new(); // Track source type // Own ancestors for node in &ctx.nodes { if ancestors.contains(&node.id) { let k = self.project_key(&node.embedding); let mut score = dot(&q, &k) / scale; // Boost common ancestors if common_ancestors.contains(&node.id) { score += self.common_ancestor_boost; } scores.push(score); node_types.push(NodeType::Ancestor); } } // Parallel branch nodes (cross-attention) for ¶llel_id in ¶llel_nodes { let parallel_node = ctx.get_node(parallel_id); let k = self.project_cross(¶llel_node.embedding); let score = dot(&q, &k) / scale * self.cross_branch_weight; scores.push(score); node_types.push(NodeType::ParallelBranch); } // 5. Softmax over all sources let weights = softmax(&scores); // 6. Aggregation (separate paths for different types) let aggregated = self.multi_source_aggregate(&weights, &node_types, ctx); AttentionOutput { weights, aggregated, metadata: AttentionMetadata::with_parallel_info(parallel_nodes, common_ancestors), } } fn find_parallel_nodes( &self, query_id: NodeId, query_depth: usize, ctx: &DagContext, ) -> Vec { let query_ancestors = self.get_ancestors(query_id, ctx); ctx.nodes.iter() .filter(|n| { n.id != query_id && ctx.depths[&n.id] == query_depth && !query_ancestors.contains(&n.id) // Not an ancestor }) .map(|n| n.id) .collect() } fn find_common_ancestors( &self, query_id: NodeId, parallel_ids: &[NodeId], ctx: &DagContext, ) -> HashSet { if parallel_ids.is_empty() { return HashSet::new(); } // Start with query's ancestors let mut common = self.get_ancestors(query_id, ctx); // Intersect with each parallel node's ancestors for ¶llel_id in parallel_ids { let parallel_ancestors = self.get_ancestors(parallel_id, ctx); common = common.intersection(¶llel_ancestors).cloned().collect(); } common } } ``` ### Configuration | Parameter | Default | Description | |-----------|---------|-------------| | `cross_branch_weight` | 0.5 | Weight for cross-branch attention | | `common_ancestor_boost` | 2.0 | Boost for synchronization points | ### SQL Interface ```sql SELECT ruvector_attention_parallel_branch( query_embedding, ancestor_embeddings, parallel_embeddings, -- VECTOR[]: parallel branch embeddings common_ancestor_mask, -- BOOLEAN[]: common ancestor indicators 0.5, -- cross_weight 2.0 -- common_ancestor_boost ) AS attention_weights; ``` --- ## 7. Temporal BTSP Attention ### Purpose Combines DAG structure with temporal spike patterns. Uses BTSP (Behavioral Timescale Synaptic Plasticity) for one-shot learning of time-correlated query patterns. ### Algorithm ```rust pub struct TemporalBTSPAttention { hidden_dim: usize, num_heads: usize, /// Coincidence window for temporal grouping (ms) coincidence_window_ms: f32, /// Boost for temporally coincident nodes coincidence_boost: f32, /// BTSP memory for one-shot pattern recall btsp_memory: BTSPLayer, /// Projection weights w_query: Array2, w_key: Array2, w_value: Array2, } impl TemporalBTSPAttention { pub fn forward( &self, query_node: &DagNode, ctx: &DagContext, ) -> AttentionOutput { let timestamps = ctx.timestamps.as_ref() .expect("Temporal attention requires timestamps"); let query_time = timestamps[&query_node.id]; // 1. Try BTSP recall first (one-shot memory) if let Some(recalled) = self.btsp_memory.recall(&query_node.embedding) { return AttentionOutput { weights: recalled.weights, aggregated: recalled.aggregated, metadata: AttentionMetadata::btsp_recalled(), }; } // 2. Find temporally coincident nodes let coincident: HashSet = ctx.nodes.iter() .filter(|n| { let t = timestamps[&n.id]; (query_time - t).abs() < self.coincidence_window_ms as f64 }) .map(|n| n.id) .collect(); // 3. Compute attention with temporal boost let ancestors = self.get_ancestors(query_node.id, ctx); let q = self.project_query(&query_node.embedding); let scale = (self.hidden_dim as f32).sqrt(); let mut scores = Vec::with_capacity(ctx.nodes.len()); for node in &ctx.nodes { if ancestors.contains(&node.id) { let k = self.project_key(&node.embedding); let mut score = dot(&q, &k) / scale; // Boost temporally coincident nodes if coincident.contains(&node.id) { score *= self.coincidence_boost; } scores.push(score); } else { scores.push(f32::NEG_INFINITY); } } let weights = softmax(&scores); let aggregated = self.aggregate_values(&weights, ctx); // 4. One-shot learning via BTSP plateau if self.should_learn(&weights) { self.btsp_memory.associate( &query_node.embedding, &weights, &aggregated, ); } AttentionOutput { weights, aggregated, metadata: AttentionMetadata::with_temporal_info(coincident), } } fn should_learn(&self, weights: &[f32]) -> bool { // Learn if attention is confident (low entropy) let entropy: f32 = weights.iter() .filter(|&&w| w > 1e-8) .map(|&w| -w * w.ln()) .sum(); let max_entropy = (weights.len() as f32).ln(); let normalized_entropy = entropy / max_entropy; normalized_entropy < 0.5 // Confident attention pattern } } ``` ### Configuration | Parameter | Default | Description | |-----------|---------|-------------| | `coincidence_window_ms` | 50.0 | Temporal coincidence window | | `coincidence_boost` | 1.5 | Boost for coincident nodes | ### BTSP Memory Structure ```rust pub struct BTSPLayer { /// Synaptic weights: pattern → attention synapses: Vec, /// Number of stored patterns pattern_count: usize, /// Maximum patterns before consolidation max_patterns: usize, /// Similarity threshold for recall recall_threshold: f32, } pub struct BTSPSynapse { /// Input pattern embedding pattern: Vec, /// Learned attention weights attention_weights: Vec, /// Learned aggregated output aggregated: Vec, /// Confidence score confidence: f32, /// Usage count usage_count: usize, } ``` ### SQL Interface ```sql SELECT ruvector_attention_temporal_btsp( query_embedding, ancestor_embeddings, ancestor_timestamps, -- FLOAT[]: timestamps in ms 50.0, -- coincidence_window_ms 1.5 -- coincidence_boost ) AS attention_weights; ``` --- ## Ensemble Attention ### Purpose Combines multiple attention types for robust performance. ### Algorithm ```rust pub struct EnsembleAttention { /// Component attention mechanisms components: Vec>, /// Combination weights (learned or fixed) weights: Vec, /// Weight learning mode weight_mode: WeightMode, } pub enum WeightMode { /// Fixed weights Fixed, /// Learn weights via gradient descent Learned, /// Adaptive based on query pattern Adaptive(AdaptiveWeightSelector), } impl EnsembleAttention { pub fn forward(&self, query_node: &DagNode, ctx: &DagContext) -> AttentionOutput { // Get weights for this query let weights = match &self.weight_mode { WeightMode::Fixed => self.weights.clone(), WeightMode::Learned => self.weights.clone(), WeightMode::Adaptive(selector) => { selector.select_weights(&query_node.embedding) } }; // Compute attention from each component let outputs: Vec = self.components.iter() .map(|c| c.forward(query_node, ctx)) .collect(); // Weighted combination let n = outputs[0].weights.len(); let mut combined_weights = vec![0.0; n]; let mut combined_aggregated = vec![0.0; outputs[0].aggregated.len()]; for (i, output) in outputs.iter().enumerate() { let w = weights[i]; for (j, &ow) in output.weights.iter().enumerate() { combined_weights[j] += w * ow; } for (j, &oa) in output.aggregated.iter().enumerate() { combined_aggregated[j] += w * oa; } } // Renormalize weights let sum: f32 = combined_weights.iter().sum(); if sum > 0.0 { for w in &mut combined_weights { *w /= sum; } } AttentionOutput { weights: combined_weights, aggregated: combined_aggregated, metadata: AttentionMetadata::ensemble( self.components.iter().map(|c| c.attention_type()).collect() ), } } } ``` ### SQL Interface ```sql SELECT ruvector_attention_ensemble( query_embedding, ancestor_embeddings, ARRAY['topological', 'critical_path', 'mincut_gated'], -- types ARRAY[0.4, 0.3, 0.3]::FLOAT[] -- weights (optional) ) AS attention_weights; ``` --- ## Attention Selector (UCB Bandit) ### Purpose Automatically selects the best attention type for each query pattern. ### Algorithm ```rust pub struct AttentionSelector { /// Performance history per (pattern_type, attention_type) history: DashMap<(PatternTypeId, DagAttentionType), PerformanceStats>, /// UCB exploration coefficient ucb_c: f32, /// Exploration probability (epsilon-greedy) epsilon: f32, /// Pattern type classifier pattern_classifier: PatternClassifier, } impl AttentionSelector { pub fn select(&self, query_embedding: &[f32]) -> DagAttentionType { // Classify query pattern let pattern_type = self.pattern_classifier.classify(query_embedding); // Epsilon-greedy exploration if rand::random::() < self.epsilon { return DagAttentionType::random(); } // UCB selection let total_trials: usize = self.history.iter() .filter(|e| e.key().0 == pattern_type) .map(|e| e.value().trials) .sum(); let mut best_ucb = f32::NEG_INFINITY; let mut best_type = DagAttentionType::Topological; for attention_type in DagAttentionType::all() { let key = (pattern_type, attention_type.clone()); let (mean_reward, trials) = self.history.get(&key) .map(|s| (s.mean_reward(), s.trials)) .unwrap_or((0.5, 1)); // Optimistic initialization // UCB formula let exploration = self.ucb_c * ((total_trials as f32).ln() / trials as f32).sqrt(); let ucb = mean_reward + exploration; if ucb > best_ucb { best_ucb = ucb; best_type = attention_type; } } best_type } pub fn update( &self, query_embedding: &[f32], attention_type: DagAttentionType, reward: f32, ) { let pattern_type = self.pattern_classifier.classify(query_embedding); let key = (pattern_type, attention_type); self.history.entry(key) .or_insert_with(PerformanceStats::new) .record(reward); } } ``` ### Configuration | Parameter | Default | Description | |-----------|---------|-------------| | `ucb_c` | 1.414 | UCB exploration coefficient (√2) | | `epsilon` | 0.1 | Random exploration probability | --- ## Performance Summary | Attention Type | Time Complexity | Space Complexity | Best For | |----------------|-----------------|------------------|----------| | Topological | O(n·k) | O(n) | Causal consistency | | Causal Cone | O(n·d) | O(n) | Distance-weighted | | Critical Path | O(n + crit_len) | O(n) | Bottleneck focus | | MinCut Gated | O(n^0.12 + n·k) | O(n) | Bottleneck detection | | Hierarchical Lorentz | O(n·d) | O(n·d) | Deep hierarchies | | Parallel Branch | O(n·b) | O(n) | Parallel execution | | Temporal BTSP | O(n·w) | O(patterns) | Repeated queries | | Ensemble | O(e·n·k) | O(e·n) | Robust performance | Where: - n = number of nodes - k = average ancestors - d = DAG depth - b = parallel branches - w = coincidence window - e = ensemble components