Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
89
vendor/ruvector/crates/ruvector-dag/src/attention/IMPLEMENTATION_NOTES.md
vendored
Normal file
89
vendor/ruvector/crates/ruvector-dag/src/attention/IMPLEMENTATION_NOTES.md
vendored
Normal file
@@ -0,0 +1,89 @@
|
||||
# Advanced Attention Mechanisms - Implementation Notes
|
||||
|
||||
## Agent #3 Implementation Summary
|
||||
|
||||
This implementation provides 5 advanced attention mechanisms for DAG query optimization:
|
||||
|
||||
### 1. Hierarchical Lorentz Attention (`hierarchical_lorentz.rs`)
|
||||
- **Lines**: 274
|
||||
- **Complexity**: O(n²·d)
|
||||
- Uses hyperbolic geometry (Lorentz model) to embed DAG nodes
|
||||
- Deeper nodes in hierarchy receive higher attention
|
||||
- Implements Lorentz inner product and distance metrics
|
||||
|
||||
### 2. Parallel Branch Attention (`parallel_branch.rs`)
|
||||
- **Lines**: 291
|
||||
- **Complexity**: O(n² + b·n)
|
||||
- Detects and coordinates parallel execution branches
|
||||
- Balances workload across branches
|
||||
- Applies synchronization penalties
|
||||
|
||||
### 3. Temporal BTSP Attention (`temporal_btsp.rs`)
|
||||
- **Lines**: 291
|
||||
- **Complexity**: O(n + t)
|
||||
- Behavioral Timescale Synaptic Plasticity
|
||||
- Uses eligibility traces for temporal learning
|
||||
- Implements plateau state boosting
|
||||
|
||||
### 4. Attention Selector (`selector.rs`)
|
||||
- **Lines**: 281
|
||||
- **Complexity**: O(1) for selection
|
||||
- UCB1 bandit algorithm for mechanism selection
|
||||
- Tracks rewards and counts for each mechanism
|
||||
- Balances exploration vs exploitation
|
||||
|
||||
### 5. Attention Cache (`cache.rs`)
|
||||
- **Lines**: 316
|
||||
- **Complexity**: O(1) for get/insert
|
||||
- LRU eviction policy
|
||||
- TTL support for entries
|
||||
- Hash-based DAG fingerprinting
|
||||
|
||||
## Trait Definition
|
||||
|
||||
**`trait_def.rs`** (75 lines):
|
||||
- Defines `DagAttentionMechanism` trait
|
||||
- `AttentionScores` structure with edge weights
|
||||
- `AttentionError` for error handling
|
||||
|
||||
## Integration
|
||||
|
||||
All mechanisms are exported from `mod.rs` with appropriate type aliases to avoid conflicts with existing attention system.
|
||||
|
||||
## Tests
|
||||
|
||||
Each mechanism includes comprehensive unit tests:
|
||||
- Hierarchical Lorentz: 3 tests
|
||||
- Parallel Branch: 2 tests
|
||||
- Temporal BTSP: 3 tests
|
||||
- Selector: 4 tests
|
||||
- Cache: 7 tests
|
||||
|
||||
Total: 19 new test functions
|
||||
|
||||
## Performance Targets
|
||||
|
||||
| Mechanism | Target | Notes |
|
||||
|-----------|--------|-------|
|
||||
| HierarchicalLorentz | <150μs | For 100 nodes |
|
||||
| ParallelBranch | <100μs | For 100 nodes |
|
||||
| TemporalBTSP | <120μs | For 100 nodes |
|
||||
| Selector.select() | <1μs | UCB1 computation |
|
||||
| Cache.get() | <1μs | LRU lookup |
|
||||
|
||||
## Compatibility Notes
|
||||
|
||||
The implementation is designed to work with the updated QueryDag API that uses:
|
||||
- HashMap-based node storage
|
||||
- `estimated_cost` and `estimated_rows` instead of `cost` and `selectivity`
|
||||
- `children(id)` and `parents(id)` methods
|
||||
- No direct edges iterator
|
||||
|
||||
Some test fixtures may need updates to match the current OperatorNode structure.
|
||||
|
||||
## Future Work
|
||||
|
||||
- Add SIMD optimizations for hyperbolic distance calculations
|
||||
- Implement GPU acceleration for large DAGs
|
||||
- Add more sophisticated caching strategies
|
||||
- Integrate with existing TopologicalAttention, CausalConeAttention, etc.
|
||||
322
vendor/ruvector/crates/ruvector-dag/src/attention/cache.rs
vendored
Normal file
322
vendor/ruvector/crates/ruvector-dag/src/attention/cache.rs
vendored
Normal file
@@ -0,0 +1,322 @@
|
||||
//! Attention Cache: LRU cache for computed attention scores
|
||||
//!
|
||||
//! Caches attention scores to avoid redundant computation for identical DAGs.
|
||||
//! Uses LRU eviction policy to manage memory usage.
|
||||
|
||||
use super::trait_def::AttentionScores;
|
||||
use crate::dag::QueryDag;
|
||||
use std::collections::{hash_map::DefaultHasher, HashMap};
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CacheConfig {
|
||||
/// Maximum number of entries
|
||||
pub capacity: usize,
|
||||
/// Time-to-live for entries
|
||||
pub ttl: Option<Duration>,
|
||||
}
|
||||
|
||||
impl Default for CacheConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
capacity: 1000,
|
||||
ttl: Some(Duration::from_secs(300)), // 5 minutes
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct CacheEntry {
|
||||
scores: AttentionScores,
|
||||
timestamp: Instant,
|
||||
access_count: usize,
|
||||
}
|
||||
|
||||
pub struct AttentionCache {
|
||||
config: CacheConfig,
|
||||
cache: HashMap<u64, CacheEntry>,
|
||||
access_order: Vec<u64>,
|
||||
hits: usize,
|
||||
misses: usize,
|
||||
}
|
||||
|
||||
impl AttentionCache {
|
||||
pub fn new(config: CacheConfig) -> Self {
|
||||
Self {
|
||||
cache: HashMap::with_capacity(config.capacity),
|
||||
access_order: Vec::with_capacity(config.capacity),
|
||||
config,
|
||||
hits: 0,
|
||||
misses: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Hash a DAG for cache key
|
||||
fn hash_dag(dag: &QueryDag, mechanism: &str) -> u64 {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
|
||||
// Hash mechanism name
|
||||
mechanism.hash(&mut hasher);
|
||||
|
||||
// Hash number of nodes
|
||||
dag.node_count().hash(&mut hasher);
|
||||
|
||||
// Hash edges structure
|
||||
let mut edge_list: Vec<(usize, usize)> = Vec::new();
|
||||
for node_id in dag.node_ids() {
|
||||
for &child in dag.children(node_id) {
|
||||
edge_list.push((node_id, child));
|
||||
}
|
||||
}
|
||||
edge_list.sort_unstable();
|
||||
|
||||
for (from, to) in edge_list {
|
||||
from.hash(&mut hasher);
|
||||
to.hash(&mut hasher);
|
||||
}
|
||||
|
||||
hasher.finish()
|
||||
}
|
||||
|
||||
/// Check if entry is expired
|
||||
fn is_expired(&self, entry: &CacheEntry) -> bool {
|
||||
if let Some(ttl) = self.config.ttl {
|
||||
entry.timestamp.elapsed() > ttl
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Get cached scores for a DAG and mechanism
|
||||
pub fn get(&mut self, dag: &QueryDag, mechanism: &str) -> Option<AttentionScores> {
|
||||
let key = Self::hash_dag(dag, mechanism);
|
||||
|
||||
// Check if key exists and is not expired
|
||||
let is_expired = self
|
||||
.cache
|
||||
.get(&key)
|
||||
.map(|entry| self.is_expired(entry))
|
||||
.unwrap_or(true);
|
||||
|
||||
if is_expired {
|
||||
self.cache.remove(&key);
|
||||
self.access_order.retain(|&k| k != key);
|
||||
self.misses += 1;
|
||||
return None;
|
||||
}
|
||||
|
||||
// Update access and return clone
|
||||
if let Some(entry) = self.cache.get_mut(&key) {
|
||||
// Update access order (move to end = most recently used)
|
||||
self.access_order.retain(|&k| k != key);
|
||||
self.access_order.push(key);
|
||||
entry.access_count += 1;
|
||||
self.hits += 1;
|
||||
|
||||
Some(entry.scores.clone())
|
||||
} else {
|
||||
self.misses += 1;
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Insert scores into cache
|
||||
pub fn insert(&mut self, dag: &QueryDag, mechanism: &str, scores: AttentionScores) {
|
||||
let key = Self::hash_dag(dag, mechanism);
|
||||
|
||||
// Evict if at capacity
|
||||
while self.cache.len() >= self.config.capacity && !self.access_order.is_empty() {
|
||||
if let Some(oldest) = self.access_order.first().copied() {
|
||||
self.cache.remove(&oldest);
|
||||
self.access_order.remove(0);
|
||||
}
|
||||
}
|
||||
|
||||
let entry = CacheEntry {
|
||||
scores,
|
||||
timestamp: Instant::now(),
|
||||
access_count: 0,
|
||||
};
|
||||
|
||||
self.cache.insert(key, entry);
|
||||
self.access_order.push(key);
|
||||
}
|
||||
|
||||
/// Clear all entries
|
||||
pub fn clear(&mut self) {
|
||||
self.cache.clear();
|
||||
self.access_order.clear();
|
||||
self.hits = 0;
|
||||
self.misses = 0;
|
||||
}
|
||||
|
||||
/// Remove expired entries
|
||||
pub fn evict_expired(&mut self) {
|
||||
let expired_keys: Vec<u64> = self
|
||||
.cache
|
||||
.iter()
|
||||
.filter(|(_, entry)| self.is_expired(entry))
|
||||
.map(|(k, _)| *k)
|
||||
.collect();
|
||||
|
||||
for key in expired_keys {
|
||||
self.cache.remove(&key);
|
||||
self.access_order.retain(|&k| k != key);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get cache statistics
|
||||
pub fn stats(&self) -> CacheStats {
|
||||
CacheStats {
|
||||
size: self.cache.len(),
|
||||
capacity: self.config.capacity,
|
||||
hits: self.hits,
|
||||
misses: self.misses,
|
||||
hit_rate: if self.hits + self.misses > 0 {
|
||||
self.hits as f64 / (self.hits + self.misses) as f64
|
||||
} else {
|
||||
0.0
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Get entry with most accesses
|
||||
pub fn most_accessed(&self) -> Option<(&u64, usize)> {
|
||||
self.cache
|
||||
.iter()
|
||||
.max_by_key(|(_, entry)| entry.access_count)
|
||||
.map(|(k, entry)| (k, entry.access_count))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CacheStats {
|
||||
pub size: usize,
|
||||
pub capacity: usize,
|
||||
pub hits: usize,
|
||||
pub misses: usize,
|
||||
pub hit_rate: f64,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::dag::{OperatorNode, OperatorType};
|
||||
|
||||
fn create_test_dag(n: usize) -> QueryDag {
|
||||
let mut dag = QueryDag::new();
|
||||
for i in 0..n {
|
||||
let mut node = OperatorNode::new(i, OperatorType::Scan);
|
||||
node.estimated_cost = (i + 1) as f64;
|
||||
dag.add_node(node);
|
||||
}
|
||||
if n > 1 {
|
||||
let _ = dag.add_edge(0, 1);
|
||||
}
|
||||
dag
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_insert_and_get() {
|
||||
let mut cache = AttentionCache::new(CacheConfig::default());
|
||||
let dag = create_test_dag(3);
|
||||
|
||||
let scores = AttentionScores::new(vec![0.5, 0.3, 0.2]);
|
||||
let expected_scores = scores.scores.clone();
|
||||
cache.insert(&dag, "test_mechanism", scores);
|
||||
|
||||
let retrieved = cache.get(&dag, "test_mechanism").unwrap();
|
||||
assert_eq!(retrieved.scores, expected_scores);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_miss() {
|
||||
let mut cache = AttentionCache::new(CacheConfig::default());
|
||||
let dag = create_test_dag(3);
|
||||
|
||||
let result = cache.get(&dag, "nonexistent");
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lru_eviction() {
|
||||
let mut cache = AttentionCache::new(CacheConfig {
|
||||
capacity: 2,
|
||||
ttl: None,
|
||||
});
|
||||
|
||||
let dag1 = create_test_dag(1);
|
||||
let dag2 = create_test_dag(2);
|
||||
let dag3 = create_test_dag(3);
|
||||
|
||||
cache.insert(&dag1, "mech", AttentionScores::new(vec![0.5]));
|
||||
cache.insert(&dag2, "mech", AttentionScores::new(vec![0.3, 0.7]));
|
||||
cache.insert(&dag3, "mech", AttentionScores::new(vec![0.2, 0.3, 0.5]));
|
||||
|
||||
// dag1 should be evicted (LRU), dag2 and dag3 should still be present
|
||||
let result1 = cache.get(&dag1, "mech");
|
||||
let result2 = cache.get(&dag2, "mech");
|
||||
let result3 = cache.get(&dag3, "mech");
|
||||
|
||||
assert!(result1.is_none());
|
||||
assert!(result2.is_some());
|
||||
assert!(result3.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_stats() {
|
||||
let mut cache = AttentionCache::new(CacheConfig::default());
|
||||
let dag = create_test_dag(2);
|
||||
|
||||
cache.insert(&dag, "mech", AttentionScores::new(vec![0.5, 0.5]));
|
||||
|
||||
cache.get(&dag, "mech"); // hit
|
||||
cache.get(&dag, "nonexistent"); // miss
|
||||
|
||||
let stats = cache.stats();
|
||||
assert_eq!(stats.hits, 1);
|
||||
assert_eq!(stats.misses, 1);
|
||||
assert!((stats.hit_rate - 0.5).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ttl_expiration() {
|
||||
let mut cache = AttentionCache::new(CacheConfig {
|
||||
capacity: 100,
|
||||
ttl: Some(Duration::from_millis(50)),
|
||||
});
|
||||
|
||||
let dag = create_test_dag(2);
|
||||
cache.insert(&dag, "mech", AttentionScores::new(vec![0.5, 0.5]));
|
||||
|
||||
// Should be present immediately
|
||||
assert!(cache.get(&dag, "mech").is_some());
|
||||
|
||||
// Wait for expiration
|
||||
std::thread::sleep(Duration::from_millis(60));
|
||||
|
||||
// Should be expired
|
||||
assert!(cache.get(&dag, "mech").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hash_consistency() {
|
||||
let dag = create_test_dag(3);
|
||||
|
||||
let hash1 = AttentionCache::hash_dag(&dag, "mechanism");
|
||||
let hash2 = AttentionCache::hash_dag(&dag, "mechanism");
|
||||
|
||||
assert_eq!(hash1, hash2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hash_different_mechanisms() {
|
||||
let dag = create_test_dag(3);
|
||||
|
||||
let hash1 = AttentionCache::hash_dag(&dag, "mechanism1");
|
||||
let hash2 = AttentionCache::hash_dag(&dag, "mechanism2");
|
||||
|
||||
assert_ne!(hash1, hash2);
|
||||
}
|
||||
}
|
||||
127
vendor/ruvector/crates/ruvector-dag/src/attention/causal_cone.rs
vendored
Normal file
127
vendor/ruvector/crates/ruvector-dag/src/attention/causal_cone.rs
vendored
Normal file
@@ -0,0 +1,127 @@
|
||||
//! Causal Cone Attention: Focuses on ancestors with temporal discount
|
||||
|
||||
use super::{AttentionError, AttentionScores, DagAttention};
|
||||
use crate::dag::QueryDag;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CausalConeConfig {
|
||||
pub time_window_ms: u64,
|
||||
pub future_discount: f32,
|
||||
pub ancestor_weight: f32,
|
||||
}
|
||||
|
||||
impl Default for CausalConeConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
time_window_ms: 1000,
|
||||
future_discount: 0.8,
|
||||
ancestor_weight: 0.9,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CausalConeAttention {
|
||||
config: CausalConeConfig,
|
||||
}
|
||||
|
||||
impl CausalConeAttention {
|
||||
pub fn new(config: CausalConeConfig) -> Self {
|
||||
Self { config }
|
||||
}
|
||||
|
||||
pub fn with_defaults() -> Self {
|
||||
Self::new(CausalConeConfig::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl DagAttention for CausalConeAttention {
|
||||
fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError> {
|
||||
if dag.node_count() == 0 {
|
||||
return Err(AttentionError::EmptyDag);
|
||||
}
|
||||
|
||||
let mut scores = HashMap::new();
|
||||
let mut total = 0.0f32;
|
||||
|
||||
// For each node, compute attention based on:
|
||||
// 1. Number of ancestors (causal influence)
|
||||
// 2. Distance from node (temporal decay)
|
||||
let node_ids: Vec<usize> = (0..dag.node_count()).collect();
|
||||
for node_id in node_ids {
|
||||
if dag.get_node(node_id).is_none() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let ancestors = dag.ancestors(node_id);
|
||||
let ancestor_count = ancestors.len();
|
||||
|
||||
// Base score is proportional to causal influence (number of ancestors)
|
||||
let mut score = 1.0 + (ancestor_count as f32 * self.config.ancestor_weight);
|
||||
|
||||
// Apply temporal discount based on depth
|
||||
let depths = dag.compute_depths();
|
||||
if let Some(&depth) = depths.get(&node_id) {
|
||||
score *= self.config.future_discount.powi(depth as i32);
|
||||
}
|
||||
|
||||
scores.insert(node_id, score);
|
||||
total += score;
|
||||
}
|
||||
|
||||
// Normalize to sum to 1
|
||||
if total > 0.0 {
|
||||
for score in scores.values_mut() {
|
||||
*score /= total;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(scores)
|
||||
}
|
||||
|
||||
fn update(&mut self, _dag: &QueryDag, _times: &HashMap<usize, f64>) {
|
||||
// Could update temporal discount based on actual execution times
|
||||
// For now, static configuration
|
||||
}
|
||||
|
||||
fn name(&self) -> &'static str {
|
||||
"causal_cone"
|
||||
}
|
||||
|
||||
fn complexity(&self) -> &'static str {
|
||||
"O(n^2)"
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::dag::{OperatorNode, OperatorType};
|
||||
|
||||
#[test]
|
||||
fn test_causal_cone_attention() {
|
||||
let mut dag = QueryDag::new();
|
||||
|
||||
// Create a DAG with multiple paths
|
||||
let id0 = dag.add_node(OperatorNode::seq_scan(0, "table1"));
|
||||
let id1 = dag.add_node(OperatorNode::seq_scan(0, "table2"));
|
||||
let id2 = dag.add_node(OperatorNode::hash_join(0, "id"));
|
||||
let id3 = dag.add_node(OperatorNode::project(0, vec!["name".to_string()]));
|
||||
|
||||
dag.add_edge(id0, id2).unwrap();
|
||||
dag.add_edge(id1, id2).unwrap();
|
||||
dag.add_edge(id2, id3).unwrap();
|
||||
|
||||
let attention = CausalConeAttention::with_defaults();
|
||||
let scores = attention.forward(&dag).unwrap();
|
||||
|
||||
// Check normalization
|
||||
let sum: f32 = scores.values().sum();
|
||||
assert!((sum - 1.0).abs() < 1e-5);
|
||||
|
||||
// All scores should be in [0, 1]
|
||||
for &score in scores.values() {
|
||||
assert!(score >= 0.0 && score <= 1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
198
vendor/ruvector/crates/ruvector-dag/src/attention/critical_path.rs
vendored
Normal file
198
vendor/ruvector/crates/ruvector-dag/src/attention/critical_path.rs
vendored
Normal file
@@ -0,0 +1,198 @@
|
||||
//! Critical Path Attention: Focuses on bottleneck nodes
|
||||
|
||||
use super::{AttentionError, AttentionScores, DagAttention};
|
||||
use crate::dag::QueryDag;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CriticalPathConfig {
|
||||
pub path_weight: f32,
|
||||
pub branch_penalty: f32,
|
||||
}
|
||||
|
||||
impl Default for CriticalPathConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
path_weight: 2.0,
|
||||
branch_penalty: 0.5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CriticalPathAttention {
|
||||
config: CriticalPathConfig,
|
||||
critical_path: Vec<usize>,
|
||||
}
|
||||
|
||||
impl CriticalPathAttention {
|
||||
pub fn new(config: CriticalPathConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
critical_path: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_defaults() -> Self {
|
||||
Self::new(CriticalPathConfig::default())
|
||||
}
|
||||
|
||||
/// Compute the critical path (longest path by cost)
|
||||
fn compute_critical_path(&self, dag: &QueryDag) -> Vec<usize> {
|
||||
let mut longest_path: HashMap<usize, (f64, Vec<usize>)> = HashMap::new();
|
||||
|
||||
// Initialize leaves
|
||||
for &leaf in &dag.leaves() {
|
||||
if let Some(node) = dag.get_node(leaf) {
|
||||
longest_path.insert(leaf, (node.estimated_cost, vec![leaf]));
|
||||
}
|
||||
}
|
||||
|
||||
// Process nodes in reverse topological order
|
||||
if let Ok(topo_order) = dag.topological_sort() {
|
||||
for &node_id in topo_order.iter().rev() {
|
||||
let node = match dag.get_node(node_id) {
|
||||
Some(n) => n,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let mut max_cost = node.estimated_cost;
|
||||
let mut max_path = vec![node_id];
|
||||
|
||||
// Check all children
|
||||
for &child in dag.children(node_id) {
|
||||
if let Some(&(child_cost, ref child_path)) = longest_path.get(&child) {
|
||||
let total_cost = node.estimated_cost + child_cost;
|
||||
if total_cost > max_cost {
|
||||
max_cost = total_cost;
|
||||
max_path = vec![node_id];
|
||||
max_path.extend(child_path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
longest_path.insert(node_id, (max_cost, max_path));
|
||||
}
|
||||
}
|
||||
|
||||
// Find the path with maximum cost
|
||||
longest_path
|
||||
.into_iter()
|
||||
.max_by(|a, b| {
|
||||
a.1 .0
|
||||
.partial_cmp(&b.1 .0)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
})
|
||||
.map(|(_, (_, path))| path)
|
||||
.unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
impl DagAttention for CriticalPathAttention {
|
||||
fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError> {
|
||||
if dag.node_count() == 0 {
|
||||
return Err(AttentionError::EmptyDag);
|
||||
}
|
||||
|
||||
let critical = self.compute_critical_path(dag);
|
||||
let mut scores = HashMap::new();
|
||||
let mut total = 0.0f32;
|
||||
|
||||
// Assign higher attention to nodes on critical path
|
||||
let node_ids: Vec<usize> = (0..dag.node_count()).collect();
|
||||
for node_id in node_ids {
|
||||
if dag.get_node(node_id).is_none() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let is_on_critical_path = critical.contains(&node_id);
|
||||
let num_children = dag.children(node_id).len();
|
||||
|
||||
let mut score = if is_on_critical_path {
|
||||
self.config.path_weight
|
||||
} else {
|
||||
1.0
|
||||
};
|
||||
|
||||
// Apply branch penalty for nodes with many children (potential bottlenecks)
|
||||
if num_children > 1 {
|
||||
score *= 1.0 + (num_children as f32 - 1.0) * self.config.branch_penalty;
|
||||
}
|
||||
|
||||
scores.insert(node_id, score);
|
||||
total += score;
|
||||
}
|
||||
|
||||
// Normalize to sum to 1
|
||||
if total > 0.0 {
|
||||
for score in scores.values_mut() {
|
||||
*score /= total;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(scores)
|
||||
}
|
||||
|
||||
fn update(&mut self, dag: &QueryDag, execution_times: &HashMap<usize, f64>) {
|
||||
// Recompute critical path based on actual execution times
|
||||
// For now, we use the static cost-based approach
|
||||
self.critical_path = self.compute_critical_path(dag);
|
||||
|
||||
// Could adjust path_weight based on execution time variance
|
||||
if !execution_times.is_empty() {
|
||||
let max_time = execution_times.values().fold(0.0f64, |a, &b| a.max(b));
|
||||
let avg_time: f64 =
|
||||
execution_times.values().sum::<f64>() / execution_times.len() as f64;
|
||||
|
||||
if max_time > 0.0 && avg_time > 0.0 {
|
||||
// Increase path weight if there's high variance
|
||||
let variance_ratio = max_time / avg_time;
|
||||
if variance_ratio > 2.0 {
|
||||
self.config.path_weight = (self.config.path_weight * 1.1).min(5.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn name(&self) -> &'static str {
|
||||
"critical_path"
|
||||
}
|
||||
|
||||
fn complexity(&self) -> &'static str {
|
||||
"O(n + e)"
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::dag::{OperatorNode, OperatorType};
|
||||
|
||||
#[test]
|
||||
fn test_critical_path_attention() {
|
||||
let mut dag = QueryDag::new();
|
||||
|
||||
// Create a DAG with different costs
|
||||
let id0 =
|
||||
dag.add_node(OperatorNode::seq_scan(0, "large_table").with_estimates(10000.0, 10.0));
|
||||
let id1 =
|
||||
dag.add_node(OperatorNode::filter(0, "status = 'active'").with_estimates(1000.0, 1.0));
|
||||
let id2 = dag.add_node(OperatorNode::hash_join(0, "user_id").with_estimates(5000.0, 5.0));
|
||||
|
||||
dag.add_edge(id0, id2).unwrap();
|
||||
dag.add_edge(id1, id2).unwrap();
|
||||
|
||||
let attention = CriticalPathAttention::with_defaults();
|
||||
let scores = attention.forward(&dag).unwrap();
|
||||
|
||||
// Check normalization
|
||||
let sum: f32 = scores.values().sum();
|
||||
assert!((sum - 1.0).abs() < 1e-5);
|
||||
|
||||
// Nodes on critical path should have higher attention
|
||||
let critical = attention.compute_critical_path(&dag);
|
||||
for &node_id in &critical {
|
||||
let score = scores.get(&node_id).unwrap();
|
||||
assert!(*score > 0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
288
vendor/ruvector/crates/ruvector-dag/src/attention/hierarchical_lorentz.rs
vendored
Normal file
288
vendor/ruvector/crates/ruvector-dag/src/attention/hierarchical_lorentz.rs
vendored
Normal file
@@ -0,0 +1,288 @@
|
||||
//! Hierarchical Lorentz Attention: Hyperbolic geometry for tree-like structures
|
||||
//!
|
||||
//! This mechanism embeds DAG nodes in hyperbolic space using the Lorentz (hyperboloid) model,
|
||||
//! where hierarchical relationships are naturally represented by distance from the origin.
|
||||
|
||||
use super::trait_def::{AttentionError, AttentionScores, DagAttentionMechanism};
|
||||
use crate::dag::QueryDag;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HierarchicalLorentzConfig {
|
||||
/// Curvature parameter (-1.0 for standard Poincaré ball)
|
||||
pub curvature: f32,
|
||||
/// Scale factor for temporal dimension
|
||||
pub time_scale: f32,
|
||||
/// Embedding dimension
|
||||
pub dim: usize,
|
||||
/// Temperature for softmax
|
||||
pub temperature: f32,
|
||||
}
|
||||
|
||||
impl Default for HierarchicalLorentzConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
curvature: -1.0,
|
||||
time_scale: 1.0,
|
||||
dim: 64,
|
||||
temperature: 0.1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct HierarchicalLorentzAttention {
|
||||
config: HierarchicalLorentzConfig,
|
||||
}
|
||||
|
||||
impl HierarchicalLorentzAttention {
|
||||
pub fn new(config: HierarchicalLorentzConfig) -> Self {
|
||||
Self { config }
|
||||
}
|
||||
|
||||
/// Lorentz inner product: -x0*y0 + x1*y1 + ... + xn*yn
|
||||
fn lorentz_inner(&self, x: &[f32], y: &[f32]) -> f32 {
|
||||
if x.is_empty() || y.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
-x[0] * y[0] + x[1..].iter().zip(&y[1..]).map(|(a, b)| a * b).sum::<f32>()
|
||||
}
|
||||
|
||||
/// Lorentz distance in hyperboloid model
|
||||
fn lorentz_distance(&self, x: &[f32], y: &[f32]) -> f32 {
|
||||
let inner = self.lorentz_inner(x, y);
|
||||
// Clamp to avoid numerical issues with acosh
|
||||
let clamped = (-inner).max(1.0);
|
||||
clamped.acosh() * self.config.curvature.abs()
|
||||
}
|
||||
|
||||
/// Project to hyperboloid: [sqrt(1 + ||x||^2), x1, x2, ..., xn]
|
||||
fn project_to_hyperboloid(&self, x: &[f32]) -> Vec<f32> {
|
||||
let spatial_norm_sq: f32 = x.iter().map(|v| v * v).sum();
|
||||
let time_coord = (1.0 + spatial_norm_sq).sqrt();
|
||||
|
||||
let mut result = Vec::with_capacity(x.len() + 1);
|
||||
result.push(time_coord * self.config.time_scale);
|
||||
result.extend_from_slice(x);
|
||||
result
|
||||
}
|
||||
|
||||
/// Compute hierarchical depth for each node
|
||||
fn compute_depths(&self, dag: &QueryDag) -> Vec<usize> {
|
||||
let n = dag.node_count();
|
||||
let mut depths = vec![0; n];
|
||||
let mut adj_list: HashMap<usize, Vec<usize>> = HashMap::new();
|
||||
|
||||
// Build adjacency list
|
||||
for node_id in dag.node_ids() {
|
||||
for &child in dag.children(node_id) {
|
||||
adj_list.entry(node_id).or_insert_with(Vec::new).push(child);
|
||||
}
|
||||
}
|
||||
|
||||
// Find root nodes (nodes with no incoming edges)
|
||||
let mut has_incoming = vec![false; n];
|
||||
for node_id in dag.node_ids() {
|
||||
for &child in dag.children(node_id) {
|
||||
if child < n {
|
||||
has_incoming[child] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BFS to compute depths
|
||||
let mut queue: Vec<usize> = (0..n).filter(|&i| !has_incoming[i]).collect();
|
||||
let mut visited = vec![false; n];
|
||||
|
||||
while let Some(node) = queue.pop() {
|
||||
if visited[node] {
|
||||
continue;
|
||||
}
|
||||
visited[node] = true;
|
||||
|
||||
if let Some(children) = adj_list.get(&node) {
|
||||
for &child in children {
|
||||
if child < n {
|
||||
depths[child] = depths[node] + 1;
|
||||
queue.push(child);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
depths
|
||||
}
|
||||
|
||||
/// Embed node in hyperbolic space based on depth and position
|
||||
fn embed_node(&self, node_id: usize, depth: usize, total_nodes: usize) -> Vec<f32> {
|
||||
let dim = self.config.dim;
|
||||
let mut embedding = vec![0.0; dim];
|
||||
|
||||
// Use depth to determine radial distance from origin
|
||||
let radial = (depth as f32 * 0.5).tanh();
|
||||
|
||||
// Angular position based on node ID
|
||||
let angle = 2.0 * std::f32::consts::PI * (node_id as f32) / (total_nodes as f32).max(1.0);
|
||||
|
||||
// Spherical coordinates in hyperbolic space
|
||||
embedding[0] = radial * angle.cos();
|
||||
if dim > 1 {
|
||||
embedding[1] = radial * angle.sin();
|
||||
}
|
||||
|
||||
// Add noise to remaining dimensions for better separation
|
||||
for i in 2..dim {
|
||||
embedding[i] = 0.1 * ((node_id + i) as f32).sin();
|
||||
}
|
||||
|
||||
embedding
|
||||
}
|
||||
|
||||
/// Compute attention using hyperbolic distances
|
||||
fn compute_attention_from_distances(&self, distances: &[f32]) -> Vec<f32> {
|
||||
if distances.is_empty() {
|
||||
return vec![];
|
||||
}
|
||||
|
||||
// Convert distances to attention scores using softmax
|
||||
// Closer nodes (smaller distance) should have higher attention
|
||||
let neg_distances: Vec<f32> = distances
|
||||
.iter()
|
||||
.map(|&d| -d / self.config.temperature)
|
||||
.collect();
|
||||
|
||||
// Softmax
|
||||
let max_val = neg_distances
|
||||
.iter()
|
||||
.cloned()
|
||||
.fold(f32::NEG_INFINITY, f32::max);
|
||||
let exp_sum: f32 = neg_distances.iter().map(|&x| (x - max_val).exp()).sum();
|
||||
|
||||
if exp_sum == 0.0 {
|
||||
// Uniform distribution if all distances are too large
|
||||
return vec![1.0 / distances.len() as f32; distances.len()];
|
||||
}
|
||||
|
||||
neg_distances
|
||||
.iter()
|
||||
.map(|&x| (x - max_val).exp() / exp_sum)
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl DagAttentionMechanism for HierarchicalLorentzAttention {
|
||||
fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError> {
|
||||
if dag.node_count() == 0 {
|
||||
return Err(AttentionError::InvalidDag("Empty DAG".to_string()));
|
||||
}
|
||||
|
||||
let n = dag.node_count();
|
||||
|
||||
// Step 1: Compute hierarchical depths
|
||||
let depths = self.compute_depths(dag);
|
||||
|
||||
// Step 2: Embed each node in Euclidean space
|
||||
let euclidean_embeddings: Vec<Vec<f32>> =
|
||||
(0..n).map(|i| self.embed_node(i, depths[i], n)).collect();
|
||||
|
||||
// Step 3: Project to hyperboloid
|
||||
let hyperbolic_embeddings: Vec<Vec<f32>> = euclidean_embeddings
|
||||
.iter()
|
||||
.map(|emb| self.project_to_hyperboloid(emb))
|
||||
.collect();
|
||||
|
||||
// Step 4: Compute pairwise distances from a reference point (origin-like)
|
||||
let origin = self.project_to_hyperboloid(&vec![0.0; self.config.dim]);
|
||||
let distances: Vec<f32> = hyperbolic_embeddings
|
||||
.iter()
|
||||
.map(|emb| self.lorentz_distance(emb, &origin))
|
||||
.collect();
|
||||
|
||||
// Step 5: Convert distances to attention scores
|
||||
let scores = self.compute_attention_from_distances(&distances);
|
||||
|
||||
// Step 6: Compute edge weights (optional)
|
||||
let mut edge_weights = vec![vec![0.0; n]; n];
|
||||
for i in 0..n {
|
||||
for j in 0..n {
|
||||
let dist =
|
||||
self.lorentz_distance(&hyperbolic_embeddings[i], &hyperbolic_embeddings[j]);
|
||||
edge_weights[i][j] = (-dist / self.config.temperature).exp();
|
||||
}
|
||||
}
|
||||
|
||||
let mut result = AttentionScores::new(scores)
|
||||
.with_edge_weights(edge_weights)
|
||||
.with_metadata("mechanism".to_string(), "hierarchical_lorentz".to_string());
|
||||
|
||||
result.metadata.insert(
|
||||
"avg_depth".to_string(),
|
||||
format!("{:.2}", depths.iter().sum::<usize>() as f32 / n as f32),
|
||||
);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn name(&self) -> &'static str {
|
||||
"hierarchical_lorentz"
|
||||
}
|
||||
|
||||
fn complexity(&self) -> &'static str {
|
||||
"O(n²·d)"
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::dag::{OperatorNode, OperatorType};
|
||||
|
||||
#[test]
|
||||
fn test_lorentz_distance() {
|
||||
let config = HierarchicalLorentzConfig::default();
|
||||
let attention = HierarchicalLorentzAttention::new(config);
|
||||
|
||||
let x = vec![1.0, 0.5, 0.3];
|
||||
let y = vec![1.2, 0.6, 0.4];
|
||||
|
||||
let dist = attention.lorentz_distance(&x, &y);
|
||||
assert!(dist >= 0.0, "Distance should be non-negative");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_project_to_hyperboloid() {
|
||||
let config = HierarchicalLorentzConfig::default();
|
||||
let attention = HierarchicalLorentzAttention::new(config);
|
||||
|
||||
let x = vec![0.5, 0.3, 0.2];
|
||||
let projected = attention.project_to_hyperboloid(&x);
|
||||
|
||||
assert_eq!(projected.len(), 4);
|
||||
assert!(projected[0] > 0.0, "Time coordinate should be positive");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hierarchical_attention() {
|
||||
let config = HierarchicalLorentzConfig::default();
|
||||
let attention = HierarchicalLorentzAttention::new(config);
|
||||
|
||||
let mut dag = QueryDag::new();
|
||||
let mut node0 = OperatorNode::new(0, OperatorType::Scan);
|
||||
node0.estimated_cost = 1.0;
|
||||
dag.add_node(node0);
|
||||
|
||||
let mut node1 = OperatorNode::new(
|
||||
1,
|
||||
OperatorType::Filter {
|
||||
predicate: "x > 0".to_string(),
|
||||
},
|
||||
);
|
||||
node1.estimated_cost = 2.0;
|
||||
dag.add_node(node1);
|
||||
|
||||
dag.add_edge(0, 1).unwrap();
|
||||
|
||||
let result = attention.forward(&dag).unwrap();
|
||||
assert_eq!(result.scores.len(), 2);
|
||||
assert!((result.scores.iter().sum::<f32>() - 1.0).abs() < 0.01);
|
||||
}
|
||||
}
|
||||
253
vendor/ruvector/crates/ruvector-dag/src/attention/mincut_gated.rs
vendored
Normal file
253
vendor/ruvector/crates/ruvector-dag/src/attention/mincut_gated.rs
vendored
Normal file
@@ -0,0 +1,253 @@
|
||||
//! MinCut Gated Attention: Gates attention by graph cut criticality
|
||||
|
||||
use super::trait_def::{AttentionError, AttentionScores, DagAttentionMechanism};
|
||||
use crate::dag::QueryDag;
|
||||
use std::collections::{HashMap, HashSet, VecDeque};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum FlowCapacity {
|
||||
UnitCapacity,
|
||||
CostBased,
|
||||
RowBased,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MinCutConfig {
|
||||
pub gate_threshold: f32,
|
||||
pub flow_capacity: FlowCapacity,
|
||||
}
|
||||
|
||||
impl Default for MinCutConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
gate_threshold: 0.5,
|
||||
flow_capacity: FlowCapacity::UnitCapacity,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MinCutGatedAttention {
|
||||
config: MinCutConfig,
|
||||
}
|
||||
|
||||
impl MinCutGatedAttention {
|
||||
pub fn new(config: MinCutConfig) -> Self {
|
||||
Self { config }
|
||||
}
|
||||
|
||||
pub fn with_defaults() -> Self {
|
||||
Self::new(MinCutConfig::default())
|
||||
}
|
||||
|
||||
/// Compute min-cut between leaves and root using Ford-Fulkerson
|
||||
fn compute_min_cut(&self, dag: &QueryDag) -> HashSet<usize> {
|
||||
let mut cut_nodes = HashSet::new();
|
||||
|
||||
// Build capacity matrix from the DAG structure
|
||||
let mut capacity: HashMap<(usize, usize), f64> = HashMap::new();
|
||||
for node_id in 0..dag.node_count() {
|
||||
if dag.get_node(node_id).is_none() {
|
||||
continue;
|
||||
}
|
||||
for &child in dag.children(node_id) {
|
||||
let cap = match self.config.flow_capacity {
|
||||
FlowCapacity::UnitCapacity => 1.0,
|
||||
FlowCapacity::CostBased => dag
|
||||
.get_node(node_id)
|
||||
.map(|n| n.estimated_cost)
|
||||
.unwrap_or(1.0),
|
||||
FlowCapacity::RowBased => dag
|
||||
.get_node(node_id)
|
||||
.map(|n| n.estimated_rows)
|
||||
.unwrap_or(1.0),
|
||||
};
|
||||
capacity.insert((node_id, child), cap);
|
||||
}
|
||||
}
|
||||
|
||||
// Find source (root) and sink (any leaf)
|
||||
let source = match dag.root() {
|
||||
Some(root) => root,
|
||||
None => return cut_nodes,
|
||||
};
|
||||
|
||||
let leaves = dag.leaves();
|
||||
if leaves.is_empty() {
|
||||
return cut_nodes;
|
||||
}
|
||||
|
||||
// Use first leaf as sink
|
||||
let sink = leaves[0];
|
||||
|
||||
// Ford-Fulkerson to find max flow
|
||||
let mut residual = capacity.clone();
|
||||
#[allow(unused_variables, unused_assignments)]
|
||||
let mut total_flow = 0.0;
|
||||
|
||||
loop {
|
||||
// BFS to find augmenting path
|
||||
let mut parent: HashMap<usize, usize> = HashMap::new();
|
||||
let mut visited = HashSet::new();
|
||||
let mut queue = VecDeque::new();
|
||||
|
||||
queue.push_back(source);
|
||||
visited.insert(source);
|
||||
|
||||
while let Some(u) = queue.pop_front() {
|
||||
if u == sink {
|
||||
break;
|
||||
}
|
||||
|
||||
for v in dag.children(u) {
|
||||
if !visited.contains(v) && residual.get(&(u, *v)).copied().unwrap_or(0.0) > 0.0
|
||||
{
|
||||
visited.insert(*v);
|
||||
parent.insert(*v, u);
|
||||
queue.push_back(*v);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// No augmenting path found
|
||||
if !parent.contains_key(&sink) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Find minimum capacity along the path
|
||||
let mut path_flow = f64::INFINITY;
|
||||
let mut v = sink;
|
||||
while v != source {
|
||||
let u = parent[&v];
|
||||
path_flow = path_flow.min(residual.get(&(u, v)).copied().unwrap_or(0.0));
|
||||
v = u;
|
||||
}
|
||||
|
||||
// Update residual capacities
|
||||
v = sink;
|
||||
while v != source {
|
||||
let u = parent[&v];
|
||||
*residual.entry((u, v)).or_insert(0.0) -= path_flow;
|
||||
*residual.entry((v, u)).or_insert(0.0) += path_flow;
|
||||
v = u;
|
||||
}
|
||||
|
||||
total_flow += path_flow;
|
||||
}
|
||||
|
||||
// Find nodes reachable from source in residual graph
|
||||
let mut reachable = HashSet::new();
|
||||
let mut queue = VecDeque::new();
|
||||
queue.push_back(source);
|
||||
reachable.insert(source);
|
||||
|
||||
while let Some(u) = queue.pop_front() {
|
||||
for &v in dag.children(u) {
|
||||
if !reachable.contains(&v) && residual.get(&(u, v)).copied().unwrap_or(0.0) > 0.0 {
|
||||
reachable.insert(v);
|
||||
queue.push_back(v);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Nodes in the cut are those with edges crossing from reachable to non-reachable
|
||||
for node_id in 0..dag.node_count() {
|
||||
if dag.get_node(node_id).is_none() {
|
||||
continue;
|
||||
}
|
||||
for &child in dag.children(node_id) {
|
||||
if reachable.contains(&node_id) && !reachable.contains(&child) {
|
||||
cut_nodes.insert(node_id);
|
||||
cut_nodes.insert(child);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cut_nodes
|
||||
}
|
||||
}
|
||||
|
||||
impl DagAttentionMechanism for MinCutGatedAttention {
|
||||
fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError> {
|
||||
if dag.node_count() == 0 {
|
||||
return Err(AttentionError::InvalidDag("Empty DAG".to_string()));
|
||||
}
|
||||
|
||||
let cut_nodes = self.compute_min_cut(dag);
|
||||
let n = dag.node_count();
|
||||
let mut score_vec = vec![0.0; n];
|
||||
let mut total = 0.0f32;
|
||||
|
||||
// Gate attention based on whether node is in cut
|
||||
for node_id in 0..n {
|
||||
if dag.get_node(node_id).is_none() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let is_in_cut = cut_nodes.contains(&node_id);
|
||||
|
||||
let score = if is_in_cut {
|
||||
// Nodes in the cut are critical bottlenecks
|
||||
1.0
|
||||
} else {
|
||||
// Other nodes get reduced attention
|
||||
self.config.gate_threshold
|
||||
};
|
||||
|
||||
score_vec[node_id] = score;
|
||||
total += score;
|
||||
}
|
||||
|
||||
// Normalize to sum to 1
|
||||
if total > 0.0 {
|
||||
for score in score_vec.iter_mut() {
|
||||
*score /= total;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(AttentionScores::new(score_vec))
|
||||
}
|
||||
|
||||
fn name(&self) -> &'static str {
|
||||
"mincut_gated"
|
||||
}
|
||||
|
||||
fn complexity(&self) -> &'static str {
|
||||
"O(n * e^2)"
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::dag::{OperatorNode, OperatorType};
|
||||
|
||||
#[test]
|
||||
fn test_mincut_gated_attention() {
|
||||
let mut dag = QueryDag::new();
|
||||
|
||||
// Create a simple bottleneck DAG
|
||||
let id0 = dag.add_node(OperatorNode::seq_scan(0, "table1"));
|
||||
let id1 = dag.add_node(OperatorNode::seq_scan(0, "table2"));
|
||||
let id2 = dag.add_node(OperatorNode::hash_join(0, "id"));
|
||||
let id3 = dag.add_node(OperatorNode::filter(0, "status = 'active'"));
|
||||
let id4 = dag.add_node(OperatorNode::project(0, vec!["name".to_string()]));
|
||||
|
||||
// Create bottleneck at node id2
|
||||
dag.add_edge(id0, id2).unwrap();
|
||||
dag.add_edge(id1, id2).unwrap();
|
||||
dag.add_edge(id2, id3).unwrap();
|
||||
dag.add_edge(id2, id4).unwrap();
|
||||
|
||||
let attention = MinCutGatedAttention::with_defaults();
|
||||
let scores = attention.forward(&dag).unwrap();
|
||||
|
||||
// Check normalization
|
||||
let sum: f32 = scores.scores.iter().sum();
|
||||
assert!((sum - 1.0).abs() < 1e-5);
|
||||
|
||||
// All scores should be in [0, 1]
|
||||
for &score in &scores.scores {
|
||||
assert!(score >= 0.0 && score <= 1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
38
vendor/ruvector/crates/ruvector-dag/src/attention/mod.rs
vendored
Normal file
38
vendor/ruvector/crates/ruvector-dag/src/attention/mod.rs
vendored
Normal file
@@ -0,0 +1,38 @@
|
||||
//! DAG Attention Mechanisms
|
||||
//!
|
||||
//! This module provides graph-topology-aware attention mechanisms for DAG-based
|
||||
//! query optimization. Unlike traditional neural attention, these mechanisms
|
||||
//! leverage the structural properties of the DAG (topology, paths, cuts) to
|
||||
//! compute attention scores.
|
||||
|
||||
// Team 2 (Agent #2) - Base attention mechanisms
|
||||
mod causal_cone;
|
||||
mod critical_path;
|
||||
mod mincut_gated;
|
||||
mod topological;
|
||||
mod traits;
|
||||
|
||||
// Team 2 (Agent #3) - Advanced attention mechanisms
|
||||
mod cache;
|
||||
mod hierarchical_lorentz;
|
||||
mod parallel_branch;
|
||||
mod selector;
|
||||
mod temporal_btsp;
|
||||
mod trait_def;
|
||||
|
||||
// Export base mechanisms
|
||||
pub use causal_cone::{CausalConeAttention, CausalConeConfig};
|
||||
pub use critical_path::{CriticalPathAttention, CriticalPathConfig};
|
||||
pub use mincut_gated::{FlowCapacity, MinCutConfig, MinCutGatedAttention};
|
||||
pub use topological::{TopologicalAttention, TopologicalConfig};
|
||||
pub use traits::{AttentionConfig, AttentionError, AttentionScores, DagAttention};
|
||||
|
||||
// Export advanced mechanisms
|
||||
pub use cache::{AttentionCache, CacheConfig, CacheStats};
|
||||
pub use hierarchical_lorentz::{HierarchicalLorentzAttention, HierarchicalLorentzConfig};
|
||||
pub use parallel_branch::{ParallelBranchAttention, ParallelBranchConfig};
|
||||
pub use selector::{AttentionSelector, MechanismStats, SelectorConfig};
|
||||
pub use temporal_btsp::{TemporalBTSPAttention, TemporalBTSPConfig};
|
||||
pub use trait_def::{
|
||||
AttentionError as AttentionErrorV2, AttentionScores as AttentionScoresV2, DagAttentionMechanism,
|
||||
};
|
||||
303
vendor/ruvector/crates/ruvector-dag/src/attention/parallel_branch.rs
vendored
Normal file
303
vendor/ruvector/crates/ruvector-dag/src/attention/parallel_branch.rs
vendored
Normal file
@@ -0,0 +1,303 @@
|
||||
//! Parallel Branch Attention: Coordinates attention across parallel execution branches
|
||||
//!
|
||||
//! This mechanism identifies parallel branches in the DAG and distributes attention
|
||||
//! to balance workload and minimize synchronization overhead.
|
||||
|
||||
use super::trait_def::{AttentionError, AttentionScores, DagAttentionMechanism};
|
||||
use crate::dag::QueryDag;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ParallelBranchConfig {
|
||||
/// Maximum number of parallel branches to consider
|
||||
pub max_branches: usize,
|
||||
/// Penalty for synchronization between branches
|
||||
pub sync_penalty: f32,
|
||||
/// Weight for branch balance in attention computation
|
||||
pub balance_weight: f32,
|
||||
/// Temperature for softmax
|
||||
pub temperature: f32,
|
||||
}
|
||||
|
||||
impl Default for ParallelBranchConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_branches: 8,
|
||||
sync_penalty: 0.2,
|
||||
balance_weight: 0.5,
|
||||
temperature: 0.1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ParallelBranchAttention {
|
||||
config: ParallelBranchConfig,
|
||||
}
|
||||
|
||||
impl ParallelBranchAttention {
|
||||
pub fn new(config: ParallelBranchConfig) -> Self {
|
||||
Self { config }
|
||||
}
|
||||
|
||||
/// Detect parallel branches (nodes with same parent, no edges between them)
|
||||
fn detect_branches(&self, dag: &QueryDag) -> Vec<Vec<usize>> {
|
||||
let n = dag.node_count();
|
||||
let mut children_of: HashMap<usize, Vec<usize>> = HashMap::new();
|
||||
let mut parents_of: HashMap<usize, Vec<usize>> = HashMap::new();
|
||||
|
||||
// Build parent-child relationships from adjacency
|
||||
for node_id in dag.node_ids() {
|
||||
let children = dag.children(node_id);
|
||||
if !children.is_empty() {
|
||||
for &child in children {
|
||||
children_of
|
||||
.entry(node_id)
|
||||
.or_insert_with(Vec::new)
|
||||
.push(child);
|
||||
parents_of
|
||||
.entry(child)
|
||||
.or_insert_with(Vec::new)
|
||||
.push(node_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut branches = Vec::new();
|
||||
let mut visited = HashSet::new();
|
||||
|
||||
// For each node, check if its children form parallel branches
|
||||
for node_id in 0..n {
|
||||
if let Some(children) = children_of.get(&node_id) {
|
||||
if children.len() > 1 {
|
||||
// Check if children are truly parallel (no edges between them)
|
||||
let mut parallel_group = Vec::new();
|
||||
|
||||
for &child in children {
|
||||
if !visited.contains(&child) {
|
||||
// Check if this child has edges to any siblings
|
||||
let child_children = dag.children(child);
|
||||
let has_sibling_edge = children
|
||||
.iter()
|
||||
.any(|&other| other != child && child_children.contains(&other));
|
||||
|
||||
if !has_sibling_edge {
|
||||
parallel_group.push(child);
|
||||
visited.insert(child);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if parallel_group.len() > 1 {
|
||||
branches.push(parallel_group);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
branches
|
||||
}
|
||||
|
||||
/// Compute branch balance score (lower is better balanced)
|
||||
fn branch_balance(&self, branches: &[Vec<usize>], dag: &QueryDag) -> f32 {
|
||||
if branches.is_empty() {
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
let mut total_variance = 0.0;
|
||||
|
||||
for branch in branches {
|
||||
if branch.len() <= 1 {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Compute costs for each node in the branch
|
||||
let costs: Vec<f64> = branch
|
||||
.iter()
|
||||
.filter_map(|&id| dag.get_node(id).map(|n| n.estimated_cost))
|
||||
.collect();
|
||||
|
||||
if costs.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Compute variance
|
||||
let mean = costs.iter().sum::<f64>() / costs.len() as f64;
|
||||
let variance =
|
||||
costs.iter().map(|&c| (c - mean).powi(2)).sum::<f64>() / costs.len() as f64;
|
||||
|
||||
total_variance += variance as f32;
|
||||
}
|
||||
|
||||
// Normalize by number of branches
|
||||
if branches.is_empty() {
|
||||
1.0
|
||||
} else {
|
||||
(total_variance / branches.len() as f32).sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute criticality score for a branch
|
||||
fn branch_criticality(&self, branch: &[usize], dag: &QueryDag) -> f32 {
|
||||
if branch.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Sum of costs in the branch
|
||||
let total_cost: f64 = branch
|
||||
.iter()
|
||||
.filter_map(|&id| dag.get_node(id).map(|n| n.estimated_cost))
|
||||
.sum();
|
||||
|
||||
// Average rows (higher rows = more critical for filtering)
|
||||
let avg_rows: f64 = branch
|
||||
.iter()
|
||||
.filter_map(|&id| dag.get_node(id).map(|n| n.estimated_rows))
|
||||
.sum::<f64>()
|
||||
/ branch.len().max(1) as f64;
|
||||
|
||||
// Criticality is high cost + high row count
|
||||
(total_cost * (avg_rows / 1000.0).min(1.0)) as f32
|
||||
}
|
||||
|
||||
/// Compute attention scores based on parallel branch analysis
|
||||
fn compute_branch_attention(&self, dag: &QueryDag, branches: &[Vec<usize>]) -> Vec<f32> {
|
||||
let n = dag.node_count();
|
||||
let mut scores = vec![0.0; n];
|
||||
|
||||
// Base score for nodes not in any branch
|
||||
let base_score = 0.5;
|
||||
for i in 0..n {
|
||||
scores[i] = base_score;
|
||||
}
|
||||
|
||||
// Compute balance metric
|
||||
let balance_penalty = self.branch_balance(branches, dag);
|
||||
|
||||
// Assign scores based on branch criticality
|
||||
for branch in branches {
|
||||
let criticality = self.branch_criticality(branch, dag);
|
||||
|
||||
// Higher criticality = higher attention
|
||||
// Apply balance penalty
|
||||
let branch_score = criticality * (1.0 - self.config.balance_weight * balance_penalty);
|
||||
|
||||
for &node_id in branch {
|
||||
if node_id < n {
|
||||
scores[node_id] = branch_score;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply sync penalty to nodes that synchronize branches
|
||||
for from in dag.node_ids() {
|
||||
for &to in dag.children(from) {
|
||||
if from < n && to < n {
|
||||
// Check if this edge connects different branches
|
||||
let from_branch = branches.iter().position(|b| b.iter().any(|&x| x == from));
|
||||
let to_branch = branches.iter().position(|b| b.iter().any(|&x| x == to));
|
||||
|
||||
if from_branch.is_some() && to_branch.is_some() && from_branch != to_branch {
|
||||
scores[to] *= 1.0 - self.config.sync_penalty;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize using softmax
|
||||
let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let exp_sum: f32 = scores
|
||||
.iter()
|
||||
.map(|&s| ((s - max_score) / self.config.temperature).exp())
|
||||
.sum();
|
||||
|
||||
if exp_sum > 0.0 {
|
||||
for score in scores.iter_mut() {
|
||||
*score = ((*score - max_score) / self.config.temperature).exp() / exp_sum;
|
||||
}
|
||||
} else {
|
||||
// Uniform if all scores are too low
|
||||
let uniform = 1.0 / n as f32;
|
||||
scores.fill(uniform);
|
||||
}
|
||||
|
||||
scores
|
||||
}
|
||||
}
|
||||
|
||||
impl DagAttentionMechanism for ParallelBranchAttention {
|
||||
fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError> {
|
||||
if dag.node_count() == 0 {
|
||||
return Err(AttentionError::InvalidDag("Empty DAG".to_string()));
|
||||
}
|
||||
|
||||
// Step 1: Detect parallel branches
|
||||
let branches = self.detect_branches(dag);
|
||||
|
||||
// Step 2: Compute attention based on branches
|
||||
let scores = self.compute_branch_attention(dag, &branches);
|
||||
|
||||
// Step 3: Build result
|
||||
let mut result = AttentionScores::new(scores)
|
||||
.with_metadata("mechanism".to_string(), "parallel_branch".to_string())
|
||||
.with_metadata("num_branches".to_string(), branches.len().to_string());
|
||||
|
||||
let balance = self.branch_balance(&branches, dag);
|
||||
result
|
||||
.metadata
|
||||
.insert("balance_score".to_string(), format!("{:.4}", balance));
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn name(&self) -> &'static str {
|
||||
"parallel_branch"
|
||||
}
|
||||
|
||||
fn complexity(&self) -> &'static str {
|
||||
"O(n² + b·n)"
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::dag::{OperatorNode, OperatorType};
|
||||
|
||||
#[test]
|
||||
fn test_detect_branches() {
|
||||
let config = ParallelBranchConfig::default();
|
||||
let attention = ParallelBranchAttention::new(config);
|
||||
|
||||
let mut dag = QueryDag::new();
|
||||
for i in 0..4 {
|
||||
dag.add_node(OperatorNode::new(i, OperatorType::Scan));
|
||||
}
|
||||
|
||||
// Create parallel branches: 0 -> 1, 0 -> 2, 1 -> 3, 2 -> 3
|
||||
dag.add_edge(0, 1).unwrap();
|
||||
dag.add_edge(0, 2).unwrap();
|
||||
dag.add_edge(1, 3).unwrap();
|
||||
dag.add_edge(2, 3).unwrap();
|
||||
|
||||
let branches = attention.detect_branches(&dag);
|
||||
assert!(!branches.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parallel_attention() {
|
||||
let config = ParallelBranchConfig::default();
|
||||
let attention = ParallelBranchAttention::new(config);
|
||||
|
||||
let mut dag = QueryDag::new();
|
||||
for i in 0..3 {
|
||||
let mut node = OperatorNode::new(i, OperatorType::Scan);
|
||||
node.estimated_cost = (i + 1) as f64;
|
||||
dag.add_node(node);
|
||||
}
|
||||
dag.add_edge(0, 1).unwrap();
|
||||
dag.add_edge(0, 2).unwrap();
|
||||
|
||||
let result = attention.forward(&dag).unwrap();
|
||||
assert_eq!(result.scores.len(), 3);
|
||||
}
|
||||
}
|
||||
305
vendor/ruvector/crates/ruvector-dag/src/attention/selector.rs
vendored
Normal file
305
vendor/ruvector/crates/ruvector-dag/src/attention/selector.rs
vendored
Normal file
@@ -0,0 +1,305 @@
|
||||
//! Attention Selector: UCB Bandit for mechanism selection
|
||||
//!
|
||||
//! Implements Upper Confidence Bound (UCB1) algorithm to dynamically select
|
||||
//! the best attention mechanism based on observed performance.
|
||||
|
||||
use super::trait_def::{AttentionError, AttentionScores, DagAttentionMechanism};
|
||||
use crate::dag::QueryDag;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SelectorConfig {
|
||||
/// UCB exploration constant (typically sqrt(2))
|
||||
pub exploration_factor: f32,
|
||||
/// Optimistic initialization value
|
||||
pub initial_value: f32,
|
||||
/// Minimum samples before exploitation
|
||||
pub min_samples: usize,
|
||||
}
|
||||
|
||||
impl Default for SelectorConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
exploration_factor: (2.0_f32).sqrt(),
|
||||
initial_value: 1.0,
|
||||
min_samples: 5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AttentionSelector {
|
||||
config: SelectorConfig,
|
||||
mechanisms: Vec<Box<dyn DagAttentionMechanism>>,
|
||||
/// Cumulative rewards for each mechanism
|
||||
rewards: Vec<f32>,
|
||||
/// Number of times each mechanism was selected
|
||||
counts: Vec<usize>,
|
||||
/// Total number of selections
|
||||
total_count: usize,
|
||||
}
|
||||
|
||||
impl AttentionSelector {
|
||||
pub fn new(mechanisms: Vec<Box<dyn DagAttentionMechanism>>, config: SelectorConfig) -> Self {
|
||||
let n = mechanisms.len();
|
||||
let initial_value = config.initial_value;
|
||||
Self {
|
||||
config,
|
||||
mechanisms,
|
||||
rewards: vec![initial_value; n],
|
||||
counts: vec![0; n],
|
||||
total_count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Select mechanism using UCB1 algorithm
|
||||
pub fn select(&self) -> usize {
|
||||
if self.mechanisms.is_empty() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// If any mechanism hasn't been tried min_samples times, try it
|
||||
for (i, &count) in self.counts.iter().enumerate() {
|
||||
if count < self.config.min_samples {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
|
||||
// UCB1 selection: exploitation + exploration
|
||||
let ln_total = (self.total_count as f32).ln().max(1.0);
|
||||
|
||||
let ucb_values: Vec<f32> = self
|
||||
.mechanisms
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, _)| {
|
||||
let count = self.counts[i] as f32;
|
||||
if count == 0.0 {
|
||||
return f32::INFINITY;
|
||||
}
|
||||
|
||||
let exploitation = self.rewards[i] / count;
|
||||
let exploration = self.config.exploration_factor * (ln_total / count).sqrt();
|
||||
|
||||
exploitation + exploration
|
||||
})
|
||||
.collect();
|
||||
|
||||
ucb_values
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
||||
.map(|(i, _)| i)
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Update rewards after execution
|
||||
pub fn update(&mut self, mechanism_idx: usize, reward: f32) {
|
||||
if mechanism_idx < self.rewards.len() {
|
||||
self.rewards[mechanism_idx] += reward;
|
||||
self.counts[mechanism_idx] += 1;
|
||||
self.total_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the selected mechanism
|
||||
pub fn get_mechanism(&self, idx: usize) -> Option<&dyn DagAttentionMechanism> {
|
||||
self.mechanisms.get(idx).map(|m| m.as_ref())
|
||||
}
|
||||
|
||||
/// Get mutable reference to mechanism for updates
|
||||
pub fn get_mechanism_mut(&mut self, idx: usize) -> Option<&mut Box<dyn DagAttentionMechanism>> {
|
||||
self.mechanisms.get_mut(idx)
|
||||
}
|
||||
|
||||
/// Get statistics for all mechanisms
|
||||
pub fn stats(&self) -> HashMap<&'static str, MechanismStats> {
|
||||
self.mechanisms
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, m)| {
|
||||
let stats = MechanismStats {
|
||||
total_reward: self.rewards[i],
|
||||
count: self.counts[i],
|
||||
avg_reward: if self.counts[i] > 0 {
|
||||
self.rewards[i] / self.counts[i] as f32
|
||||
} else {
|
||||
0.0
|
||||
},
|
||||
};
|
||||
(m.name(), stats)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get the best performing mechanism based on average reward
|
||||
pub fn best_mechanism(&self) -> Option<usize> {
|
||||
self.mechanisms
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(i, _)| self.counts[*i] >= self.config.min_samples)
|
||||
.max_by(|(i, _), (j, _)| {
|
||||
let avg_i = self.rewards[*i] / self.counts[*i] as f32;
|
||||
let avg_j = self.rewards[*j] / self.counts[*j] as f32;
|
||||
avg_i
|
||||
.partial_cmp(&avg_j)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
})
|
||||
.map(|(i, _)| i)
|
||||
}
|
||||
|
||||
/// Reset all statistics
|
||||
pub fn reset(&mut self) {
|
||||
for i in 0..self.rewards.len() {
|
||||
self.rewards[i] = self.config.initial_value;
|
||||
self.counts[i] = 0;
|
||||
}
|
||||
self.total_count = 0;
|
||||
}
|
||||
|
||||
/// Forward pass using selected mechanism
|
||||
pub fn forward(&mut self, dag: &QueryDag) -> Result<(AttentionScores, usize), AttentionError> {
|
||||
let selected = self.select();
|
||||
let mechanism = self
|
||||
.get_mechanism(selected)
|
||||
.ok_or_else(|| AttentionError::ConfigError("No mechanisms available".to_string()))?;
|
||||
|
||||
let scores = mechanism.forward(dag)?;
|
||||
Ok((scores, selected))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MechanismStats {
|
||||
pub total_reward: f32,
|
||||
pub count: usize,
|
||||
pub avg_reward: f32,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::dag::{OperatorNode, OperatorType, QueryDag};
|
||||
|
||||
// Mock mechanism for testing
|
||||
struct MockMechanism {
|
||||
name: &'static str,
|
||||
score_value: f32,
|
||||
}
|
||||
|
||||
impl DagAttentionMechanism for MockMechanism {
|
||||
fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError> {
|
||||
let scores = vec![self.score_value; dag.nodes.len()];
|
||||
Ok(AttentionScores::new(scores))
|
||||
}
|
||||
|
||||
fn name(&self) -> &'static str {
|
||||
self.name
|
||||
}
|
||||
|
||||
fn complexity(&self) -> &'static str {
|
||||
"O(1)"
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ucb_selection() {
|
||||
let mechanisms: Vec<Box<dyn DagAttentionMechanism>> = vec![
|
||||
Box::new(MockMechanism {
|
||||
name: "mech1",
|
||||
score_value: 0.5,
|
||||
}),
|
||||
Box::new(MockMechanism {
|
||||
name: "mech2",
|
||||
score_value: 0.7,
|
||||
}),
|
||||
Box::new(MockMechanism {
|
||||
name: "mech3",
|
||||
score_value: 0.3,
|
||||
}),
|
||||
];
|
||||
|
||||
let mut selector = AttentionSelector::new(mechanisms, SelectorConfig::default());
|
||||
|
||||
// First selections should explore all mechanisms
|
||||
for _ in 0..15 {
|
||||
let selected = selector.select();
|
||||
selector.update(selected, 0.5);
|
||||
}
|
||||
|
||||
assert!(selector.total_count > 0);
|
||||
assert!(selector.counts.iter().all(|&c| c > 0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_best_mechanism() {
|
||||
let mechanisms: Vec<Box<dyn DagAttentionMechanism>> = vec![
|
||||
Box::new(MockMechanism {
|
||||
name: "poor",
|
||||
score_value: 0.3,
|
||||
}),
|
||||
Box::new(MockMechanism {
|
||||
name: "good",
|
||||
score_value: 0.8,
|
||||
}),
|
||||
];
|
||||
|
||||
let mut selector = AttentionSelector::new(
|
||||
mechanisms,
|
||||
SelectorConfig {
|
||||
min_samples: 2,
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
|
||||
// Simulate different rewards
|
||||
selector.update(0, 0.3);
|
||||
selector.update(0, 0.4);
|
||||
selector.update(1, 0.8);
|
||||
selector.update(1, 0.9);
|
||||
|
||||
let best = selector.best_mechanism().unwrap();
|
||||
assert_eq!(best, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_selector_forward() {
|
||||
let mechanisms: Vec<Box<dyn DagAttentionMechanism>> = vec![Box::new(MockMechanism {
|
||||
name: "test",
|
||||
score_value: 0.5,
|
||||
})];
|
||||
|
||||
let mut selector = AttentionSelector::new(mechanisms, SelectorConfig::default());
|
||||
|
||||
let mut dag = QueryDag::new();
|
||||
let node = OperatorNode::new(0, OperatorType::Scan);
|
||||
dag.add_node(node);
|
||||
|
||||
let (scores, idx) = selector.forward(&dag).unwrap();
|
||||
assert_eq!(scores.scores.len(), 1);
|
||||
assert_eq!(idx, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stats() {
|
||||
let mechanisms: Vec<Box<dyn DagAttentionMechanism>> = vec![Box::new(MockMechanism {
|
||||
name: "mech1",
|
||||
score_value: 0.5,
|
||||
})];
|
||||
|
||||
// Use initial_value = 0 so we can test pure update accumulation
|
||||
let config = SelectorConfig {
|
||||
initial_value: 0.0,
|
||||
..Default::default()
|
||||
};
|
||||
let mut selector = AttentionSelector::new(mechanisms, config);
|
||||
selector.update(0, 1.0);
|
||||
selector.update(0, 2.0);
|
||||
|
||||
let stats = selector.stats();
|
||||
let mech1_stats = stats.get("mech1").unwrap();
|
||||
|
||||
assert_eq!(mech1_stats.count, 2);
|
||||
assert_eq!(mech1_stats.total_reward, 3.0);
|
||||
assert_eq!(mech1_stats.avg_reward, 1.5);
|
||||
}
|
||||
}
|
||||
301
vendor/ruvector/crates/ruvector-dag/src/attention/temporal_btsp.rs
vendored
Normal file
301
vendor/ruvector/crates/ruvector-dag/src/attention/temporal_btsp.rs
vendored
Normal file
@@ -0,0 +1,301 @@
|
||||
//! Temporal BTSP Attention: Behavioral Timescale Synaptic Plasticity
|
||||
//!
|
||||
//! This mechanism implements a biologically-inspired attention mechanism based on
|
||||
//! eligibility traces and plateau potentials, allowing the system to learn from
|
||||
//! temporal patterns in query execution.
|
||||
|
||||
use super::trait_def::{AttentionError, AttentionScores, DagAttentionMechanism};
|
||||
use crate::dag::QueryDag;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Instant;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TemporalBTSPConfig {
|
||||
/// Duration of plateau state in milliseconds
|
||||
pub plateau_duration_ms: u64,
|
||||
/// Decay rate for eligibility traces (0.0 to 1.0)
|
||||
pub eligibility_decay: f32,
|
||||
/// Learning rate for trace updates
|
||||
pub learning_rate: f32,
|
||||
/// Temperature for softmax
|
||||
pub temperature: f32,
|
||||
/// Baseline attention for nodes without history
|
||||
pub baseline_attention: f32,
|
||||
}
|
||||
|
||||
impl Default for TemporalBTSPConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
plateau_duration_ms: 500,
|
||||
eligibility_decay: 0.95,
|
||||
learning_rate: 0.1,
|
||||
temperature: 0.1,
|
||||
baseline_attention: 0.5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TemporalBTSPAttention {
|
||||
config: TemporalBTSPConfig,
|
||||
/// Eligibility traces for each node
|
||||
eligibility_traces: HashMap<usize, f32>,
|
||||
/// Timestamp of last plateau for each node
|
||||
last_plateau: HashMap<usize, Instant>,
|
||||
/// Total updates counter
|
||||
update_count: usize,
|
||||
}
|
||||
|
||||
impl TemporalBTSPAttention {
|
||||
pub fn new(config: TemporalBTSPConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
eligibility_traces: HashMap::new(),
|
||||
last_plateau: HashMap::new(),
|
||||
update_count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Update eligibility trace for a node
|
||||
fn update_eligibility(&mut self, node_id: usize, signal: f32) {
|
||||
let trace = self.eligibility_traces.entry(node_id).or_insert(0.0);
|
||||
*trace = *trace * self.config.eligibility_decay + signal * self.config.learning_rate;
|
||||
|
||||
// Clamp to [0, 1]
|
||||
*trace = trace.max(0.0).min(1.0);
|
||||
}
|
||||
|
||||
/// Check if node is in plateau state
|
||||
fn is_plateau(&self, node_id: usize) -> bool {
|
||||
self.last_plateau
|
||||
.get(&node_id)
|
||||
.map(|t| t.elapsed().as_millis() < self.config.plateau_duration_ms as u128)
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Trigger plateau for a node
|
||||
fn trigger_plateau(&mut self, node_id: usize) {
|
||||
self.last_plateau.insert(node_id, Instant::now());
|
||||
}
|
||||
|
||||
/// Compute base attention from topology
|
||||
fn compute_topology_attention(&self, dag: &QueryDag) -> Vec<f32> {
|
||||
let n = dag.node_count();
|
||||
let mut scores = vec![self.config.baseline_attention; n];
|
||||
|
||||
// Simple heuristic: nodes with higher cost get more attention
|
||||
for node in dag.nodes() {
|
||||
if node.id < n {
|
||||
let cost_factor = (node.estimated_cost as f32 / 100.0).min(1.0);
|
||||
let rows_factor = (node.estimated_rows as f32 / 1000.0).min(1.0);
|
||||
scores[node.id] = 0.5 * cost_factor + 0.5 * rows_factor;
|
||||
}
|
||||
}
|
||||
|
||||
scores
|
||||
}
|
||||
|
||||
/// Apply eligibility trace modulation
|
||||
fn apply_eligibility_modulation(&self, base_scores: &mut [f32]) {
|
||||
for (node_id, &trace) in &self.eligibility_traces {
|
||||
if *node_id < base_scores.len() {
|
||||
// Boost attention based on eligibility trace
|
||||
base_scores[*node_id] *= 1.0 + trace;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply plateau boosting
|
||||
fn apply_plateau_boost(&self, scores: &mut [f32]) {
|
||||
for (node_id, _) in &self.last_plateau {
|
||||
if *node_id < scores.len() && self.is_plateau(*node_id) {
|
||||
// Strong boost for nodes in plateau state
|
||||
scores[*node_id] *= 1.5;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Normalize scores using softmax
|
||||
fn normalize_scores(&self, scores: &mut [f32]) {
|
||||
if scores.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let exp_sum: f32 = scores
|
||||
.iter()
|
||||
.map(|&s| ((s - max_score) / self.config.temperature).exp())
|
||||
.sum();
|
||||
|
||||
if exp_sum > 0.0 {
|
||||
for score in scores.iter_mut() {
|
||||
*score = ((*score - max_score) / self.config.temperature).exp() / exp_sum;
|
||||
}
|
||||
} else {
|
||||
let uniform = 1.0 / scores.len() as f32;
|
||||
scores.fill(uniform);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DagAttentionMechanism for TemporalBTSPAttention {
|
||||
fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError> {
|
||||
if dag.nodes.is_empty() {
|
||||
return Err(AttentionError::InvalidDag("Empty DAG".to_string()));
|
||||
}
|
||||
|
||||
// Step 1: Compute base attention from topology
|
||||
let mut scores = self.compute_topology_attention(dag);
|
||||
|
||||
// Step 2: Modulate by eligibility traces
|
||||
self.apply_eligibility_modulation(&mut scores);
|
||||
|
||||
// Step 3: Apply plateau boosting for recently active nodes
|
||||
self.apply_plateau_boost(&mut scores);
|
||||
|
||||
// Step 4: Normalize
|
||||
self.normalize_scores(&mut scores);
|
||||
|
||||
// Build result with metadata
|
||||
let mut result = AttentionScores::new(scores)
|
||||
.with_metadata("mechanism".to_string(), "temporal_btsp".to_string())
|
||||
.with_metadata("update_count".to_string(), self.update_count.to_string());
|
||||
|
||||
let active_traces = self
|
||||
.eligibility_traces
|
||||
.values()
|
||||
.filter(|&&t| t > 0.01)
|
||||
.count();
|
||||
result
|
||||
.metadata
|
||||
.insert("active_traces".to_string(), active_traces.to_string());
|
||||
|
||||
let active_plateaus = self
|
||||
.last_plateau
|
||||
.keys()
|
||||
.filter(|k| self.is_plateau(**k))
|
||||
.count();
|
||||
result
|
||||
.metadata
|
||||
.insert("active_plateaus".to_string(), active_plateaus.to_string());
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn name(&self) -> &'static str {
|
||||
"temporal_btsp"
|
||||
}
|
||||
|
||||
fn complexity(&self) -> &'static str {
|
||||
"O(n + t)"
|
||||
}
|
||||
|
||||
fn update(&mut self, dag: &QueryDag, execution_times: &HashMap<usize, f64>) {
|
||||
self.update_count += 1;
|
||||
|
||||
// Update eligibility traces based on execution feedback
|
||||
for (node_id, &exec_time) in execution_times {
|
||||
let node = match dag.get_node(*node_id) {
|
||||
Some(n) => n,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let expected_time = node.estimated_cost;
|
||||
|
||||
// Compute reward signal: positive if faster than expected, negative if slower
|
||||
let time_ratio = exec_time / expected_time.max(0.001);
|
||||
let reward = if time_ratio < 1.0 {
|
||||
// Faster than expected - positive signal
|
||||
1.0 - time_ratio as f32
|
||||
} else {
|
||||
// Slower than expected - negative signal
|
||||
-(time_ratio as f32 - 1.0).min(1.0)
|
||||
};
|
||||
|
||||
// Update eligibility trace
|
||||
self.update_eligibility(*node_id, reward);
|
||||
|
||||
// Trigger plateau for nodes that significantly exceeded expectations
|
||||
if reward > 0.3 {
|
||||
self.trigger_plateau(*node_id);
|
||||
}
|
||||
}
|
||||
|
||||
// Decay traces for nodes that weren't executed
|
||||
let executed_nodes: std::collections::HashSet<_> = execution_times.keys().collect();
|
||||
for node_id in 0..dag.node_count() {
|
||||
if !executed_nodes.contains(&node_id) {
|
||||
self.update_eligibility(node_id, 0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn reset(&mut self) {
|
||||
self.eligibility_traces.clear();
|
||||
self.last_plateau.clear();
|
||||
self.update_count = 0;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::dag::{OperatorNode, OperatorType};
|
||||
use std::thread::sleep;
|
||||
use std::time::Duration;
|
||||
|
||||
#[test]
|
||||
fn test_eligibility_update() {
|
||||
let config = TemporalBTSPConfig::default();
|
||||
let mut attention = TemporalBTSPAttention::new(config);
|
||||
|
||||
attention.update_eligibility(0, 0.5);
|
||||
assert!(attention.eligibility_traces.get(&0).unwrap() > &0.0);
|
||||
|
||||
attention.update_eligibility(0, 0.5);
|
||||
assert!(attention.eligibility_traces.get(&0).unwrap() > &0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_plateau_state() {
|
||||
let mut config = TemporalBTSPConfig::default();
|
||||
config.plateau_duration_ms = 100;
|
||||
let mut attention = TemporalBTSPAttention::new(config);
|
||||
|
||||
attention.trigger_plateau(0);
|
||||
assert!(attention.is_plateau(0));
|
||||
|
||||
sleep(Duration::from_millis(150));
|
||||
assert!(!attention.is_plateau(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_temporal_attention() {
|
||||
let config = TemporalBTSPConfig::default();
|
||||
let mut attention = TemporalBTSPAttention::new(config);
|
||||
|
||||
let mut dag = QueryDag::new();
|
||||
for i in 0..3 {
|
||||
let mut node = OperatorNode::new(i, OperatorType::Scan);
|
||||
node.estimated_cost = 10.0;
|
||||
dag.add_node(node);
|
||||
}
|
||||
|
||||
// Initial forward pass
|
||||
let result1 = attention.forward(&dag).unwrap();
|
||||
assert_eq!(result1.scores.len(), 3);
|
||||
|
||||
// Simulate execution feedback
|
||||
let mut exec_times = HashMap::new();
|
||||
exec_times.insert(0, 5.0); // Faster than expected
|
||||
exec_times.insert(1, 15.0); // Slower than expected
|
||||
|
||||
attention.update(&dag, &exec_times);
|
||||
|
||||
// Second forward pass should show different attention
|
||||
let result2 = attention.forward(&dag).unwrap();
|
||||
assert_eq!(result2.scores.len(), 3);
|
||||
|
||||
// Node 0 should have higher attention due to positive feedback
|
||||
assert!(attention.eligibility_traces.get(&0).unwrap() > &0.0);
|
||||
}
|
||||
}
|
||||
109
vendor/ruvector/crates/ruvector-dag/src/attention/topological.rs
vendored
Normal file
109
vendor/ruvector/crates/ruvector-dag/src/attention/topological.rs
vendored
Normal file
@@ -0,0 +1,109 @@
|
||||
//! Topological Attention: Respects DAG ordering with depth-based decay
|
||||
|
||||
use super::{AttentionError, AttentionScores, DagAttention};
|
||||
use crate::dag::QueryDag;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TopologicalConfig {
|
||||
pub decay_factor: f32, // 0.9 default
|
||||
pub max_depth: usize, // 10 default
|
||||
}
|
||||
|
||||
impl Default for TopologicalConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
decay_factor: 0.9,
|
||||
max_depth: 10,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TopologicalAttention {
|
||||
config: TopologicalConfig,
|
||||
}
|
||||
|
||||
impl TopologicalAttention {
|
||||
pub fn new(config: TopologicalConfig) -> Self {
|
||||
Self { config }
|
||||
}
|
||||
|
||||
pub fn with_defaults() -> Self {
|
||||
Self::new(TopologicalConfig::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl DagAttention for TopologicalAttention {
|
||||
fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError> {
|
||||
if dag.node_count() == 0 {
|
||||
return Err(AttentionError::EmptyDag);
|
||||
}
|
||||
|
||||
let depths = dag.compute_depths();
|
||||
let max_depth = depths.values().max().copied().unwrap_or(0);
|
||||
|
||||
let mut scores = HashMap::new();
|
||||
let mut total = 0.0f32;
|
||||
|
||||
for (&node_id, &depth) in &depths {
|
||||
// Higher attention for nodes closer to root (higher depth from leaves)
|
||||
let normalized_depth = depth as f32 / (max_depth.max(1) as f32);
|
||||
let score = self.config.decay_factor.powf(1.0 - normalized_depth);
|
||||
scores.insert(node_id, score);
|
||||
total += score;
|
||||
}
|
||||
|
||||
// Normalize to sum to 1
|
||||
if total > 0.0 {
|
||||
for score in scores.values_mut() {
|
||||
*score /= total;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(scores)
|
||||
}
|
||||
|
||||
fn update(&mut self, _dag: &QueryDag, _times: &HashMap<usize, f64>) {
|
||||
// Topological attention is static, no updates needed
|
||||
}
|
||||
|
||||
fn name(&self) -> &'static str {
|
||||
"topological"
|
||||
}
|
||||
|
||||
fn complexity(&self) -> &'static str {
|
||||
"O(n)"
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::dag::{OperatorNode, OperatorType};
|
||||
|
||||
#[test]
|
||||
fn test_topological_attention() {
|
||||
let mut dag = QueryDag::new();
|
||||
|
||||
// Create a simple DAG: 0 -> 1 -> 2
|
||||
let id0 = dag.add_node(OperatorNode::seq_scan(0, "users").with_estimates(100.0, 1.0));
|
||||
let id1 = dag.add_node(OperatorNode::filter(0, "age > 18").with_estimates(50.0, 1.0));
|
||||
let id2 = dag
|
||||
.add_node(OperatorNode::project(0, vec!["name".to_string()]).with_estimates(50.0, 1.0));
|
||||
|
||||
dag.add_edge(id0, id1).unwrap();
|
||||
dag.add_edge(id1, id2).unwrap();
|
||||
|
||||
let attention = TopologicalAttention::with_defaults();
|
||||
let scores = attention.forward(&dag).unwrap();
|
||||
|
||||
// Check that scores sum to ~1.0
|
||||
let sum: f32 = scores.values().sum();
|
||||
assert!((sum - 1.0).abs() < 1e-5);
|
||||
|
||||
// All scores should be in [0, 1]
|
||||
for &score in scores.values() {
|
||||
assert!(score >= 0.0 && score <= 1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
75
vendor/ruvector/crates/ruvector-dag/src/attention/trait_def.rs
vendored
Normal file
75
vendor/ruvector/crates/ruvector-dag/src/attention/trait_def.rs
vendored
Normal file
@@ -0,0 +1,75 @@
|
||||
//! DagAttention trait definition for pluggable attention mechanisms
|
||||
|
||||
use crate::dag::QueryDag;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use thiserror::Error;
|
||||
|
||||
/// Attention scores for each node in the DAG
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AttentionScores {
|
||||
/// Attention score for each node (0.0 to 1.0)
|
||||
pub scores: Vec<f32>,
|
||||
/// Optional attention weights between nodes (adjacency-like)
|
||||
pub edge_weights: Option<Vec<Vec<f32>>>,
|
||||
/// Metadata for debugging
|
||||
pub metadata: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl AttentionScores {
|
||||
pub fn new(scores: Vec<f32>) -> Self {
|
||||
Self {
|
||||
scores,
|
||||
edge_weights: None,
|
||||
metadata: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_edge_weights(mut self, weights: Vec<Vec<f32>>) -> Self {
|
||||
self.edge_weights = Some(weights);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_metadata(mut self, key: String, value: String) -> Self {
|
||||
self.metadata.insert(key, value);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Errors that can occur during attention computation
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AttentionError {
|
||||
#[error("Invalid DAG structure: {0}")]
|
||||
InvalidDag(String),
|
||||
|
||||
#[error("Dimension mismatch: expected {expected}, got {actual}")]
|
||||
DimensionMismatch { expected: usize, actual: usize },
|
||||
|
||||
#[error("Computation failed: {0}")]
|
||||
ComputationFailed(String),
|
||||
|
||||
#[error("Configuration error: {0}")]
|
||||
ConfigError(String),
|
||||
}
|
||||
|
||||
/// Trait for DAG attention mechanisms
|
||||
pub trait DagAttentionMechanism: Send + Sync {
|
||||
/// Compute attention scores for the given DAG
|
||||
fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError>;
|
||||
|
||||
/// Get the mechanism name
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// Get computational complexity as a string
|
||||
fn complexity(&self) -> &'static str;
|
||||
|
||||
/// Optional: Update internal state based on execution feedback
|
||||
fn update(&mut self, _dag: &QueryDag, _execution_times: &HashMap<usize, f64>) {
|
||||
// Default: no-op
|
||||
}
|
||||
|
||||
/// Optional: Reset internal state
|
||||
fn reset(&mut self) {
|
||||
// Default: no-op
|
||||
}
|
||||
}
|
||||
53
vendor/ruvector/crates/ruvector-dag/src/attention/traits.rs
vendored
Normal file
53
vendor/ruvector/crates/ruvector-dag/src/attention/traits.rs
vendored
Normal file
@@ -0,0 +1,53 @@
|
||||
//! Core traits and types for DAG attention mechanisms
|
||||
|
||||
use crate::dag::QueryDag;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Attention scores for DAG nodes
|
||||
pub type AttentionScores = HashMap<usize, f32>;
|
||||
|
||||
/// Configuration for attention mechanisms
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AttentionConfig {
|
||||
pub normalize: bool,
|
||||
pub temperature: f32,
|
||||
pub dropout: f32,
|
||||
}
|
||||
|
||||
impl Default for AttentionConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
normalize: true,
|
||||
temperature: 1.0,
|
||||
dropout: 0.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Errors from attention computation
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum AttentionError {
|
||||
#[error("Empty DAG")]
|
||||
EmptyDag,
|
||||
#[error("Cycle detected in DAG")]
|
||||
CycleDetected,
|
||||
#[error("Node {0} not found")]
|
||||
NodeNotFound(usize),
|
||||
#[error("Computation failed: {0}")]
|
||||
ComputationFailed(String),
|
||||
}
|
||||
|
||||
/// Trait for DAG attention mechanisms
|
||||
pub trait DagAttention: Send + Sync {
|
||||
/// Compute attention scores for all nodes
|
||||
fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError>;
|
||||
|
||||
/// Update internal state after execution feedback
|
||||
fn update(&mut self, dag: &QueryDag, execution_times: &HashMap<usize, f64>);
|
||||
|
||||
/// Get mechanism name
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// Get computational complexity description
|
||||
fn complexity(&self) -> &'static str;
|
||||
}
|
||||
Reference in New Issue
Block a user