Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,11 @@
//! Core DAG data structures and algorithms
mod operator_node;
mod query_dag;
mod serialization;
mod traversal;
pub use operator_node::{OperatorNode, OperatorType};
pub use query_dag::{DagError, QueryDag};
pub use serialization::{DagDeserializer, DagSerializer};
pub use traversal::{BfsIterator, DfsIterator, TopologicalIterator};

View File

@@ -0,0 +1,294 @@
//! Operator node types and definitions for query DAG
use serde::{Deserialize, Serialize};
/// Types of operators in a query DAG
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum OperatorType {
// Scan operators
SeqScan {
table: String,
},
IndexScan {
index: String,
table: String,
},
HnswScan {
index: String,
ef_search: u32,
},
IvfFlatScan {
index: String,
nprobe: u32,
},
// Join operators
NestedLoopJoin,
HashJoin {
hash_key: String,
},
MergeJoin {
merge_key: String,
},
// Aggregation
Aggregate {
functions: Vec<String>,
},
GroupBy {
keys: Vec<String>,
},
// Filter/Project
Filter {
predicate: String,
},
Project {
columns: Vec<String>,
},
// Sort/Limit
Sort {
keys: Vec<String>,
descending: Vec<bool>,
},
Limit {
count: usize,
},
// Vector operations
VectorDistance {
metric: String,
},
Rerank {
model: String,
},
// Utility
Materialize,
Result,
// Backward compatibility variants (deprecated, use specific variants above)
#[deprecated(note = "Use SeqScan instead")]
Scan,
#[deprecated(note = "Use HashJoin or NestedLoopJoin instead")]
Join,
}
/// A node in the query DAG
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OperatorNode {
pub id: usize,
pub op_type: OperatorType,
pub estimated_rows: f64,
pub estimated_cost: f64,
pub actual_rows: Option<f64>,
pub actual_time_ms: Option<f64>,
pub embedding: Option<Vec<f32>>,
}
impl OperatorNode {
/// Create a new operator node
pub fn new(id: usize, op_type: OperatorType) -> Self {
Self {
id,
op_type,
estimated_rows: 0.0,
estimated_cost: 0.0,
actual_rows: None,
actual_time_ms: None,
embedding: None,
}
}
/// Create a sequential scan node
pub fn seq_scan(id: usize, table: &str) -> Self {
Self::new(
id,
OperatorType::SeqScan {
table: table.to_string(),
},
)
}
/// Create an index scan node
pub fn index_scan(id: usize, index: &str, table: &str) -> Self {
Self::new(
id,
OperatorType::IndexScan {
index: index.to_string(),
table: table.to_string(),
},
)
}
/// Create an HNSW scan node
pub fn hnsw_scan(id: usize, index: &str, ef_search: u32) -> Self {
Self::new(
id,
OperatorType::HnswScan {
index: index.to_string(),
ef_search,
},
)
}
/// Create an IVF-Flat scan node
pub fn ivf_flat_scan(id: usize, index: &str, nprobe: u32) -> Self {
Self::new(
id,
OperatorType::IvfFlatScan {
index: index.to_string(),
nprobe,
},
)
}
/// Create a nested loop join node
pub fn nested_loop_join(id: usize) -> Self {
Self::new(id, OperatorType::NestedLoopJoin)
}
/// Create a hash join node
pub fn hash_join(id: usize, key: &str) -> Self {
Self::new(
id,
OperatorType::HashJoin {
hash_key: key.to_string(),
},
)
}
/// Create a merge join node
pub fn merge_join(id: usize, key: &str) -> Self {
Self::new(
id,
OperatorType::MergeJoin {
merge_key: key.to_string(),
},
)
}
/// Create a filter node
pub fn filter(id: usize, predicate: &str) -> Self {
Self::new(
id,
OperatorType::Filter {
predicate: predicate.to_string(),
},
)
}
/// Create a project node
pub fn project(id: usize, columns: Vec<String>) -> Self {
Self::new(id, OperatorType::Project { columns })
}
/// Create a sort node
pub fn sort(id: usize, keys: Vec<String>) -> Self {
let descending = vec![false; keys.len()];
Self::new(id, OperatorType::Sort { keys, descending })
}
/// Create a sort node with descending flags
pub fn sort_with_order(id: usize, keys: Vec<String>, descending: Vec<bool>) -> Self {
Self::new(id, OperatorType::Sort { keys, descending })
}
/// Create a limit node
pub fn limit(id: usize, count: usize) -> Self {
Self::new(id, OperatorType::Limit { count })
}
/// Create an aggregate node
pub fn aggregate(id: usize, functions: Vec<String>) -> Self {
Self::new(id, OperatorType::Aggregate { functions })
}
/// Create a group by node
pub fn group_by(id: usize, keys: Vec<String>) -> Self {
Self::new(id, OperatorType::GroupBy { keys })
}
/// Create a vector distance node
pub fn vector_distance(id: usize, metric: &str) -> Self {
Self::new(
id,
OperatorType::VectorDistance {
metric: metric.to_string(),
},
)
}
/// Create a rerank node
pub fn rerank(id: usize, model: &str) -> Self {
Self::new(
id,
OperatorType::Rerank {
model: model.to_string(),
},
)
}
/// Create a materialize node
pub fn materialize(id: usize) -> Self {
Self::new(id, OperatorType::Materialize)
}
/// Create a result node
pub fn result(id: usize) -> Self {
Self::new(id, OperatorType::Result)
}
/// Set estimated statistics
pub fn with_estimates(mut self, rows: f64, cost: f64) -> Self {
self.estimated_rows = rows;
self.estimated_cost = cost;
self
}
/// Set actual statistics
pub fn with_actuals(mut self, rows: f64, time_ms: f64) -> Self {
self.actual_rows = Some(rows);
self.actual_time_ms = Some(time_ms);
self
}
/// Set embedding vector
pub fn with_embedding(mut self, embedding: Vec<f32>) -> Self {
self.embedding = Some(embedding);
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_operator_node_creation() {
let node = OperatorNode::seq_scan(1, "users");
assert_eq!(node.id, 1);
assert!(matches!(node.op_type, OperatorType::SeqScan { .. }));
}
#[test]
fn test_builder_pattern() {
let node = OperatorNode::hash_join(2, "id")
.with_estimates(1000.0, 50.0)
.with_actuals(987.0, 45.2);
assert_eq!(node.estimated_rows, 1000.0);
assert_eq!(node.estimated_cost, 50.0);
assert_eq!(node.actual_rows, Some(987.0));
assert_eq!(node.actual_time_ms, Some(45.2));
}
#[test]
fn test_serialization() {
let node = OperatorNode::hnsw_scan(3, "embeddings_idx", 100);
let json = serde_json::to_string(&node).unwrap();
let deserialized: OperatorNode = serde_json::from_str(&json).unwrap();
assert_eq!(node.id, deserialized.id);
}
}

View File

@@ -0,0 +1,452 @@
//! Core query DAG data structure
use std::collections::{HashMap, HashSet, VecDeque};
use super::operator_node::OperatorNode;
/// Error types for DAG operations
#[derive(Debug, thiserror::Error)]
pub enum DagError {
#[error("Node {0} not found")]
NodeNotFound(usize),
#[error("Adding edge would create cycle")]
CycleDetected,
#[error("Invalid operation: {0}")]
InvalidOperation(String),
#[error("DAG has cycles, cannot perform topological sort")]
HasCycles,
}
/// A Directed Acyclic Graph representing a query plan
#[derive(Debug, Clone)]
pub struct QueryDag {
pub(crate) nodes: HashMap<usize, OperatorNode>,
pub(crate) edges: HashMap<usize, Vec<usize>>, // parent -> children
pub(crate) reverse_edges: HashMap<usize, Vec<usize>>, // child -> parents
pub(crate) root: Option<usize>,
next_id: usize,
}
impl QueryDag {
/// Create a new empty DAG
pub fn new() -> Self {
Self {
nodes: HashMap::new(),
edges: HashMap::new(),
reverse_edges: HashMap::new(),
root: None,
next_id: 0,
}
}
/// Add a node to the DAG, returns the node ID
pub fn add_node(&mut self, mut node: OperatorNode) -> usize {
let id = self.next_id;
self.next_id += 1;
node.id = id;
self.nodes.insert(id, node);
self.edges.insert(id, Vec::new());
self.reverse_edges.insert(id, Vec::new());
// If this is the first node, set it as root
if self.nodes.len() == 1 {
self.root = Some(id);
}
id
}
/// Add an edge from parent to child
pub fn add_edge(&mut self, parent: usize, child: usize) -> Result<(), DagError> {
// Check both nodes exist
if !self.nodes.contains_key(&parent) {
return Err(DagError::NodeNotFound(parent));
}
if !self.nodes.contains_key(&child) {
return Err(DagError::NodeNotFound(child));
}
// Check if adding this edge would create a cycle
if self.would_create_cycle(parent, child) {
return Err(DagError::CycleDetected);
}
// Add edge
self.edges.get_mut(&parent).unwrap().push(child);
self.reverse_edges.get_mut(&child).unwrap().push(parent);
// Update root if child was previously root and now has parents
if self.root == Some(child) && !self.reverse_edges[&child].is_empty() {
// Find new root (node with no parents)
self.root = self
.nodes
.keys()
.find(|&&id| self.reverse_edges[&id].is_empty())
.copied();
}
Ok(())
}
/// Remove a node from the DAG
pub fn remove_node(&mut self, id: usize) -> Option<OperatorNode> {
let node = self.nodes.remove(&id)?;
// Remove all edges involving this node
if let Some(children) = self.edges.remove(&id) {
for child in children {
if let Some(parents) = self.reverse_edges.get_mut(&child) {
parents.retain(|&p| p != id);
}
}
}
if let Some(parents) = self.reverse_edges.remove(&id) {
for parent in parents {
if let Some(children) = self.edges.get_mut(&parent) {
children.retain(|&c| c != id);
}
}
}
// Update root if necessary
if self.root == Some(id) {
self.root = self
.nodes
.keys()
.find(|&&nid| self.reverse_edges[&nid].is_empty())
.copied();
}
Some(node)
}
/// Get a reference to a node
pub fn get_node(&self, id: usize) -> Option<&OperatorNode> {
self.nodes.get(&id)
}
/// Get a mutable reference to a node
pub fn get_node_mut(&mut self, id: usize) -> Option<&mut OperatorNode> {
self.nodes.get_mut(&id)
}
/// Get children of a node
pub fn children(&self, id: usize) -> &[usize] {
self.edges.get(&id).map(|v| v.as_slice()).unwrap_or(&[])
}
/// Get parents of a node
pub fn parents(&self, id: usize) -> &[usize] {
self.reverse_edges
.get(&id)
.map(|v| v.as_slice())
.unwrap_or(&[])
}
/// Get the root node ID
pub fn root(&self) -> Option<usize> {
self.root
}
/// Get all leaf nodes (nodes with no children)
pub fn leaves(&self) -> Vec<usize> {
self.nodes
.keys()
.filter(|&&id| self.edges[&id].is_empty())
.copied()
.collect()
}
/// Get the number of nodes
pub fn node_count(&self) -> usize {
self.nodes.len()
}
/// Get the number of edges
pub fn edge_count(&self) -> usize {
self.edges.values().map(|v| v.len()).sum()
}
/// Get iterator over node IDs
pub fn node_ids(&self) -> impl Iterator<Item = usize> + '_ {
self.nodes.keys().copied()
}
/// Get iterator over all nodes
pub fn nodes(&self) -> impl Iterator<Item = &OperatorNode> + '_ {
self.nodes.values()
}
/// Check if adding an edge would create a cycle
fn would_create_cycle(&self, from: usize, to: usize) -> bool {
// If 'to' can reach 'from', adding edge from->to would create cycle
self.can_reach(to, from)
}
/// Check if 'from' can reach 'to' through existing edges
fn can_reach(&self, from: usize, to: usize) -> bool {
if from == to {
return true;
}
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
queue.push_back(from);
visited.insert(from);
while let Some(current) = queue.pop_front() {
if current == to {
return true;
}
if let Some(children) = self.edges.get(&current) {
for &child in children {
if visited.insert(child) {
queue.push_back(child);
}
}
}
}
false
}
/// Compute depth of each node from leaves (leaves have depth 0)
pub fn compute_depths(&self) -> HashMap<usize, usize> {
let mut depths = HashMap::new();
let mut visited = HashSet::new();
// Start from leaves
let leaves = self.leaves();
let mut queue: VecDeque<(usize, usize)> = leaves.iter().map(|&id| (id, 0)).collect();
for &leaf in &leaves {
visited.insert(leaf);
depths.insert(leaf, 0);
}
while let Some((node, depth)) = queue.pop_front() {
depths.insert(node, depth);
// Process parents
if let Some(parents) = self.reverse_edges.get(&node) {
for &parent in parents {
if visited.insert(parent) {
queue.push_back((parent, depth + 1));
} else {
// Update depth if we found a longer path
let current_depth = depths.get(&parent).copied().unwrap_or(0);
if depth + 1 > current_depth {
depths.insert(parent, depth + 1);
queue.push_back((parent, depth + 1));
}
}
}
}
}
depths
}
/// Get all ancestors of a node
pub fn ancestors(&self, id: usize) -> HashSet<usize> {
let mut result = HashSet::new();
let mut queue = VecDeque::new();
if let Some(parents) = self.reverse_edges.get(&id) {
for &parent in parents {
queue.push_back(parent);
result.insert(parent);
}
}
while let Some(node) = queue.pop_front() {
if let Some(parents) = self.reverse_edges.get(&node) {
for &parent in parents {
if result.insert(parent) {
queue.push_back(parent);
}
}
}
}
result
}
/// Get all descendants of a node
pub fn descendants(&self, id: usize) -> HashSet<usize> {
let mut result = HashSet::new();
let mut queue = VecDeque::new();
if let Some(children) = self.edges.get(&id) {
for &child in children {
queue.push_back(child);
result.insert(child);
}
}
while let Some(node) = queue.pop_front() {
if let Some(children) = self.edges.get(&node) {
for &child in children {
if result.insert(child) {
queue.push_back(child);
}
}
}
}
result
}
/// Return nodes in topological order as Vec (dependencies first)
pub fn topological_sort(&self) -> Result<Vec<usize>, DagError> {
let mut result = Vec::new();
let mut in_degree: HashMap<usize, usize> = self
.nodes
.keys()
.map(|&id| (id, self.reverse_edges[&id].len()))
.collect();
let mut queue: VecDeque<usize> = in_degree
.iter()
.filter(|(_, &degree)| degree == 0)
.map(|(&id, _)| id)
.collect();
while let Some(node) = queue.pop_front() {
result.push(node);
if let Some(children) = self.edges.get(&node) {
for &child in children {
let degree = in_degree.get_mut(&child).unwrap();
*degree -= 1;
if *degree == 0 {
queue.push_back(child);
}
}
}
}
if result.len() != self.nodes.len() {
return Err(DagError::HasCycles);
}
Ok(result)
}
}
impl Default for QueryDag {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::OperatorNode;
#[test]
fn test_new_dag() {
let dag = QueryDag::new();
assert_eq!(dag.node_count(), 0);
assert_eq!(dag.edge_count(), 0);
}
#[test]
fn test_add_nodes() {
let mut dag = QueryDag::new();
let id1 = dag.add_node(OperatorNode::seq_scan(0, "users"));
let id2 = dag.add_node(OperatorNode::filter(0, "age > 18"));
assert_eq!(dag.node_count(), 2);
assert!(dag.get_node(id1).is_some());
assert!(dag.get_node(id2).is_some());
}
#[test]
fn test_add_edges() {
let mut dag = QueryDag::new();
let id1 = dag.add_node(OperatorNode::seq_scan(0, "users"));
let id2 = dag.add_node(OperatorNode::filter(0, "age > 18"));
assert!(dag.add_edge(id1, id2).is_ok());
assert_eq!(dag.edge_count(), 1);
assert_eq!(dag.children(id1), &[id2]);
assert_eq!(dag.parents(id2), &[id1]);
}
#[test]
fn test_cycle_detection() {
let mut dag = QueryDag::new();
let id1 = dag.add_node(OperatorNode::seq_scan(0, "users"));
let id2 = dag.add_node(OperatorNode::filter(0, "age > 18"));
let id3 = dag.add_node(OperatorNode::sort(0, vec!["name".to_string()]));
dag.add_edge(id1, id2).unwrap();
dag.add_edge(id2, id3).unwrap();
// This would create a cycle
assert!(matches!(
dag.add_edge(id3, id1),
Err(DagError::CycleDetected)
));
}
#[test]
fn test_topological_sort() {
let mut dag = QueryDag::new();
let id1 = dag.add_node(OperatorNode::seq_scan(0, "users"));
let id2 = dag.add_node(OperatorNode::filter(0, "age > 18"));
let id3 = dag.add_node(OperatorNode::sort(0, vec!["name".to_string()]));
dag.add_edge(id1, id2).unwrap();
dag.add_edge(id2, id3).unwrap();
let sorted = dag.topological_sort().unwrap();
assert_eq!(sorted.len(), 3);
// id1 should come before id2, id2 before id3
let pos1 = sorted.iter().position(|&x| x == id1).unwrap();
let pos2 = sorted.iter().position(|&x| x == id2).unwrap();
let pos3 = sorted.iter().position(|&x| x == id3).unwrap();
assert!(pos1 < pos2);
assert!(pos2 < pos3);
}
#[test]
fn test_remove_node() {
let mut dag = QueryDag::new();
let id1 = dag.add_node(OperatorNode::seq_scan(0, "users"));
let id2 = dag.add_node(OperatorNode::filter(0, "age > 18"));
dag.add_edge(id1, id2).unwrap();
let removed = dag.remove_node(id1);
assert!(removed.is_some());
assert_eq!(dag.node_count(), 1);
assert_eq!(dag.edge_count(), 0);
}
#[test]
fn test_ancestors_descendants() {
let mut dag = QueryDag::new();
let id1 = dag.add_node(OperatorNode::seq_scan(0, "users"));
let id2 = dag.add_node(OperatorNode::filter(0, "age > 18"));
let id3 = dag.add_node(OperatorNode::sort(0, vec!["name".to_string()]));
dag.add_edge(id1, id2).unwrap();
dag.add_edge(id2, id3).unwrap();
let ancestors = dag.ancestors(id3);
assert!(ancestors.contains(&id1));
assert!(ancestors.contains(&id2));
let descendants = dag.descendants(id1);
assert!(descendants.contains(&id2));
assert!(descendants.contains(&id3));
}
}

