git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
806 lines
26 KiB
Rust
806 lines
26 KiB
Rust
//! Counterfactual Reasoning
|
|
//!
|
|
//! This module implements counterfactual inference based on Pearl's three-step
|
|
//! procedure: Abduction, Action, Prediction.
|
|
//!
|
|
//! ## Counterfactual Semantics
|
|
//!
|
|
//! Given a structural causal model M = (U, V, F), a counterfactual query asks:
|
|
//! "What would Y have been if X had been x, given that we observed E = e?"
|
|
//!
|
|
//! Written as: P(Y_x | E = e) or Y_{X=x}(u) where u is the exogenous state.
|
|
//!
|
|
//! ## Three-Step Procedure
|
|
//!
|
|
//! 1. **Abduction**: Update P(U) given evidence E = e to get P(U | E = e)
|
|
//! 2. **Action**: Modify the model by intervention do(X = x)
|
|
//! 3. **Prediction**: Compute Y in the modified model using updated U
|
|
//!
|
|
//! ## References
|
|
//!
|
|
//! - Pearl (2009): "Causality" Chapter 7
|
|
//! - Halpern (2016): "Actual Causality"
|
|
|
|
use std::collections::HashMap;
|
|
use thiserror::Error;
|
|
|
|
use super::model::{
|
|
CausalModel, CausalModelError, Intervention, Value, VariableId, Mechanism,
|
|
Observation,
|
|
};
|
|
|
|
/// Error types for counterfactual reasoning
|
|
#[derive(Debug, Clone, Error)]
|
|
pub enum CounterfactualError {
|
|
/// Model error
|
|
#[error("Model error: {0}")]
|
|
ModelError(#[from] CausalModelError),
|
|
|
|
/// Invalid observation
|
|
#[error("Invalid observation: variable '{0}' not in model")]
|
|
InvalidObservation(String),
|
|
|
|
/// Abduction failed
|
|
#[error("Abduction failed: {0}")]
|
|
AbductionFailed(String),
|
|
|
|
/// Counterfactual not well-defined
|
|
#[error("Counterfactual not well-defined: {0}")]
|
|
NotWellDefined(String),
|
|
}
|
|
|
|
/// Extended Distribution type for counterfactual results
|
|
#[derive(Debug, Clone)]
|
|
pub struct CounterfactualDistribution {
|
|
/// Point estimate values (for deterministic models)
|
|
pub values: HashMap<VariableId, Value>,
|
|
/// Probability mass (for discrete) or density (for continuous)
|
|
pub probability: f64,
|
|
}
|
|
|
|
impl CounterfactualDistribution {
|
|
/// Create a point mass distribution
|
|
pub fn point_mass(values: HashMap<VariableId, Value>) -> Self {
|
|
Self {
|
|
values,
|
|
probability: 1.0,
|
|
}
|
|
}
|
|
|
|
/// Create from a simulation result
|
|
pub fn from_simulation(values: HashMap<VariableId, Value>) -> Self {
|
|
Self::point_mass(values)
|
|
}
|
|
|
|
/// Get value for a variable
|
|
pub fn get(&self, var: VariableId) -> Option<&Value> {
|
|
self.values.get(&var)
|
|
}
|
|
|
|
/// Get mean value (for continuous distributions)
|
|
pub fn mean(&self, var: VariableId) -> f64 {
|
|
self.values.get(&var)
|
|
.map(|v| v.as_f64())
|
|
.unwrap_or(0.0)
|
|
}
|
|
}
|
|
|
|
/// A counterfactual query
|
|
#[derive(Debug, Clone)]
|
|
pub struct CounterfactualQuery {
|
|
/// Target variable (what we want to know)
|
|
pub target: String,
|
|
/// Interventions (what we're hypothetically changing)
|
|
pub interventions: Vec<(String, Value)>,
|
|
/// Evidence (what we observed)
|
|
pub evidence: Observation,
|
|
}
|
|
|
|
impl CounterfactualQuery {
|
|
/// Create a new counterfactual query
|
|
///
|
|
/// Asking: "What would `target` have been if we had done `interventions`,
|
|
/// given that we observed `evidence`?"
|
|
pub fn new(target: &str, interventions: Vec<(&str, Value)>, evidence: Observation) -> Self {
|
|
Self {
|
|
target: target.to_string(),
|
|
interventions: interventions.into_iter()
|
|
.map(|(k, v)| (k.to_string(), v))
|
|
.collect(),
|
|
evidence,
|
|
}
|
|
}
|
|
|
|
/// Simple counterfactual: "What would Y have been if X had been x?"
|
|
pub fn simple(target: &str, intervention_var: &str, intervention_val: Value) -> Self {
|
|
Self {
|
|
target: target.to_string(),
|
|
interventions: vec![(intervention_var.to_string(), intervention_val)],
|
|
evidence: Observation::empty(),
|
|
}
|
|
}
|
|
|
|
/// Add evidence to the query
|
|
pub fn given(mut self, var: &str, value: Value) -> Self {
|
|
self.evidence.observe(var, value);
|
|
self
|
|
}
|
|
}
|
|
|
|
/// Result of a counterfactual query
|
|
#[derive(Debug, Clone)]
|
|
pub struct CounterfactualResult {
|
|
/// The query that was answered
|
|
pub query: CounterfactualQuery,
|
|
/// The counterfactual distribution
|
|
pub distribution: CounterfactualDistribution,
|
|
/// The abduced exogenous values
|
|
pub exogenous: HashMap<VariableId, Value>,
|
|
/// Explanation of the reasoning
|
|
pub explanation: String,
|
|
}
|
|
|
|
/// Average Treatment Effect computation
|
|
#[derive(Debug, Clone)]
|
|
pub struct AverageTreatmentEffect {
|
|
/// Treatment variable
|
|
pub treatment: String,
|
|
/// Outcome variable
|
|
pub outcome: String,
|
|
/// Treatment value
|
|
pub treatment_value: Value,
|
|
/// Control value
|
|
pub control_value: Value,
|
|
/// Estimated ATE
|
|
pub ate: f64,
|
|
/// Standard error (if available)
|
|
pub standard_error: Option<f64>,
|
|
}
|
|
|
|
impl AverageTreatmentEffect {
|
|
/// Create a new ATE result
|
|
pub fn new(
|
|
treatment: &str,
|
|
outcome: &str,
|
|
treatment_value: Value,
|
|
control_value: Value,
|
|
ate: f64,
|
|
) -> Self {
|
|
Self {
|
|
treatment: treatment.to_string(),
|
|
outcome: outcome.to_string(),
|
|
treatment_value,
|
|
control_value,
|
|
ate,
|
|
standard_error: None,
|
|
}
|
|
}
|
|
|
|
/// Set standard error
|
|
pub fn with_standard_error(mut self, se: f64) -> Self {
|
|
self.standard_error = Some(se);
|
|
self
|
|
}
|
|
}
|
|
|
|
/// Compute a counterfactual: "What would Y have been if X had been x, given observation?"
|
|
///
|
|
/// This implements Pearl's three-step procedure:
|
|
/// 1. Abduction: Infer exogenous variables from observation
|
|
/// 2. Action: Apply intervention do(X = x)
|
|
/// 3. Prediction: Compute Y under intervention with abduced exogenous values
|
|
///
|
|
/// # Arguments
|
|
/// * `model` - The causal model
|
|
/// * `observation` - The observed evidence
|
|
/// * `intervention_var` - The variable to intervene on
|
|
/// * `intervention_value` - The counterfactual value for the intervention
|
|
/// * `target_name` - The target variable to compute the counterfactual for
|
|
pub fn counterfactual(
|
|
model: &CausalModel,
|
|
observation: &Observation,
|
|
intervention_var: VariableId,
|
|
intervention_value: Value,
|
|
target_name: &str,
|
|
) -> Result<Value, CounterfactualError> {
|
|
// Step 1: Abduction - infer exogenous variables
|
|
let exogenous = abduce(model, observation)?;
|
|
|
|
// Step 2: Action - create intervened model
|
|
let intervention = Intervention::new(intervention_var, intervention_value);
|
|
let intervened = model.intervene_with(&[intervention])?;
|
|
|
|
// Step 3: Prediction - simulate with abduced exogenous and intervention
|
|
let result = intervened.simulate(&exogenous)?;
|
|
|
|
// Get the target variable
|
|
let target_id = model.get_variable_id(target_name)
|
|
.ok_or_else(|| CounterfactualError::InvalidObservation(target_name.to_string()))?;
|
|
|
|
result.get(&target_id)
|
|
.cloned()
|
|
.ok_or_else(|| CounterfactualError::AbductionFailed(
|
|
format!("Target variable {} not computed", target_name)
|
|
))
|
|
}
|
|
|
|
/// Compute a counterfactual with an Intervention struct (alternative API)
|
|
pub fn counterfactual_with_intervention(
|
|
model: &CausalModel,
|
|
observation: &Observation,
|
|
intervention: &Intervention,
|
|
) -> Result<CounterfactualDistribution, CounterfactualError> {
|
|
// Step 1: Abduction - infer exogenous variables
|
|
let exogenous = abduce(model, observation)?;
|
|
|
|
// Step 2: Action - create intervened model
|
|
let intervened = model.intervene_with(&[intervention.clone()])?;
|
|
|
|
// Step 3: Prediction - simulate with abduced exogenous and intervention
|
|
let result = intervened.simulate(&exogenous)?;
|
|
|
|
Ok(CounterfactualDistribution::from_simulation(result))
|
|
}
|
|
|
|
/// Compute counterfactual from a query object
|
|
pub fn counterfactual_query(
|
|
model: &CausalModel,
|
|
query: &CounterfactualQuery,
|
|
) -> Result<CounterfactualResult, CounterfactualError> {
|
|
// Convert interventions
|
|
let interventions: Result<Vec<Intervention>, _> = query.interventions.iter()
|
|
.map(|(var, val)| {
|
|
model.get_variable_id(var)
|
|
.map(|id| Intervention::new(id, val.clone()))
|
|
.ok_or_else(|| CounterfactualError::InvalidObservation(var.clone()))
|
|
})
|
|
.collect();
|
|
let interventions = interventions?;
|
|
|
|
// Step 1: Abduction
|
|
let exogenous = abduce(model, &query.evidence)?;
|
|
|
|
// Step 2 & 3: Action and Prediction
|
|
let intervened = model.intervene_with(&interventions)?;
|
|
let result = intervened.simulate(&exogenous)?;
|
|
|
|
let target_id = model.get_variable_id(&query.target)
|
|
.ok_or_else(|| CounterfactualError::InvalidObservation(query.target.clone()))?;
|
|
|
|
let explanation = format!(
|
|
"Counterfactual computed via three-step procedure:\n\
|
|
1. Abduced exogenous values from evidence\n\
|
|
2. Applied intervention(s): {}\n\
|
|
3. Predicted {} under intervention",
|
|
query.interventions.iter()
|
|
.map(|(v, val)| format!("do({}={:?})", v, val))
|
|
.collect::<Vec<_>>()
|
|
.join(", "),
|
|
query.target
|
|
);
|
|
|
|
Ok(CounterfactualResult {
|
|
query: query.clone(),
|
|
distribution: CounterfactualDistribution::from_simulation(result),
|
|
exogenous,
|
|
explanation,
|
|
})
|
|
}
|
|
|
|
/// Abduction: Infer exogenous variable values from observations
|
|
///
|
|
/// For deterministic models, this inverts the structural equations
|
|
fn abduce(
|
|
model: &CausalModel,
|
|
observation: &Observation,
|
|
) -> Result<HashMap<VariableId, Value>, CounterfactualError> {
|
|
let mut exogenous = HashMap::new();
|
|
|
|
// Get topological order
|
|
let topo_order = model.topological_order_ids()?;
|
|
|
|
// For each variable in topological order
|
|
for var_id in topo_order {
|
|
let var = model.get_variable(var_id.as_ref())
|
|
.ok_or_else(|| CounterfactualError::AbductionFailed(
|
|
format!("Variable {:?} not found", var_id)
|
|
))?;
|
|
|
|
// Check if this variable is observed
|
|
if let Some(observed_value) = observation.values.get(&var.name) {
|
|
// If this is a root variable (no parents), it's exogenous
|
|
if model.parents(&var_id).map(|p| p.is_empty()).unwrap_or(true) {
|
|
exogenous.insert(var_id, observed_value.clone());
|
|
} else {
|
|
// For endogenous variables, we might need to compute the residual
|
|
// For now, we use the observed value as the exogenous noise
|
|
exogenous.insert(var_id, observed_value.clone());
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(exogenous)
|
|
}
|
|
|
|
/// Compute the Average Treatment Effect (ATE)
|
|
///
|
|
/// ATE = E[Y | do(X=treatment_value)] - E[Y | do(X=control_value)]
|
|
///
|
|
/// # Arguments
|
|
/// * `model` - The causal model
|
|
/// * `treatment` - Treatment variable ID
|
|
/// * `outcome` - Outcome variable ID
|
|
/// * `treatment_value` - The treatment value
|
|
/// * `control_value` - The control value
|
|
pub fn causal_effect(
|
|
model: &CausalModel,
|
|
treatment: VariableId,
|
|
outcome: VariableId,
|
|
treatment_value: Value,
|
|
control_value: Value,
|
|
) -> Result<f64, CounterfactualError> {
|
|
causal_effect_at_values(
|
|
model,
|
|
treatment,
|
|
outcome,
|
|
treatment_value,
|
|
control_value,
|
|
)
|
|
}
|
|
|
|
/// Compute the Average Treatment Effect with default binary values
|
|
///
|
|
/// ATE = E[Y | do(X=1)] - E[Y | do(X=0)]
|
|
pub fn causal_effect_binary(
|
|
model: &CausalModel,
|
|
treatment: VariableId,
|
|
outcome: VariableId,
|
|
) -> Result<f64, CounterfactualError> {
|
|
causal_effect_at_values(
|
|
model,
|
|
treatment,
|
|
outcome,
|
|
Value::Continuous(1.0),
|
|
Value::Continuous(0.0),
|
|
)
|
|
}
|
|
|
|
/// Compute causal effect at specific treatment values
|
|
pub fn causal_effect_at_values(
|
|
model: &CausalModel,
|
|
treatment: VariableId,
|
|
outcome: VariableId,
|
|
treatment_value: Value,
|
|
control_value: Value,
|
|
) -> Result<f64, CounterfactualError> {
|
|
// E[Y | do(X = treatment)]
|
|
let intervention_treat = Intervention::new(treatment, treatment_value);
|
|
let intervened_treat = model.intervene_with(&[intervention_treat])?;
|
|
let result_treat = intervened_treat.simulate(&HashMap::new())?;
|
|
let y_treat = result_treat.get(&outcome)
|
|
.map(|v| v.as_f64())
|
|
.unwrap_or(0.0);
|
|
|
|
// E[Y | do(X = control)]
|
|
let intervention_ctrl = Intervention::new(treatment, control_value);
|
|
let intervened_ctrl = model.intervene_with(&[intervention_ctrl])?;
|
|
let result_ctrl = intervened_ctrl.simulate(&HashMap::new())?;
|
|
let y_ctrl = result_ctrl.get(&outcome)
|
|
.map(|v| v.as_f64())
|
|
.unwrap_or(0.0);
|
|
|
|
Ok(y_treat - y_ctrl)
|
|
}
|
|
|
|
/// Compute ATE with full result structure
|
|
pub fn average_treatment_effect(
|
|
model: &CausalModel,
|
|
treatment_name: &str,
|
|
outcome_name: &str,
|
|
treatment_value: Value,
|
|
control_value: Value,
|
|
) -> Result<AverageTreatmentEffect, CounterfactualError> {
|
|
let treatment_id = model.get_variable_id(treatment_name)
|
|
.ok_or_else(|| CounterfactualError::InvalidObservation(treatment_name.to_string()))?;
|
|
let outcome_id = model.get_variable_id(outcome_name)
|
|
.ok_or_else(|| CounterfactualError::InvalidObservation(outcome_name.to_string()))?;
|
|
|
|
let ate = causal_effect_at_values(
|
|
model,
|
|
treatment_id,
|
|
outcome_id,
|
|
treatment_value.clone(),
|
|
control_value.clone(),
|
|
)?;
|
|
|
|
Ok(AverageTreatmentEffect::new(
|
|
treatment_name,
|
|
outcome_name,
|
|
treatment_value,
|
|
control_value,
|
|
ate,
|
|
))
|
|
}
|
|
|
|
/// Compute Individual Treatment Effect (ITE) for a specific unit
|
|
///
|
|
/// ITE_i = Y_i(X=1) - Y_i(X=0)
|
|
///
|
|
/// This is a counterfactual quantity: what would have happened to unit i
|
|
/// under different treatment assignments.
|
|
pub fn individual_treatment_effect(
|
|
model: &CausalModel,
|
|
treatment: VariableId,
|
|
outcome: VariableId,
|
|
unit_observation: &Observation,
|
|
treatment_value: Value,
|
|
control_value: Value,
|
|
) -> Result<f64, CounterfactualError> {
|
|
// Get the outcome variable name
|
|
let outcome_name = model.get_variable_name(&outcome)
|
|
.ok_or_else(|| CounterfactualError::InvalidObservation(format!("Outcome variable {:?} not found", outcome)))?;
|
|
|
|
// Counterfactual: Y(X=treatment) for this unit
|
|
let y_treat = counterfactual(model, unit_observation, treatment, treatment_value, &outcome_name)?;
|
|
let y_treat_val = y_treat.as_f64();
|
|
|
|
// Counterfactual: Y(X=control) for this unit
|
|
let y_ctrl = counterfactual(model, unit_observation, treatment, control_value, &outcome_name)?;
|
|
let y_ctrl_val = y_ctrl.as_f64();
|
|
|
|
Ok(y_treat_val - y_ctrl_val)
|
|
}
|
|
|
|
/// Natural Direct Effect (NDE)
|
|
///
|
|
/// NDE = E[Y(x, M(x'))] - E[Y(x', M(x'))]
|
|
///
|
|
/// The effect of X on Y that would remain if the mediator were held at the
|
|
/// value it would have taken under X = x'.
|
|
pub fn natural_direct_effect(
|
|
model: &CausalModel,
|
|
treatment: VariableId,
|
|
mediator: VariableId,
|
|
outcome: VariableId,
|
|
treatment_value: Value,
|
|
control_value: Value,
|
|
) -> Result<f64, CounterfactualError> {
|
|
// E[Y(x', M(x'))] - baseline
|
|
let ctrl_intervention = Intervention::new(treatment, control_value.clone());
|
|
let intervened = model.intervene_with(&[ctrl_intervention.clone()])?;
|
|
let baseline_result = intervened.simulate(&HashMap::new())?;
|
|
let m_ctrl = baseline_result.get(&mediator).cloned().unwrap_or(Value::Missing);
|
|
let y_baseline = baseline_result.get(&outcome)
|
|
.map(|v| v.as_f64())
|
|
.unwrap_or(0.0);
|
|
|
|
// E[Y(x, M(x'))] - intervene on X but keep M at control level
|
|
let treat_intervention = Intervention::new(treatment, treatment_value);
|
|
let m_intervention = Intervention::new(mediator, m_ctrl);
|
|
let intervened = model.intervene_with(&[treat_intervention, m_intervention])?;
|
|
let nde_result = intervened.simulate(&HashMap::new())?;
|
|
let y_nde = nde_result.get(&outcome)
|
|
.map(|v| v.as_f64())
|
|
.unwrap_or(0.0);
|
|
|
|
Ok(y_nde - y_baseline)
|
|
}
|
|
|
|
/// Natural Indirect Effect (NIE)
|
|
///
|
|
/// NIE = E[Y(x, M(x))] - E[Y(x, M(x'))]
|
|
///
|
|
/// The effect of X on Y that is mediated through M.
|
|
pub fn natural_indirect_effect(
|
|
model: &CausalModel,
|
|
treatment: VariableId,
|
|
mediator: VariableId,
|
|
outcome: VariableId,
|
|
treatment_value: Value,
|
|
control_value: Value,
|
|
) -> Result<f64, CounterfactualError> {
|
|
// E[Y(x, M(x))] - full treatment effect
|
|
let treat_intervention = Intervention::new(treatment, treatment_value.clone());
|
|
let intervened = model.intervene_with(&[treat_intervention.clone()])?;
|
|
let full_result = intervened.simulate(&HashMap::new())?;
|
|
let y_full = full_result.get(&outcome)
|
|
.map(|v| v.as_f64())
|
|
.unwrap_or(0.0);
|
|
|
|
// E[Y(x, M(x'))] - treatment but mediator at control level
|
|
let ctrl_intervention = Intervention::new(treatment, control_value);
|
|
let ctrl_intervened = model.intervene_with(&[ctrl_intervention])?;
|
|
let ctrl_result = ctrl_intervened.simulate(&HashMap::new())?;
|
|
let m_ctrl = ctrl_result.get(&mediator).cloned().unwrap_or(Value::Missing);
|
|
|
|
let m_intervention = Intervention::new(mediator, m_ctrl);
|
|
let intervened = model.intervene_with(&[treat_intervention, m_intervention])?;
|
|
let indirect_result = intervened.simulate(&HashMap::new())?;
|
|
let y_indirect = indirect_result.get(&outcome)
|
|
.map(|v| v.as_f64())
|
|
.unwrap_or(0.0);
|
|
|
|
Ok(y_full - y_indirect)
|
|
}
|
|
|
|
/// Probability of Necessity (PN)
|
|
///
|
|
/// PN = P(Y_x' = 0 | X = x, Y = 1)
|
|
///
|
|
/// Given that X=x and Y=1 occurred, what is the probability that Y would
|
|
/// have been 0 if X had been x' instead?
|
|
pub fn probability_of_necessity(
|
|
model: &CausalModel,
|
|
treatment: VariableId,
|
|
outcome: VariableId,
|
|
observation: &Observation,
|
|
counterfactual_treatment: Value,
|
|
) -> Result<f64, CounterfactualError> {
|
|
// Get outcome variable name
|
|
let outcome_name = model.get_variable_name(&outcome)
|
|
.ok_or_else(|| CounterfactualError::InvalidObservation(format!("Outcome variable {:?} not found", outcome)))?;
|
|
|
|
// Compute counterfactual outcome
|
|
let cf_value = counterfactual(model, observation, treatment, counterfactual_treatment, &outcome_name)?;
|
|
|
|
let cf_outcome = cf_value.as_f64();
|
|
|
|
// PN is probability that outcome would be 0 (negative)
|
|
// For continuous outcomes, we check if it crosses the threshold
|
|
let observed_outcome = observation.values.iter()
|
|
.find_map(|(name, val)| {
|
|
model.get_variable_id(name)
|
|
.filter(|id| *id == outcome)
|
|
.map(|_| val.as_f64())
|
|
})
|
|
.unwrap_or(0.0);
|
|
|
|
// Simple heuristic: if counterfactual outcome is significantly different
|
|
if observed_outcome > 0.0 && cf_outcome <= 0.0 {
|
|
Ok(1.0) // Necessary
|
|
} else if (observed_outcome - cf_outcome).abs() < 1e-6 {
|
|
Ok(0.0) // Not necessary
|
|
} else {
|
|
Ok(0.5) // Uncertain
|
|
}
|
|
}
|
|
|
|
/// Probability of Sufficiency (PS)
|
|
///
|
|
/// PS = P(Y_x = 1 | X = x', Y = 0)
|
|
///
|
|
/// Given that X=x' and Y=0 occurred, what is the probability that Y would
|
|
/// have been 1 if X had been x instead?
|
|
pub fn probability_of_sufficiency(
|
|
model: &CausalModel,
|
|
treatment: VariableId,
|
|
outcome: VariableId,
|
|
observation: &Observation,
|
|
counterfactual_treatment: Value,
|
|
) -> Result<f64, CounterfactualError> {
|
|
// Get outcome variable name
|
|
let outcome_name = model.get_variable_name(&outcome)
|
|
.ok_or_else(|| CounterfactualError::InvalidObservation(format!("Outcome variable {:?} not found", outcome)))?;
|
|
|
|
// Compute counterfactual outcome
|
|
let cf_value = counterfactual(model, observation, treatment, counterfactual_treatment, &outcome_name)?;
|
|
|
|
let cf_outcome = cf_value.as_f64();
|
|
|
|
let observed_outcome = observation.values.iter()
|
|
.find_map(|(name, val)| {
|
|
model.get_variable_id(name)
|
|
.filter(|id| *id == outcome)
|
|
.map(|_| val.as_f64())
|
|
})
|
|
.unwrap_or(1.0);
|
|
|
|
// PS: would the outcome have been positive if treatment were different?
|
|
if observed_outcome <= 0.0 && cf_outcome > 0.0 {
|
|
Ok(1.0) // Sufficient
|
|
} else if (observed_outcome - cf_outcome).abs() < 1e-6 {
|
|
Ok(0.0) // Not sufficient
|
|
} else {
|
|
Ok(0.5) // Uncertain
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::causal::model::{CausalModelBuilder, VariableType, Mechanism};
|
|
|
|
fn create_simple_model() -> CausalModel {
|
|
let mut model = CausalModel::with_name("Simple");
|
|
|
|
model.add_variable("X", VariableType::Continuous).unwrap();
|
|
model.add_variable("Y", VariableType::Continuous).unwrap();
|
|
|
|
let x = model.get_variable_id("X").unwrap();
|
|
let y = model.get_variable_id("Y").unwrap();
|
|
|
|
// Y = 2*X + 1
|
|
model.add_structural_equation(y, &[x], Mechanism::new(|p| {
|
|
Value::Continuous(2.0 * p[0].as_f64() + 1.0)
|
|
})).unwrap();
|
|
|
|
model
|
|
}
|
|
|
|
fn create_mediation_model() -> CausalModel {
|
|
let mut model = CausalModel::with_name("Mediation");
|
|
|
|
model.add_variable("X", VariableType::Continuous).unwrap();
|
|
model.add_variable("M", VariableType::Continuous).unwrap();
|
|
model.add_variable("Y", VariableType::Continuous).unwrap();
|
|
|
|
let x = model.get_variable_id("X").unwrap();
|
|
let m = model.get_variable_id("M").unwrap();
|
|
let y = model.get_variable_id("Y").unwrap();
|
|
|
|
// M = X
|
|
model.add_structural_equation(m, &[x], Mechanism::new(|p| {
|
|
p[0].clone()
|
|
})).unwrap();
|
|
|
|
// Y = M + 0.5*X
|
|
model.add_structural_equation(y, &[m, x], Mechanism::new(|p| {
|
|
Value::Continuous(p[0].as_f64() + 0.5 * p[1].as_f64())
|
|
})).unwrap();
|
|
|
|
model
|
|
}
|
|
|
|
#[test]
|
|
fn test_observation() {
|
|
let mut obs = Observation::new(&[("X", Value::Continuous(5.0))]);
|
|
obs.observe("Y", Value::Continuous(11.0));
|
|
|
|
assert!(obs.is_observed("X"));
|
|
assert!(obs.is_observed("Y"));
|
|
assert!(!obs.is_observed("Z"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_counterfactual_simple() {
|
|
let model = create_simple_model();
|
|
|
|
let x_id = model.get_variable_id("X").unwrap();
|
|
|
|
// Observation: X=3, Y=7 (since Y = 2*3 + 1)
|
|
let observation = Observation::new(&[
|
|
("X", Value::Continuous(3.0)),
|
|
("Y", Value::Continuous(7.0)),
|
|
]);
|
|
|
|
// Counterfactual: What would Y have been if X had been 5?
|
|
let intervention = Intervention::new(x_id, Value::Continuous(5.0));
|
|
|
|
let result = counterfactual(&model, &observation, &intervention).unwrap();
|
|
|
|
// Y should be 2*5 + 1 = 11
|
|
let y_id = model.get_variable_id("Y").unwrap();
|
|
let y_value = result.get(y_id).unwrap().as_f64();
|
|
|
|
assert!((y_value - 11.0).abs() < 1e-10);
|
|
}
|
|
|
|
#[test]
|
|
fn test_causal_effect() {
|
|
let model = create_simple_model();
|
|
|
|
let x = model.get_variable_id("X").unwrap();
|
|
let y = model.get_variable_id("Y").unwrap();
|
|
|
|
// ATE = E[Y|do(X=1)] - E[Y|do(X=0)]
|
|
// = (2*1 + 1) - (2*0 + 1) = 3 - 1 = 2
|
|
let ate = causal_effect(&model, x, y).unwrap();
|
|
|
|
assert!((ate - 2.0).abs() < 1e-10);
|
|
}
|
|
|
|
#[test]
|
|
fn test_average_treatment_effect() {
|
|
let model = create_simple_model();
|
|
|
|
let ate_result = average_treatment_effect(
|
|
&model,
|
|
"X", "Y",
|
|
Value::Continuous(1.0),
|
|
Value::Continuous(0.0),
|
|
).unwrap();
|
|
|
|
assert_eq!(ate_result.treatment, "X");
|
|
assert_eq!(ate_result.outcome, "Y");
|
|
assert!((ate_result.ate - 2.0).abs() < 1e-10);
|
|
}
|
|
|
|
#[test]
|
|
fn test_mediation_effects() {
|
|
let model = create_mediation_model();
|
|
|
|
let x = model.get_variable_id("X").unwrap();
|
|
let m = model.get_variable_id("M").unwrap();
|
|
let y = model.get_variable_id("Y").unwrap();
|
|
|
|
let treat = Value::Continuous(1.0);
|
|
let ctrl = Value::Continuous(0.0);
|
|
|
|
// Total effect should be:
|
|
// E[Y|do(X=1)] - E[Y|do(X=0)]
|
|
// = (M(1) + 0.5*1) - (M(0) + 0.5*0)
|
|
// = (1 + 0.5) - (0 + 0) = 1.5
|
|
let total = causal_effect_at_values(&model, x, y, treat.clone(), ctrl.clone()).unwrap();
|
|
assert!((total - 1.5).abs() < 1e-10);
|
|
|
|
// NDE should be the direct effect = 0.5 (coefficient of X in Y equation)
|
|
let nde = natural_direct_effect(&model, x, m, y, treat.clone(), ctrl.clone()).unwrap();
|
|
assert!((nde - 0.5).abs() < 1e-10);
|
|
|
|
// NIE should be the indirect effect = 1.0 (coefficient of M in Y, times effect of X on M)
|
|
let nie = natural_indirect_effect(&model, x, m, y, treat, ctrl).unwrap();
|
|
assert!((nie - 1.0).abs() < 1e-10);
|
|
|
|
// NDE + NIE should equal total effect
|
|
assert!((nde + nie - total).abs() < 1e-10);
|
|
}
|
|
|
|
#[test]
|
|
fn test_counterfactual_query() {
|
|
let model = create_simple_model();
|
|
|
|
let query = CounterfactualQuery::new(
|
|
"Y",
|
|
vec![("X", Value::Continuous(10.0))],
|
|
Observation::new(&[("X", Value::Continuous(3.0))]),
|
|
);
|
|
|
|
let result = counterfactual_query(&model, &query).unwrap();
|
|
|
|
// Y = 2*10 + 1 = 21
|
|
let y_id = model.get_variable_id("Y").unwrap();
|
|
assert!((result.distribution.mean(y_id) - 21.0).abs() < 1e-10);
|
|
}
|
|
|
|
#[test]
|
|
fn test_distribution() {
|
|
let mut values = HashMap::new();
|
|
let x_id = VariableId(0);
|
|
let y_id = VariableId(1);
|
|
|
|
values.insert(x_id, Value::Continuous(5.0));
|
|
values.insert(y_id, Value::Continuous(10.0));
|
|
|
|
let dist = CounterfactualDistribution::point_mass(values);
|
|
|
|
assert_eq!(dist.mean(x_id), 5.0);
|
|
assert_eq!(dist.mean(y_id), 10.0);
|
|
assert_eq!(dist.probability, 1.0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_individual_treatment_effect() {
|
|
let model = create_simple_model();
|
|
|
|
let x = model.get_variable_id("X").unwrap();
|
|
let y = model.get_variable_id("Y").unwrap();
|
|
|
|
// Unit-specific observation
|
|
let unit_obs = Observation::new(&[
|
|
("X", Value::Continuous(3.0)),
|
|
("Y", Value::Continuous(7.0)),
|
|
]);
|
|
|
|
let ite = individual_treatment_effect(
|
|
&model,
|
|
x, y,
|
|
&unit_obs,
|
|
Value::Continuous(5.0),
|
|
Value::Continuous(3.0),
|
|
).unwrap();
|
|
|
|
// ITE = Y(X=5) - Y(X=3) = 11 - 7 = 4
|
|
assert!((ite - 4.0).abs() < 1e-10);
|
|
}
|
|
}
|