# Causal Attention Networks (CAN) ## Overview ### Problem Statement Standard attention mechanisms in GNNs ignore temporal and causal ordering, allowing future information to influence past states. This creates three critical issues: 1. **Information Leakage**: Future documents can influence retrieval of past documents 2. **Invalid Counterfactuals**: Cannot answer "what if this event never occurred?" 3. **Temporal Inconsistency**: Legal citations, event logs, and versioned documents require strict causal ordering ### Proposed Solution Implement causal attention that respects temporal ordering through: - Directed acyclic graph (DAG) structure enforcing causality - Masked attention preventing future→past information flow - Counterfactual query engine for "what-if" analysis - Temporal consistency guarantees for ordered data ### Expected Benefits - **100% prevention** of temporal information leakage - **Counterfactual queries**: Answer "what if X didn't exist?" questions - **Legal compliance**: Proper citation precedence in legal documents - **Event causality**: Correct cause-effect relationships in logs - **Version control**: Proper document evolution tracking ### Novelty Claim First integration of strict causal inference principles into vector search. Unlike temporal embeddings (which encode time but don't enforce causality) or recurrent models (which only process sequences), CAN provides: - Formal causal guarantees via DAG structure - Counterfactual reasoning via intervention calculus - Bi-directional queries (forward: "what did this cause?" backward: "what caused this?") ## Technical Design ### Architecture Diagram ``` ┌──────────────────────────────────────────────────────────────────┐ │ Causal Attention Network │ │ │ │ ┌────────────────────────────────────────────────────────┐ │ │ │ Causal DAG Layer │ │ │ │ │ │ │ │ t₀ t₁ t₂ t₃ t₄ │ │ │ │ ●────────▶●────────▶●────────▶●────────▶● │ │ │ │ │ │╲ │╲ │ │ │ │ │ │ │ │ ╲ │ ╲ │ │ │ │ │ │ │ │ ╲ │ ╲ │ │ │ │ │ │ │ ▼ ╲ ▼ ╲ ▼ ▼ │ │ │ │ │ ● └───▶● └───▶●────────▶● │ │ │ │ │ │ │ │ │ │ │ │ │ └────────▶●────────▶●────────▶●────────▶● │ │ │ │ │ │ │ │ Legend: ● = Node with timestamp │ │ │ │ ──▶ = Causal edge (past → future) │ │ │ └────────────────────────────────────────────────────────┘ │ │ │ │ │ ▼ │ │ ┌────────────────────────────────────────────────────────┐ │ │ │ Masked Attention Matrix │ │ │ │ │ │ │ │ q₀ q₁ q₂ q₃ q₄ │ │ │ │ k₀ [ 1.0 0.0 0.0 0.0 0.0 ] ◄─ No future info │ │ │ │ k₁ [ 0.7 1.0 0.0 0.0 0.0 ] │ │ │ │ k₂ [ 0.4 0.6 1.0 0.0 0.0 ] │ │ │ │ k₃ [ 0.2 0.3 0.5 1.0 0.0 ] │ │ │ │ k₄ [ 0.1 0.2 0.3 0.6 1.0 ] │ │ │ │ ▲ │ │ │ │ └─ Upper triangle masked (set to -∞) │ │ │ └────────────────────────────────────────────────────────┘ │ │ │ │ │ ▼ │ │ ┌────────────────────────────────────────────────────────┐ │ │ │ Counterfactual Query Engine │ │ │ │ │ │ │ │ Query: "Results if document D₂ never existed?" │ │ │ │ │ │ │ │ 1. Identify intervention: do(remove D₂) │ │ │ │ 2. Propagate intervention through DAG │ │ │ │ 3. Recompute attention without D₂'s influence │ │ │ │ 4. Compare: Actual vs Counterfactual results │ │ │ │ │ │ │ └────────────────────────────────────────────────────────┘ │ └──────────────────────────────────────────────────────────────────┘ ``` ### Core Data Structures ```rust /// Causal graph structure (DAG) #[derive(Clone, Debug)] pub struct CausalGraph { /// Nodes with temporal ordering nodes: Vec, /// Adjacency list (only forward edges: past → future) edges: Vec>, /// Topological ordering cache topo_order: Vec, /// Temporal index for fast time-based queries temporal_index: BTreeMap>, /// Reverse index (for backward causal queries) reverse_edges: Vec>, } /// Node with causal metadata #[derive(Clone, Debug)] pub struct CausalNode { /// Unique identifier pub id: NodeId, /// Embedding vector pub embedding: Vec, /// Timestamp (must be monotonic) pub timestamp: Timestamp, /// Causal parents (nodes that influenced this one) pub parents: Vec, /// Causal children (nodes influenced by this one) pub children: Vec, /// Metadata (document type, version, etc.) pub metadata: HashMap, } /// Timestamp with total ordering #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] pub struct Timestamp { /// Seconds since epoch pub seconds: i64, /// Nanoseconds (for sub-second precision) pub nanos: u32, /// Logical clock (for events at same physical time) pub logical: u64, } /// Causal attention mask #[derive(Clone, Debug)] pub struct CausalMask { /// Sparse mask representation /// Only store allowed attention pairs allowed_pairs: HashSet<(NodeId, NodeId)>, /// Cached dense mask for small graphs dense_mask: Option>, /// Mask generation strategy strategy: MaskStrategy, } /// Mask generation strategies #[derive(Clone, Debug)] pub enum MaskStrategy { /// Strict: Only past nodes (timestamp < current) Strict, /// Window: Past N time units TimeWindow { duration: Duration }, /// Topological: Follow DAG structure Topological { max_depth: usize }, /// Custom predicate Custom(fn(&CausalNode, &CausalNode) -> bool), } /// Counterfactual intervention #[derive(Clone, Debug)] pub struct Intervention { /// Type of intervention pub kind: InterventionKind, /// Target nodes pub targets: Vec, /// Intervention strength (0.0 = no effect, 1.0 = complete removal) pub strength: f32, } #[derive(Clone, Debug)] pub enum InterventionKind { /// Remove node entirely Remove, /// Set embedding to specific value SetValue(Vec), /// Block causal influence (cut edges) BlockInfluence, /// Add hypothetical node AddNode(CausalNode), } /// Counterfactual query result #[derive(Clone, Debug)] pub struct CounterfactualResult { /// Actual (factual) results pub factual: Vec, /// Counterfactual results (with intervention) pub counterfactual: Vec, /// Difference analysis pub differences: Vec, /// Causal effect size pub effect_size: f32, } #[derive(Clone, Debug)] pub struct Difference { pub node_id: NodeId, pub rank_change: i32, pub score_change: f32, pub explanation: String, } ``` ### Key Algorithms ```rust // Pseudocode for causal attention /// Build causal mask from temporal ordering fn build_causal_mask( graph: &CausalGraph, strategy: MaskStrategy ) -> CausalMask { let mut allowed_pairs = HashSet::new(); for node in &graph.nodes { match strategy { MaskStrategy::Strict => { // Allow attention only to earlier nodes for other in &graph.nodes { if other.timestamp < node.timestamp { allowed_pairs.insert((node.id, other.id)); } } }, MaskStrategy::TimeWindow { duration } => { // Allow attention within time window let cutoff = node.timestamp - duration; for other in &graph.nodes { if other.timestamp >= cutoff && other.timestamp < node.timestamp { allowed_pairs.insert((node.id, other.id)); } } }, MaskStrategy::Topological { max_depth } => { // Allow attention to ancestors in DAG let ancestors = find_ancestors(graph, node.id, max_depth); for ancestor in ancestors { allowed_pairs.insert((node.id, ancestor)); } }, MaskStrategy::Custom(predicate) => { for other in &graph.nodes { if predicate(node, other) { allowed_pairs.insert((node.id, other.id)); } } }, } } CausalMask { allowed_pairs, dense_mask: None, // Lazily computed strategy, } } /// Causal attention computation fn causal_attention( query: &[f32], graph: &CausalGraph, mask: &CausalMask, k: usize ) -> Vec { let mut scores = Vec::new(); // Compute attention scores for node in &graph.nodes { let score = cosine_similarity(query, &node.embedding); scores.push((node.id, score)); } // Apply causal mask scores.retain(|(node_id, _)| { // For query at "current time", only attend to past let query_time = Timestamp::now(); let node = &graph.nodes[*node_id]; node.timestamp < query_time }); // Sort by score and return top-k scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); scores.into_iter() .take(k) .map(|(id, score)| SearchResult { id, score }) .collect() } /// Counterfactual query with intervention fn counterfactual_query( query: &[f32], graph: &CausalGraph, intervention: &Intervention, k: usize ) -> CounterfactualResult { // Step 1: Compute factual results (no intervention) let factual = causal_attention(query, graph, &graph.default_mask, k); // Step 2: Apply intervention let mut modified_graph = graph.clone(); apply_intervention(&mut modified_graph, intervention); // Step 3: Compute counterfactual results let counterfactual = causal_attention( query, &modified_graph, &modified_graph.default_mask, k ); // Step 4: Analyze differences let differences = compute_differences(&factual, &counterfactual); // Step 5: Compute causal effect size let effect_size = compute_effect_size(&factual, &counterfactual); CounterfactualResult { factual, counterfactual, differences, effect_size, } } /// Apply intervention to graph fn apply_intervention( graph: &mut CausalGraph, intervention: &Intervention ) { match &intervention.kind { InterventionKind::Remove => { // Remove nodes and their causal influence for target in &intervention.targets { // Mark node as removed graph.nodes[*target].metadata.insert( "removed".to_string(), "true".to_string() ); // Cut all outgoing edges (prevent future influence) graph.edges[*target].clear(); // Remove incoming edges (erase past influence) for parent in &graph.nodes[*target].parents.clone() { graph.edges[*parent].retain(|e| { graph.get_edge(*e).target != *target }); } } // Recompute topological order graph.recompute_topo_order(); }, InterventionKind::SetValue(new_embedding) => { // Change embedding value for target in &intervention.targets { graph.nodes[*target].embedding = new_embedding.clone(); } }, InterventionKind::BlockInfluence => { // Cut outgoing edges but keep node for target in &intervention.targets { graph.edges[*target].clear(); } }, InterventionKind::AddNode(new_node) => { // Add hypothetical node graph.add_node(new_node.clone()); graph.recompute_topo_order(); }, } } /// Topological sort for DAG fn topological_sort(graph: &CausalGraph) -> Vec { let mut in_degree = vec![0; graph.nodes.len()]; // Compute in-degrees for edges in &graph.edges { for edge_id in edges { let target = graph.get_edge(*edge_id).target; in_degree[target] += 1; } } // Kahn's algorithm let mut queue: VecDeque = in_degree.iter() .enumerate() .filter(|(_, °)| deg == 0) .map(|(id, _)| id) .collect(); let mut result = Vec::new(); while let Some(node) = queue.pop_front() { result.push(node); for edge_id in &graph.edges[node] { let target = graph.get_edge(*edge_id).target; in_degree[target] -= 1; if in_degree[target] == 0 { queue.push_back(target); } } } assert_eq!(result.len(), graph.nodes.len(), "Graph has cycle!"); result } ``` ### API Design ```rust /// Public API for Causal Attention Networks pub trait CausalAttention { /// Create causal graph from timestamped documents fn new(documents: Vec, config: CausalConfig) -> Self; /// Search with causal constraints fn search( &self, query: &[f32], k: usize, options: CausalSearchOptions, ) -> Result, CanError>; /// Counterfactual query fn counterfactual( &self, query: &[f32], intervention: Intervention, k: usize, ) -> Result; /// Forward causal query: "What did X cause?" fn forward_causal( &self, source: NodeId, max_depth: usize, ) -> Result, CanError>; /// Backward causal query: "What caused X?" fn backward_causal( &self, target: NodeId, max_depth: usize, ) -> Result, CanError>; /// Add new node with temporal ordering fn add_node(&mut self, node: CausalNode) -> Result; /// Verify causal consistency fn verify_consistency(&self) -> Result<(), CanError>; /// Export causal graph for visualization fn export_graph(&self) -> CausalGraphExport; } /// Configuration for causal attention #[derive(Clone, Debug)] pub struct CausalConfig { /// Mask generation strategy pub mask_strategy: MaskStrategy, /// Allow concurrent events (same timestamp)? pub allow_concurrent: bool, /// Automatic edge inference from timestamps pub infer_edges: bool, /// Maximum causal depth for queries pub max_depth: usize, } /// Search options with causal constraints #[derive(Clone, Debug)] pub struct CausalSearchOptions { /// Search only before this timestamp pub before: Option, /// Search only after this timestamp pub after: Option, /// Require specific causal path pub require_path: Option>, /// Exclude nodes and their descendants pub exclude: Vec, } /// Causal graph export format #[derive(Clone, Debug, Serialize)] pub struct CausalGraphExport { pub nodes: Vec, pub edges: Vec, pub metadata: HashMap, } #[derive(Clone, Debug, Serialize)] pub struct ExportNode { pub id: NodeId, pub timestamp: Timestamp, pub label: String, pub position: (f32, f32), // For visualization } #[derive(Clone, Debug, Serialize)] pub struct ExportEdge { pub source: NodeId, pub target: NodeId, pub weight: f32, } ``` ## Integration Points ### Affected Crates/Modules 1. **`crates/ruvector-core/src/hnsw/`** - Extend to support directed edges (DAG structure) - Add temporal metadata to nodes - Modify search to respect causal constraints 2. **`crates/ruvector-gnn/src/attention/`** - Add causal masking to attention mechanisms - Integrate with existing attention variants 3. **`crates/ruvector-core/src/index/`** - Add temporal indexing for fast time-based queries - Support DAG-based navigation ### New Modules to Create 1. **`crates/ruvector-gnn/src/causal/`** - `graph.rs` - Causal DAG implementation - `mask.rs` - Causal masking strategies - `intervention.rs` - Counterfactual interventions - `search.rs` - Causal search algorithms - `verify.rs` - Consistency checking - `temporal.rs` - Timestamp and ordering utilities 2. **`crates/ruvector-core/src/temporal/`** - `index.rs` - Temporal indexing structures - `ordering.rs` - Total order on timestamps - `version.rs` - Document versioning support ### Dependencies on Other Features - **Feature 10 (Gravitational Fields)**: GEF must respect causal ordering (no backward pull) - **Feature 12 (Topology-Aware Routing)**: Topology metrics need DAG-aware computation - **Feature 13 (Crystallization)**: Hierarchies must respect temporal precedence ## Regression Prevention ### Existing Functionality at Risk 1. **Undirected Graph Search** - Risk: Breaking existing HNSW bidirectional search - Prevention: Maintain separate directed/undirected graph modes 2. **Performance Overhead** - Risk: Topological sort and mask computation add latency - Prevention: Cache masks, lazy computation, optional feature 3. **Storage Overhead** - Risk: Timestamp + edge direction doubles metadata - Prevention: Optional temporal metadata, compressed timestamps ### Test Cases to Prevent Regressions ```rust #[cfg(test)] mod regression_tests { /// Verify no temporal leakage #[test] fn test_no_future_information() { let mut graph = CausalGraph::new(CausalConfig::default()); // Add nodes with increasing timestamps let past = graph.add_node(node_at_time(t0)); let present = graph.add_node(node_at_time(t1)); let future = graph.add_node(node_at_time(t2)); // Query from present: should not see future let results = graph.search(&query, 10, CausalSearchOptions { before: Some(t1), ..Default::default() }); assert!(!results.contains(&future)); assert!(results.contains(&past)); } /// Counterfactual removal test #[test] fn test_counterfactual_removal() { let graph = create_legal_citation_graph(); // Factual: Case A cites Case B let factual = graph.search(&case_a_query, 10); assert!(factual.contains(&case_b)); // Counterfactual: What if Case B never existed? let intervention = Intervention { kind: InterventionKind::Remove, targets: vec![case_b], strength: 1.0, }; let counterfactual = graph.counterfactual( &case_a_query, intervention, 10 ); assert!(!counterfactual.counterfactual.contains(&case_b)); assert_ne!(factual, counterfactual.factual); } /// DAG consistency #[test] fn test_dag_no_cycles() { let graph = create_random_causal_graph(1000); // Should not panic (cycle detection) let topo = graph.topological_sort(); assert_eq!(topo.len(), 1000); // Verify all edges go forward in topological order for (source, edges) in graph.edges.iter().enumerate() { for edge in edges { let target = graph.get_edge(*edge).target; let source_pos = topo.iter().position(|&id| id == source).unwrap(); let target_pos = topo.iter().position(|&id| id == target).unwrap(); assert!(source_pos < target_pos, "Edge goes backward!"); } } } } ``` ### Backward Compatibility Strategy 1. **Dual Mode**: Support both causal and non-causal graphs 2. **Automatic Detection**: Infer causality from timestamp metadata 3. **Migration Tool**: Convert existing graphs to causal structure 4. **Graceful Degradation**: If no timestamps, fall back to standard search ## Implementation Phases ### Phase 1: Research Validation (2 weeks) **Goal**: Validate causal inference on real-world data - Implement basic DAG structure and topological sort - Create legal citation dataset with known causal structure - Test counterfactual queries on synthetic data - Measure temporal leakage prevention - **Deliverable**: Research report with causal correctness proofs ### Phase 2: Core Implementation (3 weeks) **Goal**: Production causal graph system - Implement `CausalGraph` with temporal indexing - Develop causal masking strategies - Build intervention engine - Add forward/backward causal queries - Implement consistency verification - **Deliverable**: Working CAN module with unit tests ### Phase 3: Integration (2 weeks) **Goal**: Integrate with RuVector ecosystem - Add temporal metadata to HNSW nodes - Implement DAG serialization/deserialization - Create API bindings (Python, Node.js) - Add visualization tools (Graphviz export) - Write integration tests - **Deliverable**: CAN integrated into main codebase ### Phase 4: Optimization (2 weeks) **Goal**: Production performance - Profile and optimize topological sort - Implement sparse mask representations - Add incremental updates (streaming DAG) - Create benchmarks for legal/event datasets - Write documentation and examples - **Deliverable**: Production-ready, documented feature ## Success Metrics ### Performance Benchmarks | Metric | Baseline | Target | Measurement | |--------|----------|--------|-------------| | Temporal leakage rate | N/A | 0% | Verified by test suite | | Causal query latency | N/A | <2ms | 99th percentile, 10K nodes | | Counterfactual overhead | N/A | <5x | vs. standard search | | Memory overhead | 0MB | <100MB | Per 1M nodes (timestamps+edges) | | DAG update latency | N/A | <1ms | Add node with edge inference | ### Accuracy Metrics 1. **Temporal Correctness**: No future information in results - Target: 100% correctness (formal verification) 2. **Counterfactual Validity**: Interventions produce expected changes - Target: >95% agreement with manual counterfactual analysis 3. **Causal Path Accuracy**: Correct ancestor/descendant relationships - Target: 100% correctness on citation graphs ### Comparison to Baselines Test against: 1. **Standard Attention**: Temporal leakage analysis 2. **Temporal Embeddings**: Counterfactual capability comparison 3. **RNNs/LSTMs**: Bi-directional causal query performance Datasets: - Legal citations (Caselaw Access Project, 6M cases) - arXiv citations (2M papers with temporal metadata) - Wikipedia edit history (versioned documents) - Event logs (system logs, user actions) ## Risks and Mitigations ### Technical Risks | Risk | Impact | Probability | Mitigation | |------|--------|-------------|------------| | Cycle detection bugs | High | Low | Extensive testing, formal verification | | Timestamp conflicts | Medium | Medium | Logical clocks, conflict resolution | | Counterfactual explosion | High | Medium | Limit intervention scope, caching | | DAG update complexity | Medium | High | Incremental algorithms, batching | | Poor timestamp quality | High | High | Automatic inference, validation | ### Detailed Mitigations 1. **Cycle Detection Bugs** - Implement multiple cycle detection algorithms (DFS, Kahn's) - Property-based testing (QuickCheck) - Formal proof of DAG invariants - Fallback: Reject graphs with cycles 2. **Timestamp Conflicts** - Use hybrid logical clocks (HLC) for concurrent events - Implement timestamp resolution strategies - Allow manual timestamp assignment - Fallback: Use insertion order as logical time 3. **Counterfactual Explosion** - Limit intervention depth (max descendants affected) - Implement intervention caching - Use approximate counterfactuals for large graphs - Fallback: Disable counterfactuals for >1M nodes 4. **DAG Update Complexity** - Implement incremental topological sort (Pearce-Kelly) - Batch insertions for better amortized cost - Use lazy recomputation strategies - Fallback: Full recomputation only when needed 5. **Poor Timestamp Quality** - Infer timestamps from document metadata - Cross-reference multiple time sources - Implement timestamp validation heuristics - Fallback: Warn user and disable causal guarantees ## Applications ### Legal Document Search - Citation precedence: Only cite earlier cases - Counterfactual: "Would this case still apply if landmark case X was overturned?" - Temporal queries: "Find cases before 2020 about patent law" ### Event Log Analysis - Root cause analysis: "What caused this failure?" - Impact analysis: "What did this configuration change affect?" - Counterfactual: "What if we hadn't deployed version 2.3?" ### Version Control - Document evolution: "Show me earlier versions of this section" - Blame analysis: "Which change introduced this concept?" - Counterfactual: "What would docs look like without the API redesign?" ### Knowledge Graphs - Temporal reasoning: "What was known about X in 2015?" - Causal inference: "Did discovery A enable discovery B?" - Counterfactual: "What if theory X was never proposed?" ## References ### Causal Inference Theory - Pearl's causality framework (do-calculus) - Directed Acyclic Graphs (DAGs) for causality - Counterfactual reasoning and interventions - Granger causality for time series ### Temporal Modeling - Temporal knowledge graphs - Hybrid logical clocks (HLC) - Version control theory (DAG of commits) - Event sourcing and CQRS ### Implementation Techniques - Incremental topological sorting - Sparse attention masks - Efficient DAG operations - Temporal indexing structures