854 lines
27 KiB
Rust
854 lines
27 KiB
Rust
//! Causal Reasoning Benchmarks for Prime-Radiant
|
|
//!
|
|
//! Benchmarks for causal inference operations including:
|
|
//! - Intervention computation (do-calculus)
|
|
//! - Counterfactual queries
|
|
//! - Causal abstraction verification
|
|
//! - Structural causal model operations
|
|
//!
|
|
//! Target metrics:
|
|
//! - Intervention: < 1ms per intervention
|
|
//! - Counterfactual: < 5ms per query
|
|
//! - Abstraction verification: < 10ms for moderate models
|
|
|
|
use criterion::{
|
|
black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput,
|
|
};
|
|
use std::collections::{HashMap, HashSet, VecDeque};
|
|
|
|
// ============================================================================
|
|
// CAUSAL MODEL TYPES
|
|
// ============================================================================
|
|
|
|
/// Variable identifier
|
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
|
struct VariableId(usize);
|
|
|
|
/// Variable value
|
|
#[derive(Clone, Debug)]
|
|
enum Value {
|
|
Continuous(f64),
|
|
Discrete(i64),
|
|
Vector(Vec<f64>),
|
|
}
|
|
|
|
impl Value {
|
|
fn as_f64(&self) -> f64 {
|
|
match self {
|
|
Value::Continuous(v) => *v,
|
|
Value::Discrete(v) => *v as f64,
|
|
Value::Vector(v) => v.iter().sum(),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Structural equation: V = f(Pa(V), U_V)
|
|
struct StructuralEquation {
|
|
variable: VariableId,
|
|
parents: Vec<VariableId>,
|
|
/// Function mapping parent values to variable value
|
|
function: Box<dyn Fn(&[Value]) -> Value + Send + Sync>,
|
|
}
|
|
|
|
/// Structural Causal Model
|
|
struct CausalModel {
|
|
variables: HashMap<VariableId, String>,
|
|
variable_ids: HashMap<String, VariableId>,
|
|
parents: HashMap<VariableId, Vec<VariableId>>,
|
|
children: HashMap<VariableId, Vec<VariableId>>,
|
|
equations: HashMap<VariableId, Box<dyn Fn(&[Value]) -> Value + Send + Sync>>,
|
|
exogenous: HashMap<VariableId, Value>,
|
|
next_id: usize,
|
|
}
|
|
|
|
impl CausalModel {
|
|
fn new() -> Self {
|
|
Self {
|
|
variables: HashMap::new(),
|
|
variable_ids: HashMap::new(),
|
|
parents: HashMap::new(),
|
|
children: HashMap::new(),
|
|
equations: HashMap::new(),
|
|
exogenous: HashMap::new(),
|
|
next_id: 0,
|
|
}
|
|
}
|
|
|
|
fn add_variable(&mut self, name: &str) -> VariableId {
|
|
let id = VariableId(self.next_id);
|
|
self.next_id += 1;
|
|
|
|
self.variables.insert(id, name.to_string());
|
|
self.variable_ids.insert(name.to_string(), id);
|
|
self.parents.insert(id, Vec::new());
|
|
self.children.insert(id, Vec::new());
|
|
|
|
// Default exogenous value
|
|
self.exogenous.insert(id, Value::Continuous(0.0));
|
|
|
|
id
|
|
}
|
|
|
|
fn add_edge(&mut self, from: VariableId, to: VariableId) {
|
|
self.parents.get_mut(&to).unwrap().push(from);
|
|
self.children.get_mut(&from).unwrap().push(to);
|
|
}
|
|
|
|
fn set_equation<F>(&mut self, var: VariableId, func: F)
|
|
where
|
|
F: Fn(&[Value]) -> Value + Send + Sync + 'static,
|
|
{
|
|
self.equations.insert(var, Box::new(func));
|
|
}
|
|
|
|
fn set_exogenous(&mut self, var: VariableId, value: Value) {
|
|
self.exogenous.insert(var, value);
|
|
}
|
|
|
|
fn topological_order(&self) -> Vec<VariableId> {
|
|
let mut order = Vec::new();
|
|
let mut visited = HashSet::new();
|
|
let mut temp_mark = HashSet::new();
|
|
|
|
fn visit(
|
|
id: VariableId,
|
|
parents: &HashMap<VariableId, Vec<VariableId>>,
|
|
visited: &mut HashSet<VariableId>,
|
|
temp_mark: &mut HashSet<VariableId>,
|
|
order: &mut Vec<VariableId>,
|
|
) {
|
|
if visited.contains(&id) {
|
|
return;
|
|
}
|
|
if temp_mark.contains(&id) {
|
|
return; // Cycle detected
|
|
}
|
|
|
|
temp_mark.insert(id);
|
|
|
|
for &parent in parents.get(&id).unwrap_or(&vec![]) {
|
|
visit(parent, parents, visited, temp_mark, order);
|
|
}
|
|
|
|
temp_mark.remove(&id);
|
|
visited.insert(id);
|
|
order.push(id);
|
|
}
|
|
|
|
for &id in self.variables.keys() {
|
|
visit(id, &self.parents, &mut visited, &mut temp_mark, &mut order);
|
|
}
|
|
|
|
order
|
|
}
|
|
|
|
/// Compute values given current exogenous variables
|
|
fn forward(&self) -> HashMap<VariableId, Value> {
|
|
let mut values = HashMap::new();
|
|
let order = self.topological_order();
|
|
|
|
for id in order {
|
|
let parent_ids = self.parents.get(&id).unwrap();
|
|
let parent_values: Vec<Value> = parent_ids
|
|
.iter()
|
|
.map(|&pid| values.get(&pid).cloned().unwrap_or(Value::Continuous(0.0)))
|
|
.collect();
|
|
|
|
let value = if let Some(func) = self.equations.get(&id) {
|
|
// Combine exogenous with structural equation
|
|
let exo = self.exogenous.get(&id).cloned().unwrap_or(Value::Continuous(0.0));
|
|
let base = func(&parent_values);
|
|
Value::Continuous(base.as_f64() + exo.as_f64())
|
|
} else {
|
|
self.exogenous.get(&id).cloned().unwrap_or(Value::Continuous(0.0))
|
|
};
|
|
|
|
values.insert(id, value);
|
|
}
|
|
|
|
values
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// INTERVENTION
|
|
// ============================================================================
|
|
|
|
/// Intervention: do(X = x)
|
|
#[derive(Clone)]
|
|
struct Intervention {
|
|
variable: VariableId,
|
|
value: Value,
|
|
}
|
|
|
|
impl Intervention {
|
|
fn new(variable: VariableId, value: Value) -> Self {
|
|
Self { variable, value }
|
|
}
|
|
}
|
|
|
|
/// Apply intervention and compute resulting distribution
|
|
fn apply_intervention(
|
|
model: &CausalModel,
|
|
intervention: &Intervention,
|
|
) -> HashMap<VariableId, Value> {
|
|
let mut values = HashMap::new();
|
|
let order = model.topological_order();
|
|
|
|
for id in order {
|
|
if id == intervention.variable {
|
|
// Override with intervention value
|
|
values.insert(id, intervention.value.clone());
|
|
} else {
|
|
let parent_ids = model.parents.get(&id).unwrap();
|
|
let parent_values: Vec<Value> = parent_ids
|
|
.iter()
|
|
.map(|&pid| values.get(&pid).cloned().unwrap_or(Value::Continuous(0.0)))
|
|
.collect();
|
|
|
|
let value = if let Some(func) = model.equations.get(&id) {
|
|
let exo = model.exogenous.get(&id).cloned().unwrap_or(Value::Continuous(0.0));
|
|
let base = func(&parent_values);
|
|
Value::Continuous(base.as_f64() + exo.as_f64())
|
|
} else {
|
|
model.exogenous.get(&id).cloned().unwrap_or(Value::Continuous(0.0))
|
|
};
|
|
|
|
values.insert(id, value);
|
|
}
|
|
}
|
|
|
|
values
|
|
}
|
|
|
|
/// Apply multiple interventions
|
|
fn apply_multi_intervention(
|
|
model: &CausalModel,
|
|
interventions: &[Intervention],
|
|
) -> HashMap<VariableId, Value> {
|
|
let intervention_map: HashMap<VariableId, Value> = interventions
|
|
.iter()
|
|
.map(|i| (i.variable, i.value.clone()))
|
|
.collect();
|
|
|
|
let mut values = HashMap::new();
|
|
let order = model.topological_order();
|
|
|
|
for id in order {
|
|
if let Some(value) = intervention_map.get(&id) {
|
|
values.insert(id, value.clone());
|
|
} else {
|
|
let parent_ids = model.parents.get(&id).unwrap();
|
|
let parent_values: Vec<Value> = parent_ids
|
|
.iter()
|
|
.map(|&pid| values.get(&pid).cloned().unwrap_or(Value::Continuous(0.0)))
|
|
.collect();
|
|
|
|
let value = if let Some(func) = model.equations.get(&id) {
|
|
let exo = model.exogenous.get(&id).cloned().unwrap_or(Value::Continuous(0.0));
|
|
let base = func(&parent_values);
|
|
Value::Continuous(base.as_f64() + exo.as_f64())
|
|
} else {
|
|
model.exogenous.get(&id).cloned().unwrap_or(Value::Continuous(0.0))
|
|
};
|
|
|
|
values.insert(id, value);
|
|
}
|
|
}
|
|
|
|
values
|
|
}
|
|
|
|
// ============================================================================
|
|
// COUNTERFACTUAL REASONING
|
|
// ============================================================================
|
|
|
|
/// Counterfactual query: Y_x(u) where we observed Y = y
|
|
struct CounterfactualQuery {
|
|
/// The variable we're asking about
|
|
target: VariableId,
|
|
/// The intervention
|
|
intervention: Intervention,
|
|
/// Observed facts
|
|
observations: HashMap<VariableId, Value>,
|
|
}
|
|
|
|
/// Compute counterfactual using abduction-action-prediction
|
|
fn compute_counterfactual(
|
|
model: &CausalModel,
|
|
query: &CounterfactualQuery,
|
|
) -> Option<Value> {
|
|
// Step 1: Abduction - infer exogenous variables from observations
|
|
let inferred_exogenous = abduct_exogenous(model, &query.observations)?;
|
|
|
|
// Step 2: Action - create modified model with intervention
|
|
// (We don't actually modify the model, we use the intervention directly)
|
|
|
|
// Step 3: Prediction - compute outcome under intervention with inferred exogenous
|
|
let mut values = HashMap::new();
|
|
let order = model.topological_order();
|
|
|
|
for id in order {
|
|
if id == query.intervention.variable {
|
|
values.insert(id, query.intervention.value.clone());
|
|
} else {
|
|
let parent_ids = model.parents.get(&id).unwrap();
|
|
let parent_values: Vec<Value> = parent_ids
|
|
.iter()
|
|
.map(|&pid| values.get(&pid).cloned().unwrap_or(Value::Continuous(0.0)))
|
|
.collect();
|
|
|
|
let value = if let Some(func) = model.equations.get(&id) {
|
|
let exo = inferred_exogenous
|
|
.get(&id)
|
|
.cloned()
|
|
.unwrap_or(Value::Continuous(0.0));
|
|
let base = func(&parent_values);
|
|
Value::Continuous(base.as_f64() + exo.as_f64())
|
|
} else {
|
|
inferred_exogenous
|
|
.get(&id)
|
|
.cloned()
|
|
.unwrap_or(Value::Continuous(0.0))
|
|
};
|
|
|
|
values.insert(id, value);
|
|
}
|
|
}
|
|
|
|
values.get(&query.target).cloned()
|
|
}
|
|
|
|
/// Abduct exogenous variables from observations
|
|
fn abduct_exogenous(
|
|
model: &CausalModel,
|
|
observations: &HashMap<VariableId, Value>,
|
|
) -> Option<HashMap<VariableId, Value>> {
|
|
let mut exogenous = model.exogenous.clone();
|
|
let order = model.topological_order();
|
|
|
|
// For each observed variable, infer the exogenous noise
|
|
let mut computed_values = HashMap::new();
|
|
|
|
for id in order {
|
|
let parent_ids = model.parents.get(&id).unwrap();
|
|
let parent_values: Vec<Value> = parent_ids
|
|
.iter()
|
|
.map(|&pid| {
|
|
computed_values
|
|
.get(&pid)
|
|
.cloned()
|
|
.unwrap_or(Value::Continuous(0.0))
|
|
})
|
|
.collect();
|
|
|
|
if let Some(observed) = observations.get(&id) {
|
|
// Infer exogenous: U = Y - f(Pa)
|
|
if let Some(func) = model.equations.get(&id) {
|
|
let structural_part = func(&parent_values).as_f64();
|
|
let inferred_exo = observed.as_f64() - structural_part;
|
|
exogenous.insert(id, Value::Continuous(inferred_exo));
|
|
}
|
|
computed_values.insert(id, observed.clone());
|
|
} else {
|
|
// Compute from parents
|
|
let value = if let Some(func) = model.equations.get(&id) {
|
|
let exo = exogenous.get(&id).cloned().unwrap_or(Value::Continuous(0.0));
|
|
let base = func(&parent_values);
|
|
Value::Continuous(base.as_f64() + exo.as_f64())
|
|
} else {
|
|
exogenous.get(&id).cloned().unwrap_or(Value::Continuous(0.0))
|
|
};
|
|
computed_values.insert(id, value);
|
|
}
|
|
}
|
|
|
|
Some(exogenous)
|
|
}
|
|
|
|
// ============================================================================
|
|
// CAUSAL ABSTRACTION
|
|
// ============================================================================
|
|
|
|
/// Map between low-level and high-level causal models
|
|
struct CausalAbstraction {
|
|
/// Low-level model
|
|
low_level: CausalModel,
|
|
/// High-level model
|
|
high_level: CausalModel,
|
|
/// Variable mapping: high-level -> set of low-level variables
|
|
variable_map: HashMap<VariableId, Vec<VariableId>>,
|
|
/// Value mapping: how to aggregate low-level values
|
|
value_aggregator: Box<dyn Fn(&[Value]) -> Value + Send + Sync>,
|
|
}
|
|
|
|
impl CausalAbstraction {
|
|
fn new(low_level: CausalModel, high_level: CausalModel) -> Self {
|
|
Self {
|
|
low_level,
|
|
high_level,
|
|
variable_map: HashMap::new(),
|
|
value_aggregator: Box::new(|vals: &[Value]| {
|
|
let sum: f64 = vals.iter().map(|v| v.as_f64()).sum();
|
|
Value::Continuous(sum / vals.len().max(1) as f64)
|
|
}),
|
|
}
|
|
}
|
|
|
|
fn add_mapping(&mut self, high_var: VariableId, low_vars: Vec<VariableId>) {
|
|
self.variable_map.insert(high_var, low_vars);
|
|
}
|
|
|
|
/// Verify abstraction consistency: interventions commute
|
|
fn verify_consistency(&self, intervention: &Intervention) -> bool {
|
|
// High-level: intervene and compute
|
|
let high_values = apply_intervention(&self.high_level, intervention);
|
|
|
|
// Low-level: intervene on corresponding variables and aggregate
|
|
let low_vars = self.variable_map.get(&intervention.variable);
|
|
if low_vars.is_none() {
|
|
return false;
|
|
}
|
|
|
|
let low_interventions: Vec<Intervention> = low_vars
|
|
.unwrap()
|
|
.iter()
|
|
.map(|&v| Intervention::new(v, intervention.value.clone()))
|
|
.collect();
|
|
|
|
let low_values = apply_multi_intervention(&self.low_level, &low_interventions);
|
|
|
|
// Compare aggregated low-level values with high-level values
|
|
for (&high_var, low_vars) in &self.variable_map {
|
|
let high_val = high_values.get(&high_var).map(|v| v.as_f64()).unwrap_or(0.0);
|
|
|
|
let low_vals: Vec<Value> = low_vars
|
|
.iter()
|
|
.filter_map(|&lv| low_values.get(&lv).cloned())
|
|
.collect();
|
|
|
|
let aggregated = (self.value_aggregator)(&low_vals).as_f64();
|
|
|
|
if (high_val - aggregated).abs() > 1e-6 {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
true
|
|
}
|
|
|
|
/// Compute abstraction error
|
|
fn compute_abstraction_error(&self, num_samples: usize) -> f64 {
|
|
let mut total_error = 0.0;
|
|
|
|
for i in 0..num_samples {
|
|
// Random intervention value
|
|
let value = Value::Continuous((i as f64 * 0.1).sin() * 10.0);
|
|
|
|
// Pick a random variable to intervene on
|
|
let high_vars: Vec<_> = self.high_level.variables.keys().copied().collect();
|
|
if high_vars.is_empty() {
|
|
continue;
|
|
}
|
|
let var_idx = i % high_vars.len();
|
|
let intervention = Intervention::new(high_vars[var_idx], value);
|
|
|
|
// Compute values
|
|
let high_values = apply_intervention(&self.high_level, &intervention);
|
|
|
|
let low_vars = self.variable_map.get(&intervention.variable);
|
|
if low_vars.is_none() {
|
|
continue;
|
|
}
|
|
|
|
let low_interventions: Vec<Intervention> = low_vars
|
|
.unwrap()
|
|
.iter()
|
|
.map(|&v| Intervention::new(v, intervention.value.clone()))
|
|
.collect();
|
|
|
|
let low_values = apply_multi_intervention(&self.low_level, &low_interventions);
|
|
|
|
// Compute error
|
|
for (&high_var, low_vars) in &self.variable_map {
|
|
let high_val = high_values.get(&high_var).map(|v| v.as_f64()).unwrap_or(0.0);
|
|
|
|
let low_vals: Vec<Value> = low_vars
|
|
.iter()
|
|
.filter_map(|&lv| low_values.get(&lv).cloned())
|
|
.collect();
|
|
|
|
let aggregated = (self.value_aggregator)(&low_vals).as_f64();
|
|
total_error += (high_val - aggregated).powi(2);
|
|
}
|
|
}
|
|
|
|
(total_error / num_samples.max(1) as f64).sqrt()
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// CAUSAL EFFECT ESTIMATION
|
|
// ============================================================================
|
|
|
|
/// Average Treatment Effect
|
|
fn compute_ate(
|
|
model: &CausalModel,
|
|
treatment: VariableId,
|
|
outcome: VariableId,
|
|
treatment_values: (f64, f64), // (control, treated)
|
|
) -> f64 {
|
|
// E[Y | do(X = treated)] - E[Y | do(X = control)]
|
|
let intervention_treated = Intervention::new(treatment, Value::Continuous(treatment_values.1));
|
|
let intervention_control = Intervention::new(treatment, Value::Continuous(treatment_values.0));
|
|
|
|
let values_treated = apply_intervention(model, &intervention_treated);
|
|
let values_control = apply_intervention(model, &intervention_control);
|
|
|
|
let y_treated = values_treated.get(&outcome).map(|v| v.as_f64()).unwrap_or(0.0);
|
|
let y_control = values_control.get(&outcome).map(|v| v.as_f64()).unwrap_or(0.0);
|
|
|
|
y_treated - y_control
|
|
}
|
|
|
|
// ============================================================================
|
|
// BENCHMARK DATA GENERATORS
|
|
// ============================================================================
|
|
|
|
fn create_chain_model(length: usize) -> CausalModel {
|
|
let mut model = CausalModel::new();
|
|
let mut vars = Vec::new();
|
|
|
|
for i in 0..length {
|
|
let var = model.add_variable(&format!("V{}", i));
|
|
vars.push(var);
|
|
|
|
if i > 0 {
|
|
model.add_edge(vars[i - 1], var);
|
|
|
|
let parent_var = vars[i - 1];
|
|
model.set_equation(var, move |parents| {
|
|
if parents.is_empty() {
|
|
Value::Continuous(0.0)
|
|
} else {
|
|
Value::Continuous(parents[0].as_f64() * 0.8 + 0.5)
|
|
}
|
|
});
|
|
}
|
|
}
|
|
|
|
model
|
|
}
|
|
|
|
fn create_diamond_model(num_layers: usize, width: usize) -> CausalModel {
|
|
let mut model = CausalModel::new();
|
|
let mut layers: Vec<Vec<VariableId>> = Vec::new();
|
|
|
|
// Create layers
|
|
for layer in 0..num_layers {
|
|
let layer_width = if layer == 0 || layer == num_layers - 1 {
|
|
1
|
|
} else {
|
|
width
|
|
};
|
|
|
|
let mut layer_vars = Vec::new();
|
|
for i in 0..layer_width {
|
|
let var = model.add_variable(&format!("L{}_{}", layer, i));
|
|
layer_vars.push(var);
|
|
|
|
// Connect to previous layer
|
|
if layer > 0 {
|
|
for &parent in &layers[layer - 1] {
|
|
model.add_edge(parent, var);
|
|
}
|
|
|
|
model.set_equation(var, |parents| {
|
|
let sum: f64 = parents.iter().map(|p| p.as_f64()).sum();
|
|
Value::Continuous(sum / parents.len().max(1) as f64 + 0.1)
|
|
});
|
|
}
|
|
}
|
|
|
|
layers.push(layer_vars);
|
|
}
|
|
|
|
model
|
|
}
|
|
|
|
fn create_dense_model(num_vars: usize, density: f64, seed: u64) -> CausalModel {
|
|
let mut model = CausalModel::new();
|
|
let mut vars = Vec::new();
|
|
|
|
// Create variables
|
|
for i in 0..num_vars {
|
|
let var = model.add_variable(&format!("V{}", i));
|
|
vars.push(var);
|
|
}
|
|
|
|
// Add edges (respecting DAG structure: only forward edges)
|
|
let mut rng_state = seed;
|
|
for i in 0..num_vars {
|
|
for j in (i + 1)..num_vars {
|
|
rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
|
|
let random = (rng_state >> 33) as f64 / (u32::MAX as f64);
|
|
|
|
if random < density {
|
|
model.add_edge(vars[i], vars[j]);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Set equations
|
|
for i in 1..num_vars {
|
|
model.set_equation(vars[i], |parents| {
|
|
let sum: f64 = parents.iter().map(|p| p.as_f64()).sum();
|
|
Value::Continuous(sum * 0.5 + 0.1)
|
|
});
|
|
}
|
|
|
|
model
|
|
}
|
|
|
|
// ============================================================================
|
|
// BENCHMARKS
|
|
// ============================================================================
|
|
|
|
fn bench_intervention(c: &mut Criterion) {
|
|
let mut group = c.benchmark_group("causal/intervention");
|
|
group.sample_size(100);
|
|
|
|
for &size in &[10, 50, 100, 200] {
|
|
let model = create_chain_model(size);
|
|
let var = VariableId(size / 2); // Intervene in middle
|
|
let intervention = Intervention::new(var, Value::Continuous(1.0));
|
|
|
|
group.throughput(Throughput::Elements(size as u64));
|
|
|
|
group.bench_with_input(
|
|
BenchmarkId::new("chain", size),
|
|
&(&model, &intervention),
|
|
|b, (model, intervention)| {
|
|
b.iter(|| black_box(apply_intervention(black_box(model), black_box(intervention))))
|
|
},
|
|
);
|
|
}
|
|
|
|
for &size in &[10, 25, 50] {
|
|
let model = create_diamond_model(4, size);
|
|
let var = VariableId(0);
|
|
let intervention = Intervention::new(var, Value::Continuous(1.0));
|
|
|
|
let total_vars = 2 + 2 * size; // 1 + size + size + 1
|
|
group.throughput(Throughput::Elements(total_vars as u64));
|
|
|
|
group.bench_with_input(
|
|
BenchmarkId::new("diamond", size),
|
|
&(&model, &intervention),
|
|
|b, (model, intervention)| {
|
|
b.iter(|| black_box(apply_intervention(black_box(model), black_box(intervention))))
|
|
},
|
|
);
|
|
}
|
|
|
|
group.finish();
|
|
}
|
|
|
|
fn bench_multi_intervention(c: &mut Criterion) {
|
|
let mut group = c.benchmark_group("causal/multi_intervention");
|
|
group.sample_size(50);
|
|
|
|
for &num_interventions in &[1, 5, 10, 20] {
|
|
let model = create_dense_model(100, 0.1, 42);
|
|
let interventions: Vec<Intervention> = (0..num_interventions)
|
|
.map(|i| Intervention::new(VariableId(i * 5), Value::Continuous(1.0)))
|
|
.collect();
|
|
|
|
group.throughput(Throughput::Elements(num_interventions as u64));
|
|
|
|
group.bench_with_input(
|
|
BenchmarkId::new("dense_100", num_interventions),
|
|
&(&model, &interventions),
|
|
|b, (model, interventions)| {
|
|
b.iter(|| black_box(apply_multi_intervention(black_box(model), black_box(interventions))))
|
|
},
|
|
);
|
|
}
|
|
|
|
group.finish();
|
|
}
|
|
|
|
fn bench_counterfactual(c: &mut Criterion) {
|
|
let mut group = c.benchmark_group("causal/counterfactual");
|
|
group.sample_size(50);
|
|
|
|
for &size in &[10, 25, 50, 100] {
|
|
let model = create_chain_model(size);
|
|
|
|
// Observe last variable
|
|
let mut observations = HashMap::new();
|
|
observations.insert(VariableId(size - 1), Value::Continuous(5.0));
|
|
|
|
let query = CounterfactualQuery {
|
|
target: VariableId(size - 1),
|
|
intervention: Intervention::new(VariableId(0), Value::Continuous(2.0)),
|
|
observations,
|
|
};
|
|
|
|
group.throughput(Throughput::Elements(size as u64));
|
|
|
|
group.bench_with_input(
|
|
BenchmarkId::new("chain", size),
|
|
&(&model, &query),
|
|
|b, (model, query)| {
|
|
b.iter(|| black_box(compute_counterfactual(black_box(model), black_box(query))))
|
|
},
|
|
);
|
|
}
|
|
|
|
group.finish();
|
|
}
|
|
|
|
fn bench_abstraction_verification(c: &mut Criterion) {
|
|
let mut group = c.benchmark_group("causal/abstraction");
|
|
group.sample_size(30);
|
|
|
|
for &low_size in &[20, 50, 100] {
|
|
let high_size = low_size / 5;
|
|
|
|
let low_model = create_chain_model(low_size);
|
|
let high_model = create_chain_model(high_size);
|
|
|
|
let mut abstraction = CausalAbstraction::new(low_model, high_model);
|
|
|
|
// Map high-level vars to groups of low-level vars
|
|
for i in 0..high_size {
|
|
let low_vars: Vec<VariableId> = (0..5)
|
|
.map(|j| VariableId(i * 5 + j))
|
|
.collect();
|
|
abstraction.add_mapping(VariableId(i), low_vars);
|
|
}
|
|
|
|
let intervention = Intervention::new(VariableId(0), Value::Continuous(1.0));
|
|
|
|
group.throughput(Throughput::Elements(low_size as u64));
|
|
|
|
group.bench_with_input(
|
|
BenchmarkId::new("verify_single", low_size),
|
|
&(&abstraction, &intervention),
|
|
|b, (abstraction, intervention)| {
|
|
b.iter(|| black_box(abstraction.verify_consistency(black_box(intervention))))
|
|
},
|
|
);
|
|
|
|
group.bench_with_input(
|
|
BenchmarkId::new("compute_error", low_size),
|
|
&abstraction,
|
|
|b, abstraction| {
|
|
b.iter(|| black_box(abstraction.compute_abstraction_error(10)))
|
|
},
|
|
);
|
|
}
|
|
|
|
group.finish();
|
|
}
|
|
|
|
fn bench_ate(c: &mut Criterion) {
|
|
let mut group = c.benchmark_group("causal/ate");
|
|
group.sample_size(100);
|
|
|
|
for &size in &[10, 50, 100] {
|
|
let model = create_dense_model(size, 0.15, 42);
|
|
let treatment = VariableId(0);
|
|
let outcome = VariableId(size - 1);
|
|
|
|
group.throughput(Throughput::Elements(size as u64));
|
|
|
|
group.bench_with_input(
|
|
BenchmarkId::new("dense", size),
|
|
&(&model, treatment, outcome),
|
|
|b, (model, treatment, outcome)| {
|
|
b.iter(|| {
|
|
black_box(compute_ate(
|
|
black_box(model),
|
|
*treatment,
|
|
*outcome,
|
|
(0.0, 1.0),
|
|
))
|
|
})
|
|
},
|
|
);
|
|
}
|
|
|
|
group.finish();
|
|
}
|
|
|
|
fn bench_topological_sort(c: &mut Criterion) {
|
|
let mut group = c.benchmark_group("causal/topological_sort");
|
|
group.sample_size(100);
|
|
|
|
for &size in &[50, 100, 200, 500] {
|
|
let model = create_dense_model(size, 0.1, 42);
|
|
|
|
group.throughput(Throughput::Elements(size as u64));
|
|
|
|
group.bench_with_input(
|
|
BenchmarkId::new("dense", size),
|
|
&model,
|
|
|b, model| {
|
|
b.iter(|| black_box(model.topological_order()))
|
|
},
|
|
);
|
|
}
|
|
|
|
group.finish();
|
|
}
|
|
|
|
fn bench_forward_propagation(c: &mut Criterion) {
|
|
let mut group = c.benchmark_group("causal/forward");
|
|
group.sample_size(50);
|
|
|
|
for &size in &[50, 100, 200] {
|
|
let model = create_dense_model(size, 0.1, 42);
|
|
|
|
group.throughput(Throughput::Elements(size as u64));
|
|
|
|
group.bench_with_input(
|
|
BenchmarkId::new("dense", size),
|
|
&model,
|
|
|b, model| {
|
|
b.iter(|| black_box(model.forward()))
|
|
},
|
|
);
|
|
}
|
|
|
|
for &(layers, width) in &[(3, 10), (5, 10), (5, 20)] {
|
|
let model = create_diamond_model(layers, width);
|
|
let total_vars = 2 + (layers - 2) * width;
|
|
|
|
group.throughput(Throughput::Elements(total_vars as u64));
|
|
|
|
group.bench_with_input(
|
|
BenchmarkId::new(format!("diamond_{}x{}", layers, width), total_vars),
|
|
&model,
|
|
|b, model| {
|
|
b.iter(|| black_box(model.forward()))
|
|
},
|
|
);
|
|
}
|
|
|
|
group.finish();
|
|
}
|
|
|
|
criterion_group!(
|
|
benches,
|
|
bench_intervention,
|
|
bench_multi_intervention,
|
|
bench_counterfactual,
|
|
bench_abstraction_verification,
|
|
bench_ate,
|
|
bench_topological_sort,
|
|
bench_forward_propagation,
|
|
);
|
|
criterion_main!(benches);
|