Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,89 @@
# Advanced Attention Mechanisms - Implementation Notes
## Agent #3 Implementation Summary
This implementation provides 5 advanced attention mechanisms for DAG query optimization:
### 1. Hierarchical Lorentz Attention (`hierarchical_lorentz.rs`)
- **Lines**: 274
- **Complexity**: O(n²·d)
- Uses hyperbolic geometry (Lorentz model) to embed DAG nodes
- Deeper nodes in hierarchy receive higher attention
- Implements Lorentz inner product and distance metrics
### 2. Parallel Branch Attention (`parallel_branch.rs`)
- **Lines**: 291
- **Complexity**: O(n² + b·n)
- Detects and coordinates parallel execution branches
- Balances workload across branches
- Applies synchronization penalties
### 3. Temporal BTSP Attention (`temporal_btsp.rs`)
- **Lines**: 291
- **Complexity**: O(n + t)
- Behavioral Timescale Synaptic Plasticity
- Uses eligibility traces for temporal learning
- Implements plateau state boosting
### 4. Attention Selector (`selector.rs`)
- **Lines**: 281
- **Complexity**: O(1) for selection
- UCB1 bandit algorithm for mechanism selection
- Tracks rewards and counts for each mechanism
- Balances exploration vs exploitation
### 5. Attention Cache (`cache.rs`)
- **Lines**: 316
- **Complexity**: O(1) for get/insert
- LRU eviction policy
- TTL support for entries
- Hash-based DAG fingerprinting
## Trait Definition
**`trait_def.rs`** (75 lines):
- Defines `DagAttentionMechanism` trait
- `AttentionScores` structure with edge weights
- `AttentionError` for error handling
## Integration
All mechanisms are exported from `mod.rs` with appropriate type aliases to avoid conflicts with existing attention system.
## Tests
Each mechanism includes comprehensive unit tests:
- Hierarchical Lorentz: 3 tests
- Parallel Branch: 2 tests
- Temporal BTSP: 3 tests
- Selector: 4 tests
- Cache: 7 tests
Total: 19 new test functions
## Performance Targets
| Mechanism | Target | Notes |
|-----------|--------|-------|
| HierarchicalLorentz | <150μs | For 100 nodes |
| ParallelBranch | <100μs | For 100 nodes |
| TemporalBTSP | <120μs | For 100 nodes |
| Selector.select() | <1μs | UCB1 computation |
| Cache.get() | <1μs | LRU lookup |
## Compatibility Notes
The implementation is designed to work with the updated QueryDag API that uses:
- HashMap-based node storage
- `estimated_cost` and `estimated_rows` instead of `cost` and `selectivity`
- `children(id)` and `parents(id)` methods
- No direct edges iterator
Some test fixtures may need updates to match the current OperatorNode structure.
## Future Work
- Add SIMD optimizations for hyperbolic distance calculations
- Implement GPU acceleration for large DAGs
- Add more sophisticated caching strategies
- Integrate with existing TopologicalAttention, CausalConeAttention, etc.

View File

@@ -0,0 +1,322 @@
//! Attention Cache: LRU cache for computed attention scores
//!
//! Caches attention scores to avoid redundant computation for identical DAGs.
//! Uses LRU eviction policy to manage memory usage.
use super::trait_def::AttentionScores;
use crate::dag::QueryDag;
use std::collections::{hash_map::DefaultHasher, HashMap};
use std::hash::{Hash, Hasher};
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct CacheConfig {
/// Maximum number of entries
pub capacity: usize,
/// Time-to-live for entries
pub ttl: Option<Duration>,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
capacity: 1000,
ttl: Some(Duration::from_secs(300)), // 5 minutes
}
}
}
#[derive(Debug)]
struct CacheEntry {
scores: AttentionScores,
timestamp: Instant,
access_count: usize,
}
pub struct AttentionCache {
config: CacheConfig,
cache: HashMap<u64, CacheEntry>,
access_order: Vec<u64>,
hits: usize,
misses: usize,
}
impl AttentionCache {
pub fn new(config: CacheConfig) -> Self {
Self {
cache: HashMap::with_capacity(config.capacity),
access_order: Vec::with_capacity(config.capacity),
config,
hits: 0,
misses: 0,
}
}
/// Hash a DAG for cache key
fn hash_dag(dag: &QueryDag, mechanism: &str) -> u64 {
let mut hasher = DefaultHasher::new();
// Hash mechanism name
mechanism.hash(&mut hasher);
// Hash number of nodes
dag.node_count().hash(&mut hasher);
// Hash edges structure
let mut edge_list: Vec<(usize, usize)> = Vec::new();
for node_id in dag.node_ids() {
for &child in dag.children(node_id) {
edge_list.push((node_id, child));
}
}
edge_list.sort_unstable();
for (from, to) in edge_list {
from.hash(&mut hasher);
to.hash(&mut hasher);
}
hasher.finish()
}
/// Check if entry is expired
fn is_expired(&self, entry: &CacheEntry) -> bool {
if let Some(ttl) = self.config.ttl {
entry.timestamp.elapsed() > ttl
} else {
false
}
}
/// Get cached scores for a DAG and mechanism
pub fn get(&mut self, dag: &QueryDag, mechanism: &str) -> Option<AttentionScores> {
let key = Self::hash_dag(dag, mechanism);
// Check if key exists and is not expired
let is_expired = self
.cache
.get(&key)
.map(|entry| self.is_expired(entry))
.unwrap_or(true);
if is_expired {
self.cache.remove(&key);
self.access_order.retain(|&k| k != key);
self.misses += 1;
return None;
}
// Update access and return clone
if let Some(entry) = self.cache.get_mut(&key) {
// Update access order (move to end = most recently used)
self.access_order.retain(|&k| k != key);
self.access_order.push(key);
entry.access_count += 1;
self.hits += 1;
Some(entry.scores.clone())
} else {
self.misses += 1;
None
}
}
/// Insert scores into cache
pub fn insert(&mut self, dag: &QueryDag, mechanism: &str, scores: AttentionScores) {
let key = Self::hash_dag(dag, mechanism);
// Evict if at capacity
while self.cache.len() >= self.config.capacity && !self.access_order.is_empty() {
if let Some(oldest) = self.access_order.first().copied() {
self.cache.remove(&oldest);
self.access_order.remove(0);
}
}
let entry = CacheEntry {
scores,
timestamp: Instant::now(),
access_count: 0,
};
self.cache.insert(key, entry);
self.access_order.push(key);
}
/// Clear all entries
pub fn clear(&mut self) {
self.cache.clear();
self.access_order.clear();
self.hits = 0;
self.misses = 0;
}
/// Remove expired entries
pub fn evict_expired(&mut self) {
let expired_keys: Vec<u64> = self
.cache
.iter()
.filter(|(_, entry)| self.is_expired(entry))
.map(|(k, _)| *k)
.collect();
for key in expired_keys {
self.cache.remove(&key);
self.access_order.retain(|&k| k != key);
}
}
/// Get cache statistics
pub fn stats(&self) -> CacheStats {
CacheStats {
size: self.cache.len(),
capacity: self.config.capacity,
hits: self.hits,
misses: self.misses,
hit_rate: if self.hits + self.misses > 0 {
self.hits as f64 / (self.hits + self.misses) as f64
} else {
0.0
},
}
}
/// Get entry with most accesses
pub fn most_accessed(&self) -> Option<(&u64, usize)> {
self.cache
.iter()
.max_by_key(|(_, entry)| entry.access_count)
.map(|(k, entry)| (k, entry.access_count))
}
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub size: usize,
pub capacity: usize,
pub hits: usize,
pub misses: usize,
pub hit_rate: f64,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dag::{OperatorNode, OperatorType};
fn create_test_dag(n: usize) -> QueryDag {
let mut dag = QueryDag::new();
for i in 0..n {
let mut node = OperatorNode::new(i, OperatorType::Scan);
node.estimated_cost = (i + 1) as f64;
dag.add_node(node);
}
if n > 1 {
let _ = dag.add_edge(0, 1);
}
dag
}
#[test]
fn test_cache_insert_and_get() {
let mut cache = AttentionCache::new(CacheConfig::default());
let dag = create_test_dag(3);
let scores = AttentionScores::new(vec![0.5, 0.3, 0.2]);
let expected_scores = scores.scores.clone();
cache.insert(&dag, "test_mechanism", scores);
let retrieved = cache.get(&dag, "test_mechanism").unwrap();
assert_eq!(retrieved.scores, expected_scores);
}
#[test]
fn test_cache_miss() {
let mut cache = AttentionCache::new(CacheConfig::default());
let dag = create_test_dag(3);
let result = cache.get(&dag, "nonexistent");
assert!(result.is_none());
}
#[test]
fn test_lru_eviction() {
let mut cache = AttentionCache::new(CacheConfig {
capacity: 2,
ttl: None,
});
let dag1 = create_test_dag(1);
let dag2 = create_test_dag(2);
let dag3 = create_test_dag(3);
cache.insert(&dag1, "mech", AttentionScores::new(vec![0.5]));
cache.insert(&dag2, "mech", AttentionScores::new(vec![0.3, 0.7]));
cache.insert(&dag3, "mech", AttentionScores::new(vec![0.2, 0.3, 0.5]));
// dag1 should be evicted (LRU), dag2 and dag3 should still be present
let result1 = cache.get(&dag1, "mech");
let result2 = cache.get(&dag2, "mech");
let result3 = cache.get(&dag3, "mech");
assert!(result1.is_none());
assert!(result2.is_some());
assert!(result3.is_some());
}
#[test]
fn test_cache_stats() {
let mut cache = AttentionCache::new(CacheConfig::default());
let dag = create_test_dag(2);
cache.insert(&dag, "mech", AttentionScores::new(vec![0.5, 0.5]));
cache.get(&dag, "mech"); // hit
cache.get(&dag, "nonexistent"); // miss
let stats = cache.stats();
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
assert!((stats.hit_rate - 0.5).abs() < 0.01);
}
#[test]
fn test_ttl_expiration() {
let mut cache = AttentionCache::new(CacheConfig {
capacity: 100,
ttl: Some(Duration::from_millis(50)),
});
let dag = create_test_dag(2);
cache.insert(&dag, "mech", AttentionScores::new(vec![0.5, 0.5]));
// Should be present immediately
assert!(cache.get(&dag, "mech").is_some());
// Wait for expiration
std::thread::sleep(Duration::from_millis(60));
// Should be expired
assert!(cache.get(&dag, "mech").is_none());
}
#[test]
fn test_hash_consistency() {
let dag = create_test_dag(3);
let hash1 = AttentionCache::hash_dag(&dag, "mechanism");
let hash2 = AttentionCache::hash_dag(&dag, "mechanism");
assert_eq!(hash1, hash2);
}
#[test]
fn test_hash_different_mechanisms() {
let dag = create_test_dag(3);
let hash1 = AttentionCache::hash_dag(&dag, "mechanism1");
let hash2 = AttentionCache::hash_dag(&dag, "mechanism2");
assert_ne!(hash1, hash2);
}
}

