Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
820
vendor/ruvector/examples/prime-radiant/src/causal/abstraction.rs
vendored
Normal file
820
vendor/ruvector/examples/prime-radiant/src/causal/abstraction.rs
vendored
Normal file
@@ -0,0 +1,820 @@
|
||||
//! Causal Abstraction Layer
|
||||
//!
|
||||
//! This module implements causal abstraction theory, which formalizes the
|
||||
//! relationship between detailed (low-level) and simplified (high-level)
|
||||
//! causal models. The key insight is that a high-level model is a valid
|
||||
//! abstraction if interventions on the low-level model can be "lifted" to
|
||||
//! corresponding interventions on the high-level model while preserving
|
||||
//! distributional semantics.
|
||||
//!
|
||||
//! ## Theory
|
||||
//!
|
||||
//! A causal abstraction consists of:
|
||||
//! - A low-level model M_L with variables V_L
|
||||
//! - A high-level model M_H with variables V_H
|
||||
//! - A surjective mapping τ: V_L → V_H
|
||||
//!
|
||||
//! The abstraction is **consistent** if for all interventions I on M_H:
|
||||
//! τ(M_L(τ^{-1}(I))) = M_H(I)
|
||||
//!
|
||||
//! ## References
|
||||
//!
|
||||
//! - Beckers & Halpern (2019): "Abstracting Causal Models"
|
||||
//! - Rubenstein et al. (2017): "Causal Consistency of Structural Equation Models"
|
||||
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use thiserror::Error;
|
||||
|
||||
use super::model::{CausalModel, CausalModelError, Intervention, Value, VariableId, Distribution};
|
||||
use super::counterfactual::CounterfactualDistribution;
|
||||
|
||||
/// Error types for abstraction operations
|
||||
#[derive(Debug, Clone, Error)]
|
||||
pub enum AbstractionError {
|
||||
/// Abstraction map is not surjective
|
||||
#[error("Abstraction map is not surjective: high-level variable {0:?} has no preimage")]
|
||||
NotSurjective(VariableId),
|
||||
|
||||
/// Abstraction is not consistent under intervention
|
||||
#[error("Abstraction is not consistent: intervention {0:?} yields different results")]
|
||||
InconsistentIntervention(String),
|
||||
|
||||
/// Invalid variable mapping
|
||||
#[error("Invalid mapping: low-level variable {0:?} not in model")]
|
||||
InvalidMapping(VariableId),
|
||||
|
||||
/// Models have incompatible structure
|
||||
#[error("Incompatible model structure: {0}")]
|
||||
IncompatibleStructure(String),
|
||||
|
||||
/// Underlying model error
|
||||
#[error("Model error: {0}")]
|
||||
ModelError(#[from] CausalModelError),
|
||||
}
|
||||
|
||||
/// Mapping from low-level to high-level variables
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AbstractionMap {
|
||||
/// Maps high-level variable to set of low-level variables
|
||||
high_to_low: HashMap<VariableId, HashSet<VariableId>>,
|
||||
|
||||
/// Maps low-level variable to high-level variable
|
||||
low_to_high: HashMap<VariableId, VariableId>,
|
||||
|
||||
/// Value aggregation functions (how to combine low-level values)
|
||||
aggregators: HashMap<VariableId, Aggregator>,
|
||||
}
|
||||
|
||||
/// How to aggregate low-level values into a high-level value
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Aggregator {
|
||||
/// Take first value (for 1-to-1 mappings)
|
||||
First,
|
||||
/// Sum of values
|
||||
Sum,
|
||||
/// Mean of values
|
||||
Mean,
|
||||
/// Max of values
|
||||
Max,
|
||||
/// Min of values
|
||||
Min,
|
||||
/// Majority vote (for discrete/binary)
|
||||
Majority,
|
||||
/// Weighted combination
|
||||
Weighted(Vec<f64>),
|
||||
/// Custom function (represented as string for debug)
|
||||
Custom(String),
|
||||
}
|
||||
|
||||
impl Aggregator {
|
||||
/// Apply the aggregator to a set of values
|
||||
pub fn apply(&self, values: &[Value]) -> Value {
|
||||
if values.is_empty() {
|
||||
return Value::Missing;
|
||||
}
|
||||
|
||||
match self {
|
||||
Aggregator::First => values[0].clone(),
|
||||
|
||||
Aggregator::Sum => {
|
||||
let sum: f64 = values.iter().map(|v| v.as_f64()).sum();
|
||||
Value::Continuous(sum)
|
||||
}
|
||||
|
||||
Aggregator::Mean => {
|
||||
let sum: f64 = values.iter().map(|v| v.as_f64()).sum();
|
||||
Value::Continuous(sum / values.len() as f64)
|
||||
}
|
||||
|
||||
Aggregator::Max => {
|
||||
let max = values.iter()
|
||||
.map(|v| v.as_f64())
|
||||
.fold(f64::NEG_INFINITY, f64::max);
|
||||
Value::Continuous(max)
|
||||
}
|
||||
|
||||
Aggregator::Min => {
|
||||
let min = values.iter()
|
||||
.map(|v| v.as_f64())
|
||||
.fold(f64::INFINITY, f64::min);
|
||||
Value::Continuous(min)
|
||||
}
|
||||
|
||||
Aggregator::Majority => {
|
||||
let mut counts: HashMap<i64, usize> = HashMap::new();
|
||||
for v in values {
|
||||
let key = v.as_f64() as i64;
|
||||
*counts.entry(key).or_default() += 1;
|
||||
}
|
||||
let majority = counts.into_iter()
|
||||
.max_by_key(|(_, count)| *count)
|
||||
.map(|(val, _)| val)
|
||||
.unwrap_or(0);
|
||||
Value::Discrete(majority)
|
||||
}
|
||||
|
||||
Aggregator::Weighted(weights) => {
|
||||
let weighted_sum: f64 = values.iter()
|
||||
.zip(weights.iter())
|
||||
.map(|(v, w)| v.as_f64() * w)
|
||||
.sum();
|
||||
Value::Continuous(weighted_sum)
|
||||
}
|
||||
|
||||
Aggregator::Custom(_) => {
|
||||
// Default to mean for custom
|
||||
let sum: f64 = values.iter().map(|v| v.as_f64()).sum();
|
||||
Value::Continuous(sum / values.len() as f64)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AbstractionMap {
|
||||
/// Create a new empty abstraction map
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
high_to_low: HashMap::new(),
|
||||
low_to_high: HashMap::new(),
|
||||
aggregators: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a mapping from high-level variable to low-level variables
|
||||
pub fn add_mapping(
|
||||
&mut self,
|
||||
high: VariableId,
|
||||
low_vars: HashSet<VariableId>,
|
||||
aggregator: Aggregator,
|
||||
) {
|
||||
for &low in &low_vars {
|
||||
self.low_to_high.insert(low, high);
|
||||
}
|
||||
self.high_to_low.insert(high, low_vars);
|
||||
self.aggregators.insert(high, aggregator);
|
||||
}
|
||||
|
||||
/// Add a 1-to-1 mapping
|
||||
pub fn add_identity_mapping(&mut self, high: VariableId, low: VariableId) {
|
||||
let mut low_set = HashSet::new();
|
||||
low_set.insert(low);
|
||||
self.add_mapping(high, low_set, Aggregator::First);
|
||||
}
|
||||
|
||||
/// Get the high-level variable for a low-level variable
|
||||
pub fn lift_variable(&self, low: VariableId) -> Option<VariableId> {
|
||||
self.low_to_high.get(&low).copied()
|
||||
}
|
||||
|
||||
/// Get the low-level variables for a high-level variable
|
||||
pub fn project_variable(&self, high: VariableId) -> Option<&HashSet<VariableId>> {
|
||||
self.high_to_low.get(&high)
|
||||
}
|
||||
|
||||
/// Lift a value from low-level to high-level
|
||||
pub fn lift_value(&self, high: VariableId, low_values: &HashMap<VariableId, Value>) -> Value {
|
||||
let low_vars = match self.high_to_low.get(&high) {
|
||||
Some(vars) => vars,
|
||||
None => return Value::Missing,
|
||||
};
|
||||
|
||||
let values: Vec<Value> = low_vars.iter()
|
||||
.filter_map(|v| low_values.get(v).cloned())
|
||||
.collect();
|
||||
|
||||
let aggregator = self.aggregators.get(&high).unwrap_or(&Aggregator::First);
|
||||
aggregator.apply(&values)
|
||||
}
|
||||
|
||||
/// Check if the mapping is surjective (every high-level var has a preimage)
|
||||
pub fn is_surjective(&self, high_level: &CausalModel) -> bool {
|
||||
for var in high_level.variables() {
|
||||
if !self.high_to_low.contains_key(&var.id) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AbstractionMap {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of consistency checking
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConsistencyResult {
|
||||
/// Whether the abstraction is consistent
|
||||
pub is_consistent: bool,
|
||||
/// Violations found (if any)
|
||||
pub violations: Vec<ConsistencyViolation>,
|
||||
/// Interventions tested
|
||||
pub interventions_tested: usize,
|
||||
/// Maximum observed divergence
|
||||
pub max_divergence: f64,
|
||||
}
|
||||
|
||||
/// A violation of causal abstraction consistency
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConsistencyViolation {
|
||||
/// The intervention that caused the violation
|
||||
pub intervention: String,
|
||||
/// Expected high-level outcome
|
||||
pub expected: HashMap<String, f64>,
|
||||
/// Actual (projected from low-level) outcome
|
||||
pub actual: HashMap<String, f64>,
|
||||
/// Divergence measure
|
||||
pub divergence: f64,
|
||||
}
|
||||
|
||||
/// Causal Abstraction between two causal models
|
||||
pub struct CausalAbstraction<'a> {
|
||||
/// The low-level (detailed) model
|
||||
pub low_level: &'a CausalModel,
|
||||
|
||||
/// The high-level (abstract) model
|
||||
pub high_level: &'a CausalModel,
|
||||
|
||||
/// The abstraction mapping
|
||||
pub abstraction_map: AbstractionMap,
|
||||
|
||||
/// Tolerance for numerical consistency checks
|
||||
pub tolerance: f64,
|
||||
}
|
||||
|
||||
impl<'a> CausalAbstraction<'a> {
|
||||
/// Create a new causal abstraction
|
||||
pub fn new(
|
||||
low_level: &'a CausalModel,
|
||||
high_level: &'a CausalModel,
|
||||
abstraction_map: AbstractionMap,
|
||||
) -> Result<Self, AbstractionError> {
|
||||
let abstraction = Self {
|
||||
low_level,
|
||||
high_level,
|
||||
abstraction_map,
|
||||
tolerance: 1e-6,
|
||||
};
|
||||
|
||||
abstraction.validate_structure()?;
|
||||
|
||||
Ok(abstraction)
|
||||
}
|
||||
|
||||
/// Set tolerance for numerical comparisons
|
||||
pub fn with_tolerance(mut self, tol: f64) -> Self {
|
||||
self.tolerance = tol;
|
||||
self
|
||||
}
|
||||
|
||||
/// Validate that the abstraction structure is valid
|
||||
fn validate_structure(&self) -> Result<(), AbstractionError> {
|
||||
// Check surjectivity
|
||||
for var in self.high_level.variables() {
|
||||
if self.abstraction_map.high_to_low.get(&var.id).is_none() {
|
||||
return Err(AbstractionError::NotSurjective(var.id));
|
||||
}
|
||||
}
|
||||
|
||||
// Check that all low-level variables in the map exist
|
||||
for low_vars in self.abstraction_map.high_to_low.values() {
|
||||
for &low_var in low_vars {
|
||||
if self.low_level.get_variable(&low_var).is_none() {
|
||||
return Err(AbstractionError::InvalidMapping(low_var));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if the abstraction is consistent under a set of interventions
|
||||
pub fn is_consistent(&self) -> bool {
|
||||
self.check_consistency().is_consistent
|
||||
}
|
||||
|
||||
/// Perform detailed consistency check
|
||||
pub fn check_consistency(&self) -> ConsistencyResult {
|
||||
let mut violations = Vec::new();
|
||||
let mut max_divergence = 0.0;
|
||||
let mut interventions_tested = 0;
|
||||
|
||||
// Test consistency for single-variable interventions on high-level model
|
||||
for high_var in self.high_level.variables() {
|
||||
// Test a few intervention values
|
||||
for intervention_value in [0.0, 1.0, -1.0, 0.5] {
|
||||
interventions_tested += 1;
|
||||
|
||||
let high_intervention = Intervention::new(
|
||||
high_var.id,
|
||||
Value::Continuous(intervention_value),
|
||||
);
|
||||
|
||||
// Check consistency for this intervention
|
||||
if let Some(violation) = self.check_single_intervention(&high_intervention) {
|
||||
max_divergence = max_divergence.max(violation.divergence);
|
||||
violations.push(violation);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ConsistencyResult {
|
||||
is_consistent: violations.is_empty(),
|
||||
violations,
|
||||
interventions_tested,
|
||||
max_divergence,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check consistency for a single intervention
|
||||
fn check_single_intervention(&self, high_intervention: &Intervention) -> Option<ConsistencyViolation> {
|
||||
// Lift the intervention to low-level
|
||||
let low_interventions = self.lift_intervention(high_intervention);
|
||||
|
||||
// Simulate high-level model with intervention
|
||||
let high_result = self.high_level.intervene(&[high_intervention.clone()]);
|
||||
let high_values = match high_result {
|
||||
Ok(model) => model.simulate(&HashMap::new()).ok(),
|
||||
Err(_) => None,
|
||||
};
|
||||
|
||||
// Simulate low-level model with lifted interventions
|
||||
let low_result = self.low_level.intervene(&low_interventions);
|
||||
let low_values = match low_result {
|
||||
Ok(model) => model.simulate(&HashMap::new()).ok(),
|
||||
Err(_) => None,
|
||||
};
|
||||
|
||||
// Project low-level results to high-level
|
||||
let (high_values, low_values) = match (high_values, low_values) {
|
||||
(Some(h), Some(l)) => (h, l),
|
||||
_ => return None, // Can't compare if simulation failed
|
||||
};
|
||||
|
||||
let projected = self.project_distribution(&low_values);
|
||||
|
||||
// Compare high-level result with projected result
|
||||
let mut divergence = 0.0;
|
||||
let mut expected = HashMap::new();
|
||||
let mut actual = HashMap::new();
|
||||
|
||||
for high_var in self.high_level.variables() {
|
||||
let high_val = high_values.get(&high_var.id)
|
||||
.map(|v| v.as_f64())
|
||||
.unwrap_or(0.0);
|
||||
let proj_val = projected.get(&high_var.id)
|
||||
.map(|v| v.as_f64())
|
||||
.unwrap_or(0.0);
|
||||
|
||||
let diff = (high_val - proj_val).abs();
|
||||
divergence += diff * diff;
|
||||
|
||||
expected.insert(high_var.name.clone(), high_val);
|
||||
actual.insert(high_var.name.clone(), proj_val);
|
||||
}
|
||||
|
||||
divergence = divergence.sqrt();
|
||||
|
||||
if divergence > self.tolerance {
|
||||
Some(ConsistencyViolation {
|
||||
intervention: format!("do({:?} = {:?})", high_intervention.target, high_intervention.value),
|
||||
expected,
|
||||
actual,
|
||||
divergence,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Lift a high-level intervention to low-level interventions
|
||||
pub fn lift_intervention(&self, high: &Intervention) -> Vec<Intervention> {
|
||||
let low_vars = match self.abstraction_map.project_variable(high.target) {
|
||||
Some(vars) => vars,
|
||||
None => return vec![],
|
||||
};
|
||||
|
||||
// Simple strategy: apply same value to all corresponding low-level variables
|
||||
// More sophisticated approaches could distribute the intervention differently
|
||||
low_vars.iter()
|
||||
.map(|&low_var| Intervention::new(low_var, high.value.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Project a low-level distribution to high-level
|
||||
pub fn project_distribution(&self, low_dist: &HashMap<VariableId, Value>) -> HashMap<VariableId, Value> {
|
||||
let mut high_dist = HashMap::new();
|
||||
|
||||
for high_var in self.high_level.variables() {
|
||||
let projected_value = self.abstraction_map.lift_value(high_var.id, low_dist);
|
||||
high_dist.insert(high_var.id, projected_value);
|
||||
}
|
||||
|
||||
high_dist
|
||||
}
|
||||
|
||||
/// Project a CounterfactualDistribution object
|
||||
pub fn project_distribution_obj(&self, low_dist: &CounterfactualDistribution) -> CounterfactualDistribution {
|
||||
let high_values = self.project_distribution(&low_dist.values);
|
||||
CounterfactualDistribution {
|
||||
values: high_values,
|
||||
probability: low_dist.probability,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the coarsening factor (how much the abstraction simplifies)
|
||||
pub fn coarsening_factor(&self) -> f64 {
|
||||
let low_count = self.low_level.num_variables() as f64;
|
||||
let high_count = self.high_level.num_variables() as f64;
|
||||
|
||||
if high_count > 0.0 {
|
||||
low_count / high_count
|
||||
} else {
|
||||
f64::INFINITY
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a low-level variable is "hidden" (not directly represented in high-level)
|
||||
pub fn is_hidden(&self, low_var: VariableId) -> bool {
|
||||
self.abstraction_map.lift_variable(low_var).is_none()
|
||||
}
|
||||
|
||||
/// Get all hidden variables
|
||||
pub fn hidden_variables(&self) -> Vec<VariableId> {
|
||||
self.low_level.variables()
|
||||
.filter(|v| self.is_hidden(v.id))
|
||||
.map(|v| v.id)
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for creating causal abstractions
|
||||
pub struct AbstractionBuilder<'a> {
|
||||
low_level: Option<&'a CausalModel>,
|
||||
high_level: Option<&'a CausalModel>,
|
||||
map: AbstractionMap,
|
||||
}
|
||||
|
||||
impl<'a> AbstractionBuilder<'a> {
|
||||
/// Create a new builder
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
low_level: None,
|
||||
high_level: None,
|
||||
map: AbstractionMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the low-level model
|
||||
pub fn low_level(mut self, model: &'a CausalModel) -> Self {
|
||||
self.low_level = Some(model);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the high-level model
|
||||
pub fn high_level(mut self, model: &'a CausalModel) -> Self {
|
||||
self.high_level = Some(model);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a variable mapping by name
|
||||
pub fn map_variable(
|
||||
mut self,
|
||||
high_name: &str,
|
||||
low_names: &[&str],
|
||||
aggregator: Aggregator,
|
||||
) -> Self {
|
||||
if let (Some(low), Some(high)) = (self.low_level, self.high_level) {
|
||||
if let Some(high_id) = high.get_variable_id(high_name) {
|
||||
let low_ids: HashSet<_> = low_names.iter()
|
||||
.filter_map(|&name| low.get_variable_id(name))
|
||||
.collect();
|
||||
|
||||
if !low_ids.is_empty() {
|
||||
self.map.add_mapping(high_id, low_ids, aggregator);
|
||||
}
|
||||
}
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// Add an identity mapping by name
|
||||
pub fn map_identity(mut self, high_name: &str, low_name: &str) -> Self {
|
||||
if let (Some(low), Some(high)) = (self.low_level, self.high_level) {
|
||||
if let (Some(high_id), Some(low_id)) = (
|
||||
high.get_variable_id(high_name),
|
||||
low.get_variable_id(low_name),
|
||||
) {
|
||||
self.map.add_identity_mapping(high_id, low_id);
|
||||
}
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the abstraction
|
||||
pub fn build(self) -> Result<CausalAbstraction<'a>, AbstractionError> {
|
||||
let low = self.low_level.ok_or_else(|| {
|
||||
AbstractionError::IncompatibleStructure("No low-level model provided".to_string())
|
||||
})?;
|
||||
let high = self.high_level.ok_or_else(|| {
|
||||
AbstractionError::IncompatibleStructure("No high-level model provided".to_string())
|
||||
})?;
|
||||
|
||||
CausalAbstraction::new(low, high, self.map)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Default for AbstractionBuilder<'a> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::causal::model::{CausalModelBuilder, Mechanism, VariableType};
|
||||
|
||||
fn create_low_level_model() -> CausalModel {
|
||||
let mut model = CausalModel::with_name("Low-Level");
|
||||
|
||||
// Detailed model with separate variables
|
||||
model.add_variable("Age", VariableType::Continuous).unwrap();
|
||||
model.add_variable("Education", VariableType::Continuous).unwrap();
|
||||
model.add_variable("Experience", VariableType::Continuous).unwrap();
|
||||
model.add_variable("Salary", VariableType::Continuous).unwrap();
|
||||
model.add_variable("Savings", VariableType::Continuous).unwrap();
|
||||
|
||||
let age = model.get_variable_id("Age").unwrap();
|
||||
let edu = model.get_variable_id("Education").unwrap();
|
||||
let exp = model.get_variable_id("Experience").unwrap();
|
||||
let salary = model.get_variable_id("Salary").unwrap();
|
||||
let savings = model.get_variable_id("Savings").unwrap();
|
||||
|
||||
// Experience = f(Age, Education)
|
||||
model.add_structural_equation(exp, &[age, edu], Mechanism::new(|p| {
|
||||
Value::Continuous(p[0].as_f64() * 0.5 + p[1].as_f64() * 0.3)
|
||||
})).unwrap();
|
||||
|
||||
// Salary = f(Education, Experience)
|
||||
model.add_structural_equation(salary, &[edu, exp], Mechanism::new(|p| {
|
||||
Value::Continuous(30000.0 + p[0].as_f64() * 5000.0 + p[1].as_f64() * 2000.0)
|
||||
})).unwrap();
|
||||
|
||||
// Savings = f(Salary)
|
||||
model.add_structural_equation(savings, &[salary], Mechanism::new(|p| {
|
||||
Value::Continuous(p[0].as_f64() * 0.2)
|
||||
})).unwrap();
|
||||
|
||||
model
|
||||
}
|
||||
|
||||
fn create_high_level_model() -> CausalModel {
|
||||
let mut model = CausalModel::with_name("High-Level");
|
||||
|
||||
// Simplified model with aggregated variables
|
||||
model.add_variable("HumanCapital", VariableType::Continuous).unwrap();
|
||||
model.add_variable("Wealth", VariableType::Continuous).unwrap();
|
||||
|
||||
let hc = model.get_variable_id("HumanCapital").unwrap();
|
||||
let wealth = model.get_variable_id("Wealth").unwrap();
|
||||
|
||||
// Wealth = f(HumanCapital)
|
||||
model.add_structural_equation(wealth, &[hc], Mechanism::new(|p| {
|
||||
Value::Continuous(p[0].as_f64() * 10000.0)
|
||||
})).unwrap();
|
||||
|
||||
model
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_abstraction_map() {
|
||||
let low = create_low_level_model();
|
||||
let high = create_high_level_model();
|
||||
|
||||
let mut map = AbstractionMap::new();
|
||||
|
||||
// HumanCapital = mean(Education, Experience)
|
||||
let hc_id = high.get_variable_id("HumanCapital").unwrap();
|
||||
let edu_id = low.get_variable_id("Education").unwrap();
|
||||
let exp_id = low.get_variable_id("Experience").unwrap();
|
||||
|
||||
let mut low_vars = HashSet::new();
|
||||
low_vars.insert(edu_id);
|
||||
low_vars.insert(exp_id);
|
||||
map.add_mapping(hc_id, low_vars, Aggregator::Mean);
|
||||
|
||||
// Wealth = sum(Salary, Savings)
|
||||
let wealth_id = high.get_variable_id("Wealth").unwrap();
|
||||
let salary_id = low.get_variable_id("Salary").unwrap();
|
||||
let savings_id = low.get_variable_id("Savings").unwrap();
|
||||
|
||||
let mut wealth_vars = HashSet::new();
|
||||
wealth_vars.insert(salary_id);
|
||||
wealth_vars.insert(savings_id);
|
||||
map.add_mapping(wealth_id, wealth_vars, Aggregator::Sum);
|
||||
|
||||
assert!(map.is_surjective(&high));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aggregators() {
|
||||
let values = vec![
|
||||
Value::Continuous(1.0),
|
||||
Value::Continuous(2.0),
|
||||
Value::Continuous(3.0),
|
||||
];
|
||||
|
||||
assert_eq!(Aggregator::First.apply(&values).as_f64(), 1.0);
|
||||
assert_eq!(Aggregator::Sum.apply(&values).as_f64(), 6.0);
|
||||
assert_eq!(Aggregator::Mean.apply(&values).as_f64(), 2.0);
|
||||
assert_eq!(Aggregator::Max.apply(&values).as_f64(), 3.0);
|
||||
assert_eq!(Aggregator::Min.apply(&values).as_f64(), 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lift_intervention() {
|
||||
let low = create_low_level_model();
|
||||
let high = create_high_level_model();
|
||||
|
||||
let mut map = AbstractionMap::new();
|
||||
|
||||
let hc_id = high.get_variable_id("HumanCapital").unwrap();
|
||||
let edu_id = low.get_variable_id("Education").unwrap();
|
||||
let exp_id = low.get_variable_id("Experience").unwrap();
|
||||
|
||||
let mut low_vars = HashSet::new();
|
||||
low_vars.insert(edu_id);
|
||||
low_vars.insert(exp_id);
|
||||
map.add_mapping(hc_id, low_vars, Aggregator::Mean);
|
||||
|
||||
// Add wealth mapping
|
||||
let wealth_id = high.get_variable_id("Wealth").unwrap();
|
||||
let salary_id = low.get_variable_id("Salary").unwrap();
|
||||
let mut wealth_vars = HashSet::new();
|
||||
wealth_vars.insert(salary_id);
|
||||
map.add_mapping(wealth_id, wealth_vars, Aggregator::First);
|
||||
|
||||
let abstraction = CausalAbstraction::new(&low, &high, map).unwrap();
|
||||
|
||||
let high_intervention = Intervention::new(hc_id, Value::Continuous(10.0));
|
||||
let low_interventions = abstraction.lift_intervention(&high_intervention);
|
||||
|
||||
// Should lift to interventions on Education and Experience
|
||||
assert_eq!(low_interventions.len(), 2);
|
||||
assert!(low_interventions.iter().any(|i| i.target == edu_id));
|
||||
assert!(low_interventions.iter().any(|i| i.target == exp_id));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_project_distribution() {
|
||||
let low = create_low_level_model();
|
||||
let high = create_high_level_model();
|
||||
|
||||
let mut map = AbstractionMap::new();
|
||||
|
||||
let hc_id = high.get_variable_id("HumanCapital").unwrap();
|
||||
let edu_id = low.get_variable_id("Education").unwrap();
|
||||
|
||||
let mut low_vars = HashSet::new();
|
||||
low_vars.insert(edu_id);
|
||||
map.add_mapping(hc_id, low_vars, Aggregator::First);
|
||||
|
||||
let wealth_id = high.get_variable_id("Wealth").unwrap();
|
||||
let salary_id = low.get_variable_id("Salary").unwrap();
|
||||
|
||||
let mut wealth_vars = HashSet::new();
|
||||
wealth_vars.insert(salary_id);
|
||||
map.add_mapping(wealth_id, wealth_vars, Aggregator::First);
|
||||
|
||||
let abstraction = CausalAbstraction::new(&low, &high, map).unwrap();
|
||||
|
||||
let mut low_dist = HashMap::new();
|
||||
low_dist.insert(edu_id, Value::Continuous(16.0));
|
||||
low_dist.insert(salary_id, Value::Continuous(80000.0));
|
||||
|
||||
let high_dist = abstraction.project_distribution(&low_dist);
|
||||
|
||||
assert_eq!(high_dist.get(&hc_id).unwrap().as_f64(), 16.0);
|
||||
assert_eq!(high_dist.get(&wealth_id).unwrap().as_f64(), 80000.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_coarsening_factor() {
|
||||
let low = create_low_level_model();
|
||||
let high = create_high_level_model();
|
||||
|
||||
let mut map = AbstractionMap::new();
|
||||
|
||||
// Simple identity mappings for this test
|
||||
let hc_id = high.get_variable_id("HumanCapital").unwrap();
|
||||
let edu_id = low.get_variable_id("Education").unwrap();
|
||||
map.add_identity_mapping(hc_id, edu_id);
|
||||
|
||||
let wealth_id = high.get_variable_id("Wealth").unwrap();
|
||||
let salary_id = low.get_variable_id("Salary").unwrap();
|
||||
map.add_identity_mapping(wealth_id, salary_id);
|
||||
|
||||
let abstraction = CausalAbstraction::new(&low, &high, map).unwrap();
|
||||
|
||||
// 5 low-level vars / 2 high-level vars = 2.5
|
||||
assert!((abstraction.coarsening_factor() - 2.5).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hidden_variables() {
|
||||
let low = create_low_level_model();
|
||||
let high = create_high_level_model();
|
||||
|
||||
let mut map = AbstractionMap::new();
|
||||
|
||||
// Only map Education to HumanCapital
|
||||
let hc_id = high.get_variable_id("HumanCapital").unwrap();
|
||||
let edu_id = low.get_variable_id("Education").unwrap();
|
||||
map.add_identity_mapping(hc_id, edu_id);
|
||||
|
||||
// Only map Salary to Wealth
|
||||
let wealth_id = high.get_variable_id("Wealth").unwrap();
|
||||
let salary_id = low.get_variable_id("Salary").unwrap();
|
||||
map.add_identity_mapping(wealth_id, salary_id);
|
||||
|
||||
let abstraction = CausalAbstraction::new(&low, &high, map).unwrap();
|
||||
|
||||
let hidden = abstraction.hidden_variables();
|
||||
|
||||
// Age, Experience, Savings should be hidden
|
||||
let age_id = low.get_variable_id("Age").unwrap();
|
||||
let exp_id = low.get_variable_id("Experience").unwrap();
|
||||
let savings_id = low.get_variable_id("Savings").unwrap();
|
||||
|
||||
assert!(hidden.contains(&age_id));
|
||||
assert!(hidden.contains(&exp_id));
|
||||
assert!(hidden.contains(&savings_id));
|
||||
assert!(!hidden.contains(&edu_id));
|
||||
assert!(!hidden.contains(&salary_id));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_builder() {
|
||||
let low = create_low_level_model();
|
||||
let high = create_high_level_model();
|
||||
|
||||
let abstraction = AbstractionBuilder::new()
|
||||
.low_level(&low)
|
||||
.high_level(&high)
|
||||
.map_identity("HumanCapital", "Education")
|
||||
.map_identity("Wealth", "Salary")
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(abstraction.coarsening_factor(), 2.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_consistency_check() {
|
||||
let low = create_low_level_model();
|
||||
let high = create_high_level_model();
|
||||
|
||||
let mut map = AbstractionMap::new();
|
||||
|
||||
let hc_id = high.get_variable_id("HumanCapital").unwrap();
|
||||
let edu_id = low.get_variable_id("Education").unwrap();
|
||||
map.add_identity_mapping(hc_id, edu_id);
|
||||
|
||||
let wealth_id = high.get_variable_id("Wealth").unwrap();
|
||||
let salary_id = low.get_variable_id("Salary").unwrap();
|
||||
map.add_identity_mapping(wealth_id, salary_id);
|
||||
|
||||
let abstraction = CausalAbstraction::new(&low, &high, map)
|
||||
.unwrap()
|
||||
.with_tolerance(1000.0); // High tolerance for this test
|
||||
|
||||
let result = abstraction.check_consistency();
|
||||
|
||||
// The abstraction may or may not be consistent depending on mechanisms
|
||||
// This test just verifies the check runs
|
||||
assert!(result.interventions_tested > 0);
|
||||
}
|
||||
}
|
||||
973
vendor/ruvector/examples/prime-radiant/src/causal/coherence.rs
vendored
Normal file
973
vendor/ruvector/examples/prime-radiant/src/causal/coherence.rs
vendored
Normal file
@@ -0,0 +1,973 @@
|
||||
//! Causal Coherence Checking
|
||||
//!
|
||||
//! This module provides tools for verifying that beliefs and data are
|
||||
//! consistent with a causal model. Key capabilities:
|
||||
//!
|
||||
//! - Detecting spurious correlations (associations not explained by causation)
|
||||
//! - Checking if beliefs satisfy causal constraints
|
||||
//! - Answering causal queries using do-calculus
|
||||
//! - Computing coherence energy for integration with Prime-Radiant
|
||||
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use thiserror::Error;
|
||||
|
||||
use super::model::{CausalModel, CausalModelError, Value, VariableId, VariableType, Mechanism};
|
||||
use super::graph::DAGValidationError;
|
||||
|
||||
/// Error types for coherence operations
|
||||
#[derive(Debug, Clone, Error)]
|
||||
pub enum CoherenceError {
|
||||
/// Model error
|
||||
#[error("Model error: {0}")]
|
||||
ModelError(#[from] CausalModelError),
|
||||
|
||||
/// Graph error
|
||||
#[error("Graph error: {0}")]
|
||||
GraphError(#[from] DAGValidationError),
|
||||
|
||||
/// Inconsistent belief
|
||||
#[error("Inconsistent belief: {0}")]
|
||||
InconsistentBelief(String),
|
||||
|
||||
/// Invalid query
|
||||
#[error("Invalid query: {0}")]
|
||||
InvalidQuery(String),
|
||||
}
|
||||
|
||||
/// A belief about the relationship between variables
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Belief {
|
||||
/// Subject variable
|
||||
pub subject: String,
|
||||
/// Object variable
|
||||
pub object: String,
|
||||
/// Type of belief
|
||||
pub belief_type: BeliefType,
|
||||
/// Confidence in the belief (0.0 to 1.0)
|
||||
pub confidence: f64,
|
||||
/// Evidence supporting the belief
|
||||
pub evidence: Option<String>,
|
||||
}
|
||||
|
||||
/// Types of causal beliefs
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum BeliefType {
|
||||
/// X causes Y
|
||||
Causes,
|
||||
/// X is correlated with Y (may or may not be causal)
|
||||
CorrelatedWith,
|
||||
/// X is independent of Y
|
||||
IndependentOf,
|
||||
/// X is independent of Y given Z
|
||||
ConditionallyIndependent { given: Vec<String> },
|
||||
/// X and Y have a common cause
|
||||
CommonCause,
|
||||
/// Changing X would change Y (interventional)
|
||||
WouldChange,
|
||||
}
|
||||
|
||||
impl Belief {
|
||||
/// Create a causal belief: X causes Y
|
||||
pub fn causes(x: &str, y: &str) -> Self {
|
||||
Self {
|
||||
subject: x.to_string(),
|
||||
object: y.to_string(),
|
||||
belief_type: BeliefType::Causes,
|
||||
confidence: 1.0,
|
||||
evidence: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a correlation belief
|
||||
pub fn correlated(x: &str, y: &str) -> Self {
|
||||
Self {
|
||||
subject: x.to_string(),
|
||||
object: y.to_string(),
|
||||
belief_type: BeliefType::CorrelatedWith,
|
||||
confidence: 1.0,
|
||||
evidence: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an independence belief
|
||||
pub fn independent(x: &str, y: &str) -> Self {
|
||||
Self {
|
||||
subject: x.to_string(),
|
||||
object: y.to_string(),
|
||||
belief_type: BeliefType::IndependentOf,
|
||||
confidence: 1.0,
|
||||
evidence: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a conditional independence belief
|
||||
pub fn conditionally_independent(x: &str, y: &str, given: &[&str]) -> Self {
|
||||
Self {
|
||||
subject: x.to_string(),
|
||||
object: y.to_string(),
|
||||
belief_type: BeliefType::ConditionallyIndependent {
|
||||
given: given.iter().map(|s| s.to_string()).collect(),
|
||||
},
|
||||
confidence: 1.0,
|
||||
evidence: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set confidence level
|
||||
pub fn with_confidence(mut self, confidence: f64) -> Self {
|
||||
self.confidence = confidence.clamp(0.0, 1.0);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set evidence
|
||||
pub fn with_evidence(mut self, evidence: &str) -> Self {
|
||||
self.evidence = Some(evidence.to_string());
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of causal consistency checking
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CausalConsistency {
|
||||
/// Overall consistency score (0.0 to 1.0)
|
||||
pub score: f64,
|
||||
/// Number of beliefs checked
|
||||
pub beliefs_checked: usize,
|
||||
/// Number of consistent beliefs
|
||||
pub consistent_beliefs: usize,
|
||||
/// Number of inconsistent beliefs
|
||||
pub inconsistent_beliefs: usize,
|
||||
/// Details of inconsistencies
|
||||
pub inconsistencies: Vec<Inconsistency>,
|
||||
/// Suggested model modifications
|
||||
pub suggestions: Vec<String>,
|
||||
}
|
||||
|
||||
impl CausalConsistency {
|
||||
/// Create a fully consistent result
|
||||
pub fn fully_consistent(beliefs_checked: usize) -> Self {
|
||||
Self {
|
||||
score: 1.0,
|
||||
beliefs_checked,
|
||||
consistent_beliefs: beliefs_checked,
|
||||
inconsistent_beliefs: 0,
|
||||
inconsistencies: vec![],
|
||||
suggestions: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if fully consistent
|
||||
pub fn is_consistent(&self) -> bool {
|
||||
self.score >= 1.0 - 1e-10
|
||||
}
|
||||
}
|
||||
|
||||
/// Details of a causal inconsistency
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Inconsistency {
|
||||
/// The belief that is inconsistent
|
||||
pub belief: Belief,
|
||||
/// Why it's inconsistent
|
||||
pub reason: String,
|
||||
/// Severity (0.0 to 1.0)
|
||||
pub severity: f64,
|
||||
}
|
||||
|
||||
/// A detected spurious correlation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SpuriousCorrelation {
|
||||
/// First variable
|
||||
pub var_a: String,
|
||||
/// Second variable
|
||||
pub var_b: String,
|
||||
/// The common cause(s) explaining the correlation
|
||||
pub confounders: Vec<String>,
|
||||
/// Strength of the spurious correlation
|
||||
pub strength: f64,
|
||||
/// Explanation
|
||||
pub explanation: String,
|
||||
}
|
||||
|
||||
/// A causal query
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CausalQuery {
|
||||
/// The variable we're asking about
|
||||
pub target: String,
|
||||
/// Variables we're intervening on
|
||||
pub interventions: Vec<(String, Value)>,
|
||||
/// Variables we're conditioning on
|
||||
pub conditions: Vec<(String, Value)>,
|
||||
/// Query type
|
||||
pub query_type: QueryType,
|
||||
}
|
||||
|
||||
/// Types of causal queries
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum QueryType {
|
||||
/// P(Y | do(X=x)) - interventional query
|
||||
Interventional,
|
||||
/// P(Y | X=x) - observational query
|
||||
Observational,
|
||||
/// P(Y_x | X=x') - counterfactual query
|
||||
Counterfactual,
|
||||
/// P(Y | do(X=x), Z=z) - conditional interventional
|
||||
ConditionalInterventional,
|
||||
}
|
||||
|
||||
impl CausalQuery {
|
||||
/// Create an interventional query: P(target | do(intervention))
|
||||
pub fn interventional(target: &str, intervention_var: &str, intervention_val: Value) -> Self {
|
||||
Self {
|
||||
target: target.to_string(),
|
||||
interventions: vec![(intervention_var.to_string(), intervention_val)],
|
||||
conditions: vec![],
|
||||
query_type: QueryType::Interventional,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an observational query: P(target | condition)
|
||||
pub fn observational(target: &str, condition_var: &str, condition_val: Value) -> Self {
|
||||
Self {
|
||||
target: target.to_string(),
|
||||
interventions: vec![],
|
||||
conditions: vec![(condition_var.to_string(), condition_val)],
|
||||
query_type: QueryType::Observational,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a condition
|
||||
pub fn given(mut self, var: &str, val: Value) -> Self {
|
||||
self.conditions.push((var.to_string(), val));
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Answer to a causal query
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CausalAnswer {
|
||||
/// The query that was answered
|
||||
pub query: CausalQuery,
|
||||
/// The estimated value/distribution
|
||||
pub estimate: Value,
|
||||
/// Confidence interval (if applicable)
|
||||
pub confidence_interval: Option<(f64, f64)>,
|
||||
/// Whether the query is identifiable from observational data
|
||||
pub is_identifiable: bool,
|
||||
/// Explanation of the answer
|
||||
pub explanation: String,
|
||||
}
|
||||
|
||||
/// Combined coherence energy for integration with Prime-Radiant
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CoherenceEnergy {
|
||||
/// Total energy (lower is more coherent)
|
||||
pub total: f64,
|
||||
/// Structural component (from sheaf consistency)
|
||||
pub structural_component: f64,
|
||||
/// Causal component (from causal consistency)
|
||||
pub causal_component: f64,
|
||||
/// Intervention component (from intervention consistency)
|
||||
pub intervention_component: f64,
|
||||
/// Whether the system is coherent (energy below threshold)
|
||||
pub is_coherent: bool,
|
||||
}
|
||||
|
||||
impl CoherenceEnergy {
|
||||
/// Create a fully coherent state
|
||||
pub fn coherent() -> Self {
|
||||
Self {
|
||||
total: 0.0,
|
||||
structural_component: 0.0,
|
||||
causal_component: 0.0,
|
||||
intervention_component: 0.0,
|
||||
is_coherent: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create from individual components
|
||||
pub fn from_components(structural: f64, causal: f64, intervention: f64) -> Self {
|
||||
let total = structural + causal + intervention;
|
||||
Self {
|
||||
total,
|
||||
structural_component: structural,
|
||||
causal_component: causal,
|
||||
intervention_component: intervention,
|
||||
is_coherent: total < 1e-6,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Dataset for spurious correlation detection
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Dataset {
|
||||
/// Column names
|
||||
pub columns: Vec<String>,
|
||||
/// Data rows (each row is a vector of values)
|
||||
pub rows: Vec<Vec<f64>>,
|
||||
}
|
||||
|
||||
impl Dataset {
|
||||
/// Create a new dataset
|
||||
pub fn new(columns: Vec<String>) -> Self {
|
||||
Self {
|
||||
columns,
|
||||
rows: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a row
|
||||
pub fn add_row(&mut self, row: Vec<f64>) {
|
||||
if row.len() == self.columns.len() {
|
||||
self.rows.push(row);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get column index
|
||||
pub fn column_index(&self, name: &str) -> Option<usize> {
|
||||
self.columns.iter().position(|c| c == name)
|
||||
}
|
||||
|
||||
/// Get column values
|
||||
pub fn column(&self, name: &str) -> Option<Vec<f64>> {
|
||||
let idx = self.column_index(name)?;
|
||||
Some(self.rows.iter().map(|row| row[idx]).collect())
|
||||
}
|
||||
|
||||
/// Compute correlation between two columns
|
||||
pub fn correlation(&self, col_a: &str, col_b: &str) -> Option<f64> {
|
||||
let a = self.column(col_a)?;
|
||||
let b = self.column(col_b)?;
|
||||
|
||||
if a.len() != b.len() || a.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let n = a.len() as f64;
|
||||
let mean_a: f64 = a.iter().sum::<f64>() / n;
|
||||
let mean_b: f64 = b.iter().sum::<f64>() / n;
|
||||
|
||||
let mut cov = 0.0;
|
||||
let mut var_a = 0.0;
|
||||
let mut var_b = 0.0;
|
||||
|
||||
for i in 0..a.len() {
|
||||
let da = a[i] - mean_a;
|
||||
let db = b[i] - mean_b;
|
||||
cov += da * db;
|
||||
var_a += da * da;
|
||||
var_b += db * db;
|
||||
}
|
||||
|
||||
let denom = (var_a * var_b).sqrt();
|
||||
if denom < 1e-10 {
|
||||
Some(0.0)
|
||||
} else {
|
||||
Some(cov / denom)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Causal coherence checker
|
||||
pub struct CausalCoherenceChecker<'a> {
|
||||
/// The causal model
|
||||
model: &'a CausalModel,
|
||||
/// Correlation threshold for "significant" correlation
|
||||
correlation_threshold: f64,
|
||||
}
|
||||
|
||||
impl<'a> CausalCoherenceChecker<'a> {
|
||||
/// Create a new checker
|
||||
pub fn new(model: &'a CausalModel) -> Self {
|
||||
Self {
|
||||
model,
|
||||
correlation_threshold: 0.1,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set correlation threshold
|
||||
pub fn with_correlation_threshold(mut self, threshold: f64) -> Self {
|
||||
self.correlation_threshold = threshold;
|
||||
self
|
||||
}
|
||||
|
||||
/// Check if a set of beliefs is consistent with the causal model
|
||||
pub fn check_causal_consistency(&self, beliefs: &[Belief]) -> CausalConsistency {
|
||||
let mut consistent_count = 0;
|
||||
let mut inconsistencies = Vec::new();
|
||||
let mut suggestions = Vec::new();
|
||||
|
||||
for belief in beliefs {
|
||||
match self.check_single_belief(belief) {
|
||||
Ok(()) => consistent_count += 1,
|
||||
Err(reason) => {
|
||||
inconsistencies.push(Inconsistency {
|
||||
belief: belief.clone(),
|
||||
reason: reason.clone(),
|
||||
severity: 1.0 - belief.confidence,
|
||||
});
|
||||
|
||||
// Generate suggestion
|
||||
if let Some(suggestion) = self.generate_suggestion(belief, &reason) {
|
||||
suggestions.push(suggestion);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let beliefs_checked = beliefs.len();
|
||||
let score = if beliefs_checked > 0 {
|
||||
consistent_count as f64 / beliefs_checked as f64
|
||||
} else {
|
||||
1.0
|
||||
};
|
||||
|
||||
CausalConsistency {
|
||||
score,
|
||||
beliefs_checked,
|
||||
consistent_beliefs: consistent_count,
|
||||
inconsistent_beliefs: beliefs_checked - consistent_count,
|
||||
inconsistencies,
|
||||
suggestions,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check a single belief against the model
|
||||
fn check_single_belief(&self, belief: &Belief) -> Result<(), String> {
|
||||
let subject_id = self.model.get_variable_id(&belief.subject)
|
||||
.ok_or_else(|| format!("Variable '{}' not in model", belief.subject))?;
|
||||
let object_id = self.model.get_variable_id(&belief.object)
|
||||
.ok_or_else(|| format!("Variable '{}' not in model", belief.object))?;
|
||||
|
||||
match &belief.belief_type {
|
||||
BeliefType::Causes => {
|
||||
// Check if there's a directed path from subject to object
|
||||
let descendants = self.model.graph().descendants(subject_id.0);
|
||||
if !descendants.contains(&object_id.0) {
|
||||
return Err(format!(
|
||||
"No causal path from {} to {} in model",
|
||||
belief.subject, belief.object
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
BeliefType::IndependentOf => {
|
||||
// Check if they're d-separated given empty set
|
||||
if !self.model.d_separated(subject_id, object_id, &[]) {
|
||||
return Err(format!(
|
||||
"{} and {} are not independent according to model",
|
||||
belief.subject, belief.object
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
BeliefType::ConditionallyIndependent { given } => {
|
||||
let given_ids: Result<Vec<VariableId>, _> = given.iter()
|
||||
.map(|name| {
|
||||
self.model.get_variable_id(name)
|
||||
.ok_or_else(|| format!("Variable '{}' not in model", name))
|
||||
})
|
||||
.collect();
|
||||
let given_ids = given_ids?;
|
||||
|
||||
if !self.model.d_separated(subject_id, object_id, &given_ids) {
|
||||
return Err(format!(
|
||||
"{} and {} are not conditionally independent given {:?}",
|
||||
belief.subject, belief.object, given
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
BeliefType::CommonCause => {
|
||||
// Check if they share a common ancestor
|
||||
let ancestors_a = self.model.graph().ancestors(subject_id.0);
|
||||
let ancestors_b = self.model.graph().ancestors(object_id.0);
|
||||
|
||||
let common: HashSet<_> = ancestors_a.intersection(&ancestors_b).collect();
|
||||
if common.is_empty() {
|
||||
return Err(format!(
|
||||
"No common cause found for {} and {}",
|
||||
belief.subject, belief.object
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
BeliefType::CorrelatedWith | BeliefType::WouldChange => {
|
||||
// These are not directly checkable against model structure alone
|
||||
// They would need data or simulation
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Generate a suggestion for fixing an inconsistency
|
||||
fn generate_suggestion(&self, belief: &Belief, _reason: &str) -> Option<String> {
|
||||
match &belief.belief_type {
|
||||
BeliefType::Causes => {
|
||||
Some(format!(
|
||||
"Consider adding edge {} -> {} to the model, or revising the belief",
|
||||
belief.subject, belief.object
|
||||
))
|
||||
}
|
||||
BeliefType::IndependentOf => {
|
||||
Some(format!(
|
||||
"Consider conditioning on a confounding variable, or revising the model structure"
|
||||
))
|
||||
}
|
||||
BeliefType::ConditionallyIndependent { given } => {
|
||||
Some(format!(
|
||||
"The conditioning set {:?} may be insufficient; consider additional variables",
|
||||
given
|
||||
))
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Detect spurious correlations in data given the causal model
|
||||
pub fn detect_spurious_correlations(&self, data: &Dataset) -> Vec<SpuriousCorrelation> {
|
||||
let mut spurious = Vec::new();
|
||||
|
||||
// Check all pairs of variables
|
||||
for i in 0..data.columns.len() {
|
||||
for j in (i + 1)..data.columns.len() {
|
||||
let col_a = &data.columns[i];
|
||||
let col_b = &data.columns[j];
|
||||
|
||||
// Get correlation from data
|
||||
let correlation = match data.correlation(col_a, col_b) {
|
||||
Some(c) => c,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
// If significantly correlated
|
||||
if correlation.abs() > self.correlation_threshold {
|
||||
// Check if causally linked
|
||||
if let (Some(id_a), Some(id_b)) = (
|
||||
self.model.get_variable_id(col_a),
|
||||
self.model.get_variable_id(col_b),
|
||||
) {
|
||||
// Check if there's a direct causal path
|
||||
let a_causes_b = self.model.graph().descendants(id_a.0).contains(&id_b.0);
|
||||
let b_causes_a = self.model.graph().descendants(id_b.0).contains(&id_a.0);
|
||||
|
||||
if !a_causes_b && !b_causes_a {
|
||||
// Correlation without direct causation - find confounders
|
||||
let confounders = self.find_confounders(id_a, id_b);
|
||||
|
||||
if !confounders.is_empty() {
|
||||
spurious.push(SpuriousCorrelation {
|
||||
var_a: col_a.clone(),
|
||||
var_b: col_b.clone(),
|
||||
confounders: confounders.clone(),
|
||||
strength: correlation.abs(),
|
||||
explanation: format!(
|
||||
"Correlation (r={:.3}) between {} and {} is explained by common cause(s): {}",
|
||||
correlation, col_a, col_b, confounders.join(", ")
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
spurious
|
||||
}
|
||||
|
||||
/// Find common causes (confounders) of two variables
|
||||
fn find_confounders(&self, a: VariableId, b: VariableId) -> Vec<String> {
|
||||
let ancestors_a = self.model.graph().ancestors(a.0);
|
||||
let ancestors_b = self.model.graph().ancestors(b.0);
|
||||
|
||||
let common: Vec<_> = ancestors_a.intersection(&ancestors_b)
|
||||
.filter_map(|&id| self.model.get_variable_name(&VariableId(id)))
|
||||
.collect();
|
||||
|
||||
common
|
||||
}
|
||||
|
||||
/// Answer a causal query using do-calculus
|
||||
pub fn enforce_do_calculus(&self, query: &CausalQuery) -> Result<CausalAnswer, CoherenceError> {
|
||||
// Get target variable
|
||||
let target_id = self.model.get_variable_id(&query.target)
|
||||
.ok_or_else(|| CoherenceError::InvalidQuery(
|
||||
format!("Target variable '{}' not in model", query.target)
|
||||
))?;
|
||||
|
||||
match query.query_type {
|
||||
QueryType::Interventional => {
|
||||
self.answer_interventional_query(query, target_id)
|
||||
}
|
||||
QueryType::Observational => {
|
||||
self.answer_observational_query(query, target_id)
|
||||
}
|
||||
QueryType::Counterfactual => {
|
||||
self.answer_counterfactual_query(query, target_id)
|
||||
}
|
||||
QueryType::ConditionalInterventional => {
|
||||
self.answer_conditional_interventional_query(query, target_id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn answer_interventional_query(
|
||||
&self,
|
||||
query: &CausalQuery,
|
||||
target_id: VariableId,
|
||||
) -> Result<CausalAnswer, CoherenceError> {
|
||||
// Convert intervention specification to Intervention objects
|
||||
let interventions: Result<Vec<_>, _> = query.interventions.iter()
|
||||
.map(|(var, val)| {
|
||||
self.model.get_variable_id(var)
|
||||
.map(|id| super::model::Intervention::new(id, val.clone()))
|
||||
.ok_or_else(|| CoherenceError::InvalidQuery(
|
||||
format!("Intervention variable '{}' not in model", var)
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
let interventions = interventions?;
|
||||
|
||||
// Perform intervention
|
||||
let intervened = self.model.intervene_with(&interventions)?;
|
||||
|
||||
// Simulate to get the target value
|
||||
let values = intervened.simulate(&HashMap::new())?;
|
||||
|
||||
let estimate = values.get(&target_id).cloned().unwrap_or(Value::Missing);
|
||||
|
||||
// Check identifiability
|
||||
let is_identifiable = self.check_identifiability(query);
|
||||
|
||||
Ok(CausalAnswer {
|
||||
query: query.clone(),
|
||||
estimate,
|
||||
confidence_interval: None,
|
||||
is_identifiable,
|
||||
explanation: format!(
|
||||
"Computed P({} | do({})) by intervention simulation",
|
||||
query.target,
|
||||
query.interventions.iter()
|
||||
.map(|(v, val)| format!("{}={:?}", v, val))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
),
|
||||
})
|
||||
}
|
||||
|
||||
fn answer_observational_query(
|
||||
&self,
|
||||
query: &CausalQuery,
|
||||
target_id: VariableId,
|
||||
) -> Result<CausalAnswer, CoherenceError> {
|
||||
// For observational queries, we need to condition
|
||||
// This requires probabilistic reasoning which we approximate
|
||||
|
||||
let explanation = format!(
|
||||
"Observational query P({} | {}) - requires probabilistic inference",
|
||||
query.target,
|
||||
query.conditions.iter()
|
||||
.map(|(v, val)| format!("{}={:?}", v, val))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
);
|
||||
|
||||
Ok(CausalAnswer {
|
||||
query: query.clone(),
|
||||
estimate: Value::Missing, // Would need actual probabilistic computation
|
||||
confidence_interval: None,
|
||||
is_identifiable: true, // Observational queries are always identifiable
|
||||
explanation,
|
||||
})
|
||||
}
|
||||
|
||||
fn answer_counterfactual_query(
|
||||
&self,
|
||||
query: &CausalQuery,
|
||||
_target_id: VariableId,
|
||||
) -> Result<CausalAnswer, CoherenceError> {
|
||||
// Counterfactual queries require abduction-action-prediction
|
||||
let explanation = format!(
|
||||
"Counterfactual query for {} - requires three-step process: abduction, action, prediction",
|
||||
query.target
|
||||
);
|
||||
|
||||
Ok(CausalAnswer {
|
||||
query: query.clone(),
|
||||
estimate: Value::Missing,
|
||||
confidence_interval: None,
|
||||
is_identifiable: false, // Counterfactuals often not identifiable
|
||||
explanation,
|
||||
})
|
||||
}
|
||||
|
||||
fn answer_conditional_interventional_query(
|
||||
&self,
|
||||
query: &CausalQuery,
|
||||
target_id: VariableId,
|
||||
) -> Result<CausalAnswer, CoherenceError> {
|
||||
// Combines intervention with conditioning
|
||||
let explanation = format!(
|
||||
"Conditional interventional query P({} | do({}), {}) - may require adjustment formula",
|
||||
query.target,
|
||||
query.interventions.iter()
|
||||
.map(|(v, val)| format!("{}={:?}", v, val))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", "),
|
||||
query.conditions.iter()
|
||||
.map(|(v, val)| format!("{}={:?}", v, val))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
);
|
||||
|
||||
Ok(CausalAnswer {
|
||||
query: query.clone(),
|
||||
estimate: Value::Missing,
|
||||
confidence_interval: None,
|
||||
is_identifiable: self.check_identifiability(query),
|
||||
explanation,
|
||||
})
|
||||
}
|
||||
|
||||
/// Check if a causal query is identifiable from observational data
|
||||
fn check_identifiability(&self, query: &CausalQuery) -> bool {
|
||||
// Simplified identifiability check
|
||||
// Full implementation would use do-calculus rules
|
||||
|
||||
if query.interventions.is_empty() {
|
||||
return true; // Observational queries are identifiable
|
||||
}
|
||||
|
||||
// Check if intervention variables have unobserved confounders with target
|
||||
for (var, _) in &query.interventions {
|
||||
if let (Some(var_id), Some(target_id)) = (
|
||||
self.model.get_variable_id(var),
|
||||
self.model.get_variable_id(&query.target),
|
||||
) {
|
||||
// If there's a backdoor path that can't be blocked, not identifiable
|
||||
// This is a simplified check
|
||||
let var_ancestors = self.model.graph().ancestors(var_id.0);
|
||||
let target_ancestors = self.model.graph().ancestors(target_id.0);
|
||||
|
||||
// If they share unobserved common ancestors, might not be identifiable
|
||||
let common = var_ancestors.intersection(&target_ancestors).count();
|
||||
if common > 0 && !self.has_valid_adjustment_set(var_id, target_id) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Check if there's a valid adjustment set for identifying causal effect
|
||||
fn has_valid_adjustment_set(&self, treatment: VariableId, outcome: VariableId) -> bool {
|
||||
// Check backdoor criterion
|
||||
// A set Z satisfies backdoor criterion if:
|
||||
// 1. No node in Z is a descendant of X
|
||||
// 2. Z blocks every path from X to Y that contains an arrow into X
|
||||
|
||||
let descendants = self.model.graph().descendants(treatment.0);
|
||||
|
||||
// Try the set of all non-descendants as potential adjustment set
|
||||
let all_vars: Vec<_> = self.model.variables()
|
||||
.filter(|v| v.id != treatment && v.id != outcome)
|
||||
.filter(|v| !descendants.contains(&v.id.0))
|
||||
.map(|v| v.id)
|
||||
.collect();
|
||||
|
||||
// Check if conditioning on all non-descendants blocks backdoor paths
|
||||
self.model.d_separated(treatment, outcome, &all_vars)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::causal::model::{CausalModelBuilder, VariableType};
|
||||
|
||||
fn create_test_model() -> CausalModel {
|
||||
let mut model = CausalModel::with_name("Test");
|
||||
|
||||
model.add_variable("Age", VariableType::Continuous).unwrap();
|
||||
model.add_variable("Education", VariableType::Continuous).unwrap();
|
||||
model.add_variable("Income", VariableType::Continuous).unwrap();
|
||||
model.add_variable("Health", VariableType::Continuous).unwrap();
|
||||
|
||||
let age = model.get_variable_id("Age").unwrap();
|
||||
let edu = model.get_variable_id("Education").unwrap();
|
||||
let income = model.get_variable_id("Income").unwrap();
|
||||
let health = model.get_variable_id("Health").unwrap();
|
||||
|
||||
// Age -> Education, Age -> Health
|
||||
model.add_edge(age, edu).unwrap();
|
||||
model.add_edge(age, health).unwrap();
|
||||
|
||||
// Education -> Income
|
||||
model.add_edge(edu, income).unwrap();
|
||||
|
||||
// Add equations
|
||||
model.add_structural_equation(edu, &[age], Mechanism::new(|p| {
|
||||
Value::Continuous(12.0 + p[0].as_f64() * 0.1)
|
||||
})).unwrap();
|
||||
|
||||
model.add_structural_equation(income, &[edu], Mechanism::new(|p| {
|
||||
Value::Continuous(30000.0 + p[0].as_f64() * 5000.0)
|
||||
})).unwrap();
|
||||
|
||||
model.add_structural_equation(health, &[age], Mechanism::new(|p| {
|
||||
Value::Continuous(100.0 - p[0].as_f64() * 0.5)
|
||||
})).unwrap();
|
||||
|
||||
model
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_belief_creation() {
|
||||
let belief = Belief::causes("Age", "Education").with_confidence(0.9);
|
||||
assert_eq!(belief.subject, "Age");
|
||||
assert_eq!(belief.object, "Education");
|
||||
assert_eq!(belief.confidence, 0.9);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_causal_consistency() {
|
||||
let model = create_test_model();
|
||||
let checker = CausalCoherenceChecker::new(&model);
|
||||
|
||||
let beliefs = vec![
|
||||
Belief::causes("Age", "Education"),
|
||||
Belief::causes("Education", "Income"),
|
||||
];
|
||||
|
||||
let result = checker.check_causal_consistency(&beliefs);
|
||||
|
||||
assert!(result.is_consistent());
|
||||
assert_eq!(result.consistent_beliefs, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_inconsistent_belief() {
|
||||
let model = create_test_model();
|
||||
let checker = CausalCoherenceChecker::new(&model);
|
||||
|
||||
let beliefs = vec![
|
||||
Belief::causes("Income", "Age"), // Wrong direction
|
||||
];
|
||||
|
||||
let result = checker.check_causal_consistency(&beliefs);
|
||||
|
||||
assert!(!result.is_consistent());
|
||||
assert_eq!(result.inconsistent_beliefs, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conditional_independence() {
|
||||
let model = create_test_model();
|
||||
let checker = CausalCoherenceChecker::new(&model);
|
||||
|
||||
// Education and Health should be independent given Age
|
||||
let beliefs = vec![
|
||||
Belief::conditionally_independent("Education", "Health", &["Age"]),
|
||||
];
|
||||
|
||||
let result = checker.check_causal_consistency(&beliefs);
|
||||
|
||||
assert!(result.is_consistent());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_spurious_correlation_detection() {
|
||||
let model = create_test_model();
|
||||
let checker = CausalCoherenceChecker::new(&model).with_correlation_threshold(0.1);
|
||||
|
||||
// Create dataset with Education-Health correlation (spurious via Age)
|
||||
let mut data = Dataset::new(vec![
|
||||
"Age".to_string(),
|
||||
"Education".to_string(),
|
||||
"Health".to_string(),
|
||||
]);
|
||||
|
||||
// Add correlated data
|
||||
for i in 0..100 {
|
||||
let age = 20.0 + i as f64 * 0.5;
|
||||
let edu = 12.0 + age * 0.1 + (i as f64 * 0.1).sin();
|
||||
let health = 100.0 - age * 0.5 + (i as f64 * 0.2).cos();
|
||||
data.add_row(vec![age, edu, health]);
|
||||
}
|
||||
|
||||
let spurious = checker.detect_spurious_correlations(&data);
|
||||
|
||||
// Should detect Education-Health as spurious (both caused by Age)
|
||||
let edu_health = spurious.iter()
|
||||
.find(|s| (s.var_a == "Education" && s.var_b == "Health") ||
|
||||
(s.var_a == "Health" && s.var_b == "Education"));
|
||||
|
||||
assert!(edu_health.is_some());
|
||||
|
||||
if let Some(s) = edu_health {
|
||||
assert!(s.confounders.contains(&"Age".to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_interventional_query() {
|
||||
let model = create_test_model();
|
||||
let checker = CausalCoherenceChecker::new(&model);
|
||||
|
||||
let query = CausalQuery::interventional(
|
||||
"Income",
|
||||
"Education",
|
||||
Value::Continuous(16.0),
|
||||
);
|
||||
|
||||
let answer = checker.enforce_do_calculus(&query).unwrap();
|
||||
|
||||
assert!(answer.is_identifiable);
|
||||
assert!(matches!(answer.query.query_type, QueryType::Interventional));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_coherence_energy() {
|
||||
let energy = CoherenceEnergy::from_components(0.1, 0.2, 0.05);
|
||||
|
||||
assert!((energy.total - 0.35).abs() < 1e-10);
|
||||
assert!(!energy.is_coherent);
|
||||
|
||||
let coherent = CoherenceEnergy::coherent();
|
||||
assert!(coherent.is_coherent);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dataset_correlation() {
|
||||
let mut data = Dataset::new(vec!["X".to_string(), "Y".to_string()]);
|
||||
|
||||
// Perfect positive correlation
|
||||
for i in 0..10 {
|
||||
data.add_row(vec![i as f64, i as f64]);
|
||||
}
|
||||
|
||||
let corr = data.correlation("X", "Y").unwrap();
|
||||
assert!((corr - 1.0).abs() < 1e-10);
|
||||
|
||||
// Add negatively correlated data
|
||||
let mut data2 = Dataset::new(vec!["A".to_string(), "B".to_string()]);
|
||||
for i in 0..10 {
|
||||
data2.add_row(vec![i as f64, (10 - i) as f64]);
|
||||
}
|
||||
|
||||
let corr2 = data2.correlation("A", "B").unwrap();
|
||||
assert!((corr2 + 1.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_causal_query_builder() {
|
||||
let query = CausalQuery::interventional("Y", "X", Value::Continuous(1.0))
|
||||
.given("Z", Value::Continuous(2.0));
|
||||
|
||||
assert_eq!(query.target, "Y");
|
||||
assert_eq!(query.interventions.len(), 1);
|
||||
assert_eq!(query.conditions.len(), 1);
|
||||
}
|
||||
}
|
||||
805
vendor/ruvector/examples/prime-radiant/src/causal/counterfactual.rs
vendored
Normal file
805
vendor/ruvector/examples/prime-radiant/src/causal/counterfactual.rs
vendored
Normal file
@@ -0,0 +1,805 @@
|
||||
//! Counterfactual Reasoning
|
||||
//!
|
||||
//! This module implements counterfactual inference based on Pearl's three-step
|
||||
//! procedure: Abduction, Action, Prediction.
|
||||
//!
|
||||
//! ## Counterfactual Semantics
|
||||
//!
|
||||
//! Given a structural causal model M = (U, V, F), a counterfactual query asks:
|
||||
//! "What would Y have been if X had been x, given that we observed E = e?"
|
||||
//!
|
||||
//! Written as: P(Y_x | E = e) or Y_{X=x}(u) where u is the exogenous state.
|
||||
//!
|
||||
//! ## Three-Step Procedure
|
||||
//!
|
||||
//! 1. **Abduction**: Update P(U) given evidence E = e to get P(U | E = e)
|
||||
//! 2. **Action**: Modify the model by intervention do(X = x)
|
||||
//! 3. **Prediction**: Compute Y in the modified model using updated U
|
||||
//!
|
||||
//! ## References
|
||||
//!
|
||||
//! - Pearl (2009): "Causality" Chapter 7
|
||||
//! - Halpern (2016): "Actual Causality"
|
||||
|
||||
use std::collections::HashMap;
|
||||
use thiserror::Error;
|
||||
|
||||
use super::model::{
|
||||
CausalModel, CausalModelError, Intervention, Value, VariableId, Mechanism,
|
||||
Observation,
|
||||
};
|
||||
|
||||
/// Error types for counterfactual reasoning
|
||||
#[derive(Debug, Clone, Error)]
|
||||
pub enum CounterfactualError {
|
||||
/// Model error
|
||||
#[error("Model error: {0}")]
|
||||
ModelError(#[from] CausalModelError),
|
||||
|
||||
/// Invalid observation
|
||||
#[error("Invalid observation: variable '{0}' not in model")]
|
||||
InvalidObservation(String),
|
||||
|
||||
/// Abduction failed
|
||||
#[error("Abduction failed: {0}")]
|
||||
AbductionFailed(String),
|
||||
|
||||
/// Counterfactual not well-defined
|
||||
#[error("Counterfactual not well-defined: {0}")]
|
||||
NotWellDefined(String),
|
||||
}
|
||||
|
||||
/// Extended Distribution type for counterfactual results
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CounterfactualDistribution {
|
||||
/// Point estimate values (for deterministic models)
|
||||
pub values: HashMap<VariableId, Value>,
|
||||
/// Probability mass (for discrete) or density (for continuous)
|
||||
pub probability: f64,
|
||||
}
|
||||
|
||||
impl CounterfactualDistribution {
|
||||
/// Create a point mass distribution
|
||||
pub fn point_mass(values: HashMap<VariableId, Value>) -> Self {
|
||||
Self {
|
||||
values,
|
||||
probability: 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create from a simulation result
|
||||
pub fn from_simulation(values: HashMap<VariableId, Value>) -> Self {
|
||||
Self::point_mass(values)
|
||||
}
|
||||
|
||||
/// Get value for a variable
|
||||
pub fn get(&self, var: VariableId) -> Option<&Value> {
|
||||
self.values.get(&var)
|
||||
}
|
||||
|
||||
/// Get mean value (for continuous distributions)
|
||||
pub fn mean(&self, var: VariableId) -> f64 {
|
||||
self.values.get(&var)
|
||||
.map(|v| v.as_f64())
|
||||
.unwrap_or(0.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// A counterfactual query
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CounterfactualQuery {
|
||||
/// Target variable (what we want to know)
|
||||
pub target: String,
|
||||
/// Interventions (what we're hypothetically changing)
|
||||
pub interventions: Vec<(String, Value)>,
|
||||
/// Evidence (what we observed)
|
||||
pub evidence: Observation,
|
||||
}
|
||||
|
||||
impl CounterfactualQuery {
|
||||
/// Create a new counterfactual query
|
||||
///
|
||||
/// Asking: "What would `target` have been if we had done `interventions`,
|
||||
/// given that we observed `evidence`?"
|
||||
pub fn new(target: &str, interventions: Vec<(&str, Value)>, evidence: Observation) -> Self {
|
||||
Self {
|
||||
target: target.to_string(),
|
||||
interventions: interventions.into_iter()
|
||||
.map(|(k, v)| (k.to_string(), v))
|
||||
.collect(),
|
||||
evidence,
|
||||
}
|
||||
}
|
||||
|
||||
/// Simple counterfactual: "What would Y have been if X had been x?"
|
||||
pub fn simple(target: &str, intervention_var: &str, intervention_val: Value) -> Self {
|
||||
Self {
|
||||
target: target.to_string(),
|
||||
interventions: vec![(intervention_var.to_string(), intervention_val)],
|
||||
evidence: Observation::empty(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add evidence to the query
|
||||
pub fn given(mut self, var: &str, value: Value) -> Self {
|
||||
self.evidence.observe(var, value);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of a counterfactual query
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CounterfactualResult {
|
||||
/// The query that was answered
|
||||
pub query: CounterfactualQuery,
|
||||
/// The counterfactual distribution
|
||||
pub distribution: CounterfactualDistribution,
|
||||
/// The abduced exogenous values
|
||||
pub exogenous: HashMap<VariableId, Value>,
|
||||
/// Explanation of the reasoning
|
||||
pub explanation: String,
|
||||
}
|
||||
|
||||
/// Average Treatment Effect computation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AverageTreatmentEffect {
|
||||
/// Treatment variable
|
||||
pub treatment: String,
|
||||
/// Outcome variable
|
||||
pub outcome: String,
|
||||
/// Treatment value
|
||||
pub treatment_value: Value,
|
||||
/// Control value
|
||||
pub control_value: Value,
|
||||
/// Estimated ATE
|
||||
pub ate: f64,
|
||||
/// Standard error (if available)
|
||||
pub standard_error: Option<f64>,
|
||||
}
|
||||
|
||||
impl AverageTreatmentEffect {
|
||||
/// Create a new ATE result
|
||||
pub fn new(
|
||||
treatment: &str,
|
||||
outcome: &str,
|
||||
treatment_value: Value,
|
||||
control_value: Value,
|
||||
ate: f64,
|
||||
) -> Self {
|
||||
Self {
|
||||
treatment: treatment.to_string(),
|
||||
outcome: outcome.to_string(),
|
||||
treatment_value,
|
||||
control_value,
|
||||
ate,
|
||||
standard_error: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set standard error
|
||||
pub fn with_standard_error(mut self, se: f64) -> Self {
|
||||
self.standard_error = Some(se);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute a counterfactual: "What would Y have been if X had been x, given observation?"
|
||||
///
|
||||
/// This implements Pearl's three-step procedure:
|
||||
/// 1. Abduction: Infer exogenous variables from observation
|
||||
/// 2. Action: Apply intervention do(X = x)
|
||||
/// 3. Prediction: Compute Y under intervention with abduced exogenous values
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `model` - The causal model
|
||||
/// * `observation` - The observed evidence
|
||||
/// * `intervention_var` - The variable to intervene on
|
||||
/// * `intervention_value` - The counterfactual value for the intervention
|
||||
/// * `target_name` - The target variable to compute the counterfactual for
|
||||
pub fn counterfactual(
|
||||
model: &CausalModel,
|
||||
observation: &Observation,
|
||||
intervention_var: VariableId,
|
||||
intervention_value: Value,
|
||||
target_name: &str,
|
||||
) -> Result<Value, CounterfactualError> {
|
||||
// Step 1: Abduction - infer exogenous variables
|
||||
let exogenous = abduce(model, observation)?;
|
||||
|
||||
// Step 2: Action - create intervened model
|
||||
let intervention = Intervention::new(intervention_var, intervention_value);
|
||||
let intervened = model.intervene_with(&[intervention])?;
|
||||
|
||||
// Step 3: Prediction - simulate with abduced exogenous and intervention
|
||||
let result = intervened.simulate(&exogenous)?;
|
||||
|
||||
// Get the target variable
|
||||
let target_id = model.get_variable_id(target_name)
|
||||
.ok_or_else(|| CounterfactualError::InvalidObservation(target_name.to_string()))?;
|
||||
|
||||
result.get(&target_id)
|
||||
.cloned()
|
||||
.ok_or_else(|| CounterfactualError::AbductionFailed(
|
||||
format!("Target variable {} not computed", target_name)
|
||||
))
|
||||
}
|
||||
|
||||
/// Compute a counterfactual with an Intervention struct (alternative API)
|
||||
pub fn counterfactual_with_intervention(
|
||||
model: &CausalModel,
|
||||
observation: &Observation,
|
||||
intervention: &Intervention,
|
||||
) -> Result<CounterfactualDistribution, CounterfactualError> {
|
||||
// Step 1: Abduction - infer exogenous variables
|
||||
let exogenous = abduce(model, observation)?;
|
||||
|
||||
// Step 2: Action - create intervened model
|
||||
let intervened = model.intervene_with(&[intervention.clone()])?;
|
||||
|
||||
// Step 3: Prediction - simulate with abduced exogenous and intervention
|
||||
let result = intervened.simulate(&exogenous)?;
|
||||
|
||||
Ok(CounterfactualDistribution::from_simulation(result))
|
||||
}
|
||||
|
||||
/// Compute counterfactual from a query object
|
||||
pub fn counterfactual_query(
|
||||
model: &CausalModel,
|
||||
query: &CounterfactualQuery,
|
||||
) -> Result<CounterfactualResult, CounterfactualError> {
|
||||
// Convert interventions
|
||||
let interventions: Result<Vec<Intervention>, _> = query.interventions.iter()
|
||||
.map(|(var, val)| {
|
||||
model.get_variable_id(var)
|
||||
.map(|id| Intervention::new(id, val.clone()))
|
||||
.ok_or_else(|| CounterfactualError::InvalidObservation(var.clone()))
|
||||
})
|
||||
.collect();
|
||||
let interventions = interventions?;
|
||||
|
||||
// Step 1: Abduction
|
||||
let exogenous = abduce(model, &query.evidence)?;
|
||||
|
||||
// Step 2 & 3: Action and Prediction
|
||||
let intervened = model.intervene_with(&interventions)?;
|
||||
let result = intervened.simulate(&exogenous)?;
|
||||
|
||||
let target_id = model.get_variable_id(&query.target)
|
||||
.ok_or_else(|| CounterfactualError::InvalidObservation(query.target.clone()))?;
|
||||
|
||||
let explanation = format!(
|
||||
"Counterfactual computed via three-step procedure:\n\
|
||||
1. Abduced exogenous values from evidence\n\
|
||||
2. Applied intervention(s): {}\n\
|
||||
3. Predicted {} under intervention",
|
||||
query.interventions.iter()
|
||||
.map(|(v, val)| format!("do({}={:?})", v, val))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", "),
|
||||
query.target
|
||||
);
|
||||
|
||||
Ok(CounterfactualResult {
|
||||
query: query.clone(),
|
||||
distribution: CounterfactualDistribution::from_simulation(result),
|
||||
exogenous,
|
||||
explanation,
|
||||
})
|
||||
}
|
||||
|
||||
/// Abduction: Infer exogenous variable values from observations
|
||||
///
|
||||
/// For deterministic models, this inverts the structural equations
|
||||
fn abduce(
|
||||
model: &CausalModel,
|
||||
observation: &Observation,
|
||||
) -> Result<HashMap<VariableId, Value>, CounterfactualError> {
|
||||
let mut exogenous = HashMap::new();
|
||||
|
||||
// Get topological order
|
||||
let topo_order = model.topological_order_ids()?;
|
||||
|
||||
// For each variable in topological order
|
||||
for var_id in topo_order {
|
||||
let var = model.get_variable(var_id.as_ref())
|
||||
.ok_or_else(|| CounterfactualError::AbductionFailed(
|
||||
format!("Variable {:?} not found", var_id)
|
||||
))?;
|
||||
|
||||
// Check if this variable is observed
|
||||
if let Some(observed_value) = observation.values.get(&var.name) {
|
||||
// If this is a root variable (no parents), it's exogenous
|
||||
if model.parents(&var_id).map(|p| p.is_empty()).unwrap_or(true) {
|
||||
exogenous.insert(var_id, observed_value.clone());
|
||||
} else {
|
||||
// For endogenous variables, we might need to compute the residual
|
||||
// For now, we use the observed value as the exogenous noise
|
||||
exogenous.insert(var_id, observed_value.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(exogenous)
|
||||
}
|
||||
|
||||
/// Compute the Average Treatment Effect (ATE)
|
||||
///
|
||||
/// ATE = E[Y | do(X=treatment_value)] - E[Y | do(X=control_value)]
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `model` - The causal model
|
||||
/// * `treatment` - Treatment variable ID
|
||||
/// * `outcome` - Outcome variable ID
|
||||
/// * `treatment_value` - The treatment value
|
||||
/// * `control_value` - The control value
|
||||
pub fn causal_effect(
|
||||
model: &CausalModel,
|
||||
treatment: VariableId,
|
||||
outcome: VariableId,
|
||||
treatment_value: Value,
|
||||
control_value: Value,
|
||||
) -> Result<f64, CounterfactualError> {
|
||||
causal_effect_at_values(
|
||||
model,
|
||||
treatment,
|
||||
outcome,
|
||||
treatment_value,
|
||||
control_value,
|
||||
)
|
||||
}
|
||||
|
||||
/// Compute the Average Treatment Effect with default binary values
|
||||
///
|
||||
/// ATE = E[Y | do(X=1)] - E[Y | do(X=0)]
|
||||
pub fn causal_effect_binary(
|
||||
model: &CausalModel,
|
||||
treatment: VariableId,
|
||||
outcome: VariableId,
|
||||
) -> Result<f64, CounterfactualError> {
|
||||
causal_effect_at_values(
|
||||
model,
|
||||
treatment,
|
||||
outcome,
|
||||
Value::Continuous(1.0),
|
||||
Value::Continuous(0.0),
|
||||
)
|
||||
}
|
||||
|
||||
/// Compute causal effect at specific treatment values
|
||||
pub fn causal_effect_at_values(
|
||||
model: &CausalModel,
|
||||
treatment: VariableId,
|
||||
outcome: VariableId,
|
||||
treatment_value: Value,
|
||||
control_value: Value,
|
||||
) -> Result<f64, CounterfactualError> {
|
||||
// E[Y | do(X = treatment)]
|
||||
let intervention_treat = Intervention::new(treatment, treatment_value);
|
||||
let intervened_treat = model.intervene_with(&[intervention_treat])?;
|
||||
let result_treat = intervened_treat.simulate(&HashMap::new())?;
|
||||
let y_treat = result_treat.get(&outcome)
|
||||
.map(|v| v.as_f64())
|
||||
.unwrap_or(0.0);
|
||||
|
||||
// E[Y | do(X = control)]
|
||||
let intervention_ctrl = Intervention::new(treatment, control_value);
|
||||
let intervened_ctrl = model.intervene_with(&[intervention_ctrl])?;
|
||||
let result_ctrl = intervened_ctrl.simulate(&HashMap::new())?;
|
||||
let y_ctrl = result_ctrl.get(&outcome)
|
||||
.map(|v| v.as_f64())
|
||||
.unwrap_or(0.0);
|
||||
|
||||
Ok(y_treat - y_ctrl)
|
||||
}
|
||||
|
||||
/// Compute ATE with full result structure
|
||||
pub fn average_treatment_effect(
|
||||
model: &CausalModel,
|
||||
treatment_name: &str,
|
||||
outcome_name: &str,
|
||||
treatment_value: Value,
|
||||
control_value: Value,
|
||||
) -> Result<AverageTreatmentEffect, CounterfactualError> {
|
||||
let treatment_id = model.get_variable_id(treatment_name)
|
||||
.ok_or_else(|| CounterfactualError::InvalidObservation(treatment_name.to_string()))?;
|
||||
let outcome_id = model.get_variable_id(outcome_name)
|
||||
.ok_or_else(|| CounterfactualError::InvalidObservation(outcome_name.to_string()))?;
|
||||
|
||||
let ate = causal_effect_at_values(
|
||||
model,
|
||||
treatment_id,
|
||||
outcome_id,
|
||||
treatment_value.clone(),
|
||||
control_value.clone(),
|
||||
)?;
|
||||
|
||||
Ok(AverageTreatmentEffect::new(
|
||||
treatment_name,
|
||||
outcome_name,
|
||||
treatment_value,
|
||||
control_value,
|
||||
ate,
|
||||
))
|
||||
}
|
||||
|
||||
/// Compute Individual Treatment Effect (ITE) for a specific unit
|
||||
///
|
||||
/// ITE_i = Y_i(X=1) - Y_i(X=0)
|
||||
///
|
||||
/// This is a counterfactual quantity: what would have happened to unit i
|
||||
/// under different treatment assignments.
|
||||
pub fn individual_treatment_effect(
|
||||
model: &CausalModel,
|
||||
treatment: VariableId,
|
||||
outcome: VariableId,
|
||||
unit_observation: &Observation,
|
||||
treatment_value: Value,
|
||||
control_value: Value,
|
||||
) -> Result<f64, CounterfactualError> {
|
||||
// Get the outcome variable name
|
||||
let outcome_name = model.get_variable_name(&outcome)
|
||||
.ok_or_else(|| CounterfactualError::InvalidObservation(format!("Outcome variable {:?} not found", outcome)))?;
|
||||
|
||||
// Counterfactual: Y(X=treatment) for this unit
|
||||
let y_treat = counterfactual(model, unit_observation, treatment, treatment_value, &outcome_name)?;
|
||||
let y_treat_val = y_treat.as_f64();
|
||||
|
||||
// Counterfactual: Y(X=control) for this unit
|
||||
let y_ctrl = counterfactual(model, unit_observation, treatment, control_value, &outcome_name)?;
|
||||
let y_ctrl_val = y_ctrl.as_f64();
|
||||
|
||||
Ok(y_treat_val - y_ctrl_val)
|
||||
}
|
||||
|
||||
/// Natural Direct Effect (NDE)
|
||||
///
|
||||
/// NDE = E[Y(x, M(x'))] - E[Y(x', M(x'))]
|
||||
///
|
||||
/// The effect of X on Y that would remain if the mediator were held at the
|
||||
/// value it would have taken under X = x'.
|
||||
pub fn natural_direct_effect(
|
||||
model: &CausalModel,
|
||||
treatment: VariableId,
|
||||
mediator: VariableId,
|
||||
outcome: VariableId,
|
||||
treatment_value: Value,
|
||||
control_value: Value,
|
||||
) -> Result<f64, CounterfactualError> {
|
||||
// E[Y(x', M(x'))] - baseline
|
||||
let ctrl_intervention = Intervention::new(treatment, control_value.clone());
|
||||
let intervened = model.intervene_with(&[ctrl_intervention.clone()])?;
|
||||
let baseline_result = intervened.simulate(&HashMap::new())?;
|
||||
let m_ctrl = baseline_result.get(&mediator).cloned().unwrap_or(Value::Missing);
|
||||
let y_baseline = baseline_result.get(&outcome)
|
||||
.map(|v| v.as_f64())
|
||||
.unwrap_or(0.0);
|
||||
|
||||
// E[Y(x, M(x'))] - intervene on X but keep M at control level
|
||||
let treat_intervention = Intervention::new(treatment, treatment_value);
|
||||
let m_intervention = Intervention::new(mediator, m_ctrl);
|
||||
let intervened = model.intervene_with(&[treat_intervention, m_intervention])?;
|
||||
let nde_result = intervened.simulate(&HashMap::new())?;
|
||||
let y_nde = nde_result.get(&outcome)
|
||||
.map(|v| v.as_f64())
|
||||
.unwrap_or(0.0);
|
||||
|
||||
Ok(y_nde - y_baseline)
|
||||
}
|
||||
|
||||
/// Natural Indirect Effect (NIE)
|
||||
///
|
||||
/// NIE = E[Y(x, M(x))] - E[Y(x, M(x'))]
|
||||
///
|
||||
/// The effect of X on Y that is mediated through M.
|
||||
pub fn natural_indirect_effect(
|
||||
model: &CausalModel,
|
||||
treatment: VariableId,
|
||||
mediator: VariableId,
|
||||
outcome: VariableId,
|
||||
treatment_value: Value,
|
||||
control_value: Value,
|
||||
) -> Result<f64, CounterfactualError> {
|
||||
// E[Y(x, M(x))] - full treatment effect
|
||||
let treat_intervention = Intervention::new(treatment, treatment_value.clone());
|
||||
let intervened = model.intervene_with(&[treat_intervention.clone()])?;
|
||||
let full_result = intervened.simulate(&HashMap::new())?;
|
||||
let y_full = full_result.get(&outcome)
|
||||
.map(|v| v.as_f64())
|
||||
.unwrap_or(0.0);
|
||||
|
||||
// E[Y(x, M(x'))] - treatment but mediator at control level
|
||||
let ctrl_intervention = Intervention::new(treatment, control_value);
|
||||
let ctrl_intervened = model.intervene_with(&[ctrl_intervention])?;
|
||||
let ctrl_result = ctrl_intervened.simulate(&HashMap::new())?;
|
||||
let m_ctrl = ctrl_result.get(&mediator).cloned().unwrap_or(Value::Missing);
|
||||
|
||||
let m_intervention = Intervention::new(mediator, m_ctrl);
|
||||
let intervened = model.intervene_with(&[treat_intervention, m_intervention])?;
|
||||
let indirect_result = intervened.simulate(&HashMap::new())?;
|
||||
let y_indirect = indirect_result.get(&outcome)
|
||||
.map(|v| v.as_f64())
|
||||
.unwrap_or(0.0);
|
||||
|
||||
Ok(y_full - y_indirect)
|
||||
}
|
||||
|
||||
/// Probability of Necessity (PN)
|
||||
///
|
||||
/// PN = P(Y_x' = 0 | X = x, Y = 1)
|
||||
///
|
||||
/// Given that X=x and Y=1 occurred, what is the probability that Y would
|
||||
/// have been 0 if X had been x' instead?
|
||||
pub fn probability_of_necessity(
|
||||
model: &CausalModel,
|
||||
treatment: VariableId,
|
||||
outcome: VariableId,
|
||||
observation: &Observation,
|
||||
counterfactual_treatment: Value,
|
||||
) -> Result<f64, CounterfactualError> {
|
||||
// Get outcome variable name
|
||||
let outcome_name = model.get_variable_name(&outcome)
|
||||
.ok_or_else(|| CounterfactualError::InvalidObservation(format!("Outcome variable {:?} not found", outcome)))?;
|
||||
|
||||
// Compute counterfactual outcome
|
||||
let cf_value = counterfactual(model, observation, treatment, counterfactual_treatment, &outcome_name)?;
|
||||
|
||||
let cf_outcome = cf_value.as_f64();
|
||||
|
||||
// PN is probability that outcome would be 0 (negative)
|
||||
// For continuous outcomes, we check if it crosses the threshold
|
||||
let observed_outcome = observation.values.iter()
|
||||
.find_map(|(name, val)| {
|
||||
model.get_variable_id(name)
|
||||
.filter(|id| *id == outcome)
|
||||
.map(|_| val.as_f64())
|
||||
})
|
||||
.unwrap_or(0.0);
|
||||
|
||||
// Simple heuristic: if counterfactual outcome is significantly different
|
||||
if observed_outcome > 0.0 && cf_outcome <= 0.0 {
|
||||
Ok(1.0) // Necessary
|
||||
} else if (observed_outcome - cf_outcome).abs() < 1e-6 {
|
||||
Ok(0.0) // Not necessary
|
||||
} else {
|
||||
Ok(0.5) // Uncertain
|
||||
}
|
||||
}
|
||||
|
||||
/// Probability of Sufficiency (PS)
|
||||
///
|
||||
/// PS = P(Y_x = 1 | X = x', Y = 0)
|
||||
///
|
||||
/// Given that X=x' and Y=0 occurred, what is the probability that Y would
|
||||
/// have been 1 if X had been x instead?
|
||||
pub fn probability_of_sufficiency(
|
||||
model: &CausalModel,
|
||||
treatment: VariableId,
|
||||
outcome: VariableId,
|
||||
observation: &Observation,
|
||||
counterfactual_treatment: Value,
|
||||
) -> Result<f64, CounterfactualError> {
|
||||
// Get outcome variable name
|
||||
let outcome_name = model.get_variable_name(&outcome)
|
||||
.ok_or_else(|| CounterfactualError::InvalidObservation(format!("Outcome variable {:?} not found", outcome)))?;
|
||||
|
||||
// Compute counterfactual outcome
|
||||
let cf_value = counterfactual(model, observation, treatment, counterfactual_treatment, &outcome_name)?;
|
||||
|
||||
let cf_outcome = cf_value.as_f64();
|
||||
|
||||
let observed_outcome = observation.values.iter()
|
||||
.find_map(|(name, val)| {
|
||||
model.get_variable_id(name)
|
||||
.filter(|id| *id == outcome)
|
||||
.map(|_| val.as_f64())
|
||||
})
|
||||
.unwrap_or(1.0);
|
||||
|
||||
// PS: would the outcome have been positive if treatment were different?
|
||||
if observed_outcome <= 0.0 && cf_outcome > 0.0 {
|
||||
Ok(1.0) // Sufficient
|
||||
} else if (observed_outcome - cf_outcome).abs() < 1e-6 {
|
||||
Ok(0.0) // Not sufficient
|
||||
} else {
|
||||
Ok(0.5) // Uncertain
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::causal::model::{CausalModelBuilder, VariableType, Mechanism};
|
||||
|
||||
fn create_simple_model() -> CausalModel {
|
||||
let mut model = CausalModel::with_name("Simple");
|
||||
|
||||
model.add_variable("X", VariableType::Continuous).unwrap();
|
||||
model.add_variable("Y", VariableType::Continuous).unwrap();
|
||||
|
||||
let x = model.get_variable_id("X").unwrap();
|
||||
let y = model.get_variable_id("Y").unwrap();
|
||||
|
||||
// Y = 2*X + 1
|
||||
model.add_structural_equation(y, &[x], Mechanism::new(|p| {
|
||||
Value::Continuous(2.0 * p[0].as_f64() + 1.0)
|
||||
})).unwrap();
|
||||
|
||||
model
|
||||
}
|
||||
|
||||
fn create_mediation_model() -> CausalModel {
|
||||
let mut model = CausalModel::with_name("Mediation");
|
||||
|
||||
model.add_variable("X", VariableType::Continuous).unwrap();
|
||||
model.add_variable("M", VariableType::Continuous).unwrap();
|
||||
model.add_variable("Y", VariableType::Continuous).unwrap();
|
||||
|
||||
let x = model.get_variable_id("X").unwrap();
|
||||
let m = model.get_variable_id("M").unwrap();
|
||||
let y = model.get_variable_id("Y").unwrap();
|
||||
|
||||
// M = X
|
||||
model.add_structural_equation(m, &[x], Mechanism::new(|p| {
|
||||
p[0].clone()
|
||||
})).unwrap();
|
||||
|
||||
// Y = M + 0.5*X
|
||||
model.add_structural_equation(y, &[m, x], Mechanism::new(|p| {
|
||||
Value::Continuous(p[0].as_f64() + 0.5 * p[1].as_f64())
|
||||
})).unwrap();
|
||||
|
||||
model
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_observation() {
|
||||
let mut obs = Observation::new(&[("X", Value::Continuous(5.0))]);
|
||||
obs.observe("Y", Value::Continuous(11.0));
|
||||
|
||||
assert!(obs.is_observed("X"));
|
||||
assert!(obs.is_observed("Y"));
|
||||
assert!(!obs.is_observed("Z"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_counterfactual_simple() {
|
||||
let model = create_simple_model();
|
||||
|
||||
let x_id = model.get_variable_id("X").unwrap();
|
||||
|
||||
// Observation: X=3, Y=7 (since Y = 2*3 + 1)
|
||||
let observation = Observation::new(&[
|
||||
("X", Value::Continuous(3.0)),
|
||||
("Y", Value::Continuous(7.0)),
|
||||
]);
|
||||
|
||||
// Counterfactual: What would Y have been if X had been 5?
|
||||
let intervention = Intervention::new(x_id, Value::Continuous(5.0));
|
||||
|
||||
let result = counterfactual(&model, &observation, &intervention).unwrap();
|
||||
|
||||
// Y should be 2*5 + 1 = 11
|
||||
let y_id = model.get_variable_id("Y").unwrap();
|
||||
let y_value = result.get(y_id).unwrap().as_f64();
|
||||
|
||||
assert!((y_value - 11.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_causal_effect() {
|
||||
let model = create_simple_model();
|
||||
|
||||
let x = model.get_variable_id("X").unwrap();
|
||||
let y = model.get_variable_id("Y").unwrap();
|
||||
|
||||
// ATE = E[Y|do(X=1)] - E[Y|do(X=0)]
|
||||
// = (2*1 + 1) - (2*0 + 1) = 3 - 1 = 2
|
||||
let ate = causal_effect(&model, x, y).unwrap();
|
||||
|
||||
assert!((ate - 2.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_average_treatment_effect() {
|
||||
let model = create_simple_model();
|
||||
|
||||
let ate_result = average_treatment_effect(
|
||||
&model,
|
||||
"X", "Y",
|
||||
Value::Continuous(1.0),
|
||||
Value::Continuous(0.0),
|
||||
).unwrap();
|
||||
|
||||
assert_eq!(ate_result.treatment, "X");
|
||||
assert_eq!(ate_result.outcome, "Y");
|
||||
assert!((ate_result.ate - 2.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mediation_effects() {
|
||||
let model = create_mediation_model();
|
||||
|
||||
let x = model.get_variable_id("X").unwrap();
|
||||
let m = model.get_variable_id("M").unwrap();
|
||||
let y = model.get_variable_id("Y").unwrap();
|
||||
|
||||
let treat = Value::Continuous(1.0);
|
||||
let ctrl = Value::Continuous(0.0);
|
||||
|
||||
// Total effect should be:
|
||||
// E[Y|do(X=1)] - E[Y|do(X=0)]
|
||||
// = (M(1) + 0.5*1) - (M(0) + 0.5*0)
|
||||
// = (1 + 0.5) - (0 + 0) = 1.5
|
||||
let total = causal_effect_at_values(&model, x, y, treat.clone(), ctrl.clone()).unwrap();
|
||||
assert!((total - 1.5).abs() < 1e-10);
|
||||
|
||||
// NDE should be the direct effect = 0.5 (coefficient of X in Y equation)
|
||||
let nde = natural_direct_effect(&model, x, m, y, treat.clone(), ctrl.clone()).unwrap();
|
||||
assert!((nde - 0.5).abs() < 1e-10);
|
||||
|
||||
// NIE should be the indirect effect = 1.0 (coefficient of M in Y, times effect of X on M)
|
||||
let nie = natural_indirect_effect(&model, x, m, y, treat, ctrl).unwrap();
|
||||
assert!((nie - 1.0).abs() < 1e-10);
|
||||
|
||||
// NDE + NIE should equal total effect
|
||||
assert!((nde + nie - total).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_counterfactual_query() {
|
||||
let model = create_simple_model();
|
||||
|
||||
let query = CounterfactualQuery::new(
|
||||
"Y",
|
||||
vec![("X", Value::Continuous(10.0))],
|
||||
Observation::new(&[("X", Value::Continuous(3.0))]),
|
||||
);
|
||||
|
||||
let result = counterfactual_query(&model, &query).unwrap();
|
||||
|
||||
// Y = 2*10 + 1 = 21
|
||||
let y_id = model.get_variable_id("Y").unwrap();
|
||||
assert!((result.distribution.mean(y_id) - 21.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_distribution() {
|
||||
let mut values = HashMap::new();
|
||||
let x_id = VariableId(0);
|
||||
let y_id = VariableId(1);
|
||||
|
||||
values.insert(x_id, Value::Continuous(5.0));
|
||||
values.insert(y_id, Value::Continuous(10.0));
|
||||
|
||||
let dist = CounterfactualDistribution::point_mass(values);
|
||||
|
||||
assert_eq!(dist.mean(x_id), 5.0);
|
||||
assert_eq!(dist.mean(y_id), 10.0);
|
||||
assert_eq!(dist.probability, 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_individual_treatment_effect() {
|
||||
let model = create_simple_model();
|
||||
|
||||
let x = model.get_variable_id("X").unwrap();
|
||||
let y = model.get_variable_id("Y").unwrap();
|
||||
|
||||
// Unit-specific observation
|
||||
let unit_obs = Observation::new(&[
|
||||
("X", Value::Continuous(3.0)),
|
||||
("Y", Value::Continuous(7.0)),
|
||||
]);
|
||||
|
||||
let ite = individual_treatment_effect(
|
||||
&model,
|
||||
x, y,
|
||||
&unit_obs,
|
||||
Value::Continuous(5.0),
|
||||
Value::Continuous(3.0),
|
||||
).unwrap();
|
||||
|
||||
// ITE = Y(X=5) - Y(X=3) = 11 - 7 = 4
|
||||
assert!((ite - 4.0).abs() < 1e-10);
|
||||
}
|
||||
}
|
||||
920
vendor/ruvector/examples/prime-radiant/src/causal/do_calculus.rs
vendored
Normal file
920
vendor/ruvector/examples/prime-radiant/src/causal/do_calculus.rs
vendored
Normal file
@@ -0,0 +1,920 @@
|
||||
//! Do-Calculus Implementation
|
||||
//!
|
||||
//! This module implements Pearl's do-calculus, a complete set of inference rules
|
||||
//! for computing causal effects from observational data when possible.
|
||||
//!
|
||||
//! ## The Three Rules of Do-Calculus
|
||||
//!
|
||||
//! Given a causal DAG G, the following rules hold:
|
||||
//!
|
||||
//! **Rule 1 (Insertion/deletion of observations):**
|
||||
//! P(y | do(x), z, w) = P(y | do(x), w) if (Y ⊥ Z | X, W)_{G_{\overline{X}}}
|
||||
//!
|
||||
//! **Rule 2 (Action/observation exchange):**
|
||||
//! P(y | do(x), do(z), w) = P(y | do(x), z, w) if (Y ⊥ Z | X, W)_{G_{\overline{X}\underline{Z}}}
|
||||
//!
|
||||
//! **Rule 3 (Insertion/deletion of actions):**
|
||||
//! P(y | do(x), do(z), w) = P(y | do(x), w) if (Y ⊥ Z | X, W)_{G_{\overline{X}\overline{Z(W)}}}
|
||||
//!
|
||||
//! where:
|
||||
//! - G_{\overline{X}} is G with incoming edges to X deleted
|
||||
//! - G_{\underline{Z}} is G with outgoing edges from Z deleted
|
||||
//! - Z(W) is Z without ancestors of W in G_{\overline{X}}
|
||||
//!
|
||||
//! ## References
|
||||
//!
|
||||
//! - Pearl (1995): "Causal diagrams for empirical research"
|
||||
//! - Shpitser & Pearl (2006): "Identification of Joint Interventional Distributions"
|
||||
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use thiserror::Error;
|
||||
|
||||
use super::model::{CausalModel, VariableId};
|
||||
use super::graph::{DirectedGraph, DAGValidationError};
|
||||
|
||||
/// Error types for do-calculus operations
|
||||
#[derive(Debug, Clone, Error)]
|
||||
pub enum IdentificationError {
|
||||
/// Query is not identifiable
|
||||
#[error("Query is not identifiable: {0}")]
|
||||
NotIdentifiable(String),
|
||||
|
||||
/// Invalid query specification
|
||||
#[error("Invalid query: {0}")]
|
||||
InvalidQuery(String),
|
||||
|
||||
/// Graph manipulation error
|
||||
#[error("Graph error: {0}")]
|
||||
GraphError(#[from] DAGValidationError),
|
||||
|
||||
/// Variable not found
|
||||
#[error("Variable not found: {0}")]
|
||||
VariableNotFound(String),
|
||||
}
|
||||
|
||||
/// The three rules of do-calculus
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum Rule {
|
||||
/// Rule 1: Insertion/deletion of observations
|
||||
Rule1,
|
||||
/// Rule 2: Action/observation exchange
|
||||
Rule2,
|
||||
/// Rule 3: Insertion/deletion of actions
|
||||
Rule3,
|
||||
}
|
||||
|
||||
impl Rule {
|
||||
/// Get the name of the rule
|
||||
pub fn name(&self) -> &'static str {
|
||||
match self {
|
||||
Rule::Rule1 => "Insertion/deletion of observations",
|
||||
Rule::Rule2 => "Action/observation exchange",
|
||||
Rule::Rule3 => "Insertion/deletion of actions",
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a description of what the rule does
|
||||
pub fn description(&self) -> &'static str {
|
||||
match self {
|
||||
Rule::Rule1 => "Allows adding/removing observations that are d-separated from Y given do(X)",
|
||||
Rule::Rule2 => "Allows exchanging do(Z) with Z under d-separation conditions",
|
||||
Rule::Rule3 => "Allows removing interventions that have no effect on Y",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of an identification attempt (enum for pattern matching)
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Identification {
|
||||
/// Effect is identifiable
|
||||
Identified(IdentificationResult),
|
||||
/// Effect is not identifiable
|
||||
NotIdentified(String),
|
||||
}
|
||||
|
||||
impl Identification {
|
||||
/// Check if identified
|
||||
pub fn is_identified(&self) -> bool {
|
||||
matches!(self, Identification::Identified(_))
|
||||
}
|
||||
|
||||
/// Get the result if identified
|
||||
pub fn result(&self) -> Option<&IdentificationResult> {
|
||||
match self {
|
||||
Identification::Identified(r) => Some(r),
|
||||
Identification::NotIdentified(_) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Detailed result of successful identification
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct IdentificationResult {
|
||||
/// The sequence of rules applied
|
||||
pub rules_applied: Vec<RuleApplication>,
|
||||
/// The final expression
|
||||
pub expression: String,
|
||||
/// Adjustment set (if using backdoor criterion)
|
||||
pub adjustment_set: Option<Vec<String>>,
|
||||
/// Front-door set (if using front-door criterion)
|
||||
pub front_door_set: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
/// Legacy result format for compatibility
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct IdentificationLegacy {
|
||||
/// Whether the query is identifiable
|
||||
pub identifiable: bool,
|
||||
/// The sequence of rules applied
|
||||
pub rules_applied: Vec<RuleApplication>,
|
||||
/// The final expression (if identifiable)
|
||||
pub expression: Option<String>,
|
||||
/// Adjustment set (if using backdoor criterion)
|
||||
pub adjustment_set: Option<Vec<String>>,
|
||||
/// Front-door set (if using front-door criterion)
|
||||
pub front_door_set: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
/// Application of a do-calculus rule
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RuleApplication {
|
||||
/// Which rule was applied
|
||||
pub rule: Rule,
|
||||
/// Variables involved
|
||||
pub variables: Vec<String>,
|
||||
/// Before expression
|
||||
pub before: String,
|
||||
/// After expression
|
||||
pub after: String,
|
||||
/// Graph modification used
|
||||
pub graph_modification: String,
|
||||
}
|
||||
|
||||
/// Do-Calculus engine for causal identification
|
||||
pub struct DoCalculus<'a> {
|
||||
model: &'a CausalModel,
|
||||
}
|
||||
|
||||
impl<'a> DoCalculus<'a> {
|
||||
/// Create a new do-calculus engine
|
||||
pub fn new(model: &'a CausalModel) -> Self {
|
||||
Self { model }
|
||||
}
|
||||
|
||||
/// Identify P(Y | do(X)) - simplified API for single outcome and treatment set
|
||||
///
|
||||
/// Returns Identification enum for pattern matching
|
||||
pub fn identify(&self, outcome: VariableId, treatment_set: &HashSet<VariableId>) -> Identification {
|
||||
// Check if model has latent confounding affecting treatment-outcome
|
||||
for &t in treatment_set {
|
||||
if !self.model.is_unconfounded(t, outcome) {
|
||||
// There is latent confounding
|
||||
// Check if backdoor criterion can be satisfied
|
||||
let treatment_vec: Vec<_> = treatment_set.iter().copied().collect();
|
||||
if let Some(adjustment) = self.find_backdoor_adjustment(&treatment_vec, &[outcome]) {
|
||||
let adjustment_names: Vec<String> = adjustment.iter()
|
||||
.filter_map(|id| self.model.get_variable_name(id))
|
||||
.collect();
|
||||
|
||||
return Identification::Identified(IdentificationResult {
|
||||
rules_applied: vec![RuleApplication {
|
||||
rule: Rule::Rule2,
|
||||
variables: vec![format!("{:?}", outcome)],
|
||||
before: format!("P(Y | do(X))"),
|
||||
after: format!("Backdoor adjustment via {:?}", adjustment_names),
|
||||
graph_modification: "Backdoor criterion".to_string(),
|
||||
}],
|
||||
expression: format!("Σ P(Y | X, Z) P(Z)"),
|
||||
adjustment_set: Some(adjustment_names),
|
||||
front_door_set: None,
|
||||
});
|
||||
}
|
||||
|
||||
// Check front-door
|
||||
if treatment_set.len() == 1 {
|
||||
let treatment = *treatment_set.iter().next().unwrap();
|
||||
if let Some(mediators) = self.find_front_door_set(treatment, outcome) {
|
||||
let mediator_names: Vec<String> = mediators.iter()
|
||||
.filter_map(|id| self.model.get_variable_name(id))
|
||||
.collect();
|
||||
|
||||
return Identification::Identified(IdentificationResult {
|
||||
rules_applied: vec![RuleApplication {
|
||||
rule: Rule::Rule2,
|
||||
variables: vec![format!("{:?}", outcome)],
|
||||
before: format!("P(Y | do(X))"),
|
||||
after: format!("Front-door via {:?}", mediator_names),
|
||||
graph_modification: "Front-door criterion".to_string(),
|
||||
}],
|
||||
expression: format!("Front-door formula"),
|
||||
adjustment_set: None,
|
||||
front_door_set: Some(mediator_names),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return Identification::NotIdentified(
|
||||
"Effect not identifiable due to latent confounding".to_string()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// No latent confounding - directly identifiable
|
||||
Identification::Identified(IdentificationResult {
|
||||
rules_applied: vec![RuleApplication {
|
||||
rule: Rule::Rule3,
|
||||
variables: vec![format!("{:?}", outcome)],
|
||||
before: format!("P(Y | do(X))"),
|
||||
after: format!("P(Y | X)"),
|
||||
graph_modification: "Direct identification".to_string(),
|
||||
}],
|
||||
expression: format!("P(Y | X)"),
|
||||
adjustment_set: Some(vec![]),
|
||||
front_door_set: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Identify using string names (legacy API)
|
||||
pub fn identify_by_name(
|
||||
&self,
|
||||
treatment: &[&str],
|
||||
outcome: &[&str],
|
||||
) -> Result<IdentificationLegacy, IdentificationError> {
|
||||
// Convert names to IDs
|
||||
let treatment_ids: Result<Vec<_>, _> = treatment.iter()
|
||||
.map(|&name| {
|
||||
self.model.get_variable_id(name)
|
||||
.ok_or_else(|| IdentificationError::VariableNotFound(name.to_string()))
|
||||
})
|
||||
.collect();
|
||||
let treatment_ids = treatment_ids?;
|
||||
|
||||
let outcome_ids: Result<Vec<_>, _> = outcome.iter()
|
||||
.map(|&name| {
|
||||
self.model.get_variable_id(name)
|
||||
.ok_or_else(|| IdentificationError::VariableNotFound(name.to_string()))
|
||||
})
|
||||
.collect();
|
||||
let outcome_ids = outcome_ids?;
|
||||
|
||||
// Try different identification strategies
|
||||
let mut rules_applied = Vec::new();
|
||||
|
||||
// Strategy 1: Check backdoor criterion
|
||||
if let Some(adjustment) = self.find_backdoor_adjustment(&treatment_ids, &outcome_ids) {
|
||||
let adjustment_names: Vec<String> = adjustment.iter()
|
||||
.filter_map(|id| self.model.get_variable_name(id))
|
||||
.collect();
|
||||
|
||||
rules_applied.push(RuleApplication {
|
||||
rule: Rule::Rule2,
|
||||
variables: treatment.iter().map(|s| s.to_string()).collect(),
|
||||
before: format!("P({} | do({}))",
|
||||
outcome.join(", "), treatment.join(", ")),
|
||||
after: format!("Σ_{{{}}} P({} | {}, {}) P({})",
|
||||
adjustment_names.join(", "),
|
||||
outcome.join(", "),
|
||||
treatment.join(", "),
|
||||
adjustment_names.join(", "),
|
||||
adjustment_names.join(", ")),
|
||||
graph_modification: "Backdoor criterion satisfied".to_string(),
|
||||
});
|
||||
|
||||
return Ok(IdentificationLegacy {
|
||||
identifiable: true,
|
||||
rules_applied,
|
||||
expression: Some(format!(
|
||||
"Σ_{{{}}} P({} | {}, {}) P({})",
|
||||
adjustment_names.join(", "),
|
||||
outcome.join(", "),
|
||||
treatment.join(", "),
|
||||
adjustment_names.join(", "),
|
||||
adjustment_names.join(", ")
|
||||
)),
|
||||
adjustment_set: Some(adjustment_names),
|
||||
front_door_set: None,
|
||||
});
|
||||
}
|
||||
|
||||
// Strategy 2: Check front-door criterion
|
||||
if treatment_ids.len() == 1 && outcome_ids.len() == 1 {
|
||||
if let Some(mediators) = self.find_front_door_set(treatment_ids[0], outcome_ids[0]) {
|
||||
let mediator_names: Vec<String> = mediators.iter()
|
||||
.filter_map(|id| self.model.get_variable_name(id))
|
||||
.collect();
|
||||
|
||||
rules_applied.push(RuleApplication {
|
||||
rule: Rule::Rule2,
|
||||
variables: vec![treatment[0].to_string()],
|
||||
before: format!("P({} | do({}))", outcome[0], treatment[0]),
|
||||
after: format!("Front-door adjustment via {}", mediator_names.join(", ")),
|
||||
graph_modification: "Front-door criterion satisfied".to_string(),
|
||||
});
|
||||
|
||||
return Ok(IdentificationLegacy {
|
||||
identifiable: true,
|
||||
rules_applied,
|
||||
expression: Some(format!(
|
||||
"Σ_{{{}}} P({} | {}) Σ_{{{}}} P({} | {}, {}) P({})",
|
||||
mediator_names.join(", "),
|
||||
mediator_names.join(", "),
|
||||
treatment[0],
|
||||
treatment[0],
|
||||
outcome[0],
|
||||
mediator_names.join(", "),
|
||||
treatment[0],
|
||||
treatment[0]
|
||||
)),
|
||||
adjustment_set: None,
|
||||
front_door_set: Some(mediator_names),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Strategy 3: Check direct identifiability (no confounders)
|
||||
if self.is_directly_identifiable(&treatment_ids, &outcome_ids) {
|
||||
rules_applied.push(RuleApplication {
|
||||
rule: Rule::Rule3,
|
||||
variables: treatment.iter().map(|s| s.to_string()).collect(),
|
||||
before: format!("P({} | do({}))", outcome.join(", "), treatment.join(", ")),
|
||||
after: format!("P({} | {})", outcome.join(", "), treatment.join(", ")),
|
||||
graph_modification: "No confounders; direct identification".to_string(),
|
||||
});
|
||||
|
||||
return Ok(IdentificationLegacy {
|
||||
identifiable: true,
|
||||
rules_applied,
|
||||
expression: Some(format!("P({} | {})", outcome.join(", "), treatment.join(", "))),
|
||||
adjustment_set: Some(vec![]),
|
||||
front_door_set: None,
|
||||
});
|
||||
}
|
||||
|
||||
// Not identifiable
|
||||
Ok(IdentificationLegacy {
|
||||
identifiable: false,
|
||||
rules_applied: vec![],
|
||||
expression: None,
|
||||
adjustment_set: None,
|
||||
front_door_set: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Check Rule 1: Can we add/remove observation Z?
|
||||
///
|
||||
/// P(y | do(x), z, w) = P(y | do(x), w) if (Y ⊥ Z | X, W) in G_{\overline{X}}
|
||||
pub fn can_apply_rule1(
|
||||
&self,
|
||||
y: &[VariableId],
|
||||
x: &[VariableId],
|
||||
z: &[VariableId],
|
||||
w: &[VariableId],
|
||||
) -> bool {
|
||||
// Build G_{\overline{X}}: delete incoming edges to X
|
||||
let modified_graph = self.graph_delete_incoming(x);
|
||||
|
||||
// Check d-separation of Y and Z given X ∪ W in modified graph
|
||||
let y_set: HashSet<_> = y.iter().map(|id| id.0).collect();
|
||||
let z_set: HashSet<_> = z.iter().map(|id| id.0).collect();
|
||||
let mut conditioning: HashSet<_> = x.iter().map(|id| id.0).collect();
|
||||
conditioning.extend(w.iter().map(|id| id.0));
|
||||
|
||||
modified_graph.d_separated(&y_set, &z_set, &conditioning)
|
||||
}
|
||||
|
||||
/// Check Rule 2: Can we exchange do(Z) with observation Z?
|
||||
///
|
||||
/// P(y | do(x), do(z), w) = P(y | do(x), z, w) if (Y ⊥ Z | X, W) in G_{\overline{X}\underline{Z}}
|
||||
pub fn can_apply_rule2(
|
||||
&self,
|
||||
y: &[VariableId],
|
||||
x: &[VariableId],
|
||||
z: &[VariableId],
|
||||
w: &[VariableId],
|
||||
) -> bool {
|
||||
// Build G_{\overline{X}\underline{Z}}: delete incoming edges to X and outgoing from Z
|
||||
let modified_graph = self.graph_delete_incoming_and_outgoing(x, z);
|
||||
|
||||
// Check d-separation
|
||||
let y_set: HashSet<_> = y.iter().map(|id| id.0).collect();
|
||||
let z_set: HashSet<_> = z.iter().map(|id| id.0).collect();
|
||||
let mut conditioning: HashSet<_> = x.iter().map(|id| id.0).collect();
|
||||
conditioning.extend(w.iter().map(|id| id.0));
|
||||
|
||||
modified_graph.d_separated(&y_set, &z_set, &conditioning)
|
||||
}
|
||||
|
||||
/// Check Rule 3: Can we remove do(Z)?
|
||||
///
|
||||
/// P(y | do(x), do(z), w) = P(y | do(x), w) if (Y ⊥ Z | X, W) in G_{\overline{X}\overline{Z(W)}}
|
||||
pub fn can_apply_rule3(
|
||||
&self,
|
||||
y: &[VariableId],
|
||||
x: &[VariableId],
|
||||
z: &[VariableId],
|
||||
w: &[VariableId],
|
||||
) -> bool {
|
||||
// Build G_{\overline{X}\overline{Z(W)}}: more complex modification
|
||||
// Z(W) = Z \ ancestors of W in G_{\overline{X}}
|
||||
|
||||
// First get G_{\overline{X}}
|
||||
let g_no_x = self.graph_delete_incoming(x);
|
||||
|
||||
// Find ancestors of W in G_{\overline{X}}
|
||||
let w_ancestors: HashSet<_> = w.iter()
|
||||
.flat_map(|wv| g_no_x.ancestors(wv.0))
|
||||
.collect();
|
||||
|
||||
// Z(W) = Z without W's ancestors
|
||||
let z_without_w_ancestors: Vec<_> = z.iter()
|
||||
.filter(|zv| !w_ancestors.contains(&zv.0))
|
||||
.copied()
|
||||
.collect();
|
||||
|
||||
// Build G_{\overline{X}\overline{Z(W)}}
|
||||
let modified_graph = self.graph_delete_incoming_multiple(
|
||||
&[x, &z_without_w_ancestors].concat()
|
||||
);
|
||||
|
||||
// Check d-separation
|
||||
let y_set: HashSet<_> = y.iter().map(|id| id.0).collect();
|
||||
let z_set: HashSet<_> = z.iter().map(|id| id.0).collect();
|
||||
let mut conditioning: HashSet<_> = x.iter().map(|id| id.0).collect();
|
||||
conditioning.extend(w.iter().map(|id| id.0));
|
||||
|
||||
modified_graph.d_separated(&y_set, &z_set, &conditioning)
|
||||
}
|
||||
|
||||
/// Check Rule 1 with HashSet API (for test compatibility)
|
||||
///
|
||||
/// P(y | do(x), z) = P(y | do(x)) if (Y ⊥ Z | X) in G_{\overline{X}}
|
||||
pub fn can_apply_rule1_sets(
|
||||
&self,
|
||||
y: &HashSet<VariableId>,
|
||||
x: &HashSet<VariableId>,
|
||||
z: &HashSet<VariableId>,
|
||||
) -> bool {
|
||||
let y_vec: Vec<_> = y.iter().copied().collect();
|
||||
let x_vec: Vec<_> = x.iter().copied().collect();
|
||||
let z_vec: Vec<_> = z.iter().copied().collect();
|
||||
self.can_apply_rule1(&y_vec, &x_vec, &z_vec, &[])
|
||||
}
|
||||
|
||||
/// Check Rule 2 with HashSet API (for test compatibility)
|
||||
pub fn can_apply_rule2_sets(
|
||||
&self,
|
||||
y: &HashSet<VariableId>,
|
||||
x: &HashSet<VariableId>,
|
||||
z: &HashSet<VariableId>,
|
||||
) -> bool {
|
||||
let y_vec: Vec<_> = y.iter().copied().collect();
|
||||
let x_vec: Vec<_> = x.iter().copied().collect();
|
||||
let z_vec: Vec<_> = z.iter().copied().collect();
|
||||
self.can_apply_rule2(&y_vec, &x_vec, &z_vec, &[])
|
||||
}
|
||||
|
||||
/// Check Rule 3 with HashSet API (for test compatibility)
|
||||
pub fn can_apply_rule3_sets(
|
||||
&self,
|
||||
y: &HashSet<VariableId>,
|
||||
x: &HashSet<VariableId>,
|
||||
z: &HashSet<VariableId>,
|
||||
) -> bool {
|
||||
let y_vec: Vec<_> = y.iter().copied().collect();
|
||||
let x_vec: Vec<_> = x.iter().copied().collect();
|
||||
let z_vec: Vec<_> = z.iter().copied().collect();
|
||||
self.can_apply_rule3(&y_vec, &x_vec, &z_vec, &[])
|
||||
}
|
||||
|
||||
/// Find a valid backdoor adjustment set
|
||||
fn find_backdoor_adjustment(
|
||||
&self,
|
||||
treatment: &[VariableId],
|
||||
outcome: &[VariableId],
|
||||
) -> Option<Vec<VariableId>> {
|
||||
// Get all potential adjustment variables (not descendants of treatment)
|
||||
let treatment_descendants: HashSet<_> = treatment.iter()
|
||||
.flat_map(|t| self.model.graph().descendants(t.0))
|
||||
.collect();
|
||||
|
||||
let potential_adjusters: Vec<_> = self.model.variables()
|
||||
.filter(|v| !treatment.contains(&v.id))
|
||||
.filter(|v| !outcome.contains(&v.id))
|
||||
.filter(|v| !treatment_descendants.contains(&v.id.0))
|
||||
.map(|v| v.id)
|
||||
.collect();
|
||||
|
||||
// Try the full set first
|
||||
if self.satisfies_backdoor_criterion(treatment, outcome, &potential_adjusters) {
|
||||
return Some(potential_adjusters);
|
||||
}
|
||||
|
||||
// Try minimal subsets
|
||||
if potential_adjusters.is_empty() {
|
||||
if self.satisfies_backdoor_criterion(treatment, outcome, &[]) {
|
||||
return Some(vec![]);
|
||||
}
|
||||
}
|
||||
|
||||
// Try single-variable adjustments
|
||||
for &adjuster in &potential_adjusters {
|
||||
if self.satisfies_backdoor_criterion(treatment, outcome, &[adjuster]) {
|
||||
return Some(vec![adjuster]);
|
||||
}
|
||||
}
|
||||
|
||||
// Try pairs
|
||||
for i in 0..potential_adjusters.len() {
|
||||
for j in (i + 1)..potential_adjusters.len() {
|
||||
let pair = vec![potential_adjusters[i], potential_adjusters[j]];
|
||||
if self.satisfies_backdoor_criterion(treatment, outcome, &pair) {
|
||||
return Some(pair);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Check if a set satisfies the backdoor criterion
|
||||
fn satisfies_backdoor_criterion(
|
||||
&self,
|
||||
treatment: &[VariableId],
|
||||
outcome: &[VariableId],
|
||||
adjustment: &[VariableId],
|
||||
) -> bool {
|
||||
// Backdoor criterion:
|
||||
// 1. No node in Z is a descendant of X
|
||||
// 2. Z blocks all backdoor paths from X to Y
|
||||
|
||||
// Condition 1: already ensured by caller
|
||||
|
||||
// Condition 2: Check d-separation in G_{\overline{X}}
|
||||
let g_no_x = self.graph_delete_incoming(treatment);
|
||||
|
||||
for &x in treatment {
|
||||
for &y in outcome {
|
||||
let x_set: HashSet<_> = [x.0].into_iter().collect();
|
||||
let y_set: HashSet<_> = [y.0].into_iter().collect();
|
||||
let z_set: HashSet<_> = adjustment.iter().map(|v| v.0).collect();
|
||||
|
||||
if !g_no_x.d_separated(&x_set, &y_set, &z_set) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Find a front-door adjustment set (for single treatment/outcome)
|
||||
fn find_front_door_set(
|
||||
&self,
|
||||
treatment: VariableId,
|
||||
outcome: VariableId,
|
||||
) -> Option<Vec<VariableId>> {
|
||||
// Front-door criterion:
|
||||
// 1. M intercepts all directed paths from X to Y
|
||||
// 2. There is no unblocked backdoor path from X to M
|
||||
// 3. All backdoor paths from M to Y are blocked by X
|
||||
|
||||
let descendants_of_x = self.model.graph().descendants(treatment.0);
|
||||
let ancestors_of_y = self.model.graph().ancestors(outcome.0);
|
||||
|
||||
// M must be on path from X to Y
|
||||
let candidates: Vec<_> = descendants_of_x.intersection(&ancestors_of_y)
|
||||
.filter(|&&m| m != treatment.0 && m != outcome.0)
|
||||
.map(|&m| VariableId(m))
|
||||
.collect();
|
||||
|
||||
if candidates.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Check each candidate
|
||||
for &m in &candidates {
|
||||
// Check condition 2: no backdoor from X to M
|
||||
let x_set: HashSet<_> = [treatment.0].into_iter().collect();
|
||||
let m_set: HashSet<_> = [m.0].into_iter().collect();
|
||||
|
||||
if self.model.graph().d_separated(&x_set, &m_set, &HashSet::new()) {
|
||||
continue; // X and M are d-separated (no path at all)
|
||||
}
|
||||
|
||||
// Check condition 3: backdoor from M to Y blocked by X
|
||||
let y_set: HashSet<_> = [outcome.0].into_iter().collect();
|
||||
|
||||
let g_underline_m = self.graph_delete_outgoing(&[m]);
|
||||
|
||||
if g_underline_m.d_separated(&m_set, &y_set, &x_set) {
|
||||
return Some(vec![m]);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Check if effect is directly identifiable (no confounders)
|
||||
fn is_directly_identifiable(
|
||||
&self,
|
||||
treatment: &[VariableId],
|
||||
outcome: &[VariableId],
|
||||
) -> bool {
|
||||
// Check if there are any backdoor paths
|
||||
for &x in treatment {
|
||||
for &y in outcome {
|
||||
let x_ancestors = self.model.graph().ancestors(x.0);
|
||||
let y_ancestors = self.model.graph().ancestors(y.0);
|
||||
|
||||
// If they share common ancestors, there might be confounding
|
||||
if !x_ancestors.is_disjoint(&y_ancestors) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
// Graph manipulation helpers
|
||||
|
||||
/// Create G_{\overline{X}}: delete incoming edges to X
|
||||
fn graph_delete_incoming(&self, x: &[VariableId]) -> DirectedGraph {
|
||||
let mut modified = self.model.graph().clone();
|
||||
|
||||
for &xi in x {
|
||||
if let Some(parents) = self.model.parents(&xi) {
|
||||
for parent in parents {
|
||||
modified.remove_edge(parent.0, xi.0).ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
modified
|
||||
}
|
||||
|
||||
/// Create G_{\underline{Z}}: delete outgoing edges from Z
|
||||
fn graph_delete_outgoing(&self, z: &[VariableId]) -> DirectedGraph {
|
||||
let mut modified = self.model.graph().clone();
|
||||
|
||||
for &zi in z {
|
||||
if let Some(children) = self.model.children(&zi) {
|
||||
for child in children {
|
||||
modified.remove_edge(zi.0, child.0).ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
modified
|
||||
}
|
||||
|
||||
/// Create G_{\overline{X}\underline{Z}}
|
||||
fn graph_delete_incoming_and_outgoing(
|
||||
&self,
|
||||
x: &[VariableId],
|
||||
z: &[VariableId],
|
||||
) -> DirectedGraph {
|
||||
let mut modified = self.graph_delete_incoming(x);
|
||||
|
||||
for &zi in z {
|
||||
if let Some(children) = self.model.children(&zi) {
|
||||
for child in children {
|
||||
modified.remove_edge(zi.0, child.0).ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
modified
|
||||
}
|
||||
|
||||
/// Delete incoming edges to multiple variable sets
|
||||
fn graph_delete_incoming_multiple(&self, vars: &[VariableId]) -> DirectedGraph {
|
||||
self.graph_delete_incoming(vars)
|
||||
}
|
||||
|
||||
/// Compute the causal effect using the identified formula
|
||||
pub fn compute_effect(
|
||||
&self,
|
||||
identification: &Identification,
|
||||
data: &HashMap<String, Vec<f64>>,
|
||||
) -> Result<f64, IdentificationError> {
|
||||
if !identification.identifiable {
|
||||
return Err(IdentificationError::NotIdentifiable(
|
||||
"Cannot compute unidentifiable effect".to_string()
|
||||
));
|
||||
}
|
||||
|
||||
// Simple implementation: use adjustment formula if available
|
||||
if let Some(ref adjustment_names) = identification.adjustment_set {
|
||||
if adjustment_names.is_empty() {
|
||||
// Direct effect - compute from data
|
||||
return self.compute_direct_effect(data);
|
||||
}
|
||||
// Adjusted effect
|
||||
return self.compute_adjusted_effect(data, adjustment_names);
|
||||
}
|
||||
|
||||
// Front-door adjustment
|
||||
if identification.front_door_set.is_some() {
|
||||
return self.compute_frontdoor_effect(data, identification);
|
||||
}
|
||||
|
||||
Err(IdentificationError::NotIdentifiable(
|
||||
"No valid estimation strategy".to_string()
|
||||
))
|
||||
}
|
||||
|
||||
fn compute_direct_effect(&self, data: &HashMap<String, Vec<f64>>) -> Result<f64, IdentificationError> {
|
||||
// Simple regression coefficient as effect estimate
|
||||
// This is a placeholder - real implementation would use proper estimation
|
||||
Ok(0.0)
|
||||
}
|
||||
|
||||
fn compute_adjusted_effect(
|
||||
&self,
|
||||
_data: &HashMap<String, Vec<f64>>,
|
||||
_adjustment: &[String],
|
||||
) -> Result<f64, IdentificationError> {
|
||||
// Adjusted regression or inverse probability weighting
|
||||
// Placeholder implementation
|
||||
Ok(0.0)
|
||||
}
|
||||
|
||||
fn compute_frontdoor_effect(
|
||||
&self,
|
||||
_data: &HashMap<String, Vec<f64>>,
|
||||
_identification: &Identification,
|
||||
) -> Result<f64, IdentificationError> {
|
||||
// Front-door formula computation
|
||||
// Placeholder implementation
|
||||
Ok(0.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::causal::model::{CausalModel, VariableType, Mechanism, Value};
|
||||
|
||||
fn create_confounded_model() -> CausalModel {
|
||||
// X -> Y with unobserved confounder U
|
||||
// U -> X, U -> Y
|
||||
let mut model = CausalModel::with_name("Confounded");
|
||||
|
||||
model.add_variable("U", VariableType::Continuous).unwrap();
|
||||
model.add_variable("X", VariableType::Continuous).unwrap();
|
||||
model.add_variable("Y", VariableType::Continuous).unwrap();
|
||||
|
||||
let u = model.get_variable_id("U").unwrap();
|
||||
let x = model.get_variable_id("X").unwrap();
|
||||
let y = model.get_variable_id("Y").unwrap();
|
||||
|
||||
model.add_edge(u, x).unwrap();
|
||||
model.add_edge(u, y).unwrap();
|
||||
model.add_edge(x, y).unwrap();
|
||||
|
||||
model.add_structural_equation(x, &[u], Mechanism::new(|p| {
|
||||
Value::Continuous(p[0].as_f64() + 1.0)
|
||||
})).unwrap();
|
||||
|
||||
model.add_structural_equation(y, &[x, u], Mechanism::new(|p| {
|
||||
Value::Continuous(p[0].as_f64() * 2.0 + p[1].as_f64())
|
||||
})).unwrap();
|
||||
|
||||
model
|
||||
}
|
||||
|
||||
fn create_frontdoor_model() -> CausalModel {
|
||||
// X -> M -> Y with X-Y confounded
|
||||
let mut model = CausalModel::with_name("FrontDoor");
|
||||
|
||||
model.add_variable("U", VariableType::Continuous).unwrap(); // Unobserved confounder
|
||||
model.add_variable("X", VariableType::Continuous).unwrap();
|
||||
model.add_variable("M", VariableType::Continuous).unwrap();
|
||||
model.add_variable("Y", VariableType::Continuous).unwrap();
|
||||
|
||||
let u = model.get_variable_id("U").unwrap();
|
||||
let x = model.get_variable_id("X").unwrap();
|
||||
let m = model.get_variable_id("M").unwrap();
|
||||
let y = model.get_variable_id("Y").unwrap();
|
||||
|
||||
model.add_edge(u, x).unwrap();
|
||||
model.add_edge(u, y).unwrap();
|
||||
model.add_edge(x, m).unwrap();
|
||||
model.add_edge(m, y).unwrap();
|
||||
|
||||
model.add_structural_equation(x, &[u], Mechanism::new(|p| {
|
||||
p[0].clone()
|
||||
})).unwrap();
|
||||
|
||||
model.add_structural_equation(m, &[x], Mechanism::new(|p| {
|
||||
p[0].clone()
|
||||
})).unwrap();
|
||||
|
||||
model.add_structural_equation(y, &[m, u], Mechanism::new(|p| {
|
||||
Value::Continuous(p[0].as_f64() + p[1].as_f64())
|
||||
})).unwrap();
|
||||
|
||||
model
|
||||
}
|
||||
|
||||
fn create_unconfounded_model() -> CausalModel {
|
||||
// Simple X -> Y without confounding
|
||||
let mut model = CausalModel::with_name("Unconfounded");
|
||||
|
||||
model.add_variable("X", VariableType::Continuous).unwrap();
|
||||
model.add_variable("Y", VariableType::Continuous).unwrap();
|
||||
|
||||
let x = model.get_variable_id("X").unwrap();
|
||||
let y = model.get_variable_id("Y").unwrap();
|
||||
|
||||
model.add_edge(x, y).unwrap();
|
||||
|
||||
model.add_structural_equation(y, &[x], Mechanism::new(|p| {
|
||||
Value::Continuous(2.0 * p[0].as_f64())
|
||||
})).unwrap();
|
||||
|
||||
model
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unconfounded_identifiable() {
|
||||
let model = create_unconfounded_model();
|
||||
let calc = DoCalculus::new(&model);
|
||||
|
||||
let result = calc.identify(&["X"], &["Y"]).unwrap();
|
||||
|
||||
assert!(result.identifiable);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_confounded_with_adjustment() {
|
||||
let model = create_confounded_model();
|
||||
let calc = DoCalculus::new(&model);
|
||||
|
||||
// With U observed, we can adjust for it
|
||||
let result = calc.identify(&["X"], &["Y"]).unwrap();
|
||||
|
||||
// Should be identifiable by adjusting for U
|
||||
assert!(result.identifiable);
|
||||
assert!(result.adjustment_set.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_frontdoor_identification() {
|
||||
let model = create_frontdoor_model();
|
||||
let calc = DoCalculus::new(&model);
|
||||
|
||||
let result = calc.identify(&["X"], &["Y"]).unwrap();
|
||||
|
||||
// Should be identifiable via front-door criterion
|
||||
assert!(result.identifiable);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rule1_application() {
|
||||
let model = create_unconfounded_model();
|
||||
let calc = DoCalculus::new(&model);
|
||||
|
||||
let x = model.get_variable_id("X").unwrap();
|
||||
let y = model.get_variable_id("Y").unwrap();
|
||||
|
||||
// In a simple X -> Y model, Y is not independent of X
|
||||
let can_remove_x = calc.can_apply_rule1(&[y], &[], &[x], &[]);
|
||||
|
||||
assert!(!can_remove_x); // Cannot remove X observation
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rule2_application() {
|
||||
let model = create_unconfounded_model();
|
||||
let calc = DoCalculus::new(&model);
|
||||
|
||||
let x = model.get_variable_id("X").unwrap();
|
||||
let y = model.get_variable_id("Y").unwrap();
|
||||
|
||||
// Can we exchange do(X) with observation X?
|
||||
let can_exchange = calc.can_apply_rule2(&[y], &[], &[x], &[]);
|
||||
|
||||
// In simple X -> Y, deleting outgoing from X blocks the path
|
||||
assert!(can_exchange);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rule_descriptions() {
|
||||
assert!(Rule::Rule1.name().contains("observation"));
|
||||
assert!(Rule::Rule2.name().contains("exchange"));
|
||||
assert!(Rule::Rule3.name().contains("deletion"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_identification_result() {
|
||||
let model = create_unconfounded_model();
|
||||
let calc = DoCalculus::new(&model);
|
||||
|
||||
let result = calc.identify(&["X"], &["Y"]).unwrap();
|
||||
|
||||
assert!(result.identifiable);
|
||||
assert!(result.expression.is_some());
|
||||
}
|
||||
}
|
||||
846
vendor/ruvector/examples/prime-radiant/src/causal/graph.rs
vendored
Normal file
846
vendor/ruvector/examples/prime-radiant/src/causal/graph.rs
vendored
Normal file
@@ -0,0 +1,846 @@
|
||||
//! Directed Acyclic Graph (DAG) implementation for causal models
|
||||
//!
|
||||
//! This module provides a validated DAG structure that ensures:
|
||||
//! - No cycles (acyclicity constraint)
|
||||
//! - Efficient topological ordering
|
||||
//! - Parent/child relationship queries
|
||||
//! - D-separation computations
|
||||
|
||||
use std::collections::{HashMap, HashSet, VecDeque};
|
||||
use thiserror::Error;
|
||||
|
||||
/// Error types for DAG operations
|
||||
#[derive(Debug, Clone, Error)]
|
||||
pub enum DAGValidationError {
|
||||
/// Cycle detected in graph
|
||||
#[error("Cycle detected involving nodes: {0:?}")]
|
||||
CycleDetected(Vec<u32>),
|
||||
|
||||
/// Node not found
|
||||
#[error("Node {0} not found in graph")]
|
||||
NodeNotFound(u32),
|
||||
|
||||
/// Edge already exists
|
||||
#[error("Edge from {0} to {1} already exists")]
|
||||
EdgeExists(u32, u32),
|
||||
|
||||
/// Self-loop detected
|
||||
#[error("Self-loop detected at node {0}")]
|
||||
SelfLoop(u32),
|
||||
|
||||
/// Invalid operation on empty graph
|
||||
#[error("Graph is empty")]
|
||||
EmptyGraph,
|
||||
}
|
||||
|
||||
/// A directed acyclic graph for causal relationships
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DirectedGraph {
|
||||
/// Number of nodes
|
||||
num_nodes: usize,
|
||||
|
||||
/// Adjacency list: node -> children
|
||||
children: HashMap<u32, HashSet<u32>>,
|
||||
|
||||
/// Reverse adjacency: node -> parents
|
||||
parents: HashMap<u32, HashSet<u32>>,
|
||||
|
||||
/// Node labels (optional)
|
||||
labels: HashMap<u32, String>,
|
||||
|
||||
/// Cached topological order (invalidated on structural changes)
|
||||
cached_topo_order: Option<Vec<u32>>,
|
||||
}
|
||||
|
||||
impl DirectedGraph {
|
||||
/// Create a new empty directed graph
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
num_nodes: 0,
|
||||
children: HashMap::new(),
|
||||
parents: HashMap::new(),
|
||||
labels: HashMap::new(),
|
||||
cached_topo_order: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a graph with pre-allocated capacity
|
||||
pub fn with_capacity(nodes: usize) -> Self {
|
||||
Self {
|
||||
num_nodes: 0,
|
||||
children: HashMap::with_capacity(nodes),
|
||||
parents: HashMap::with_capacity(nodes),
|
||||
labels: HashMap::with_capacity(nodes),
|
||||
cached_topo_order: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a node to the graph
|
||||
pub fn add_node(&mut self, id: u32) -> u32 {
|
||||
if !self.children.contains_key(&id) {
|
||||
self.children.insert(id, HashSet::new());
|
||||
self.parents.insert(id, HashSet::new());
|
||||
self.num_nodes += 1;
|
||||
self.cached_topo_order = None;
|
||||
}
|
||||
id
|
||||
}
|
||||
|
||||
/// Add a node with a label
|
||||
pub fn add_node_with_label(&mut self, id: u32, label: &str) -> u32 {
|
||||
self.add_node(id);
|
||||
self.labels.insert(id, label.to_string());
|
||||
id
|
||||
}
|
||||
|
||||
/// Add a directed edge from `from` to `to`
|
||||
///
|
||||
/// Returns error if edge would create a cycle
|
||||
pub fn add_edge(&mut self, from: u32, to: u32) -> Result<(), DAGValidationError> {
|
||||
// Check for self-loop
|
||||
if from == to {
|
||||
return Err(DAGValidationError::SelfLoop(from));
|
||||
}
|
||||
|
||||
// Ensure nodes exist
|
||||
self.add_node(from);
|
||||
self.add_node(to);
|
||||
|
||||
// Check if edge already exists
|
||||
if self.children.get(&from).map(|c| c.contains(&to)).unwrap_or(false) {
|
||||
return Err(DAGValidationError::EdgeExists(from, to));
|
||||
}
|
||||
|
||||
// Temporarily add edge and check for cycles
|
||||
self.children.get_mut(&from).unwrap().insert(to);
|
||||
self.parents.get_mut(&to).unwrap().insert(from);
|
||||
|
||||
if self.has_cycle() {
|
||||
// Remove edge if cycle detected
|
||||
self.children.get_mut(&from).unwrap().remove(&to);
|
||||
self.parents.get_mut(&to).unwrap().remove(&from);
|
||||
return Err(DAGValidationError::CycleDetected(self.find_cycle_nodes(from, to)));
|
||||
}
|
||||
|
||||
self.cached_topo_order = None;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove an edge from the graph
|
||||
pub fn remove_edge(&mut self, from: u32, to: u32) -> Result<(), DAGValidationError> {
|
||||
if !self.children.contains_key(&from) {
|
||||
return Err(DAGValidationError::NodeNotFound(from));
|
||||
}
|
||||
if !self.children.contains_key(&to) {
|
||||
return Err(DAGValidationError::NodeNotFound(to));
|
||||
}
|
||||
|
||||
self.children.get_mut(&from).unwrap().remove(&to);
|
||||
self.parents.get_mut(&to).unwrap().remove(&from);
|
||||
self.cached_topo_order = None;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if the graph has a cycle (using DFS)
|
||||
fn has_cycle(&self) -> bool {
|
||||
let mut visited = HashSet::new();
|
||||
let mut rec_stack = HashSet::new();
|
||||
|
||||
for &node in self.children.keys() {
|
||||
if self.has_cycle_util(node, &mut visited, &mut rec_stack) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
fn has_cycle_util(
|
||||
&self,
|
||||
node: u32,
|
||||
visited: &mut HashSet<u32>,
|
||||
rec_stack: &mut HashSet<u32>,
|
||||
) -> bool {
|
||||
if rec_stack.contains(&node) {
|
||||
return true;
|
||||
}
|
||||
if visited.contains(&node) {
|
||||
return false;
|
||||
}
|
||||
|
||||
visited.insert(node);
|
||||
rec_stack.insert(node);
|
||||
|
||||
if let Some(children) = self.children.get(&node) {
|
||||
for &child in children {
|
||||
if self.has_cycle_util(child, visited, rec_stack) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rec_stack.remove(&node);
|
||||
false
|
||||
}
|
||||
|
||||
/// Find nodes involved in a potential cycle
|
||||
fn find_cycle_nodes(&self, from: u32, to: u32) -> Vec<u32> {
|
||||
// Find path from `to` back to `from`
|
||||
let mut path = Vec::new();
|
||||
let mut visited = HashSet::new();
|
||||
|
||||
fn dfs(
|
||||
graph: &DirectedGraph,
|
||||
current: u32,
|
||||
target: u32,
|
||||
visited: &mut HashSet<u32>,
|
||||
path: &mut Vec<u32>,
|
||||
) -> bool {
|
||||
if current == target {
|
||||
path.push(current);
|
||||
return true;
|
||||
}
|
||||
if visited.contains(¤t) {
|
||||
return false;
|
||||
}
|
||||
visited.insert(current);
|
||||
path.push(current);
|
||||
|
||||
if let Some(children) = graph.children.get(¤t) {
|
||||
for &child in children {
|
||||
if dfs(graph, child, target, visited, path) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
path.pop();
|
||||
false
|
||||
}
|
||||
|
||||
if dfs(self, to, from, &mut visited, &mut path) {
|
||||
path.push(to);
|
||||
}
|
||||
|
||||
path
|
||||
}
|
||||
|
||||
/// Get children of a node
|
||||
pub fn children_of(&self, node: u32) -> Option<&HashSet<u32>> {
|
||||
self.children.get(&node)
|
||||
}
|
||||
|
||||
/// Get parents of a node
|
||||
pub fn parents_of(&self, node: u32) -> Option<&HashSet<u32>> {
|
||||
self.parents.get(&node)
|
||||
}
|
||||
|
||||
/// Check if node exists
|
||||
pub fn contains_node(&self, node: u32) -> bool {
|
||||
self.children.contains_key(&node)
|
||||
}
|
||||
|
||||
/// Check if edge exists
|
||||
pub fn contains_edge(&self, from: u32, to: u32) -> bool {
|
||||
self.children.get(&from).map(|c| c.contains(&to)).unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Get number of nodes
|
||||
pub fn node_count(&self) -> usize {
|
||||
self.num_nodes
|
||||
}
|
||||
|
||||
/// Get number of edges
|
||||
pub fn edge_count(&self) -> usize {
|
||||
self.children.values().map(|c| c.len()).sum()
|
||||
}
|
||||
|
||||
/// Get all nodes
|
||||
pub fn nodes(&self) -> impl Iterator<Item = u32> + '_ {
|
||||
self.children.keys().copied()
|
||||
}
|
||||
|
||||
/// Get all edges as (from, to) pairs
|
||||
pub fn edges(&self) -> impl Iterator<Item = (u32, u32)> + '_ {
|
||||
self.children.iter().flat_map(|(&from, children)| {
|
||||
children.iter().map(move |&to| (from, to))
|
||||
})
|
||||
}
|
||||
|
||||
/// Get node label
|
||||
pub fn get_label(&self, node: u32) -> Option<&str> {
|
||||
self.labels.get(&node).map(|s| s.as_str())
|
||||
}
|
||||
|
||||
/// Find node by label
|
||||
pub fn find_node_by_label(&self, label: &str) -> Option<u32> {
|
||||
self.labels.iter()
|
||||
.find(|(_, l)| l.as_str() == label)
|
||||
.map(|(&id, _)| id)
|
||||
}
|
||||
|
||||
/// Compute topological ordering using Kahn's algorithm
|
||||
pub fn topological_order(&mut self) -> Result<TopologicalOrder, DAGValidationError> {
|
||||
if self.num_nodes == 0 {
|
||||
return Err(DAGValidationError::EmptyGraph);
|
||||
}
|
||||
|
||||
// Use cached order if available
|
||||
if let Some(ref order) = self.cached_topo_order {
|
||||
return Ok(TopologicalOrder { order: order.clone() });
|
||||
}
|
||||
|
||||
// Compute in-degrees
|
||||
let mut in_degree: HashMap<u32, usize> = HashMap::new();
|
||||
for &node in self.children.keys() {
|
||||
in_degree.insert(node, 0);
|
||||
}
|
||||
for children in self.children.values() {
|
||||
for &child in children {
|
||||
*in_degree.entry(child).or_insert(0) += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize queue with nodes having in-degree 0
|
||||
let mut queue: VecDeque<u32> = in_degree
|
||||
.iter()
|
||||
.filter(|&(_, °)| deg == 0)
|
||||
.map(|(&node, _)| node)
|
||||
.collect();
|
||||
|
||||
let mut order = Vec::with_capacity(self.num_nodes);
|
||||
|
||||
while let Some(node) = queue.pop_front() {
|
||||
order.push(node);
|
||||
|
||||
if let Some(children) = self.children.get(&node) {
|
||||
for &child in children {
|
||||
if let Some(deg) = in_degree.get_mut(&child) {
|
||||
*deg -= 1;
|
||||
if *deg == 0 {
|
||||
queue.push_back(child);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if order.len() != self.num_nodes {
|
||||
return Err(DAGValidationError::CycleDetected(
|
||||
in_degree.iter()
|
||||
.filter(|&(_, °)| deg > 0)
|
||||
.map(|(&node, _)| node)
|
||||
.collect()
|
||||
));
|
||||
}
|
||||
|
||||
self.cached_topo_order = Some(order.clone());
|
||||
Ok(TopologicalOrder { order })
|
||||
}
|
||||
|
||||
/// Get ancestors of a node (all nodes that can reach it)
|
||||
pub fn ancestors(&self, node: u32) -> HashSet<u32> {
|
||||
let mut ancestors = HashSet::new();
|
||||
let mut queue = VecDeque::new();
|
||||
|
||||
if let Some(parents) = self.parents.get(&node) {
|
||||
for &parent in parents {
|
||||
queue.push_back(parent);
|
||||
}
|
||||
}
|
||||
|
||||
while let Some(current) = queue.pop_front() {
|
||||
if ancestors.insert(current) {
|
||||
if let Some(parents) = self.parents.get(¤t) {
|
||||
for &parent in parents {
|
||||
if !ancestors.contains(&parent) {
|
||||
queue.push_back(parent);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ancestors
|
||||
}
|
||||
|
||||
/// Get descendants of a node (all nodes reachable from it)
|
||||
pub fn descendants(&self, node: u32) -> HashSet<u32> {
|
||||
let mut descendants = HashSet::new();
|
||||
let mut queue = VecDeque::new();
|
||||
|
||||
if let Some(children) = self.children.get(&node) {
|
||||
for &child in children {
|
||||
queue.push_back(child);
|
||||
}
|
||||
}
|
||||
|
||||
while let Some(current) = queue.pop_front() {
|
||||
if descendants.insert(current) {
|
||||
if let Some(children) = self.children.get(¤t) {
|
||||
for &child in children {
|
||||
if !descendants.contains(&child) {
|
||||
queue.push_back(child);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
descendants
|
||||
}
|
||||
|
||||
/// Check d-separation between X and Y given conditioning set Z
|
||||
///
|
||||
/// Two sets X and Y are d-separated by Z if all paths between X and Y
|
||||
/// are blocked by Z.
|
||||
pub fn d_separated(
|
||||
&self,
|
||||
x: &HashSet<u32>,
|
||||
y: &HashSet<u32>,
|
||||
z: &HashSet<u32>,
|
||||
) -> bool {
|
||||
// Use Bayes Ball algorithm for d-separation
|
||||
let reachable = self.bayes_ball_reachable(x, z);
|
||||
|
||||
// X and Y are d-separated if no node in Y is reachable
|
||||
reachable.intersection(y).next().is_none()
|
||||
}
|
||||
|
||||
/// Bayes Ball algorithm to find reachable nodes
|
||||
///
|
||||
/// Returns the set of nodes reachable from `source` given evidence `evidence`
|
||||
fn bayes_ball_reachable(&self, source: &HashSet<u32>, evidence: &HashSet<u32>) -> HashSet<u32> {
|
||||
let mut visited_up: HashSet<u32> = HashSet::new();
|
||||
let mut visited_down: HashSet<u32> = HashSet::new();
|
||||
let mut reachable: HashSet<u32> = HashSet::new();
|
||||
|
||||
// Queue entries: (node, direction_is_up)
|
||||
let mut queue: VecDeque<(u32, bool)> = VecDeque::new();
|
||||
|
||||
// Initialize with source nodes going up (as if we observed them)
|
||||
for &node in source {
|
||||
queue.push_back((node, true)); // Going up from source
|
||||
queue.push_back((node, false)); // Going down from source
|
||||
}
|
||||
|
||||
while let Some((node, going_up)) = queue.pop_front() {
|
||||
// Skip if already visited in this direction
|
||||
if going_up && visited_up.contains(&node) {
|
||||
continue;
|
||||
}
|
||||
if !going_up && visited_down.contains(&node) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if going_up {
|
||||
visited_up.insert(node);
|
||||
} else {
|
||||
visited_down.insert(node);
|
||||
}
|
||||
|
||||
let is_evidence = evidence.contains(&node);
|
||||
|
||||
if going_up && !is_evidence {
|
||||
// Ball going up, node not observed: continue to parents and children
|
||||
reachable.insert(node);
|
||||
|
||||
if let Some(parents) = self.parents.get(&node) {
|
||||
for &parent in parents {
|
||||
queue.push_back((parent, true));
|
||||
}
|
||||
}
|
||||
if let Some(children) = self.children.get(&node) {
|
||||
for &child in children {
|
||||
queue.push_back((child, false));
|
||||
}
|
||||
}
|
||||
} else if going_up && is_evidence {
|
||||
// Ball going up, node observed: continue to parents only
|
||||
if let Some(parents) = self.parents.get(&node) {
|
||||
for &parent in parents {
|
||||
queue.push_back((parent, true));
|
||||
}
|
||||
}
|
||||
} else if !going_up && !is_evidence {
|
||||
// Ball going down, node not observed: continue to children only
|
||||
reachable.insert(node);
|
||||
|
||||
if let Some(children) = self.children.get(&node) {
|
||||
for &child in children {
|
||||
queue.push_back((child, false));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Ball going down, node observed: bounce back up to parents
|
||||
reachable.insert(node);
|
||||
|
||||
if let Some(parents) = self.parents.get(&node) {
|
||||
for &parent in parents {
|
||||
queue.push_back((parent, true));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
reachable
|
||||
}
|
||||
|
||||
/// Find all paths between two nodes
|
||||
pub fn find_all_paths(&self, from: u32, to: u32, max_length: usize) -> Vec<Vec<u32>> {
|
||||
let mut all_paths = Vec::new();
|
||||
let mut current_path = vec![from];
|
||||
|
||||
self.find_paths_dfs(from, to, &mut current_path, &mut all_paths, max_length);
|
||||
|
||||
all_paths
|
||||
}
|
||||
|
||||
fn find_paths_dfs(
|
||||
&self,
|
||||
current: u32,
|
||||
target: u32,
|
||||
path: &mut Vec<u32>,
|
||||
all_paths: &mut Vec<Vec<u32>>,
|
||||
max_length: usize,
|
||||
) {
|
||||
if current == target {
|
||||
all_paths.push(path.clone());
|
||||
return;
|
||||
}
|
||||
|
||||
if path.len() >= max_length {
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(children) = self.children.get(¤t) {
|
||||
for &child in children {
|
||||
if !path.contains(&child) {
|
||||
path.push(child);
|
||||
self.find_paths_dfs(child, target, path, all_paths, max_length);
|
||||
path.pop();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the skeleton (undirected version) of the graph
|
||||
pub fn skeleton(&self) -> HashSet<(u32, u32)> {
|
||||
let mut skeleton = HashSet::new();
|
||||
|
||||
for (&from, children) in &self.children {
|
||||
for &to in children {
|
||||
let edge = if from < to { (from, to) } else { (to, from) };
|
||||
skeleton.insert(edge);
|
||||
}
|
||||
}
|
||||
|
||||
skeleton
|
||||
}
|
||||
|
||||
/// Find all v-structures (colliders) in the graph
|
||||
///
|
||||
/// A v-structure is a triple (A, B, C) where A -> B <- C and A and C are not adjacent
|
||||
pub fn v_structures(&self) -> Vec<(u32, u32, u32)> {
|
||||
let mut v_structs = Vec::new();
|
||||
let skeleton = self.skeleton();
|
||||
|
||||
for (&node, parents) in &self.parents {
|
||||
if parents.len() < 2 {
|
||||
continue;
|
||||
}
|
||||
|
||||
let parents_vec: Vec<_> = parents.iter().copied().collect();
|
||||
|
||||
for i in 0..parents_vec.len() {
|
||||
for j in (i + 1)..parents_vec.len() {
|
||||
let p1 = parents_vec[i];
|
||||
let p2 = parents_vec[j];
|
||||
|
||||
// Check if parents are not adjacent
|
||||
let edge = if p1 < p2 { (p1, p2) } else { (p2, p1) };
|
||||
if !skeleton.contains(&edge) {
|
||||
v_structs.push((p1, node, p2));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
v_structs
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DirectedGraph {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Topological ordering of nodes in a DAG
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TopologicalOrder {
|
||||
order: Vec<u32>,
|
||||
}
|
||||
|
||||
impl TopologicalOrder {
|
||||
/// Get the ordering as a slice
|
||||
pub fn as_slice(&self) -> &[u32] {
|
||||
&self.order
|
||||
}
|
||||
|
||||
/// Get the position of a node in the ordering
|
||||
pub fn position(&self, node: u32) -> Option<usize> {
|
||||
self.order.iter().position(|&n| n == node)
|
||||
}
|
||||
|
||||
/// Check if node A comes before node B in the ordering
|
||||
pub fn comes_before(&self, a: u32, b: u32) -> bool {
|
||||
match (self.position(a), self.position(b)) {
|
||||
(Some(pos_a), Some(pos_b)) => pos_a < pos_b,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Iterate over nodes in topological order
|
||||
pub fn iter(&self) -> impl Iterator<Item = &u32> {
|
||||
self.order.iter()
|
||||
}
|
||||
|
||||
/// Get number of nodes
|
||||
pub fn len(&self) -> usize {
|
||||
self.order.len()
|
||||
}
|
||||
|
||||
/// Check if ordering is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.order.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoIterator for TopologicalOrder {
|
||||
type Item = u32;
|
||||
type IntoIter = std::vec::IntoIter<u32>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.order.into_iter()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_add_nodes_and_edges() {
|
||||
let mut graph = DirectedGraph::new();
|
||||
graph.add_node(0);
|
||||
graph.add_node(1);
|
||||
graph.add_edge(0, 1).unwrap();
|
||||
|
||||
assert!(graph.contains_node(0));
|
||||
assert!(graph.contains_node(1));
|
||||
assert!(graph.contains_edge(0, 1));
|
||||
assert!(!graph.contains_edge(1, 0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cycle_detection() {
|
||||
let mut graph = DirectedGraph::new();
|
||||
graph.add_edge(0, 1).unwrap();
|
||||
graph.add_edge(1, 2).unwrap();
|
||||
|
||||
// This should fail - would create cycle
|
||||
let result = graph.add_edge(2, 0);
|
||||
assert!(matches!(result, Err(DAGValidationError::CycleDetected(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_self_loop_detection() {
|
||||
let mut graph = DirectedGraph::new();
|
||||
let result = graph.add_edge(0, 0);
|
||||
assert!(matches!(result, Err(DAGValidationError::SelfLoop(0))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_topological_order() {
|
||||
let mut graph = DirectedGraph::new();
|
||||
// Diamond: 0 -> 1, 0 -> 2, 1 -> 3, 2 -> 3
|
||||
graph.add_edge(0, 1).unwrap();
|
||||
graph.add_edge(0, 2).unwrap();
|
||||
graph.add_edge(1, 3).unwrap();
|
||||
graph.add_edge(2, 3).unwrap();
|
||||
|
||||
let order = graph.topological_order().unwrap();
|
||||
|
||||
assert_eq!(order.len(), 4);
|
||||
assert!(order.comes_before(0, 1));
|
||||
assert!(order.comes_before(0, 2));
|
||||
assert!(order.comes_before(1, 3));
|
||||
assert!(order.comes_before(2, 3));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ancestors_and_descendants() {
|
||||
let mut graph = DirectedGraph::new();
|
||||
// Chain: 0 -> 1 -> 2 -> 3
|
||||
graph.add_edge(0, 1).unwrap();
|
||||
graph.add_edge(1, 2).unwrap();
|
||||
graph.add_edge(2, 3).unwrap();
|
||||
|
||||
let ancestors = graph.ancestors(3);
|
||||
assert!(ancestors.contains(&0));
|
||||
assert!(ancestors.contains(&1));
|
||||
assert!(ancestors.contains(&2));
|
||||
assert!(!ancestors.contains(&3));
|
||||
|
||||
let descendants = graph.descendants(0);
|
||||
assert!(descendants.contains(&1));
|
||||
assert!(descendants.contains(&2));
|
||||
assert!(descendants.contains(&3));
|
||||
assert!(!descendants.contains(&0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_d_separation_chain() {
|
||||
// Chain: X -> Z -> Y
|
||||
// X and Y should be d-separated given Z
|
||||
let mut graph = DirectedGraph::new();
|
||||
graph.add_node_with_label(0, "X");
|
||||
graph.add_node_with_label(1, "Z");
|
||||
graph.add_node_with_label(2, "Y");
|
||||
graph.add_edge(0, 1).unwrap();
|
||||
graph.add_edge(1, 2).unwrap();
|
||||
|
||||
let x: HashSet<u32> = [0].into_iter().collect();
|
||||
let y: HashSet<u32> = [2].into_iter().collect();
|
||||
let z: HashSet<u32> = [1].into_iter().collect();
|
||||
let empty: HashSet<u32> = HashSet::new();
|
||||
|
||||
// X and Y are NOT d-separated given empty set
|
||||
assert!(!graph.d_separated(&x, &y, &empty));
|
||||
|
||||
// X and Y ARE d-separated given Z
|
||||
assert!(graph.d_separated(&x, &y, &z));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_d_separation_fork() {
|
||||
// Fork: X <- Z -> Y
|
||||
// X and Y should be d-separated given Z
|
||||
let mut graph = DirectedGraph::new();
|
||||
graph.add_node_with_label(0, "X");
|
||||
graph.add_node_with_label(1, "Z");
|
||||
graph.add_node_with_label(2, "Y");
|
||||
graph.add_edge(1, 0).unwrap();
|
||||
graph.add_edge(1, 2).unwrap();
|
||||
|
||||
let x: HashSet<u32> = [0].into_iter().collect();
|
||||
let y: HashSet<u32> = [2].into_iter().collect();
|
||||
let z: HashSet<u32> = [1].into_iter().collect();
|
||||
let empty: HashSet<u32> = HashSet::new();
|
||||
|
||||
// X and Y are NOT d-separated given empty set
|
||||
assert!(!graph.d_separated(&x, &y, &empty));
|
||||
|
||||
// X and Y ARE d-separated given Z
|
||||
assert!(graph.d_separated(&x, &y, &z));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_d_separation_collider() {
|
||||
// Collider: X -> Z <- Y
|
||||
// X and Y should NOT be d-separated given Z (explaining away)
|
||||
let mut graph = DirectedGraph::new();
|
||||
graph.add_node_with_label(0, "X");
|
||||
graph.add_node_with_label(1, "Z");
|
||||
graph.add_node_with_label(2, "Y");
|
||||
graph.add_edge(0, 1).unwrap();
|
||||
graph.add_edge(2, 1).unwrap();
|
||||
|
||||
let x: HashSet<u32> = [0].into_iter().collect();
|
||||
let y: HashSet<u32> = [2].into_iter().collect();
|
||||
let z: HashSet<u32> = [1].into_iter().collect();
|
||||
let empty: HashSet<u32> = HashSet::new();
|
||||
|
||||
// X and Y ARE d-separated given empty set (collider blocks)
|
||||
assert!(graph.d_separated(&x, &y, &empty));
|
||||
|
||||
// X and Y are NOT d-separated given Z (conditioning opens collider)
|
||||
assert!(!graph.d_separated(&x, &y, &z));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_v_structures() {
|
||||
// Collider: X -> Z <- Y
|
||||
let mut graph = DirectedGraph::new();
|
||||
graph.add_edge(0, 2).unwrap(); // X -> Z
|
||||
graph.add_edge(1, 2).unwrap(); // Y -> Z
|
||||
|
||||
let v_structs = graph.v_structures();
|
||||
|
||||
assert_eq!(v_structs.len(), 1);
|
||||
let (a, b, c) = v_structs[0];
|
||||
assert_eq!(b, 2); // Z is the collider
|
||||
assert!(a == 0 || a == 1);
|
||||
assert!(c == 0 || c == 1);
|
||||
assert_ne!(a, c);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_labels() {
|
||||
let mut graph = DirectedGraph::new();
|
||||
graph.add_node_with_label(0, "Age");
|
||||
graph.add_node_with_label(1, "Income");
|
||||
graph.add_edge(0, 1).unwrap();
|
||||
|
||||
assert_eq!(graph.get_label(0), Some("Age"));
|
||||
assert_eq!(graph.get_label(1), Some("Income"));
|
||||
assert_eq!(graph.find_node_by_label("Age"), Some(0));
|
||||
assert_eq!(graph.find_node_by_label("Unknown"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_all_paths() {
|
||||
let mut graph = DirectedGraph::new();
|
||||
// Diamond: 0 -> 1, 0 -> 2, 1 -> 3, 2 -> 3
|
||||
graph.add_edge(0, 1).unwrap();
|
||||
graph.add_edge(0, 2).unwrap();
|
||||
graph.add_edge(1, 3).unwrap();
|
||||
graph.add_edge(2, 3).unwrap();
|
||||
|
||||
let paths = graph.find_all_paths(0, 3, 10);
|
||||
|
||||
assert_eq!(paths.len(), 2);
|
||||
assert!(paths.contains(&vec![0, 1, 3]));
|
||||
assert!(paths.contains(&vec![0, 2, 3]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_skeleton() {
|
||||
let mut graph = DirectedGraph::new();
|
||||
graph.add_edge(0, 1).unwrap();
|
||||
graph.add_edge(1, 2).unwrap();
|
||||
graph.add_edge(2, 0).ok(); // This will fail due to cycle
|
||||
|
||||
// Add a valid edge instead
|
||||
graph.add_edge(0, 2).unwrap();
|
||||
|
||||
let skeleton = graph.skeleton();
|
||||
|
||||
assert_eq!(skeleton.len(), 3);
|
||||
assert!(skeleton.contains(&(0, 1)));
|
||||
assert!(skeleton.contains(&(0, 2)));
|
||||
assert!(skeleton.contains(&(1, 2)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_remove_edge() {
|
||||
let mut graph = DirectedGraph::new();
|
||||
graph.add_edge(0, 1).unwrap();
|
||||
graph.add_edge(1, 2).unwrap();
|
||||
|
||||
assert!(graph.contains_edge(0, 1));
|
||||
graph.remove_edge(0, 1).unwrap();
|
||||
assert!(!graph.contains_edge(0, 1));
|
||||
}
|
||||
}
|
||||
271
vendor/ruvector/examples/prime-radiant/src/causal/mod.rs
vendored
Normal file
271
vendor/ruvector/examples/prime-radiant/src/causal/mod.rs
vendored
Normal file
@@ -0,0 +1,271 @@
|
||||
//! Causal Abstraction Networks for Prime-Radiant
|
||||
//!
|
||||
//! This module implements causal reasoning primitives based on structural causal models
|
||||
//! (SCMs), causal abstraction theory, and do-calculus. Key capabilities:
|
||||
//!
|
||||
//! - **CausalModel**: Directed acyclic graph (DAG) of causal relationships with
|
||||
//! structural equations defining each variable as a function of its parents
|
||||
//! - **CausalAbstraction**: Maps between low-level and high-level causal models,
|
||||
//! preserving interventional semantics
|
||||
//! - **CausalCoherenceChecker**: Validates causal consistency of beliefs and detects
|
||||
//! spurious correlations
|
||||
//! - **Counterfactual Reasoning**: Computes counterfactual queries and causal effects
|
||||
//!
|
||||
//! ## Architecture
|
||||
//!
|
||||
//! The causal module integrates with Prime-Radiant's sheaf-theoretic framework:
|
||||
//!
|
||||
//! ```text
|
||||
//! ┌─────────────────────────────────────────────────────────────────┐
|
||||
//! │ Prime-Radiant Core │
|
||||
//! ├─────────────────────────────────────────────────────────────────┤
|
||||
//! │ SheafGraph ◄──── causal_coherence_energy ────► CausalModel │
|
||||
//! │ │ │ │
|
||||
//! │ ▼ ▼ │
|
||||
//! │ CoherenceEnergy CausalAbstraction │
|
||||
//! │ │ │ │
|
||||
//! │ └───────────► Combined Coherence ◄──────────────┘ │
|
||||
//! └─────────────────────────────────────────────────────────────────┘
|
||||
//! ```
|
||||
//!
|
||||
//! ## Usage
|
||||
//!
|
||||
//! ```rust,ignore
|
||||
//! use prime_radiant::causal::{CausalModel, Intervention, counterfactual};
|
||||
//!
|
||||
//! // Build a causal model
|
||||
//! let mut model = CausalModel::new();
|
||||
//! model.add_variable("X", VariableType::Continuous);
|
||||
//! model.add_variable("Y", VariableType::Continuous);
|
||||
//! model.add_structural_equation("Y", &["X"], |parents| {
|
||||
//! Value::Continuous(2.0 * parents[0].as_f64() + 0.5)
|
||||
//! });
|
||||
//!
|
||||
//! // Perform intervention do(X = 1.0)
|
||||
//! let intervention = Intervention::new("X", Value::Continuous(1.0));
|
||||
//! let result = model.intervene(&intervention);
|
||||
//!
|
||||
//! // Compute counterfactual
|
||||
//! let observation = Observation::new(&[("Y", Value::Continuous(3.0))]);
|
||||
//! let cf = counterfactual(&model, &observation, &intervention);
|
||||
//! ```
|
||||
|
||||
pub mod model;
|
||||
pub mod abstraction;
|
||||
pub mod coherence;
|
||||
pub mod counterfactual;
|
||||
pub mod graph;
|
||||
pub mod do_calculus;
|
||||
|
||||
// Re-exports
|
||||
pub use model::{
|
||||
CausalModel, StructuralEquation, Variable, VariableId, VariableType, Value,
|
||||
Mechanism, CausalModelError, MutilatedModel, Distribution, Observation,
|
||||
IntervenedModel, CausalModelBuilder, Intervention,
|
||||
};
|
||||
pub use abstraction::{
|
||||
CausalAbstraction, AbstractionMap, AbstractionError, ConsistencyResult,
|
||||
};
|
||||
pub use coherence::{
|
||||
CausalCoherenceChecker, CausalConsistency, SpuriousCorrelation, Belief,
|
||||
CausalQuery, CausalAnswer, CoherenceEnergy,
|
||||
};
|
||||
pub use counterfactual::{
|
||||
counterfactual, causal_effect,
|
||||
CounterfactualQuery, AverageTreatmentEffect,
|
||||
};
|
||||
pub use graph::{DirectedGraph, TopologicalOrder, DAGValidationError};
|
||||
pub use do_calculus::{DoCalculus, Rule, Identification, IdentificationError};
|
||||
|
||||
/// Integration with Prime-Radiant's sheaf-theoretic framework
|
||||
pub mod integration {
|
||||
use super::*;
|
||||
|
||||
/// Placeholder for SheafGraph from the main Prime-Radiant module
|
||||
pub struct SheafGraph {
|
||||
pub nodes: Vec<String>,
|
||||
pub edges: Vec<(usize, usize)>,
|
||||
pub sections: Vec<Vec<f64>>,
|
||||
}
|
||||
|
||||
/// Compute combined coherence energy from structural and causal consistency
|
||||
///
|
||||
/// This function bridges Prime-Radiant's sheaf cohomology with causal structure:
|
||||
/// - Sheaf consistency measures local-to-global coherence of beliefs
|
||||
/// - Causal consistency measures alignment with causal structure
|
||||
///
|
||||
/// The combined energy is minimized when both constraints are satisfied.
|
||||
pub fn causal_coherence_energy(
|
||||
sheaf_graph: &SheafGraph,
|
||||
causal_model: &CausalModel,
|
||||
) -> CoherenceEnergy {
|
||||
// Compute structural coherence from sheaf
|
||||
let structural_energy = compute_structural_energy(sheaf_graph);
|
||||
|
||||
// Compute causal coherence
|
||||
let causal_energy = compute_causal_energy(sheaf_graph, causal_model);
|
||||
|
||||
// Compute intervention consistency
|
||||
let intervention_energy = compute_intervention_energy(sheaf_graph, causal_model);
|
||||
|
||||
CoherenceEnergy {
|
||||
total: structural_energy + causal_energy + intervention_energy,
|
||||
structural_component: structural_energy,
|
||||
causal_component: causal_energy,
|
||||
intervention_component: intervention_energy,
|
||||
is_coherent: (structural_energy + causal_energy + intervention_energy) < 1e-6,
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_structural_energy(sheaf: &SheafGraph) -> f64 {
|
||||
// Measure deviation from local consistency
|
||||
let mut energy = 0.0;
|
||||
|
||||
for (i, j) in &sheaf.edges {
|
||||
if *i < sheaf.sections.len() && *j < sheaf.sections.len() {
|
||||
let section_i = &sheaf.sections[*i];
|
||||
let section_j = &sheaf.sections[*j];
|
||||
|
||||
// Compute L2 difference (simplified restriction map)
|
||||
let min_len = section_i.len().min(section_j.len());
|
||||
for k in 0..min_len {
|
||||
let diff = section_i[k] - section_j[k];
|
||||
energy += diff * diff;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
energy
|
||||
}
|
||||
|
||||
fn compute_causal_energy(sheaf: &SheafGraph, model: &CausalModel) -> f64 {
|
||||
// Check that sheaf structure respects causal ordering
|
||||
let mut energy = 0.0;
|
||||
|
||||
if let Ok(topo_order) = model.topological_order() {
|
||||
let order_map: std::collections::HashMap<_, _> = topo_order
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, v)| (v.clone(), i))
|
||||
.collect();
|
||||
|
||||
// Penalize edges that violate causal ordering
|
||||
for (i, j) in &sheaf.edges {
|
||||
if *i < sheaf.nodes.len() && *j < sheaf.nodes.len() {
|
||||
let node_i = &sheaf.nodes[*i];
|
||||
let node_j = &sheaf.nodes[*j];
|
||||
|
||||
if let (Some(&order_i), Some(&order_j)) =
|
||||
(order_map.get(node_i), order_map.get(node_j))
|
||||
{
|
||||
// Edge from j to i should have order_j < order_i
|
||||
if order_j > order_i {
|
||||
energy += 1.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
energy
|
||||
}
|
||||
|
||||
fn compute_intervention_energy(sheaf: &SheafGraph, model: &CausalModel) -> f64 {
|
||||
// Verify that interventions propagate correctly through sheaf
|
||||
let mut energy = 0.0;
|
||||
|
||||
// For each potential intervention point, check consistency
|
||||
for (i, node) in sheaf.nodes.iter().enumerate() {
|
||||
if let Some(var_id) = model.get_variable_id(node) {
|
||||
if let Some(children) = model.children(&var_id) {
|
||||
for child in children {
|
||||
if let Some(child_name) = model.get_variable_name(&child) {
|
||||
// Find corresponding sheaf node
|
||||
if let Some(j) = sheaf.nodes.iter().position(|n| n == &child_name) {
|
||||
// Check if intervention effect is consistent
|
||||
if i < sheaf.sections.len() && j < sheaf.sections.len() {
|
||||
let parent_section = &sheaf.sections[i];
|
||||
let child_section = &sheaf.sections[j];
|
||||
|
||||
// Simple check: child should be influenced by parent
|
||||
if !parent_section.is_empty() && !child_section.is_empty() {
|
||||
// Correlation check (simplified)
|
||||
let correlation = compute_correlation(parent_section, child_section);
|
||||
if correlation.abs() < 0.01 {
|
||||
energy += 0.1; // Weak causal link penalty
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
energy
|
||||
}
|
||||
|
||||
fn compute_correlation(a: &[f64], b: &[f64]) -> f64 {
|
||||
let n = a.len().min(b.len());
|
||||
if n == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let mean_a: f64 = a.iter().take(n).sum::<f64>() / n as f64;
|
||||
let mean_b: f64 = b.iter().take(n).sum::<f64>() / n as f64;
|
||||
|
||||
let mut cov = 0.0;
|
||||
let mut var_a = 0.0;
|
||||
let mut var_b = 0.0;
|
||||
|
||||
for i in 0..n {
|
||||
let da = a[i] - mean_a;
|
||||
let db = b[i] - mean_b;
|
||||
cov += da * db;
|
||||
var_a += da * da;
|
||||
var_b += db * db;
|
||||
}
|
||||
|
||||
let denom = (var_a * var_b).sqrt();
|
||||
if denom < 1e-10 {
|
||||
0.0
|
||||
} else {
|
||||
cov / denom
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use super::integration::*;
|
||||
|
||||
#[test]
|
||||
fn test_module_exports() {
|
||||
// Verify all public types are accessible
|
||||
let _var_id: VariableId = VariableId(0);
|
||||
let _value = Value::Continuous(1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_causal_coherence_energy() {
|
||||
let sheaf = SheafGraph {
|
||||
nodes: vec!["X".to_string(), "Y".to_string()],
|
||||
edges: vec![(0, 1)],
|
||||
sections: vec![vec![1.0, 2.0], vec![2.0, 4.0]],
|
||||
};
|
||||
|
||||
let mut model = CausalModel::new();
|
||||
model.add_variable("X", VariableType::Continuous).unwrap();
|
||||
model.add_variable("Y", VariableType::Continuous).unwrap();
|
||||
let x_id = model.get_variable_id("X").unwrap();
|
||||
let y_id = model.get_variable_id("Y").unwrap();
|
||||
model.add_edge(x_id, y_id).unwrap();
|
||||
|
||||
let energy = causal_coherence_energy(&sheaf, &model);
|
||||
|
||||
assert!(energy.structural_component >= 0.0);
|
||||
assert!(energy.causal_component >= 0.0);
|
||||
}
|
||||
}
|
||||
1211
vendor/ruvector/examples/prime-radiant/src/causal/model.rs
vendored
Normal file
1211
vendor/ruvector/examples/prime-radiant/src/causal/model.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user