Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
This commit is contained in:
211
crates/ruvector-dag/tests/fixtures/dag_generator.rs
vendored
Normal file
211
crates/ruvector-dag/tests/fixtures/dag_generator.rs
vendored
Normal file
@@ -0,0 +1,211 @@
|
||||
//! DAG Generator for testing
|
||||
|
||||
use ruvector_dag::dag::{QueryDag, OperatorNode, OperatorType};
|
||||
use rand::Rng;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum DagComplexity {
|
||||
Simple, // 3-5 nodes, linear
|
||||
Medium, // 10-20 nodes, some branches
|
||||
Complex, // 50-100 nodes, many branches
|
||||
VectorQuery, // Typical vector search pattern
|
||||
}
|
||||
|
||||
pub struct DagGenerator {
|
||||
rng: rand::rngs::ThreadRng,
|
||||
}
|
||||
|
||||
impl DagGenerator {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
rng: rand::thread_rng(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn generate(&mut self, complexity: DagComplexity) -> QueryDag {
|
||||
match complexity {
|
||||
DagComplexity::Simple => self.generate_simple(),
|
||||
DagComplexity::Medium => self.generate_medium(),
|
||||
DagComplexity::Complex => self.generate_complex(),
|
||||
DagComplexity::VectorQuery => self.generate_vector_query(),
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_simple(&mut self) -> QueryDag {
|
||||
let mut dag = QueryDag::new();
|
||||
|
||||
// Simple: Scan -> Filter -> Result
|
||||
let scan = dag.add_node(OperatorNode::seq_scan(0, "users"));
|
||||
let filter = dag.add_node(OperatorNode::filter(1, "id > 0"));
|
||||
let result = dag.add_node(OperatorNode::new(2, OperatorType::Result));
|
||||
|
||||
dag.add_edge(scan, filter).unwrap();
|
||||
dag.add_edge(filter, result).unwrap();
|
||||
|
||||
dag
|
||||
}
|
||||
|
||||
fn generate_medium(&mut self) -> QueryDag {
|
||||
let mut dag = QueryDag::new();
|
||||
let mut id = 0;
|
||||
|
||||
// Two table join with aggregation
|
||||
let scan1 = dag.add_node(OperatorNode::seq_scan(id, "orders")); id += 1;
|
||||
let scan2 = dag.add_node(OperatorNode::seq_scan(id, "products")); id += 1;
|
||||
|
||||
let join = dag.add_node(OperatorNode::hash_join(id, "product_id")); id += 1;
|
||||
dag.add_edge(scan1, join).unwrap();
|
||||
dag.add_edge(scan2, join).unwrap();
|
||||
|
||||
let filter = dag.add_node(OperatorNode::filter(id, "amount > 100")); id += 1;
|
||||
dag.add_edge(join, filter).unwrap();
|
||||
|
||||
let agg = dag.add_node(OperatorNode::new(id, OperatorType::Aggregate {
|
||||
functions: vec!["SUM(amount)".to_string()],
|
||||
})); id += 1;
|
||||
dag.add_edge(filter, agg).unwrap();
|
||||
|
||||
let sort = dag.add_node(OperatorNode::sort(id, vec!["total".to_string()])); id += 1;
|
||||
dag.add_edge(agg, sort).unwrap();
|
||||
|
||||
let limit = dag.add_node(OperatorNode::limit(id, 10)); id += 1;
|
||||
dag.add_edge(sort, limit).unwrap();
|
||||
|
||||
let result = dag.add_node(OperatorNode::new(id, OperatorType::Result));
|
||||
dag.add_edge(limit, result).unwrap();
|
||||
|
||||
dag
|
||||
}
|
||||
|
||||
fn generate_complex(&mut self) -> QueryDag {
|
||||
let mut dag = QueryDag::new();
|
||||
let node_count = self.rng.gen_range(50..100);
|
||||
|
||||
// Generate nodes
|
||||
for i in 0..node_count {
|
||||
let op_type = self.random_operator_type(i);
|
||||
let mut node = OperatorNode::new(i, op_type);
|
||||
node.estimated_cost = self.rng.gen_range(1.0..1000.0);
|
||||
node.estimated_rows = self.rng.gen_range(1.0..100000.0);
|
||||
dag.add_node(node);
|
||||
}
|
||||
|
||||
// Generate edges (ensuring DAG property)
|
||||
for i in 1..node_count {
|
||||
let parent_count = self.rng.gen_range(1..=2.min(i));
|
||||
for _ in 0..parent_count {
|
||||
let parent = self.rng.gen_range(0..i);
|
||||
let _ = dag.add_edge(parent, i);
|
||||
}
|
||||
}
|
||||
|
||||
dag
|
||||
}
|
||||
|
||||
fn generate_vector_query(&mut self) -> QueryDag {
|
||||
let mut dag = QueryDag::new();
|
||||
let mut id = 0;
|
||||
|
||||
// Vector search with join to metadata
|
||||
let hnsw = dag.add_node(OperatorNode::hnsw_scan(id, "vectors_idx", 64)); id += 1;
|
||||
let meta_scan = dag.add_node(OperatorNode::seq_scan(id, "metadata")); id += 1;
|
||||
|
||||
let join = dag.add_node(OperatorNode::new(id, OperatorType::NestedLoopJoin)); id += 1;
|
||||
dag.add_edge(hnsw, join).unwrap();
|
||||
dag.add_edge(meta_scan, join).unwrap();
|
||||
|
||||
let rerank = dag.add_node(OperatorNode::new(id, OperatorType::Rerank {
|
||||
model: "cross-encoder".to_string(),
|
||||
})); id += 1;
|
||||
dag.add_edge(join, rerank).unwrap();
|
||||
|
||||
let limit = dag.add_node(OperatorNode::limit(id, 10)); id += 1;
|
||||
dag.add_edge(rerank, limit).unwrap();
|
||||
|
||||
let result = dag.add_node(OperatorNode::new(id, OperatorType::Result));
|
||||
dag.add_edge(limit, result).unwrap();
|
||||
|
||||
dag
|
||||
}
|
||||
|
||||
fn random_operator_type(&mut self, id: usize) -> OperatorType {
|
||||
match self.rng.gen_range(0..10) {
|
||||
0 => OperatorType::SeqScan { table: format!("table_{}", id) },
|
||||
1 => OperatorType::IndexScan {
|
||||
index: format!("idx_{}", id),
|
||||
table: format!("table_{}", id)
|
||||
},
|
||||
2 => OperatorType::HnswScan {
|
||||
index: format!("hnsw_{}", id),
|
||||
ef_search: 64
|
||||
},
|
||||
3 => OperatorType::HashJoin {
|
||||
hash_key: "id".to_string()
|
||||
},
|
||||
4 => OperatorType::Filter {
|
||||
predicate: "x > 0".to_string()
|
||||
},
|
||||
5 => OperatorType::Sort {
|
||||
keys: vec!["col1".to_string()],
|
||||
descending: vec![false]
|
||||
},
|
||||
6 => OperatorType::Limit { count: 100 },
|
||||
7 => OperatorType::Aggregate {
|
||||
functions: vec!["COUNT(*)".to_string()]
|
||||
},
|
||||
8 => OperatorType::Project {
|
||||
columns: vec!["a".to_string(), "b".to_string()]
|
||||
},
|
||||
_ => OperatorType::Result,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DagGenerator {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a batch of DAGs
|
||||
pub fn generate_dag_batch(count: usize, complexity: DagComplexity) -> Vec<QueryDag> {
|
||||
let mut gen = DagGenerator::new();
|
||||
(0..count).map(|_| gen.generate(complexity)).collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_generate_simple() {
|
||||
let mut gen = DagGenerator::new();
|
||||
let dag = gen.generate(DagComplexity::Simple);
|
||||
assert_eq!(dag.nodes.len(), 3);
|
||||
assert_eq!(dag.edges.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_medium() {
|
||||
let mut gen = DagGenerator::new();
|
||||
let dag = gen.generate(DagComplexity::Medium);
|
||||
assert!(dag.nodes.len() >= 5);
|
||||
assert!(dag.nodes.len() <= 20);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_vector_query() {
|
||||
let mut gen = DagGenerator::new();
|
||||
let dag = gen.generate(DagComplexity::VectorQuery);
|
||||
|
||||
// Should have HNSW scan node
|
||||
let has_hnsw = dag.nodes.iter().any(|n| matches!(n.op_type, OperatorType::HnswScan { .. }));
|
||||
assert!(has_hnsw);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_batch() {
|
||||
let dags = generate_dag_batch(10, DagComplexity::Simple);
|
||||
assert_eq!(dags.len(), 10);
|
||||
}
|
||||
}
|
||||
195
crates/ruvector-dag/tests/fixtures/mock_qudag.rs
vendored
Normal file
195
crates/ruvector-dag/tests/fixtures/mock_qudag.rs
vendored
Normal file
@@ -0,0 +1,195 @@
|
||||
//! Mock QuDAG Server for testing
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
pub struct MockQuDagServer {
|
||||
proposals: Arc<Mutex<HashMap<String, MockProposal>>>,
|
||||
patterns: Arc<Mutex<Vec<MockPattern>>>,
|
||||
balances: Arc<Mutex<HashMap<String, f64>>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MockProposal {
|
||||
pub id: String,
|
||||
pub status: String,
|
||||
pub votes_for: u64,
|
||||
pub votes_against: u64,
|
||||
pub finalized: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MockPattern {
|
||||
pub id: String,
|
||||
pub vector: Vec<f32>,
|
||||
pub round: u64,
|
||||
}
|
||||
|
||||
impl MockQuDagServer {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
proposals: Arc::new(Mutex::new(HashMap::new())),
|
||||
patterns: Arc::new(Mutex::new(Vec::new())),
|
||||
balances: Arc::new(Mutex::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn endpoint(&self) -> String {
|
||||
"mock://localhost:8443".to_string()
|
||||
}
|
||||
|
||||
pub fn submit_proposal(&self, vector: Vec<f32>) -> String {
|
||||
let id = format!("prop_{}", rand::random::<u64>());
|
||||
|
||||
let proposal = MockProposal {
|
||||
id: id.clone(),
|
||||
status: "pending".to_string(),
|
||||
votes_for: 0,
|
||||
votes_against: 0,
|
||||
finalized: false,
|
||||
};
|
||||
|
||||
self.proposals.lock().unwrap().insert(id.clone(), proposal);
|
||||
id
|
||||
}
|
||||
|
||||
pub fn get_proposal(&self, id: &str) -> Option<MockProposal> {
|
||||
self.proposals.lock().unwrap().get(id).cloned()
|
||||
}
|
||||
|
||||
pub fn finalize_proposal(&self, id: &str, accept: bool) {
|
||||
if let Some(proposal) = self.proposals.lock().unwrap().get_mut(id) {
|
||||
proposal.status = if accept { "accepted" } else { "rejected" }.to_string();
|
||||
proposal.finalized = true;
|
||||
proposal.votes_for = if accept { 100 } else { 30 };
|
||||
proposal.votes_against = if accept { 20 } else { 70 };
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_pattern(&self, vector: Vec<f32>, round: u64) -> String {
|
||||
let id = format!("pat_{}", rand::random::<u64>());
|
||||
|
||||
self.patterns.lock().unwrap().push(MockPattern {
|
||||
id: id.clone(),
|
||||
vector,
|
||||
round,
|
||||
});
|
||||
|
||||
id
|
||||
}
|
||||
|
||||
pub fn get_patterns_since(&self, round: u64) -> Vec<MockPattern> {
|
||||
self.patterns.lock().unwrap()
|
||||
.iter()
|
||||
.filter(|p| p.round >= round)
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn set_balance(&self, node_id: &str, balance: f64) {
|
||||
self.balances.lock().unwrap().insert(node_id.to_string(), balance);
|
||||
}
|
||||
|
||||
pub fn get_balance(&self, node_id: &str) -> f64 {
|
||||
self.balances.lock().unwrap().get(node_id).copied().unwrap_or(0.0)
|
||||
}
|
||||
|
||||
pub fn stake(&self, node_id: &str, amount: f64) -> Result<(), String> {
|
||||
let mut balances = self.balances.lock().unwrap();
|
||||
let balance = balances.get(node_id).copied().unwrap_or(0.0);
|
||||
|
||||
if balance < amount {
|
||||
return Err("Insufficient balance".to_string());
|
||||
}
|
||||
|
||||
balances.insert(node_id.to_string(), balance - amount);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MockQuDagServer {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a pre-populated mock server for testing
|
||||
pub fn create_test_server() -> MockQuDagServer {
|
||||
let server = MockQuDagServer::new();
|
||||
|
||||
// Add some patterns
|
||||
for round in 0..10 {
|
||||
let vector: Vec<f32> = (0..256).map(|i| (i as f32 / 256.0).sin()).collect();
|
||||
server.add_pattern(vector, round);
|
||||
}
|
||||
|
||||
// Set up balances
|
||||
server.set_balance("test_node", 1000.0);
|
||||
|
||||
server
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_submit_proposal() {
|
||||
let server = MockQuDagServer::new();
|
||||
let vector = vec![0.1; 256];
|
||||
|
||||
let id = server.submit_proposal(vector);
|
||||
assert!(id.starts_with("prop_"));
|
||||
|
||||
let proposal = server.get_proposal(&id).unwrap();
|
||||
assert_eq!(proposal.status, "pending");
|
||||
assert_eq!(proposal.votes_for, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_finalize_proposal() {
|
||||
let server = MockQuDagServer::new();
|
||||
let id = server.submit_proposal(vec![0.1; 256]);
|
||||
|
||||
server.finalize_proposal(&id, true);
|
||||
|
||||
let proposal = server.get_proposal(&id).unwrap();
|
||||
assert_eq!(proposal.status, "accepted");
|
||||
assert!(proposal.finalized);
|
||||
assert!(proposal.votes_for > proposal.votes_against);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_pattern() {
|
||||
let server = MockQuDagServer::new();
|
||||
let vector = vec![0.2; 128];
|
||||
|
||||
let id = server.add_pattern(vector.clone(), 5);
|
||||
assert!(id.starts_with("pat_"));
|
||||
|
||||
let patterns = server.get_patterns_since(5);
|
||||
assert_eq!(patterns.len(), 1);
|
||||
assert_eq!(patterns[0].round, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stake() {
|
||||
let server = MockQuDagServer::new();
|
||||
server.set_balance("node1", 1000.0);
|
||||
|
||||
assert!(server.stake("node1", 100.0).is_ok());
|
||||
assert_eq!(server.get_balance("node1"), 900.0);
|
||||
|
||||
assert!(server.stake("node1", 2000.0).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_test_server() {
|
||||
let server = create_test_server();
|
||||
|
||||
let patterns = server.get_patterns_since(0);
|
||||
assert_eq!(patterns.len(), 10);
|
||||
|
||||
assert_eq!(server.get_balance("test_node"), 1000.0);
|
||||
}
|
||||
}
|
||||
11
crates/ruvector-dag/tests/fixtures/mod.rs
vendored
Normal file
11
crates/ruvector-dag/tests/fixtures/mod.rs
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
//! Test fixtures and generators
|
||||
|
||||
pub mod dag_generator;
|
||||
pub mod pattern_generator;
|
||||
pub mod trajectory_generator;
|
||||
pub mod mock_qudag;
|
||||
|
||||
pub use dag_generator::*;
|
||||
pub use pattern_generator::*;
|
||||
pub use trajectory_generator::*;
|
||||
pub use mock_qudag::*;
|
||||
165
crates/ruvector-dag/tests/fixtures/pattern_generator.rs
vendored
Normal file
165
crates/ruvector-dag/tests/fixtures/pattern_generator.rs
vendored
Normal file
@@ -0,0 +1,165 @@
|
||||
//! Pattern Generator for testing
|
||||
|
||||
use rand::Rng;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GeneratedPattern {
|
||||
pub vector: Vec<f32>,
|
||||
pub quality_score: f64,
|
||||
pub category: PatternCategory,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum PatternCategory {
|
||||
Scan,
|
||||
Join,
|
||||
Aggregate,
|
||||
Sort,
|
||||
Vector,
|
||||
}
|
||||
|
||||
pub struct PatternGenerator {
|
||||
dim: usize,
|
||||
rng: rand::rngs::ThreadRng,
|
||||
}
|
||||
|
||||
impl PatternGenerator {
|
||||
pub fn new(dim: usize) -> Self {
|
||||
Self {
|
||||
dim,
|
||||
rng: rand::thread_rng(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn generate(&mut self, category: PatternCategory) -> GeneratedPattern {
|
||||
let base = self.category_base_vector(category);
|
||||
let vector = self.add_noise(&base, 0.1);
|
||||
let quality_score = 0.5 + self.rng.gen::<f64>() * 0.5;
|
||||
|
||||
GeneratedPattern {
|
||||
vector,
|
||||
quality_score,
|
||||
category,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn generate_batch(&mut self, count: usize) -> Vec<GeneratedPattern> {
|
||||
let categories = [
|
||||
PatternCategory::Scan,
|
||||
PatternCategory::Join,
|
||||
PatternCategory::Aggregate,
|
||||
PatternCategory::Sort,
|
||||
PatternCategory::Vector,
|
||||
];
|
||||
|
||||
(0..count)
|
||||
.map(|i| {
|
||||
let cat = categories[i % categories.len()];
|
||||
self.generate(cat)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn category_base_vector(&mut self, category: PatternCategory) -> Vec<f32> {
|
||||
// Each category has a distinct base pattern
|
||||
let seed = match category {
|
||||
PatternCategory::Scan => 1.0,
|
||||
PatternCategory::Join => 2.0,
|
||||
PatternCategory::Aggregate => 3.0,
|
||||
PatternCategory::Sort => 4.0,
|
||||
PatternCategory::Vector => 5.0,
|
||||
};
|
||||
|
||||
(0..self.dim)
|
||||
.map(|i| {
|
||||
let x = (i as f32 + seed) / self.dim as f32;
|
||||
(x * std::f32::consts::PI * seed).sin()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn add_noise(&mut self, base: &[f32], noise_level: f32) -> Vec<f32> {
|
||||
base.iter()
|
||||
.map(|&v| v + (self.rng.gen::<f32>() - 0.5) * 2.0 * noise_level)
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PatternGenerator {
|
||||
fn default() -> Self {
|
||||
Self::new(256)
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate clustered patterns for testing ReasoningBank
|
||||
pub fn generate_clustered_patterns(
|
||||
clusters: usize,
|
||||
patterns_per_cluster: usize,
|
||||
dim: usize,
|
||||
) -> Vec<GeneratedPattern> {
|
||||
let mut gen = PatternGenerator::new(dim);
|
||||
let mut patterns = Vec::new();
|
||||
|
||||
let categories = [
|
||||
PatternCategory::Scan,
|
||||
PatternCategory::Join,
|
||||
PatternCategory::Aggregate,
|
||||
PatternCategory::Sort,
|
||||
PatternCategory::Vector,
|
||||
];
|
||||
|
||||
for c in 0..clusters {
|
||||
let category = categories[c % categories.len()];
|
||||
for _ in 0..patterns_per_cluster {
|
||||
patterns.push(gen.generate(category));
|
||||
}
|
||||
}
|
||||
|
||||
patterns
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_generate_pattern() {
|
||||
let mut gen = PatternGenerator::new(128);
|
||||
let pattern = gen.generate(PatternCategory::Scan);
|
||||
|
||||
assert_eq!(pattern.vector.len(), 128);
|
||||
assert!(pattern.quality_score >= 0.5 && pattern.quality_score <= 1.0);
|
||||
assert_eq!(pattern.category, PatternCategory::Scan);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_batch() {
|
||||
let mut gen = PatternGenerator::new(64);
|
||||
let patterns = gen.generate_batch(10);
|
||||
|
||||
assert_eq!(patterns.len(), 10);
|
||||
assert!(patterns.iter().all(|p| p.vector.len() == 64));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clustered_patterns() {
|
||||
let patterns = generate_clustered_patterns(3, 5, 128);
|
||||
assert_eq!(patterns.len(), 15);
|
||||
|
||||
// Check that patterns are distributed across categories
|
||||
let scan_count = patterns.iter().filter(|p| p.category == PatternCategory::Scan).count();
|
||||
assert!(scan_count > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_category_distinctness() {
|
||||
let mut gen = PatternGenerator::new(64);
|
||||
|
||||
let scan = gen.generate(PatternCategory::Scan);
|
||||
let join = gen.generate(PatternCategory::Join);
|
||||
|
||||
// Vectors should be different (cosine similarity should be < 1.0)
|
||||
let dot: f32 = scan.vector.iter().zip(&join.vector).map(|(a, b)| a * b).sum();
|
||||
assert!(dot.abs() < 0.99);
|
||||
}
|
||||
}
|
||||
135
crates/ruvector-dag/tests/fixtures/trajectory_generator.rs
vendored
Normal file
135
crates/ruvector-dag/tests/fixtures/trajectory_generator.rs
vendored
Normal file
@@ -0,0 +1,135 @@
|
||||
//! Trajectory Generator for testing
|
||||
|
||||
use ruvector_dag::sona::DagTrajectory;
|
||||
use rand::Rng;
|
||||
|
||||
pub struct TrajectoryGenerator {
|
||||
rng: rand::rngs::ThreadRng,
|
||||
embedding_dim: usize,
|
||||
}
|
||||
|
||||
impl TrajectoryGenerator {
|
||||
pub fn new(embedding_dim: usize) -> Self {
|
||||
Self {
|
||||
rng: rand::thread_rng(),
|
||||
embedding_dim,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn generate(&mut self, mechanism: &str) -> DagTrajectory {
|
||||
let query_hash = self.rng.gen::<u64>();
|
||||
let dag_embedding = self.random_embedding();
|
||||
let execution_time_ms = 10.0 + self.rng.gen::<f64>() * 990.0;
|
||||
let baseline_time_ms = execution_time_ms * (1.0 + self.rng.gen::<f64>() * 0.5);
|
||||
|
||||
DagTrajectory::new(
|
||||
query_hash,
|
||||
dag_embedding,
|
||||
mechanism.to_string(),
|
||||
execution_time_ms,
|
||||
baseline_time_ms,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn generate_batch(&mut self, count: usize) -> Vec<DagTrajectory> {
|
||||
let mechanisms = ["topological", "causal_cone", "critical_path", "mincut_gated"];
|
||||
|
||||
(0..count)
|
||||
.map(|i| {
|
||||
let mech = mechanisms[i % mechanisms.len()];
|
||||
self.generate(mech)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn generate_improving_batch(&mut self, count: usize) -> Vec<DagTrajectory> {
|
||||
// Generate trajectories with improving quality
|
||||
(0..count)
|
||||
.map(|i| {
|
||||
let improvement = i as f64 / count as f64;
|
||||
let execution_time = 100.0 * (1.0 - improvement * 0.5);
|
||||
let baseline = 100.0;
|
||||
|
||||
DagTrajectory::new(
|
||||
self.rng.gen(),
|
||||
self.random_embedding(),
|
||||
"auto".to_string(),
|
||||
execution_time,
|
||||
baseline,
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn random_embedding(&mut self) -> Vec<f32> {
|
||||
let mut embedding: Vec<f32> = (0..self.embedding_dim)
|
||||
.map(|_| self.rng.gen::<f32>() * 2.0 - 1.0)
|
||||
.collect();
|
||||
|
||||
// Normalize
|
||||
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm > 0.0 {
|
||||
embedding.iter_mut().for_each(|x| *x /= norm);
|
||||
}
|
||||
|
||||
embedding
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for TrajectoryGenerator {
|
||||
fn default() -> Self {
|
||||
Self::new(256)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_generate_trajectory() {
|
||||
let mut gen = TrajectoryGenerator::new(128);
|
||||
let traj = gen.generate("topological");
|
||||
|
||||
assert_eq!(traj.dag_embedding.len(), 128);
|
||||
assert_eq!(traj.mechanism, "topological");
|
||||
assert!(traj.execution_time_ms > 0.0);
|
||||
assert!(traj.baseline_time_ms > 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_batch() {
|
||||
let mut gen = TrajectoryGenerator::new(64);
|
||||
let trajectories = gen.generate_batch(20);
|
||||
|
||||
assert_eq!(trajectories.len(), 20);
|
||||
|
||||
// Check mechanism distribution
|
||||
let mechanisms: Vec<_> = trajectories.iter().map(|t| &t.mechanism).collect();
|
||||
assert!(mechanisms.contains(&&"topological".to_string()));
|
||||
assert!(mechanisms.contains(&&"causal_cone".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_improving_batch() {
|
||||
let mut gen = TrajectoryGenerator::new(128);
|
||||
let trajectories = gen.generate_improving_batch(10);
|
||||
|
||||
assert_eq!(trajectories.len(), 10);
|
||||
|
||||
// Check that execution times are decreasing (improvement)
|
||||
for i in 0..trajectories.len() - 1 {
|
||||
assert!(trajectories[i].execution_time_ms >= trajectories[i + 1].execution_time_ms);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalized_embeddings() {
|
||||
let mut gen = TrajectoryGenerator::new(64);
|
||||
let traj = gen.generate("test");
|
||||
|
||||
// Check that embedding is normalized
|
||||
let norm: f32 = traj.dag_embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
assert!((norm - 1.0).abs() < 0.01);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user