View File

@@ -0,0 +1,127 @@
//! Causal Cone Attention: Focuses on ancestors with temporal discount
use super::{AttentionError, AttentionScores, DagAttention};
use crate::dag::QueryDag;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct CausalConeConfig {
pub time_window_ms: u64,
pub future_discount: f32,
pub ancestor_weight: f32,
}
impl Default for CausalConeConfig {
fn default() -> Self {
Self {
time_window_ms: 1000,
future_discount: 0.8,
ancestor_weight: 0.9,
}
}
}
pub struct CausalConeAttention {
config: CausalConeConfig,
}
impl CausalConeAttention {
pub fn new(config: CausalConeConfig) -> Self {
Self { config }
}
pub fn with_defaults() -> Self {
Self::new(CausalConeConfig::default())
}
}
impl DagAttention for CausalConeAttention {
fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError> {
if dag.node_count() == 0 {
return Err(AttentionError::EmptyDag);
}
let mut scores = HashMap::new();
let mut total = 0.0f32;
// For each node, compute attention based on:
// 1. Number of ancestors (causal influence)
// 2. Distance from node (temporal decay)
let node_ids: Vec<usize> = (0..dag.node_count()).collect();
for node_id in node_ids {
if dag.get_node(node_id).is_none() {
continue;
}
let ancestors = dag.ancestors(node_id);
let ancestor_count = ancestors.len();
// Base score is proportional to causal influence (number of ancestors)
let mut score = 1.0 + (ancestor_count as f32 * self.config.ancestor_weight);
// Apply temporal discount based on depth
let depths = dag.compute_depths();
if let Some(&depth) = depths.get(&node_id) {
score *= self.config.future_discount.powi(depth as i32);
}
scores.insert(node_id, score);
total += score;
}
// Normalize to sum to 1
if total > 0.0 {
for score in scores.values_mut() {
*score /= total;
}
}
Ok(scores)
}
fn update(&mut self, _dag: &QueryDag, _times: &HashMap<usize, f64>) {
// Could update temporal discount based on actual execution times
// For now, static configuration
}
fn name(&self) -> &'static str {
"causal_cone"
}
fn complexity(&self) -> &'static str {
"O(n^2)"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dag::{OperatorNode, OperatorType};
#[test]
fn test_causal_cone_attention() {
let mut dag = QueryDag::new();
// Create a DAG with multiple paths
let id0 = dag.add_node(OperatorNode::seq_scan(0, "table1"));
let id1 = dag.add_node(OperatorNode::seq_scan(0, "table2"));
let id2 = dag.add_node(OperatorNode::hash_join(0, "id"));
let id3 = dag.add_node(OperatorNode::project(0, vec!["name".to_string()]));
dag.add_edge(id0, id2).unwrap();
dag.add_edge(id1, id2).unwrap();
dag.add_edge(id2, id3).unwrap();
let attention = CausalConeAttention::with_defaults();
let scores = attention.forward(&dag).unwrap();
// Check normalization
let sum: f32 = scores.values().sum();
assert!((sum - 1.0).abs() < 1e-5);
// All scores should be in [0, 1]
for &score in scores.values() {
assert!(score >= 0.0 && score <= 1.0);
}
}
}

View File

