974 lines
32 KiB
Rust
974 lines
32 KiB
Rust
//! Causal Coherence Checking
|
|
//!
|
|
//! This module provides tools for verifying that beliefs and data are
|
|
//! consistent with a causal model. Key capabilities:
|
|
//!
|
|
//! - Detecting spurious correlations (associations not explained by causation)
|
|
//! - Checking if beliefs satisfy causal constraints
|
|
//! - Answering causal queries using do-calculus
|
|
//! - Computing coherence energy for integration with Prime-Radiant
|
|
|
|
use std::collections::{HashMap, HashSet};
|
|
use thiserror::Error;
|
|
|
|
use super::model::{CausalModel, CausalModelError, Value, VariableId, VariableType, Mechanism};
|
|
use super::graph::DAGValidationError;
|
|
|
|
/// Error types for coherence operations
|
|
#[derive(Debug, Clone, Error)]
|
|
pub enum CoherenceError {
|
|
/// Model error
|
|
#[error("Model error: {0}")]
|
|
ModelError(#[from] CausalModelError),
|
|
|
|
/// Graph error
|
|
#[error("Graph error: {0}")]
|
|
GraphError(#[from] DAGValidationError),
|
|
|
|
/// Inconsistent belief
|
|
#[error("Inconsistent belief: {0}")]
|
|
InconsistentBelief(String),
|
|
|
|
/// Invalid query
|
|
#[error("Invalid query: {0}")]
|
|
InvalidQuery(String),
|
|
}
|
|
|
|
/// A belief about the relationship between variables
|
|
#[derive(Debug, Clone)]
|
|
pub struct Belief {
|
|
/// Subject variable
|
|
pub subject: String,
|
|
/// Object variable
|
|
pub object: String,
|
|
/// Type of belief
|
|
pub belief_type: BeliefType,
|
|
/// Confidence in the belief (0.0 to 1.0)
|
|
pub confidence: f64,
|
|
/// Evidence supporting the belief
|
|
pub evidence: Option<String>,
|
|
}
|
|
|
|
/// Types of causal beliefs
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub enum BeliefType {
|
|
/// X causes Y
|
|
Causes,
|
|
/// X is correlated with Y (may or may not be causal)
|
|
CorrelatedWith,
|
|
/// X is independent of Y
|
|
IndependentOf,
|
|
/// X is independent of Y given Z
|
|
ConditionallyIndependent { given: Vec<String> },
|
|
/// X and Y have a common cause
|
|
CommonCause,
|
|
/// Changing X would change Y (interventional)
|
|
WouldChange,
|
|
}
|
|
|
|
impl Belief {
|
|
/// Create a causal belief: X causes Y
|
|
pub fn causes(x: &str, y: &str) -> Self {
|
|
Self {
|
|
subject: x.to_string(),
|
|
object: y.to_string(),
|
|
belief_type: BeliefType::Causes,
|
|
confidence: 1.0,
|
|
evidence: None,
|
|
}
|
|
}
|
|
|
|
/// Create a correlation belief
|
|
pub fn correlated(x: &str, y: &str) -> Self {
|
|
Self {
|
|
subject: x.to_string(),
|
|
object: y.to_string(),
|
|
belief_type: BeliefType::CorrelatedWith,
|
|
confidence: 1.0,
|
|
evidence: None,
|
|
}
|
|
}
|
|
|
|
/// Create an independence belief
|
|
pub fn independent(x: &str, y: &str) -> Self {
|
|
Self {
|
|
subject: x.to_string(),
|
|
object: y.to_string(),
|
|
belief_type: BeliefType::IndependentOf,
|
|
confidence: 1.0,
|
|
evidence: None,
|
|
}
|
|
}
|
|
|
|
/// Create a conditional independence belief
|
|
pub fn conditionally_independent(x: &str, y: &str, given: &[&str]) -> Self {
|
|
Self {
|
|
subject: x.to_string(),
|
|
object: y.to_string(),
|
|
belief_type: BeliefType::ConditionallyIndependent {
|
|
given: given.iter().map(|s| s.to_string()).collect(),
|
|
},
|
|
confidence: 1.0,
|
|
evidence: None,
|
|
}
|
|
}
|
|
|
|
/// Set confidence level
|
|
pub fn with_confidence(mut self, confidence: f64) -> Self {
|
|
self.confidence = confidence.clamp(0.0, 1.0);
|
|
self
|
|
}
|
|
|
|
/// Set evidence
|
|
pub fn with_evidence(mut self, evidence: &str) -> Self {
|
|
self.evidence = Some(evidence.to_string());
|
|
self
|
|
}
|
|
}
|
|
|
|
/// Result of causal consistency checking
|
|
#[derive(Debug, Clone)]
|
|
pub struct CausalConsistency {
|
|
/// Overall consistency score (0.0 to 1.0)
|
|
pub score: f64,
|
|
/// Number of beliefs checked
|
|
pub beliefs_checked: usize,
|
|
/// Number of consistent beliefs
|
|
pub consistent_beliefs: usize,
|
|
/// Number of inconsistent beliefs
|
|
pub inconsistent_beliefs: usize,
|
|
/// Details of inconsistencies
|
|
pub inconsistencies: Vec<Inconsistency>,
|
|
/// Suggested model modifications
|
|
pub suggestions: Vec<String>,
|
|
}
|
|
|
|
impl CausalConsistency {
|
|
/// Create a fully consistent result
|
|
pub fn fully_consistent(beliefs_checked: usize) -> Self {
|
|
Self {
|
|
score: 1.0,
|
|
beliefs_checked,
|
|
consistent_beliefs: beliefs_checked,
|
|
inconsistent_beliefs: 0,
|
|
inconsistencies: vec![],
|
|
suggestions: vec![],
|
|
}
|
|
}
|
|
|
|
/// Check if fully consistent
|
|
pub fn is_consistent(&self) -> bool {
|
|
self.score >= 1.0 - 1e-10
|
|
}
|
|
}
|
|
|
|
/// Details of a causal inconsistency
|
|
#[derive(Debug, Clone)]
|
|
pub struct Inconsistency {
|
|
/// The belief that is inconsistent
|
|
pub belief: Belief,
|
|
/// Why it's inconsistent
|
|
pub reason: String,
|
|
/// Severity (0.0 to 1.0)
|
|
pub severity: f64,
|
|
}
|
|
|
|
/// A detected spurious correlation
|
|
#[derive(Debug, Clone)]
|
|
pub struct SpuriousCorrelation {
|
|
/// First variable
|
|
pub var_a: String,
|
|
/// Second variable
|
|
pub var_b: String,
|
|
/// The common cause(s) explaining the correlation
|
|
pub confounders: Vec<String>,
|
|
/// Strength of the spurious correlation
|
|
pub strength: f64,
|
|
/// Explanation
|
|
pub explanation: String,
|
|
}
|
|
|
|
/// A causal query
|
|
#[derive(Debug, Clone)]
|
|
pub struct CausalQuery {
|
|
/// The variable we're asking about
|
|
pub target: String,
|
|
/// Variables we're intervening on
|
|
pub interventions: Vec<(String, Value)>,
|
|
/// Variables we're conditioning on
|
|
pub conditions: Vec<(String, Value)>,
|
|
/// Query type
|
|
pub query_type: QueryType,
|
|
}
|
|
|
|
/// Types of causal queries
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub enum QueryType {
|
|
/// P(Y | do(X=x)) - interventional query
|
|
Interventional,
|
|
/// P(Y | X=x) - observational query
|
|
Observational,
|
|
/// P(Y_x | X=x') - counterfactual query
|
|
Counterfactual,
|
|
/// P(Y | do(X=x), Z=z) - conditional interventional
|
|
ConditionalInterventional,
|
|
}
|
|
|
|
impl CausalQuery {
|
|
/// Create an interventional query: P(target | do(intervention))
|
|
pub fn interventional(target: &str, intervention_var: &str, intervention_val: Value) -> Self {
|
|
Self {
|
|
target: target.to_string(),
|
|
interventions: vec![(intervention_var.to_string(), intervention_val)],
|
|
conditions: vec![],
|
|
query_type: QueryType::Interventional,
|
|
}
|
|
}
|
|
|
|
/// Create an observational query: P(target | condition)
|
|
pub fn observational(target: &str, condition_var: &str, condition_val: Value) -> Self {
|
|
Self {
|
|
target: target.to_string(),
|
|
interventions: vec![],
|
|
conditions: vec![(condition_var.to_string(), condition_val)],
|
|
query_type: QueryType::Observational,
|
|
}
|
|
}
|
|
|
|
/// Add a condition
|
|
pub fn given(mut self, var: &str, val: Value) -> Self {
|
|
self.conditions.push((var.to_string(), val));
|
|
self
|
|
}
|
|
}
|
|
|
|
/// Answer to a causal query
|
|
#[derive(Debug, Clone)]
|
|
pub struct CausalAnswer {
|
|
/// The query that was answered
|
|
pub query: CausalQuery,
|
|
/// The estimated value/distribution
|
|
pub estimate: Value,
|
|
/// Confidence interval (if applicable)
|
|
pub confidence_interval: Option<(f64, f64)>,
|
|
/// Whether the query is identifiable from observational data
|
|
pub is_identifiable: bool,
|
|
/// Explanation of the answer
|
|
pub explanation: String,
|
|
}
|
|
|
|
/// Combined coherence energy for integration with Prime-Radiant
|
|
#[derive(Debug, Clone)]
|
|
pub struct CoherenceEnergy {
|
|
/// Total energy (lower is more coherent)
|
|
pub total: f64,
|
|
/// Structural component (from sheaf consistency)
|
|
pub structural_component: f64,
|
|
/// Causal component (from causal consistency)
|
|
pub causal_component: f64,
|
|
/// Intervention component (from intervention consistency)
|
|
pub intervention_component: f64,
|
|
/// Whether the system is coherent (energy below threshold)
|
|
pub is_coherent: bool,
|
|
}
|
|
|
|
impl CoherenceEnergy {
|
|
/// Create a fully coherent state
|
|
pub fn coherent() -> Self {
|
|
Self {
|
|
total: 0.0,
|
|
structural_component: 0.0,
|
|
causal_component: 0.0,
|
|
intervention_component: 0.0,
|
|
is_coherent: true,
|
|
}
|
|
}
|
|
|
|
/// Create from individual components
|
|
pub fn from_components(structural: f64, causal: f64, intervention: f64) -> Self {
|
|
let total = structural + causal + intervention;
|
|
Self {
|
|
total,
|
|
structural_component: structural,
|
|
causal_component: causal,
|
|
intervention_component: intervention,
|
|
is_coherent: total < 1e-6,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Dataset for spurious correlation detection
|
|
#[derive(Debug, Clone)]
|
|
pub struct Dataset {
|
|
/// Column names
|
|
pub columns: Vec<String>,
|
|
/// Data rows (each row is a vector of values)
|
|
pub rows: Vec<Vec<f64>>,
|
|
}
|
|
|
|
impl Dataset {
|
|
/// Create a new dataset
|
|
pub fn new(columns: Vec<String>) -> Self {
|
|
Self {
|
|
columns,
|
|
rows: Vec::new(),
|
|
}
|
|
}
|
|
|
|
/// Add a row
|
|
pub fn add_row(&mut self, row: Vec<f64>) {
|
|
if row.len() == self.columns.len() {
|
|
self.rows.push(row);
|
|
}
|
|
}
|
|
|
|
/// Get column index
|
|
pub fn column_index(&self, name: &str) -> Option<usize> {
|
|
self.columns.iter().position(|c| c == name)
|
|
}
|
|
|
|
/// Get column values
|
|
pub fn column(&self, name: &str) -> Option<Vec<f64>> {
|
|
let idx = self.column_index(name)?;
|
|
Some(self.rows.iter().map(|row| row[idx]).collect())
|
|
}
|
|
|
|
/// Compute correlation between two columns
|
|
pub fn correlation(&self, col_a: &str, col_b: &str) -> Option<f64> {
|
|
let a = self.column(col_a)?;
|
|
let b = self.column(col_b)?;
|
|
|
|
if a.len() != b.len() || a.is_empty() {
|
|
return None;
|
|
}
|
|
|
|
let n = a.len() as f64;
|
|
let mean_a: f64 = a.iter().sum::<f64>() / n;
|
|
let mean_b: f64 = b.iter().sum::<f64>() / n;
|
|
|
|
let mut cov = 0.0;
|
|
let mut var_a = 0.0;
|
|
let mut var_b = 0.0;
|
|
|
|
for i in 0..a.len() {
|
|
let da = a[i] - mean_a;
|
|
let db = b[i] - mean_b;
|
|
cov += da * db;
|
|
var_a += da * da;
|
|
var_b += db * db;
|
|
}
|
|
|
|
let denom = (var_a * var_b).sqrt();
|
|
if denom < 1e-10 {
|
|
Some(0.0)
|
|
} else {
|
|
Some(cov / denom)
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Causal coherence checker
|
|
pub struct CausalCoherenceChecker<'a> {
|
|
/// The causal model
|
|
model: &'a CausalModel,
|
|
/// Correlation threshold for "significant" correlation
|
|
correlation_threshold: f64,
|
|
}
|
|
|
|
impl<'a> CausalCoherenceChecker<'a> {
|
|
/// Create a new checker
|
|
pub fn new(model: &'a CausalModel) -> Self {
|
|
Self {
|
|
model,
|
|
correlation_threshold: 0.1,
|
|
}
|
|
}
|
|
|
|
/// Set correlation threshold
|
|
pub fn with_correlation_threshold(mut self, threshold: f64) -> Self {
|
|
self.correlation_threshold = threshold;
|
|
self
|
|
}
|
|
|
|
/// Check if a set of beliefs is consistent with the causal model
|
|
pub fn check_causal_consistency(&self, beliefs: &[Belief]) -> CausalConsistency {
|
|
let mut consistent_count = 0;
|
|
let mut inconsistencies = Vec::new();
|
|
let mut suggestions = Vec::new();
|
|
|
|
for belief in beliefs {
|
|
match self.check_single_belief(belief) {
|
|
Ok(()) => consistent_count += 1,
|
|
Err(reason) => {
|
|
inconsistencies.push(Inconsistency {
|
|
belief: belief.clone(),
|
|
reason: reason.clone(),
|
|
severity: 1.0 - belief.confidence,
|
|
});
|
|
|
|
// Generate suggestion
|
|
if let Some(suggestion) = self.generate_suggestion(belief, &reason) {
|
|
suggestions.push(suggestion);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
let beliefs_checked = beliefs.len();
|
|
let score = if beliefs_checked > 0 {
|
|
consistent_count as f64 / beliefs_checked as f64
|
|
} else {
|
|
1.0
|
|
};
|
|
|
|
CausalConsistency {
|
|
score,
|
|
beliefs_checked,
|
|
consistent_beliefs: consistent_count,
|
|
inconsistent_beliefs: beliefs_checked - consistent_count,
|
|
inconsistencies,
|
|
suggestions,
|
|
}
|
|
}
|
|
|
|
/// Check a single belief against the model
|
|
fn check_single_belief(&self, belief: &Belief) -> Result<(), String> {
|
|
let subject_id = self.model.get_variable_id(&belief.subject)
|
|
.ok_or_else(|| format!("Variable '{}' not in model", belief.subject))?;
|
|
let object_id = self.model.get_variable_id(&belief.object)
|
|
.ok_or_else(|| format!("Variable '{}' not in model", belief.object))?;
|
|
|
|
match &belief.belief_type {
|
|
BeliefType::Causes => {
|
|
// Check if there's a directed path from subject to object
|
|
let descendants = self.model.graph().descendants(subject_id.0);
|
|
if !descendants.contains(&object_id.0) {
|
|
return Err(format!(
|
|
"No causal path from {} to {} in model",
|
|
belief.subject, belief.object
|
|
));
|
|
}
|
|
}
|
|
|
|
BeliefType::IndependentOf => {
|
|
// Check if they're d-separated given empty set
|
|
if !self.model.d_separated(subject_id, object_id, &[]) {
|
|
return Err(format!(
|
|
"{} and {} are not independent according to model",
|
|
belief.subject, belief.object
|
|
));
|
|
}
|
|
}
|
|
|
|
BeliefType::ConditionallyIndependent { given } => {
|
|
let given_ids: Result<Vec<VariableId>, _> = given.iter()
|
|
.map(|name| {
|
|
self.model.get_variable_id(name)
|
|
.ok_or_else(|| format!("Variable '{}' not in model", name))
|
|
})
|
|
.collect();
|
|
let given_ids = given_ids?;
|
|
|
|
if !self.model.d_separated(subject_id, object_id, &given_ids) {
|
|
return Err(format!(
|
|
"{} and {} are not conditionally independent given {:?}",
|
|
belief.subject, belief.object, given
|
|
));
|
|
}
|
|
}
|
|
|
|
BeliefType::CommonCause => {
|
|
// Check if they share a common ancestor
|
|
let ancestors_a = self.model.graph().ancestors(subject_id.0);
|
|
let ancestors_b = self.model.graph().ancestors(object_id.0);
|
|
|
|
let common: HashSet<_> = ancestors_a.intersection(&ancestors_b).collect();
|
|
if common.is_empty() {
|
|
return Err(format!(
|
|
"No common cause found for {} and {}",
|
|
belief.subject, belief.object
|
|
));
|
|
}
|
|
}
|
|
|
|
BeliefType::CorrelatedWith | BeliefType::WouldChange => {
|
|
// These are not directly checkable against model structure alone
|
|
// They would need data or simulation
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Generate a suggestion for fixing an inconsistency
|
|
fn generate_suggestion(&self, belief: &Belief, _reason: &str) -> Option<String> {
|
|
match &belief.belief_type {
|
|
BeliefType::Causes => {
|
|
Some(format!(
|
|
"Consider adding edge {} -> {} to the model, or revising the belief",
|
|
belief.subject, belief.object
|
|
))
|
|
}
|
|
BeliefType::IndependentOf => {
|
|
Some(format!(
|
|
"Consider conditioning on a confounding variable, or revising the model structure"
|
|
))
|
|
}
|
|
BeliefType::ConditionallyIndependent { given } => {
|
|
Some(format!(
|
|
"The conditioning set {:?} may be insufficient; consider additional variables",
|
|
given
|
|
))
|
|
}
|
|
_ => None,
|
|
}
|
|
}
|
|
|
|
/// Detect spurious correlations in data given the causal model
|
|
pub fn detect_spurious_correlations(&self, data: &Dataset) -> Vec<SpuriousCorrelation> {
|
|
let mut spurious = Vec::new();
|
|
|
|
// Check all pairs of variables
|
|
for i in 0..data.columns.len() {
|
|
for j in (i + 1)..data.columns.len() {
|
|
let col_a = &data.columns[i];
|
|
let col_b = &data.columns[j];
|
|
|
|
// Get correlation from data
|
|
let correlation = match data.correlation(col_a, col_b) {
|
|
Some(c) => c,
|
|
None => continue,
|
|
};
|
|
|
|
// If significantly correlated
|
|
if correlation.abs() > self.correlation_threshold {
|
|
// Check if causally linked
|
|
if let (Some(id_a), Some(id_b)) = (
|
|
self.model.get_variable_id(col_a),
|
|
self.model.get_variable_id(col_b),
|
|
) {
|
|
// Check if there's a direct causal path
|
|
let a_causes_b = self.model.graph().descendants(id_a.0).contains(&id_b.0);
|
|
let b_causes_a = self.model.graph().descendants(id_b.0).contains(&id_a.0);
|
|
|
|
if !a_causes_b && !b_causes_a {
|
|
// Correlation without direct causation - find confounders
|
|
let confounders = self.find_confounders(id_a, id_b);
|
|
|
|
if !confounders.is_empty() {
|
|
spurious.push(SpuriousCorrelation {
|
|
var_a: col_a.clone(),
|
|
var_b: col_b.clone(),
|
|
confounders: confounders.clone(),
|
|
strength: correlation.abs(),
|
|
explanation: format!(
|
|
"Correlation (r={:.3}) between {} and {} is explained by common cause(s): {}",
|
|
correlation, col_a, col_b, confounders.join(", ")
|
|
),
|
|
});
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
spurious
|
|
}
|
|
|
|
/// Find common causes (confounders) of two variables
|
|
fn find_confounders(&self, a: VariableId, b: VariableId) -> Vec<String> {
|
|
let ancestors_a = self.model.graph().ancestors(a.0);
|
|
let ancestors_b = self.model.graph().ancestors(b.0);
|
|
|
|
let common: Vec<_> = ancestors_a.intersection(&ancestors_b)
|
|
.filter_map(|&id| self.model.get_variable_name(&VariableId(id)))
|
|
.collect();
|
|
|
|
common
|
|
}
|
|
|
|
/// Answer a causal query using do-calculus
|
|
pub fn enforce_do_calculus(&self, query: &CausalQuery) -> Result<CausalAnswer, CoherenceError> {
|
|
// Get target variable
|
|
let target_id = self.model.get_variable_id(&query.target)
|
|
.ok_or_else(|| CoherenceError::InvalidQuery(
|
|
format!("Target variable '{}' not in model", query.target)
|
|
))?;
|
|
|
|
match query.query_type {
|
|
QueryType::Interventional => {
|
|
self.answer_interventional_query(query, target_id)
|
|
}
|
|
QueryType::Observational => {
|
|
self.answer_observational_query(query, target_id)
|
|
}
|
|
QueryType::Counterfactual => {
|
|
self.answer_counterfactual_query(query, target_id)
|
|
}
|
|
QueryType::ConditionalInterventional => {
|
|
self.answer_conditional_interventional_query(query, target_id)
|
|
}
|
|
}
|
|
}
|
|
|
|
fn answer_interventional_query(
|
|
&self,
|
|
query: &CausalQuery,
|
|
target_id: VariableId,
|
|
) -> Result<CausalAnswer, CoherenceError> {
|
|
// Convert intervention specification to Intervention objects
|
|
let interventions: Result<Vec<_>, _> = query.interventions.iter()
|
|
.map(|(var, val)| {
|
|
self.model.get_variable_id(var)
|
|
.map(|id| super::model::Intervention::new(id, val.clone()))
|
|
.ok_or_else(|| CoherenceError::InvalidQuery(
|
|
format!("Intervention variable '{}' not in model", var)
|
|
))
|
|
})
|
|
.collect();
|
|
let interventions = interventions?;
|
|
|
|
// Perform intervention
|
|
let intervened = self.model.intervene_with(&interventions)?;
|
|
|
|
// Simulate to get the target value
|
|
let values = intervened.simulate(&HashMap::new())?;
|
|
|
|
let estimate = values.get(&target_id).cloned().unwrap_or(Value::Missing);
|
|
|
|
// Check identifiability
|
|
let is_identifiable = self.check_identifiability(query);
|
|
|
|
Ok(CausalAnswer {
|
|
query: query.clone(),
|
|
estimate,
|
|
confidence_interval: None,
|
|
is_identifiable,
|
|
explanation: format!(
|
|
"Computed P({} | do({})) by intervention simulation",
|
|
query.target,
|
|
query.interventions.iter()
|
|
.map(|(v, val)| format!("{}={:?}", v, val))
|
|
.collect::<Vec<_>>()
|
|
.join(", ")
|
|
),
|
|
})
|
|
}
|
|
|
|
fn answer_observational_query(
|
|
&self,
|
|
query: &CausalQuery,
|
|
target_id: VariableId,
|
|
) -> Result<CausalAnswer, CoherenceError> {
|
|
// For observational queries, we need to condition
|
|
// This requires probabilistic reasoning which we approximate
|
|
|
|
let explanation = format!(
|
|
"Observational query P({} | {}) - requires probabilistic inference",
|
|
query.target,
|
|
query.conditions.iter()
|
|
.map(|(v, val)| format!("{}={:?}", v, val))
|
|
.collect::<Vec<_>>()
|
|
.join(", ")
|
|
);
|
|
|
|
Ok(CausalAnswer {
|
|
query: query.clone(),
|
|
estimate: Value::Missing, // Would need actual probabilistic computation
|
|
confidence_interval: None,
|
|
is_identifiable: true, // Observational queries are always identifiable
|
|
explanation,
|
|
})
|
|
}
|
|
|
|
fn answer_counterfactual_query(
|
|
&self,
|
|
query: &CausalQuery,
|
|
_target_id: VariableId,
|
|
) -> Result<CausalAnswer, CoherenceError> {
|
|
// Counterfactual queries require abduction-action-prediction
|
|
let explanation = format!(
|
|
"Counterfactual query for {} - requires three-step process: abduction, action, prediction",
|
|
query.target
|
|
);
|
|
|
|
Ok(CausalAnswer {
|
|
query: query.clone(),
|
|
estimate: Value::Missing,
|
|
confidence_interval: None,
|
|
is_identifiable: false, // Counterfactuals often not identifiable
|
|
explanation,
|
|
})
|
|
}
|
|
|
|
fn answer_conditional_interventional_query(
|
|
&self,
|
|
query: &CausalQuery,
|
|
target_id: VariableId,
|
|
) -> Result<CausalAnswer, CoherenceError> {
|
|
// Combines intervention with conditioning
|
|
let explanation = format!(
|
|
"Conditional interventional query P({} | do({}), {}) - may require adjustment formula",
|
|
query.target,
|
|
query.interventions.iter()
|
|
.map(|(v, val)| format!("{}={:?}", v, val))
|
|
.collect::<Vec<_>>()
|
|
.join(", "),
|
|
query.conditions.iter()
|
|
.map(|(v, val)| format!("{}={:?}", v, val))
|
|
.collect::<Vec<_>>()
|
|
.join(", ")
|
|
);
|
|
|
|
Ok(CausalAnswer {
|
|
query: query.clone(),
|
|
estimate: Value::Missing,
|
|
confidence_interval: None,
|
|
is_identifiable: self.check_identifiability(query),
|
|
explanation,
|
|
})
|
|
}
|
|
|
|
/// Check if a causal query is identifiable from observational data
|
|
fn check_identifiability(&self, query: &CausalQuery) -> bool {
|
|
// Simplified identifiability check
|
|
// Full implementation would use do-calculus rules
|
|
|
|
if query.interventions.is_empty() {
|
|
return true; // Observational queries are identifiable
|
|
}
|
|
|
|
// Check if intervention variables have unobserved confounders with target
|
|
for (var, _) in &query.interventions {
|
|
if let (Some(var_id), Some(target_id)) = (
|
|
self.model.get_variable_id(var),
|
|
self.model.get_variable_id(&query.target),
|
|
) {
|
|
// If there's a backdoor path that can't be blocked, not identifiable
|
|
// This is a simplified check
|
|
let var_ancestors = self.model.graph().ancestors(var_id.0);
|
|
let target_ancestors = self.model.graph().ancestors(target_id.0);
|
|
|
|
// If they share unobserved common ancestors, might not be identifiable
|
|
let common = var_ancestors.intersection(&target_ancestors).count();
|
|
if common > 0 && !self.has_valid_adjustment_set(var_id, target_id) {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
|
|
true
|
|
}
|
|
|
|
/// Check if there's a valid adjustment set for identifying causal effect
|
|
fn has_valid_adjustment_set(&self, treatment: VariableId, outcome: VariableId) -> bool {
|
|
// Check backdoor criterion
|
|
// A set Z satisfies backdoor criterion if:
|
|
// 1. No node in Z is a descendant of X
|
|
// 2. Z blocks every path from X to Y that contains an arrow into X
|
|
|
|
let descendants = self.model.graph().descendants(treatment.0);
|
|
|
|
// Try the set of all non-descendants as potential adjustment set
|
|
let all_vars: Vec<_> = self.model.variables()
|
|
.filter(|v| v.id != treatment && v.id != outcome)
|
|
.filter(|v| !descendants.contains(&v.id.0))
|
|
.map(|v| v.id)
|
|
.collect();
|
|
|
|
// Check if conditioning on all non-descendants blocks backdoor paths
|
|
self.model.d_separated(treatment, outcome, &all_vars)
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::causal::model::{CausalModelBuilder, VariableType};
|
|
|
|
fn create_test_model() -> CausalModel {
|
|
let mut model = CausalModel::with_name("Test");
|
|
|
|
model.add_variable("Age", VariableType::Continuous).unwrap();
|
|
model.add_variable("Education", VariableType::Continuous).unwrap();
|
|
model.add_variable("Income", VariableType::Continuous).unwrap();
|
|
model.add_variable("Health", VariableType::Continuous).unwrap();
|
|
|
|
let age = model.get_variable_id("Age").unwrap();
|
|
let edu = model.get_variable_id("Education").unwrap();
|
|
let income = model.get_variable_id("Income").unwrap();
|
|
let health = model.get_variable_id("Health").unwrap();
|
|
|
|
// Age -> Education, Age -> Health
|
|
model.add_edge(age, edu).unwrap();
|
|
model.add_edge(age, health).unwrap();
|
|
|
|
// Education -> Income
|
|
model.add_edge(edu, income).unwrap();
|
|
|
|
// Add equations
|
|
model.add_structural_equation(edu, &[age], Mechanism::new(|p| {
|
|
Value::Continuous(12.0 + p[0].as_f64() * 0.1)
|
|
})).unwrap();
|
|
|
|
model.add_structural_equation(income, &[edu], Mechanism::new(|p| {
|
|
Value::Continuous(30000.0 + p[0].as_f64() * 5000.0)
|
|
})).unwrap();
|
|
|
|
model.add_structural_equation(health, &[age], Mechanism::new(|p| {
|
|
Value::Continuous(100.0 - p[0].as_f64() * 0.5)
|
|
})).unwrap();
|
|
|
|
model
|
|
}
|
|
|
|
#[test]
|
|
fn test_belief_creation() {
|
|
let belief = Belief::causes("Age", "Education").with_confidence(0.9);
|
|
assert_eq!(belief.subject, "Age");
|
|
assert_eq!(belief.object, "Education");
|
|
assert_eq!(belief.confidence, 0.9);
|
|
}
|
|
|
|
#[test]
|
|
fn test_causal_consistency() {
|
|
let model = create_test_model();
|
|
let checker = CausalCoherenceChecker::new(&model);
|
|
|
|
let beliefs = vec![
|
|
Belief::causes("Age", "Education"),
|
|
Belief::causes("Education", "Income"),
|
|
];
|
|
|
|
let result = checker.check_causal_consistency(&beliefs);
|
|
|
|
assert!(result.is_consistent());
|
|
assert_eq!(result.consistent_beliefs, 2);
|
|
}
|
|
|
|
#[test]
|
|
fn test_inconsistent_belief() {
|
|
let model = create_test_model();
|
|
let checker = CausalCoherenceChecker::new(&model);
|
|
|
|
let beliefs = vec![
|
|
Belief::causes("Income", "Age"), // Wrong direction
|
|
];
|
|
|
|
let result = checker.check_causal_consistency(&beliefs);
|
|
|
|
assert!(!result.is_consistent());
|
|
assert_eq!(result.inconsistent_beliefs, 1);
|
|
}
|
|
|
|
#[test]
|
|
fn test_conditional_independence() {
|
|
let model = create_test_model();
|
|
let checker = CausalCoherenceChecker::new(&model);
|
|
|
|
// Education and Health should be independent given Age
|
|
let beliefs = vec![
|
|
Belief::conditionally_independent("Education", "Health", &["Age"]),
|
|
];
|
|
|
|
let result = checker.check_causal_consistency(&beliefs);
|
|
|
|
assert!(result.is_consistent());
|
|
}
|
|
|
|
#[test]
|
|
fn test_spurious_correlation_detection() {
|
|
let model = create_test_model();
|
|
let checker = CausalCoherenceChecker::new(&model).with_correlation_threshold(0.1);
|
|
|
|
// Create dataset with Education-Health correlation (spurious via Age)
|
|
let mut data = Dataset::new(vec![
|
|
"Age".to_string(),
|
|
"Education".to_string(),
|
|
"Health".to_string(),
|
|
]);
|
|
|
|
// Add correlated data
|
|
for i in 0..100 {
|
|
let age = 20.0 + i as f64 * 0.5;
|
|
let edu = 12.0 + age * 0.1 + (i as f64 * 0.1).sin();
|
|
let health = 100.0 - age * 0.5 + (i as f64 * 0.2).cos();
|
|
data.add_row(vec![age, edu, health]);
|
|
}
|
|
|
|
let spurious = checker.detect_spurious_correlations(&data);
|
|
|
|
// Should detect Education-Health as spurious (both caused by Age)
|
|
let edu_health = spurious.iter()
|
|
.find(|s| (s.var_a == "Education" && s.var_b == "Health") ||
|
|
(s.var_a == "Health" && s.var_b == "Education"));
|
|
|
|
assert!(edu_health.is_some());
|
|
|
|
if let Some(s) = edu_health {
|
|
assert!(s.confounders.contains(&"Age".to_string()));
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_interventional_query() {
|
|
let model = create_test_model();
|
|
let checker = CausalCoherenceChecker::new(&model);
|
|
|
|
let query = CausalQuery::interventional(
|
|
"Income",
|
|
"Education",
|
|
Value::Continuous(16.0),
|
|
);
|
|
|
|
let answer = checker.enforce_do_calculus(&query).unwrap();
|
|
|
|
assert!(answer.is_identifiable);
|
|
assert!(matches!(answer.query.query_type, QueryType::Interventional));
|
|
}
|
|
|
|
#[test]
|
|
fn test_coherence_energy() {
|
|
let energy = CoherenceEnergy::from_components(0.1, 0.2, 0.05);
|
|
|
|
assert!((energy.total - 0.35).abs() < 1e-10);
|
|
assert!(!energy.is_coherent);
|
|
|
|
let coherent = CoherenceEnergy::coherent();
|
|
assert!(coherent.is_coherent);
|
|
}
|
|
|
|
#[test]
|
|
fn test_dataset_correlation() {
|
|
let mut data = Dataset::new(vec!["X".to_string(), "Y".to_string()]);
|
|
|
|
// Perfect positive correlation
|
|
for i in 0..10 {
|
|
data.add_row(vec![i as f64, i as f64]);
|
|
}
|
|
|
|
let corr = data.correlation("X", "Y").unwrap();
|
|
assert!((corr - 1.0).abs() < 1e-10);
|
|
|
|
// Add negatively correlated data
|
|
let mut data2 = Dataset::new(vec!["A".to_string(), "B".to_string()]);
|
|
for i in 0..10 {
|
|
data2.add_row(vec![i as f64, (10 - i) as f64]);
|
|
}
|
|
|
|
let corr2 = data2.correlation("A", "B").unwrap();
|
|
assert!((corr2 + 1.0).abs() < 1e-10);
|
|
}
|
|
|
|
#[test]
|
|
fn test_causal_query_builder() {
|
|
let query = CausalQuery::interventional("Y", "X", Value::Continuous(1.0))
|
|
.given("Z", Value::Continuous(2.0));
|
|
|
|
assert_eq!(query.target, "Y");
|
|
assert_eq!(query.interventions.len(), 1);
|
|
assert_eq!(query.conditions.len(), 1);
|
|
}
|
|
}
|