git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
212 lines
6.7 KiB
Rust
212 lines
6.7 KiB
Rust
//! DAG Generator for testing
|
|
|
|
use ruvector_dag::dag::{QueryDag, OperatorNode, OperatorType};
|
|
use rand::Rng;
|
|
|
|
#[derive(Debug, Clone, Copy)]
|
|
pub enum DagComplexity {
|
|
Simple, // 3-5 nodes, linear
|
|
Medium, // 10-20 nodes, some branches
|
|
Complex, // 50-100 nodes, many branches
|
|
VectorQuery, // Typical vector search pattern
|
|
}
|
|
|
|
pub struct DagGenerator {
|
|
rng: rand::rngs::ThreadRng,
|
|
}
|
|
|
|
impl DagGenerator {
|
|
pub fn new() -> Self {
|
|
Self {
|
|
rng: rand::thread_rng(),
|
|
}
|
|
}
|
|
|
|
pub fn generate(&mut self, complexity: DagComplexity) -> QueryDag {
|
|
match complexity {
|
|
DagComplexity::Simple => self.generate_simple(),
|
|
DagComplexity::Medium => self.generate_medium(),
|
|
DagComplexity::Complex => self.generate_complex(),
|
|
DagComplexity::VectorQuery => self.generate_vector_query(),
|
|
}
|
|
}
|
|
|
|
fn generate_simple(&mut self) -> QueryDag {
|
|
let mut dag = QueryDag::new();
|
|
|
|
// Simple: Scan -> Filter -> Result
|
|
let scan = dag.add_node(OperatorNode::seq_scan(0, "users"));
|
|
let filter = dag.add_node(OperatorNode::filter(1, "id > 0"));
|
|
let result = dag.add_node(OperatorNode::new(2, OperatorType::Result));
|
|
|
|
dag.add_edge(scan, filter).unwrap();
|
|
dag.add_edge(filter, result).unwrap();
|
|
|
|
dag
|
|
}
|
|
|
|
fn generate_medium(&mut self) -> QueryDag {
|
|
let mut dag = QueryDag::new();
|
|
let mut id = 0;
|
|
|
|
// Two table join with aggregation
|
|
let scan1 = dag.add_node(OperatorNode::seq_scan(id, "orders")); id += 1;
|
|
let scan2 = dag.add_node(OperatorNode::seq_scan(id, "products")); id += 1;
|
|
|
|
let join = dag.add_node(OperatorNode::hash_join(id, "product_id")); id += 1;
|
|
dag.add_edge(scan1, join).unwrap();
|
|
dag.add_edge(scan2, join).unwrap();
|
|
|
|
let filter = dag.add_node(OperatorNode::filter(id, "amount > 100")); id += 1;
|
|
dag.add_edge(join, filter).unwrap();
|
|
|
|
let agg = dag.add_node(OperatorNode::new(id, OperatorType::Aggregate {
|
|
functions: vec!["SUM(amount)".to_string()],
|
|
})); id += 1;
|
|
dag.add_edge(filter, agg).unwrap();
|
|
|
|
let sort = dag.add_node(OperatorNode::sort(id, vec!["total".to_string()])); id += 1;
|
|
dag.add_edge(agg, sort).unwrap();
|
|
|
|
let limit = dag.add_node(OperatorNode::limit(id, 10)); id += 1;
|
|
dag.add_edge(sort, limit).unwrap();
|
|
|
|
let result = dag.add_node(OperatorNode::new(id, OperatorType::Result));
|
|
dag.add_edge(limit, result).unwrap();
|
|
|
|
dag
|
|
}
|
|
|
|
fn generate_complex(&mut self) -> QueryDag {
|
|
let mut dag = QueryDag::new();
|
|
let node_count = self.rng.gen_range(50..100);
|
|
|
|
// Generate nodes
|
|
for i in 0..node_count {
|
|
let op_type = self.random_operator_type(i);
|
|
let mut node = OperatorNode::new(i, op_type);
|
|
node.estimated_cost = self.rng.gen_range(1.0..1000.0);
|
|
node.estimated_rows = self.rng.gen_range(1.0..100000.0);
|
|
dag.add_node(node);
|
|
}
|
|
|
|
// Generate edges (ensuring DAG property)
|
|
for i in 1..node_count {
|
|
let parent_count = self.rng.gen_range(1..=2.min(i));
|
|
for _ in 0..parent_count {
|
|
let parent = self.rng.gen_range(0..i);
|
|
let _ = dag.add_edge(parent, i);
|
|
}
|
|
}
|
|
|
|
dag
|
|
}
|
|
|
|
fn generate_vector_query(&mut self) -> QueryDag {
|
|
let mut dag = QueryDag::new();
|
|
let mut id = 0;
|
|
|
|
// Vector search with join to metadata
|
|
let hnsw = dag.add_node(OperatorNode::hnsw_scan(id, "vectors_idx", 64)); id += 1;
|
|
let meta_scan = dag.add_node(OperatorNode::seq_scan(id, "metadata")); id += 1;
|
|
|
|
let join = dag.add_node(OperatorNode::new(id, OperatorType::NestedLoopJoin)); id += 1;
|
|
dag.add_edge(hnsw, join).unwrap();
|
|
dag.add_edge(meta_scan, join).unwrap();
|
|
|
|
let rerank = dag.add_node(OperatorNode::new(id, OperatorType::Rerank {
|
|
model: "cross-encoder".to_string(),
|
|
})); id += 1;
|
|
dag.add_edge(join, rerank).unwrap();
|
|
|
|
let limit = dag.add_node(OperatorNode::limit(id, 10)); id += 1;
|
|
dag.add_edge(rerank, limit).unwrap();
|
|
|
|
let result = dag.add_node(OperatorNode::new(id, OperatorType::Result));
|
|
dag.add_edge(limit, result).unwrap();
|
|
|
|
dag
|
|
}
|
|
|
|
fn random_operator_type(&mut self, id: usize) -> OperatorType {
|
|
match self.rng.gen_range(0..10) {
|
|
0 => OperatorType::SeqScan { table: format!("table_{}", id) },
|
|
1 => OperatorType::IndexScan {
|
|
index: format!("idx_{}", id),
|
|
table: format!("table_{}", id)
|
|
},
|
|
2 => OperatorType::HnswScan {
|
|
index: format!("hnsw_{}", id),
|
|
ef_search: 64
|
|
},
|
|
3 => OperatorType::HashJoin {
|
|
hash_key: "id".to_string()
|
|
},
|
|
4 => OperatorType::Filter {
|
|
predicate: "x > 0".to_string()
|
|
},
|
|
5 => OperatorType::Sort {
|
|
keys: vec!["col1".to_string()],
|
|
descending: vec![false]
|
|
},
|
|
6 => OperatorType::Limit { count: 100 },
|
|
7 => OperatorType::Aggregate {
|
|
functions: vec!["COUNT(*)".to_string()]
|
|
},
|
|
8 => OperatorType::Project {
|
|
columns: vec!["a".to_string(), "b".to_string()]
|
|
},
|
|
_ => OperatorType::Result,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Default for DagGenerator {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
/// Generate a batch of DAGs
|
|
pub fn generate_dag_batch(count: usize, complexity: DagComplexity) -> Vec<QueryDag> {
|
|
let mut gen = DagGenerator::new();
|
|
(0..count).map(|_| gen.generate(complexity)).collect()
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_generate_simple() {
|
|
let mut gen = DagGenerator::new();
|
|
let dag = gen.generate(DagComplexity::Simple);
|
|
assert_eq!(dag.nodes.len(), 3);
|
|
assert_eq!(dag.edges.len(), 2);
|
|
}
|
|
|
|
#[test]
|
|
fn test_generate_medium() {
|
|
let mut gen = DagGenerator::new();
|
|
let dag = gen.generate(DagComplexity::Medium);
|
|
assert!(dag.nodes.len() >= 5);
|
|
assert!(dag.nodes.len() <= 20);
|
|
}
|
|
|
|
#[test]
|
|
fn test_generate_vector_query() {
|
|
let mut gen = DagGenerator::new();
|
|
let dag = gen.generate(DagComplexity::VectorQuery);
|
|
|
|
// Should have HNSW scan node
|
|
let has_hnsw = dag.nodes.iter().any(|n| matches!(n.op_type, OperatorType::HnswScan { .. }));
|
|
assert!(has_hnsw);
|
|
}
|
|
|
|
#[test]
|
|
fn test_generate_batch() {
|
|
let dags = generate_dag_batch(10, DagComplexity::Simple);
|
|
assert_eq!(dags.len(), 10);
|
|
}
|
|
}
|