@@ -0,0 +1,198 @@
//! Critical Path Attention: Focuses on bottleneck nodes
use super::{AttentionError, AttentionScores, DagAttention};
use crate::dag::QueryDag;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct CriticalPathConfig {
pub path_weight: f32,
pub branch_penalty: f32,
}
impl Default for CriticalPathConfig {
fn default() -> Self {
Self {
path_weight: 2.0,
branch_penalty: 0.5,
}
}
}
pub struct CriticalPathAttention {
config: CriticalPathConfig,
critical_path: Vec<usize>,
}
impl CriticalPathAttention {
pub fn new(config: CriticalPathConfig) -> Self {
Self {
config,
critical_path: Vec::new(),
}
}
pub fn with_defaults() -> Self {
Self::new(CriticalPathConfig::default())
}
/// Compute the critical path (longest path by cost)
fn compute_critical_path(&self, dag: &QueryDag) -> Vec<usize> {
let mut longest_path: HashMap<usize, (f64, Vec<usize>)> = HashMap::new();
// Initialize leaves
for &leaf in &dag.leaves() {
if let Some(node) = dag.get_node(leaf) {
longest_path.insert(leaf, (node.estimated_cost, vec![leaf]));
}
}
// Process nodes in reverse topological order
if let Ok(topo_order) = dag.topological_sort() {
for &node_id in topo_order.iter().rev() {
let node = match dag.get_node(node_id) {
Some(n) => n,
None => continue,
};
let mut max_cost = node.estimated_cost;
let mut max_path = vec![node_id];
// Check all children
for &child in dag.children(node_id) {
if let Some(&(child_cost, ref child_path)) = longest_path.get(&child) {
let total_cost = node.estimated_cost + child_cost;
if total_cost > max_cost {
max_cost = total_cost;
max_path = vec![node_id];
max_path.extend(child_path);
}
}
}
longest_path.insert(node_id, (max_cost, max_path));
}
}
// Find the path with maximum cost
longest_path
.into_iter()
.max_by(|a, b| {
a.1 .0
.partial_cmp(&b.1 .0)
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(_, (_, path))| path)
.unwrap_or_default()
}
}
impl DagAttention for CriticalPathAttention {
fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError> {
if dag.node_count() == 0 {
return Err(AttentionError::EmptyDag);
}
let critical = self.compute_critical_path(dag);
let mut scores = HashMap::new();
let mut total = 0.0f32;
// Assign higher attention to nodes on critical path
let node_ids: Vec<usize> = (0..dag.node_count()).collect();
for node_id in node_ids {
if dag.get_node(node_id).is_none() {
continue;
}
let is_on_critical_path = critical.contains(&node_id);
let num_children = dag.children(node_id).len();
let mut score = if is_on_critical_path {
self.config.path_weight
} else {
1.0
};
// Apply branch penalty for nodes with many children (potential bottlenecks)
if num_children > 1 {
score *= 1.0 + (num_children as f32 - 1.0) * self.config.branch_penalty;
}
scores.insert(node_id, score);
total += score;
}
// Normalize to sum to 1
if total > 0.0 {
for score in scores.values_mut() {
*score /= total;
}
}
Ok(scores)
}
fn update(&mut self, dag: &QueryDag, execution_times: &HashMap<usize, f64>) {
// Recompute critical path based on actual execution times
// For now, we use the static cost-based approach
self.critical_path = self.compute_critical_path(dag);
// Could adjust path_weight based on execution time variance
if !execution_times.is_empty() {
let max_time = execution_times.values().fold(0.0f64, |a, &b| a.max(b));
let avg_time: f64 =
execution_times.values().sum::<f64>() / execution_times.len() as f64;
if max_time > 0.0 && avg_time > 0.0 {
// Increase path weight if there's high variance
let variance_ratio = max_time / avg_time;
if variance_ratio > 2.0 {
self.config.path_weight = (self.config.path_weight * 1.1).min(5.0);
}
}
}
}
fn name(&self) -> &'static str {
"critical_path"
}
fn complexity(&self) -> &'static str {
"O(n + e)"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dag::{OperatorNode, OperatorType};
#[test]
fn test_critical_path_attention() {
let mut dag = QueryDag::new();
// Create a DAG with different costs
let id0 =
dag.add_node(OperatorNode::seq_scan(0, "large_table").with_estimates(10000.0, 10.0));
let id1 =
dag.add_node(OperatorNode::filter(0, "status = 'active'").with_estimates(1000.0, 1.0));
let id2 = dag.add_node(OperatorNode::hash_join(0, "user_id").with_estimates(5000.0, 5.0));
dag.add_edge(id0, id2).unwrap();
dag.add_edge(id1, id2).unwrap();
let attention = CriticalPathAttention::with_defaults();
let scores = attention.forward(&dag).unwrap();
// Check normalization
let sum: f32 = scores.values().sum();
assert!((sum - 1.0).abs() < 1e-5);
// Nodes on critical path should have higher attention
let critical = attention.compute_critical_path(&dag);
for &node_id in &critical {
let score = scores.get(&node_id).unwrap();
assert!(*score > 0.0);
}
}
}

View File

@@ -0,0 +1,288 @@
//! Hierarchical Lorentz Attention: Hyperbolic geometry for tree-like structures
//!
//! This mechanism embeds DAG nodes in hyperbolic space using the Lorentz (hyperboloid) model,
//! where hierarchical relationships are naturally represented by distance from the origin.
use super::trait_def::{AttentionError, AttentionScores, DagAttentionMechanism};
use crate::dag::QueryDag;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct HierarchicalLorentzConfig {
/// Curvature parameter (-1.0 for standard Poincaré ball)
pub curvature: f32,
/// Scale factor for temporal dimension
pub time_scale: f32,
/// Embedding dimension
pub dim: usize,
/// Temperature for softmax
pub temperature: f32,
}
impl Default for HierarchicalLorentzConfig {
fn default() -> Self {
Self {
curvature: -1.0,
time_scale: 1.0,
dim: 64,
temperature: 0.1,
}
}
}
pub struct HierarchicalLorentzAttention {
config: HierarchicalLorentzConfig,
}
impl HierarchicalLorentzAttention {
pub fn new(config: HierarchicalLorentzConfig) -> Self {
Self { config }
}
/// Lorentz inner product: -x0*y0 + x1*y1 + ... + xn*yn
fn lorentz_inner(&self, x: &[f32], y: &[f32]) -> f32 {
if x.is_empty() || y.is_empty() {
return 0.0;
}
-x[0] * y[0] + x[1..].iter().zip(&y[1..]).map(|(a, b)| a * b).sum::<f32>()
}
/// Lorentz distance in hyperboloid model
fn lorentz_distance(&self, x: &[f32], y: &[f32]) -> f32 {
let inner = self.lorentz_inner(x, y);
// Clamp to avoid numerical issues with acosh
let clamped = (-inner).max(1.0);
clamped.acosh() * self.config.curvature.abs()
}
/// Project to hyperboloid: [sqrt(1 + ||x||^2), x1, x2, ..., xn]
fn project_to_hyperboloid(&self, x: &[f32]) -> Vec<f32> {
let spatial_norm_sq: f32 = x.iter().map(|v| v * v).sum();
let time_coord = (1.0 + spatial_norm_sq).sqrt();
let mut result = Vec::with_capacity(x.len() + 1);
result.push(time_coord * self.config.time_scale);
result.extend_from_slice(x);
result
}
/// Compute hierarchical depth for each node
fn compute_depths(&self, dag: &QueryDag) -> Vec<usize> {
let n = dag.node_count();
let mut depths = vec![0; n];
let mut adj_list: HashMap<usize, Vec<usize>> = HashMap::new();
// Build adjacency list
for node_id in dag.node_ids() {
for &child in dag.children(node_id) {
adj_list.entry(node_id).or_insert_with(Vec::new).push(child);
}
}
// Find root nodes (nodes with no incoming edges)
let mut has_incoming = vec![false; n];
for node_id in dag.node_ids() {
for &child in dag.children(node_id) {
if child < n {
has_incoming[child] = true;
}
}
}
// BFS to compute depths
let mut queue: Vec<usize> = (0..n).filter(|&i| !has_incoming[i]).collect();
let mut visited = vec![false; n];
while let Some(node) = queue.pop() {
if visited[node] {
continue;
}
visited[node] = true;
if let Some(children) = adj_list.get(&node) {
for &child in children {
if child < n {
depths[child] = depths[node] + 1;
queue.push(child);
}
}
}
}
depths
}
/// Embed node in hyperbolic space based on depth and position
fn embed_node(&self, node_id: usize, depth: usize, total_nodes: usize) -> Vec<f32> {
let dim = self.config.dim;
let mut embedding = vec![0.0; dim];
// Use depth to determine radial distance from origin
let radial = (depth as f32 * 0.5).tanh();
// Angular position based on node ID
let angle = 2.0 * std::f32::consts::PI * (node_id as f32) / (total_nodes as f32).max(1.0);
// Spherical coordinates in hyperbolic space
embedding[0] = radial * angle.cos();
if dim > 1 {
embedding[1] = radial * angle.sin();
}
// Add noise to remaining dimensions for better separation
for i in 2..dim {
embedding[i] = 0.1 * ((node_id + i) as f32).sin();
}
embedding
}
/// Compute attention using hyperbolic distances
fn compute_attention_from_distances(&self, distances: &[f32]) -> Vec<f32> {
if distances.is_empty() {
return vec![];
}
// Convert distances to attention scores using softmax
// Closer nodes (smaller distance) should have higher attention
let neg_distances: Vec<f32> = distances
.iter()
.map(|&d| -d / self.config.temperature)
.collect();
// Softmax
let max_val = neg_distances
.iter()
.cloned()
.fold(f32::NEG_INFINITY, f32::max);
let exp_sum: f32 = neg_distances.iter().map(|&x| (x - max_val).exp()).sum();
if exp_sum == 0.0 {
// Uniform distribution if all distances are too large
return vec![1.0 / distances.len() as f32; distances.len()];
}
neg_distances
.iter()
.map(|&x| (x - max_val).exp() / exp_sum)
.collect()
}
}
impl DagAttentionMechanism for HierarchicalLorentzAttention {
fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError> {
if dag.node_count() == 0 {
return Err(AttentionError::InvalidDag("Empty DAG".to_string()));
}
let n = dag.node_count();
// Step 1: Compute hierarchical depths
let depths = self.compute_depths(dag);
// Step 2: Embed each node in Euclidean space
let euclidean_embeddings: Vec<Vec<f32>> =
(0..n).map(|i| self.embed_node(i, depths[i], n)).collect();
// Step 3: Project to hyperboloid
let hyperbolic_embeddings: Vec<Vec<f32>> = euclidean_embeddings
.iter()
.map(|emb| self.project_to_hyperboloid(emb))
.collect();
// Step 4: Compute pairwise distances from a reference point (origin-like)
let origin = self.project_to_hyperboloid(&vec![0.0; self.config.dim]);
let distances: Vec<f32> = hyperbolic_embeddings
.iter()
.map(|emb| self.lorentz_distance(emb, &origin))
.collect();
// Step 5: Convert distances to attention scores
let scores = self.compute_attention_from_distances(&distances);
// Step 6: Compute edge weights (optional)
let mut edge_weights = vec![vec![0.0; n]; n];
for i in 0..n {
for j in 0..n {
let dist =
self.lorentz_distance(&hyperbolic_embeddings[i], &hyperbolic_embeddings[j]);
edge_weights[i][j] = (-dist / self.config.temperature).exp();
}
}
let mut result = AttentionScores::new(scores)
.with_edge_weights(edge_weights)
.with_metadata("mechanism".to_string(), "hierarchical_lorentz".to_string());
result.metadata.insert(
"avg_depth".to_string(),
format!("{:.2}", depths.iter().sum::<usize>() as f32 / n as f32),
);
Ok(result)
}
fn name(&self) -> &'static str {
"hierarchical_lorentz"
}
fn complexity(&self) -> &'static str {
"O(n²·d)"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dag::{OperatorNode, OperatorType};
#[test]
fn test_lorentz_distance() {
let config = HierarchicalLorentzConfig::default();
let attention = HierarchicalLorentzAttention::new(config);
let x = vec![1.0, 0.5, 0.3];
let y = vec![1.2, 0.6, 0.4];
let dist = attention.lorentz_distance(&x, &y);
assert!(dist >= 0.0, "Distance should be non-negative");
}
#[test]
fn test_project_to_hyperboloid() {
let config = HierarchicalLorentzConfig::default();
let attention = HierarchicalLorentzAttention::new(config);
let x = vec![0.5, 0.3, 0.2];
let projected = attention.project_to_hyperboloid(&x);
assert_eq!(projected.len(), 4);
assert!(projected[0] > 0.0, "Time coordinate should be positive");
}
#[test]
fn test_hierarchical_attention() {
let config = HierarchicalLorentzConfig::default();
let attention = HierarchicalLorentzAttention::new(config);
let mut dag = QueryDag::new();
let mut node0 = OperatorNode::new(0, OperatorType::Scan);
node0.estimated_cost = 1.0;
dag.add_node(node0);
let mut node1 = OperatorNode::new(
1,
OperatorType::Filter {
predicate: "x > 0".to_string(),
},
);
node1.estimated_cost = 2.0;
dag.add_node(node1);
dag.add_edge(0, 1).unwrap();
let result = attention.forward(&dag).unwrap();
assert_eq!(result.scores.len(), 2);
assert!((result.scores.iter().sum::<f32>() - 1.0).abs() < 0.01);
}
}

