Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
169
vendor/ruvector/crates/prime-radiant/src/hyperbolic/config.rs
vendored
Normal file
169
vendor/ruvector/crates/prime-radiant/src/hyperbolic/config.rs
vendored
Normal file
@@ -0,0 +1,169 @@
|
||||
//! Hyperbolic Coherence Configuration
|
||||
//!
|
||||
//! Configuration for hyperbolic coherence computation.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Configuration for hyperbolic coherence computation
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HyperbolicCoherenceConfig {
|
||||
/// State vector dimension
|
||||
pub dimension: usize,
|
||||
|
||||
/// Curvature of the hyperbolic space (must be negative)
|
||||
/// Typical values: -1.0 (unit curvature), -0.5 (flatter), -2.0 (more curved)
|
||||
pub curvature: f32,
|
||||
|
||||
/// Epsilon for numerical stability (projection boundary)
|
||||
pub epsilon: f32,
|
||||
|
||||
/// Maximum number of iterations for Frechet mean computation
|
||||
pub frechet_max_iters: usize,
|
||||
|
||||
/// Convergence threshold for Frechet mean
|
||||
pub frechet_tolerance: f32,
|
||||
|
||||
/// Depth weight function type
|
||||
pub depth_weight_type: DepthWeightType,
|
||||
|
||||
/// HNSW M parameter (max connections per node)
|
||||
pub hnsw_m: usize,
|
||||
|
||||
/// HNSW ef_construction parameter
|
||||
pub hnsw_ef_construction: usize,
|
||||
|
||||
/// Enable sharding for large collections
|
||||
pub enable_sharding: bool,
|
||||
|
||||
/// Default shard curvature
|
||||
pub default_shard_curvature: f32,
|
||||
}
|
||||
|
||||
impl Default for HyperbolicCoherenceConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
dimension: 64,
|
||||
curvature: -1.0,
|
||||
epsilon: 1e-5,
|
||||
frechet_max_iters: 100,
|
||||
frechet_tolerance: 1e-6,
|
||||
depth_weight_type: DepthWeightType::Logarithmic,
|
||||
hnsw_m: 16,
|
||||
hnsw_ef_construction: 200,
|
||||
enable_sharding: false,
|
||||
default_shard_curvature: -1.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl HyperbolicCoherenceConfig {
|
||||
/// Create a configuration for small collections (< 10K nodes)
|
||||
pub fn small() -> Self {
|
||||
Self {
|
||||
dimension: 64,
|
||||
curvature: -1.0,
|
||||
hnsw_m: 8,
|
||||
hnsw_ef_construction: 100,
|
||||
enable_sharding: false,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a configuration for large collections (> 100K nodes)
|
||||
pub fn large() -> Self {
|
||||
Self {
|
||||
dimension: 64,
|
||||
curvature: -1.0,
|
||||
hnsw_m: 32,
|
||||
hnsw_ef_construction: 400,
|
||||
enable_sharding: true,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate configuration
|
||||
pub fn validate(&self) -> Result<(), String> {
|
||||
if self.curvature >= 0.0 {
|
||||
return Err(format!(
|
||||
"Curvature must be negative, got {}",
|
||||
self.curvature
|
||||
));
|
||||
}
|
||||
if self.dimension == 0 {
|
||||
return Err("Dimension must be positive".to_string());
|
||||
}
|
||||
if self.epsilon <= 0.0 {
|
||||
return Err("Epsilon must be positive".to_string());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Compute depth weight using configured function type
|
||||
pub fn depth_weight_fn(&self, depth: f32) -> f32 {
|
||||
self.depth_weight_type.compute(depth)
|
||||
}
|
||||
}
|
||||
|
||||
/// Type of depth weighting function
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum DepthWeightType {
|
||||
/// Constant weight (no depth scaling)
|
||||
Constant,
|
||||
/// Linear: 1 + depth
|
||||
Linear,
|
||||
/// Logarithmic: 1 + ln(max(depth, 1))
|
||||
Logarithmic,
|
||||
/// Quadratic: 1 + depth^2
|
||||
Quadratic,
|
||||
/// Exponential: e^(depth * scale)
|
||||
Exponential,
|
||||
}
|
||||
|
||||
impl Default for DepthWeightType {
|
||||
fn default() -> Self {
|
||||
Self::Logarithmic
|
||||
}
|
||||
}
|
||||
|
||||
impl DepthWeightType {
|
||||
/// Compute depth weight
|
||||
pub fn compute(&self, depth: f32) -> f32 {
|
||||
match self {
|
||||
Self::Constant => 1.0,
|
||||
Self::Linear => 1.0 + depth,
|
||||
Self::Logarithmic => 1.0 + depth.max(1.0).ln(),
|
||||
Self::Quadratic => 1.0 + depth * depth,
|
||||
Self::Exponential => (depth * 0.5).exp().min(10.0), // Capped at 10x
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_default_config() {
|
||||
let config = HyperbolicCoherenceConfig::default();
|
||||
assert_eq!(config.curvature, -1.0);
|
||||
assert!(config.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_curvature() {
|
||||
let config = HyperbolicCoherenceConfig {
|
||||
curvature: 1.0, // Invalid - must be negative
|
||||
..Default::default()
|
||||
};
|
||||
assert!(config.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_depth_weights() {
|
||||
assert_eq!(DepthWeightType::Constant.compute(5.0), 1.0);
|
||||
assert_eq!(DepthWeightType::Linear.compute(5.0), 6.0);
|
||||
|
||||
let log_weight = DepthWeightType::Logarithmic.compute(2.718281828);
|
||||
assert!((log_weight - 2.0).abs() < 0.01);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user