Files
wifi-densepose/vendor/ruvector/examples/prime-radiant/src/causal/model.rs

1212 lines
37 KiB
Rust

//! Structural Causal Models (SCM) for causal reasoning
//!
//! This module implements the core causal model structure, including:
//! - Variables with types (continuous, discrete, binary)
//! - Structural equations defining causal mechanisms
//! - Intervention semantics (do-operator)
//! - Forward simulation
use std::collections::HashMap;
use std::sync::Arc;
use thiserror::Error;
use super::graph::{DirectedGraph, DAGValidationError, TopologicalOrder};
/// Unique identifier for a variable in the causal model
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct VariableId(pub u32);
impl From<u32> for VariableId {
fn from(id: u32) -> Self {
VariableId(id)
}
}
impl From<VariableId> for u32 {
fn from(id: VariableId) -> u32 {
id.0
}
}
/// Type of a causal variable
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VariableType {
/// Continuous real-valued variable
Continuous,
/// Discrete variable with finite domain
Discrete,
/// Binary variable (special case of discrete)
Binary,
/// Categorical variable with named levels
Categorical,
}
/// Value that a variable can take
#[derive(Debug, Clone, PartialEq)]
pub enum Value {
/// Continuous value
Continuous(f64),
/// Discrete integer value
Discrete(i64),
/// Binary value
Binary(bool),
/// Categorical value (index into category list)
Categorical(usize),
/// Missing/unknown value
Missing,
}
impl Value {
/// Convert to f64 if possible
pub fn as_f64(&self) -> f64 {
match self {
Value::Continuous(x) => *x,
Value::Discrete(x) => *x as f64,
Value::Binary(b) => if *b { 1.0 } else { 0.0 },
Value::Categorical(i) => *i as f64,
Value::Missing => f64::NAN,
}
}
/// Convert to bool if binary
pub fn as_bool(&self) -> Option<bool> {
match self {
Value::Binary(b) => Some(*b),
Value::Discrete(x) => Some(*x != 0),
Value::Continuous(x) => Some(*x != 0.0),
Value::Categorical(i) => Some(*i != 0),
Value::Missing => None,
}
}
/// Check if value is missing
pub fn is_missing(&self) -> bool {
matches!(self, Value::Missing)
}
}
impl Default for Value {
fn default() -> Self {
Value::Missing
}
}
/// A variable in the causal model
#[derive(Debug, Clone)]
pub struct Variable {
/// Unique identifier
pub id: VariableId,
/// Human-readable name
pub name: String,
/// Variable type
pub var_type: VariableType,
/// Domain constraints (min, max) for continuous
pub domain: Option<(f64, f64)>,
/// Categories for categorical variables
pub categories: Option<Vec<String>>,
/// Description
pub description: Option<String>,
}
impl Variable {
/// Create a new variable
pub fn new(id: VariableId, name: &str, var_type: VariableType) -> Self {
Self {
id,
name: name.to_string(),
var_type,
domain: None,
categories: None,
description: None,
}
}
/// Set domain constraints
pub fn with_domain(mut self, min: f64, max: f64) -> Self {
self.domain = Some((min, max));
self
}
/// Set categories
pub fn with_categories(mut self, categories: Vec<String>) -> Self {
self.categories = Some(categories);
self
}
/// Set description
pub fn with_description(mut self, desc: &str) -> Self {
self.description = Some(desc.to_string());
self
}
}
/// Type alias for mechanism function
pub type MechanismFn = dyn Fn(&[Value]) -> Value + Send + Sync;
/// A mechanism (functional relationship) in a structural equation
#[derive(Clone)]
pub struct Mechanism {
/// The function implementing the mechanism
func: Arc<MechanismFn>,
/// Optional noise distribution parameter
pub noise_scale: f64,
}
impl Mechanism {
/// Create a new mechanism from a function
pub fn new<F>(func: F) -> Self
where
F: Fn(&[Value]) -> Value + Send + Sync + 'static,
{
Self {
func: Arc::new(func),
noise_scale: 0.0,
}
}
/// Create a mechanism with noise
pub fn with_noise<F>(func: F, noise_scale: f64) -> Self
where
F: Fn(&[Value]) -> Value + Send + Sync + 'static,
{
Self {
func: Arc::new(func),
noise_scale,
}
}
/// Apply the mechanism to parent values
pub fn apply(&self, parents: &[Value]) -> Value {
(self.func)(parents)
}
}
impl std::fmt::Debug for Mechanism {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Mechanism")
.field("noise_scale", &self.noise_scale)
.finish()
}
}
/// A structural equation: Y = f(Pa(Y), U_Y)
#[derive(Clone)]
pub struct StructuralEquation {
/// Target variable this equation defines
pub target: VariableId,
/// Parent variables (causes)
pub parents: Vec<VariableId>,
/// The functional mechanism
pub mechanism: Mechanism,
}
impl StructuralEquation {
/// Create a new structural equation
pub fn new(target: VariableId, parents: Vec<VariableId>, mechanism: Mechanism) -> Self {
Self {
target,
parents,
mechanism,
}
}
/// Create a linear structural equation: Y = sum(coefficients[i] * parents[i])
pub fn linear(parents: &[VariableId], coefficients: Vec<f64>) -> Self {
let parents_vec = parents.to_vec();
let coeffs = coefficients.clone();
let mechanism = Mechanism::new(move |parent_values| {
let sum: f64 = parent_values.iter()
.zip(coeffs.iter())
.map(|(v, c)| v.as_f64() * c)
.sum();
Value::Continuous(sum)
});
Self {
target: VariableId(0), // Will be set when added to model
parents: parents_vec,
mechanism,
}
}
/// Create a structural equation with additive noise: Y = sum(coefficients[i] * parents[i]) + noise
pub fn with_noise(parents: &[VariableId], coefficients: Vec<f64>) -> Self {
let parents_vec = parents.to_vec();
let coeffs = coefficients.clone();
let mechanism = Mechanism::with_noise(
move |parent_values| {
let sum: f64 = parent_values.iter()
.zip(coeffs.iter())
.map(|(v, c)| v.as_f64() * c)
.sum();
Value::Continuous(sum)
},
1.0, // Default noise scale
);
Self {
target: VariableId(0), // Will be set when added to model
parents: parents_vec,
mechanism,
}
}
/// Compute the value of the target given parent values
pub fn compute(&self, parent_values: &[Value]) -> Value {
self.mechanism.apply(parent_values)
}
}
impl std::fmt::Debug for StructuralEquation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StructuralEquation")
.field("target", &self.target)
.field("parents", &self.parents)
.finish()
}
}
/// An intervention: do(X = x)
#[derive(Debug, Clone)]
pub struct Intervention {
/// Variable being intervened on
pub target: VariableId,
/// Value to set
pub value: Value,
}
impl Intervention {
/// Create a new intervention
pub fn new(target: VariableId, value: Value) -> Self {
Self { target, value }
}
/// Create from variable name (requires model lookup)
pub fn from_name(model: &CausalModel, name: &str, value: Value) -> Option<Self> {
model.get_variable_id(name).map(|id| Self::new(id, value))
}
}
/// Error types for causal model operations
#[derive(Debug, Clone, Error)]
pub enum CausalModelError {
/// Variable not found
#[error("Variable '{0}' not found")]
VariableNotFound(String),
/// Variable ID not found
#[error("Variable ID {0:?} not found")]
VariableIdNotFound(VariableId),
/// Duplicate variable name
#[error("Variable '{0}' already exists")]
DuplicateVariable(String),
/// DAG validation error
#[error("Graph error: {0}")]
GraphError(#[from] DAGValidationError),
/// Missing structural equation
#[error("No structural equation for variable {0:?}")]
MissingEquation(VariableId),
/// Invalid parent reference
#[error("Invalid parent reference: {0:?}")]
InvalidParent(VariableId),
/// Type mismatch
#[error("Type mismatch for variable {0}: expected {1:?}, got {2:?}")]
TypeMismatch(String, VariableType, Value),
/// Computation error
#[error("Computation error: {0}")]
ComputationError(String),
}
/// A Structural Causal Model (SCM)
#[derive(Debug, Clone)]
pub struct CausalModel {
/// Variables in the model
variables: HashMap<VariableId, Variable>,
/// Name to ID mapping
name_to_id: HashMap<String, VariableId>,
/// Structural equations
equations: HashMap<VariableId, StructuralEquation>,
/// Underlying DAG structure
graph: DirectedGraph,
/// Next variable ID
next_id: u32,
/// Model name
pub name: Option<String>,
/// Model description
pub description: Option<String>,
/// Latent confounders (unobserved common causes)
latent_confounders: Vec<(VariableId, VariableId)>,
/// Intervention values (for mutilated models)
intervention_values: HashMap<VariableId, Value>,
}
impl CausalModel {
/// Create a new empty causal model
pub fn new() -> Self {
Self {
variables: HashMap::new(),
name_to_id: HashMap::new(),
equations: HashMap::new(),
graph: DirectedGraph::new(),
next_id: 0,
name: None,
description: None,
latent_confounders: Vec::new(),
intervention_values: HashMap::new(),
}
}
/// Create a model with a name
pub fn with_name(name: &str) -> Self {
let mut model = Self::new();
model.name = Some(name.to_string());
model
}
/// Add a variable to the model
pub fn add_variable(&mut self, name: &str, var_type: VariableType) -> Result<VariableId, CausalModelError> {
if self.name_to_id.contains_key(name) {
return Err(CausalModelError::DuplicateVariable(name.to_string()));
}
let id = VariableId(self.next_id);
self.next_id += 1;
let variable = Variable::new(id, name, var_type);
self.variables.insert(id, variable);
self.name_to_id.insert(name.to_string(), id);
self.graph.add_node_with_label(id.0, name);
Ok(id)
}
/// Add a variable with full configuration
pub fn add_variable_with_config(&mut self, variable: Variable) -> Result<VariableId, CausalModelError> {
if self.name_to_id.contains_key(&variable.name) {
return Err(CausalModelError::DuplicateVariable(variable.name.clone()));
}
let id = variable.id;
self.name_to_id.insert(variable.name.clone(), id);
self.graph.add_node_with_label(id.0, &variable.name);
self.variables.insert(id, variable);
// Update next_id if necessary
if id.0 >= self.next_id {
self.next_id = id.0 + 1;
}
Ok(id)
}
/// Add a causal edge from parent to child
pub fn add_edge(&mut self, parent: VariableId, child: VariableId) -> Result<(), CausalModelError> {
if !self.variables.contains_key(&parent) {
return Err(CausalModelError::VariableIdNotFound(parent));
}
if !self.variables.contains_key(&child) {
return Err(CausalModelError::VariableIdNotFound(child));
}
self.graph.add_edge(parent.0, child.0)?;
Ok(())
}
/// Add a structural equation
pub fn add_structural_equation(
&mut self,
target: VariableId,
parents: &[VariableId],
mechanism: Mechanism,
) -> Result<(), CausalModelError> {
// Validate target exists
if !self.variables.contains_key(&target) {
return Err(CausalModelError::VariableIdNotFound(target));
}
// Validate parents exist and add edges
for &parent in parents {
if !self.variables.contains_key(&parent) {
return Err(CausalModelError::InvalidParent(parent));
}
self.graph.add_edge(parent.0, target.0)?;
}
let equation = StructuralEquation::new(target, parents.to_vec(), mechanism);
self.equations.insert(target, equation);
Ok(())
}
/// Add a structural equation using variable names
pub fn add_equation_by_name<F>(
&mut self,
target_name: &str,
parent_names: &[&str],
func: F,
) -> Result<(), CausalModelError>
where
F: Fn(&[Value]) -> Value + Send + Sync + 'static,
{
let target = self.get_variable_id(target_name)
.ok_or_else(|| CausalModelError::VariableNotFound(target_name.to_string()))?;
let parents: Result<Vec<VariableId>, _> = parent_names
.iter()
.map(|&name| {
self.get_variable_id(name)
.ok_or_else(|| CausalModelError::VariableNotFound(name.to_string()))
})
.collect();
let mechanism = Mechanism::new(func);
self.add_structural_equation(target, &parents?, mechanism)
}
/// Get variable ID by name
pub fn get_variable_id(&self, name: &str) -> Option<VariableId> {
self.name_to_id.get(name).copied()
}
/// Get variable name by ID
pub fn get_variable_name(&self, id: &VariableId) -> Option<String> {
self.variables.get(id).map(|v| v.name.clone())
}
/// Get variable by ID
pub fn get_variable(&self, id: &VariableId) -> Option<&Variable> {
self.variables.get(id)
}
/// Get all variables
pub fn variables(&self) -> impl Iterator<Item = &Variable> {
self.variables.values()
}
/// Get number of variables
pub fn num_variables(&self) -> usize {
self.variables.len()
}
/// Alias for num_variables (for API compatibility)
pub fn variable_count(&self) -> usize {
self.variables.len()
}
/// Check if the model is a valid DAG
pub fn is_dag(&self) -> bool {
let mut graph = self.graph.clone();
graph.topological_order().is_ok()
}
/// Set a structural equation for a variable (convenience method)
pub fn set_structural_equation(&mut self, target: VariableId, equation: StructuralEquation) {
// Add edges from parents to target
for &parent in &equation.parents {
let _ = self.graph.add_edge(parent.0, target.0);
}
// Create a new equation with the correct target
let eq = StructuralEquation {
target,
parents: equation.parents,
mechanism: equation.mechanism,
};
self.equations.insert(target, eq);
}
/// Add latent confounding between two variables
pub fn add_latent_confounding(&mut self, var1: VariableId, var2: VariableId) {
self.latent_confounders.push((var1, var2));
}
/// Check if two variables are unconfounded (no latent common cause)
pub fn is_unconfounded(&self, var1: VariableId, var2: VariableId) -> bool {
!self.latent_confounders.iter().any(|&(a, b)| {
(a == var1 && b == var2) || (a == var2 && b == var1)
})
}
/// Check if there are latent confounders affecting a variable
pub fn has_latent_confounding(&self, var: VariableId) -> bool {
self.latent_confounders.iter().any(|&(a, b)| a == var || b == var)
}
/// Get children of a variable
pub fn children(&self, id: &VariableId) -> Option<Vec<VariableId>> {
self.graph.children_of(id.0).map(|children| {
children.iter().map(|&c| VariableId(c)).collect()
})
}
/// Get parents of a variable
pub fn parents(&self, id: &VariableId) -> Option<Vec<VariableId>> {
self.graph.parents_of(id.0).map(|parents| {
parents.iter().map(|&p| VariableId(p)).collect()
})
}
/// Compute topological ordering
pub fn topological_order(&self) -> Result<Vec<String>, CausalModelError> {
let mut graph = self.graph.clone();
let order = graph.topological_order()?;
Ok(order.iter()
.filter_map(|&id| self.variables.get(&VariableId(id)).map(|v| v.name.clone()))
.collect())
}
/// Compute topological ordering of variable IDs
pub fn topological_order_ids(&self) -> Result<Vec<VariableId>, CausalModelError> {
let mut graph = self.graph.clone();
let order = graph.topological_order()?;
Ok(order.iter().map(|&id| VariableId(id)).collect())
}
/// Perform an intervention and compute the resulting distribution
///
/// This implements the do-operator: do(X = x)
pub fn intervene(&self, target: VariableId, value: Value) -> Result<MutilatedModel, CausalModelError> {
if !self.variables.contains_key(&target) {
return Err(CausalModelError::VariableIdNotFound(target));
}
// Create a mutilated model (clone with incoming edges removed)
let mut mutilated = self.clone();
// Remove incoming edges to the intervened variable
if let Some(parents) = self.graph.parents_of(target.0).cloned() {
for parent in parents {
mutilated.graph.remove_edge(parent, target.0).ok();
}
}
// Set the equation to return the intervention value
let intervention_value = value.clone();
mutilated.equations.insert(target, StructuralEquation {
target,
parents: vec![],
mechanism: Mechanism::new(move |_| intervention_value.clone()),
});
// Store the intervention value for reference
mutilated.intervention_values.insert(target, value);
Ok(MutilatedModel { model: mutilated })
}
/// Perform an intervention using a slice of Intervention structs
pub fn intervene_with(&self, interventions: &[Intervention]) -> Result<IntervenedModel, CausalModelError> {
let intervention_map: HashMap<VariableId, Value> = interventions
.iter()
.map(|i| (i.target, i.value.clone()))
.collect();
Ok(IntervenedModel {
base_model: self,
interventions: intervention_map,
})
}
/// Perform multiple simultaneous interventions
pub fn multi_intervene(&self, interventions: &[(VariableId, Value)]) -> Result<MutilatedModel, CausalModelError> {
let mut mutilated = self.clone();
for (target, value) in interventions {
if !self.variables.contains_key(target) {
return Err(CausalModelError::VariableIdNotFound(*target));
}
// Remove incoming edges
if let Some(parents) = self.graph.parents_of(target.0).cloned() {
for parent in parents {
mutilated.graph.remove_edge(parent, target.0).ok();
}
}
// Set constant equation
let intervention_value = value.clone();
mutilated.equations.insert(*target, StructuralEquation {
target: *target,
parents: vec![],
mechanism: Mechanism::new(move |_| intervention_value.clone()),
});
mutilated.intervention_values.insert(*target, value.clone());
}
Ok(MutilatedModel { model: mutilated })
}
/// Forward simulation: compute all variable values given exogenous inputs
pub fn forward_simulate(&self, exogenous: &HashMap<VariableId, Value>) -> Result<HashMap<VariableId, Value>, CausalModelError> {
let order = self.topological_order_ids()?;
let mut values: HashMap<VariableId, Value> = exogenous.clone();
for var_id in order {
if values.contains_key(&var_id) {
continue; // Already set (exogenous or intervened)
}
if let Some(equation) = self.equations.get(&var_id) {
let parent_values: Vec<Value> = equation.parents
.iter()
.map(|&p| values.get(&p).cloned().unwrap_or(Value::Missing))
.collect();
let value = equation.compute(&parent_values);
values.insert(var_id, value);
} else {
// No equation - must be exogenous
if !exogenous.contains_key(&var_id) {
return Err(CausalModelError::MissingEquation(var_id));
}
}
}
Ok(values)
}
/// Check if two variables are d-separated given a conditioning set
pub fn d_separated(&self, x: VariableId, y: VariableId, z: &[VariableId]) -> bool {
let x_set = [x.0].into_iter().collect();
let y_set = [y.0].into_iter().collect();
let z_set: std::collections::HashSet<_> = z.iter().map(|id| id.0).collect();
self.graph.d_separated(&x_set, &y_set, &z_set)
}
/// Get the structural equation for a variable
pub fn get_equation(&self, id: &VariableId) -> Option<&StructuralEquation> {
self.equations.get(id)
}
/// Get the underlying DAG
pub fn graph(&self) -> &DirectedGraph {
&self.graph
}
/// Check if the model is valid (all endogenous variables have equations)
pub fn validate(&self) -> Result<(), CausalModelError> {
// Check for cycles
let mut graph = self.graph.clone();
graph.topological_order()?;
// Check that non-root variables have equations
for (&id, _) in &self.variables {
let parents = self.graph.parents_of(id.0);
if parents.map(|p| !p.is_empty()).unwrap_or(false) {
// Has parents, so should have an equation
if !self.equations.contains_key(&id) {
return Err(CausalModelError::MissingEquation(id));
}
}
}
Ok(())
}
/// Compute the conditional distribution P(Y | observation)
pub fn conditional_distribution(&self, observation: &Observation, target_name: &str) -> Result<Distribution, CausalModelError> {
let target_id = self.get_variable_id(target_name)
.ok_or_else(|| CausalModelError::VariableNotFound(target_name.to_string()))?;
// Convert observation to exogenous values
let mut exogenous = HashMap::new();
for (name, value) in &observation.values {
if let Some(id) = self.get_variable_id(name) {
exogenous.insert(id, value.clone());
}
}
// Forward simulate
let result = self.forward_simulate(&exogenous)?;
let value = result.get(&target_id)
.cloned()
.unwrap_or(Value::Missing);
Ok(Distribution::point(target_id, value))
}
/// Compute the marginal distribution P(Y)
pub fn marginal_distribution(&self, target_name: &str) -> Result<Distribution, CausalModelError> {
let target_id = self.get_variable_id(target_name)
.ok_or_else(|| CausalModelError::VariableNotFound(target_name.to_string()))?;
// Simulate with empty exogenous (use default/zero values)
let result = self.forward_simulate(&self.intervention_values)?;
let value = result.get(&target_id)
.cloned()
.unwrap_or(Value::Missing);
Ok(Distribution::point(target_id, value))
}
}
/// An observation of variable values (for conditioning)
#[derive(Debug, Clone)]
pub struct Observation {
/// Observed variable values by name
pub values: HashMap<String, Value>,
}
impl Observation {
/// Create a new observation from name-value pairs
pub fn new(values: &[(&str, Value)]) -> Self {
Self {
values: values.iter()
.map(|(k, v)| (k.to_string(), v.clone()))
.collect(),
}
}
/// Create an empty observation
pub fn empty() -> Self {
Self {
values: HashMap::new(),
}
}
/// Add an observed value
pub fn observe(&mut self, var: &str, value: Value) {
self.values.insert(var.to_string(), value);
}
/// Get an observed value
pub fn get(&self, var: &str) -> Option<&Value> {
self.values.get(var)
}
/// Check if a variable is observed
pub fn is_observed(&self, var: &str) -> bool {
self.values.contains_key(var)
}
}
impl Default for CausalModel {
fn default() -> Self {
Self::new()
}
}
/// A causal model with interventions applied
pub struct IntervenedModel<'a> {
base_model: &'a CausalModel,
interventions: HashMap<VariableId, Value>,
}
impl<'a> IntervenedModel<'a> {
/// Simulate the intervened model
pub fn simulate(&self, exogenous: &HashMap<VariableId, Value>) -> Result<HashMap<VariableId, Value>, CausalModelError> {
let order = self.base_model.topological_order_ids()?;
let mut values: HashMap<VariableId, Value> = exogenous.clone();
// Apply interventions first
for (var, val) in &self.interventions {
values.insert(*var, val.clone());
}
for var_id in order {
if values.contains_key(&var_id) {
continue;
}
// Check if this variable is intervened
if let Some(intervention_value) = self.interventions.get(&var_id) {
values.insert(var_id, intervention_value.clone());
continue;
}
if let Some(equation) = self.base_model.equations.get(&var_id) {
let parent_values: Vec<Value> = equation.parents
.iter()
.map(|&p| values.get(&p).cloned().unwrap_or(Value::Missing))
.collect();
let value = equation.compute(&parent_values);
values.insert(var_id, value);
}
}
Ok(values)
}
/// Check if a variable is intervened
pub fn is_intervened(&self, var: VariableId) -> bool {
self.interventions.contains_key(&var)
}
/// Get the intervention value for a variable
pub fn intervention_value(&self, var: VariableId) -> Option<&Value> {
self.interventions.get(&var)
}
}
/// A mutilated causal model (with interventions applied)
///
/// This is a complete copy of the model with incoming edges to intervened
/// variables removed, representing the do-operator graph transformation.
#[derive(Debug, Clone)]
pub struct MutilatedModel {
/// The mutilated model
pub model: CausalModel,
}
impl MutilatedModel {
/// Get parents of a variable in the mutilated model
pub fn parents(&self, id: &VariableId) -> Result<Vec<VariableId>, CausalModelError> {
self.model.parents(id).ok_or(CausalModelError::VariableIdNotFound(*id))
}
/// Compute the value of a variable by name
pub fn compute(&self, var_name: &str) -> Result<Value, CausalModelError> {
let var_id = self.model.get_variable_id(var_name)
.ok_or_else(|| CausalModelError::VariableNotFound(var_name.to_string()))?;
// Forward simulate with intervention values as exogenous
let result = self.model.forward_simulate(&self.model.intervention_values)?;
result.get(&var_id)
.cloned()
.ok_or_else(|| CausalModelError::ComputationError(format!("Variable {} not computed", var_name)))
}
/// Get the marginal distribution of a variable (point mass for deterministic models)
pub fn marginal_distribution(&self, var_name: &str) -> Result<Distribution, CausalModelError> {
let value = self.compute(var_name)?;
let var_id = self.model.get_variable_id(var_name)
.ok_or_else(|| CausalModelError::VariableNotFound(var_name.to_string()))?;
Ok(Distribution::point(var_id, value))
}
/// Simulate the mutilated model with optional exogenous inputs
pub fn simulate(&self, exogenous: &HashMap<VariableId, Value>) -> Result<HashMap<VariableId, Value>, CausalModelError> {
// Merge exogenous with intervention values
let mut all_exogenous = self.model.intervention_values.clone();
all_exogenous.extend(exogenous.iter().map(|(k, v)| (*k, v.clone())));
self.model.forward_simulate(&all_exogenous)
}
/// Check if a variable is intervened
pub fn is_intervened(&self, var: &VariableId) -> bool {
self.model.intervention_values.contains_key(var)
}
/// Check if the mutilated model is still a DAG
pub fn is_dag(&self) -> bool {
self.model.is_dag()
}
/// Get the underlying model
pub fn inner(&self) -> &CausalModel {
&self.model
}
}
/// A simple probability distribution representation
#[derive(Debug, Clone)]
pub struct Distribution {
/// Variable ID
pub variable: VariableId,
/// Value (point mass for deterministic)
pub value: Value,
/// Probability mass
pub probability: f64,
}
impl Distribution {
/// Create a point mass distribution
pub fn point(variable: VariableId, value: Value) -> Self {
Self {
variable,
value,
probability: 1.0,
}
}
/// Get the expected value (for continuous)
pub fn expected_value(&self) -> f64 {
self.value.as_f64()
}
}
impl PartialEq for Distribution {
fn eq(&self, other: &Self) -> bool {
self.variable == other.variable &&
(self.probability - other.probability).abs() < 1e-10 &&
match (&self.value, &other.value) {
(Value::Continuous(a), Value::Continuous(b)) => (a - b).abs() < 1e-10,
(Value::Discrete(a), Value::Discrete(b)) => a == b,
(Value::Binary(a), Value::Binary(b)) => a == b,
_ => false,
}
}
}
/// Builder for creating causal models fluently
pub struct CausalModelBuilder {
model: CausalModel,
}
impl CausalModelBuilder {
/// Create a new builder
pub fn new() -> Self {
Self {
model: CausalModel::new(),
}
}
/// Create a builder with a model name
pub fn with_name(name: &str) -> Self {
Self {
model: CausalModel::with_name(name),
}
}
/// Add a continuous variable
pub fn add_continuous(mut self, name: &str) -> Self {
self.model.add_variable(name, VariableType::Continuous).ok();
self
}
/// Add a binary variable
pub fn add_binary(mut self, name: &str) -> Self {
self.model.add_variable(name, VariableType::Binary).ok();
self
}
/// Add a discrete variable
pub fn add_discrete(mut self, name: &str) -> Self {
self.model.add_variable(name, VariableType::Discrete).ok();
self
}
/// Add a causal relationship
pub fn add_cause(mut self, cause: &str, effect: &str) -> Self {
if let (Some(c), Some(e)) = (
self.model.get_variable_id(cause),
self.model.get_variable_id(effect),
) {
self.model.add_edge(c, e).ok();
}
self
}
/// Add a structural equation
pub fn with_equation<F>(mut self, target: &str, parents: &[&str], func: F) -> Self
where
F: Fn(&[Value]) -> Value + Send + Sync + 'static,
{
self.model.add_equation_by_name(target, parents, func).ok();
self
}
/// Build the model
pub fn build(self) -> CausalModel {
self.model
}
}
impl Default for CausalModelBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_model() {
let mut model = CausalModel::new();
let x = model.add_variable("X", VariableType::Continuous).unwrap();
let y = model.add_variable("Y", VariableType::Continuous).unwrap();
assert_eq!(model.num_variables(), 2);
assert_eq!(model.get_variable_id("X"), Some(x));
assert_eq!(model.get_variable_id("Y"), Some(y));
}
#[test]
fn test_add_edges() {
let mut model = CausalModel::new();
let x = model.add_variable("X", VariableType::Continuous).unwrap();
let y = model.add_variable("Y", VariableType::Continuous).unwrap();
model.add_edge(x, y).unwrap();
assert_eq!(model.children(&x), Some(vec![y]));
assert_eq!(model.parents(&y), Some(vec![x]));
}
#[test]
fn test_structural_equation() {
let mut model = CausalModel::new();
let x = model.add_variable("X", VariableType::Continuous).unwrap();
let y = model.add_variable("Y", VariableType::Continuous).unwrap();
// Y = 2*X + 1
let mechanism = Mechanism::new(|parents| {
let x_val = parents[0].as_f64();
Value::Continuous(2.0 * x_val + 1.0)
});
model.add_structural_equation(y, &[x], mechanism).unwrap();
// Simulate
let mut exogenous = HashMap::new();
exogenous.insert(x, Value::Continuous(3.0));
let result = model.forward_simulate(&exogenous).unwrap();
assert_eq!(result.get(&y), Some(&Value::Continuous(7.0)));
}
#[test]
fn test_intervention() {
let mut model = CausalModel::new();
let x = model.add_variable("X", VariableType::Continuous).unwrap();
let y = model.add_variable("Y", VariableType::Continuous).unwrap();
let z = model.add_variable("Z", VariableType::Continuous).unwrap();
// Y = X, Z = Y
model.add_structural_equation(y, &[x], Mechanism::new(|p| p[0].clone())).unwrap();
model.add_structural_equation(z, &[y], Mechanism::new(|p| p[0].clone())).unwrap();
// Intervene: do(Y = 5)
let intervention = Intervention::new(y, Value::Continuous(5.0));
let intervened = model.intervene(&[intervention]).unwrap();
let mut exogenous = HashMap::new();
exogenous.insert(x, Value::Continuous(10.0)); // X = 10
let result = intervened.simulate(&exogenous).unwrap();
// X should still be 10
assert_eq!(result.get(&x).unwrap().as_f64(), 10.0);
// Y should be 5 (intervened)
assert_eq!(result.get(&y).unwrap().as_f64(), 5.0);
// Z should be 5 (from Y)
assert_eq!(result.get(&z).unwrap().as_f64(), 5.0);
}
#[test]
fn test_builder() {
let model = CausalModelBuilder::new()
.add_continuous("Age")
.add_continuous("Income")
.add_binary("Employed")
.add_cause("Age", "Income")
.add_cause("Employed", "Income")
.build();
assert_eq!(model.num_variables(), 3);
let age = model.get_variable_id("Age").unwrap();
let income = model.get_variable_id("Income").unwrap();
assert_eq!(model.children(&age), Some(vec![income]));
}
#[test]
fn test_d_separation() {
let mut model = CausalModel::new();
// Chain: X -> Z -> Y
let x = model.add_variable("X", VariableType::Continuous).unwrap();
let z = model.add_variable("Z", VariableType::Continuous).unwrap();
let y = model.add_variable("Y", VariableType::Continuous).unwrap();
model.add_edge(x, z).unwrap();
model.add_edge(z, y).unwrap();
// X and Y are NOT d-separated given empty set
assert!(!model.d_separated(x, y, &[]));
// X and Y ARE d-separated given Z
assert!(model.d_separated(x, y, &[z]));
}
#[test]
fn test_topological_order() {
let mut model = CausalModel::new();
let a = model.add_variable("A", VariableType::Continuous).unwrap();
let b = model.add_variable("B", VariableType::Continuous).unwrap();
let c = model.add_variable("C", VariableType::Continuous).unwrap();
model.add_edge(a, b).unwrap();
model.add_edge(b, c).unwrap();
let order = model.topological_order().unwrap();
let pos_a = order.iter().position(|n| n == "A").unwrap();
let pos_b = order.iter().position(|n| n == "B").unwrap();
let pos_c = order.iter().position(|n| n == "C").unwrap();
assert!(pos_a < pos_b);
assert!(pos_b < pos_c);
}
#[test]
fn test_value_conversions() {
let continuous = Value::Continuous(3.14);
assert!((continuous.as_f64() - 3.14).abs() < 1e-10);
let binary = Value::Binary(true);
assert_eq!(binary.as_bool(), Some(true));
assert!((binary.as_f64() - 1.0).abs() < 1e-10);
let discrete = Value::Discrete(42);
assert!((discrete.as_f64() - 42.0).abs() < 1e-10);
let missing = Value::Missing;
assert!(missing.is_missing());
assert!(missing.as_f64().is_nan());
}
#[test]
fn test_duplicate_variable() {
let mut model = CausalModel::new();
model.add_variable("X", VariableType::Continuous).unwrap();
let result = model.add_variable("X", VariableType::Continuous);
assert!(matches!(result, Err(CausalModelError::DuplicateVariable(_))));
}
#[test]
fn test_model_validation() {
let mut model = CausalModel::new();
let x = model.add_variable("X", VariableType::Continuous).unwrap();
let y = model.add_variable("Y", VariableType::Continuous).unwrap();
model.add_edge(x, y).unwrap();
// Should fail - Y has parents but no equation
let result = model.validate();
assert!(matches!(result, Err(CausalModelError::MissingEquation(_))));
// Add equation
model.add_structural_equation(y, &[x], Mechanism::new(|p| p[0].clone())).unwrap();
// Should pass now
model.validate().unwrap();
}
}