Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,820 @@
//! Causal Abstraction Layer
//!
//! This module implements causal abstraction theory, which formalizes the
//! relationship between detailed (low-level) and simplified (high-level)
//! causal models. The key insight is that a high-level model is a valid
//! abstraction if interventions on the low-level model can be "lifted" to
//! corresponding interventions on the high-level model while preserving
//! distributional semantics.
//!
//! ## Theory
//!
//! A causal abstraction consists of:
//! - A low-level model M_L with variables V_L
//! - A high-level model M_H with variables V_H
//! - A surjective mapping τ: V_L → V_H
//!
//! The abstraction is **consistent** if for all interventions I on M_H:
//! τ(M_L(τ^{-1}(I))) = M_H(I)
//!
//! ## References
//!
//! - Beckers & Halpern (2019): "Abstracting Causal Models"
//! - Rubenstein et al. (2017): "Causal Consistency of Structural Equation Models"
use std::collections::{HashMap, HashSet};
use thiserror::Error;
use super::model::{CausalModel, CausalModelError, Intervention, Value, VariableId, Distribution};
use super::counterfactual::CounterfactualDistribution;
/// Error types for abstraction operations
#[derive(Debug, Clone, Error)]
pub enum AbstractionError {
/// Abstraction map is not surjective
#[error("Abstraction map is not surjective: high-level variable {0:?} has no preimage")]
NotSurjective(VariableId),
/// Abstraction is not consistent under intervention
#[error("Abstraction is not consistent: intervention {0:?} yields different results")]
InconsistentIntervention(String),
/// Invalid variable mapping
#[error("Invalid mapping: low-level variable {0:?} not in model")]
InvalidMapping(VariableId),
/// Models have incompatible structure
#[error("Incompatible model structure: {0}")]
IncompatibleStructure(String),
/// Underlying model error
#[error("Model error: {0}")]
ModelError(#[from] CausalModelError),
}
/// Mapping from low-level to high-level variables
#[derive(Debug, Clone)]
pub struct AbstractionMap {
/// Maps high-level variable to set of low-level variables
high_to_low: HashMap<VariableId, HashSet<VariableId>>,
/// Maps low-level variable to high-level variable
low_to_high: HashMap<VariableId, VariableId>,
/// Value aggregation functions (how to combine low-level values)
aggregators: HashMap<VariableId, Aggregator>,
}
/// How to aggregate low-level values into a high-level value
#[derive(Debug, Clone)]
pub enum Aggregator {
/// Take first value (for 1-to-1 mappings)
First,
/// Sum of values
Sum,
/// Mean of values
Mean,
/// Max of values
Max,
/// Min of values
Min,
/// Majority vote (for discrete/binary)
Majority,
/// Weighted combination
Weighted(Vec<f64>),
/// Custom function (represented as string for debug)
Custom(String),
}
impl Aggregator {
/// Apply the aggregator to a set of values
pub fn apply(&self, values: &[Value]) -> Value {
if values.is_empty() {
return Value::Missing;
}
match self {
Aggregator::First => values[0].clone(),
Aggregator::Sum => {
let sum: f64 = values.iter().map(|v| v.as_f64()).sum();
Value::Continuous(sum)
}
Aggregator::Mean => {
let sum: f64 = values.iter().map(|v| v.as_f64()).sum();
Value::Continuous(sum / values.len() as f64)
}
Aggregator::Max => {
let max = values.iter()
.map(|v| v.as_f64())
.fold(f64::NEG_INFINITY, f64::max);
Value::Continuous(max)
}
Aggregator::Min => {
let min = values.iter()
.map(|v| v.as_f64())
.fold(f64::INFINITY, f64::min);
Value::Continuous(min)
}
Aggregator::Majority => {
let mut counts: HashMap<i64, usize> = HashMap::new();
for v in values {
let key = v.as_f64() as i64;
*counts.entry(key).or_default() += 1;
}
let majority = counts.into_iter()
.max_by_key(|(_, count)| *count)
.map(|(val, _)| val)
.unwrap_or(0);
Value::Discrete(majority)
}
Aggregator::Weighted(weights) => {
let weighted_sum: f64 = values.iter()
.zip(weights.iter())
.map(|(v, w)| v.as_f64() * w)
.sum();
Value::Continuous(weighted_sum)
}
Aggregator::Custom(_) => {
// Default to mean for custom
let sum: f64 = values.iter().map(|v| v.as_f64()).sum();
Value::Continuous(sum / values.len() as f64)
}
}
}
}
impl AbstractionMap {
/// Create a new empty abstraction map
pub fn new() -> Self {
Self {
high_to_low: HashMap::new(),
low_to_high: HashMap::new(),
aggregators: HashMap::new(),
}
}
/// Add a mapping from high-level variable to low-level variables
pub fn add_mapping(
&mut self,
high: VariableId,
low_vars: HashSet<VariableId>,
aggregator: Aggregator,
) {
for &low in &low_vars {
self.low_to_high.insert(low, high);
}
self.high_to_low.insert(high, low_vars);
self.aggregators.insert(high, aggregator);
}
/// Add a 1-to-1 mapping
pub fn add_identity_mapping(&mut self, high: VariableId, low: VariableId) {
let mut low_set = HashSet::new();
low_set.insert(low);
self.add_mapping(high, low_set, Aggregator::First);
}
/// Get the high-level variable for a low-level variable
pub fn lift_variable(&self, low: VariableId) -> Option<VariableId> {
self.low_to_high.get(&low).copied()
}
/// Get the low-level variables for a high-level variable
pub fn project_variable(&self, high: VariableId) -> Option<&HashSet<VariableId>> {
self.high_to_low.get(&high)
}
/// Lift a value from low-level to high-level
pub fn lift_value(&self, high: VariableId, low_values: &HashMap<VariableId, Value>) -> Value {
let low_vars = match self.high_to_low.get(&high) {
Some(vars) => vars,
None => return Value::Missing,
};
let values: Vec<Value> = low_vars.iter()
.filter_map(|v| low_values.get(v).cloned())
.collect();
let aggregator = self.aggregators.get(&high).unwrap_or(&Aggregator::First);
aggregator.apply(&values)
}
/// Check if the mapping is surjective (every high-level var has a preimage)
pub fn is_surjective(&self, high_level: &CausalModel) -> bool {
for var in high_level.variables() {
if !self.high_to_low.contains_key(&var.id) {
return false;
}
}
true
}
}
impl Default for AbstractionMap {
fn default() -> Self {
Self::new()
}
}
/// Result of consistency checking
#[derive(Debug, Clone)]
pub struct ConsistencyResult {
/// Whether the abstraction is consistent
pub is_consistent: bool,
/// Violations found (if any)
pub violations: Vec<ConsistencyViolation>,
/// Interventions tested
pub interventions_tested: usize,
/// Maximum observed divergence
pub max_divergence: f64,
}
/// A violation of causal abstraction consistency
#[derive(Debug, Clone)]
pub struct ConsistencyViolation {
/// The intervention that caused the violation
pub intervention: String,
/// Expected high-level outcome
pub expected: HashMap<String, f64>,
/// Actual (projected from low-level) outcome
pub actual: HashMap<String, f64>,
/// Divergence measure
pub divergence: f64,
}
/// Causal Abstraction between two causal models
pub struct CausalAbstraction<'a> {
/// The low-level (detailed) model
pub low_level: &'a CausalModel,
/// The high-level (abstract) model
pub high_level: &'a CausalModel,
/// The abstraction mapping
pub abstraction_map: AbstractionMap,
/// Tolerance for numerical consistency checks
pub tolerance: f64,
}
impl<'a> CausalAbstraction<'a> {
/// Create a new causal abstraction
pub fn new(
low_level: &'a CausalModel,
high_level: &'a CausalModel,
abstraction_map: AbstractionMap,
) -> Result<Self, AbstractionError> {
let abstraction = Self {
low_level,
high_level,
abstraction_map,
tolerance: 1e-6,
};
abstraction.validate_structure()?;
Ok(abstraction)
}
/// Set tolerance for numerical comparisons
pub fn with_tolerance(mut self, tol: f64) -> Self {
self.tolerance = tol;
self
}
/// Validate that the abstraction structure is valid
fn validate_structure(&self) -> Result<(), AbstractionError> {
// Check surjectivity
for var in self.high_level.variables() {
if self.abstraction_map.high_to_low.get(&var.id).is_none() {
return Err(AbstractionError::NotSurjective(var.id));
}
}
// Check that all low-level variables in the map exist
for low_vars in self.abstraction_map.high_to_low.values() {
for &low_var in low_vars {
if self.low_level.get_variable(&low_var).is_none() {
return Err(AbstractionError::InvalidMapping(low_var));
}
}
}
Ok(())
}
/// Check if the abstraction is consistent under a set of interventions
pub fn is_consistent(&self) -> bool {
self.check_consistency().is_consistent
}
/// Perform detailed consistency check
pub fn check_consistency(&self) -> ConsistencyResult {
let mut violations = Vec::new();
let mut max_divergence = 0.0;
let mut interventions_tested = 0;
// Test consistency for single-variable interventions on high-level model
for high_var in self.high_level.variables() {
// Test a few intervention values
for intervention_value in [0.0, 1.0, -1.0, 0.5] {
interventions_tested += 1;
let high_intervention = Intervention::new(
high_var.id,
Value::Continuous(intervention_value),
);
// Check consistency for this intervention
if let Some(violation) = self.check_single_intervention(&high_intervention) {
max_divergence = max_divergence.max(violation.divergence);
violations.push(violation);
}
}
}
ConsistencyResult {
is_consistent: violations.is_empty(),
violations,
interventions_tested,
max_divergence,
}
}
/// Check consistency for a single intervention
fn check_single_intervention(&self, high_intervention: &Intervention) -> Option<ConsistencyViolation> {
// Lift the intervention to low-level
let low_interventions = self.lift_intervention(high_intervention);
// Simulate high-level model with intervention
let high_result = self.high_level.intervene(&[high_intervention.clone()]);
let high_values = match high_result {
Ok(model) => model.simulate(&HashMap::new()).ok(),
Err(_) => None,
};
// Simulate low-level model with lifted interventions
let low_result = self.low_level.intervene(&low_interventions);
let low_values = match low_result {
Ok(model) => model.simulate(&HashMap::new()).ok(),
Err(_) => None,
};
// Project low-level results to high-level
let (high_values, low_values) = match (high_values, low_values) {
(Some(h), Some(l)) => (h, l),
_ => return None, // Can't compare if simulation failed
};
let projected = self.project_distribution(&low_values);
// Compare high-level result with projected result
let mut divergence = 0.0;
let mut expected = HashMap::new();
let mut actual = HashMap::new();
for high_var in self.high_level.variables() {
let high_val = high_values.get(&high_var.id)
.map(|v| v.as_f64())
.unwrap_or(0.0);
let proj_val = projected.get(&high_var.id)
.map(|v| v.as_f64())
.unwrap_or(0.0);
let diff = (high_val - proj_val).abs();
divergence += diff * diff;
expected.insert(high_var.name.clone(), high_val);
actual.insert(high_var.name.clone(), proj_val);
}
divergence = divergence.sqrt();
if divergence > self.tolerance {
Some(ConsistencyViolation {
intervention: format!("do({:?} = {:?})", high_intervention.target, high_intervention.value),
expected,
actual,
divergence,
})
} else {
None
}
}
/// Lift a high-level intervention to low-level interventions
pub fn lift_intervention(&self, high: &Intervention) -> Vec<Intervention> {
let low_vars = match self.abstraction_map.project_variable(high.target) {
Some(vars) => vars,
None => return vec![],
};
// Simple strategy: apply same value to all corresponding low-level variables
// More sophisticated approaches could distribute the intervention differently
low_vars.iter()
.map(|&low_var| Intervention::new(low_var, high.value.clone()))
.collect()
}
/// Project a low-level distribution to high-level
pub fn project_distribution(&self, low_dist: &HashMap<VariableId, Value>) -> HashMap<VariableId, Value> {
let mut high_dist = HashMap::new();
for high_var in self.high_level.variables() {
let projected_value = self.abstraction_map.lift_value(high_var.id, low_dist);
high_dist.insert(high_var.id, projected_value);
}
high_dist
}
/// Project a CounterfactualDistribution object
pub fn project_distribution_obj(&self, low_dist: &CounterfactualDistribution) -> CounterfactualDistribution {
let high_values = self.project_distribution(&low_dist.values);
CounterfactualDistribution {
values: high_values,
probability: low_dist.probability,
}
}
/// Get the coarsening factor (how much the abstraction simplifies)
pub fn coarsening_factor(&self) -> f64 {
let low_count = self.low_level.num_variables() as f64;
let high_count = self.high_level.num_variables() as f64;
if high_count > 0.0 {
low_count / high_count
} else {
f64::INFINITY
}
}
/// Check if a low-level variable is "hidden" (not directly represented in high-level)
pub fn is_hidden(&self, low_var: VariableId) -> bool {
self.abstraction_map.lift_variable(low_var).is_none()
}
/// Get all hidden variables
pub fn hidden_variables(&self) -> Vec<VariableId> {
self.low_level.variables()
.filter(|v| self.is_hidden(v.id))
.map(|v| v.id)
.collect()
}
}
/// Builder for creating causal abstractions
pub struct AbstractionBuilder<'a> {
low_level: Option<&'a CausalModel>,
high_level: Option<&'a CausalModel>,
map: AbstractionMap,
}
impl<'a> AbstractionBuilder<'a> {
/// Create a new builder
pub fn new() -> Self {
Self {
low_level: None,
high_level: None,
map: AbstractionMap::new(),
}
}
/// Set the low-level model
pub fn low_level(mut self, model: &'a CausalModel) -> Self {
self.low_level = Some(model);
self
}
/// Set the high-level model
pub fn high_level(mut self, model: &'a CausalModel) -> Self {
self.high_level = Some(model);
self
}
/// Add a variable mapping by name
pub fn map_variable(
mut self,
high_name: &str,
low_names: &[&str],
aggregator: Aggregator,
) -> Self {
if let (Some(low), Some(high)) = (self.low_level, self.high_level) {
if let Some(high_id) = high.get_variable_id(high_name) {
let low_ids: HashSet<_> = low_names.iter()
.filter_map(|&name| low.get_variable_id(name))
.collect();
if !low_ids.is_empty() {
self.map.add_mapping(high_id, low_ids, aggregator);
}
}
}
self
}
/// Add an identity mapping by name
pub fn map_identity(mut self, high_name: &str, low_name: &str) -> Self {
if let (Some(low), Some(high)) = (self.low_level, self.high_level) {
if let (Some(high_id), Some(low_id)) = (
high.get_variable_id(high_name),
low.get_variable_id(low_name),
) {
self.map.add_identity_mapping(high_id, low_id);
}
}
self
}
/// Build the abstraction
pub fn build(self) -> Result<CausalAbstraction<'a>, AbstractionError> {
let low = self.low_level.ok_or_else(|| {
AbstractionError::IncompatibleStructure("No low-level model provided".to_string())
})?;
let high = self.high_level.ok_or_else(|| {
AbstractionError::IncompatibleStructure("No high-level model provided".to_string())
})?;
CausalAbstraction::new(low, high, self.map)
}
}
impl<'a> Default for AbstractionBuilder<'a> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::causal::model::{CausalModelBuilder, Mechanism, VariableType};
fn create_low_level_model() -> CausalModel {
let mut model = CausalModel::with_name("Low-Level");
// Detailed model with separate variables
model.add_variable("Age", VariableType::Continuous).unwrap();
model.add_variable("Education", VariableType::Continuous).unwrap();
model.add_variable("Experience", VariableType::Continuous).unwrap();
model.add_variable("Salary", VariableType::Continuous).unwrap();
model.add_variable("Savings", VariableType::Continuous).unwrap();
let age = model.get_variable_id("Age").unwrap();
let edu = model.get_variable_id("Education").unwrap();
let exp = model.get_variable_id("Experience").unwrap();
let salary = model.get_variable_id("Salary").unwrap();
let savings = model.get_variable_id("Savings").unwrap();
// Experience = f(Age, Education)
model.add_structural_equation(exp, &[age, edu], Mechanism::new(|p| {
Value::Continuous(p[0].as_f64() * 0.5 + p[1].as_f64() * 0.3)
})).unwrap();
// Salary = f(Education, Experience)
model.add_structural_equation(salary, &[edu, exp], Mechanism::new(|p| {
Value::Continuous(30000.0 + p[0].as_f64() * 5000.0 + p[1].as_f64() * 2000.0)
})).unwrap();
// Savings = f(Salary)
model.add_structural_equation(savings, &[salary], Mechanism::new(|p| {
Value::Continuous(p[0].as_f64() * 0.2)
})).unwrap();
model
}
fn create_high_level_model() -> CausalModel {
let mut model = CausalModel::with_name("High-Level");
// Simplified model with aggregated variables
model.add_variable("HumanCapital", VariableType::Continuous).unwrap();
model.add_variable("Wealth", VariableType::Continuous).unwrap();
let hc = model.get_variable_id("HumanCapital").unwrap();
let wealth = model.get_variable_id("Wealth").unwrap();
// Wealth = f(HumanCapital)
model.add_structural_equation(wealth, &[hc], Mechanism::new(|p| {
Value::Continuous(p[0].as_f64() * 10000.0)
})).unwrap();
model
}
#[test]
fn test_abstraction_map() {
let low = create_low_level_model();
let high = create_high_level_model();
let mut map = AbstractionMap::new();
// HumanCapital = mean(Education, Experience)
let hc_id = high.get_variable_id("HumanCapital").unwrap();
let edu_id = low.get_variable_id("Education").unwrap();
let exp_id = low.get_variable_id("Experience").unwrap();
let mut low_vars = HashSet::new();
low_vars.insert(edu_id);
low_vars.insert(exp_id);
map.add_mapping(hc_id, low_vars, Aggregator::Mean);
// Wealth = sum(Salary, Savings)
let wealth_id = high.get_variable_id("Wealth").unwrap();
let salary_id = low.get_variable_id("Salary").unwrap();
let savings_id = low.get_variable_id("Savings").unwrap();
let mut wealth_vars = HashSet::new();
wealth_vars.insert(salary_id);
wealth_vars.insert(savings_id);
map.add_mapping(wealth_id, wealth_vars, Aggregator::Sum);
assert!(map.is_surjective(&high));
}
#[test]
fn test_aggregators() {
let values = vec![
Value::Continuous(1.0),
Value::Continuous(2.0),
Value::Continuous(3.0),
];
assert_eq!(Aggregator::First.apply(&values).as_f64(), 1.0);
assert_eq!(Aggregator::Sum.apply(&values).as_f64(), 6.0);
assert_eq!(Aggregator::Mean.apply(&values).as_f64(), 2.0);
assert_eq!(Aggregator::Max.apply(&values).as_f64(), 3.0);
assert_eq!(Aggregator::Min.apply(&values).as_f64(), 1.0);
}
#[test]
fn test_lift_intervention() {
let low = create_low_level_model();
let high = create_high_level_model();
let mut map = AbstractionMap::new();
let hc_id = high.get_variable_id("HumanCapital").unwrap();
let edu_id = low.get_variable_id("Education").unwrap();
let exp_id = low.get_variable_id("Experience").unwrap();
let mut low_vars = HashSet::new();
low_vars.insert(edu_id);
low_vars.insert(exp_id);
map.add_mapping(hc_id, low_vars, Aggregator::Mean);
// Add wealth mapping
let wealth_id = high.get_variable_id("Wealth").unwrap();
let salary_id = low.get_variable_id("Salary").unwrap();
let mut wealth_vars = HashSet::new();
wealth_vars.insert(salary_id);
map.add_mapping(wealth_id, wealth_vars, Aggregator::First);
let abstraction = CausalAbstraction::new(&low, &high, map).unwrap();
let high_intervention = Intervention::new(hc_id, Value::Continuous(10.0));
let low_interventions = abstraction.lift_intervention(&high_intervention);
// Should lift to interventions on Education and Experience
assert_eq!(low_interventions.len(), 2);
assert!(low_interventions.iter().any(|i| i.target == edu_id));
assert!(low_interventions.iter().any(|i| i.target == exp_id));
}
#[test]
fn test_project_distribution() {
let low = create_low_level_model();
let high = create_high_level_model();
let mut map = AbstractionMap::new();
let hc_id = high.get_variable_id("HumanCapital").unwrap();
let edu_id = low.get_variable_id("Education").unwrap();
let mut low_vars = HashSet::new();
low_vars.insert(edu_id);
map.add_mapping(hc_id, low_vars, Aggregator::First);
let wealth_id = high.get_variable_id("Wealth").unwrap();
let salary_id = low.get_variable_id("Salary").unwrap();
let mut wealth_vars = HashSet::new();
wealth_vars.insert(salary_id);
map.add_mapping(wealth_id, wealth_vars, Aggregator::First);
let abstraction = CausalAbstraction::new(&low, &high, map).unwrap();
let mut low_dist = HashMap::new();
low_dist.insert(edu_id, Value::Continuous(16.0));
low_dist.insert(salary_id, Value::Continuous(80000.0));
let high_dist = abstraction.project_distribution(&low_dist);
assert_eq!(high_dist.get(&hc_id).unwrap().as_f64(), 16.0);
assert_eq!(high_dist.get(&wealth_id).unwrap().as_f64(), 80000.0);
}
#[test]
fn test_coarsening_factor() {
let low = create_low_level_model();
let high = create_high_level_model();
let mut map = AbstractionMap::new();
// Simple identity mappings for this test
let hc_id = high.get_variable_id("HumanCapital").unwrap();
let edu_id = low.get_variable_id("Education").unwrap();
map.add_identity_mapping(hc_id, edu_id);
let wealth_id = high.get_variable_id("Wealth").unwrap();
let salary_id = low.get_variable_id("Salary").unwrap();
map.add_identity_mapping(wealth_id, salary_id);
let abstraction = CausalAbstraction::new(&low, &high, map).unwrap();
// 5 low-level vars / 2 high-level vars = 2.5
assert!((abstraction.coarsening_factor() - 2.5).abs() < 1e-10);
}
#[test]
fn test_hidden_variables() {
let low = create_low_level_model();
let high = create_high_level_model();
let mut map = AbstractionMap::new();
// Only map Education to HumanCapital
let hc_id = high.get_variable_id("HumanCapital").unwrap();
let edu_id = low.get_variable_id("Education").unwrap();
map.add_identity_mapping(hc_id, edu_id);
// Only map Salary to Wealth
let wealth_id = high.get_variable_id("Wealth").unwrap();
let salary_id = low.get_variable_id("Salary").unwrap();
map.add_identity_mapping(wealth_id, salary_id);
let abstraction = CausalAbstraction::new(&low, &high, map).unwrap();
let hidden = abstraction.hidden_variables();
// Age, Experience, Savings should be hidden
let age_id = low.get_variable_id("Age").unwrap();
let exp_id = low.get_variable_id("Experience").unwrap();
let savings_id = low.get_variable_id("Savings").unwrap();
assert!(hidden.contains(&age_id));
assert!(hidden.contains(&exp_id));
assert!(hidden.contains(&savings_id));
assert!(!hidden.contains(&edu_id));
assert!(!hidden.contains(&salary_id));
}
#[test]
fn test_builder() {
let low = create_low_level_model();
let high = create_high_level_model();
let abstraction = AbstractionBuilder::new()
.low_level(&low)
.high_level(&high)
.map_identity("HumanCapital", "Education")
.map_identity("Wealth", "Salary")
.build()
.unwrap();
assert_eq!(abstraction.coarsening_factor(), 2.5);
}
#[test]
fn test_consistency_check() {
let low = create_low_level_model();
let high = create_high_level_model();
let mut map = AbstractionMap::new();
let hc_id = high.get_variable_id("HumanCapital").unwrap();
let edu_id = low.get_variable_id("Education").unwrap();
map.add_identity_mapping(hc_id, edu_id);
let wealth_id = high.get_variable_id("Wealth").unwrap();
let salary_id = low.get_variable_id("Salary").unwrap();
map.add_identity_mapping(wealth_id, salary_id);
let abstraction = CausalAbstraction::new(&low, &high, map)
.unwrap()
.with_tolerance(1000.0); // High tolerance for this test
let result = abstraction.check_consistency();
// The abstraction may or may not be consistent depending on mechanisms
// This test just verifies the check runs
assert!(result.interventions_tested > 0);
}
}