View File

@@ -0,0 +1,253 @@
//! MinCut Gated Attention: Gates attention by graph cut criticality
use super::trait_def::{AttentionError, AttentionScores, DagAttentionMechanism};
use crate::dag::QueryDag;
use std::collections::{HashMap, HashSet, VecDeque};
#[derive(Debug, Clone)]
pub enum FlowCapacity {
UnitCapacity,
CostBased,
RowBased,
}
#[derive(Debug, Clone)]
pub struct MinCutConfig {
pub gate_threshold: f32,
pub flow_capacity: FlowCapacity,
}
impl Default for MinCutConfig {
fn default() -> Self {
Self {
gate_threshold: 0.5,
flow_capacity: FlowCapacity::UnitCapacity,
}
}
}
pub struct MinCutGatedAttention {
config: MinCutConfig,
}
impl MinCutGatedAttention {
pub fn new(config: MinCutConfig) -> Self {
Self { config }
}
pub fn with_defaults() -> Self {
Self::new(MinCutConfig::default())
}
/// Compute min-cut between leaves and root using Ford-Fulkerson
fn compute_min_cut(&self, dag: &QueryDag) -> HashSet<usize> {
let mut cut_nodes = HashSet::new();
// Build capacity matrix from the DAG structure
let mut capacity: HashMap<(usize, usize), f64> = HashMap::new();
for node_id in 0..dag.node_count() {
if dag.get_node(node_id).is_none() {
continue;
}
for &child in dag.children(node_id) {
let cap = match self.config.flow_capacity {
FlowCapacity::UnitCapacity => 1.0,
FlowCapacity::CostBased => dag
.get_node(node_id)
.map(|n| n.estimated_cost)
.unwrap_or(1.0),
FlowCapacity::RowBased => dag
.get_node(node_id)
.map(|n| n.estimated_rows)
.unwrap_or(1.0),
};
capacity.insert((node_id, child), cap);
}
}
// Find source (root) and sink (any leaf)
let source = match dag.root() {
Some(root) => root,
None => return cut_nodes,
};
let leaves = dag.leaves();
if leaves.is_empty() {
return cut_nodes;
}
// Use first leaf as sink
let sink = leaves[0];
// Ford-Fulkerson to find max flow
let mut residual = capacity.clone();
#[allow(unused_variables, unused_assignments)]
let mut total_flow = 0.0;
loop {
// BFS to find augmenting path
let mut parent: HashMap<usize, usize> = HashMap::new();
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
queue.push_back(source);
visited.insert(source);
while let Some(u) = queue.pop_front() {
if u == sink {
break;
}
for v in dag.children(u) {
if !visited.contains(v) && residual.get(&(u, *v)).copied().unwrap_or(0.0) > 0.0
{
visited.insert(*v);
parent.insert(*v, u);
queue.push_back(*v);
}
}
}
// No augmenting path found
if !parent.contains_key(&sink) {
break;
}
// Find minimum capacity along the path
let mut path_flow = f64::INFINITY;
let mut v = sink;
while v != source {
let u = parent[&v];
path_flow = path_flow.min(residual.get(&(u, v)).copied().unwrap_or(0.0));
v = u;
}
// Update residual capacities
v = sink;
while v != source {
let u = parent[&v];
*residual.entry((u, v)).or_insert(0.0) -= path_flow;
*residual.entry((v, u)).or_insert(0.0) += path_flow;
v = u;
}
total_flow += path_flow;
}
// Find nodes reachable from source in residual graph
let mut reachable = HashSet::new();
let mut queue = VecDeque::new();
queue.push_back(source);
reachable.insert(source);
while let Some(u) = queue.pop_front() {
for &v in dag.children(u) {
if !reachable.contains(&v) && residual.get(&(u, v)).copied().unwrap_or(0.0) > 0.0 {
reachable.insert(v);
queue.push_back(v);
}
}
}
// Nodes in the cut are those with edges crossing from reachable to non-reachable
for node_id in 0..dag.node_count() {
if dag.get_node(node_id).is_none() {
continue;
}
for &child in dag.children(node_id) {
if reachable.contains(&node_id) && !reachable.contains(&child) {
cut_nodes.insert(node_id);
cut_nodes.insert(child);
}
}
}
cut_nodes
}
}
impl DagAttentionMechanism for MinCutGatedAttention {
fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError> {
if dag.node_count() == 0 {
return Err(AttentionError::InvalidDag("Empty DAG".to_string()));
}
let cut_nodes = self.compute_min_cut(dag);
let n = dag.node_count();
let mut score_vec = vec![0.0; n];
let mut total = 0.0f32;
// Gate attention based on whether node is in cut
for node_id in 0..n {
if dag.get_node(node_id).is_none() {
continue;
}
let is_in_cut = cut_nodes.contains(&node_id);
let score = if is_in_cut {
// Nodes in the cut are critical bottlenecks
1.0
} else {
// Other nodes get reduced attention
self.config.gate_threshold
};
score_vec[node_id] = score;
total += score;
}
// Normalize to sum to 1
if total > 0.0 {
for score in score_vec.iter_mut() {
*score /= total;
}
}
Ok(AttentionScores::new(score_vec))
}
fn name(&self) -> &'static str {
"mincut_gated"
}
fn complexity(&self) -> &'static str {
"O(n * e^2)"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dag::{OperatorNode, OperatorType};
#[test]
fn test_mincut_gated_attention() {
let mut dag = QueryDag::new();
// Create a simple bottleneck DAG
let id0 = dag.add_node(OperatorNode::seq_scan(0, "table1"));
let id1 = dag.add_node(OperatorNode::seq_scan(0, "table2"));
let id2 = dag.add_node(OperatorNode::hash_join(0, "id"));
let id3 = dag.add_node(OperatorNode::filter(0, "status = 'active'"));
let id4 = dag.add_node(OperatorNode::project(0, vec!["name".to_string()]));
// Create bottleneck at node id2
dag.add_edge(id0, id2).unwrap();
dag.add_edge(id1, id2).unwrap();
dag.add_edge(id2, id3).unwrap();
dag.add_edge(id2, id4).unwrap();
let attention = MinCutGatedAttention::with_defaults();
let scores = attention.forward(&dag).unwrap();
// Check normalization
let sum: f32 = scores.scores.iter().sum();
assert!((sum - 1.0).abs() < 1e-5);
// All scores should be in [0, 1]
for &score in &scores.scores {
assert!(score >= 0.0 && score <= 1.0);
}
}
}