View File

@@ -0,0 +1,184 @@
//! DAG serialization and deserialization
use serde::{Deserialize, Serialize};
use super::operator_node::OperatorNode;
use super::query_dag::{DagError, QueryDag};
/// Serializable representation of a DAG
#[derive(Debug, Clone, Serialize, Deserialize)]
struct SerializableDag {
nodes: Vec<OperatorNode>,
edges: Vec<(usize, usize)>, // (parent, child) pairs
root: Option<usize>,
}
/// Trait for DAG serialization
pub trait DagSerializer {
/// Serialize to JSON string
fn to_json(&self) -> Result<String, serde_json::Error>;
/// Serialize to bytes (using bincode-like format via JSON for now)
fn to_bytes(&self) -> Vec<u8>;
}
/// Trait for DAG deserialization
pub trait DagDeserializer {
/// Deserialize from JSON string
fn from_json(json: &str) -> Result<Self, serde_json::Error>
where
Self: Sized;
/// Deserialize from bytes
fn from_bytes(bytes: &[u8]) -> Result<Self, DagError>
where
Self: Sized;
}
impl DagSerializer for QueryDag {
fn to_json(&self) -> Result<String, serde_json::Error> {
let nodes: Vec<OperatorNode> = self.nodes.values().cloned().collect();
let mut edges = Vec::new();
for (&parent, children) in &self.edges {
for &child in children {
edges.push((parent, child));
}
}
let serializable = SerializableDag {
nodes,
edges,
root: self.root,
};
serde_json::to_string_pretty(&serializable)
}
fn to_bytes(&self) -> Vec<u8> {
// For now, use JSON as bytes. In production, use bincode or similar
self.to_json().unwrap_or_default().into_bytes()
}
}
impl DagDeserializer for QueryDag {
fn from_json(json: &str) -> Result<Self, serde_json::Error> {
let serializable: SerializableDag = serde_json::from_str(json)?;
let mut dag = QueryDag::new();
// Create a mapping from old IDs to new IDs
let mut id_map = std::collections::HashMap::new();
// Add all nodes
for node in serializable.nodes {
let old_id = node.id;
let new_id = dag.add_node(node);
id_map.insert(old_id, new_id);
}
// Add all edges using mapped IDs
for (parent, child) in serializable.edges {
if let (Some(&new_parent), Some(&new_child)) = (id_map.get(&parent), id_map.get(&child))
{
// Ignore errors from edge addition during deserialization
let _ = dag.add_edge(new_parent, new_child);
}
}
// Map root if it exists
if let Some(old_root) = serializable.root {
dag.root = id_map.get(&old_root).copied();
}
Ok(dag)
}
fn from_bytes(bytes: &[u8]) -> Result<Self, DagError> {
let json = String::from_utf8(bytes.to_vec())
.map_err(|e| DagError::InvalidOperation(format!("Invalid UTF-8: {}", e)))?;
Self::from_json(&json)
.map_err(|e| DagError::InvalidOperation(format!("Deserialization failed: {}", e)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::OperatorNode;
#[test]
fn test_json_serialization() {
let mut dag = QueryDag::new();
let id1 = dag.add_node(OperatorNode::seq_scan(0, "users"));
let id2 = dag.add_node(OperatorNode::filter(0, "age > 18"));
let id3 = dag.add_node(OperatorNode::sort(0, vec!["name".to_string()]));
dag.add_edge(id1, id2).unwrap();
dag.add_edge(id2, id3).unwrap();
// Serialize
let json = dag.to_json().unwrap();
assert!(!json.is_empty());
// Deserialize
let deserialized = QueryDag::from_json(&json).unwrap();
assert_eq!(deserialized.node_count(), 3);
assert_eq!(deserialized.edge_count(), 2);
}
#[test]
fn test_bytes_serialization() {
let mut dag = QueryDag::new();
let id1 = dag.add_node(OperatorNode::seq_scan(0, "users"));
let id2 = dag.add_node(OperatorNode::filter(0, "age > 18"));
dag.add_edge(id1, id2).unwrap();
// Serialize to bytes
let bytes = dag.to_bytes();
assert!(!bytes.is_empty());
// Deserialize from bytes
let deserialized = QueryDag::from_bytes(&bytes).unwrap();
assert_eq!(deserialized.node_count(), 2);
assert_eq!(deserialized.edge_count(), 1);
}
#[test]
fn test_complex_dag_roundtrip() {
let mut dag = QueryDag::new();
// Create a more complex DAG
let scan1 = dag.add_node(OperatorNode::seq_scan(0, "users"));
let scan2 = dag.add_node(OperatorNode::seq_scan(0, "orders"));
let join = dag.add_node(OperatorNode::hash_join(0, "user_id"));
let filter = dag.add_node(OperatorNode::filter(0, "total > 100"));
let sort = dag.add_node(OperatorNode::sort(0, vec!["date".to_string()]));
let limit = dag.add_node(OperatorNode::limit(0, 10));
dag.add_edge(scan1, join).unwrap();
dag.add_edge(scan2, join).unwrap();
dag.add_edge(join, filter).unwrap();
dag.add_edge(filter, sort).unwrap();
dag.add_edge(sort, limit).unwrap();
// Round trip
let json = dag.to_json().unwrap();
let restored = QueryDag::from_json(&json).unwrap();
assert_eq!(restored.node_count(), dag.node_count());
assert_eq!(restored.edge_count(), dag.edge_count());
}
#[test]
fn test_empty_dag_serialization() {
let dag = QueryDag::new();
let json = dag.to_json().unwrap();
let restored = QueryDag::from_json(&json).unwrap();
assert_eq!(restored.node_count(), 0);
assert_eq!(restored.edge_count(), 0);
}
}

View File

@@ -0,0 +1,228 @@
//! DAG traversal algorithms and iterators
use std::collections::{HashSet, VecDeque};
use super::query_dag::{DagError, QueryDag};
/// Iterator for topological order traversal (dependencies first)
pub struct TopologicalIterator<'a> {
#[allow(dead_code)]
dag: &'a QueryDag,
sorted: Vec<usize>,
index: usize,
}
impl<'a> TopologicalIterator<'a> {
pub(crate) fn new(dag: &'a QueryDag) -> Result<Self, DagError> {
let sorted = dag.topological_sort()?;
Ok(Self {
dag,
sorted,
index: 0,
})
}
}
impl<'a> Iterator for TopologicalIterator<'a> {
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
if self.index < self.sorted.len() {
let id = self.sorted[self.index];
self.index += 1;
Some(id)
} else {
None
}
}
}
/// Iterator for depth-first search traversal
pub struct DfsIterator<'a> {
dag: &'a QueryDag,
stack: Vec<usize>,
visited: HashSet<usize>,
}
impl<'a> DfsIterator<'a> {
pub(crate) fn new(dag: &'a QueryDag, start: usize) -> Self {
let mut stack = Vec::new();
let visited = HashSet::new();
if dag.get_node(start).is_some() {
stack.push(start);
}
Self {
dag,
stack,
visited,
}
}
}
impl<'a> Iterator for DfsIterator<'a> {
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
while let Some(node) = self.stack.pop() {
if self.visited.insert(node) {
// Add children to stack (in reverse order so they're processed in order)
if let Some(children) = self.dag.edges.get(&node) {
for &child in children.iter().rev() {
if !self.visited.contains(&child) {
self.stack.push(child);
}
}
}
return Some(node);
}
}
None
}
}
/// Iterator for breadth-first search traversal
pub struct BfsIterator<'a> {
dag: &'a QueryDag,
queue: VecDeque<usize>,
visited: HashSet<usize>,
}
impl<'a> BfsIterator<'a> {
pub(crate) fn new(dag: &'a QueryDag, start: usize) -> Self {
let mut queue = VecDeque::new();
let visited = HashSet::new();
if dag.get_node(start).is_some() {
queue.push_back(start);
}
Self {
dag,
queue,
visited,
}
}
}
impl<'a> Iterator for BfsIterator<'a> {
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
while let Some(node) = self.queue.pop_front() {
if self.visited.insert(node) {
// Add children to queue
if let Some(children) = self.dag.edges.get(&node) {
for &child in children {
if !self.visited.contains(&child) {
self.queue.push_back(child);
}
}
}
return Some(node);
}
}
None
}
}
impl QueryDag {
/// Create an iterator for topological order traversal
pub fn topological_iter(&self) -> Result<TopologicalIterator<'_>, DagError> {
TopologicalIterator::new(self)
}
/// Create an iterator for depth-first search starting from a node
pub fn dfs_iter(&self, start: usize) -> DfsIterator<'_> {
DfsIterator::new(self, start)
}
/// Create an iterator for breadth-first search starting from a node
pub fn bfs_iter(&self, start: usize) -> BfsIterator<'_> {
BfsIterator::new(self, start)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::OperatorNode;
fn create_test_dag() -> QueryDag {
let mut dag = QueryDag::new();
let id1 = dag.add_node(OperatorNode::seq_scan(0, "users"));
let id2 = dag.add_node(OperatorNode::filter(0, "age > 18"));
let id3 = dag.add_node(OperatorNode::sort(0, vec!["name".to_string()]));
let id4 = dag.add_node(OperatorNode::limit(0, 10));
dag.add_edge(id1, id2).unwrap();
dag.add_edge(id2, id3).unwrap();
dag.add_edge(id3, id4).unwrap();
dag
}
#[test]
fn test_topological_iterator() {
let dag = create_test_dag();
let nodes: Vec<usize> = dag.topological_iter().unwrap().collect();
assert_eq!(nodes.len(), 4);
// Check ordering constraints
let pos: Vec<usize> = (0..4)
.map(|i| nodes.iter().position(|&x| x == i).unwrap())
.collect();
assert!(pos[0] < pos[1]); // 0 before 1
assert!(pos[1] < pos[2]); // 1 before 2
assert!(pos[2] < pos[3]); // 2 before 3
}
#[test]
fn test_dfs_iterator() {
let dag = create_test_dag();
let nodes: Vec<usize> = dag.dfs_iter(0).collect();
assert_eq!(nodes.len(), 4);
assert_eq!(nodes[0], 0); // Should start from node 0
}
#[test]
fn test_bfs_iterator() {
let dag = create_test_dag();
let nodes: Vec<usize> = dag.bfs_iter(0).collect();
assert_eq!(nodes.len(), 4);
assert_eq!(nodes[0], 0); // Should start from node 0
}
#[test]
fn test_branching_dag() {
let mut dag = QueryDag::new();
let root = dag.add_node(OperatorNode::seq_scan(0, "users"));
let left1 = dag.add_node(OperatorNode::filter(0, "age > 18"));
let left2 = dag.add_node(OperatorNode::project(0, vec!["name".to_string()]));
let right1 = dag.add_node(OperatorNode::filter(0, "active = true"));
let join = dag.add_node(OperatorNode::hash_join(0, "id"));
dag.add_edge(root, left1).unwrap();
dag.add_edge(left1, left2).unwrap();
dag.add_edge(root, right1).unwrap();
dag.add_edge(left2, join).unwrap();
dag.add_edge(right1, join).unwrap();
// BFS should visit level by level
let bfs_nodes: Vec<usize> = dag.bfs_iter(root).collect();
assert_eq!(bfs_nodes.len(), 5);
// Topological sort should respect dependencies
let topo_nodes = dag.topological_sort().unwrap();
assert_eq!(topo_nodes.len(), 5);
let pos_root = topo_nodes.iter().position(|&x| x == root).unwrap();
let pos_join = topo_nodes.iter().position(|&x| x == join).unwrap();
assert!(pos_root < pos_join);
}
}