git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
100 lines
3.2 KiB
Rust
100 lines
3.2 KiB
Rust
//! Attention mechanism selection example
|
|
|
|
use ruvector_dag::attention::{
|
|
CausalConeAttention, CausalConeConfig, DagAttention, TopologicalAttention, TopologicalConfig,
|
|
};
|
|
use ruvector_dag::dag::{OperatorNode, OperatorType, QueryDag};
|
|
|
|
fn main() {
|
|
println!("=== Attention Mechanism Selection ===\n");
|
|
|
|
// Create a sample DAG
|
|
let dag = create_vector_search_dag();
|
|
|
|
println!("Created vector search DAG:");
|
|
println!(" Nodes: {}", dag.node_count());
|
|
println!(" Edges: {}", dag.edge_count());
|
|
|
|
// Test Topological Attention
|
|
println!("\n--- Topological Attention ---");
|
|
println!("Emphasizes node depth in the DAG hierarchy");
|
|
|
|
let topo = TopologicalAttention::new(TopologicalConfig {
|
|
decay_factor: 0.9,
|
|
max_depth: 10,
|
|
});
|
|
|
|
let scores = topo.forward(&dag).unwrap();
|
|
println!("\nAttention scores:");
|
|
for (node_id, score) in &scores {
|
|
let node = dag.get_node(*node_id).unwrap();
|
|
println!(" Node {}: {:.4} - {:?}", node_id, score, node.op_type);
|
|
}
|
|
|
|
let sum: f32 = scores.values().sum();
|
|
println!("\nSum of scores: {:.4} (should be ~1.0)", sum);
|
|
|
|
// Test Causal Cone Attention
|
|
println!("\n--- Causal Cone Attention ---");
|
|
println!("Focuses on downstream dependencies");
|
|
|
|
let causal = CausalConeAttention::new(CausalConeConfig {
|
|
time_window_ms: 1000,
|
|
future_discount: 0.85,
|
|
ancestor_weight: 0.5,
|
|
});
|
|
|
|
let causal_scores = causal.forward(&dag).unwrap();
|
|
println!("\nCausal cone scores:");
|
|
for (node_id, score) in &causal_scores {
|
|
let node = dag.get_node(*node_id).unwrap();
|
|
println!(" Node {}: {:.4} - {:?}", node_id, score, node.op_type);
|
|
}
|
|
|
|
// Compare mechanisms
|
|
println!("\n--- Comparison ---");
|
|
println!("Node | Topological | Causal Cone | Difference");
|
|
println!("-----|-------------|-------------|------------");
|
|
for node_id in 0..dag.node_count() {
|
|
let topo_score = scores.get(&node_id).unwrap_or(&0.0);
|
|
let causal_score = causal_scores.get(&node_id).unwrap_or(&0.0);
|
|
let diff = (topo_score - causal_score).abs();
|
|
println!(
|
|
"{:4} | {:11.4} | {:11.4} | {:11.4}",
|
|
node_id, topo_score, causal_score, diff
|
|
);
|
|
}
|
|
|
|
println!("\n=== Example Complete ===");
|
|
}
|
|
|
|
fn create_vector_search_dag() -> QueryDag {
|
|
let mut dag = QueryDag::new();
|
|
|
|
// HNSW scan - the primary vector search
|
|
let hnsw = dag.add_node(OperatorNode::hnsw_scan(0, "embeddings_idx", 64));
|
|
|
|
// Metadata table scan
|
|
let meta = dag.add_node(OperatorNode::seq_scan(1, "metadata"));
|
|
|
|
// Join embeddings with metadata
|
|
let join = dag.add_node(OperatorNode::new(2, OperatorType::NestedLoopJoin));
|
|
|
|
dag.add_edge(hnsw, join).unwrap();
|
|
dag.add_edge(meta, join).unwrap();
|
|
|
|
// Filter by category
|
|
let filter = dag.add_node(OperatorNode::filter(3, "category = 'tech'"));
|
|
dag.add_edge(join, filter).unwrap();
|
|
|
|
// Limit results
|
|
let limit = dag.add_node(OperatorNode::limit(4, 10));
|
|
dag.add_edge(filter, limit).unwrap();
|
|
|
|
// Result node
|
|
let result = dag.add_node(OperatorNode::new(5, OperatorType::Result));
|
|
dag.add_edge(limit, result).unwrap();
|
|
|
|
dag
|
|
}
|