View File

@@ -0,0 +1,38 @@
//! DAG Attention Mechanisms
//!
//! This module provides graph-topology-aware attention mechanisms for DAG-based
//! query optimization. Unlike traditional neural attention, these mechanisms
//! leverage the structural properties of the DAG (topology, paths, cuts) to
//! compute attention scores.
// Team 2 (Agent #2) - Base attention mechanisms
mod causal_cone;
mod critical_path;
mod mincut_gated;
mod topological;
mod traits;
// Team 2 (Agent #3) - Advanced attention mechanisms
mod cache;
mod hierarchical_lorentz;
mod parallel_branch;
mod selector;
mod temporal_btsp;
mod trait_def;
// Export base mechanisms
pub use causal_cone::{CausalConeAttention, CausalConeConfig};
pub use critical_path::{CriticalPathAttention, CriticalPathConfig};
pub use mincut_gated::{FlowCapacity, MinCutConfig, MinCutGatedAttention};
pub use topological::{TopologicalAttention, TopologicalConfig};
pub use traits::{AttentionConfig, AttentionError, AttentionScores, DagAttention};
// Export advanced mechanisms
pub use cache::{AttentionCache, CacheConfig, CacheStats};
pub use hierarchical_lorentz::{HierarchicalLorentzAttention, HierarchicalLorentzConfig};
pub use parallel_branch::{ParallelBranchAttention, ParallelBranchConfig};
pub use selector::{AttentionSelector, MechanismStats, SelectorConfig};
pub use temporal_btsp::{TemporalBTSPAttention, TemporalBTSPConfig};
pub use trait_def::{
AttentionError as AttentionErrorV2, AttentionScores as AttentionScoresV2, DagAttentionMechanism,
};

View File

@@ -0,0 +1,303 @@
//! Parallel Branch Attention: Coordinates attention across parallel execution branches
//!
//! This mechanism identifies parallel branches in the DAG and distributes attention
//! to balance workload and minimize synchronization overhead.
use super::trait_def::{AttentionError, AttentionScores, DagAttentionMechanism};
use crate::dag::QueryDag;
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone)]
pub struct ParallelBranchConfig {
/// Maximum number of parallel branches to consider
pub max_branches: usize,
/// Penalty for synchronization between branches
pub sync_penalty: f32,
/// Weight for branch balance in attention computation
pub balance_weight: f32,
/// Temperature for softmax
pub temperature: f32,
}
impl Default for ParallelBranchConfig {
fn default() -> Self {
Self {
max_branches: 8,
sync_penalty: 0.2,
balance_weight: 0.5,
temperature: 0.1,
}
}
}
pub struct ParallelBranchAttention {
config: ParallelBranchConfig,
}
impl ParallelBranchAttention {
pub fn new(config: ParallelBranchConfig) -> Self {
Self { config }
}
/// Detect parallel branches (nodes with same parent, no edges between them)
fn detect_branches(&self, dag: &QueryDag) -> Vec<Vec<usize>> {
let n = dag.node_count();
let mut children_of: HashMap<usize, Vec<usize>> = HashMap::new();
let mut parents_of: HashMap<usize, Vec<usize>> = HashMap::new();
// Build parent-child relationships from adjacency
for node_id in dag.node_ids() {
let children = dag.children(node_id);
if !children.is_empty() {
for &child in children {
children_of
.entry(node_id)
.or_insert_with(Vec::new)
.push(child);
parents_of
.entry(child)
.or_insert_with(Vec::new)
.push(node_id);
}
}
}
let mut branches = Vec::new();
let mut visited = HashSet::new();
// For each node, check if its children form parallel branches
for node_id in 0..n {
if let Some(children) = children_of.get(&node_id) {
if children.len() > 1 {
// Check if children are truly parallel (no edges between them)
let mut parallel_group = Vec::new();
for &child in children {
if !visited.contains(&child) {
// Check if this child has edges to any siblings
let child_children = dag.children(child);
let has_sibling_edge = children
.iter()
.any(|&other| other != child && child_children.contains(&other));
if !has_sibling_edge {
parallel_group.push(child);
visited.insert(child);
}
}
}
if parallel_group.len() > 1 {
branches.push(parallel_group);
}
}
}
}
branches
}
/// Compute branch balance score (lower is better balanced)
fn branch_balance(&self, branches: &[Vec<usize>], dag: &QueryDag) -> f32 {
if branches.is_empty() {
return 1.0;
}
let mut total_variance = 0.0;
for branch in branches {
if branch.len() <= 1 {
continue;
}
// Compute costs for each node in the branch
let costs: Vec<f64> = branch
.iter()
.filter_map(|&id| dag.get_node(id).map(|n| n.estimated_cost))
.collect();
if costs.is_empty() {
continue;
}
// Compute variance
let mean = costs.iter().sum::<f64>() / costs.len() as f64;
let variance =
costs.iter().map(|&c| (c - mean).powi(2)).sum::<f64>() / costs.len() as f64;
total_variance += variance as f32;
}
// Normalize by number of branches
if branches.is_empty() {
1.0
} else {
(total_variance / branches.len() as f32).sqrt()
}
}
/// Compute criticality score for a branch
fn branch_criticality(&self, branch: &[usize], dag: &QueryDag) -> f32 {
if branch.is_empty() {
return 0.0;
}
// Sum of costs in the branch
let total_cost: f64 = branch
.iter()
.filter_map(|&id| dag.get_node(id).map(|n| n.estimated_cost))
.sum();
// Average rows (higher rows = more critical for filtering)
let avg_rows: f64 = branch
.iter()
.filter_map(|&id| dag.get_node(id).map(|n| n.estimated_rows))
.sum::<f64>()
/ branch.len().max(1) as f64;
// Criticality is high cost + high row count
(total_cost * (avg_rows / 1000.0).min(1.0)) as f32
}
/// Compute attention scores based on parallel branch analysis
fn compute_branch_attention(&self, dag: &QueryDag, branches: &[Vec<usize>]) -> Vec<f32> {
let n = dag.node_count();
let mut scores = vec![0.0; n];
// Base score for nodes not in any branch
let base_score = 0.5;
for i in 0..n {
scores[i] = base_score;
}
// Compute balance metric
let balance_penalty = self.branch_balance(branches, dag);
// Assign scores based on branch criticality
for branch in branches {
let criticality = self.branch_criticality(branch, dag);
// Higher criticality = higher attention
// Apply balance penalty
let branch_score = criticality * (1.0 - self.config.balance_weight * balance_penalty);
for &node_id in branch {
if node_id < n {
scores[node_id] = branch_score;
}
}
}
// Apply sync penalty to nodes that synchronize branches
for from in dag.node_ids() {
for &to in dag.children(from) {
if from < n && to < n {
// Check if this edge connects different branches
let from_branch = branches.iter().position(|b| b.iter().any(|&x| x == from));
let to_branch = branches.iter().position(|b| b.iter().any(|&x| x == to));
if from_branch.is_some() && to_branch.is_some() && from_branch != to_branch {
scores[to] *= 1.0 - self.config.sync_penalty;
}
}
}
}
// Normalize using softmax
let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_sum: f32 = scores
.iter()
.map(|&s| ((s - max_score) / self.config.temperature).exp())
.sum();
if exp_sum > 0.0 {
for score in scores.iter_mut() {
*score = ((*score - max_score) / self.config.temperature).exp() / exp_sum;
}
} else {
// Uniform if all scores are too low
let uniform = 1.0 / n as f32;
scores.fill(uniform);
}
scores
}
}
impl DagAttentionMechanism for ParallelBranchAttention {
fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError> {
if dag.node_count() == 0 {
return Err(AttentionError::InvalidDag("Empty DAG".to_string()));
}
// Step 1: Detect parallel branches
let branches = self.detect_branches(dag);
// Step 2: Compute attention based on branches
let scores = self.compute_branch_attention(dag, &branches);
// Step 3: Build result
let mut result = AttentionScores::new(scores)
.with_metadata("mechanism".to_string(), "parallel_branch".to_string())
.with_metadata("num_branches".to_string(), branches.len().to_string());
let balance = self.branch_balance(&branches, dag);
result
.metadata
.insert("balance_score".to_string(), format!("{:.4}", balance));
Ok(result)
}
fn name(&self) -> &'static str {
"parallel_branch"
}
fn complexity(&self) -> &'static str {
"O(n² + b·n)"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dag::{OperatorNode, OperatorType};
#[test]
fn test_detect_branches() {
let config = ParallelBranchConfig::default();
let attention = ParallelBranchAttention::new(config);
let mut dag = QueryDag::new();
for i in 0..4 {
dag.add_node(OperatorNode::new(i, OperatorType::Scan));
}
// Create parallel branches: 0 -> 1, 0 -> 2, 1 -> 3, 2 -> 3
dag.add_edge(0, 1).unwrap();
dag.add_edge(0, 2).unwrap();
dag.add_edge(1, 3).unwrap();
dag.add_edge(2, 3).unwrap();
let branches = attention.detect_branches(&dag);
assert!(!branches.is_empty());
}
#[test]
fn test_parallel_attention() {
let config = ParallelBranchConfig::default();
let attention = ParallelBranchAttention::new(config);
let mut dag = QueryDag::new();
for i in 0..3 {
let mut node = OperatorNode::new(i, OperatorType::Scan);
node.estimated_cost = (i + 1) as f64;
dag.add_node(node);
}
dag.add_edge(0, 1).unwrap();
dag.add_edge(0, 2).unwrap();
let result = attention.forward(&dag).unwrap();
assert_eq!(result.scores.len(), 3);
}
}

