Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
332
vendor/ruvector/crates/prime-radiant/src/hyperbolic/adapter.rs
vendored
Normal file
332
vendor/ruvector/crates/prime-radiant/src/hyperbolic/adapter.rs
vendored
Normal file
@@ -0,0 +1,332 @@
|
||||
//! Adapter to ruvector-hyperbolic-hnsw
|
||||
//!
|
||||
//! Provides a domain-specific interface for hyperbolic coherence operations.
|
||||
|
||||
use super::{HyperbolicCoherenceConfig, HyperbolicCoherenceError, NodeId, Result};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Epsilon for numerical stability
|
||||
const EPS: f32 = 1e-5;
|
||||
|
||||
/// Adapter wrapping ruvector-hyperbolic-hnsw functionality
|
||||
///
|
||||
/// This adapter provides coherence-specific operations built on top of
|
||||
/// the hyperbolic HNSW index, including:
|
||||
/// - Poincare ball projection
|
||||
/// - Distance computation with curvature awareness
|
||||
/// - Frechet mean calculation
|
||||
/// - Similarity search
|
||||
#[derive(Debug)]
|
||||
pub struct HyperbolicAdapter {
|
||||
/// Configuration
|
||||
config: HyperbolicCoherenceConfig,
|
||||
/// Node vectors (projected to ball)
|
||||
vectors: HashMap<NodeId, Vec<f32>>,
|
||||
/// Index for similarity search (simple implementation)
|
||||
/// In production, this would use ShardedHyperbolicHnsw
|
||||
index_built: bool,
|
||||
}
|
||||
|
||||
impl HyperbolicAdapter {
|
||||
/// Create a new adapter
|
||||
pub fn new(config: HyperbolicCoherenceConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
vectors: HashMap::new(),
|
||||
index_built: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Project a vector to the Poincare ball
|
||||
///
|
||||
/// Ensures the vector has norm < 1 (within ball radius)
|
||||
pub fn project_to_ball(&self, vector: &[f32]) -> Result<Vec<f32>> {
|
||||
let norm_sq: f32 = vector.iter().map(|x| x * x).sum();
|
||||
let norm = norm_sq.sqrt();
|
||||
|
||||
if norm < 1.0 - self.config.epsilon {
|
||||
// Already inside ball
|
||||
return Ok(vector.to_vec());
|
||||
}
|
||||
|
||||
// Project to boundary with epsilon margin
|
||||
let max_norm = 1.0 - self.config.epsilon;
|
||||
let scale = max_norm / (norm + EPS);
|
||||
|
||||
let projected: Vec<f32> = vector.iter().map(|x| x * scale).collect();
|
||||
|
||||
Ok(projected)
|
||||
}
|
||||
|
||||
/// Insert a vector (must already be projected)
|
||||
pub fn insert(&mut self, node_id: NodeId, vector: Vec<f32>) -> Result<()> {
|
||||
self.vectors.insert(node_id, vector);
|
||||
self.index_built = false; // Invalidate index
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update a vector
|
||||
pub fn update(&mut self, node_id: NodeId, vector: Vec<f32>) -> Result<()> {
|
||||
if !self.vectors.contains_key(&node_id) {
|
||||
return Err(HyperbolicCoherenceError::NodeNotFound(node_id));
|
||||
}
|
||||
self.vectors.insert(node_id, vector);
|
||||
self.index_built = false;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get a vector
|
||||
pub fn get(&self, node_id: NodeId) -> Option<&Vec<f32>> {
|
||||
self.vectors.get(&node_id)
|
||||
}
|
||||
|
||||
/// Compute Poincare distance between two points
|
||||
///
|
||||
/// d(x, y) = acosh(1 + 2 * |x-y|^2 / ((1-|x|^2)(1-|y|^2))) / sqrt(-c)
|
||||
pub fn poincare_distance(&self, x: &[f32], y: &[f32]) -> f32 {
|
||||
let c = -self.config.curvature; // Make positive for computation
|
||||
|
||||
let norm_x_sq: f32 = x.iter().map(|v| v * v).sum();
|
||||
let norm_y_sq: f32 = y.iter().map(|v| v * v).sum();
|
||||
|
||||
let diff_sq: f32 = x.iter().zip(y.iter()).map(|(a, b)| (a - b) * (a - b)).sum();
|
||||
|
||||
let denom = (1.0 - norm_x_sq).max(EPS) * (1.0 - norm_y_sq).max(EPS);
|
||||
let inner = 1.0 + 2.0 * diff_sq / denom;
|
||||
|
||||
// acosh(x) = ln(x + sqrt(x^2 - 1))
|
||||
let acosh_inner = if inner >= 1.0 {
|
||||
(inner + (inner * inner - 1.0).sqrt()).ln()
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
acosh_inner / c.sqrt()
|
||||
}
|
||||
|
||||
/// Compute Frechet mean of multiple points in Poincare ball
|
||||
///
|
||||
/// Uses iterative gradient descent on the hyperbolic manifold.
|
||||
pub fn frechet_mean(&self, points: &[&Vec<f32>]) -> Result<Vec<f32>> {
|
||||
if points.is_empty() {
|
||||
return Err(HyperbolicCoherenceError::EmptyCollection);
|
||||
}
|
||||
|
||||
if points.len() == 1 {
|
||||
return Ok(points[0].clone());
|
||||
}
|
||||
|
||||
let dim = points[0].len();
|
||||
|
||||
// Initialize with Euclidean mean projected to ball
|
||||
let mut mean: Vec<f32> = vec![0.0; dim];
|
||||
for p in points {
|
||||
for (m, &v) in mean.iter_mut().zip(p.iter()) {
|
||||
*m += v;
|
||||
}
|
||||
}
|
||||
for m in mean.iter_mut() {
|
||||
*m /= points.len() as f32;
|
||||
}
|
||||
mean = self.project_to_ball(&mean)?;
|
||||
|
||||
// Iterative refinement
|
||||
for _ in 0..self.config.frechet_max_iters {
|
||||
let mut grad = vec![0.0f32; dim];
|
||||
let mut total_dist = 0.0f32;
|
||||
|
||||
for &p in points {
|
||||
// Log map from mean to point
|
||||
let log = self.log_map(&mean, p);
|
||||
for (g, l) in grad.iter_mut().zip(log.iter()) {
|
||||
*g += l;
|
||||
}
|
||||
total_dist += self.poincare_distance(&mean, p);
|
||||
}
|
||||
|
||||
// Average gradient
|
||||
for g in grad.iter_mut() {
|
||||
*g /= points.len() as f32;
|
||||
}
|
||||
|
||||
// Check convergence
|
||||
let grad_norm: f32 = grad.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if grad_norm < self.config.frechet_tolerance {
|
||||
break;
|
||||
}
|
||||
|
||||
// Exponential map to move along gradient
|
||||
let step_size = 0.1f32.min(1.0 / (total_dist + 1.0));
|
||||
let step: Vec<f32> = grad.iter().map(|g| g * step_size).collect();
|
||||
mean = self.exp_map(&mean, &step)?;
|
||||
mean = self.project_to_ball(&mean)?;
|
||||
}
|
||||
|
||||
Ok(mean)
|
||||
}
|
||||
|
||||
/// Logarithmic map: tangent vector from base to point
|
||||
fn log_map(&self, base: &[f32], point: &[f32]) -> Vec<f32> {
|
||||
let c = -self.config.curvature;
|
||||
|
||||
let diff: Vec<f32> = point.iter().zip(base.iter()).map(|(p, b)| p - b).collect();
|
||||
let diff_norm: f32 = diff.iter().map(|x| x * x).sum::<f32>().sqrt().max(EPS);
|
||||
|
||||
let base_norm_sq: f32 = base.iter().map(|x| x * x).sum();
|
||||
let lambda_base = 2.0 / (1.0 - base_norm_sq).max(EPS);
|
||||
|
||||
let dist = self.poincare_distance(base, point);
|
||||
let scale = dist * lambda_base.sqrt() / (c.sqrt() * diff_norm);
|
||||
|
||||
diff.iter().map(|d| d * scale).collect()
|
||||
}
|
||||
|
||||
/// Exponential map: move from base along tangent vector
|
||||
fn exp_map(&self, base: &[f32], tangent: &[f32]) -> Result<Vec<f32>> {
|
||||
let c = -self.config.curvature;
|
||||
|
||||
let tangent_norm: f32 = tangent.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if tangent_norm < EPS {
|
||||
return Ok(base.to_vec());
|
||||
}
|
||||
|
||||
let base_norm_sq: f32 = base.iter().map(|x| x * x).sum();
|
||||
let lambda_base = 2.0 / (1.0 - base_norm_sq).max(EPS);
|
||||
|
||||
let normalized: Vec<f32> = tangent.iter().map(|t| t / tangent_norm).collect();
|
||||
let scaled_norm = tangent_norm / lambda_base.sqrt();
|
||||
|
||||
// tanh(sqrt(c) * t / 2)
|
||||
let tanh_arg = c.sqrt() * scaled_norm;
|
||||
let tanh_val = tanh_arg.tanh();
|
||||
|
||||
let scale = tanh_val / c.sqrt();
|
||||
|
||||
let mut result: Vec<f32> = base.to_vec();
|
||||
for (r, n) in result.iter_mut().zip(normalized.iter()) {
|
||||
*r += scale * n;
|
||||
}
|
||||
|
||||
self.project_to_ball(&result)
|
||||
}
|
||||
|
||||
/// Search for k nearest neighbors
|
||||
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(NodeId, f32)>> {
|
||||
if self.vectors.is_empty() {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
// Simple brute-force search (in production, use HNSW)
|
||||
let mut distances: Vec<(NodeId, f32)> = self
|
||||
.vectors
|
||||
.iter()
|
||||
.map(|(&id, vec)| (id, self.poincare_distance(query, vec)))
|
||||
.collect();
|
||||
|
||||
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
distances.truncate(k);
|
||||
|
||||
Ok(distances)
|
||||
}
|
||||
|
||||
/// Get number of vectors
|
||||
pub fn len(&self) -> usize {
|
||||
self.vectors.len()
|
||||
}
|
||||
|
||||
/// Check if empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.vectors.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_projection() {
|
||||
let config = HyperbolicCoherenceConfig {
|
||||
dimension: 4,
|
||||
curvature: -1.0,
|
||||
..Default::default()
|
||||
};
|
||||
let adapter = HyperbolicAdapter::new(config);
|
||||
|
||||
// Vector inside ball - should be unchanged
|
||||
let inside = vec![0.1, 0.1, 0.1, 0.1];
|
||||
let projected = adapter.project_to_ball(&inside).unwrap();
|
||||
assert!((projected[0] - inside[0]).abs() < 0.01);
|
||||
|
||||
// Vector outside ball - should be projected
|
||||
let outside = vec![0.9, 0.9, 0.9, 0.9];
|
||||
let projected = adapter.project_to_ball(&outside).unwrap();
|
||||
let norm: f32 = projected.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
assert!(norm < 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_poincare_distance() {
|
||||
let config = HyperbolicCoherenceConfig {
|
||||
dimension: 4,
|
||||
curvature: -1.0,
|
||||
..Default::default()
|
||||
};
|
||||
let adapter = HyperbolicAdapter::new(config);
|
||||
|
||||
let origin = vec![0.0, 0.0, 0.0, 0.0];
|
||||
let point = vec![0.5, 0.0, 0.0, 0.0];
|
||||
|
||||
let dist = adapter.poincare_distance(&origin, &point);
|
||||
assert!(dist > 0.0);
|
||||
|
||||
// Distance from point to itself should be 0
|
||||
let self_dist = adapter.poincare_distance(&point, &point);
|
||||
assert!(self_dist < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_frechet_mean() {
|
||||
let config = HyperbolicCoherenceConfig {
|
||||
dimension: 4,
|
||||
curvature: -1.0,
|
||||
..Default::default()
|
||||
};
|
||||
let adapter = HyperbolicAdapter::new(config);
|
||||
|
||||
let points = vec![
|
||||
vec![0.1, 0.0, 0.0, 0.0],
|
||||
vec![-0.1, 0.0, 0.0, 0.0],
|
||||
vec![0.0, 0.1, 0.0, 0.0],
|
||||
vec![0.0, -0.1, 0.0, 0.0],
|
||||
];
|
||||
|
||||
let refs: Vec<&Vec<f32>> = points.iter().collect();
|
||||
let mean = adapter.frechet_mean(&refs).unwrap();
|
||||
|
||||
// Mean should be near origin
|
||||
let mean_norm: f32 = mean.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
assert!(mean_norm < 0.1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_search() {
|
||||
let config = HyperbolicCoherenceConfig {
|
||||
dimension: 4,
|
||||
curvature: -1.0,
|
||||
..Default::default()
|
||||
};
|
||||
let mut adapter = HyperbolicAdapter::new(config);
|
||||
|
||||
adapter.insert(1, vec![0.1, 0.0, 0.0, 0.0]).unwrap();
|
||||
adapter.insert(2, vec![0.2, 0.0, 0.0, 0.0]).unwrap();
|
||||
adapter.insert(3, vec![0.5, 0.0, 0.0, 0.0]).unwrap();
|
||||
|
||||
let query = vec![0.15, 0.0, 0.0, 0.0];
|
||||
let results = adapter.search(&query, 2).unwrap();
|
||||
|
||||
assert_eq!(results.len(), 2);
|
||||
// Closest should be node 1 or 2
|
||||
assert!(results[0].0 == 1 || results[0].0 == 2);
|
||||
}
|
||||
}
|
||||
169
vendor/ruvector/crates/prime-radiant/src/hyperbolic/config.rs
vendored
Normal file
169
vendor/ruvector/crates/prime-radiant/src/hyperbolic/config.rs
vendored
Normal file
@@ -0,0 +1,169 @@
|
||||
//! Hyperbolic Coherence Configuration
|
||||
//!
|
||||
//! Configuration for hyperbolic coherence computation.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Configuration for hyperbolic coherence computation
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HyperbolicCoherenceConfig {
|
||||
/// State vector dimension
|
||||
pub dimension: usize,
|
||||
|
||||
/// Curvature of the hyperbolic space (must be negative)
|
||||
/// Typical values: -1.0 (unit curvature), -0.5 (flatter), -2.0 (more curved)
|
||||
pub curvature: f32,
|
||||
|
||||
/// Epsilon for numerical stability (projection boundary)
|
||||
pub epsilon: f32,
|
||||
|
||||
/// Maximum number of iterations for Frechet mean computation
|
||||
pub frechet_max_iters: usize,
|
||||
|
||||
/// Convergence threshold for Frechet mean
|
||||
pub frechet_tolerance: f32,
|
||||
|
||||
/// Depth weight function type
|
||||
pub depth_weight_type: DepthWeightType,
|
||||
|
||||
/// HNSW M parameter (max connections per node)
|
||||
pub hnsw_m: usize,
|
||||
|
||||
/// HNSW ef_construction parameter
|
||||
pub hnsw_ef_construction: usize,
|
||||
|
||||
/// Enable sharding for large collections
|
||||
pub enable_sharding: bool,
|
||||
|
||||
/// Default shard curvature
|
||||
pub default_shard_curvature: f32,
|
||||
}
|
||||
|
||||
impl Default for HyperbolicCoherenceConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
dimension: 64,
|
||||
curvature: -1.0,
|
||||
epsilon: 1e-5,
|
||||
frechet_max_iters: 100,
|
||||
frechet_tolerance: 1e-6,
|
||||
depth_weight_type: DepthWeightType::Logarithmic,
|
||||
hnsw_m: 16,
|
||||
hnsw_ef_construction: 200,
|
||||
enable_sharding: false,
|
||||
default_shard_curvature: -1.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl HyperbolicCoherenceConfig {
|
||||
/// Create a configuration for small collections (< 10K nodes)
|
||||
pub fn small() -> Self {
|
||||
Self {
|
||||
dimension: 64,
|
||||
curvature: -1.0,
|
||||
hnsw_m: 8,
|
||||
hnsw_ef_construction: 100,
|
||||
enable_sharding: false,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a configuration for large collections (> 100K nodes)
|
||||
pub fn large() -> Self {
|
||||
Self {
|
||||
dimension: 64,
|
||||
curvature: -1.0,
|
||||
hnsw_m: 32,
|
||||
hnsw_ef_construction: 400,
|
||||
enable_sharding: true,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate configuration
|
||||
pub fn validate(&self) -> Result<(), String> {
|
||||
if self.curvature >= 0.0 {
|
||||
return Err(format!(
|
||||
"Curvature must be negative, got {}",
|
||||
self.curvature
|
||||
));
|
||||
}
|
||||
if self.dimension == 0 {
|
||||
return Err("Dimension must be positive".to_string());
|
||||
}
|
||||
if self.epsilon <= 0.0 {
|
||||
return Err("Epsilon must be positive".to_string());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Compute depth weight using configured function type
|
||||
pub fn depth_weight_fn(&self, depth: f32) -> f32 {
|
||||
self.depth_weight_type.compute(depth)
|
||||
}
|
||||
}
|
||||
|
||||
/// Type of depth weighting function
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum DepthWeightType {
|
||||
/// Constant weight (no depth scaling)
|
||||
Constant,
|
||||
/// Linear: 1 + depth
|
||||
Linear,
|
||||
/// Logarithmic: 1 + ln(max(depth, 1))
|
||||
Logarithmic,
|
||||
/// Quadratic: 1 + depth^2
|
||||
Quadratic,
|
||||
/// Exponential: e^(depth * scale)
|
||||
Exponential,
|
||||
}
|
||||
|
||||
impl Default for DepthWeightType {
|
||||
fn default() -> Self {
|
||||
Self::Logarithmic
|
||||
}
|
||||
}
|
||||
|
||||
impl DepthWeightType {
|
||||
/// Compute depth weight
|
||||
pub fn compute(&self, depth: f32) -> f32 {
|
||||
match self {
|
||||
Self::Constant => 1.0,
|
||||
Self::Linear => 1.0 + depth,
|
||||
Self::Logarithmic => 1.0 + depth.max(1.0).ln(),
|
||||
Self::Quadratic => 1.0 + depth * depth,
|
||||
Self::Exponential => (depth * 0.5).exp().min(10.0), // Capped at 10x
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_default_config() {
|
||||
let config = HyperbolicCoherenceConfig::default();
|
||||
assert_eq!(config.curvature, -1.0);
|
||||
assert!(config.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_curvature() {
|
||||
let config = HyperbolicCoherenceConfig {
|
||||
curvature: 1.0, // Invalid - must be negative
|
||||
..Default::default()
|
||||
};
|
||||
assert!(config.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_depth_weights() {
|
||||
assert_eq!(DepthWeightType::Constant.compute(5.0), 1.0);
|
||||
assert_eq!(DepthWeightType::Linear.compute(5.0), 6.0);
|
||||
|
||||
let log_weight = DepthWeightType::Logarithmic.compute(2.718281828);
|
||||
assert!((log_weight - 2.0).abs() < 0.01);
|
||||
}
|
||||
}
|
||||
214
vendor/ruvector/crates/prime-radiant/src/hyperbolic/depth.rs
vendored
Normal file
214
vendor/ruvector/crates/prime-radiant/src/hyperbolic/depth.rs
vendored
Normal file
@@ -0,0 +1,214 @@
|
||||
//! Depth Computation for Hyperbolic Hierarchy
|
||||
//!
|
||||
//! Computes hierarchical depth from Poincare ball coordinates.
|
||||
|
||||
/// Epsilon for numerical stability
|
||||
const EPS: f32 = 1e-5;
|
||||
|
||||
/// Computes depth in the Poincare ball model
|
||||
///
|
||||
/// Depth is defined as the hyperbolic distance from the origin,
|
||||
/// which correlates with hierarchy level in embedded trees.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DepthComputer {
|
||||
/// Curvature of the hyperbolic space
|
||||
curvature: f32,
|
||||
/// Threshold boundaries for hierarchy levels
|
||||
level_thresholds: [f32; 4],
|
||||
}
|
||||
|
||||
impl DepthComputer {
|
||||
/// Create a new depth computer
|
||||
pub fn new(curvature: f32) -> Self {
|
||||
// Default thresholds based on typical hierarchy depths
|
||||
Self {
|
||||
curvature,
|
||||
level_thresholds: [0.5, 1.0, 2.0, 3.0],
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with custom thresholds
|
||||
pub fn with_thresholds(curvature: f32, thresholds: [f32; 4]) -> Self {
|
||||
Self {
|
||||
curvature,
|
||||
level_thresholds: thresholds,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute depth as hyperbolic distance from origin
|
||||
///
|
||||
/// In the Poincare ball, depth = 2 * arctanh(|x|) / sqrt(-c)
|
||||
pub fn compute_depth(&self, point: &[f32]) -> f32 {
|
||||
let norm_sq: f32 = point.iter().map(|x| x * x).sum();
|
||||
let norm = norm_sq.sqrt();
|
||||
|
||||
if norm < EPS {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let c = -self.curvature;
|
||||
|
||||
// arctanh(x) = 0.5 * ln((1+x)/(1-x))
|
||||
let clamped_norm = norm.min(1.0 - EPS);
|
||||
let arctanh = 0.5 * ((1.0 + clamped_norm) / (1.0 - clamped_norm)).ln();
|
||||
|
||||
2.0 * arctanh / c.sqrt()
|
||||
}
|
||||
|
||||
/// Compute normalized depth (0 to 1 range based on typical max)
|
||||
pub fn normalized_depth(&self, point: &[f32]) -> f32 {
|
||||
let depth = self.compute_depth(point);
|
||||
// Typical max depth around 5-6 for deep hierarchies
|
||||
(depth / 5.0).min(1.0)
|
||||
}
|
||||
|
||||
/// Classify depth into hierarchy level
|
||||
pub fn classify_level(&self, depth: f32) -> HierarchyLevel {
|
||||
if depth < self.level_thresholds[0] {
|
||||
HierarchyLevel::Root
|
||||
} else if depth < self.level_thresholds[1] {
|
||||
HierarchyLevel::High
|
||||
} else if depth < self.level_thresholds[2] {
|
||||
HierarchyLevel::Mid
|
||||
} else if depth < self.level_thresholds[3] {
|
||||
HierarchyLevel::Deep
|
||||
} else {
|
||||
HierarchyLevel::VeryDeep
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute radius at which a given depth is achieved
|
||||
pub fn radius_for_depth(&self, target_depth: f32) -> f32 {
|
||||
let c = -self.curvature;
|
||||
// Inverse of depth formula: r = tanh(depth * sqrt(c) / 2)
|
||||
(target_depth * c.sqrt() / 2.0).tanh()
|
||||
}
|
||||
|
||||
/// Get curvature
|
||||
pub fn curvature(&self) -> f32 {
|
||||
self.curvature
|
||||
}
|
||||
}
|
||||
|
||||
/// Hierarchy level classification
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum HierarchyLevel {
|
||||
/// Root level (depth < 0.5)
|
||||
Root,
|
||||
/// High level (0.5 <= depth < 1.0)
|
||||
High,
|
||||
/// Mid level (1.0 <= depth < 2.0)
|
||||
Mid,
|
||||
/// Deep level (2.0 <= depth < 3.0)
|
||||
Deep,
|
||||
/// Very deep level (depth >= 3.0)
|
||||
VeryDeep,
|
||||
}
|
||||
|
||||
impl HierarchyLevel {
|
||||
/// Get numeric level (0 = Root, 4 = VeryDeep)
|
||||
pub fn as_level(&self) -> usize {
|
||||
match self {
|
||||
Self::Root => 0,
|
||||
Self::High => 1,
|
||||
Self::Mid => 2,
|
||||
Self::Deep => 3,
|
||||
Self::VeryDeep => 4,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get weight multiplier for this level
|
||||
pub fn weight_multiplier(&self) -> f32 {
|
||||
match self {
|
||||
Self::Root => 1.0,
|
||||
Self::High => 1.2,
|
||||
Self::Mid => 1.5,
|
||||
Self::Deep => 2.0,
|
||||
Self::VeryDeep => 3.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get human-readable name
|
||||
pub fn name(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Root => "root",
|
||||
Self::High => "high",
|
||||
Self::Mid => "mid",
|
||||
Self::Deep => "deep",
|
||||
Self::VeryDeep => "very_deep",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for HierarchyLevel {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.name())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_depth_at_origin() {
|
||||
let computer = DepthComputer::new(-1.0);
|
||||
let origin = vec![0.0, 0.0, 0.0, 0.0];
|
||||
let depth = computer.compute_depth(&origin);
|
||||
assert!(depth < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_depth_increases_with_radius() {
|
||||
let computer = DepthComputer::new(-1.0);
|
||||
|
||||
let point1 = vec![0.1, 0.0, 0.0, 0.0];
|
||||
let point2 = vec![0.5, 0.0, 0.0, 0.0];
|
||||
let point3 = vec![0.9, 0.0, 0.0, 0.0];
|
||||
|
||||
let d1 = computer.compute_depth(&point1);
|
||||
let d2 = computer.compute_depth(&point2);
|
||||
let d3 = computer.compute_depth(&point3);
|
||||
|
||||
assert!(d1 < d2);
|
||||
assert!(d2 < d3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hierarchy_levels() {
|
||||
let computer = DepthComputer::new(-1.0);
|
||||
|
||||
assert_eq!(computer.classify_level(0.3), HierarchyLevel::Root);
|
||||
assert_eq!(computer.classify_level(0.7), HierarchyLevel::High);
|
||||
assert_eq!(computer.classify_level(1.5), HierarchyLevel::Mid);
|
||||
assert_eq!(computer.classify_level(2.5), HierarchyLevel::Deep);
|
||||
assert_eq!(computer.classify_level(4.0), HierarchyLevel::VeryDeep);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_radius_for_depth() {
|
||||
let computer = DepthComputer::new(-1.0);
|
||||
|
||||
let radius = computer.radius_for_depth(1.0);
|
||||
let point = vec![radius, 0.0, 0.0, 0.0];
|
||||
let computed_depth = computer.compute_depth(&point);
|
||||
|
||||
assert!((computed_depth - 1.0).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalized_depth() {
|
||||
let computer = DepthComputer::new(-1.0);
|
||||
|
||||
let shallow = vec![0.1, 0.0, 0.0, 0.0];
|
||||
let deep = vec![0.95, 0.0, 0.0, 0.0];
|
||||
|
||||
let norm_shallow = computer.normalized_depth(&shallow);
|
||||
let norm_deep = computer.normalized_depth(&deep);
|
||||
|
||||
assert!(norm_shallow < 0.2);
|
||||
assert!(norm_deep > 0.5);
|
||||
assert!(norm_shallow <= 1.0);
|
||||
assert!(norm_deep <= 1.0);
|
||||
}
|
||||
}
|
||||
351
vendor/ruvector/crates/prime-radiant/src/hyperbolic/energy.rs
vendored
Normal file
351
vendor/ruvector/crates/prime-radiant/src/hyperbolic/energy.rs
vendored
Normal file
@@ -0,0 +1,351 @@
|
||||
//! Hyperbolic Energy Computation
|
||||
//!
|
||||
//! Structures for representing depth-weighted coherence energy.
|
||||
|
||||
use super::NodeId;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Result of computing a weighted residual for a single edge
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WeightedResidual {
|
||||
/// Source node ID
|
||||
pub source_id: NodeId,
|
||||
/// Target node ID
|
||||
pub target_id: NodeId,
|
||||
/// Depth of source node
|
||||
pub source_depth: f32,
|
||||
/// Depth of target node
|
||||
pub target_depth: f32,
|
||||
/// Depth-based weight multiplier
|
||||
pub depth_weight: f32,
|
||||
/// Squared norm of the residual vector
|
||||
pub residual_norm_sq: f32,
|
||||
/// Base weight from edge definition
|
||||
pub base_weight: f32,
|
||||
/// Final weighted energy: base_weight * residual_norm_sq * depth_weight
|
||||
pub weighted_energy: f32,
|
||||
}
|
||||
|
||||
impl WeightedResidual {
|
||||
/// Get average depth of the edge
|
||||
pub fn avg_depth(&self) -> f32 {
|
||||
(self.source_depth + self.target_depth) / 2.0
|
||||
}
|
||||
|
||||
/// Get maximum depth
|
||||
pub fn max_depth(&self) -> f32 {
|
||||
self.source_depth.max(self.target_depth)
|
||||
}
|
||||
|
||||
/// Get unweighted energy (without depth scaling)
|
||||
pub fn unweighted_energy(&self) -> f32 {
|
||||
self.base_weight * self.residual_norm_sq
|
||||
}
|
||||
|
||||
/// Get depth contribution to energy
|
||||
pub fn depth_contribution(&self) -> f32 {
|
||||
self.weighted_energy - self.unweighted_energy()
|
||||
}
|
||||
}
|
||||
|
||||
/// Aggregated hyperbolic coherence energy
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HyperbolicEnergy {
|
||||
/// Total weighted energy across all edges
|
||||
pub total_energy: f32,
|
||||
/// Per-edge weighted residuals
|
||||
pub edge_energies: Vec<WeightedResidual>,
|
||||
/// Curvature used for computation
|
||||
pub curvature: f32,
|
||||
/// Maximum depth encountered
|
||||
pub max_depth: f32,
|
||||
/// Minimum depth encountered
|
||||
pub min_depth: f32,
|
||||
/// Number of edges
|
||||
pub num_edges: usize,
|
||||
}
|
||||
|
||||
impl HyperbolicEnergy {
|
||||
/// Create empty energy
|
||||
pub fn empty() -> Self {
|
||||
Self {
|
||||
total_energy: 0.0,
|
||||
edge_energies: vec![],
|
||||
curvature: -1.0,
|
||||
max_depth: 0.0,
|
||||
min_depth: 0.0,
|
||||
num_edges: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if coherent (energy below threshold)
|
||||
pub fn is_coherent(&self, threshold: f32) -> bool {
|
||||
self.total_energy < threshold
|
||||
}
|
||||
|
||||
/// Get average energy per edge
|
||||
pub fn avg_energy(&self) -> f32 {
|
||||
if self.num_edges == 0 {
|
||||
0.0
|
||||
} else {
|
||||
self.total_energy / self.num_edges as f32
|
||||
}
|
||||
}
|
||||
|
||||
/// Get average depth across all edges
|
||||
pub fn avg_depth(&self) -> f32 {
|
||||
if self.edge_energies.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
let sum: f32 = self.edge_energies.iter().map(|e| e.avg_depth()).sum();
|
||||
sum / self.edge_energies.len() as f32
|
||||
}
|
||||
|
||||
/// Get total unweighted energy (without depth scaling)
|
||||
pub fn total_unweighted_energy(&self) -> f32 {
|
||||
self.edge_energies
|
||||
.iter()
|
||||
.map(|e| e.unweighted_energy())
|
||||
.sum()
|
||||
}
|
||||
|
||||
/// Get depth contribution ratio
|
||||
pub fn depth_contribution_ratio(&self) -> f32 {
|
||||
let unweighted = self.total_unweighted_energy();
|
||||
if unweighted < 1e-10 {
|
||||
return 1.0;
|
||||
}
|
||||
self.total_energy / unweighted
|
||||
}
|
||||
|
||||
/// Find highest energy edge
|
||||
pub fn highest_energy_edge(&self) -> Option<&WeightedResidual> {
|
||||
self.edge_energies
|
||||
.iter()
|
||||
.max_by(|a, b| a.weighted_energy.partial_cmp(&b.weighted_energy).unwrap())
|
||||
}
|
||||
|
||||
/// Find deepest edge
|
||||
pub fn deepest_edge(&self) -> Option<&WeightedResidual> {
|
||||
self.edge_energies
|
||||
.iter()
|
||||
.max_by(|a, b| a.avg_depth().partial_cmp(&b.avg_depth()).unwrap())
|
||||
}
|
||||
|
||||
/// Get edges above energy threshold
|
||||
pub fn edges_above_threshold(&self, threshold: f32) -> Vec<&WeightedResidual> {
|
||||
self.edge_energies
|
||||
.iter()
|
||||
.filter(|e| e.weighted_energy > threshold)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get edges at specific depth level
|
||||
pub fn edges_at_depth(&self, min_depth: f32, max_depth: f32) -> Vec<&WeightedResidual> {
|
||||
self.edge_energies
|
||||
.iter()
|
||||
.filter(|e| {
|
||||
let avg = e.avg_depth();
|
||||
avg >= min_depth && avg < max_depth
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Compute energy distribution by depth buckets
|
||||
pub fn energy_by_depth_buckets(&self, num_buckets: usize) -> Vec<DepthBucketEnergy> {
|
||||
if self.edge_energies.is_empty() || num_buckets == 0 {
|
||||
return vec![];
|
||||
}
|
||||
|
||||
let depth_range = self.max_depth - self.min_depth;
|
||||
let bucket_size = if depth_range > 0.0 {
|
||||
depth_range / num_buckets as f32
|
||||
} else {
|
||||
1.0
|
||||
};
|
||||
|
||||
let mut buckets: Vec<DepthBucketEnergy> = (0..num_buckets)
|
||||
.map(|i| DepthBucketEnergy {
|
||||
bucket_index: i,
|
||||
depth_min: self.min_depth + i as f32 * bucket_size,
|
||||
depth_max: self.min_depth + (i + 1) as f32 * bucket_size,
|
||||
total_energy: 0.0,
|
||||
num_edges: 0,
|
||||
})
|
||||
.collect();
|
||||
|
||||
for edge in &self.edge_energies {
|
||||
let avg_depth = edge.avg_depth();
|
||||
let bucket_idx = ((avg_depth - self.min_depth) / bucket_size).floor() as usize;
|
||||
let bucket_idx = bucket_idx.min(num_buckets - 1);
|
||||
|
||||
buckets[bucket_idx].total_energy += edge.weighted_energy;
|
||||
buckets[bucket_idx].num_edges += 1;
|
||||
}
|
||||
|
||||
buckets
|
||||
}
|
||||
|
||||
/// Merge with another HyperbolicEnergy
|
||||
pub fn merge(&mut self, other: HyperbolicEnergy) {
|
||||
self.total_energy += other.total_energy;
|
||||
self.edge_energies.extend(other.edge_energies);
|
||||
self.max_depth = self.max_depth.max(other.max_depth);
|
||||
self.min_depth = self.min_depth.min(other.min_depth);
|
||||
self.num_edges += other.num_edges;
|
||||
}
|
||||
}
|
||||
|
||||
/// Energy aggregated by depth bucket
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DepthBucketEnergy {
|
||||
/// Bucket index (0 = shallowest)
|
||||
pub bucket_index: usize,
|
||||
/// Minimum depth in bucket
|
||||
pub depth_min: f32,
|
||||
/// Maximum depth in bucket
|
||||
pub depth_max: f32,
|
||||
/// Total energy in bucket
|
||||
pub total_energy: f32,
|
||||
/// Number of edges in bucket
|
||||
pub num_edges: usize,
|
||||
}
|
||||
|
||||
impl DepthBucketEnergy {
|
||||
/// Get average energy per edge in bucket
|
||||
pub fn avg_energy(&self) -> f32 {
|
||||
if self.num_edges == 0 {
|
||||
0.0
|
||||
} else {
|
||||
self.total_energy / self.num_edges as f32
|
||||
}
|
||||
}
|
||||
|
||||
/// Get bucket midpoint depth
|
||||
pub fn midpoint_depth(&self) -> f32 {
|
||||
(self.depth_min + self.depth_max) / 2.0
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_weighted_residual(
|
||||
source: NodeId,
|
||||
target: NodeId,
|
||||
source_depth: f32,
|
||||
target_depth: f32,
|
||||
energy: f32,
|
||||
) -> WeightedResidual {
|
||||
WeightedResidual {
|
||||
source_id: source,
|
||||
target_id: target,
|
||||
source_depth,
|
||||
target_depth,
|
||||
depth_weight: 1.0 + (source_depth + target_depth).ln().max(0.0) / 2.0,
|
||||
residual_norm_sq: energy / 2.0,
|
||||
base_weight: 1.0,
|
||||
weighted_energy: energy,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_energy() {
|
||||
let energy = HyperbolicEnergy::empty();
|
||||
assert_eq!(energy.total_energy, 0.0);
|
||||
assert_eq!(energy.num_edges, 0);
|
||||
assert!(energy.is_coherent(1.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_energy_aggregation() {
|
||||
let edge1 = make_weighted_residual(1, 2, 0.5, 0.5, 0.1);
|
||||
let edge2 = make_weighted_residual(2, 3, 1.0, 1.5, 0.2);
|
||||
let edge3 = make_weighted_residual(3, 4, 2.0, 2.5, 0.3);
|
||||
|
||||
let energy = HyperbolicEnergy {
|
||||
total_energy: 0.6,
|
||||
edge_energies: vec![edge1, edge2, edge3],
|
||||
curvature: -1.0,
|
||||
max_depth: 2.5,
|
||||
min_depth: 0.5,
|
||||
num_edges: 3,
|
||||
};
|
||||
|
||||
assert_eq!(energy.num_edges, 3);
|
||||
assert!((energy.avg_energy() - 0.2).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_highest_energy_edge() {
|
||||
let edge1 = make_weighted_residual(1, 2, 0.5, 0.5, 0.1);
|
||||
let edge2 = make_weighted_residual(2, 3, 1.0, 1.5, 0.5); // Highest
|
||||
let edge3 = make_weighted_residual(3, 4, 2.0, 2.5, 0.2);
|
||||
|
||||
let energy = HyperbolicEnergy {
|
||||
total_energy: 0.8,
|
||||
edge_energies: vec![edge1, edge2, edge3],
|
||||
curvature: -1.0,
|
||||
max_depth: 2.5,
|
||||
min_depth: 0.5,
|
||||
num_edges: 3,
|
||||
};
|
||||
|
||||
let highest = energy.highest_energy_edge().unwrap();
|
||||
assert_eq!(highest.source_id, 2);
|
||||
assert_eq!(highest.target_id, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_depth_buckets() {
|
||||
let edge1 = make_weighted_residual(1, 2, 0.5, 0.5, 0.1);
|
||||
let edge2 = make_weighted_residual(2, 3, 1.5, 1.5, 0.2);
|
||||
let edge3 = make_weighted_residual(3, 4, 2.5, 2.5, 0.3);
|
||||
|
||||
let energy = HyperbolicEnergy {
|
||||
total_energy: 0.6,
|
||||
edge_energies: vec![edge1, edge2, edge3],
|
||||
curvature: -1.0,
|
||||
max_depth: 2.5,
|
||||
min_depth: 0.5,
|
||||
num_edges: 3,
|
||||
};
|
||||
|
||||
let buckets = energy.energy_by_depth_buckets(2);
|
||||
assert_eq!(buckets.len(), 2);
|
||||
|
||||
// Shallow bucket should have edge1
|
||||
assert_eq!(buckets[0].num_edges, 1);
|
||||
// Deep bucket should have edge2 and edge3
|
||||
assert_eq!(buckets[1].num_edges, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge() {
|
||||
let mut energy1 = HyperbolicEnergy {
|
||||
total_energy: 0.5,
|
||||
edge_energies: vec![make_weighted_residual(1, 2, 0.5, 0.5, 0.5)],
|
||||
curvature: -1.0,
|
||||
max_depth: 0.5,
|
||||
min_depth: 0.5,
|
||||
num_edges: 1,
|
||||
};
|
||||
|
||||
let energy2 = HyperbolicEnergy {
|
||||
total_energy: 0.3,
|
||||
edge_energies: vec![make_weighted_residual(3, 4, 2.0, 2.0, 0.3)],
|
||||
curvature: -1.0,
|
||||
max_depth: 2.0,
|
||||
min_depth: 2.0,
|
||||
num_edges: 1,
|
||||
};
|
||||
|
||||
energy1.merge(energy2);
|
||||
|
||||
assert!((energy1.total_energy - 0.8).abs() < 0.01);
|
||||
assert_eq!(energy1.num_edges, 2);
|
||||
assert_eq!(energy1.max_depth, 2.0);
|
||||
assert_eq!(energy1.min_depth, 0.5);
|
||||
}
|
||||
}
|
||||
362
vendor/ruvector/crates/prime-radiant/src/hyperbolic/mod.rs
vendored
Normal file
362
vendor/ruvector/crates/prime-radiant/src/hyperbolic/mod.rs
vendored
Normal file
@@ -0,0 +1,362 @@
|
||||
//! Hyperbolic Coherence Module
|
||||
//!
|
||||
//! Hierarchy-aware coherence computation using hyperbolic geometry.
|
||||
//! Leverages `ruvector-hyperbolic-hnsw` for Poincare ball operations.
|
||||
//!
|
||||
//! # Features
|
||||
//!
|
||||
//! - Depth-aware energy weighting: deeper nodes get higher violation weights
|
||||
//! - Poincare ball projection for hierarchy-aware storage
|
||||
//! - Curvature-adaptive residual computation
|
||||
//! - Sharded hyperbolic index for scalability
|
||||
//!
|
||||
//! # Mathematical Foundation
|
||||
//!
|
||||
//! In the Poincare ball model, distance from origin correlates with hierarchy depth.
|
||||
//! Nodes closer to the boundary (|x| -> 1) are "deeper" in the hierarchy.
|
||||
//!
|
||||
//! Energy weighting: E_weighted = w_e * |r_e|^2 * depth_weight(e)
|
||||
//! where depth_weight = 1 + ln(max(avg_depth, 1))
|
||||
|
||||
mod adapter;
|
||||
mod config;
|
||||
mod depth;
|
||||
mod energy;
|
||||
|
||||
pub use adapter::HyperbolicAdapter;
|
||||
pub use config::HyperbolicCoherenceConfig;
|
||||
pub use depth::{DepthComputer, HierarchyLevel};
|
||||
pub use energy::{HyperbolicEnergy, WeightedResidual};
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Node identifier type alias
|
||||
pub type NodeId = u64;
|
||||
|
||||
/// Edge identifier type alias
|
||||
pub type EdgeId = u64;
|
||||
|
||||
/// Result type for hyperbolic coherence operations
|
||||
pub type Result<T> = std::result::Result<T, HyperbolicCoherenceError>;
|
||||
|
||||
/// Errors that can occur in hyperbolic coherence computation
|
||||
#[derive(Debug, Clone, thiserror::Error)]
|
||||
pub enum HyperbolicCoherenceError {
|
||||
/// Node not found in the index
|
||||
#[error("Node not found: {0}")]
|
||||
NodeNotFound(NodeId),
|
||||
|
||||
/// Invalid vector dimension
|
||||
#[error("Dimension mismatch: expected {expected}, got {actual}")]
|
||||
DimensionMismatch { expected: usize, actual: usize },
|
||||
|
||||
/// Curvature out of valid range
|
||||
#[error("Invalid curvature: {0} (must be negative)")]
|
||||
InvalidCurvature(f32),
|
||||
|
||||
/// Projection failed (vector outside ball)
|
||||
#[error("Projection failed: vector norm {0} exceeds ball radius")]
|
||||
ProjectionFailed(f32),
|
||||
|
||||
/// Underlying HNSW error
|
||||
#[error("HNSW error: {0}")]
|
||||
HnswError(String),
|
||||
|
||||
/// Empty collection
|
||||
#[error("Empty collection")]
|
||||
EmptyCollection,
|
||||
}
|
||||
|
||||
/// Main hyperbolic coherence engine
|
||||
///
|
||||
/// Computes hierarchy-aware coherence energy using the Poincare ball model.
|
||||
/// Deeper nodes (further from origin) receive higher weights for violations,
|
||||
/// encoding the intuition that deeper hierarchical nodes should be more consistent.
|
||||
#[derive(Debug)]
|
||||
pub struct HyperbolicCoherence {
|
||||
/// Configuration
|
||||
config: HyperbolicCoherenceConfig,
|
||||
/// Adapter to underlying hyperbolic HNSW
|
||||
adapter: HyperbolicAdapter,
|
||||
/// Depth computer
|
||||
depth: DepthComputer,
|
||||
/// Node states (node_id -> state vector)
|
||||
node_states: HashMap<NodeId, Vec<f32>>,
|
||||
/// Node depths (cached)
|
||||
node_depths: HashMap<NodeId, f32>,
|
||||
}
|
||||
|
||||
impl HyperbolicCoherence {
|
||||
/// Create a new hyperbolic coherence engine
|
||||
pub fn new(config: HyperbolicCoherenceConfig) -> Self {
|
||||
let adapter = HyperbolicAdapter::new(config.clone());
|
||||
let depth = DepthComputer::new(config.curvature);
|
||||
|
||||
Self {
|
||||
config,
|
||||
adapter,
|
||||
depth,
|
||||
node_states: HashMap::new(),
|
||||
node_depths: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with default configuration
|
||||
pub fn default_config() -> Self {
|
||||
Self::new(HyperbolicCoherenceConfig::default())
|
||||
}
|
||||
|
||||
/// Insert a node state
|
||||
pub fn insert_node(&mut self, node_id: NodeId, state: Vec<f32>) -> Result<()> {
|
||||
// Validate dimension
|
||||
if !self.node_states.is_empty() {
|
||||
let expected_dim = self.config.dimension;
|
||||
if state.len() != expected_dim {
|
||||
return Err(HyperbolicCoherenceError::DimensionMismatch {
|
||||
expected: expected_dim,
|
||||
actual: state.len(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Project to Poincare ball
|
||||
let projected = self.adapter.project_to_ball(&state)?;
|
||||
|
||||
// Compute and cache depth
|
||||
let depth = self.depth.compute_depth(&projected);
|
||||
self.node_depths.insert(node_id, depth);
|
||||
|
||||
// Store in adapter and local cache
|
||||
self.adapter.insert(node_id, projected.clone())?;
|
||||
self.node_states.insert(node_id, projected);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update a node state
|
||||
pub fn update_node(&mut self, node_id: NodeId, state: Vec<f32>) -> Result<()> {
|
||||
if !self.node_states.contains_key(&node_id) {
|
||||
return Err(HyperbolicCoherenceError::NodeNotFound(node_id));
|
||||
}
|
||||
|
||||
// Project and update
|
||||
let projected = self.adapter.project_to_ball(&state)?;
|
||||
let depth = self.depth.compute_depth(&projected);
|
||||
|
||||
self.node_depths.insert(node_id, depth);
|
||||
self.adapter.update(node_id, projected.clone())?;
|
||||
self.node_states.insert(node_id, projected);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get node state
|
||||
pub fn get_node(&self, node_id: NodeId) -> Option<&Vec<f32>> {
|
||||
self.node_states.get(&node_id)
|
||||
}
|
||||
|
||||
/// Get node depth
|
||||
pub fn get_depth(&self, node_id: NodeId) -> Option<f32> {
|
||||
self.node_depths.get(&node_id).copied()
|
||||
}
|
||||
|
||||
/// Compute depth-weighted energy for an edge
|
||||
///
|
||||
/// The energy is weighted by the average depth of the connected nodes.
|
||||
/// Deeper nodes receive higher violation weights.
|
||||
pub fn weighted_edge_energy(
|
||||
&self,
|
||||
source_id: NodeId,
|
||||
target_id: NodeId,
|
||||
residual: &[f32],
|
||||
base_weight: f32,
|
||||
) -> Result<WeightedResidual> {
|
||||
let source_depth = self
|
||||
.node_depths
|
||||
.get(&source_id)
|
||||
.ok_or(HyperbolicCoherenceError::NodeNotFound(source_id))?;
|
||||
let target_depth = self
|
||||
.node_depths
|
||||
.get(&target_id)
|
||||
.ok_or(HyperbolicCoherenceError::NodeNotFound(target_id))?;
|
||||
|
||||
let avg_depth = (source_depth + target_depth) / 2.0;
|
||||
|
||||
// Depth weight: higher for deeper nodes
|
||||
let depth_weight = self.config.depth_weight_fn(avg_depth);
|
||||
|
||||
// Residual norm squared
|
||||
let residual_norm_sq: f32 = residual.iter().map(|x| x * x).sum();
|
||||
|
||||
let weighted_energy = base_weight * residual_norm_sq * depth_weight;
|
||||
|
||||
Ok(WeightedResidual {
|
||||
source_id,
|
||||
target_id,
|
||||
source_depth: *source_depth,
|
||||
target_depth: *target_depth,
|
||||
depth_weight,
|
||||
residual_norm_sq,
|
||||
base_weight,
|
||||
weighted_energy,
|
||||
})
|
||||
}
|
||||
|
||||
/// Compute total hyperbolic energy for a set of edges
|
||||
pub fn compute_total_energy(
|
||||
&self,
|
||||
edges: &[(NodeId, NodeId, Vec<f32>, f32)], // (source, target, residual, weight)
|
||||
) -> Result<HyperbolicEnergy> {
|
||||
if edges.is_empty() {
|
||||
return Ok(HyperbolicEnergy::empty());
|
||||
}
|
||||
|
||||
let mut edge_energies = Vec::with_capacity(edges.len());
|
||||
let mut total_energy = 0.0f32;
|
||||
let mut max_depth = 0.0f32;
|
||||
let mut min_depth = f32::MAX;
|
||||
|
||||
for (source, target, residual, weight) in edges {
|
||||
let weighted = self.weighted_edge_energy(*source, *target, residual, *weight)?;
|
||||
total_energy += weighted.weighted_energy;
|
||||
max_depth = max_depth.max(weighted.source_depth.max(weighted.target_depth));
|
||||
min_depth = min_depth.min(weighted.source_depth.min(weighted.target_depth));
|
||||
edge_energies.push(weighted);
|
||||
}
|
||||
|
||||
Ok(HyperbolicEnergy {
|
||||
total_energy,
|
||||
edge_energies,
|
||||
curvature: self.config.curvature,
|
||||
max_depth,
|
||||
min_depth,
|
||||
num_edges: edges.len(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Find similar nodes in hyperbolic space
|
||||
pub fn find_similar(&self, query: &[f32], k: usize) -> Result<Vec<(NodeId, f32)>> {
|
||||
let projected = self.adapter.project_to_ball(query)?;
|
||||
self.adapter.search(&projected, k)
|
||||
}
|
||||
|
||||
/// Get hierarchy level for a node based on depth
|
||||
pub fn hierarchy_level(&self, node_id: NodeId) -> Result<HierarchyLevel> {
|
||||
let depth = self
|
||||
.node_depths
|
||||
.get(&node_id)
|
||||
.ok_or(HyperbolicCoherenceError::NodeNotFound(node_id))?;
|
||||
|
||||
Ok(self.depth.classify_level(*depth))
|
||||
}
|
||||
|
||||
/// Compute Frechet mean of a set of nodes
|
||||
pub fn frechet_mean(&self, node_ids: &[NodeId]) -> Result<Vec<f32>> {
|
||||
if node_ids.is_empty() {
|
||||
return Err(HyperbolicCoherenceError::EmptyCollection);
|
||||
}
|
||||
|
||||
let states: Vec<&Vec<f32>> = node_ids
|
||||
.iter()
|
||||
.filter_map(|id| self.node_states.get(id))
|
||||
.collect();
|
||||
|
||||
if states.is_empty() {
|
||||
return Err(HyperbolicCoherenceError::EmptyCollection);
|
||||
}
|
||||
|
||||
self.adapter.frechet_mean(&states)
|
||||
}
|
||||
|
||||
/// Get configuration
|
||||
pub fn config(&self) -> &HyperbolicCoherenceConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Get number of nodes
|
||||
pub fn num_nodes(&self) -> usize {
|
||||
self.node_states.len()
|
||||
}
|
||||
|
||||
/// Get curvature
|
||||
pub fn curvature(&self) -> f32 {
|
||||
self.config.curvature
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_basic_coherence() {
|
||||
let config = HyperbolicCoherenceConfig {
|
||||
dimension: 4,
|
||||
curvature: -1.0,
|
||||
..Default::default()
|
||||
};
|
||||
let mut coherence = HyperbolicCoherence::new(config);
|
||||
|
||||
// Insert nodes
|
||||
coherence.insert_node(1, vec![0.1, 0.1, 0.1, 0.1]).unwrap();
|
||||
coherence.insert_node(2, vec![0.2, 0.2, 0.2, 0.2]).unwrap();
|
||||
coherence.insert_node(3, vec![0.5, 0.5, 0.5, 0.5]).unwrap();
|
||||
|
||||
assert_eq!(coherence.num_nodes(), 3);
|
||||
|
||||
// Node 3 should be deeper (further from origin)
|
||||
let depth1 = coherence.get_depth(1).unwrap();
|
||||
let depth3 = coherence.get_depth(3).unwrap();
|
||||
assert!(depth3 > depth1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_weighted_energy() {
|
||||
let config = HyperbolicCoherenceConfig {
|
||||
dimension: 4,
|
||||
curvature: -1.0,
|
||||
..Default::default()
|
||||
};
|
||||
let mut coherence = HyperbolicCoherence::new(config);
|
||||
|
||||
coherence.insert_node(1, vec![0.1, 0.1, 0.1, 0.1]).unwrap();
|
||||
coherence.insert_node(2, vec![0.5, 0.5, 0.5, 0.5]).unwrap();
|
||||
|
||||
let residual = vec![0.1, 0.1, 0.1, 0.1];
|
||||
let weighted = coherence
|
||||
.weighted_edge_energy(1, 2, &residual, 1.0)
|
||||
.unwrap();
|
||||
|
||||
assert!(weighted.weighted_energy > 0.0);
|
||||
assert!(weighted.depth_weight > 1.0); // Should have depth scaling
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hierarchy_levels() {
|
||||
let config = HyperbolicCoherenceConfig {
|
||||
dimension: 4,
|
||||
curvature: -1.0,
|
||||
..Default::default()
|
||||
};
|
||||
let mut coherence = HyperbolicCoherence::new(config);
|
||||
|
||||
coherence
|
||||
.insert_node(1, vec![0.05, 0.05, 0.05, 0.05])
|
||||
.unwrap();
|
||||
coherence.insert_node(2, vec![0.7, 0.7, 0.0, 0.0]).unwrap();
|
||||
|
||||
let level1 = coherence.hierarchy_level(1).unwrap();
|
||||
let level2 = coherence.hierarchy_level(2).unwrap();
|
||||
|
||||
// Node 1 should be at higher level (closer to root)
|
||||
assert!(matches!(
|
||||
level1,
|
||||
HierarchyLevel::Root | HierarchyLevel::High
|
||||
));
|
||||
// Node 2 should be deeper
|
||||
assert!(matches!(
|
||||
level2,
|
||||
HierarchyLevel::Deep | HierarchyLevel::VeryDeep
|
||||
));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user