633 lines
20 KiB
Rust
633 lines
20 KiB
Rust
//! WebAssembly Bindings for Hyperbolic HNSW
|
|
//!
|
|
//! This module provides JavaScript/TypeScript bindings for hyperbolic embeddings
|
|
//! and HNSW search in the browser and Node.js environments.
|
|
//!
|
|
//! # Usage in JavaScript
|
|
//!
|
|
//! ```javascript
|
|
//! import init, {
|
|
//! HyperbolicIndex,
|
|
//! poincareDistance,
|
|
//! mobiusAdd,
|
|
//! expMap,
|
|
//! logMap
|
|
//! } from 'ruvector-hyperbolic-hnsw-wasm';
|
|
//!
|
|
//! // Initialize WASM module
|
|
//! await init();
|
|
//!
|
|
//! // Create index
|
|
//! const index = new HyperbolicIndex(16, 1.0); // ef_search=16, curvature=1.0
|
|
//!
|
|
//! // Insert vectors
|
|
//! index.insert(new Float32Array([0.1, 0.2, 0.3]));
|
|
//! index.insert(new Float32Array([-0.1, 0.15, 0.25]));
|
|
//!
|
|
//! // Search
|
|
//! const results = index.search(new Float32Array([0.15, 0.1, 0.2]), 2);
|
|
//! console.log(results); // [{id: 0, distance: 0.123}, ...]
|
|
//!
|
|
//! // Use low-level math operations
|
|
//! const d = poincareDistance(
|
|
//! new Float32Array([0.3, 0.2]),
|
|
//! new Float32Array([-0.1, 0.4]),
|
|
//! 1.0
|
|
//! );
|
|
//! ```
|
|
|
|
use ruvector_hyperbolic_hnsw::{
|
|
exp_map, frechet_mean, log_map, mobius_add, mobius_scalar_mult, poincare_distance,
|
|
project_to_ball, HyperbolicHnsw, HyperbolicHnswConfig, PoincareConfig, ShardedHyperbolicHnsw,
|
|
TangentCache, DEFAULT_CURVATURE, EPS,
|
|
};
|
|
use serde::{Deserialize, Serialize};
|
|
use wasm_bindgen::prelude::*;
|
|
|
|
#[cfg(feature = "console_error_panic_hook")]
|
|
fn set_panic_hook() {
|
|
console_error_panic_hook::set_once();
|
|
}
|
|
|
|
/// Initialize the WASM module
|
|
#[wasm_bindgen(start)]
|
|
pub fn init() {
|
|
#[cfg(feature = "console_error_panic_hook")]
|
|
set_panic_hook();
|
|
}
|
|
|
|
// ============================================================================
|
|
// Low-Level Math Operations
|
|
// ============================================================================
|
|
|
|
/// Compute Poincaré distance between two points
|
|
///
|
|
/// @param u - First point (Float32Array)
|
|
/// @param v - Second point (Float32Array)
|
|
/// @param curvature - Curvature parameter (positive)
|
|
/// @returns Geodesic distance in hyperbolic space
|
|
#[wasm_bindgen(js_name = poincareDistance)]
|
|
pub fn wasm_poincare_distance(u: &[f32], v: &[f32], curvature: f32) -> f32 {
|
|
poincare_distance(u, v, curvature)
|
|
}
|
|
|
|
/// Möbius addition in Poincaré ball
|
|
///
|
|
/// Computes the hyperbolic analog of vector addition: x ⊕_c y
|
|
///
|
|
/// @param x - First point (Float32Array)
|
|
/// @param y - Second point (Float32Array)
|
|
/// @param curvature - Curvature parameter
|
|
/// @returns Result of Möbius addition (Float32Array)
|
|
#[wasm_bindgen(js_name = mobiusAdd)]
|
|
pub fn wasm_mobius_add(x: &[f32], y: &[f32], curvature: f32) -> Vec<f32> {
|
|
mobius_add(x, y, curvature)
|
|
}
|
|
|
|
/// Möbius scalar multiplication
|
|
///
|
|
/// Computes r ⊗_c x for scalar r and point x
|
|
///
|
|
/// @param r - Scalar value
|
|
/// @param x - Point in Poincaré ball (Float32Array)
|
|
/// @param curvature - Curvature parameter
|
|
/// @returns Scaled point (Float32Array)
|
|
#[wasm_bindgen(js_name = mobiusScalarMult)]
|
|
pub fn wasm_mobius_scalar_mult(r: f32, x: &[f32], curvature: f32) -> Vec<f32> {
|
|
mobius_scalar_mult(r, x, curvature)
|
|
}
|
|
|
|
/// Exponential map at point p
|
|
///
|
|
/// Maps a tangent vector v at point p to the Poincaré ball
|
|
///
|
|
/// @param v - Tangent vector (Float32Array)
|
|
/// @param p - Base point (Float32Array)
|
|
/// @param curvature - Curvature parameter
|
|
/// @returns Point on the manifold (Float32Array)
|
|
#[wasm_bindgen(js_name = expMap)]
|
|
pub fn wasm_exp_map(v: &[f32], p: &[f32], curvature: f32) -> Vec<f32> {
|
|
exp_map(v, p, curvature)
|
|
}
|
|
|
|
/// Logarithmic map at point p
|
|
///
|
|
/// Maps a point y to the tangent space at point p
|
|
///
|
|
/// @param y - Target point (Float32Array)
|
|
/// @param p - Base point (Float32Array)
|
|
/// @param curvature - Curvature parameter
|
|
/// @returns Tangent vector at p (Float32Array)
|
|
#[wasm_bindgen(js_name = logMap)]
|
|
pub fn wasm_log_map(y: &[f32], p: &[f32], curvature: f32) -> Vec<f32> {
|
|
log_map(y, p, curvature)
|
|
}
|
|
|
|
/// Project point to Poincaré ball
|
|
///
|
|
/// Ensures ||x|| < 1/√c - eps for numerical stability
|
|
///
|
|
/// @param x - Point to project (Float32Array)
|
|
/// @param curvature - Curvature parameter
|
|
/// @returns Projected point (Float32Array)
|
|
#[wasm_bindgen(js_name = projectToBall)]
|
|
pub fn wasm_project_to_ball(x: &[f32], curvature: f32) -> Vec<f32> {
|
|
project_to_ball(x, curvature, EPS)
|
|
}
|
|
|
|
/// Compute Fréchet mean (hyperbolic centroid)
|
|
///
|
|
/// @param points - Array of points as flat Float32Array
|
|
/// @param dim - Dimension of each point
|
|
/// @param curvature - Curvature parameter
|
|
/// @returns Centroid point (Float32Array)
|
|
#[wasm_bindgen(js_name = frechetMean)]
|
|
pub fn wasm_frechet_mean(points: &[f32], dim: usize, curvature: f32) -> Result<Vec<f32>, JsValue> {
|
|
if points.is_empty() || dim == 0 {
|
|
return Err(JsValue::from_str("Empty points or invalid dimension"));
|
|
}
|
|
|
|
let point_vecs: Vec<Vec<f32>> = points.chunks(dim).map(|c| c.to_vec()).collect();
|
|
|
|
let point_refs: Vec<&[f32]> = point_vecs.iter().map(|v| v.as_slice()).collect();
|
|
|
|
let config = PoincareConfig::with_curvature(curvature)
|
|
.map_err(|e| JsValue::from_str(&e.to_string()))?;
|
|
|
|
frechet_mean(&point_refs, None, &config).map_err(|e| JsValue::from_str(&e.to_string()))
|
|
}
|
|
|
|
// ============================================================================
|
|
// Search Result Type
|
|
// ============================================================================
|
|
|
|
/// Search result from hyperbolic HNSW
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
#[wasm_bindgen]
|
|
pub struct WasmSearchResult {
|
|
/// Vector ID
|
|
pub id: usize,
|
|
/// Hyperbolic distance to query
|
|
pub distance: f32,
|
|
}
|
|
|
|
#[wasm_bindgen]
|
|
impl WasmSearchResult {
|
|
#[wasm_bindgen(constructor)]
|
|
pub fn new(id: usize, distance: f32) -> Self {
|
|
Self { id, distance }
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// Hyperbolic HNSW Index
|
|
// ============================================================================
|
|
|
|
/// Hyperbolic HNSW Index for hierarchy-aware vector search
|
|
///
|
|
/// @example
|
|
/// ```javascript
|
|
/// const index = new HyperbolicIndex(16, 1.0);
|
|
/// index.insert(new Float32Array([0.1, 0.2]));
|
|
/// index.insert(new Float32Array([-0.1, 0.3]));
|
|
/// const results = index.search(new Float32Array([0.05, 0.25]), 2);
|
|
/// ```
|
|
#[wasm_bindgen]
|
|
pub struct HyperbolicIndex {
|
|
inner: HyperbolicHnsw,
|
|
}
|
|
|
|
#[wasm_bindgen]
|
|
impl HyperbolicIndex {
|
|
/// Create a new hyperbolic HNSW index
|
|
///
|
|
/// @param ef_search - Size of dynamic candidate list during search (default: 50)
|
|
/// @param curvature - Curvature parameter for Poincaré ball (default: 1.0)
|
|
#[wasm_bindgen(constructor)]
|
|
pub fn new(ef_search: Option<usize>, curvature: Option<f32>) -> Self {
|
|
let mut config = HyperbolicHnswConfig::default();
|
|
config.ef_search = ef_search.unwrap_or(50);
|
|
config.curvature = curvature.unwrap_or(DEFAULT_CURVATURE);
|
|
|
|
Self {
|
|
inner: HyperbolicHnsw::new(config),
|
|
}
|
|
}
|
|
|
|
/// Create with custom configuration
|
|
///
|
|
/// @param config - JSON configuration object
|
|
#[wasm_bindgen(js_name = fromConfig)]
|
|
pub fn from_config(config: JsValue) -> Result<HyperbolicIndex, JsValue> {
|
|
let config: HyperbolicHnswConfig =
|
|
serde_wasm_bindgen::from_value(config).map_err(|e| JsValue::from_str(&e.to_string()))?;
|
|
Ok(Self {
|
|
inner: HyperbolicHnsw::new(config),
|
|
})
|
|
}
|
|
|
|
/// Insert a vector into the index
|
|
///
|
|
/// @param vector - Vector to insert (Float32Array)
|
|
/// @returns ID of inserted vector
|
|
#[wasm_bindgen]
|
|
pub fn insert(&mut self, vector: &[f32]) -> Result<usize, JsValue> {
|
|
self.inner
|
|
.insert(vector.to_vec())
|
|
.map_err(|e| JsValue::from_str(&e.to_string()))
|
|
}
|
|
|
|
/// Insert batch of vectors
|
|
///
|
|
/// @param vectors - Flat array of vectors
|
|
/// @param dim - Dimension of each vector
|
|
/// @returns Array of inserted IDs
|
|
#[wasm_bindgen(js_name = insertBatch)]
|
|
pub fn insert_batch(&mut self, vectors: &[f32], dim: usize) -> Result<Vec<usize>, JsValue> {
|
|
let vecs: Vec<Vec<f32>> = vectors.chunks(dim).map(|c| c.to_vec()).collect();
|
|
self.inner
|
|
.insert_batch(vecs)
|
|
.map_err(|e| JsValue::from_str(&e.to_string()))
|
|
}
|
|
|
|
/// Search for k nearest neighbors
|
|
///
|
|
/// @param query - Query vector (Float32Array)
|
|
/// @param k - Number of neighbors to return
|
|
/// @returns Array of search results as JSON
|
|
#[wasm_bindgen]
|
|
pub fn search(&self, query: &[f32], k: usize) -> Result<JsValue, JsValue> {
|
|
let results = self
|
|
.inner
|
|
.search(query, k)
|
|
.map_err(|e| JsValue::from_str(&e.to_string()))?;
|
|
|
|
let wasm_results: Vec<WasmSearchResult> = results
|
|
.into_iter()
|
|
.map(|r| WasmSearchResult::new(r.id, r.distance))
|
|
.collect();
|
|
|
|
serde_wasm_bindgen::to_value(&wasm_results).map_err(|e| JsValue::from_str(&e.to_string()))
|
|
}
|
|
|
|
/// Search with tangent space pruning (optimized)
|
|
///
|
|
/// @param query - Query vector (Float32Array)
|
|
/// @param k - Number of neighbors to return
|
|
/// @returns Array of search results as JSON
|
|
#[wasm_bindgen(js_name = searchWithPruning)]
|
|
pub fn search_with_pruning(&self, query: &[f32], k: usize) -> Result<JsValue, JsValue> {
|
|
let results = self
|
|
.inner
|
|
.search_with_pruning(query, k)
|
|
.map_err(|e| JsValue::from_str(&e.to_string()))?;
|
|
|
|
let wasm_results: Vec<WasmSearchResult> = results
|
|
.into_iter()
|
|
.map(|r| WasmSearchResult::new(r.id, r.distance))
|
|
.collect();
|
|
|
|
serde_wasm_bindgen::to_value(&wasm_results).map_err(|e| JsValue::from_str(&e.to_string()))
|
|
}
|
|
|
|
/// Build tangent cache for optimized search
|
|
#[wasm_bindgen(js_name = buildTangentCache)]
|
|
pub fn build_tangent_cache(&mut self) -> Result<(), JsValue> {
|
|
self.inner
|
|
.build_tangent_cache()
|
|
.map_err(|e| JsValue::from_str(&e.to_string()))
|
|
}
|
|
|
|
/// Get number of vectors in index
|
|
#[wasm_bindgen]
|
|
pub fn len(&self) -> usize {
|
|
self.inner.len()
|
|
}
|
|
|
|
/// Check if index is empty
|
|
#[wasm_bindgen(js_name = isEmpty)]
|
|
pub fn is_empty(&self) -> bool {
|
|
self.inner.is_empty()
|
|
}
|
|
|
|
/// Get vector dimension
|
|
#[wasm_bindgen]
|
|
pub fn dim(&self) -> Option<usize> {
|
|
self.inner.dim()
|
|
}
|
|
|
|
/// Update curvature parameter
|
|
///
|
|
/// @param curvature - New curvature value (must be positive)
|
|
#[wasm_bindgen(js_name = setCurvature)]
|
|
pub fn set_curvature(&mut self, curvature: f32) -> Result<(), JsValue> {
|
|
self.inner
|
|
.set_curvature(curvature)
|
|
.map_err(|e| JsValue::from_str(&e.to_string()))
|
|
}
|
|
|
|
/// Get a vector by ID
|
|
///
|
|
/// @param id - Vector ID
|
|
/// @returns Vector data or null if not found
|
|
#[wasm_bindgen(js_name = getVector)]
|
|
pub fn get_vector(&self, id: usize) -> Option<Vec<f32>> {
|
|
self.inner.get_vector(id).map(|v| v.to_vec())
|
|
}
|
|
|
|
/// Export index configuration as JSON
|
|
#[wasm_bindgen(js_name = exportConfig)]
|
|
pub fn export_config(&self) -> Result<JsValue, JsValue> {
|
|
serde_wasm_bindgen::to_value(&self.inner.config)
|
|
.map_err(|e| JsValue::from_str(&e.to_string()))
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// Sharded Index
|
|
// ============================================================================
|
|
|
|
/// Sharded Hyperbolic HNSW with per-shard curvature
|
|
///
|
|
/// @example
|
|
/// ```javascript
|
|
/// const manager = new ShardedIndex(1.0);
|
|
/// manager.insertToShard("taxonomy", new Float32Array([0.1, 0.2]), 0);
|
|
/// manager.insertToShard("taxonomy", new Float32Array([0.3, 0.1]), 3);
|
|
/// manager.updateCurvature("taxonomy", 0.5);
|
|
/// const results = manager.search(new Float32Array([0.2, 0.15]), 5);
|
|
/// ```
|
|
#[wasm_bindgen]
|
|
pub struct ShardedIndex {
|
|
inner: ShardedHyperbolicHnsw,
|
|
}
|
|
|
|
#[wasm_bindgen]
|
|
impl ShardedIndex {
|
|
/// Create a new sharded index
|
|
///
|
|
/// @param default_curvature - Default curvature for new shards
|
|
#[wasm_bindgen(constructor)]
|
|
pub fn new(default_curvature: f32) -> Self {
|
|
Self {
|
|
inner: ShardedHyperbolicHnsw::new(default_curvature),
|
|
}
|
|
}
|
|
|
|
/// Insert vector with automatic shard assignment
|
|
///
|
|
/// @param vector - Vector to insert (Float32Array)
|
|
/// @param depth - Optional hierarchy depth for shard assignment
|
|
/// @returns Global vector ID
|
|
#[wasm_bindgen]
|
|
pub fn insert(&mut self, vector: &[f32], depth: Option<usize>) -> Result<usize, JsValue> {
|
|
self.inner
|
|
.insert(vector.to_vec(), depth)
|
|
.map_err(|e| JsValue::from_str(&e.to_string()))
|
|
}
|
|
|
|
/// Insert vector into specific shard
|
|
///
|
|
/// @param shard_id - Target shard ID
|
|
/// @param vector - Vector to insert (Float32Array)
|
|
/// @returns Global vector ID
|
|
#[wasm_bindgen(js_name = insertToShard)]
|
|
pub fn insert_to_shard(&mut self, shard_id: &str, vector: &[f32]) -> Result<usize, JsValue> {
|
|
self.inner
|
|
.insert_to_shard(shard_id, vector.to_vec())
|
|
.map_err(|e| JsValue::from_str(&e.to_string()))
|
|
}
|
|
|
|
/// Search across all shards
|
|
///
|
|
/// @param query - Query vector (Float32Array)
|
|
/// @param k - Number of neighbors to return
|
|
/// @returns Array of search results as JSON
|
|
#[wasm_bindgen]
|
|
pub fn search(&self, query: &[f32], k: usize) -> Result<JsValue, JsValue> {
|
|
let results = self
|
|
.inner
|
|
.search(query, k)
|
|
.map_err(|e| JsValue::from_str(&e.to_string()))?;
|
|
|
|
let wasm_results: Vec<WasmSearchResult> = results
|
|
.into_iter()
|
|
.map(|(id, r)| WasmSearchResult::new(id, r.distance))
|
|
.collect();
|
|
|
|
serde_wasm_bindgen::to_value(&wasm_results).map_err(|e| JsValue::from_str(&e.to_string()))
|
|
}
|
|
|
|
/// Update curvature for a shard
|
|
///
|
|
/// @param shard_id - Shard ID
|
|
/// @param curvature - New curvature value
|
|
#[wasm_bindgen(js_name = updateCurvature)]
|
|
pub fn update_curvature(&mut self, shard_id: &str, curvature: f32) -> Result<(), JsValue> {
|
|
self.inner
|
|
.update_curvature(shard_id, curvature)
|
|
.map_err(|e| JsValue::from_str(&e.to_string()))
|
|
}
|
|
|
|
/// Set canary curvature for A/B testing
|
|
///
|
|
/// @param shard_id - Shard ID
|
|
/// @param curvature - Canary curvature value
|
|
/// @param traffic - Percentage of traffic for canary (0-100)
|
|
#[wasm_bindgen(js_name = setCanaryCurvature)]
|
|
pub fn set_canary_curvature(&mut self, shard_id: &str, curvature: f32, traffic: u8) {
|
|
self.inner.registry.set_canary(shard_id, curvature, traffic);
|
|
}
|
|
|
|
/// Promote canary to production
|
|
///
|
|
/// @param shard_id - Shard ID
|
|
#[wasm_bindgen(js_name = promoteCanary)]
|
|
pub fn promote_canary(&mut self, shard_id: &str) -> Result<(), JsValue> {
|
|
if let Some(shard_curv) = self.inner.registry.shards.get_mut(shard_id) {
|
|
shard_curv.promote_canary();
|
|
}
|
|
self.inner
|
|
.reload_curvatures()
|
|
.map_err(|e| JsValue::from_str(&e.to_string()))
|
|
}
|
|
|
|
/// Rollback canary
|
|
///
|
|
/// @param shard_id - Shard ID
|
|
#[wasm_bindgen(js_name = rollbackCanary)]
|
|
pub fn rollback_canary(&mut self, shard_id: &str) {
|
|
if let Some(shard_curv) = self.inner.registry.shards.get_mut(shard_id) {
|
|
shard_curv.rollback_canary();
|
|
}
|
|
}
|
|
|
|
/// Build tangent caches for all shards
|
|
#[wasm_bindgen(js_name = buildCaches)]
|
|
pub fn build_caches(&mut self) -> Result<(), JsValue> {
|
|
self.inner
|
|
.build_caches()
|
|
.map_err(|e| JsValue::from_str(&e.to_string()))
|
|
}
|
|
|
|
/// Get total vector count
|
|
#[wasm_bindgen]
|
|
pub fn len(&self) -> usize {
|
|
self.inner.len()
|
|
}
|
|
|
|
/// Check if empty
|
|
#[wasm_bindgen(js_name = isEmpty)]
|
|
pub fn is_empty(&self) -> bool {
|
|
self.inner.is_empty()
|
|
}
|
|
|
|
/// Get number of shards
|
|
#[wasm_bindgen(js_name = numShards)]
|
|
pub fn num_shards(&self) -> usize {
|
|
self.inner.num_shards()
|
|
}
|
|
|
|
/// Get curvature registry as JSON
|
|
#[wasm_bindgen(js_name = getRegistry)]
|
|
pub fn get_registry(&self) -> Result<JsValue, JsValue> {
|
|
serde_wasm_bindgen::to_value(&self.inner.registry)
|
|
.map_err(|e| JsValue::from_str(&e.to_string()))
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// Tangent Cache Operations
|
|
// ============================================================================
|
|
|
|
/// Tangent space cache for fast pruning
|
|
#[wasm_bindgen]
|
|
pub struct WasmTangentCache {
|
|
inner: TangentCache,
|
|
}
|
|
|
|
#[wasm_bindgen]
|
|
impl WasmTangentCache {
|
|
/// Create tangent cache from points
|
|
///
|
|
/// @param points - Flat array of points
|
|
/// @param dim - Dimension of each point
|
|
/// @param curvature - Curvature parameter
|
|
#[wasm_bindgen(constructor)]
|
|
pub fn new(points: &[f32], dim: usize, curvature: f32) -> Result<WasmTangentCache, JsValue> {
|
|
let point_vecs: Vec<Vec<f32>> = points.chunks(dim).map(|c| c.to_vec()).collect();
|
|
let indices: Vec<usize> = (0..point_vecs.len()).collect();
|
|
|
|
let cache = TangentCache::new(&point_vecs, &indices, curvature)
|
|
.map_err(|e| JsValue::from_str(&e.to_string()))?;
|
|
|
|
Ok(Self { inner: cache })
|
|
}
|
|
|
|
/// Get centroid of the cache
|
|
#[wasm_bindgen]
|
|
pub fn centroid(&self) -> Vec<f32> {
|
|
self.inner.centroid.clone()
|
|
}
|
|
|
|
/// Get tangent coordinates for a query
|
|
///
|
|
/// @param query - Query point (Float32Array)
|
|
/// @returns Tangent coordinates (Float32Array)
|
|
#[wasm_bindgen(js_name = queryTangent)]
|
|
pub fn query_tangent(&self, query: &[f32]) -> Vec<f32> {
|
|
self.inner.query_tangent(query)
|
|
}
|
|
|
|
/// Compute tangent distance squared (for fast pruning)
|
|
///
|
|
/// @param query_tangent - Query in tangent space (Float32Array)
|
|
/// @param idx - Index of cached point
|
|
/// @returns Squared distance in tangent space
|
|
#[wasm_bindgen(js_name = tangentDistanceSquared)]
|
|
pub fn tangent_distance_squared(&self, query_tangent: &[f32], idx: usize) -> f32 {
|
|
self.inner.tangent_distance_squared(query_tangent, idx)
|
|
}
|
|
|
|
/// Get number of cached points
|
|
#[wasm_bindgen]
|
|
pub fn len(&self) -> usize {
|
|
self.inner.len()
|
|
}
|
|
|
|
/// Get dimension
|
|
#[wasm_bindgen]
|
|
pub fn dim(&self) -> usize {
|
|
self.inner.dim()
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// Utility Functions
|
|
// ============================================================================
|
|
|
|
/// Get library version
|
|
#[wasm_bindgen(js_name = getVersion)]
|
|
pub fn get_version() -> String {
|
|
ruvector_hyperbolic_hnsw::VERSION.to_string()
|
|
}
|
|
|
|
/// Get default curvature value
|
|
#[wasm_bindgen(js_name = getDefaultCurvature)]
|
|
pub fn get_default_curvature() -> f32 {
|
|
DEFAULT_CURVATURE
|
|
}
|
|
|
|
/// Get numerical stability epsilon
|
|
#[wasm_bindgen(js_name = getEps)]
|
|
pub fn get_eps() -> f32 {
|
|
EPS
|
|
}
|
|
|
|
/// Compute vector norm
|
|
#[wasm_bindgen(js_name = vectorNorm)]
|
|
pub fn vector_norm(x: &[f32]) -> f32 {
|
|
ruvector_hyperbolic_hnsw::norm(x)
|
|
}
|
|
|
|
/// Compute squared vector norm
|
|
#[wasm_bindgen(js_name = vectorNormSquared)]
|
|
pub fn vector_norm_squared(x: &[f32]) -> f32 {
|
|
ruvector_hyperbolic_hnsw::norm_squared(x)
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use wasm_bindgen_test::*;
|
|
|
|
#[wasm_bindgen_test]
|
|
fn test_poincare_distance() {
|
|
let u = vec![0.3, 0.2];
|
|
let v = vec![-0.1, 0.4];
|
|
let d = wasm_poincare_distance(&u, &v, 1.0);
|
|
assert!(d > 0.0);
|
|
}
|
|
|
|
#[wasm_bindgen_test]
|
|
fn test_mobius_add() {
|
|
let x = vec![0.2, 0.1];
|
|
let y = vec![0.1, -0.1];
|
|
let z = wasm_mobius_add(&x, &y, 1.0);
|
|
assert_eq!(z.len(), 2);
|
|
}
|
|
|
|
#[wasm_bindgen_test]
|
|
fn test_hyperbolic_index() {
|
|
let mut index = HyperbolicIndex::new(Some(16), Some(1.0));
|
|
|
|
index.insert(&[0.1, 0.2, 0.3]).unwrap();
|
|
index.insert(&[-0.1, 0.15, 0.25]).unwrap();
|
|
index.insert(&[0.2, -0.1, 0.1]).unwrap();
|
|
|
|
assert_eq!(index.len(), 3);
|
|
assert!(!index.is_empty());
|
|
assert_eq!(index.dim(), Some(3));
|
|
}
|
|
}
|