View File

@@ -0,0 +1,305 @@
//! Attention Selector: UCB Bandit for mechanism selection
//!
//! Implements Upper Confidence Bound (UCB1) algorithm to dynamically select
//! the best attention mechanism based on observed performance.
use super::trait_def::{AttentionError, AttentionScores, DagAttentionMechanism};
use crate::dag::QueryDag;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct SelectorConfig {
/// UCB exploration constant (typically sqrt(2))
pub exploration_factor: f32,
/// Optimistic initialization value
pub initial_value: f32,
/// Minimum samples before exploitation
pub min_samples: usize,
}
impl Default for SelectorConfig {
fn default() -> Self {
Self {
exploration_factor: (2.0_f32).sqrt(),
initial_value: 1.0,
min_samples: 5,
}
}
}
pub struct AttentionSelector {
config: SelectorConfig,
mechanisms: Vec<Box<dyn DagAttentionMechanism>>,
/// Cumulative rewards for each mechanism
rewards: Vec<f32>,
/// Number of times each mechanism was selected
counts: Vec<usize>,
/// Total number of selections
total_count: usize,
}
impl AttentionSelector {
pub fn new(mechanisms: Vec<Box<dyn DagAttentionMechanism>>, config: SelectorConfig) -> Self {
let n = mechanisms.len();
let initial_value = config.initial_value;
Self {
config,
mechanisms,
rewards: vec![initial_value; n],
counts: vec![0; n],
total_count: 0,
}
}
/// Select mechanism using UCB1 algorithm
pub fn select(&self) -> usize {
if self.mechanisms.is_empty() {
return 0;
}
// If any mechanism hasn't been tried min_samples times, try it
for (i, &count) in self.counts.iter().enumerate() {
if count < self.config.min_samples {
return i;
}
}
// UCB1 selection: exploitation + exploration
let ln_total = (self.total_count as f32).ln().max(1.0);
let ucb_values: Vec<f32> = self
.mechanisms
.iter()
.enumerate()
.map(|(i, _)| {
let count = self.counts[i] as f32;
if count == 0.0 {
return f32::INFINITY;
}
let exploitation = self.rewards[i] / count;
let exploration = self.config.exploration_factor * (ln_total / count).sqrt();
exploitation + exploration
})
.collect();
ucb_values
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0)
}
/// Update rewards after execution
pub fn update(&mut self, mechanism_idx: usize, reward: f32) {
if mechanism_idx < self.rewards.len() {
self.rewards[mechanism_idx] += reward;
self.counts[mechanism_idx] += 1;
self.total_count += 1;
}
}
/// Get the selected mechanism
pub fn get_mechanism(&self, idx: usize) -> Option<&dyn DagAttentionMechanism> {
self.mechanisms.get(idx).map(|m| m.as_ref())
}
/// Get mutable reference to mechanism for updates
pub fn get_mechanism_mut(&mut self, idx: usize) -> Option<&mut Box<dyn DagAttentionMechanism>> {
self.mechanisms.get_mut(idx)
}
/// Get statistics for all mechanisms
pub fn stats(&self) -> HashMap<&'static str, MechanismStats> {
self.mechanisms
.iter()
.enumerate()
.map(|(i, m)| {
let stats = MechanismStats {
total_reward: self.rewards[i],
count: self.counts[i],
avg_reward: if self.counts[i] > 0 {
self.rewards[i] / self.counts[i] as f32
} else {
0.0
},
};
(m.name(), stats)
})
.collect()
}
/// Get the best performing mechanism based on average reward
pub fn best_mechanism(&self) -> Option<usize> {
self.mechanisms
.iter()
.enumerate()
.filter(|(i, _)| self.counts[*i] >= self.config.min_samples)
.max_by(|(i, _), (j, _)| {
let avg_i = self.rewards[*i] / self.counts[*i] as f32;
let avg_j = self.rewards[*j] / self.counts[*j] as f32;
avg_i
.partial_cmp(&avg_j)
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, _)| i)
}
/// Reset all statistics
pub fn reset(&mut self) {
for i in 0..self.rewards.len() {
self.rewards[i] = self.config.initial_value;
self.counts[i] = 0;
}
self.total_count = 0;
}
/// Forward pass using selected mechanism
pub fn forward(&mut self, dag: &QueryDag) -> Result<(AttentionScores, usize), AttentionError> {
let selected = self.select();
let mechanism = self
.get_mechanism(selected)
.ok_or_else(|| AttentionError::ConfigError("No mechanisms available".to_string()))?;
let scores = mechanism.forward(dag)?;
Ok((scores, selected))
}
}
#[derive(Debug, Clone)]
pub struct MechanismStats {
pub total_reward: f32,
pub count: usize,
pub avg_reward: f32,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dag::{OperatorNode, OperatorType, QueryDag};
// Mock mechanism for testing
struct MockMechanism {
name: &'static str,
score_value: f32,
}
impl DagAttentionMechanism for MockMechanism {
fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError> {
let scores = vec![self.score_value; dag.nodes.len()];
Ok(AttentionScores::new(scores))
}
fn name(&self) -> &'static str {
self.name
}
fn complexity(&self) -> &'static str {
"O(1)"
}
}
#[test]
fn test_ucb_selection() {
let mechanisms: Vec<Box<dyn DagAttentionMechanism>> = vec![
Box::new(MockMechanism {
name: "mech1",
score_value: 0.5,
}),
Box::new(MockMechanism {
name: "mech2",
score_value: 0.7,
}),
Box::new(MockMechanism {
name: "mech3",
score_value: 0.3,
}),
];
let mut selector = AttentionSelector::new(mechanisms, SelectorConfig::default());
// First selections should explore all mechanisms
for _ in 0..15 {
let selected = selector.select();
selector.update(selected, 0.5);
}
assert!(selector.total_count > 0);
assert!(selector.counts.iter().all(|&c| c > 0));
}
#[test]
fn test_best_mechanism() {
let mechanisms: Vec<Box<dyn DagAttentionMechanism>> = vec![
Box::new(MockMechanism {
name: "poor",
score_value: 0.3,
}),
Box::new(MockMechanism {
name: "good",
score_value: 0.8,
}),
];
let mut selector = AttentionSelector::new(
mechanisms,
SelectorConfig {
min_samples: 2,
..Default::default()
},
);
// Simulate different rewards
selector.update(0, 0.3);
selector.update(0, 0.4);
selector.update(1, 0.8);
selector.update(1, 0.9);
let best = selector.best_mechanism().unwrap();
assert_eq!(best, 1);
}
#[test]
fn test_selector_forward() {
let mechanisms: Vec<Box<dyn DagAttentionMechanism>> = vec![Box::new(MockMechanism {
name: "test",
score_value: 0.5,
})];
let mut selector = AttentionSelector::new(mechanisms, SelectorConfig::default());
let mut dag = QueryDag::new();
let node = OperatorNode::new(0, OperatorType::Scan);
dag.add_node(node);
let (scores, idx) = selector.forward(&dag).unwrap();
assert_eq!(scores.scores.len(), 1);
assert_eq!(idx, 0);
}
#[test]
fn test_stats() {
let mechanisms: Vec<Box<dyn DagAttentionMechanism>> = vec![Box::new(MockMechanism {
name: "mech1",
score_value: 0.5,
})];
// Use initial_value = 0 so we can test pure update accumulation
let config = SelectorConfig {
initial_value: 0.0,
..Default::default()
};
let mut selector = AttentionSelector::new(mechanisms, config);
selector.update(0, 1.0);
selector.update(0, 2.0);
let stats = selector.stats();
let mech1_stats = stats.get("mech1").unwrap();
assert_eq!(mech1_stats.count, 2);
assert_eq!(mech1_stats.total_reward, 3.0);
assert_eq!(mech1_stats.avg_reward, 1.5);
}
}

