Files
wifi-densepose/crates/ruvector-dag/tests/fixtures/dag_generator.rs
ruv d803bfe2b1 Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector
git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
2026-02-28 14:39:40 -05:00

212 lines
6.7 KiB
Rust

//! DAG Generator for testing
use ruvector_dag::dag::{QueryDag, OperatorNode, OperatorType};
use rand::Rng;
#[derive(Debug, Clone, Copy)]
pub enum DagComplexity {
Simple, // 3-5 nodes, linear
Medium, // 10-20 nodes, some branches
Complex, // 50-100 nodes, many branches
VectorQuery, // Typical vector search pattern
}
pub struct DagGenerator {
rng: rand::rngs::ThreadRng,
}
impl DagGenerator {
pub fn new() -> Self {
Self {
rng: rand::thread_rng(),
}
}
pub fn generate(&mut self, complexity: DagComplexity) -> QueryDag {
match complexity {
DagComplexity::Simple => self.generate_simple(),
DagComplexity::Medium => self.generate_medium(),
DagComplexity::Complex => self.generate_complex(),
DagComplexity::VectorQuery => self.generate_vector_query(),
}
}
fn generate_simple(&mut self) -> QueryDag {
let mut dag = QueryDag::new();
// Simple: Scan -> Filter -> Result
let scan = dag.add_node(OperatorNode::seq_scan(0, "users"));
let filter = dag.add_node(OperatorNode::filter(1, "id > 0"));
let result = dag.add_node(OperatorNode::new(2, OperatorType::Result));
dag.add_edge(scan, filter).unwrap();
dag.add_edge(filter, result).unwrap();
dag
}
fn generate_medium(&mut self) -> QueryDag {
let mut dag = QueryDag::new();
let mut id = 0;
// Two table join with aggregation
let scan1 = dag.add_node(OperatorNode::seq_scan(id, "orders")); id += 1;
let scan2 = dag.add_node(OperatorNode::seq_scan(id, "products")); id += 1;
let join = dag.add_node(OperatorNode::hash_join(id, "product_id")); id += 1;
dag.add_edge(scan1, join).unwrap();
dag.add_edge(scan2, join).unwrap();
let filter = dag.add_node(OperatorNode::filter(id, "amount > 100")); id += 1;
dag.add_edge(join, filter).unwrap();
let agg = dag.add_node(OperatorNode::new(id, OperatorType::Aggregate {
functions: vec!["SUM(amount)".to_string()],
})); id += 1;
dag.add_edge(filter, agg).unwrap();
let sort = dag.add_node(OperatorNode::sort(id, vec!["total".to_string()])); id += 1;
dag.add_edge(agg, sort).unwrap();
let limit = dag.add_node(OperatorNode::limit(id, 10)); id += 1;
dag.add_edge(sort, limit).unwrap();
let result = dag.add_node(OperatorNode::new(id, OperatorType::Result));
dag.add_edge(limit, result).unwrap();
dag
}
fn generate_complex(&mut self) -> QueryDag {
let mut dag = QueryDag::new();
let node_count = self.rng.gen_range(50..100);
// Generate nodes
for i in 0..node_count {
let op_type = self.random_operator_type(i);
let mut node = OperatorNode::new(i, op_type);
node.estimated_cost = self.rng.gen_range(1.0..1000.0);
node.estimated_rows = self.rng.gen_range(1.0..100000.0);
dag.add_node(node);
}
// Generate edges (ensuring DAG property)
for i in 1..node_count {
let parent_count = self.rng.gen_range(1..=2.min(i));
for _ in 0..parent_count {
let parent = self.rng.gen_range(0..i);
let _ = dag.add_edge(parent, i);
}
}
dag
}
fn generate_vector_query(&mut self) -> QueryDag {
let mut dag = QueryDag::new();
let mut id = 0;
// Vector search with join to metadata
let hnsw = dag.add_node(OperatorNode::hnsw_scan(id, "vectors_idx", 64)); id += 1;
let meta_scan = dag.add_node(OperatorNode::seq_scan(id, "metadata")); id += 1;
let join = dag.add_node(OperatorNode::new(id, OperatorType::NestedLoopJoin)); id += 1;
dag.add_edge(hnsw, join).unwrap();
dag.add_edge(meta_scan, join).unwrap();
let rerank = dag.add_node(OperatorNode::new(id, OperatorType::Rerank {
model: "cross-encoder".to_string(),
})); id += 1;
dag.add_edge(join, rerank).unwrap();
let limit = dag.add_node(OperatorNode::limit(id, 10)); id += 1;
dag.add_edge(rerank, limit).unwrap();
let result = dag.add_node(OperatorNode::new(id, OperatorType::Result));
dag.add_edge(limit, result).unwrap();
dag
}
fn random_operator_type(&mut self, id: usize) -> OperatorType {
match self.rng.gen_range(0..10) {
0 => OperatorType::SeqScan { table: format!("table_{}", id) },
1 => OperatorType::IndexScan {
index: format!("idx_{}", id),
table: format!("table_{}", id)
},
2 => OperatorType::HnswScan {
index: format!("hnsw_{}", id),
ef_search: 64
},
3 => OperatorType::HashJoin {
hash_key: "id".to_string()
},
4 => OperatorType::Filter {
predicate: "x > 0".to_string()
},
5 => OperatorType::Sort {
keys: vec!["col1".to_string()],
descending: vec![false]
},
6 => OperatorType::Limit { count: 100 },
7 => OperatorType::Aggregate {
functions: vec!["COUNT(*)".to_string()]
},
8 => OperatorType::Project {
columns: vec!["a".to_string(), "b".to_string()]
},
_ => OperatorType::Result,
}
}
}
impl Default for DagGenerator {
fn default() -> Self {
Self::new()
}
}
/// Generate a batch of DAGs
pub fn generate_dag_batch(count: usize, complexity: DagComplexity) -> Vec<QueryDag> {
let mut gen = DagGenerator::new();
(0..count).map(|_| gen.generate(complexity)).collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_simple() {
let mut gen = DagGenerator::new();
let dag = gen.generate(DagComplexity::Simple);
assert_eq!(dag.nodes.len(), 3);
assert_eq!(dag.edges.len(), 2);
}
#[test]
fn test_generate_medium() {
let mut gen = DagGenerator::new();
let dag = gen.generate(DagComplexity::Medium);
assert!(dag.nodes.len() >= 5);
assert!(dag.nodes.len() <= 20);
}
#[test]
fn test_generate_vector_query() {
let mut gen = DagGenerator::new();
let dag = gen.generate(DagComplexity::VectorQuery);
// Should have HNSW scan node
let has_hnsw = dag.nodes.iter().any(|n| matches!(n.op_type, OperatorType::HnswScan { .. }));
assert!(has_hnsw);
}
#[test]
fn test_generate_batch() {
let dags = generate_dag_batch(10, DagComplexity::Simple);
assert_eq!(dags.len(), 10);
}
}