Files
wifi-densepose/crates/ruvector-dag/tests/integration/attention_tests.rs
ruv d803bfe2b1 Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector
git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
2026-02-28 14:39:40 -05:00

212 lines
5.4 KiB
Rust

//! Attention mechanism integration tests
use ruvector_dag::attention::*;
use ruvector_dag::dag::{OperatorNode, OperatorType, QueryDag};
fn create_test_dag() -> QueryDag {
let mut dag = QueryDag::new();
// Simple linear DAG
for i in 0..5 {
dag.add_node(OperatorNode::new(
i,
OperatorType::SeqScan {
table: format!("t{}", i),
},
));
}
for i in 0..4 {
dag.add_edge(i, i + 1).unwrap();
}
dag
}
#[test]
fn test_topological_attention() {
let dag = create_test_dag();
let attention = TopologicalAttention::new(TopologicalConfig::default());
let scores = attention.forward(&dag).unwrap();
// Verify normalization
let sum: f32 = scores.values().sum();
assert!(
(sum - 1.0).abs() < 0.001,
"Attention scores should sum to 1.0"
);
// Verify all scores in [0, 1]
assert!(scores.values().all(|&s| s >= 0.0 && s <= 1.0));
}
// Mock mechanism for testing selector with DagAttentionMechanism trait
struct MockMechanism {
name: &'static str,
score_value: f32,
}
impl DagAttentionMechanism for MockMechanism {
fn forward(&self, dag: &QueryDag) -> Result<AttentionScoresV2, AttentionErrorV2> {
let scores = vec![self.score_value; dag.node_count()];
Ok(AttentionScoresV2::new(scores))
}
fn name(&self) -> &'static str {
self.name
}
fn complexity(&self) -> &'static str {
"O(1)"
}
}
#[test]
fn test_attention_selector_convergence() {
let mechanisms: Vec<Box<dyn DagAttentionMechanism>> = vec![Box::new(MockMechanism {
name: "test_mech",
score_value: 0.5,
})];
let mut selector = AttentionSelector::new(mechanisms, SelectorConfig::default());
// Run selection multiple times
let mut selection_counts = std::collections::HashMap::new();
for _ in 0..100 {
let idx = selector.select();
*selection_counts.entry(idx).or_insert(0) += 1;
selector.update(idx, 0.5 + rand::random::<f32>() * 0.5);
}
// Should have made selections
assert!(selection_counts.values().sum::<usize>() == 100);
}
#[test]
fn test_attention_cache() {
let config = CacheConfig {
capacity: 100,
ttl: None,
};
let mut cache = AttentionCache::new(config);
let dag = create_test_dag();
// Cache miss
assert!(cache.get(&dag, "topological").is_none());
// Insert using the correct type
let scores = AttentionScoresV2::new(vec![0.2, 0.2, 0.2, 0.2, 0.2]);
cache.insert(&dag, "topological", scores);
// Cache hit
assert!(cache.get(&dag, "topological").is_some());
}
#[test]
fn test_attention_decay_factor() {
let dag = create_test_dag();
// Low decay factor (sharper distribution)
let config_low = TopologicalConfig {
decay_factor: 0.5,
max_depth: 10,
};
let attention_low = TopologicalAttention::new(config_low);
let scores_low = attention_low.forward(&dag).unwrap();
// High decay factor (smoother distribution)
let config_high = TopologicalConfig {
decay_factor: 0.99,
max_depth: 10,
};
let attention_high = TopologicalAttention::new(config_high);
let scores_high = attention_high.forward(&dag).unwrap();
// Both should be normalized
let sum_low: f32 = scores_low.values().sum();
let sum_high: f32 = scores_high.values().sum();
assert!((sum_low - 1.0).abs() < 0.001);
assert!((sum_high - 1.0).abs() < 0.001);
}
#[test]
fn test_attention_empty_dag() {
let dag = QueryDag::new();
let attention = TopologicalAttention::new(TopologicalConfig::default());
let result = attention.forward(&dag);
// Empty DAG returns error
assert!(result.is_err());
}
#[test]
fn test_attention_single_node() {
let mut dag = QueryDag::new();
dag.add_node(OperatorNode::new(0, OperatorType::Result));
let attention = TopologicalAttention::new(TopologicalConfig::default());
let scores = attention.forward(&dag).unwrap();
// Single node should get score of 1.0
assert_eq!(scores.len(), 1);
assert!((scores[&0] - 1.0).abs() < 0.001);
}
#[test]
fn test_attention_cache_eviction() {
let config = CacheConfig {
capacity: 2,
ttl: None,
};
let mut cache = AttentionCache::new(config);
// Fill cache beyond capacity
for i in 0..5 {
let mut dag = QueryDag::new();
dag.add_node(OperatorNode::new(i, OperatorType::Result));
let scores = AttentionScoresV2::new(vec![1.0]);
cache.insert(&dag, "test", scores);
}
// Cache stats should show eviction happened
let stats = cache.stats();
assert!(stats.size <= 2);
}
#[test]
fn test_multi_mechanism_selector() {
let mechanisms: Vec<Box<dyn DagAttentionMechanism>> = vec![
Box::new(MockMechanism {
name: "mech1",
score_value: 0.5,
}),
Box::new(MockMechanism {
name: "mech2",
score_value: 0.7,
}),
];
let mut selector = AttentionSelector::new(
mechanisms,
SelectorConfig {
exploration_factor: 0.1,
initial_value: 1.0,
min_samples: 3,
},
);
// Both mechanisms should be selected at some point
let mut used = std::collections::HashSet::new();
for _ in 0..50 {
let idx = selector.select();
used.insert(idx);
selector.update(idx, 0.5);
}
assert!(used.len() >= 1, "At least one mechanism should be selected");
}