View File

@@ -0,0 +1,301 @@
//! Temporal BTSP Attention: Behavioral Timescale Synaptic Plasticity
//!
//! This mechanism implements a biologically-inspired attention mechanism based on
//! eligibility traces and plateau potentials, allowing the system to learn from
//! temporal patterns in query execution.
use super::trait_def::{AttentionError, AttentionScores, DagAttentionMechanism};
use crate::dag::QueryDag;
use std::collections::HashMap;
use std::time::Instant;
#[derive(Debug, Clone)]
pub struct TemporalBTSPConfig {
/// Duration of plateau state in milliseconds
pub plateau_duration_ms: u64,
/// Decay rate for eligibility traces (0.0 to 1.0)
pub eligibility_decay: f32,
/// Learning rate for trace updates
pub learning_rate: f32,
/// Temperature for softmax
pub temperature: f32,
/// Baseline attention for nodes without history
pub baseline_attention: f32,
}
impl Default for TemporalBTSPConfig {
fn default() -> Self {
Self {
plateau_duration_ms: 500,
eligibility_decay: 0.95,
learning_rate: 0.1,
temperature: 0.1,
baseline_attention: 0.5,
}
}
}
pub struct TemporalBTSPAttention {
config: TemporalBTSPConfig,
/// Eligibility traces for each node
eligibility_traces: HashMap<usize, f32>,
/// Timestamp of last plateau for each node
last_plateau: HashMap<usize, Instant>,
/// Total updates counter
update_count: usize,
}
impl TemporalBTSPAttention {
pub fn new(config: TemporalBTSPConfig) -> Self {
Self {
config,
eligibility_traces: HashMap::new(),
last_plateau: HashMap::new(),
update_count: 0,
}
}
/// Update eligibility trace for a node
fn update_eligibility(&mut self, node_id: usize, signal: f32) {
let trace = self.eligibility_traces.entry(node_id).or_insert(0.0);
*trace = *trace * self.config.eligibility_decay + signal * self.config.learning_rate;
// Clamp to [0, 1]
*trace = trace.max(0.0).min(1.0);
}
/// Check if node is in plateau state
fn is_plateau(&self, node_id: usize) -> bool {
self.last_plateau
.get(&node_id)
.map(|t| t.elapsed().as_millis() < self.config.plateau_duration_ms as u128)
.unwrap_or(false)
}
/// Trigger plateau for a node
fn trigger_plateau(&mut self, node_id: usize) {
self.last_plateau.insert(node_id, Instant::now());
}
/// Compute base attention from topology
fn compute_topology_attention(&self, dag: &QueryDag) -> Vec<f32> {
let n = dag.node_count();
let mut scores = vec![self.config.baseline_attention; n];
// Simple heuristic: nodes with higher cost get more attention
for node in dag.nodes() {
if node.id < n {
let cost_factor = (node.estimated_cost as f32 / 100.0).min(1.0);
let rows_factor = (node.estimated_rows as f32 / 1000.0).min(1.0);
scores[node.id] = 0.5 * cost_factor + 0.5 * rows_factor;
}
}
scores
}
/// Apply eligibility trace modulation
fn apply_eligibility_modulation(&self, base_scores: &mut [f32]) {
for (node_id, &trace) in &self.eligibility_traces {
if *node_id < base_scores.len() {
// Boost attention based on eligibility trace
base_scores[*node_id] *= 1.0 + trace;
}
}
}
/// Apply plateau boosting
fn apply_plateau_boost(&self, scores: &mut [f32]) {
for (node_id, _) in &self.last_plateau {
if *node_id < scores.len() && self.is_plateau(*node_id) {
// Strong boost for nodes in plateau state
scores[*node_id] *= 1.5;
}
}
}
/// Normalize scores using softmax
fn normalize_scores(&self, scores: &mut [f32]) {
if scores.is_empty() {
return;
}
let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_sum: f32 = scores
.iter()
.map(|&s| ((s - max_score) / self.config.temperature).exp())
.sum();
if exp_sum > 0.0 {
for score in scores.iter_mut() {
*score = ((*score - max_score) / self.config.temperature).exp() / exp_sum;
}
} else {
let uniform = 1.0 / scores.len() as f32;
scores.fill(uniform);
}
}
}
impl DagAttentionMechanism for TemporalBTSPAttention {
fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError> {
if dag.nodes.is_empty() {
return Err(AttentionError::InvalidDag("Empty DAG".to_string()));
}
// Step 1: Compute base attention from topology
let mut scores = self.compute_topology_attention(dag);
// Step 2: Modulate by eligibility traces
self.apply_eligibility_modulation(&mut scores);
// Step 3: Apply plateau boosting for recently active nodes
self.apply_plateau_boost(&mut scores);
// Step 4: Normalize
self.normalize_scores(&mut scores);
// Build result with metadata
let mut result = AttentionScores::new(scores)
.with_metadata("mechanism".to_string(), "temporal_btsp".to_string())
.with_metadata("update_count".to_string(), self.update_count.to_string());
let active_traces = self
.eligibility_traces
.values()
.filter(|&&t| t > 0.01)
.count();
result
.metadata
.insert("active_traces".to_string(), active_traces.to_string());
let active_plateaus = self
.last_plateau
.keys()
.filter(|k| self.is_plateau(**k))
.count();
result
.metadata
.insert("active_plateaus".to_string(), active_plateaus.to_string());
Ok(result)
}
fn name(&self) -> &'static str {
"temporal_btsp"
}
fn complexity(&self) -> &'static str {
"O(n + t)"
}
fn update(&mut self, dag: &QueryDag, execution_times: &HashMap<usize, f64>) {
self.update_count += 1;
// Update eligibility traces based on execution feedback
for (node_id, &exec_time) in execution_times {
let node = match dag.get_node(*node_id) {
Some(n) => n,
None => continue,
};
let expected_time = node.estimated_cost;
// Compute reward signal: positive if faster than expected, negative if slower
let time_ratio = exec_time / expected_time.max(0.001);
let reward = if time_ratio < 1.0 {
// Faster than expected - positive signal
1.0 - time_ratio as f32
} else {
// Slower than expected - negative signal
-(time_ratio as f32 - 1.0).min(1.0)
};
// Update eligibility trace
self.update_eligibility(*node_id, reward);
// Trigger plateau for nodes that significantly exceeded expectations
if reward > 0.3 {
self.trigger_plateau(*node_id);
}
}
// Decay traces for nodes that weren't executed
let executed_nodes: std::collections::HashSet<_> = execution_times.keys().collect();
for node_id in 0..dag.node_count() {
if !executed_nodes.contains(&node_id) {
self.update_eligibility(node_id, 0.0);
}
}
}
fn reset(&mut self) {
self.eligibility_traces.clear();
self.last_plateau.clear();
self.update_count = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dag::{OperatorNode, OperatorType};
use std::thread::sleep;
use std::time::Duration;
#[test]
fn test_eligibility_update() {
let config = TemporalBTSPConfig::default();
let mut attention = TemporalBTSPAttention::new(config);
attention.update_eligibility(0, 0.5);
assert!(attention.eligibility_traces.get(&0).unwrap() > &0.0);
attention.update_eligibility(0, 0.5);
assert!(attention.eligibility_traces.get(&0).unwrap() > &0.0);
}
#[test]
fn test_plateau_state() {
let mut config = TemporalBTSPConfig::default();
config.plateau_duration_ms = 100;
let mut attention = TemporalBTSPAttention::new(config);
attention.trigger_plateau(0);
assert!(attention.is_plateau(0));
sleep(Duration::from_millis(150));
assert!(!attention.is_plateau(0));
}
#[test]
fn test_temporal_attention() {
let config = TemporalBTSPConfig::default();
let mut attention = TemporalBTSPAttention::new(config);
let mut dag = QueryDag::new();
for i in 0..3 {
let mut node = OperatorNode::new(i, OperatorType::Scan);
node.estimated_cost = 10.0;
dag.add_node(node);
}
// Initial forward pass
let result1 = attention.forward(&dag).unwrap();
assert_eq!(result1.scores.len(), 3);
// Simulate execution feedback
let mut exec_times = HashMap::new();
exec_times.insert(0, 5.0); // Faster than expected
exec_times.insert(1, 15.0); // Slower than expected
attention.update(&dag, &exec_times);
// Second forward pass should show different attention
let result2 = attention.forward(&dag).unwrap();
assert_eq!(result2.scores.len(), 3);
// Node 0 should have higher attention due to positive feedback
assert!(attention.eligibility_traces.get(&0).unwrap() > &0.0);
}
}