View File

@@ -0,0 +1,973 @@
//! Causal Coherence Checking
//!
//! This module provides tools for verifying that beliefs and data are
//! consistent with a causal model. Key capabilities:
//!
//! - Detecting spurious correlations (associations not explained by causation)
//! - Checking if beliefs satisfy causal constraints
//! - Answering causal queries using do-calculus
//! - Computing coherence energy for integration with Prime-Radiant
use std::collections::{HashMap, HashSet};
use thiserror::Error;
use super::model::{CausalModel, CausalModelError, Value, VariableId, VariableType, Mechanism};
use super::graph::DAGValidationError;
/// Error types for coherence operations
#[derive(Debug, Clone, Error)]
pub enum CoherenceError {
/// Model error
#[error("Model error: {0}")]
ModelError(#[from] CausalModelError),
/// Graph error
#[error("Graph error: {0}")]
GraphError(#[from] DAGValidationError),
/// Inconsistent belief
#[error("Inconsistent belief: {0}")]
InconsistentBelief(String),
/// Invalid query
#[error("Invalid query: {0}")]
InvalidQuery(String),
}
/// A belief about the relationship between variables
#[derive(Debug, Clone)]
pub struct Belief {
/// Subject variable
pub subject: String,
/// Object variable
pub object: String,
/// Type of belief
pub belief_type: BeliefType,
/// Confidence in the belief (0.0 to 1.0)
pub confidence: f64,
/// Evidence supporting the belief
pub evidence: Option<String>,
}
/// Types of causal beliefs
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum BeliefType {
/// X causes Y
Causes,
/// X is correlated with Y (may or may not be causal)
CorrelatedWith,
/// X is independent of Y
IndependentOf,
/// X is independent of Y given Z
ConditionallyIndependent { given: Vec<String> },
/// X and Y have a common cause
CommonCause,
/// Changing X would change Y (interventional)
WouldChange,
}
impl Belief {
/// Create a causal belief: X causes Y
pub fn causes(x: &str, y: &str) -> Self {
Self {
subject: x.to_string(),
object: y.to_string(),
belief_type: BeliefType::Causes,
confidence: 1.0,
evidence: None,
}
}
/// Create a correlation belief
pub fn correlated(x: &str, y: &str) -> Self {
Self {
subject: x.to_string(),
object: y.to_string(),
belief_type: BeliefType::CorrelatedWith,
confidence: 1.0,
evidence: None,
}
}
/// Create an independence belief
pub fn independent(x: &str, y: &str) -> Self {
Self {
subject: x.to_string(),
object: y.to_string(),
belief_type: BeliefType::IndependentOf,
confidence: 1.0,
evidence: None,
}
}
/// Create a conditional independence belief
pub fn conditionally_independent(x: &str, y: &str, given: &[&str]) -> Self {
Self {
subject: x.to_string(),
object: y.to_string(),
belief_type: BeliefType::ConditionallyIndependent {
given: given.iter().map(|s| s.to_string()).collect(),
},
confidence: 1.0,
evidence: None,
}
}
/// Set confidence level
pub fn with_confidence(mut self, confidence: f64) -> Self {
self.confidence = confidence.clamp(0.0, 1.0);
self
}
/// Set evidence
pub fn with_evidence(mut self, evidence: &str) -> Self {
self.evidence = Some(evidence.to_string());
self
}
}
/// Result of causal consistency checking
#[derive(Debug, Clone)]
pub struct CausalConsistency {
/// Overall consistency score (0.0 to 1.0)
pub score: f64,
/// Number of beliefs checked
pub beliefs_checked: usize,
/// Number of consistent beliefs
pub consistent_beliefs: usize,
/// Number of inconsistent beliefs
pub inconsistent_beliefs: usize,
/// Details of inconsistencies
pub inconsistencies: Vec<Inconsistency>,
/// Suggested model modifications
pub suggestions: Vec<String>,
}
impl CausalConsistency {
/// Create a fully consistent result
pub fn fully_consistent(beliefs_checked: usize) -> Self {
Self {
score: 1.0,
beliefs_checked,
consistent_beliefs: beliefs_checked,
inconsistent_beliefs: 0,
inconsistencies: vec![],
suggestions: vec![],
}
}
/// Check if fully consistent
pub fn is_consistent(&self) -> bool {
self.score >= 1.0 - 1e-10
}
}
/// Details of a causal inconsistency
#[derive(Debug, Clone)]
pub struct Inconsistency {
/// The belief that is inconsistent
pub belief: Belief,
/// Why it's inconsistent
pub reason: String,
/// Severity (0.0 to 1.0)
pub severity: f64,
}
/// A detected spurious correlation
#[derive(Debug, Clone)]
pub struct SpuriousCorrelation {
/// First variable
pub var_a: String,
/// Second variable
pub var_b: String,
/// The common cause(s) explaining the correlation
pub confounders: Vec<String>,
/// Strength of the spurious correlation
pub strength: f64,
/// Explanation
pub explanation: String,
}
/// A causal query
#[derive(Debug, Clone)]
pub struct CausalQuery {
/// The variable we're asking about
pub target: String,
/// Variables we're intervening on
pub interventions: Vec<(String, Value)>,
/// Variables we're conditioning on
pub conditions: Vec<(String, Value)>,
/// Query type
pub query_type: QueryType,
}
/// Types of causal queries
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum QueryType {
/// P(Y | do(X=x)) - interventional query
Interventional,
/// P(Y | X=x) - observational query
Observational,
/// P(Y_x | X=x') - counterfactual query
Counterfactual,
/// P(Y | do(X=x), Z=z) - conditional interventional
ConditionalInterventional,
}
impl CausalQuery {
/// Create an interventional query: P(target | do(intervention))
pub fn interventional(target: &str, intervention_var: &str, intervention_val: Value) -> Self {
Self {
target: target.to_string(),
interventions: vec![(intervention_var.to_string(), intervention_val)],
conditions: vec![],
query_type: QueryType::Interventional,
}
}
/// Create an observational query: P(target | condition)
pub fn observational(target: &str, condition_var: &str, condition_val: Value) -> Self {
Self {
target: target.to_string(),
interventions: vec![],
conditions: vec![(condition_var.to_string(), condition_val)],
query_type: QueryType::Observational,
}
}
/// Add a condition
pub fn given(mut self, var: &str, val: Value) -> Self {
self.conditions.push((var.to_string(), val));
self
}
}
/// Answer to a causal query
#[derive(Debug, Clone)]
pub struct CausalAnswer {
/// The query that was answered
pub query: CausalQuery,
/// The estimated value/distribution
pub estimate: Value,
/// Confidence interval (if applicable)
pub confidence_interval: Option<(f64, f64)>,
/// Whether the query is identifiable from observational data
pub is_identifiable: bool,
/// Explanation of the answer
pub explanation: String,
}
/// Combined coherence energy for integration with Prime-Radiant
#[derive(Debug, Clone)]
pub struct CoherenceEnergy {
/// Total energy (lower is more coherent)
pub total: f64,
/// Structural component (from sheaf consistency)
pub structural_component: f64,
/// Causal component (from causal consistency)
pub causal_component: f64,
/// Intervention component (from intervention consistency)
pub intervention_component: f64,
/// Whether the system is coherent (energy below threshold)
pub is_coherent: bool,
}
impl CoherenceEnergy {
/// Create a fully coherent state
pub fn coherent() -> Self {
Self {
total: 0.0,
structural_component: 0.0,
causal_component: 0.0,
intervention_component: 0.0,
is_coherent: true,
}
}
/// Create from individual components
pub fn from_components(structural: f64, causal: f64, intervention: f64) -> Self {
let total = structural + causal + intervention;
Self {
total,
structural_component: structural,
causal_component: causal,
intervention_component: intervention,
is_coherent: total < 1e-6,
}
}
}
/// Dataset for spurious correlation detection
#[derive(Debug, Clone)]
pub struct Dataset {
/// Column names
pub columns: Vec<String>,
/// Data rows (each row is a vector of values)
pub rows: Vec<Vec<f64>>,
}
impl Dataset {
/// Create a new dataset
pub fn new(columns: Vec<String>) -> Self {
Self {
columns,
rows: Vec::new(),
}
}
/// Add a row
pub fn add_row(&mut self, row: Vec<f64>) {
if row.len() == self.columns.len() {
self.rows.push(row);
}
}
/// Get column index
pub fn column_index(&self, name: &str) -> Option<usize> {
self.columns.iter().position(|c| c == name)
}
/// Get column values
pub fn column(&self, name: &str) -> Option<Vec<f64>> {
let idx = self.column_index(name)?;
Some(self.rows.iter().map(|row| row[idx]).collect())
}
/// Compute correlation between two columns
pub fn correlation(&self, col_a: &str, col_b: &str) -> Option<f64> {
let a = self.column(col_a)?;
let b = self.column(col_b)?;
if a.len() != b.len() || a.is_empty() {
return None;
}
let n = a.len() as f64;
let mean_a: f64 = a.iter().sum::<f64>() / n;
let mean_b: f64 = b.iter().sum::<f64>() / n;
let mut cov = 0.0;
let mut var_a = 0.0;
let mut var_b = 0.0;
for i in 0..a.len() {
let da = a[i] - mean_a;
let db = b[i] - mean_b;
cov += da * db;
var_a += da * da;
var_b += db * db;
}
let denom = (var_a * var_b).sqrt();
if denom < 1e-10 {
Some(0.0)
} else {
Some(cov / denom)
}
}
}
/// Causal coherence checker
pub struct CausalCoherenceChecker<'a> {
/// The causal model
model: &'a CausalModel,
/// Correlation threshold for "significant" correlation
correlation_threshold: f64,
}
impl<'a> CausalCoherenceChecker<'a> {
/// Create a new checker
pub fn new(model: &'a CausalModel) -> Self {
Self {
model,
correlation_threshold: 0.1,
}
}
/// Set correlation threshold
pub fn with_correlation_threshold(mut self, threshold: f64) -> Self {
self.correlation_threshold = threshold;
self
}
/// Check if a set of beliefs is consistent with the causal model
pub fn check_causal_consistency(&self, beliefs: &[Belief]) -> CausalConsistency {
let mut consistent_count = 0;
let mut inconsistencies = Vec::new();
let mut suggestions = Vec::new();
for belief in beliefs {
match self.check_single_belief(belief) {
Ok(()) => consistent_count += 1,
Err(reason) => {
inconsistencies.push(Inconsistency {
belief: belief.clone(),
reason: reason.clone(),
severity: 1.0 - belief.confidence,
});
// Generate suggestion
if let Some(suggestion) = self.generate_suggestion(belief, &reason) {
suggestions.push(suggestion);
}
}
}
}
let beliefs_checked = beliefs.len();
let score = if beliefs_checked > 0 {
consistent_count as f64 / beliefs_checked as f64
} else {
1.0
};
CausalConsistency {
score,
beliefs_checked,
consistent_beliefs: consistent_count,
inconsistent_beliefs: beliefs_checked - consistent_count,
inconsistencies,
suggestions,
}
}
/// Check a single belief against the model
fn check_single_belief(&self, belief: &Belief) -> Result<(), String> {
let subject_id = self.model.get_variable_id(&belief.subject)
.ok_or_else(|| format!("Variable '{}' not in model", belief.subject))?;
let object_id = self.model.get_variable_id(&belief.object)
.ok_or_else(|| format!("Variable '{}' not in model", belief.object))?;
match &belief.belief_type {
BeliefType::Causes => {
// Check if there's a directed path from subject to object
let descendants = self.model.graph().descendants(subject_id.0);
if !descendants.contains(&object_id.0) {
return Err(format!(
"No causal path from {} to {} in model",
belief.subject, belief.object
));
}
}
BeliefType::IndependentOf => {
// Check if they're d-separated given empty set
if !self.model.d_separated(subject_id, object_id, &[]) {
return Err(format!(
"{} and {} are not independent according to model",
belief.subject, belief.object
));
}
}
BeliefType::ConditionallyIndependent { given } => {
let given_ids: Result<Vec<VariableId>, _> = given.iter()
.map(|name| {
self.model.get_variable_id(name)
.ok_or_else(|| format!("Variable '{}' not in model", name))
})
.collect();
let given_ids = given_ids?;
if !self.model.d_separated(subject_id, object_id, &given_ids) {
return Err(format!(
"{} and {} are not conditionally independent given {:?}",
belief.subject, belief.object, given
));
}
}
BeliefType::CommonCause => {
// Check if they share a common ancestor
let ancestors_a = self.model.graph().ancestors(subject_id.0);
let ancestors_b = self.model.graph().ancestors(object_id.0);
let common: HashSet<_> = ancestors_a.intersection(&ancestors_b).collect();
if common.is_empty() {
return Err(format!(
"No common cause found for {} and {}",
belief.subject, belief.object
));
}
}
BeliefType::CorrelatedWith | BeliefType::WouldChange => {
// These are not directly checkable against model structure alone
// They would need data or simulation
}
}
Ok(())
}
/// Generate a suggestion for fixing an inconsistency
fn generate_suggestion(&self, belief: &Belief, _reason: &str) -> Option<String> {
match &belief.belief_type {
BeliefType::Causes => {
Some(format!(
"Consider adding edge {} -> {} to the model, or revising the belief",
belief.subject, belief.object
))
}
BeliefType::IndependentOf => {
Some(format!(
"Consider conditioning on a confounding variable, or revising the model structure"
))
}
BeliefType::ConditionallyIndependent { given } => {
Some(format!(
"The conditioning set {:?} may be insufficient; consider additional variables",
given
))
}
_ => None,
}
}
/// Detect spurious correlations in data given the causal model
pub fn detect_spurious_correlations(&self, data: &Dataset) -> Vec<SpuriousCorrelation> {
let mut spurious = Vec::new();
// Check all pairs of variables
for i in 0..data.columns.len() {
for j in (i + 1)..data.columns.len() {
let col_a = &data.columns[i];
let col_b = &data.columns[j];
// Get correlation from data
let correlation = match data.correlation(col_a, col_b) {
Some(c) => c,
None => continue,
};
// If significantly correlated
if correlation.abs() > self.correlation_threshold {
// Check if causally linked
if let (Some(id_a), Some(id_b)) = (
self.model.get_variable_id(col_a),
self.model.get_variable_id(col_b),
) {
// Check if there's a direct causal path
let a_causes_b = self.model.graph().descendants(id_a.0).contains(&id_b.0);
let b_causes_a = self.model.graph().descendants(id_b.0).contains(&id_a.0);
if !a_causes_b && !b_causes_a {
// Correlation without direct causation - find confounders
let confounders = self.find_confounders(id_a, id_b);
if !confounders.is_empty() {
spurious.push(SpuriousCorrelation {
var_a: col_a.clone(),
var_b: col_b.clone(),
confounders: confounders.clone(),
strength: correlation.abs(),
explanation: format!(
"Correlation (r={:.3}) between {} and {} is explained by common cause(s): {}",
correlation, col_a, col_b, confounders.join(", ")
),
});
}
}
}
}
}
}
spurious
}
/// Find common causes (confounders) of two variables
fn find_confounders(&self, a: VariableId, b: VariableId) -> Vec<String> {
let ancestors_a = self.model.graph().ancestors(a.0);
let ancestors_b = self.model.graph().ancestors(b.0);
let common: Vec<_> = ancestors_a.intersection(&ancestors_b)
.filter_map(|&id| self.model.get_variable_name(&VariableId(id)))
.collect();
common
}
/// Answer a causal query using do-calculus
pub fn enforce_do_calculus(&self, query: &CausalQuery) -> Result<CausalAnswer, CoherenceError> {
// Get target variable
let target_id = self.model.get_variable_id(&query.target)
.ok_or_else(|| CoherenceError::InvalidQuery(
format!("Target variable '{}' not in model", query.target)
))?;
match query.query_type {
QueryType::Interventional => {
self.answer_interventional_query(query, target_id)
}
QueryType::Observational => {
self.answer_observational_query(query, target_id)
}
QueryType::Counterfactual => {
self.answer_counterfactual_query(query, target_id)
}
QueryType::ConditionalInterventional => {
self.answer_conditional_interventional_query(query, target_id)
}
}
}
fn answer_interventional_query(
&self,
query: &CausalQuery,
target_id: VariableId,
) -> Result<CausalAnswer, CoherenceError> {
// Convert intervention specification to Intervention objects
let interventions: Result<Vec<_>, _> = query.interventions.iter()
.map(|(var, val)| {
self.model.get_variable_id(var)
.map(|id| super::model::Intervention::new(id, val.clone()))
.ok_or_else(|| CoherenceError::InvalidQuery(
format!("Intervention variable '{}' not in model", var)
))
})
.collect();
let interventions = interventions?;
// Perform intervention
let intervened = self.model.intervene_with(&interventions)?;
// Simulate to get the target value
let values = intervened.simulate(&HashMap::new())?;
let estimate = values.get(&target_id).cloned().unwrap_or(Value::Missing);
// Check identifiability
let is_identifiable = self.check_identifiability(query);
Ok(CausalAnswer {
query: query.clone(),
estimate,
confidence_interval: None,
is_identifiable,
explanation: format!(
"Computed P({} | do({})) by intervention simulation",
query.target,
query.interventions.iter()
.map(|(v, val)| format!("{}={:?}", v, val))
.collect::<Vec<_>>()
.join(", ")
),
})
}
fn answer_observational_query(
&self,
query: &CausalQuery,
target_id: VariableId,
) -> Result<CausalAnswer, CoherenceError> {
// For observational queries, we need to condition
// This requires probabilistic reasoning which we approximate
let explanation = format!(
"Observational query P({} | {}) - requires probabilistic inference",
query.target,
query.conditions.iter()
.map(|(v, val)| format!("{}={:?}", v, val))
.collect::<Vec<_>>()
.join(", ")
);
Ok(CausalAnswer {
query: query.clone(),
estimate: Value::Missing, // Would need actual probabilistic computation
confidence_interval: None,
is_identifiable: true, // Observational queries are always identifiable
explanation,
})
}
fn answer_counterfactual_query(
&self,
query: &CausalQuery,
_target_id: VariableId,
) -> Result<CausalAnswer, CoherenceError> {
// Counterfactual queries require abduction-action-prediction
let explanation = format!(
"Counterfactual query for {} - requires three-step process: abduction, action, prediction",
query.target
);
Ok(CausalAnswer {
query: query.clone(),
estimate: Value::Missing,
confidence_interval: None,
is_identifiable: false, // Counterfactuals often not identifiable
explanation,
})
}
fn answer_conditional_interventional_query(
&self,
query: &CausalQuery,
target_id: VariableId,
) -> Result<CausalAnswer, CoherenceError> {
// Combines intervention with conditioning
let explanation = format!(
"Conditional interventional query P({} | do({}), {}) - may require adjustment formula",
query.target,
query.interventions.iter()
.map(|(v, val)| format!("{}={:?}", v, val))
.collect::<Vec<_>>()
.join(", "),
query.conditions.iter()
.map(|(v, val)| format!("{}={:?}", v, val))
.collect::<Vec<_>>()
.join(", ")
);
Ok(CausalAnswer {
query: query.clone(),
estimate: Value::Missing,
confidence_interval: None,
is_identifiable: self.check_identifiability(query),
explanation,
})
}
/// Check if a causal query is identifiable from observational data
fn check_identifiability(&self, query: &CausalQuery) -> bool {
// Simplified identifiability check
// Full implementation would use do-calculus rules
if query.interventions.is_empty() {
return true; // Observational queries are identifiable
}
// Check if intervention variables have unobserved confounders with target
for (var, _) in &query.interventions {
if let (Some(var_id), Some(target_id)) = (
self.model.get_variable_id(var),
self.model.get_variable_id(&query.target),
) {
// If there's a backdoor path that can't be blocked, not identifiable
// This is a simplified check
let var_ancestors = self.model.graph().ancestors(var_id.0);
let target_ancestors = self.model.graph().ancestors(target_id.0);
// If they share unobserved common ancestors, might not be identifiable
let common = var_ancestors.intersection(&target_ancestors).count();
if common > 0 && !self.has_valid_adjustment_set(var_id, target_id) {
return false;
}
}
}
true
}
/// Check if there's a valid adjustment set for identifying causal effect
fn has_valid_adjustment_set(&self, treatment: VariableId, outcome: VariableId) -> bool {
// Check backdoor criterion
// A set Z satisfies backdoor criterion if:
// 1. No node in Z is a descendant of X
// 2. Z blocks every path from X to Y that contains an arrow into X
let descendants = self.model.graph().descendants(treatment.0);
// Try the set of all non-descendants as potential adjustment set
let all_vars: Vec<_> = self.model.variables()
.filter(|v| v.id != treatment && v.id != outcome)
.filter(|v| !descendants.contains(&v.id.0))
.map(|v| v.id)
.collect();
// Check if conditioning on all non-descendants blocks backdoor paths
self.model.d_separated(treatment, outcome, &all_vars)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::causal::model::{CausalModelBuilder, VariableType};
fn create_test_model() -> CausalModel {
let mut model = CausalModel::with_name("Test");
model.add_variable("Age", VariableType::Continuous).unwrap();
model.add_variable("Education", VariableType::Continuous).unwrap();
model.add_variable("Income", VariableType::Continuous).unwrap();
model.add_variable("Health", VariableType::Continuous).unwrap();
let age = model.get_variable_id("Age").unwrap();
let edu = model.get_variable_id("Education").unwrap();
let income = model.get_variable_id("Income").unwrap();
let health = model.get_variable_id("Health").unwrap();
// Age -> Education, Age -> Health
model.add_edge(age, edu).unwrap();
model.add_edge(age, health).unwrap();
// Education -> Income
model.add_edge(edu, income).unwrap();
// Add equations
model.add_structural_equation(edu, &[age], Mechanism::new(|p| {
Value::Continuous(12.0 + p[0].as_f64() * 0.1)
})).unwrap();
model.add_structural_equation(income, &[edu], Mechanism::new(|p| {
Value::Continuous(30000.0 + p[0].as_f64() * 5000.0)
})).unwrap();
model.add_structural_equation(health, &[age], Mechanism::new(|p| {
Value::Continuous(100.0 - p[0].as_f64() * 0.5)
})).unwrap();
model
}
#[test]
fn test_belief_creation() {
let belief = Belief::causes("Age", "Education").with_confidence(0.9);
assert_eq!(belief.subject, "Age");
assert_eq!(belief.object, "Education");
assert_eq!(belief.confidence, 0.9);
}
#[test]
fn test_causal_consistency() {
let model = create_test_model();
let checker = CausalCoherenceChecker::new(&model);
let beliefs = vec![
Belief::causes("Age", "Education"),
Belief::causes("Education", "Income"),
];
let result = checker.check_causal_consistency(&beliefs);
assert!(result.is_consistent());
assert_eq!(result.consistent_beliefs, 2);
}
#[test]
fn test_inconsistent_belief() {
let model = create_test_model();
let checker = CausalCoherenceChecker::new(&model);
let beliefs = vec![
Belief::causes("Income", "Age"), // Wrong direction
];
let result = checker.check_causal_consistency(&beliefs);
assert!(!result.is_consistent());
assert_eq!(result.inconsistent_beliefs, 1);
}
#[test]
fn test_conditional_independence() {
let model = create_test_model();
let checker = CausalCoherenceChecker::new(&model);
// Education and Health should be independent given Age
let beliefs = vec![
Belief::conditionally_independent("Education", "Health", &["Age"]),
];
let result = checker.check_causal_consistency(&beliefs);
assert!(result.is_consistent());
}
#[test]
fn test_spurious_correlation_detection() {
let model = create_test_model();
let checker = CausalCoherenceChecker::new(&model).with_correlation_threshold(0.1);
// Create dataset with Education-Health correlation (spurious via Age)
let mut data = Dataset::new(vec![
"Age".to_string(),
"Education".to_string(),
"Health".to_string(),
]);
// Add correlated data
for i in 0..100 {
let age = 20.0 + i as f64 * 0.5;
let edu = 12.0 + age * 0.1 + (i as f64 * 0.1).sin();
let health = 100.0 - age * 0.5 + (i as f64 * 0.2).cos();
data.add_row(vec![age, edu, health]);
}
let spurious = checker.detect_spurious_correlations(&data);
// Should detect Education-Health as spurious (both caused by Age)
let edu_health = spurious.iter()
.find(|s| (s.var_a == "Education" && s.var_b == "Health") ||
(s.var_a == "Health" && s.var_b == "Education"));
assert!(edu_health.is_some());
if let Some(s) = edu_health {
assert!(s.confounders.contains(&"Age".to_string()));
}
}
#[test]
fn test_interventional_query() {
let model = create_test_model();
let checker = CausalCoherenceChecker::new(&model);
let query = CausalQuery::interventional(
"Income",
"Education",
Value::Continuous(16.0),
);
let answer = checker.enforce_do_calculus(&query).unwrap();
assert!(answer.is_identifiable);
assert!(matches!(answer.query.query_type, QueryType::Interventional));
}
#[test]
fn test_coherence_energy() {
let energy = CoherenceEnergy::from_components(0.1, 0.2, 0.05);
assert!((energy.total - 0.35).abs() < 1e-10);
assert!(!energy.is_coherent);
let coherent = CoherenceEnergy::coherent();
assert!(coherent.is_coherent);
}
#[test]
fn test_dataset_correlation() {
let mut data = Dataset::new(vec!["X".to_string(), "Y".to_string()]);
// Perfect positive correlation
for i in 0..10 {
data.add_row(vec![i as f64, i as f64]);
}
let corr = data.correlation("X", "Y").unwrap();
assert!((corr - 1.0).abs() < 1e-10);
// Add negatively correlated data
let mut data2 = Dataset::new(vec!["A".to_string(), "B".to_string()]);
for i in 0..10 {
data2.add_row(vec![i as f64, (10 - i) as f64]);
}
let corr2 = data2.correlation("A", "B").unwrap();
assert!((corr2 + 1.0).abs() < 1e-10);
}
#[test]
fn test_causal_query_builder() {
let query = CausalQuery::interventional("Y", "X", Value::Continuous(1.0))
.given("Z", Value::Continuous(2.0));
assert_eq!(query.target, "Y");
assert_eq!(query.interventions.len(), 1);
assert_eq!(query.conditions.len(), 1);
}
}

View File

@@ -0,0 +1,805 @@
//! Counterfactual Reasoning
//!
//! This module implements counterfactual inference based on Pearl's three-step
//! procedure: Abduction, Action, Prediction.
//!
//! ## Counterfactual Semantics
//!
//! Given a structural causal model M = (U, V, F), a counterfactual query asks:
//! "What would Y have been if X had been x, given that we observed E = e?"
//!
//! Written as: P(Y_x | E = e) or Y_{X=x}(u) where u is the exogenous state.
//!
//! ## Three-Step Procedure
//!
//! 1. **Abduction**: Update P(U) given evidence E = e to get P(U | E = e)
//! 2. **Action**: Modify the model by intervention do(X = x)
//! 3. **Prediction**: Compute Y in the modified model using updated U
//!
//! ## References
//!
//! - Pearl (2009): "Causality" Chapter 7
//! - Halpern (2016): "Actual Causality"
use std::collections::HashMap;
use thiserror::Error;
use super::model::{
CausalModel, CausalModelError, Intervention, Value, VariableId, Mechanism,
Observation,
};
/// Error types for counterfactual reasoning
#[derive(Debug, Clone, Error)]
pub enum CounterfactualError {
/// Model error
#[error("Model error: {0}")]
ModelError(#[from] CausalModelError),
/// Invalid observation
#[error("Invalid observation: variable '{0}' not in model")]
InvalidObservation(String),
/// Abduction failed
#[error("Abduction failed: {0}")]
AbductionFailed(String),
/// Counterfactual not well-defined
#[error("Counterfactual not well-defined: {0}")]
NotWellDefined(String),
}
/// Extended Distribution type for counterfactual results
#[derive(Debug, Clone)]
pub struct CounterfactualDistribution {
/// Point estimate values (for deterministic models)
pub values: HashMap<VariableId, Value>,
/// Probability mass (for discrete) or density (for continuous)
pub probability: f64,
}
impl CounterfactualDistribution {
/// Create a point mass distribution
pub fn point_mass(values: HashMap<VariableId, Value>) -> Self {
Self {
values,
probability: 1.0,
}
}
/// Create from a simulation result
pub fn from_simulation(values: HashMap<VariableId, Value>) -> Self {
Self::point_mass(values)
}
/// Get value for a variable
pub fn get(&self, var: VariableId) -> Option<&Value> {
self.values.get(&var)
}
/// Get mean value (for continuous distributions)
pub fn mean(&self, var: VariableId) -> f64 {
self.values.get(&var)
.map(|v| v.as_f64())
.unwrap_or(0.0)
}
}
/// A counterfactual query
#[derive(Debug, Clone)]
pub struct CounterfactualQuery {
/// Target variable (what we want to know)
pub target: String,
/// Interventions (what we're hypothetically changing)
pub interventions: Vec<(String, Value)>,
/// Evidence (what we observed)
pub evidence: Observation,
}
impl CounterfactualQuery {
/// Create a new counterfactual query
///
/// Asking: "What would `target` have been if we had done `interventions`,
/// given that we observed `evidence`?"
pub fn new(target: &str, interventions: Vec<(&str, Value)>, evidence: Observation) -> Self {
Self {
target: target.to_string(),
interventions: interventions.into_iter()
.map(|(k, v)| (k.to_string(), v))
.collect(),
evidence,
}
}
/// Simple counterfactual: "What would Y have been if X had been x?"
pub fn simple(target: &str, intervention_var: &str, intervention_val: Value) -> Self {
Self {
target: target.to_string(),
interventions: vec![(intervention_var.to_string(), intervention_val)],
evidence: Observation::empty(),
}
}
/// Add evidence to the query
pub fn given(mut self, var: &str, value: Value) -> Self {
self.evidence.observe(var, value);
self
}
}
/// Result of a counterfactual query
#[derive(Debug, Clone)]
pub struct CounterfactualResult {
/// The query that was answered
pub query: CounterfactualQuery,
/// The counterfactual distribution
pub distribution: CounterfactualDistribution,
/// The abduced exogenous values
pub exogenous: HashMap<VariableId, Value>,
/// Explanation of the reasoning
pub explanation: String,
}
/// Average Treatment Effect computation
#[derive(Debug, Clone)]
pub struct AverageTreatmentEffect {
/// Treatment variable
pub treatment: String,
/// Outcome variable
pub outcome: String,
/// Treatment value
pub treatment_value: Value,
/// Control value
pub control_value: Value,
/// Estimated ATE
pub ate: f64,
/// Standard error (if available)
pub standard_error: Option<f64>,
}
impl AverageTreatmentEffect {
/// Create a new ATE result
pub fn new(
treatment: &str,
outcome: &str,
treatment_value: Value,
control_value: Value,
ate: f64,
) -> Self {
Self {
treatment: treatment.to_string(),
outcome: outcome.to_string(),
treatment_value,
control_value,
ate,
standard_error: None,
}
}
/// Set standard error
pub fn with_standard_error(mut self, se: f64) -> Self {
self.standard_error = Some(se);
self
}
}
/// Compute a counterfactual: "What would Y have been if X had been x, given observation?"
///
/// This implements Pearl's three-step procedure:
/// 1. Abduction: Infer exogenous variables from observation
/// 2. Action: Apply intervention do(X = x)
/// 3. Prediction: Compute Y under intervention with abduced exogenous values
///
/// # Arguments
/// * `model` - The causal model
/// * `observation` - The observed evidence
/// * `intervention_var` - The variable to intervene on
/// * `intervention_value` - The counterfactual value for the intervention
/// * `target_name` - The target variable to compute the counterfactual for
pub fn counterfactual(
model: &CausalModel,
observation: &Observation,
intervention_var: VariableId,
intervention_value: Value,
target_name: &str,
) -> Result<Value, CounterfactualError> {
// Step 1: Abduction - infer exogenous variables
let exogenous = abduce(model, observation)?;
// Step 2: Action - create intervened model
let intervention = Intervention::new(intervention_var, intervention_value);
let intervened = model.intervene_with(&[intervention])?;
// Step 3: Prediction - simulate with abduced exogenous and intervention
let result = intervened.simulate(&exogenous)?;
// Get the target variable
let target_id = model.get_variable_id(target_name)
.ok_or_else(|| CounterfactualError::InvalidObservation(target_name.to_string()))?;
result.get(&target_id)
.cloned()
.ok_or_else(|| CounterfactualError::AbductionFailed(
format!("Target variable {} not computed", target_name)
))
}
/// Compute a counterfactual with an Intervention struct (alternative API)
pub fn counterfactual_with_intervention(
model: &CausalModel,
observation: &Observation,
intervention: &Intervention,
) -> Result<CounterfactualDistribution, CounterfactualError> {
// Step 1: Abduction - infer exogenous variables
let exogenous = abduce(model, observation)?;
// Step 2: Action - create intervened model
let intervened = model.intervene_with(&[intervention.clone()])?;
// Step 3: Prediction - simulate with abduced exogenous and intervention
let result = intervened.simulate(&exogenous)?;
Ok(CounterfactualDistribution::from_simulation(result))
}
/// Compute counterfactual from a query object
pub fn counterfactual_query(
model: &CausalModel,
query: &CounterfactualQuery,
) -> Result<CounterfactualResult, CounterfactualError> {
// Convert interventions
let interventions: Result<Vec<Intervention>, _> = query.interventions.iter()
.map(|(var, val)| {
model.get_variable_id(var)
.map(|id| Intervention::new(id, val.clone()))
.ok_or_else(|| CounterfactualError::InvalidObservation(var.clone()))
})
.collect();
let interventions = interventions?;
// Step 1: Abduction
let exogenous = abduce(model, &query.evidence)?;
// Step 2 & 3: Action and Prediction
let intervened = model.intervene_with(&interventions)?;
let result = intervened.simulate(&exogenous)?;
let target_id = model.get_variable_id(&query.target)
.ok_or_else(|| CounterfactualError::InvalidObservation(query.target.clone()))?;
let explanation = format!(
"Counterfactual computed via three-step procedure:\n\
1. Abduced exogenous values from evidence\n\
2. Applied intervention(s): {}\n\
3. Predicted {} under intervention",
query.interventions.iter()
.map(|(v, val)| format!("do({}={:?})", v, val))
.collect::<Vec<_>>()
.join(", "),
query.target
);
Ok(CounterfactualResult {
query: query.clone(),
distribution: CounterfactualDistribution::from_simulation(result),
exogenous,
explanation,
})
}
/// Abduction: Infer exogenous variable values from observations
///
/// For deterministic models, this inverts the structural equations
fn abduce(
model: &CausalModel,
observation: &Observation,
) -> Result<HashMap<VariableId, Value>, CounterfactualError> {
let mut exogenous = HashMap::new();
// Get topological order
let topo_order = model.topological_order_ids()?;
// For each variable in topological order
for var_id in topo_order {
let var = model.get_variable(var_id.as_ref())
.ok_or_else(|| CounterfactualError::AbductionFailed(
format!("Variable {:?} not found", var_id)
))?;
// Check if this variable is observed
if let Some(observed_value) = observation.values.get(&var.name) {
// If this is a root variable (no parents), it's exogenous
if model.parents(&var_id).map(|p| p.is_empty()).unwrap_or(true) {
exogenous.insert(var_id, observed_value.clone());
} else {
// For endogenous variables, we might need to compute the residual
// For now, we use the observed value as the exogenous noise
exogenous.insert(var_id, observed_value.clone());
}
}
}
Ok(exogenous)
}
/// Compute the Average Treatment Effect (ATE)
///
/// ATE = E[Y | do(X=treatment_value)] - E[Y | do(X=control_value)]
///
/// # Arguments
/// * `model` - The causal model
/// * `treatment` - Treatment variable ID
/// * `outcome` - Outcome variable ID
/// * `treatment_value` - The treatment value
/// * `control_value` - The control value
pub fn causal_effect(
model: &CausalModel,
treatment: VariableId,
outcome: VariableId,
treatment_value: Value,
control_value: Value,
) -> Result<f64, CounterfactualError> {
causal_effect_at_values(
model,
treatment,
outcome,
treatment_value,
control_value,
)
}
/// Compute the Average Treatment Effect with default binary values
///
/// ATE = E[Y | do(X=1)] - E[Y | do(X=0)]
pub fn causal_effect_binary(
model: &CausalModel,
treatment: VariableId,
outcome: VariableId,
) -> Result<f64, CounterfactualError> {
causal_effect_at_values(
model,
treatment,
outcome,
Value::Continuous(1.0),
Value::Continuous(0.0),
)
}
/// Compute causal effect at specific treatment values
pub fn causal_effect_at_values(
model: &CausalModel,
treatment: VariableId,
outcome: VariableId,
treatment_value: Value,
control_value: Value,
) -> Result<f64, CounterfactualError> {
// E[Y | do(X = treatment)]
let intervention_treat = Intervention::new(treatment, treatment_value);
let intervened_treat = model.intervene_with(&[intervention_treat])?;
let result_treat = intervened_treat.simulate(&HashMap::new())?;
let y_treat = result_treat.get(&outcome)
.map(|v| v.as_f64())
.unwrap_or(0.0);
// E[Y | do(X = control)]
let intervention_ctrl = Intervention::new(treatment, control_value);
let intervened_ctrl = model.intervene_with(&[intervention_ctrl])?;
let result_ctrl = intervened_ctrl.simulate(&HashMap::new())?;
let y_ctrl = result_ctrl.get(&outcome)
.map(|v| v.as_f64())
.unwrap_or(0.0);
Ok(y_treat - y_ctrl)
}
/// Compute ATE with full result structure
pub fn average_treatment_effect(
model: &CausalModel,
treatment_name: &str,
outcome_name: &str,
treatment_value: Value,
control_value: Value,
) -> Result<AverageTreatmentEffect, CounterfactualError> {
let treatment_id = model.get_variable_id(treatment_name)
.ok_or_else(|| CounterfactualError::InvalidObservation(treatment_name.to_string()))?;
let outcome_id = model.get_variable_id(outcome_name)
.ok_or_else(|| CounterfactualError::InvalidObservation(outcome_name.to_string()))?;
let ate = causal_effect_at_values(
model,
treatment_id,
outcome_id,
treatment_value.clone(),
control_value.clone(),
)?;
Ok(AverageTreatmentEffect::new(
treatment_name,
outcome_name,
treatment_value,
control_value,
ate,
))
}
/// Compute Individual Treatment Effect (ITE) for a specific unit
///
/// ITE_i = Y_i(X=1) - Y_i(X=0)
///
/// This is a counterfactual quantity: what would have happened to unit i
/// under different treatment assignments.
pub fn individual_treatment_effect(
model: &CausalModel,
treatment: VariableId,
outcome: VariableId,
unit_observation: &Observation,
treatment_value: Value,
control_value: Value,
) -> Result<f64, CounterfactualError> {
// Get the outcome variable name
let outcome_name = model.get_variable_name(&outcome)
.ok_or_else(|| CounterfactualError::InvalidObservation(format!("Outcome variable {:?} not found", outcome)))?;
// Counterfactual: Y(X=treatment) for this unit
let y_treat = counterfactual(model, unit_observation, treatment, treatment_value, &outcome_name)?;
let y_treat_val = y_treat.as_f64();
// Counterfactual: Y(X=control) for this unit
let y_ctrl = counterfactual(model, unit_observation, treatment, control_value, &outcome_name)?;
let y_ctrl_val = y_ctrl.as_f64();
Ok(y_treat_val - y_ctrl_val)
}
/// Natural Direct Effect (NDE)
///
/// NDE = E[Y(x, M(x'))] - E[Y(x', M(x'))]
///
/// The effect of X on Y that would remain if the mediator were held at the
/// value it would have taken under X = x'.
pub fn natural_direct_effect(
model: &CausalModel,
treatment: VariableId,
mediator: VariableId,
outcome: VariableId,
treatment_value: Value,
control_value: Value,
) -> Result<f64, CounterfactualError> {
// E[Y(x', M(x'))] - baseline
let ctrl_intervention = Intervention::new(treatment, control_value.clone());
let intervened = model.intervene_with(&[ctrl_intervention.clone()])?;
let baseline_result = intervened.simulate(&HashMap::new())?;
let m_ctrl = baseline_result.get(&mediator).cloned().unwrap_or(Value::Missing);
let y_baseline = baseline_result.get(&outcome)
.map(|v| v.as_f64())
.unwrap_or(0.0);
// E[Y(x, M(x'))] - intervene on X but keep M at control level
let treat_intervention = Intervention::new(treatment, treatment_value);
let m_intervention = Intervention::new(mediator, m_ctrl);
let intervened = model.intervene_with(&[treat_intervention, m_intervention])?;
let nde_result = intervened.simulate(&HashMap::new())?;
let y_nde = nde_result.get(&outcome)
.map(|v| v.as_f64())
.unwrap_or(0.0);
Ok(y_nde - y_baseline)
}
/// Natural Indirect Effect (NIE)
///
/// NIE = E[Y(x, M(x))] - E[Y(x, M(x'))]
///
/// The effect of X on Y that is mediated through M.
pub fn natural_indirect_effect(
model: &CausalModel,
treatment: VariableId,
mediator: VariableId,
outcome: VariableId,
treatment_value: Value,
control_value: Value,
) -> Result<f64, CounterfactualError> {
// E[Y(x, M(x))] - full treatment effect
let treat_intervention = Intervention::new(treatment, treatment_value.clone());
let intervened = model.intervene_with(&[treat_intervention.clone()])?;
let full_result = intervened.simulate(&HashMap::new())?;
let y_full = full_result.get(&outcome)
.map(|v| v.as_f64())
.unwrap_or(0.0);
// E[Y(x, M(x'))] - treatment but mediator at control level
let ctrl_intervention = Intervention::new(treatment, control_value);
let ctrl_intervened = model.intervene_with(&[ctrl_intervention])?;
let ctrl_result = ctrl_intervened.simulate(&HashMap::new())?;
let m_ctrl = ctrl_result.get(&mediator).cloned().unwrap_or(Value::Missing);
let m_intervention = Intervention::new(mediator, m_ctrl);
let intervened = model.intervene_with(&[treat_intervention, m_intervention])?;
let indirect_result = intervened.simulate(&HashMap::new())?;
let y_indirect = indirect_result.get(&outcome)
.map(|v| v.as_f64())
.unwrap_or(0.0);
Ok(y_full - y_indirect)
}
/// Probability of Necessity (PN)
///
/// PN = P(Y_x' = 0 | X = x, Y = 1)
///
/// Given that X=x and Y=1 occurred, what is the probability that Y would
/// have been 0 if X had been x' instead?
pub fn probability_of_necessity(
model: &CausalModel,
treatment: VariableId,
outcome: VariableId,
observation: &Observation,
counterfactual_treatment: Value,
) -> Result<f64, CounterfactualError> {
// Get outcome variable name
let outcome_name = model.get_variable_name(&outcome)
.ok_or_else(|| CounterfactualError::InvalidObservation(format!("Outcome variable {:?} not found", outcome)))?;
// Compute counterfactual outcome
let cf_value = counterfactual(model, observation, treatment, counterfactual_treatment, &outcome_name)?;
let cf_outcome = cf_value.as_f64();
// PN is probability that outcome would be 0 (negative)
// For continuous outcomes, we check if it crosses the threshold
let observed_outcome = observation.values.iter()
.find_map(|(name, val)| {
model.get_variable_id(name)
.filter(|id| *id == outcome)
.map(|_| val.as_f64())
})
.unwrap_or(0.0);
// Simple heuristic: if counterfactual outcome is significantly different
if observed_outcome > 0.0 && cf_outcome <= 0.0 {
Ok(1.0) // Necessary
} else if (observed_outcome - cf_outcome).abs() < 1e-6 {
Ok(0.0) // Not necessary
} else {
Ok(0.5) // Uncertain
}
}
/// Probability of Sufficiency (PS)
///
/// PS = P(Y_x = 1 | X = x', Y = 0)
///
/// Given that X=x' and Y=0 occurred, what is the probability that Y would
/// have been 1 if X had been x instead?
pub fn probability_of_sufficiency(
model: &CausalModel,
treatment: VariableId,
outcome: VariableId,
observation: &Observation,
counterfactual_treatment: Value,
) -> Result<f64, CounterfactualError> {
// Get outcome variable name
let outcome_name = model.get_variable_name(&outcome)
.ok_or_else(|| CounterfactualError::InvalidObservation(format!("Outcome variable {:?} not found", outcome)))?;
// Compute counterfactual outcome
let cf_value = counterfactual(model, observation, treatment, counterfactual_treatment, &outcome_name)?;
let cf_outcome = cf_value.as_f64();
let observed_outcome = observation.values.iter()
.find_map(|(name, val)| {
model.get_variable_id(name)
.filter(|id| *id == outcome)
.map(|_| val.as_f64())
})
.unwrap_or(1.0);
// PS: would the outcome have been positive if treatment were different?
if observed_outcome <= 0.0 && cf_outcome > 0.0 {
Ok(1.0) // Sufficient
} else if (observed_outcome - cf_outcome).abs() < 1e-6 {
Ok(0.0) // Not sufficient
} else {
Ok(0.5) // Uncertain
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::causal::model::{CausalModelBuilder, VariableType, Mechanism};
fn create_simple_model() -> CausalModel {
let mut model = CausalModel::with_name("Simple");
model.add_variable("X", VariableType::Continuous).unwrap();
model.add_variable("Y", VariableType::Continuous).unwrap();
let x = model.get_variable_id("X").unwrap();
let y = model.get_variable_id("Y").unwrap();
// Y = 2*X + 1
model.add_structural_equation(y, &[x], Mechanism::new(|p| {
Value::Continuous(2.0 * p[0].as_f64() + 1.0)
})).unwrap();
model
}
fn create_mediation_model() -> CausalModel {
let mut model = CausalModel::with_name("Mediation");
model.add_variable("X", VariableType::Continuous).unwrap();
model.add_variable("M", VariableType::Continuous).unwrap();
model.add_variable("Y", VariableType::Continuous).unwrap();
let x = model.get_variable_id("X").unwrap();
let m = model.get_variable_id("M").unwrap();
let y = model.get_variable_id("Y").unwrap();
// M = X
model.add_structural_equation(m, &[x], Mechanism::new(|p| {
p[0].clone()
})).unwrap();
// Y = M + 0.5*X
model.add_structural_equation(y, &[m, x], Mechanism::new(|p| {
Value::Continuous(p[0].as_f64() + 0.5 * p[1].as_f64())
})).unwrap();
model
}
#[test]
fn test_observation() {
let mut obs = Observation::new(&[("X", Value::Continuous(5.0))]);
obs.observe("Y", Value::Continuous(11.0));
assert!(obs.is_observed("X"));
assert!(obs.is_observed("Y"));
assert!(!obs.is_observed("Z"));
}
#[test]
fn test_counterfactual_simple() {
let model = create_simple_model();
let x_id = model.get_variable_id("X").unwrap();
// Observation: X=3, Y=7 (since Y = 2*3 + 1)
let observation = Observation::new(&[
("X", Value::Continuous(3.0)),
("Y", Value::Continuous(7.0)),
]);
// Counterfactual: What would Y have been if X had been 5?
let intervention = Intervention::new(x_id, Value::Continuous(5.0));
let result = counterfactual(&model, &observation, &intervention).unwrap();
// Y should be 2*5 + 1 = 11
let y_id = model.get_variable_id("Y").unwrap();
let y_value = result.get(y_id).unwrap().as_f64();
assert!((y_value - 11.0).abs() < 1e-10);
}
#[test]
fn test_causal_effect() {
let model = create_simple_model();
let x = model.get_variable_id("X").unwrap();
let y = model.get_variable_id("Y").unwrap();
// ATE = E[Y|do(X=1)] - E[Y|do(X=0)]
// = (2*1 + 1) - (2*0 + 1) = 3 - 1 = 2
let ate = causal_effect(&model, x, y).unwrap();
assert!((ate - 2.0).abs() < 1e-10);
}
#[test]
fn test_average_treatment_effect() {
let model = create_simple_model();
let ate_result = average_treatment_effect(
&model,
"X", "Y",
Value::Continuous(1.0),
Value::Continuous(0.0),
).unwrap();
assert_eq!(ate_result.treatment, "X");
assert_eq!(ate_result.outcome, "Y");
assert!((ate_result.ate - 2.0).abs() < 1e-10);
}
#[test]
fn test_mediation_effects() {
let model = create_mediation_model();
let x = model.get_variable_id("X").unwrap();
let m = model.get_variable_id("M").unwrap();
let y = model.get_variable_id("Y").unwrap();
let treat = Value::Continuous(1.0);
let ctrl = Value::Continuous(0.0);
// Total effect should be:
// E[Y|do(X=1)] - E[Y|do(X=0)]
// = (M(1) + 0.5*1) - (M(0) + 0.5*0)
// = (1 + 0.5) - (0 + 0) = 1.5
let total = causal_effect_at_values(&model, x, y, treat.clone(), ctrl.clone()).unwrap();
assert!((total - 1.5).abs() < 1e-10);
// NDE should be the direct effect = 0.5 (coefficient of X in Y equation)
let nde = natural_direct_effect(&model, x, m, y, treat.clone(), ctrl.clone()).unwrap();
assert!((nde - 0.5).abs() < 1e-10);
// NIE should be the indirect effect = 1.0 (coefficient of M in Y, times effect of X on M)
let nie = natural_indirect_effect(&model, x, m, y, treat, ctrl).unwrap();
assert!((nie - 1.0).abs() < 1e-10);
// NDE + NIE should equal total effect
assert!((nde + nie - total).abs() < 1e-10);
}
#[test]
fn test_counterfactual_query() {
let model = create_simple_model();
let query = CounterfactualQuery::new(
"Y",
vec![("X", Value::Continuous(10.0))],
Observation::new(&[("X", Value::Continuous(3.0))]),
);
let result = counterfactual_query(&model, &query).unwrap();
// Y = 2*10 + 1 = 21
let y_id = model.get_variable_id("Y").unwrap();
assert!((result.distribution.mean(y_id) - 21.0).abs() < 1e-10);
}
#[test]
fn test_distribution() {
let mut values = HashMap::new();
let x_id = VariableId(0);
let y_id = VariableId(1);
values.insert(x_id, Value::Continuous(5.0));
values.insert(y_id, Value::Continuous(10.0));
let dist = CounterfactualDistribution::point_mass(values);
assert_eq!(dist.mean(x_id), 5.0);
assert_eq!(dist.mean(y_id), 10.0);
assert_eq!(dist.probability, 1.0);
}
#[test]
fn test_individual_treatment_effect() {
let model = create_simple_model();
let x = model.get_variable_id("X").unwrap();
let y = model.get_variable_id("Y").unwrap();
// Unit-specific observation
let unit_obs = Observation::new(&[
("X", Value::Continuous(3.0)),
("Y", Value::Continuous(7.0)),
]);
let ite = individual_treatment_effect(
&model,
x, y,
&unit_obs,
Value::Continuous(5.0),
Value::Continuous(3.0),
).unwrap();
// ITE = Y(X=5) - Y(X=3) = 11 - 7 = 4
assert!((ite - 4.0).abs() < 1e-10);
}
}

View File

@@ -0,0 +1,920 @@
//! Do-Calculus Implementation
//!
//! This module implements Pearl's do-calculus, a complete set of inference rules
//! for computing causal effects from observational data when possible.
//!
//! ## The Three Rules of Do-Calculus
//!
//! Given a causal DAG G, the following rules hold:
//!
//! **Rule 1 (Insertion/deletion of observations):**
//! P(y | do(x), z, w) = P(y | do(x), w) if (Y ⊥ Z | X, W)_{G_{\overline{X}}}
//!
//! **Rule 2 (Action/observation exchange):**
//! P(y | do(x), do(z), w) = P(y | do(x), z, w) if (Y ⊥ Z | X, W)_{G_{\overline{X}\underline{Z}}}
//!
//! **Rule 3 (Insertion/deletion of actions):**
//! P(y | do(x), do(z), w) = P(y | do(x), w) if (Y ⊥ Z | X, W)_{G_{\overline{X}\overline{Z(W)}}}
//!
//! where:
//! - G_{\overline{X}} is G with incoming edges to X deleted
//! - G_{\underline{Z}} is G with outgoing edges from Z deleted
//! - Z(W) is Z without ancestors of W in G_{\overline{X}}
//!
//! ## References
//!
//! - Pearl (1995): "Causal diagrams for empirical research"
//! - Shpitser & Pearl (2006): "Identification of Joint Interventional Distributions"
use std::collections::{HashMap, HashSet};
use thiserror::Error;
use super::model::{CausalModel, VariableId};
use super::graph::{DirectedGraph, DAGValidationError};
/// Error types for do-calculus operations
#[derive(Debug, Clone, Error)]
pub enum IdentificationError {
/// Query is not identifiable
#[error("Query is not identifiable: {0}")]
NotIdentifiable(String),
/// Invalid query specification
#[error("Invalid query: {0}")]
InvalidQuery(String),
/// Graph manipulation error
#[error("Graph error: {0}")]
GraphError(#[from] DAGValidationError),
/// Variable not found
#[error("Variable not found: {0}")]
VariableNotFound(String),
}
/// The three rules of do-calculus
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Rule {
/// Rule 1: Insertion/deletion of observations
Rule1,
/// Rule 2: Action/observation exchange
Rule2,
/// Rule 3: Insertion/deletion of actions
Rule3,
}
impl Rule {
/// Get the name of the rule
pub fn name(&self) -> &'static str {
match self {
Rule::Rule1 => "Insertion/deletion of observations",
Rule::Rule2 => "Action/observation exchange",
Rule::Rule3 => "Insertion/deletion of actions",
}
}
/// Get a description of what the rule does
pub fn description(&self) -> &'static str {
match self {
Rule::Rule1 => "Allows adding/removing observations that are d-separated from Y given do(X)",
Rule::Rule2 => "Allows exchanging do(Z) with Z under d-separation conditions",
Rule::Rule3 => "Allows removing interventions that have no effect on Y",
}
}
}
/// Result of an identification attempt (enum for pattern matching)
#[derive(Debug, Clone)]
pub enum Identification {
/// Effect is identifiable
Identified(IdentificationResult),
/// Effect is not identifiable
NotIdentified(String),
}
impl Identification {
/// Check if identified
pub fn is_identified(&self) -> bool {
matches!(self, Identification::Identified(_))
}
/// Get the result if identified
pub fn result(&self) -> Option<&IdentificationResult> {
match self {
Identification::Identified(r) => Some(r),
Identification::NotIdentified(_) => None,
}
}
}
/// Detailed result of successful identification
#[derive(Debug, Clone)]
pub struct IdentificationResult {
/// The sequence of rules applied
pub rules_applied: Vec<RuleApplication>,
/// The final expression
pub expression: String,
/// Adjustment set (if using backdoor criterion)
pub adjustment_set: Option<Vec<String>>,
/// Front-door set (if using front-door criterion)
pub front_door_set: Option<Vec<String>>,
}
/// Legacy result format for compatibility
#[derive(Debug, Clone)]
pub struct IdentificationLegacy {
/// Whether the query is identifiable
pub identifiable: bool,
/// The sequence of rules applied
pub rules_applied: Vec<RuleApplication>,
/// The final expression (if identifiable)
pub expression: Option<String>,
/// Adjustment set (if using backdoor criterion)
pub adjustment_set: Option<Vec<String>>,
/// Front-door set (if using front-door criterion)
pub front_door_set: Option<Vec<String>>,
}
/// Application of a do-calculus rule
#[derive(Debug, Clone)]
pub struct RuleApplication {
/// Which rule was applied
pub rule: Rule,
/// Variables involved
pub variables: Vec<String>,
/// Before expression
pub before: String,
/// After expression
pub after: String,
/// Graph modification used
pub graph_modification: String,
}
/// Do-Calculus engine for causal identification
pub struct DoCalculus<'a> {
model: &'a CausalModel,
}
impl<'a> DoCalculus<'a> {
/// Create a new do-calculus engine
pub fn new(model: &'a CausalModel) -> Self {
Self { model }
}
/// Identify P(Y | do(X)) - simplified API for single outcome and treatment set
///
/// Returns Identification enum for pattern matching
pub fn identify(&self, outcome: VariableId, treatment_set: &HashSet<VariableId>) -> Identification {
// Check if model has latent confounding affecting treatment-outcome
for &t in treatment_set {
if !self.model.is_unconfounded(t, outcome) {
// There is latent confounding
// Check if backdoor criterion can be satisfied
let treatment_vec: Vec<_> = treatment_set.iter().copied().collect();
if let Some(adjustment) = self.find_backdoor_adjustment(&treatment_vec, &[outcome]) {
let adjustment_names: Vec<String> = adjustment.iter()
.filter_map(|id| self.model.get_variable_name(id))
.collect();
return Identification::Identified(IdentificationResult {
rules_applied: vec![RuleApplication {
rule: Rule::Rule2,
variables: vec![format!("{:?}", outcome)],
before: format!("P(Y | do(X))"),
after: format!("Backdoor adjustment via {:?}", adjustment_names),
graph_modification: "Backdoor criterion".to_string(),
}],
expression: format!("Σ P(Y | X, Z) P(Z)"),
adjustment_set: Some(adjustment_names),
front_door_set: None,
});
}
// Check front-door
if treatment_set.len() == 1 {
let treatment = *treatment_set.iter().next().unwrap();
if let Some(mediators) = self.find_front_door_set(treatment, outcome) {
let mediator_names: Vec<String> = mediators.iter()
.filter_map(|id| self.model.get_variable_name(id))
.collect();
return Identification::Identified(IdentificationResult {
rules_applied: vec![RuleApplication {
rule: Rule::Rule2,
variables: vec![format!("{:?}", outcome)],
before: format!("P(Y | do(X))"),
after: format!("Front-door via {:?}", mediator_names),
graph_modification: "Front-door criterion".to_string(),
}],
expression: format!("Front-door formula"),
adjustment_set: None,
front_door_set: Some(mediator_names),
});
}
}
return Identification::NotIdentified(
"Effect not identifiable due to latent confounding".to_string()
);
}
}
// No latent confounding - directly identifiable
Identification::Identified(IdentificationResult {
rules_applied: vec![RuleApplication {
rule: Rule::Rule3,
variables: vec![format!("{:?}", outcome)],
before: format!("P(Y | do(X))"),
after: format!("P(Y | X)"),
graph_modification: "Direct identification".to_string(),
}],
expression: format!("P(Y | X)"),
adjustment_set: Some(vec![]),
front_door_set: None,
})
}
/// Identify using string names (legacy API)
pub fn identify_by_name(
&self,
treatment: &[&str],
outcome: &[&str],
) -> Result<IdentificationLegacy, IdentificationError> {
// Convert names to IDs
let treatment_ids: Result<Vec<_>, _> = treatment.iter()
.map(|&name| {
self.model.get_variable_id(name)
.ok_or_else(|| IdentificationError::VariableNotFound(name.to_string()))
})
.collect();
let treatment_ids = treatment_ids?;
let outcome_ids: Result<Vec<_>, _> = outcome.iter()
.map(|&name| {
self.model.get_variable_id(name)
.ok_or_else(|| IdentificationError::VariableNotFound(name.to_string()))
})
.collect();
let outcome_ids = outcome_ids?;
// Try different identification strategies
let mut rules_applied = Vec::new();
// Strategy 1: Check backdoor criterion
if let Some(adjustment) = self.find_backdoor_adjustment(&treatment_ids, &outcome_ids) {
let adjustment_names: Vec<String> = adjustment.iter()
.filter_map(|id| self.model.get_variable_name(id))
.collect();
rules_applied.push(RuleApplication {
rule: Rule::Rule2,
variables: treatment.iter().map(|s| s.to_string()).collect(),
before: format!("P({} | do({}))",
outcome.join(", "), treatment.join(", ")),
after: format!("Σ_{{{}}} P({} | {}, {}) P({})",
adjustment_names.join(", "),
outcome.join(", "),
treatment.join(", "),
adjustment_names.join(", "),
adjustment_names.join(", ")),
graph_modification: "Backdoor criterion satisfied".to_string(),
});
return Ok(IdentificationLegacy {
identifiable: true,
rules_applied,
expression: Some(format!(
"Σ_{{{}}} P({} | {}, {}) P({})",
adjustment_names.join(", "),
outcome.join(", "),
treatment.join(", "),
adjustment_names.join(", "),
adjustment_names.join(", ")
)),
adjustment_set: Some(adjustment_names),
front_door_set: None,
});
}
// Strategy 2: Check front-door criterion
if treatment_ids.len() == 1 && outcome_ids.len() == 1 {
if let Some(mediators) = self.find_front_door_set(treatment_ids[0], outcome_ids[0]) {
let mediator_names: Vec<String> = mediators.iter()
.filter_map(|id| self.model.get_variable_name(id))
.collect();
rules_applied.push(RuleApplication {
rule: Rule::Rule2,
variables: vec![treatment[0].to_string()],
before: format!("P({} | do({}))", outcome[0], treatment[0]),
after: format!("Front-door adjustment via {}", mediator_names.join(", ")),
graph_modification: "Front-door criterion satisfied".to_string(),
});
return Ok(IdentificationLegacy {
identifiable: true,
rules_applied,
expression: Some(format!(
"Σ_{{{}}} P({} | {}) Σ_{{{}}} P({} | {}, {}) P({})",
mediator_names.join(", "),
mediator_names.join(", "),
treatment[0],
treatment[0],
outcome[0],
mediator_names.join(", "),
treatment[0],
treatment[0]
)),
adjustment_set: None,
front_door_set: Some(mediator_names),
});
}
}
// Strategy 3: Check direct identifiability (no confounders)
if self.is_directly_identifiable(&treatment_ids, &outcome_ids) {
rules_applied.push(RuleApplication {
rule: Rule::Rule3,
variables: treatment.iter().map(|s| s.to_string()).collect(),
before: format!("P({} | do({}))", outcome.join(", "), treatment.join(", ")),
after: format!("P({} | {})", outcome.join(", "), treatment.join(", ")),
graph_modification: "No confounders; direct identification".to_string(),
});
return Ok(IdentificationLegacy {
identifiable: true,
rules_applied,
expression: Some(format!("P({} | {})", outcome.join(", "), treatment.join(", "))),
adjustment_set: Some(vec![]),
front_door_set: None,
});
}
// Not identifiable
Ok(IdentificationLegacy {
identifiable: false,
rules_applied: vec![],
expression: None,
adjustment_set: None,
front_door_set: None,
})
}
/// Check Rule 1: Can we add/remove observation Z?
///
/// P(y | do(x), z, w) = P(y | do(x), w) if (Y ⊥ Z | X, W) in G_{\overline{X}}
pub fn can_apply_rule1(
&self,
y: &[VariableId],
x: &[VariableId],
z: &[VariableId],
w: &[VariableId],
) -> bool {
// Build G_{\overline{X}}: delete incoming edges to X
let modified_graph = self.graph_delete_incoming(x);
// Check d-separation of Y and Z given X W in modified graph
let y_set: HashSet<_> = y.iter().map(|id| id.0).collect();
let z_set: HashSet<_> = z.iter().map(|id| id.0).collect();
let mut conditioning: HashSet<_> = x.iter().map(|id| id.0).collect();
conditioning.extend(w.iter().map(|id| id.0));
modified_graph.d_separated(&y_set, &z_set, &conditioning)
}
/// Check Rule 2: Can we exchange do(Z) with observation Z?
///
/// P(y | do(x), do(z), w) = P(y | do(x), z, w) if (Y ⊥ Z | X, W) in G_{\overline{X}\underline{Z}}
pub fn can_apply_rule2(
&self,
y: &[VariableId],
x: &[VariableId],
z: &[VariableId],
w: &[VariableId],
) -> bool {
// Build G_{\overline{X}\underline{Z}}: delete incoming edges to X and outgoing from Z
let modified_graph = self.graph_delete_incoming_and_outgoing(x, z);
// Check d-separation
let y_set: HashSet<_> = y.iter().map(|id| id.0).collect();
let z_set: HashSet<_> = z.iter().map(|id| id.0).collect();
let mut conditioning: HashSet<_> = x.iter().map(|id| id.0).collect();
conditioning.extend(w.iter().map(|id| id.0));
modified_graph.d_separated(&y_set, &z_set, &conditioning)
}
/// Check Rule 3: Can we remove do(Z)?
///
/// P(y | do(x), do(z), w) = P(y | do(x), w) if (Y ⊥ Z | X, W) in G_{\overline{X}\overline{Z(W)}}
pub fn can_apply_rule3(
&self,
y: &[VariableId],
x: &[VariableId],
z: &[VariableId],
w: &[VariableId],
) -> bool {
// Build G_{\overline{X}\overline{Z(W)}}: more complex modification
// Z(W) = Z \ ancestors of W in G_{\overline{X}}
// First get G_{\overline{X}}
let g_no_x = self.graph_delete_incoming(x);
// Find ancestors of W in G_{\overline{X}}
let w_ancestors: HashSet<_> = w.iter()
.flat_map(|wv| g_no_x.ancestors(wv.0))
.collect();
// Z(W) = Z without W's ancestors
let z_without_w_ancestors: Vec<_> = z.iter()
.filter(|zv| !w_ancestors.contains(&zv.0))
.copied()
.collect();
// Build G_{\overline{X}\overline{Z(W)}}
let modified_graph = self.graph_delete_incoming_multiple(
&[x, &z_without_w_ancestors].concat()
);
// Check d-separation
let y_set: HashSet<_> = y.iter().map(|id| id.0).collect();
let z_set: HashSet<_> = z.iter().map(|id| id.0).collect();
let mut conditioning: HashSet<_> = x.iter().map(|id| id.0).collect();
conditioning.extend(w.iter().map(|id| id.0));
modified_graph.d_separated(&y_set, &z_set, &conditioning)
}
/// Check Rule 1 with HashSet API (for test compatibility)
///
/// P(y | do(x), z) = P(y | do(x)) if (Y ⊥ Z | X) in G_{\overline{X}}
pub fn can_apply_rule1_sets(
&self,
y: &HashSet<VariableId>,
x: &HashSet<VariableId>,
z: &HashSet<VariableId>,
) -> bool {
let y_vec: Vec<_> = y.iter().copied().collect();
let x_vec: Vec<_> = x.iter().copied().collect();
let z_vec: Vec<_> = z.iter().copied().collect();
self.can_apply_rule1(&y_vec, &x_vec, &z_vec, &[])
}
/// Check Rule 2 with HashSet API (for test compatibility)
pub fn can_apply_rule2_sets(
&self,
y: &HashSet<VariableId>,
x: &HashSet<VariableId>,
z: &HashSet<VariableId>,
) -> bool {
let y_vec: Vec<_> = y.iter().copied().collect();
let x_vec: Vec<_> = x.iter().copied().collect();
let z_vec: Vec<_> = z.iter().copied().collect();
self.can_apply_rule2(&y_vec, &x_vec, &z_vec, &[])
}
/// Check Rule 3 with HashSet API (for test compatibility)
pub fn can_apply_rule3_sets(
&self,
y: &HashSet<VariableId>,
x: &HashSet<VariableId>,
z: &HashSet<VariableId>,
) -> bool {
let y_vec: Vec<_> = y.iter().copied().collect();
let x_vec: Vec<_> = x.iter().copied().collect();
let z_vec: Vec<_> = z.iter().copied().collect();
self.can_apply_rule3(&y_vec, &x_vec, &z_vec, &[])
}
/// Find a valid backdoor adjustment set
fn find_backdoor_adjustment(
&self,
treatment: &[VariableId],
outcome: &[VariableId],
) -> Option<Vec<VariableId>> {
// Get all potential adjustment variables (not descendants of treatment)
let treatment_descendants: HashSet<_> = treatment.iter()
.flat_map(|t| self.model.graph().descendants(t.0))
.collect();
let potential_adjusters: Vec<_> = self.model.variables()
.filter(|v| !treatment.contains(&v.id))
.filter(|v| !outcome.contains(&v.id))
.filter(|v| !treatment_descendants.contains(&v.id.0))
.map(|v| v.id)
.collect();
// Try the full set first
if self.satisfies_backdoor_criterion(treatment, outcome, &potential_adjusters) {
return Some(potential_adjusters);
}
// Try minimal subsets
if potential_adjusters.is_empty() {
if self.satisfies_backdoor_criterion(treatment, outcome, &[]) {
return Some(vec![]);
}
}
// Try single-variable adjustments
for &adjuster in &potential_adjusters {
if self.satisfies_backdoor_criterion(treatment, outcome, &[adjuster]) {
return Some(vec![adjuster]);
}
}
// Try pairs
for i in 0..potential_adjusters.len() {
for j in (i + 1)..potential_adjusters.len() {
let pair = vec![potential_adjusters[i], potential_adjusters[j]];
if self.satisfies_backdoor_criterion(treatment, outcome, &pair) {
return Some(pair);
}
}
}
None
}
/// Check if a set satisfies the backdoor criterion
fn satisfies_backdoor_criterion(
&self,
treatment: &[VariableId],
outcome: &[VariableId],
adjustment: &[VariableId],
) -> bool {
// Backdoor criterion:
// 1. No node in Z is a descendant of X
// 2. Z blocks all backdoor paths from X to Y
// Condition 1: already ensured by caller
// Condition 2: Check d-separation in G_{\overline{X}}
let g_no_x = self.graph_delete_incoming(treatment);
for &x in treatment {
for &y in outcome {
let x_set: HashSet<_> = [x.0].into_iter().collect();
let y_set: HashSet<_> = [y.0].into_iter().collect();
let z_set: HashSet<_> = adjustment.iter().map(|v| v.0).collect();
if !g_no_x.d_separated(&x_set, &y_set, &z_set) {
return false;
}
}
}
true
}
/// Find a front-door adjustment set (for single treatment/outcome)
fn find_front_door_set(
&self,
treatment: VariableId,
outcome: VariableId,
) -> Option<Vec<VariableId>> {
// Front-door criterion:
// 1. M intercepts all directed paths from X to Y
// 2. There is no unblocked backdoor path from X to M
// 3. All backdoor paths from M to Y are blocked by X
let descendants_of_x = self.model.graph().descendants(treatment.0);
let ancestors_of_y = self.model.graph().ancestors(outcome.0);
// M must be on path from X to Y
let candidates: Vec<_> = descendants_of_x.intersection(&ancestors_of_y)
.filter(|&&m| m != treatment.0 && m != outcome.0)
.map(|&m| VariableId(m))
.collect();
if candidates.is_empty() {
return None;
}
// Check each candidate
for &m in &candidates {
// Check condition 2: no backdoor from X to M
let x_set: HashSet<_> = [treatment.0].into_iter().collect();
let m_set: HashSet<_> = [m.0].into_iter().collect();
if self.model.graph().d_separated(&x_set, &m_set, &HashSet::new()) {
continue; // X and M are d-separated (no path at all)
}
// Check condition 3: backdoor from M to Y blocked by X
let y_set: HashSet<_> = [outcome.0].into_iter().collect();
let g_underline_m = self.graph_delete_outgoing(&[m]);
if g_underline_m.d_separated(&m_set, &y_set, &x_set) {
return Some(vec![m]);
}
}
None
}
/// Check if effect is directly identifiable (no confounders)
fn is_directly_identifiable(
&self,
treatment: &[VariableId],
outcome: &[VariableId],
) -> bool {
// Check if there are any backdoor paths
for &x in treatment {
for &y in outcome {
let x_ancestors = self.model.graph().ancestors(x.0);
let y_ancestors = self.model.graph().ancestors(y.0);
// If they share common ancestors, there might be confounding
if !x_ancestors.is_disjoint(&y_ancestors) {
return false;
}
}
}
true
}
// Graph manipulation helpers
/// Create G_{\overline{X}}: delete incoming edges to X
fn graph_delete_incoming(&self, x: &[VariableId]) -> DirectedGraph {
let mut modified = self.model.graph().clone();
for &xi in x {
if let Some(parents) = self.model.parents(&xi) {
for parent in parents {
modified.remove_edge(parent.0, xi.0).ok();
}
}
}
modified
}
/// Create G_{\underline{Z}}: delete outgoing edges from Z
fn graph_delete_outgoing(&self, z: &[VariableId]) -> DirectedGraph {
let mut modified = self.model.graph().clone();
for &zi in z {
if let Some(children) = self.model.children(&zi) {
for child in children {
modified.remove_edge(zi.0, child.0).ok();
}
}
}
modified
}
/// Create G_{\overline{X}\underline{Z}}
fn graph_delete_incoming_and_outgoing(
&self,
x: &[VariableId],
z: &[VariableId],
) -> DirectedGraph {
let mut modified = self.graph_delete_incoming(x);
for &zi in z {
if let Some(children) = self.model.children(&zi) {
for child in children {
modified.remove_edge(zi.0, child.0).ok();
}
}
}
modified
}
/// Delete incoming edges to multiple variable sets
fn graph_delete_incoming_multiple(&self, vars: &[VariableId]) -> DirectedGraph {
self.graph_delete_incoming(vars)
}
/// Compute the causal effect using the identified formula
pub fn compute_effect(
&self,
identification: &Identification,
data: &HashMap<String, Vec<f64>>,
) -> Result<f64, IdentificationError> {
if !identification.identifiable {
return Err(IdentificationError::NotIdentifiable(
"Cannot compute unidentifiable effect".to_string()
));
}
// Simple implementation: use adjustment formula if available
if let Some(ref adjustment_names) = identification.adjustment_set {
if adjustment_names.is_empty() {
// Direct effect - compute from data
return self.compute_direct_effect(data);
}
// Adjusted effect
return self.compute_adjusted_effect(data, adjustment_names);
}
// Front-door adjustment
if identification.front_door_set.is_some() {
return self.compute_frontdoor_effect(data, identification);
}
Err(IdentificationError::NotIdentifiable(
"No valid estimation strategy".to_string()
))
}
fn compute_direct_effect(&self, data: &HashMap<String, Vec<f64>>) -> Result<f64, IdentificationError> {
// Simple regression coefficient as effect estimate
// This is a placeholder - real implementation would use proper estimation
Ok(0.0)
}
fn compute_adjusted_effect(
&self,
_data: &HashMap<String, Vec<f64>>,
_adjustment: &[String],
) -> Result<f64, IdentificationError> {
// Adjusted regression or inverse probability weighting
// Placeholder implementation
Ok(0.0)
}
fn compute_frontdoor_effect(
&self,
_data: &HashMap<String, Vec<f64>>,
_identification: &Identification,
) -> Result<f64, IdentificationError> {
// Front-door formula computation
// Placeholder implementation
Ok(0.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::causal::model::{CausalModel, VariableType, Mechanism, Value};
fn create_confounded_model() -> CausalModel {
// X -> Y with unobserved confounder U
// U -> X, U -> Y
let mut model = CausalModel::with_name("Confounded");
model.add_variable("U", VariableType::Continuous).unwrap();
model.add_variable("X", VariableType::Continuous).unwrap();
model.add_variable("Y", VariableType::Continuous).unwrap();
let u = model.get_variable_id("U").unwrap();
let x = model.get_variable_id("X").unwrap();
let y = model.get_variable_id("Y").unwrap();
model.add_edge(u, x).unwrap();
model.add_edge(u, y).unwrap();
model.add_edge(x, y).unwrap();
model.add_structural_equation(x, &[u], Mechanism::new(|p| {
Value::Continuous(p[0].as_f64() + 1.0)
})).unwrap();
model.add_structural_equation(y, &[x, u], Mechanism::new(|p| {
Value::Continuous(p[0].as_f64() * 2.0 + p[1].as_f64())
})).unwrap();
model
}
fn create_frontdoor_model() -> CausalModel {
// X -> M -> Y with X-Y confounded
let mut model = CausalModel::with_name("FrontDoor");
model.add_variable("U", VariableType::Continuous).unwrap(); // Unobserved confounder
model.add_variable("X", VariableType::Continuous).unwrap();
model.add_variable("M", VariableType::Continuous).unwrap();
model.add_variable("Y", VariableType::Continuous).unwrap();
let u = model.get_variable_id("U").unwrap();
let x = model.get_variable_id("X").unwrap();
let m = model.get_variable_id("M").unwrap();
let y = model.get_variable_id("Y").unwrap();
model.add_edge(u, x).unwrap();
model.add_edge(u, y).unwrap();
model.add_edge(x, m).unwrap();
model.add_edge(m, y).unwrap();
model.add_structural_equation(x, &[u], Mechanism::new(|p| {
p[0].clone()
})).unwrap();
model.add_structural_equation(m, &[x], Mechanism::new(|p| {
p[0].clone()
})).unwrap();
model.add_structural_equation(y, &[m, u], Mechanism::new(|p| {
Value::Continuous(p[0].as_f64() + p[1].as_f64())
})).unwrap();
model
}
fn create_unconfounded_model() -> CausalModel {
// Simple X -> Y without confounding
let mut model = CausalModel::with_name("Unconfounded");
model.add_variable("X", VariableType::Continuous).unwrap();
model.add_variable("Y", VariableType::Continuous).unwrap();
let x = model.get_variable_id("X").unwrap();
let y = model.get_variable_id("Y").unwrap();
model.add_edge(x, y).unwrap();
model.add_structural_equation(y, &[x], Mechanism::new(|p| {
Value::Continuous(2.0 * p[0].as_f64())
})).unwrap();
model
}
#[test]
fn test_unconfounded_identifiable() {
let model = create_unconfounded_model();
let calc = DoCalculus::new(&model);
let result = calc.identify(&["X"], &["Y"]).unwrap();
assert!(result.identifiable);
}
#[test]
fn test_confounded_with_adjustment() {
let model = create_confounded_model();
let calc = DoCalculus::new(&model);
// With U observed, we can adjust for it
let result = calc.identify(&["X"], &["Y"]).unwrap();
// Should be identifiable by adjusting for U
assert!(result.identifiable);
assert!(result.adjustment_set.is_some());
}
#[test]
fn test_frontdoor_identification() {
let model = create_frontdoor_model();
let calc = DoCalculus::new(&model);
let result = calc.identify(&["X"], &["Y"]).unwrap();
// Should be identifiable via front-door criterion
assert!(result.identifiable);
}
#[test]
fn test_rule1_application() {
let model = create_unconfounded_model();
let calc = DoCalculus::new(&model);
let x = model.get_variable_id("X").unwrap();
let y = model.get_variable_id("Y").unwrap();
// In a simple X -> Y model, Y is not independent of X
let can_remove_x = calc.can_apply_rule1(&[y], &[], &[x], &[]);
assert!(!can_remove_x); // Cannot remove X observation
}
#[test]
fn test_rule2_application() {
let model = create_unconfounded_model();
let calc = DoCalculus::new(&model);
let x = model.get_variable_id("X").unwrap();
let y = model.get_variable_id("Y").unwrap();
// Can we exchange do(X) with observation X?
let can_exchange = calc.can_apply_rule2(&[y], &[], &[x], &[]);
// In simple X -> Y, deleting outgoing from X blocks the path
assert!(can_exchange);
}
#[test]
fn test_rule_descriptions() {
assert!(Rule::Rule1.name().contains("observation"));
assert!(Rule::Rule2.name().contains("exchange"));
assert!(Rule::Rule3.name().contains("deletion"));
}
#[test]
fn test_identification_result() {
let model = create_unconfounded_model();
let calc = DoCalculus::new(&model);
let result = calc.identify(&["X"], &["Y"]).unwrap();
assert!(result.identifiable);
assert!(result.expression.is_some());
}
}

View File

@@ -0,0 +1,846 @@
//! Directed Acyclic Graph (DAG) implementation for causal models
//!
//! This module provides a validated DAG structure that ensures:
//! - No cycles (acyclicity constraint)
//! - Efficient topological ordering
//! - Parent/child relationship queries
//! - D-separation computations
use std::collections::{HashMap, HashSet, VecDeque};
use thiserror::Error;
/// Error types for DAG operations
#[derive(Debug, Clone, Error)]
pub enum DAGValidationError {
/// Cycle detected in graph
#[error("Cycle detected involving nodes: {0:?}")]
CycleDetected(Vec<u32>),
/// Node not found
#[error("Node {0} not found in graph")]
NodeNotFound(u32),
/// Edge already exists
#[error("Edge from {0} to {1} already exists")]
EdgeExists(u32, u32),
/// Self-loop detected
#[error("Self-loop detected at node {0}")]
SelfLoop(u32),
/// Invalid operation on empty graph
#[error("Graph is empty")]
EmptyGraph,
}
/// A directed acyclic graph for causal relationships
#[derive(Debug, Clone)]
pub struct DirectedGraph {
/// Number of nodes
num_nodes: usize,
/// Adjacency list: node -> children
children: HashMap<u32, HashSet<u32>>,
/// Reverse adjacency: node -> parents
parents: HashMap<u32, HashSet<u32>>,
/// Node labels (optional)
labels: HashMap<u32, String>,
/// Cached topological order (invalidated on structural changes)
cached_topo_order: Option<Vec<u32>>,
}
impl DirectedGraph {
/// Create a new empty directed graph
pub fn new() -> Self {
Self {
num_nodes: 0,
children: HashMap::new(),
parents: HashMap::new(),
labels: HashMap::new(),
cached_topo_order: None,
}
}
/// Create a graph with pre-allocated capacity
pub fn with_capacity(nodes: usize) -> Self {
Self {
num_nodes: 0,
children: HashMap::with_capacity(nodes),
parents: HashMap::with_capacity(nodes),
labels: HashMap::with_capacity(nodes),
cached_topo_order: None,
}
}
/// Add a node to the graph
pub fn add_node(&mut self, id: u32) -> u32 {
if !self.children.contains_key(&id) {
self.children.insert(id, HashSet::new());
self.parents.insert(id, HashSet::new());
self.num_nodes += 1;
self.cached_topo_order = None;
}
id
}
/// Add a node with a label
pub fn add_node_with_label(&mut self, id: u32, label: &str) -> u32 {
self.add_node(id);
self.labels.insert(id, label.to_string());
id
}
/// Add a directed edge from `from` to `to`
///
/// Returns error if edge would create a cycle
pub fn add_edge(&mut self, from: u32, to: u32) -> Result<(), DAGValidationError> {
// Check for self-loop
if from == to {
return Err(DAGValidationError::SelfLoop(from));
}
// Ensure nodes exist
self.add_node(from);
self.add_node(to);
// Check if edge already exists
if self.children.get(&from).map(|c| c.contains(&to)).unwrap_or(false) {
return Err(DAGValidationError::EdgeExists(from, to));
}
// Temporarily add edge and check for cycles
self.children.get_mut(&from).unwrap().insert(to);
self.parents.get_mut(&to).unwrap().insert(from);
if self.has_cycle() {
// Remove edge if cycle detected
self.children.get_mut(&from).unwrap().remove(&to);
self.parents.get_mut(&to).unwrap().remove(&from);
return Err(DAGValidationError::CycleDetected(self.find_cycle_nodes(from, to)));
}
self.cached_topo_order = None;
Ok(())
}
/// Remove an edge from the graph
pub fn remove_edge(&mut self, from: u32, to: u32) -> Result<(), DAGValidationError> {
if !self.children.contains_key(&from) {
return Err(DAGValidationError::NodeNotFound(from));
}
if !self.children.contains_key(&to) {
return Err(DAGValidationError::NodeNotFound(to));
}
self.children.get_mut(&from).unwrap().remove(&to);
self.parents.get_mut(&to).unwrap().remove(&from);
self.cached_topo_order = None;
Ok(())
}
/// Check if the graph has a cycle (using DFS)
fn has_cycle(&self) -> bool {
let mut visited = HashSet::new();
let mut rec_stack = HashSet::new();
for &node in self.children.keys() {
if self.has_cycle_util(node, &mut visited, &mut rec_stack) {
return true;
}
}
false
}
fn has_cycle_util(
&self,
node: u32,
visited: &mut HashSet<u32>,
rec_stack: &mut HashSet<u32>,
) -> bool {
if rec_stack.contains(&node) {
return true;
}
if visited.contains(&node) {
return false;
}
visited.insert(node);
rec_stack.insert(node);
if let Some(children) = self.children.get(&node) {
for &child in children {
if self.has_cycle_util(child, visited, rec_stack) {
return true;
}
}
}
rec_stack.remove(&node);
false
}
/// Find nodes involved in a potential cycle
fn find_cycle_nodes(&self, from: u32, to: u32) -> Vec<u32> {
// Find path from `to` back to `from`
let mut path = Vec::new();
let mut visited = HashSet::new();
fn dfs(
graph: &DirectedGraph,
current: u32,
target: u32,
visited: &mut HashSet<u32>,
path: &mut Vec<u32>,
) -> bool {
if current == target {
path.push(current);
return true;
}
if visited.contains(&current) {
return false;
}
visited.insert(current);
path.push(current);
if let Some(children) = graph.children.get(&current) {
for &child in children {
if dfs(graph, child, target, visited, path) {
return true;
}
}
}
path.pop();
false
}
if dfs(self, to, from, &mut visited, &mut path) {
path.push(to);
}
path
}
/// Get children of a node
pub fn children_of(&self, node: u32) -> Option<&HashSet<u32>> {
self.children.get(&node)
}
/// Get parents of a node
pub fn parents_of(&self, node: u32) -> Option<&HashSet<u32>> {
self.parents.get(&node)
}
/// Check if node exists
pub fn contains_node(&self, node: u32) -> bool {
self.children.contains_key(&node)
}
/// Check if edge exists
pub fn contains_edge(&self, from: u32, to: u32) -> bool {
self.children.get(&from).map(|c| c.contains(&to)).unwrap_or(false)
}
/// Get number of nodes
pub fn node_count(&self) -> usize {
self.num_nodes
}
/// Get number of edges
pub fn edge_count(&self) -> usize {
self.children.values().map(|c| c.len()).sum()
}
/// Get all nodes
pub fn nodes(&self) -> impl Iterator<Item = u32> + '_ {
self.children.keys().copied()
}
/// Get all edges as (from, to) pairs
pub fn edges(&self) -> impl Iterator<Item = (u32, u32)> + '_ {
self.children.iter().flat_map(|(&from, children)| {
children.iter().map(move |&to| (from, to))
})
}
/// Get node label
pub fn get_label(&self, node: u32) -> Option<&str> {
self.labels.get(&node).map(|s| s.as_str())
}
/// Find node by label
pub fn find_node_by_label(&self, label: &str) -> Option<u32> {
self.labels.iter()
.find(|(_, l)| l.as_str() == label)
.map(|(&id, _)| id)
}
/// Compute topological ordering using Kahn's algorithm
pub fn topological_order(&mut self) -> Result<TopologicalOrder, DAGValidationError> {
if self.num_nodes == 0 {
return Err(DAGValidationError::EmptyGraph);
}
// Use cached order if available
if let Some(ref order) = self.cached_topo_order {
return Ok(TopologicalOrder { order: order.clone() });
}
// Compute in-degrees
let mut in_degree: HashMap<u32, usize> = HashMap::new();
for &node in self.children.keys() {
in_degree.insert(node, 0);
}
for children in self.children.values() {
for &child in children {
*in_degree.entry(child).or_insert(0) += 1;
}
}
// Initialize queue with nodes having in-degree 0
let mut queue: VecDeque<u32> = in_degree
.iter()
.filter(|&(_, &deg)| deg == 0)
.map(|(&node, _)| node)
.collect();
let mut order = Vec::with_capacity(self.num_nodes);
while let Some(node) = queue.pop_front() {
order.push(node);
if let Some(children) = self.children.get(&node) {
for &child in children {
if let Some(deg) = in_degree.get_mut(&child) {
*deg -= 1;
if *deg == 0 {
queue.push_back(child);
}
}
}
}
}
if order.len() != self.num_nodes {
return Err(DAGValidationError::CycleDetected(
in_degree.iter()
.filter(|&(_, &deg)| deg > 0)
.map(|(&node, _)| node)
.collect()
));
}
self.cached_topo_order = Some(order.clone());
Ok(TopologicalOrder { order })
}
/// Get ancestors of a node (all nodes that can reach it)
pub fn ancestors(&self, node: u32) -> HashSet<u32> {
let mut ancestors = HashSet::new();
let mut queue = VecDeque::new();
if let Some(parents) = self.parents.get(&node) {
for &parent in parents {
queue.push_back(parent);
}
}
while let Some(current) = queue.pop_front() {
if ancestors.insert(current) {
if let Some(parents) = self.parents.get(&current) {
for &parent in parents {
if !ancestors.contains(&parent) {
queue.push_back(parent);
}
}
}
}
}
ancestors
}
/// Get descendants of a node (all nodes reachable from it)
pub fn descendants(&self, node: u32) -> HashSet<u32> {
let mut descendants = HashSet::new();
let mut queue = VecDeque::new();
if let Some(children) = self.children.get(&node) {
for &child in children {
queue.push_back(child);
}
}
while let Some(current) = queue.pop_front() {
if descendants.insert(current) {
if let Some(children) = self.children.get(&current) {
for &child in children {
if !descendants.contains(&child) {
queue.push_back(child);
}
}
}
}
}
descendants
}
/// Check d-separation between X and Y given conditioning set Z
///
/// Two sets X and Y are d-separated by Z if all paths between X and Y
/// are blocked by Z.
pub fn d_separated(
&self,
x: &HashSet<u32>,
y: &HashSet<u32>,
z: &HashSet<u32>,
) -> bool {
// Use Bayes Ball algorithm for d-separation
let reachable = self.bayes_ball_reachable(x, z);
// X and Y are d-separated if no node in Y is reachable
reachable.intersection(y).next().is_none()
}
/// Bayes Ball algorithm to find reachable nodes
///
/// Returns the set of nodes reachable from `source` given evidence `evidence`
fn bayes_ball_reachable(&self, source: &HashSet<u32>, evidence: &HashSet<u32>) -> HashSet<u32> {
let mut visited_up: HashSet<u32> = HashSet::new();
let mut visited_down: HashSet<u32> = HashSet::new();
let mut reachable: HashSet<u32> = HashSet::new();
// Queue entries: (node, direction_is_up)
let mut queue: VecDeque<(u32, bool)> = VecDeque::new();
// Initialize with source nodes going up (as if we observed them)
for &node in source {
queue.push_back((node, true)); // Going up from source
queue.push_back((node, false)); // Going down from source
}
while let Some((node, going_up)) = queue.pop_front() {
// Skip if already visited in this direction
if going_up && visited_up.contains(&node) {
continue;
}
if !going_up && visited_down.contains(&node) {
continue;
}
if going_up {
visited_up.insert(node);
} else {
visited_down.insert(node);
}
let is_evidence = evidence.contains(&node);
if going_up && !is_evidence {
// Ball going up, node not observed: continue to parents and children
reachable.insert(node);
if let Some(parents) = self.parents.get(&node) {
for &parent in parents {
queue.push_back((parent, true));
}
}
if let Some(children) = self.children.get(&node) {
for &child in children {
queue.push_back((child, false));
}
}
} else if going_up && is_evidence {
// Ball going up, node observed: continue to parents only
if let Some(parents) = self.parents.get(&node) {
for &parent in parents {
queue.push_back((parent, true));
}
}
} else if !going_up && !is_evidence {
// Ball going down, node not observed: continue to children only
reachable.insert(node);
if let Some(children) = self.children.get(&node) {
for &child in children {
queue.push_back((child, false));
}
}
} else {
// Ball going down, node observed: bounce back up to parents
reachable.insert(node);
if let Some(parents) = self.parents.get(&node) {
for &parent in parents {
queue.push_back((parent, true));
}
}
}
}
reachable
}
/// Find all paths between two nodes
pub fn find_all_paths(&self, from: u32, to: u32, max_length: usize) -> Vec<Vec<u32>> {
let mut all_paths = Vec::new();
let mut current_path = vec![from];
self.find_paths_dfs(from, to, &mut current_path, &mut all_paths, max_length);
all_paths
}
fn find_paths_dfs(
&self,
current: u32,
target: u32,
path: &mut Vec<u32>,
all_paths: &mut Vec<Vec<u32>>,
max_length: usize,
) {
if current == target {
all_paths.push(path.clone());
return;
}
if path.len() >= max_length {
return;
}
if let Some(children) = self.children.get(&current) {
for &child in children {
if !path.contains(&child) {
path.push(child);
self.find_paths_dfs(child, target, path, all_paths, max_length);
path.pop();
}
}
}
}
/// Get the skeleton (undirected version) of the graph
pub fn skeleton(&self) -> HashSet<(u32, u32)> {
let mut skeleton = HashSet::new();
for (&from, children) in &self.children {
for &to in children {
let edge = if from < to { (from, to) } else { (to, from) };
skeleton.insert(edge);
}
}
skeleton
}
/// Find all v-structures (colliders) in the graph
///
/// A v-structure is a triple (A, B, C) where A -> B <- C and A and C are not adjacent
pub fn v_structures(&self) -> Vec<(u32, u32, u32)> {
let mut v_structs = Vec::new();
let skeleton = self.skeleton();
for (&node, parents) in &self.parents {
if parents.len() < 2 {
continue;
}
let parents_vec: Vec<_> = parents.iter().copied().collect();
for i in 0..parents_vec.len() {
for j in (i + 1)..parents_vec.len() {
let p1 = parents_vec[i];
let p2 = parents_vec[j];
// Check if parents are not adjacent
let edge = if p1 < p2 { (p1, p2) } else { (p2, p1) };
if !skeleton.contains(&edge) {
v_structs.push((p1, node, p2));
}
}
}
}
v_structs
}
}
impl Default for DirectedGraph {
fn default() -> Self {
Self::new()
}
}
/// Topological ordering of nodes in a DAG
#[derive(Debug, Clone)]
pub struct TopologicalOrder {
order: Vec<u32>,
}
impl TopologicalOrder {
/// Get the ordering as a slice
pub fn as_slice(&self) -> &[u32] {
&self.order
}
/// Get the position of a node in the ordering
pub fn position(&self, node: u32) -> Option<usize> {
self.order.iter().position(|&n| n == node)
}
/// Check if node A comes before node B in the ordering
pub fn comes_before(&self, a: u32, b: u32) -> bool {
match (self.position(a), self.position(b)) {
(Some(pos_a), Some(pos_b)) => pos_a < pos_b,
_ => false,
}
}
/// Iterate over nodes in topological order
pub fn iter(&self) -> impl Iterator<Item = &u32> {
self.order.iter()
}
/// Get number of nodes
pub fn len(&self) -> usize {
self.order.len()
}
/// Check if ordering is empty
pub fn is_empty(&self) -> bool {
self.order.is_empty()
}
}
impl IntoIterator for TopologicalOrder {
type Item = u32;
type IntoIter = std::vec::IntoIter<u32>;
fn into_iter(self) -> Self::IntoIter {
self.order.into_iter()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_add_nodes_and_edges() {
let mut graph = DirectedGraph::new();
graph.add_node(0);
graph.add_node(1);
graph.add_edge(0, 1).unwrap();
assert!(graph.contains_node(0));
assert!(graph.contains_node(1));
assert!(graph.contains_edge(0, 1));
assert!(!graph.contains_edge(1, 0));
}
#[test]
fn test_cycle_detection() {
let mut graph = DirectedGraph::new();
graph.add_edge(0, 1).unwrap();
graph.add_edge(1, 2).unwrap();
// This should fail - would create cycle
let result = graph.add_edge(2, 0);
assert!(matches!(result, Err(DAGValidationError::CycleDetected(_))));
}
#[test]
fn test_self_loop_detection() {
let mut graph = DirectedGraph::new();
let result = graph.add_edge(0, 0);
assert!(matches!(result, Err(DAGValidationError::SelfLoop(0))));
}
#[test]
fn test_topological_order() {
let mut graph = DirectedGraph::new();
// Diamond: 0 -> 1, 0 -> 2, 1 -> 3, 2 -> 3
graph.add_edge(0, 1).unwrap();
graph.add_edge(0, 2).unwrap();
graph.add_edge(1, 3).unwrap();
graph.add_edge(2, 3).unwrap();
let order = graph.topological_order().unwrap();
assert_eq!(order.len(), 4);
assert!(order.comes_before(0, 1));
assert!(order.comes_before(0, 2));
assert!(order.comes_before(1, 3));
assert!(order.comes_before(2, 3));
}
#[test]
fn test_ancestors_and_descendants() {
let mut graph = DirectedGraph::new();
// Chain: 0 -> 1 -> 2 -> 3
graph.add_edge(0, 1).unwrap();
graph.add_edge(1, 2).unwrap();
graph.add_edge(2, 3).unwrap();
let ancestors = graph.ancestors(3);
assert!(ancestors.contains(&0));
assert!(ancestors.contains(&1));
assert!(ancestors.contains(&2));
assert!(!ancestors.contains(&3));
let descendants = graph.descendants(0);
assert!(descendants.contains(&1));
assert!(descendants.contains(&2));
assert!(descendants.contains(&3));
assert!(!descendants.contains(&0));
}
#[test]
fn test_d_separation_chain() {
// Chain: X -> Z -> Y
// X and Y should be d-separated given Z
let mut graph = DirectedGraph::new();
graph.add_node_with_label(0, "X");
graph.add_node_with_label(1, "Z");
graph.add_node_with_label(2, "Y");
graph.add_edge(0, 1).unwrap();
graph.add_edge(1, 2).unwrap();
let x: HashSet<u32> = [0].into_iter().collect();
let y: HashSet<u32> = [2].into_iter().collect();
let z: HashSet<u32> = [1].into_iter().collect();
let empty: HashSet<u32> = HashSet::new();
// X and Y are NOT d-separated given empty set
assert!(!graph.d_separated(&x, &y, &empty));
// X and Y ARE d-separated given Z
assert!(graph.d_separated(&x, &y, &z));
}
#[test]
fn test_d_separation_fork() {
// Fork: X <- Z -> Y
// X and Y should be d-separated given Z
let mut graph = DirectedGraph::new();
graph.add_node_with_label(0, "X");
graph.add_node_with_label(1, "Z");
graph.add_node_with_label(2, "Y");
graph.add_edge(1, 0).unwrap();
graph.add_edge(1, 2).unwrap();
let x: HashSet<u32> = [0].into_iter().collect();
let y: HashSet<u32> = [2].into_iter().collect();
let z: HashSet<u32> = [1].into_iter().collect();
let empty: HashSet<u32> = HashSet::new();
// X and Y are NOT d-separated given empty set
assert!(!graph.d_separated(&x, &y, &empty));
// X and Y ARE d-separated given Z
assert!(graph.d_separated(&x, &y, &z));
}
#[test]
fn test_d_separation_collider() {
// Collider: X -> Z <- Y
// X and Y should NOT be d-separated given Z (explaining away)
let mut graph = DirectedGraph::new();
graph.add_node_with_label(0, "X");
graph.add_node_with_label(1, "Z");
graph.add_node_with_label(2, "Y");
graph.add_edge(0, 1).unwrap();
graph.add_edge(2, 1).unwrap();
let x: HashSet<u32> = [0].into_iter().collect();
let y: HashSet<u32> = [2].into_iter().collect();
let z: HashSet<u32> = [1].into_iter().collect();
let empty: HashSet<u32> = HashSet::new();
// X and Y ARE d-separated given empty set (collider blocks)
assert!(graph.d_separated(&x, &y, &empty));
// X and Y are NOT d-separated given Z (conditioning opens collider)
assert!(!graph.d_separated(&x, &y, &z));
}
#[test]
fn test_v_structures() {
// Collider: X -> Z <- Y
let mut graph = DirectedGraph::new();
graph.add_edge(0, 2).unwrap(); // X -> Z
graph.add_edge(1, 2).unwrap(); // Y -> Z
let v_structs = graph.v_structures();
assert_eq!(v_structs.len(), 1);
let (a, b, c) = v_structs[0];
assert_eq!(b, 2); // Z is the collider
assert!(a == 0 || a == 1);
assert!(c == 0 || c == 1);
assert_ne!(a, c);
}
#[test]
fn test_labels() {
let mut graph = DirectedGraph::new();
graph.add_node_with_label(0, "Age");
graph.add_node_with_label(1, "Income");
graph.add_edge(0, 1).unwrap();
assert_eq!(graph.get_label(0), Some("Age"));
assert_eq!(graph.get_label(1), Some("Income"));
assert_eq!(graph.find_node_by_label("Age"), Some(0));
assert_eq!(graph.find_node_by_label("Unknown"), None);
}
#[test]
fn test_find_all_paths() {
let mut graph = DirectedGraph::new();
// Diamond: 0 -> 1, 0 -> 2, 1 -> 3, 2 -> 3
graph.add_edge(0, 1).unwrap();
graph.add_edge(0, 2).unwrap();
graph.add_edge(1, 3).unwrap();
graph.add_edge(2, 3).unwrap();
let paths = graph.find_all_paths(0, 3, 10);
assert_eq!(paths.len(), 2);
assert!(paths.contains(&vec![0, 1, 3]));
assert!(paths.contains(&vec![0, 2, 3]));
}
#[test]
fn test_skeleton() {
let mut graph = DirectedGraph::new();
graph.add_edge(0, 1).unwrap();
graph.add_edge(1, 2).unwrap();
graph.add_edge(2, 0).ok(); // This will fail due to cycle
// Add a valid edge instead
graph.add_edge(0, 2).unwrap();
let skeleton = graph.skeleton();
assert_eq!(skeleton.len(), 3);
assert!(skeleton.contains(&(0, 1)));
assert!(skeleton.contains(&(0, 2)));
assert!(skeleton.contains(&(1, 2)));
}
#[test]
fn test_remove_edge() {
let mut graph = DirectedGraph::new();
graph.add_edge(0, 1).unwrap();
graph.add_edge(1, 2).unwrap();
assert!(graph.contains_edge(0, 1));
graph.remove_edge(0, 1).unwrap();
assert!(!graph.contains_edge(0, 1));
}
}

View File

@@ -0,0 +1,271 @@
//! Causal Abstraction Networks for Prime-Radiant
//!
//! This module implements causal reasoning primitives based on structural causal models
//! (SCMs), causal abstraction theory, and do-calculus. Key capabilities:
//!
//! - **CausalModel**: Directed acyclic graph (DAG) of causal relationships with
//! structural equations defining each variable as a function of its parents
//! - **CausalAbstraction**: Maps between low-level and high-level causal models,
//! preserving interventional semantics
//! - **CausalCoherenceChecker**: Validates causal consistency of beliefs and detects
//! spurious correlations
//! - **Counterfactual Reasoning**: Computes counterfactual queries and causal effects
//!
//! ## Architecture
//!
//! The causal module integrates with Prime-Radiant's sheaf-theoretic framework:
//!
//! ```text
//! ┌─────────────────────────────────────────────────────────────────┐
//! │ Prime-Radiant Core │
//! ├─────────────────────────────────────────────────────────────────┤
//! │ SheafGraph ◄──── causal_coherence_energy ────► CausalModel │
//! │ │ │ │
//! │ ▼ ▼ │
//! │ CoherenceEnergy CausalAbstraction │
//! │ │ │ │
//! │ └───────────► Combined Coherence ◄──────────────┘ │
//! └─────────────────────────────────────────────────────────────────┘
//! ```
//!
//! ## Usage
//!
//! ```rust,ignore
//! use prime_radiant::causal::{CausalModel, Intervention, counterfactual};
//!
//! // Build a causal model
//! let mut model = CausalModel::new();
//! model.add_variable("X", VariableType::Continuous);
//! model.add_variable("Y", VariableType::Continuous);
//! model.add_structural_equation("Y", &["X"], |parents| {
//! Value::Continuous(2.0 * parents[0].as_f64() + 0.5)
//! });
//!
//! // Perform intervention do(X = 1.0)
//! let intervention = Intervention::new("X", Value::Continuous(1.0));
//! let result = model.intervene(&intervention);
//!
//! // Compute counterfactual
//! let observation = Observation::new(&[("Y", Value::Continuous(3.0))]);
//! let cf = counterfactual(&model, &observation, &intervention);
//! ```
pub mod model;
pub mod abstraction;
pub mod coherence;
pub mod counterfactual;
pub mod graph;
pub mod do_calculus;
// Re-exports
pub use model::{
CausalModel, StructuralEquation, Variable, VariableId, VariableType, Value,
Mechanism, CausalModelError, MutilatedModel, Distribution, Observation,
IntervenedModel, CausalModelBuilder, Intervention,
};
pub use abstraction::{
CausalAbstraction, AbstractionMap, AbstractionError, ConsistencyResult,
};
pub use coherence::{
CausalCoherenceChecker, CausalConsistency, SpuriousCorrelation, Belief,
CausalQuery, CausalAnswer, CoherenceEnergy,
};
pub use counterfactual::{
counterfactual, causal_effect,
CounterfactualQuery, AverageTreatmentEffect,
};
pub use graph::{DirectedGraph, TopologicalOrder, DAGValidationError};
pub use do_calculus::{DoCalculus, Rule, Identification, IdentificationError};
/// Integration with Prime-Radiant's sheaf-theoretic framework
pub mod integration {
use super::*;
/// Placeholder for SheafGraph from the main Prime-Radiant module
pub struct SheafGraph {
pub nodes: Vec<String>,
pub edges: Vec<(usize, usize)>,
pub sections: Vec<Vec<f64>>,
}
/// Compute combined coherence energy from structural and causal consistency
///
/// This function bridges Prime-Radiant's sheaf cohomology with causal structure:
/// - Sheaf consistency measures local-to-global coherence of beliefs
/// - Causal consistency measures alignment with causal structure
///
/// The combined energy is minimized when both constraints are satisfied.
pub fn causal_coherence_energy(
sheaf_graph: &SheafGraph,
causal_model: &CausalModel,
) -> CoherenceEnergy {
// Compute structural coherence from sheaf
let structural_energy = compute_structural_energy(sheaf_graph);
// Compute causal coherence
let causal_energy = compute_causal_energy(sheaf_graph, causal_model);
// Compute intervention consistency
let intervention_energy = compute_intervention_energy(sheaf_graph, causal_model);
CoherenceEnergy {
total: structural_energy + causal_energy + intervention_energy,
structural_component: structural_energy,
causal_component: causal_energy,
intervention_component: intervention_energy,
is_coherent: (structural_energy + causal_energy + intervention_energy) < 1e-6,
}
}
fn compute_structural_energy(sheaf: &SheafGraph) -> f64 {
// Measure deviation from local consistency
let mut energy = 0.0;
for (i, j) in &sheaf.edges {
if *i < sheaf.sections.len() && *j < sheaf.sections.len() {
let section_i = &sheaf.sections[*i];
let section_j = &sheaf.sections[*j];
// Compute L2 difference (simplified restriction map)
let min_len = section_i.len().min(section_j.len());
for k in 0..min_len {
let diff = section_i[k] - section_j[k];
energy += diff * diff;
}
}
}
energy
}
fn compute_causal_energy(sheaf: &SheafGraph, model: &CausalModel) -> f64 {
// Check that sheaf structure respects causal ordering
let mut energy = 0.0;
if let Ok(topo_order) = model.topological_order() {
let order_map: std::collections::HashMap<_, _> = topo_order
.iter()
.enumerate()
.map(|(i, v)| (v.clone(), i))
.collect();
// Penalize edges that violate causal ordering
for (i, j) in &sheaf.edges {
if *i < sheaf.nodes.len() && *j < sheaf.nodes.len() {
let node_i = &sheaf.nodes[*i];
let node_j = &sheaf.nodes[*j];
if let (Some(&order_i), Some(&order_j)) =
(order_map.get(node_i), order_map.get(node_j))
{
// Edge from j to i should have order_j < order_i
if order_j > order_i {
energy += 1.0;
}
}
}
}
}
energy
}
fn compute_intervention_energy(sheaf: &SheafGraph, model: &CausalModel) -> f64 {
// Verify that interventions propagate correctly through sheaf
let mut energy = 0.0;
// For each potential intervention point, check consistency
for (i, node) in sheaf.nodes.iter().enumerate() {
if let Some(var_id) = model.get_variable_id(node) {
if let Some(children) = model.children(&var_id) {
for child in children {
if let Some(child_name) = model.get_variable_name(&child) {
// Find corresponding sheaf node
if let Some(j) = sheaf.nodes.iter().position(|n| n == &child_name) {
// Check if intervention effect is consistent
if i < sheaf.sections.len() && j < sheaf.sections.len() {
let parent_section = &sheaf.sections[i];
let child_section = &sheaf.sections[j];
// Simple check: child should be influenced by parent
if !parent_section.is_empty() && !child_section.is_empty() {
// Correlation check (simplified)
let correlation = compute_correlation(parent_section, child_section);
if correlation.abs() < 0.01 {
energy += 0.1; // Weak causal link penalty
}
}
}
}
}
}
}
}
}
energy
}
fn compute_correlation(a: &[f64], b: &[f64]) -> f64 {
let n = a.len().min(b.len());
if n == 0 {
return 0.0;
}
let mean_a: f64 = a.iter().take(n).sum::<f64>() / n as f64;
let mean_b: f64 = b.iter().take(n).sum::<f64>() / n as f64;
let mut cov = 0.0;
let mut var_a = 0.0;
let mut var_b = 0.0;
for i in 0..n {
let da = a[i] - mean_a;
let db = b[i] - mean_b;
cov += da * db;
var_a += da * da;
var_b += db * db;
}
let denom = (var_a * var_b).sqrt();
if denom < 1e-10 {
0.0
} else {
cov / denom
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use super::integration::*;
#[test]
fn test_module_exports() {
// Verify all public types are accessible
let _var_id: VariableId = VariableId(0);
let _value = Value::Continuous(1.0);
}
#[test]
fn test_causal_coherence_energy() {
let sheaf = SheafGraph {
nodes: vec!["X".to_string(), "Y".to_string()],
edges: vec![(0, 1)],
sections: vec![vec![1.0, 2.0], vec![2.0, 4.0]],
};
let mut model = CausalModel::new();
model.add_variable("X", VariableType::Continuous).unwrap();
model.add_variable("Y", VariableType::Continuous).unwrap();
let x_id = model.get_variable_id("X").unwrap();
let y_id = model.get_variable_id("Y").unwrap();
model.add_edge(x_id, y_id).unwrap();
let energy = causal_coherence_energy(&sheaf, &model);
assert!(energy.structural_component >= 0.0);
assert!(energy.causal_component >= 0.0);
}
}

File diff suppressed because it is too large Load Diff