Files
wifi-densepose/examples/prime-radiant/docs/adr/ADR-005-causal-abstraction.md
ruv d803bfe2b1 Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector
git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
2026-02-28 14:39:40 -05:00

344 lines
10 KiB
Markdown

# ADR-005: Causal Abstraction for Mechanistic Interpretability
**Status**: Accepted
**Date**: 2024-12-15
**Authors**: RuVector Team
**Supersedes**: None
---
## Context
Understanding *why* neural networks produce their outputs requires more than correlation analysis. We need:
1. **Causal mechanisms**: Which components actually cause specific behaviors
2. **Interventional reasoning**: What happens when we modify internal states
3. **Abstraction levels**: How low-level computations relate to high-level concepts
4. **Alignment verification**: Whether learned mechanisms match intended behavior
Traditional interpretability approaches provide:
- Attention visualization (correlational, not causal)
- Gradient-based attribution (local approximations)
- Probing classifiers (detect presence, not causation)
These fail to distinguish "correlates with output" from "causes output."
### Why Causal Abstraction?
Causal abstraction theory (Geiger et al., 2021) provides a rigorous framework for:
1. **Defining interpretations**: Mapping neural computations to high-level concepts
2. **Testing interpretations**: Using interventions to verify causal structure
3. **Measuring alignment**: Quantifying how well neural mechanisms match intended algorithms
4. **Localizing circuits**: Finding minimal subnetworks that implement behaviors
---
## Decision
We implement **causal abstraction** as the foundation for mechanistic interpretability in Prime-Radiant.
### Core Concepts
#### 1. Causal Models
```rust
/// A causal model with variables and structural equations
pub struct CausalModel {
/// Variable nodes
variables: HashMap<VariableId, Variable>,
/// Directed edges (cause -> effect)
edges: HashSet<(VariableId, VariableId)>,
/// Structural equations: V = f(Pa(V), noise)
equations: HashMap<VariableId, StructuralEquation>,
/// Exogenous noise distributions
noise: HashMap<VariableId, NoiseDistribution>,
}
/// A variable in the causal model
pub struct Variable {
pub id: VariableId,
pub name: String,
pub domain: VariableDomain,
pub level: AbstractionLevel,
}
/// Structural equation defining variable's value
pub enum StructuralEquation {
/// f(inputs) -> output
Function(Box<dyn Fn(&[Value]) -> Value>),
/// Neural network component
Neural(NeuralComponent),
/// Identity (exogenous variable)
Exogenous,
}
```
#### 2. Interventions
```rust
/// An intervention on a causal model
pub enum Intervention {
/// Set variable to constant value: do(X = x)
Hard(VariableId, Value),
/// Modify value by function: do(X = f(X))
Soft(VariableId, Box<dyn Fn(Value) -> Value>),
/// Interchange values between runs
Interchange(VariableId, SourceId),
/// Activation patching
Patch(VariableId, Vec<f32>),
}
impl CausalModel {
/// Apply intervention and compute effects
pub fn intervene(&self, intervention: &Intervention) -> CausalModel {
let mut modified = self.clone();
match intervention {
Intervention::Hard(var, value) => {
// Remove all incoming edges
modified.edges.retain(|(_, target)| target != var);
// Set constant equation
modified.equations.insert(*var, StructuralEquation::constant(*value));
}
Intervention::Soft(var, f) => {
// Compose with existing equation
let old_eq = modified.equations.get(var).unwrap();
modified.equations.insert(*var, old_eq.compose(f));
}
// ...
}
modified
}
}
```
#### 3. Causal Abstraction
```rust
/// A causal abstraction between two models
pub struct CausalAbstraction {
/// Low-level (concrete) model
low: CausalModel,
/// High-level (abstract) model
high: CausalModel,
/// Variable mapping: low -> high
tau: HashMap<VariableId, VariableId>,
/// Intervention mapping
intervention_map: Box<dyn Fn(&Intervention) -> Intervention>,
}
impl CausalAbstraction {
/// Check if abstraction is valid (interventional consistency)
pub fn is_valid(&self, test_interventions: &[Intervention]) -> bool {
for intervention in test_interventions {
// Map intervention to high level
let high_intervention = (self.intervention_map)(intervention);
// Intervene on both models
let low_result = self.low.intervene(intervention);
let high_result = self.high.intervene(&high_intervention);
// Check outputs match (up to tau)
let low_output = low_result.output();
let high_output = high_result.output();
if !self.outputs_match(&low_output, &high_output) {
return false;
}
}
true
}
/// Compute interchange intervention accuracy
pub fn iia(&self,
base_inputs: &[Input],
source_inputs: &[Input],
target_var: VariableId) -> f64 {
let mut correct = 0;
let total = base_inputs.len() * source_inputs.len();
for base in base_inputs {
for source in source_inputs {
// Run high-level model with intervention
let high_base = self.high.run(base);
let high_source = self.high.run(source);
let high_interchanged = self.high.intervene(
&Intervention::Interchange(target_var, high_source.id)
).run(base);
// Run low-level model with corresponding intervention
let low_base = self.low.run(base);
let low_source = self.low.run(source);
let low_intervention = (self.intervention_map)(
&Intervention::Interchange(self.tau[&target_var], low_source.id)
);
let low_interchanged = self.low.intervene(&low_intervention).run(base);
// Check if behaviors match
if self.outputs_match(&low_interchanged, &high_interchanged) {
correct += 1;
}
}
}
correct as f64 / total as f64
}
}
```
### Activation Patching
```rust
/// Activation patching for neural network interpretability
pub struct ActivationPatcher {
/// Target layer/component
target: NeuralComponent,
/// Patch source
source: PatchSource,
}
pub enum PatchSource {
/// From another input's activation
OtherInput(InputId),
/// Fixed vector
Fixed(Vec<f32>),
/// Noise ablation
Noise(NoiseDistribution),
/// Mean ablation
Mean,
/// Zero ablation
Zero,
}
impl ActivationPatcher {
/// Measure causal effect of patching
pub fn causal_effect(
&self,
model: &NeuralNetwork,
base_input: &Input,
metric: &Metric,
) -> f64 {
// Run without patching
let base_output = model.forward(base_input);
let base_metric = metric.compute(&base_output);
// Run with patching
let patched_output = model.forward_with_patch(base_input, self);
let patched_metric = metric.compute(&patched_output);
// Causal effect is the difference
patched_metric - base_metric
}
}
```
### Circuit Discovery
```rust
/// Discover minimal circuits implementing a behavior
pub struct CircuitDiscovery {
/// Target behavior to explain
behavior: Behavior,
/// Candidate components
components: Vec<NeuralComponent>,
/// Discovered circuits
circuits: Vec<Circuit>,
}
pub struct Circuit {
/// Components in the circuit
components: Vec<NeuralComponent>,
/// Edges (data flow)
edges: Vec<(NeuralComponent, NeuralComponent)>,
/// Faithfulness score (how well circuit explains behavior)
faithfulness: f64,
/// Completeness score (how much of behavior is captured)
completeness: f64,
}
impl CircuitDiscovery {
/// Use activation patching to find important components
pub fn find_circuit(&mut self, model: &NeuralNetwork, inputs: &[Input]) -> Circuit {
let mut important = Vec::new();
// Test each component
for component in &self.components {
let patcher = ActivationPatcher {
target: component.clone(),
source: PatchSource::Zero,
};
let avg_effect: f64 = inputs.iter()
.map(|input| patcher.causal_effect(model, input, &self.behavior.metric))
.sum::<f64>() / inputs.len() as f64;
if avg_effect.abs() > IMPORTANCE_THRESHOLD {
important.push((component.clone(), avg_effect));
}
}
// Build circuit from important components
self.build_circuit(important)
}
}
```
---
## Consequences
### Positive
1. **Rigorous causality**: Distinguishes correlation from causation
2. **Multi-level analysis**: Connects low-level activations to high-level concepts
3. **Testable interpretations**: Interventions provide empirical verification
4. **Circuit localization**: Identifies minimal subnetworks for behaviors
5. **Alignment checking**: Verifies mechanisms match specifications
### Negative
1. **Combinatorial explosion**: Testing all interventions is exponential
2. **Approximation required**: Full causal analysis is computationally intractable
3. **Abstraction design**: Choosing the right high-level model requires insight
4. **Noise sensitivity**: Small variations can affect intervention outcomes
### Mitigations
1. **Importance sampling**: Focus on high-impact interventions
2. **Hierarchical search**: Use coarse-to-fine circuit discovery
3. **Learned abstractions**: Train models to find good variable mappings
4. **Robust statistics**: Use multiple samples and statistical tests
---
## Integration with Prime-Radiant
### Connection to Sheaf Cohomology
Causal structure forms a sheaf:
- Open sets: Subnetworks
- Sections: Causal mechanisms
- Restriction maps: Marginalization
- Cohomology: Obstruction to global causal explanation
### Connection to Category Theory
Causal abstraction is a functor:
- Objects: Causal models
- Morphisms: Interventional maps
- Composition: Hierarchical abstraction
---
## References
1. Geiger, A., et al. (2021). "Causal Abstractions of Neural Networks." NeurIPS.
2. Pearl, J. (2009). "Causality: Models, Reasoning, and Inference." Cambridge.
3. Conmy, A., et al. (2023). "Towards Automated Circuit Discovery." NeurIPS.
4. Wang, K., et al. (2022). "Interpretability in the Wild." ICLR.
5. Goldowsky-Dill, N., et al. (2023). "Localizing Model Behavior with Path Patching." arXiv.