View File

@@ -0,0 +1,109 @@
//! Topological Attention: Respects DAG ordering with depth-based decay
use super::{AttentionError, AttentionScores, DagAttention};
use crate::dag::QueryDag;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct TopologicalConfig {
pub decay_factor: f32, // 0.9 default
pub max_depth: usize, // 10 default
}
impl Default for TopologicalConfig {
fn default() -> Self {
Self {
decay_factor: 0.9,
max_depth: 10,
}
}
}
pub struct TopologicalAttention {
config: TopologicalConfig,
}
impl TopologicalAttention {
pub fn new(config: TopologicalConfig) -> Self {
Self { config }
}
pub fn with_defaults() -> Self {
Self::new(TopologicalConfig::default())
}
}
impl DagAttention for TopologicalAttention {
fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError> {
if dag.node_count() == 0 {
return Err(AttentionError::EmptyDag);
}
let depths = dag.compute_depths();
let max_depth = depths.values().max().copied().unwrap_or(0);
let mut scores = HashMap::new();
let mut total = 0.0f32;
for (&node_id, &depth) in &depths {
// Higher attention for nodes closer to root (higher depth from leaves)
let normalized_depth = depth as f32 / (max_depth.max(1) as f32);
let score = self.config.decay_factor.powf(1.0 - normalized_depth);
scores.insert(node_id, score);
total += score;
}
// Normalize to sum to 1
if total > 0.0 {
for score in scores.values_mut() {
*score /= total;
}
}
Ok(scores)
}
fn update(&mut self, _dag: &QueryDag, _times: &HashMap<usize, f64>) {
// Topological attention is static, no updates needed
}
fn name(&self) -> &'static str {
"topological"
}
fn complexity(&self) -> &'static str {
"O(n)"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dag::{OperatorNode, OperatorType};
#[test]
fn test_topological_attention() {
let mut dag = QueryDag::new();
// Create a simple DAG: 0 -> 1 -> 2
let id0 = dag.add_node(OperatorNode::seq_scan(0, "users").with_estimates(100.0, 1.0));
let id1 = dag.add_node(OperatorNode::filter(0, "age > 18").with_estimates(50.0, 1.0));
let id2 = dag
.add_node(OperatorNode::project(0, vec!["name".to_string()]).with_estimates(50.0, 1.0));
dag.add_edge(id0, id1).unwrap();
dag.add_edge(id1, id2).unwrap();
let attention = TopologicalAttention::with_defaults();
let scores = attention.forward(&dag).unwrap();
// Check that scores sum to ~1.0
let sum: f32 = scores.values().sum();
assert!((sum - 1.0).abs() < 1e-5);
// All scores should be in [0, 1]
for &score in scores.values() {
assert!(score >= 0.0 && score <= 1.0);
}
}
}

View File

@@ -0,0 +1,75 @@
//! DagAttention trait definition for pluggable attention mechanisms
use crate::dag::QueryDag;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use thiserror::Error;
/// Attention scores for each node in the DAG
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AttentionScores {
/// Attention score for each node (0.0 to 1.0)
pub scores: Vec<f32>,
/// Optional attention weights between nodes (adjacency-like)
pub edge_weights: Option<Vec<Vec<f32>>>,
/// Metadata for debugging
pub metadata: HashMap<String, String>,
}
impl AttentionScores {
pub fn new(scores: Vec<f32>) -> Self {
Self {
scores,
edge_weights: None,
metadata: HashMap::new(),
}
}
pub fn with_edge_weights(mut self, weights: Vec<Vec<f32>>) -> Self {
self.edge_weights = Some(weights);
self
}
pub fn with_metadata(mut self, key: String, value: String) -> Self {
self.metadata.insert(key, value);
self
}
}
/// Errors that can occur during attention computation
#[derive(Debug, Error)]
pub enum AttentionError {
#[error("Invalid DAG structure: {0}")]
InvalidDag(String),
#[error("Dimension mismatch: expected {expected}, got {actual}")]
DimensionMismatch { expected: usize, actual: usize },
#[error("Computation failed: {0}")]
ComputationFailed(String),
#[error("Configuration error: {0}")]
ConfigError(String),
}
/// Trait for DAG attention mechanisms
pub trait DagAttentionMechanism: Send + Sync {
/// Compute attention scores for the given DAG
fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError>;
/// Get the mechanism name
fn name(&self) -> &'static str;
/// Get computational complexity as a string
fn complexity(&self) -> &'static str;
/// Optional: Update internal state based on execution feedback
fn update(&mut self, _dag: &QueryDag, _execution_times: &HashMap<usize, f64>) {
// Default: no-op
}
/// Optional: Reset internal state
fn reset(&mut self) {
// Default: no-op
}
}

View File

@@ -0,0 +1,53 @@
//! Core traits and types for DAG attention mechanisms
use crate::dag::QueryDag;
use std::collections::HashMap;
/// Attention scores for DAG nodes
pub type AttentionScores = HashMap<usize, f32>;
/// Configuration for attention mechanisms
#[derive(Debug, Clone)]
pub struct AttentionConfig {
pub normalize: bool,
pub temperature: f32,
pub dropout: f32,
}
impl Default for AttentionConfig {
fn default() -> Self {
Self {
normalize: true,
temperature: 1.0,
dropout: 0.0,
}
}
}
/// Errors from attention computation
#[derive(Debug, thiserror::Error)]
pub enum AttentionError {
#[error("Empty DAG")]
EmptyDag,
#[error("Cycle detected in DAG")]
CycleDetected,
#[error("Node {0} not found")]
NodeNotFound(usize),
#[error("Computation failed: {0}")]
ComputationFailed(String),
}
/// Trait for DAG attention mechanisms
pub trait DagAttention: Send + Sync {
/// Compute attention scores for all nodes
fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError>;
/// Update internal state after execution feedback
fn update(&mut self, dag: &QueryDag, execution_times: &HashMap<usize, f64>);
/// Get mechanism name
fn name(&self) -> &'static str;
/// Get computational complexity description
fn complexity(&self) -> &'static str;
}