464 lines
13 KiB
Rust
464 lines
13 KiB
Rust
//! Sheaf Data Structure
|
|
//!
|
|
//! A sheaf on a graph assigns:
|
|
//! - A vector space (stalk) to each vertex
|
|
//! - Restriction maps between adjacent stalks
|
|
//!
|
|
//! This is the foundational structure for cohomology computation.
|
|
|
|
use crate::substrate::NodeId;
|
|
use crate::substrate::{RestrictionMap, SheafGraph};
|
|
use ndarray::{Array1, Array2};
|
|
use serde::{Deserialize, Serialize};
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
|
|
/// A stalk (fiber) at a vertex - the local data space
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct Stalk {
|
|
/// Dimension of the stalk (vector space dimension)
|
|
pub dimension: usize,
|
|
/// Optional basis vectors (if not standard basis)
|
|
pub basis: Option<Array2<f64>>,
|
|
}
|
|
|
|
impl Stalk {
|
|
/// Create a stalk of given dimension with standard basis
|
|
pub fn new(dimension: usize) -> Self {
|
|
Self {
|
|
dimension,
|
|
basis: None,
|
|
}
|
|
}
|
|
|
|
/// Create a stalk with a custom basis
|
|
pub fn with_basis(dimension: usize, basis: Array2<f64>) -> Self {
|
|
assert_eq!(basis.ncols(), dimension, "Basis dimension mismatch");
|
|
Self {
|
|
dimension,
|
|
basis: Some(basis),
|
|
}
|
|
}
|
|
|
|
/// Get dimension
|
|
pub fn dim(&self) -> usize {
|
|
self.dimension
|
|
}
|
|
}
|
|
|
|
/// A local section assigns a value in the stalk at each vertex
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct LocalSection {
|
|
/// Vertex ID
|
|
pub vertex: NodeId,
|
|
/// Value in the stalk (as a vector)
|
|
pub value: Array1<f64>,
|
|
}
|
|
|
|
impl LocalSection {
|
|
/// Create a new local section
|
|
pub fn new(vertex: NodeId, value: Array1<f64>) -> Self {
|
|
Self { vertex, value }
|
|
}
|
|
|
|
/// Create from f32 slice
|
|
pub fn from_slice(vertex: NodeId, data: &[f32]) -> Self {
|
|
let value = Array1::from_iter(data.iter().map(|&x| x as f64));
|
|
Self { vertex, value }
|
|
}
|
|
|
|
/// Get dimension
|
|
pub fn dim(&self) -> usize {
|
|
self.value.len()
|
|
}
|
|
}
|
|
|
|
/// A sheaf section is a collection of local sections that are compatible
|
|
/// under restriction maps
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct SheafSection {
|
|
/// Local sections indexed by vertex
|
|
pub sections: HashMap<NodeId, Array1<f64>>,
|
|
/// Whether this is a global section (fully consistent)
|
|
pub is_global: bool,
|
|
}
|
|
|
|
impl SheafSection {
|
|
/// Create an empty section
|
|
pub fn empty() -> Self {
|
|
Self {
|
|
sections: HashMap::new(),
|
|
is_global: false,
|
|
}
|
|
}
|
|
|
|
/// Create a section from local data
|
|
pub fn from_local(sections: HashMap<NodeId, Array1<f64>>) -> Self {
|
|
Self {
|
|
sections,
|
|
is_global: false,
|
|
}
|
|
}
|
|
|
|
/// Get the value at a vertex
|
|
pub fn get(&self, vertex: NodeId) -> Option<&Array1<f64>> {
|
|
self.sections.get(&vertex)
|
|
}
|
|
|
|
/// Set the value at a vertex
|
|
pub fn set(&mut self, vertex: NodeId, value: Array1<f64>) {
|
|
self.sections.insert(vertex, value);
|
|
self.is_global = false; // Need to recheck
|
|
}
|
|
|
|
/// Check if a vertex is in the section's domain
|
|
pub fn contains(&self, vertex: NodeId) -> bool {
|
|
self.sections.contains_key(&vertex)
|
|
}
|
|
|
|
/// Number of vertices with assigned values
|
|
pub fn support_size(&self) -> usize {
|
|
self.sections.len()
|
|
}
|
|
}
|
|
|
|
/// Type alias for restriction map function
|
|
pub type RestrictionFn = Arc<dyn Fn(&Array1<f64>) -> Array1<f64> + Send + Sync>;
|
|
|
|
/// A sheaf on a graph
|
|
///
|
|
/// Assigns stalks to vertices and restriction maps to edges
|
|
#[derive(Clone)]
|
|
pub struct Sheaf {
|
|
/// Stalks at each vertex
|
|
pub stalks: HashMap<NodeId, Stalk>,
|
|
/// Restriction maps indexed by (source, target) pairs
|
|
/// The map rho_{u->v} restricts from stalk at u to edge space
|
|
restriction_maps: HashMap<(NodeId, NodeId), RestrictionFn>,
|
|
/// Cached dimensions for performance
|
|
stalk_dims: HashMap<NodeId, usize>,
|
|
/// Total dimension (sum of all stalk dimensions)
|
|
total_dim: usize,
|
|
}
|
|
|
|
impl Sheaf {
|
|
/// Create a new empty sheaf
|
|
pub fn new() -> Self {
|
|
Self {
|
|
stalks: HashMap::new(),
|
|
restriction_maps: HashMap::new(),
|
|
stalk_dims: HashMap::new(),
|
|
total_dim: 0,
|
|
}
|
|
}
|
|
|
|
/// Build a sheaf from a SheafGraph
|
|
///
|
|
/// Uses the graph's state vectors as stalks and restriction maps from edges
|
|
pub fn from_graph(graph: &SheafGraph) -> Self {
|
|
let mut sheaf = Self::new();
|
|
|
|
// Add stalks from nodes
|
|
for node_id in graph.node_ids() {
|
|
if let Some(node) = graph.get_node(node_id) {
|
|
let dim = node.state.dim();
|
|
sheaf.add_stalk(node_id, Stalk::new(dim));
|
|
}
|
|
}
|
|
|
|
// Add restriction maps from edges
|
|
for edge_id in graph.edge_ids() {
|
|
if let Some(edge) = graph.get_edge(edge_id) {
|
|
let source = edge.source;
|
|
let target = edge.target;
|
|
|
|
// Create restriction functions from the edge's restriction maps
|
|
let source_rho = edge.rho_source.clone();
|
|
let target_rho = edge.rho_target.clone();
|
|
|
|
// Source restriction map
|
|
let source_fn: RestrictionFn = Arc::new(move |v: &Array1<f64>| {
|
|
let input: Vec<f32> = v.iter().map(|&x| x as f32).collect();
|
|
let output = source_rho.apply(&input);
|
|
Array1::from_iter(output.iter().map(|&x| x as f64))
|
|
});
|
|
|
|
// Target restriction map
|
|
let target_fn: RestrictionFn = Arc::new(move |v: &Array1<f64>| {
|
|
let input: Vec<f32> = v.iter().map(|&x| x as f32).collect();
|
|
let output = target_rho.apply(&input);
|
|
Array1::from_iter(output.iter().map(|&x| x as f64))
|
|
});
|
|
|
|
sheaf.add_restriction(source, target, source_fn.clone());
|
|
sheaf.add_restriction(target, source, target_fn);
|
|
}
|
|
}
|
|
|
|
sheaf
|
|
}
|
|
|
|
/// Add a stalk at a vertex
|
|
pub fn add_stalk(&mut self, vertex: NodeId, stalk: Stalk) {
|
|
let dim = stalk.dimension;
|
|
self.stalks.insert(vertex, stalk);
|
|
self.stalk_dims.insert(vertex, dim);
|
|
self.total_dim = self.stalk_dims.values().sum();
|
|
}
|
|
|
|
/// Add a restriction map
|
|
pub fn add_restriction(&mut self, source: NodeId, target: NodeId, map: RestrictionFn) {
|
|
self.restriction_maps.insert((source, target), map);
|
|
}
|
|
|
|
/// Get the stalk at a vertex
|
|
pub fn get_stalk(&self, vertex: NodeId) -> Option<&Stalk> {
|
|
self.stalks.get(&vertex)
|
|
}
|
|
|
|
/// Get stalk dimension
|
|
pub fn stalk_dim(&self, vertex: NodeId) -> Option<usize> {
|
|
self.stalk_dims.get(&vertex).copied()
|
|
}
|
|
|
|
/// Apply restriction map from source to target
|
|
pub fn restrict(
|
|
&self,
|
|
source: NodeId,
|
|
target: NodeId,
|
|
value: &Array1<f64>,
|
|
) -> Option<Array1<f64>> {
|
|
self.restriction_maps
|
|
.get(&(source, target))
|
|
.map(|rho| rho(value))
|
|
}
|
|
|
|
/// Check if a section is globally consistent
|
|
///
|
|
/// A section is consistent if for every edge (u,v):
|
|
/// rho_u(s(u)) = rho_v(s(v))
|
|
pub fn is_consistent(&self, section: &SheafSection, tolerance: f64) -> bool {
|
|
for &(source, target) in self.restriction_maps.keys() {
|
|
if let (Some(s_val), Some(t_val)) = (section.get(source), section.get(target)) {
|
|
let s_restricted = self.restrict(source, target, s_val);
|
|
let t_restricted = self.restrict(target, source, t_val);
|
|
|
|
if let (Some(s_r), Some(t_r)) = (s_restricted, t_restricted) {
|
|
let diff = &s_r - &t_r;
|
|
let norm: f64 = diff.iter().map(|x| x * x).sum::<f64>().sqrt();
|
|
if norm > tolerance {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
true
|
|
}
|
|
|
|
/// Compute residual (inconsistency) at an edge
|
|
pub fn edge_residual(
|
|
&self,
|
|
source: NodeId,
|
|
target: NodeId,
|
|
section: &SheafSection,
|
|
) -> Option<Array1<f64>> {
|
|
let s_val = section.get(source)?;
|
|
let t_val = section.get(target)?;
|
|
|
|
let s_restricted = self.restrict(source, target, s_val)?;
|
|
let t_restricted = self.restrict(target, source, t_val)?;
|
|
|
|
Some(&s_restricted - &t_restricted)
|
|
}
|
|
|
|
/// Total dimension of the sheaf
|
|
pub fn total_dimension(&self) -> usize {
|
|
self.total_dim
|
|
}
|
|
|
|
/// Number of vertices
|
|
pub fn num_vertices(&self) -> usize {
|
|
self.stalks.len()
|
|
}
|
|
|
|
/// Iterator over vertices
|
|
pub fn vertices(&self) -> impl Iterator<Item = NodeId> + '_ {
|
|
self.stalks.keys().copied()
|
|
}
|
|
}
|
|
|
|
impl Default for Sheaf {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
impl std::fmt::Debug for Sheaf {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
f.debug_struct("Sheaf")
|
|
.field("num_vertices", &self.stalks.len())
|
|
.field("num_restrictions", &self.restriction_maps.len())
|
|
.field("total_dimension", &self.total_dim)
|
|
.finish()
|
|
}
|
|
}
|
|
|
|
/// Builder for constructing sheaves
|
|
pub struct SheafBuilder {
|
|
sheaf: Sheaf,
|
|
}
|
|
|
|
impl SheafBuilder {
|
|
/// Create a new builder
|
|
pub fn new() -> Self {
|
|
Self {
|
|
sheaf: Sheaf::new(),
|
|
}
|
|
}
|
|
|
|
/// Add a stalk at a vertex
|
|
pub fn stalk(mut self, vertex: NodeId, dimension: usize) -> Self {
|
|
self.sheaf.add_stalk(vertex, Stalk::new(dimension));
|
|
self
|
|
}
|
|
|
|
/// Add an identity restriction between vertices
|
|
pub fn identity_restriction(mut self, source: NodeId, target: NodeId) -> Self {
|
|
let identity: RestrictionFn = Arc::new(|v: &Array1<f64>| v.clone());
|
|
self.sheaf.add_restriction(source, target, identity);
|
|
self
|
|
}
|
|
|
|
/// Add a scaling restriction
|
|
pub fn scaling_restriction(mut self, source: NodeId, target: NodeId, scale: f64) -> Self {
|
|
let scale_fn: RestrictionFn = Arc::new(move |v: &Array1<f64>| v * scale);
|
|
self.sheaf.add_restriction(source, target, scale_fn);
|
|
self
|
|
}
|
|
|
|
/// Add a projection restriction (select certain dimensions)
|
|
pub fn projection_restriction(
|
|
mut self,
|
|
source: NodeId,
|
|
target: NodeId,
|
|
indices: Vec<usize>,
|
|
) -> Self {
|
|
let proj_fn: RestrictionFn =
|
|
Arc::new(move |v: &Array1<f64>| Array1::from_iter(indices.iter().map(|&i| v[i])));
|
|
self.sheaf.add_restriction(source, target, proj_fn);
|
|
self
|
|
}
|
|
|
|
/// Add a linear restriction with a matrix
|
|
pub fn linear_restriction(
|
|
mut self,
|
|
source: NodeId,
|
|
target: NodeId,
|
|
matrix: Array2<f64>,
|
|
) -> Self {
|
|
let linear_fn: RestrictionFn = Arc::new(move |v: &Array1<f64>| matrix.dot(v));
|
|
self.sheaf.add_restriction(source, target, linear_fn);
|
|
self
|
|
}
|
|
|
|
/// Build the sheaf
|
|
pub fn build(self) -> Sheaf {
|
|
self.sheaf
|
|
}
|
|
}
|
|
|
|
impl Default for SheafBuilder {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use uuid::Uuid;
|
|
|
|
fn make_node_id() -> NodeId {
|
|
Uuid::new_v4()
|
|
}
|
|
|
|
#[test]
|
|
fn test_sheaf_creation() {
|
|
let v0 = make_node_id();
|
|
let v1 = make_node_id();
|
|
|
|
let sheaf = SheafBuilder::new()
|
|
.stalk(v0, 3)
|
|
.stalk(v1, 3)
|
|
.identity_restriction(v0, v1)
|
|
.identity_restriction(v1, v0)
|
|
.build();
|
|
|
|
assert_eq!(sheaf.num_vertices(), 2);
|
|
assert_eq!(sheaf.total_dimension(), 6);
|
|
}
|
|
|
|
#[test]
|
|
fn test_consistent_section() {
|
|
let v0 = make_node_id();
|
|
let v1 = make_node_id();
|
|
|
|
let sheaf = SheafBuilder::new()
|
|
.stalk(v0, 2)
|
|
.stalk(v1, 2)
|
|
.identity_restriction(v0, v1)
|
|
.identity_restriction(v1, v0)
|
|
.build();
|
|
|
|
// Consistent section: same value at both vertices
|
|
let mut section = SheafSection::empty();
|
|
section.set(v0, Array1::from_vec(vec![1.0, 2.0]));
|
|
section.set(v1, Array1::from_vec(vec![1.0, 2.0]));
|
|
|
|
assert!(sheaf.is_consistent(§ion, 1e-10));
|
|
}
|
|
|
|
#[test]
|
|
fn test_inconsistent_section() {
|
|
let v0 = make_node_id();
|
|
let v1 = make_node_id();
|
|
|
|
let sheaf = SheafBuilder::new()
|
|
.stalk(v0, 2)
|
|
.stalk(v1, 2)
|
|
.identity_restriction(v0, v1)
|
|
.identity_restriction(v1, v0)
|
|
.build();
|
|
|
|
// Inconsistent section: different values
|
|
let mut section = SheafSection::empty();
|
|
section.set(v0, Array1::from_vec(vec![1.0, 2.0]));
|
|
section.set(v1, Array1::from_vec(vec![3.0, 4.0]));
|
|
|
|
assert!(!sheaf.is_consistent(§ion, 1e-10));
|
|
}
|
|
|
|
#[test]
|
|
fn test_edge_residual() {
|
|
let v0 = make_node_id();
|
|
let v1 = make_node_id();
|
|
|
|
let sheaf = SheafBuilder::new()
|
|
.stalk(v0, 2)
|
|
.stalk(v1, 2)
|
|
.identity_restriction(v0, v1)
|
|
.identity_restriction(v1, v0)
|
|
.build();
|
|
|
|
let mut section = SheafSection::empty();
|
|
section.set(v0, Array1::from_vec(vec![1.0, 2.0]));
|
|
section.set(v1, Array1::from_vec(vec![1.5, 2.5]));
|
|
|
|
let residual = sheaf.edge_residual(v0, v1, §ion).unwrap();
|
|
|
|
// Residual should be [1.0, 2.0] - [1.5, 2.5] = [-0.5, -0.5]
|
|
assert!((residual[0] - (-0.5)).abs() < 1e-10);
|
|
assert!((residual[1] - (-0.5)).abs() < 1e-10);
|
|
}
|
|
}
|