Files
wifi-densepose/vendor/ruvector/examples/rvf/examples/hyperbolic_taxonomy.rs

323 lines
12 KiB
Rust

//! Exotic Capability: Hierarchy-Aware Search
//!
//! Demonstrates RVF for taxonomy node embeddings where parent-child
//! relationships are captured in embedding space (representing Poincare
//! embeddings). Hierarchically related nodes cluster closer together.
//!
//! Taxonomy structure:
//! 5 top-level categories, 4 children each, 3 grandchildren each = 85 nodes
//!
//! Features:
//! - Taxonomy tree with hierarchy-aware embeddings
//! - Metadata: node_name, depth, parent_name
//! - Query for hierarchically related nodes
//! - Filter by depth to get specific levels
//!
//! RVF segments used: VEC_SEG, MANIFEST_SEG
//!
//! Run: cargo run --example hyperbolic_taxonomy
use rvf_runtime::{
FilterExpr, MetadataEntry, MetadataValue, QueryOptions, RvfOptions, RvfStore, SearchResult,
};
use rvf_runtime::filter::FilterValue;
use rvf_runtime::options::DistanceMetric;
use tempfile::TempDir;
/// Simple LCG-based pseudo-random vector generator for deterministic results.
fn random_vector(dim: usize, seed: u64) -> Vec<f32> {
let mut v = Vec::with_capacity(dim);
let mut x = seed.wrapping_add(1);
for _ in 0..dim {
x = x.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
v.push(((x >> 33) as f32) / (u32::MAX as f32) - 0.5);
}
v
}
/// Create a child vector that is close to its parent in embedding space.
/// Adds a small perturbation scaled by depth to represent Poincare embedding
/// where children cluster near parents.
fn child_vector(parent: &[f32], child_seed: u64, perturbation_scale: f32) -> Vec<f32> {
let dim = parent.len();
let noise = random_vector(dim, child_seed);
parent
.iter()
.zip(noise.iter())
.map(|(p, n)| p + n * perturbation_scale)
.collect()
}
/// A taxonomy node with its name, depth, parent name, and vector ID.
struct TaxonomyNode {
id: u64,
name: String,
depth: u64,
parent_name: String,
}
fn main() {
println!("=== Hyperbolic Taxonomy: Hierarchy-Aware Search ===\n");
let dim = 64;
// ====================================================================
// 1. Build taxonomy tree
// ====================================================================
println!("--- 1. Build Taxonomy Tree ---");
let top_categories = ["Science", "Technology", "Arts", "Medicine", "Law"];
let mut nodes: Vec<TaxonomyNode> = Vec::new();
let mut vectors_list: Vec<Vec<f32>> = Vec::new();
let mut next_id: u64 = 0;
// Level 0: top-level categories
let mut top_vectors = Vec::new();
for cat in &top_categories {
let vec = random_vector(dim, next_id * 1000 + 42);
top_vectors.push(vec.clone());
vectors_list.push(vec);
nodes.push(TaxonomyNode {
id: next_id,
name: cat.to_string(),
depth: 0,
parent_name: "ROOT".to_string(),
});
next_id += 1;
}
// Level 1: 4 children per top-level
let child_suffixes = ["Theory", "Applied", "History", "Modern"];
let mut level1_vectors = Vec::new();
let mut level1_parents = Vec::new();
for (cat_idx, cat) in top_categories.iter().enumerate() {
for (child_idx, suffix) in child_suffixes.iter().enumerate() {
let parent_vec = &top_vectors[cat_idx];
let seed = next_id * 1000 + 100 + child_idx as u64;
let vec = child_vector(parent_vec, seed, 0.15);
level1_vectors.push(vec.clone());
level1_parents.push(cat_idx);
vectors_list.push(vec);
nodes.push(TaxonomyNode {
id: next_id,
name: format!("{}_{}", cat, suffix),
depth: 1,
parent_name: cat.to_string(),
});
next_id += 1;
}
}
// Level 2: 3 grandchildren per level-1 node
let grandchild_suffixes = ["Intro", "Advanced", "Research"];
for l1_idx in 0..level1_vectors.len() {
let parent_name = nodes[5 + l1_idx].name.clone(); // offset by 5 top-level nodes
let parent_vec = &level1_vectors[l1_idx];
for (gc_idx, suffix) in grandchild_suffixes.iter().enumerate() {
let seed = next_id * 1000 + 200 + gc_idx as u64;
let vec = child_vector(parent_vec, seed, 0.08);
vectors_list.push(vec);
nodes.push(TaxonomyNode {
id: next_id,
name: format!("{}_{}", parent_name, suffix),
depth: 2,
parent_name: parent_name.clone(),
});
next_id += 1;
}
}
let total_nodes = nodes.len();
let level0_count = 5;
let level1_count = 5 * 4;
let level2_count = 5 * 4 * 3;
println!(" Taxonomy structure:");
println!(" Level 0 (roots): {} categories", level0_count);
println!(" Level 1 (children): {} subcategories", level1_count);
println!(" Level 2 (grandchildren): {} leaf nodes", level2_count);
println!(" Total nodes: {}", total_nodes);
assert_eq!(total_nodes, level0_count + level1_count + level2_count);
// ====================================================================
// 2. Create store and ingest
// ====================================================================
println!("\n--- 2. Create Taxonomy Store ---");
let tmp_dir = TempDir::new().expect("failed to create temp dir");
let store_path = tmp_dir.path().join("taxonomy.rvf");
let options = RvfOptions {
dimension: dim as u16,
metric: DistanceMetric::L2,
..Default::default()
};
let mut store = RvfStore::create(&store_path, options).expect("failed to create store");
let vec_refs: Vec<&[f32]> = vectors_list.iter().map(|v| v.as_slice()).collect();
let ids: Vec<u64> = nodes.iter().map(|n| n.id).collect();
// Metadata: node_name (0), depth (1), parent_name (2)
let mut metadata = Vec::with_capacity(total_nodes * 3);
for node in &nodes {
metadata.push(MetadataEntry {
field_id: 0,
value: MetadataValue::String(node.name.clone()),
});
metadata.push(MetadataEntry {
field_id: 1,
value: MetadataValue::U64(node.depth),
});
metadata.push(MetadataEntry {
field_id: 2,
value: MetadataValue::String(node.parent_name.clone()),
});
}
let ingest = store
.ingest_batch(&vec_refs, &ids, Some(&metadata))
.expect("ingest failed");
println!(" Ingested {} taxonomy nodes ({} dims)", ingest.accepted, dim);
// ====================================================================
// 3. Query for nodes similar to a root concept
// ====================================================================
println!("\n--- 3. Hierarchy-Aware Query ---");
// Use "Science" (ID 0) as query vector
let science_vec = &vectors_list[0];
let k = 15;
let results = store
.query(science_vec, k, &QueryOptions::default())
.expect("query failed");
println!(" Query: nodes similar to 'Science' (top-{}):", k);
print_taxonomy_results(&results, &nodes);
// Check that Science itself is nearest (distance ~0)
assert_eq!(results[0].id, 0, "Science should be nearest to itself");
println!("\n Science's children should dominate top results.");
// Count how many results share "Science" lineage
let science_related = results.iter().filter(|r| {
let node = &nodes[r.id as usize];
node.name.starts_with("Science") || node.parent_name == "Science"
}).count();
println!(" Science-related in top-{}: {} nodes", k, science_related);
// ====================================================================
// 4. Filter by depth: roots only
// ====================================================================
println!("\n--- 4. Filter by Depth (Roots Only) ---");
let filter_roots = FilterExpr::Eq(1, FilterValue::U64(0));
let opts_roots = QueryOptions {
filter: Some(filter_roots),
..Default::default()
};
let results_roots = store
.query(science_vec, 5, &opts_roots)
.expect("filtered query failed");
println!(" Root nodes (depth=0):");
print_taxonomy_results(&results_roots, &nodes);
for r in &results_roots {
assert_eq!(nodes[r.id as usize].depth, 0);
}
println!(" All results verified: depth == 0.");
// ====================================================================
// 5. Filter by depth: leaf nodes only
// ====================================================================
println!("\n--- 5. Filter by Depth (Leaves Only) ---");
let filter_leaves = FilterExpr::Eq(1, FilterValue::U64(2));
let opts_leaves = QueryOptions {
filter: Some(filter_leaves),
..Default::default()
};
let results_leaves = store
.query(science_vec, 10, &opts_leaves)
.expect("filtered query failed");
println!(" Leaf nodes (depth=2) nearest to 'Science':");
print_taxonomy_results(&results_leaves, &nodes);
for r in &results_leaves {
assert_eq!(nodes[r.id as usize].depth, 2);
}
println!(" All results verified: depth == 2.");
// ====================================================================
// 6. Cross-branch query
// ====================================================================
println!("\n--- 6. Cross-Branch Query ---");
// Query with a "Technology" vector to see cross-branch similarities
let tech_vec = &vectors_list[1]; // Technology is ID 1
let results_tech = store
.query(tech_vec, 10, &QueryOptions::default())
.expect("query failed");
println!(" Nodes nearest to 'Technology':");
print_taxonomy_results(&results_tech, &nodes);
// ====================================================================
// 7. Print taxonomy tree
// ====================================================================
println!("\n--- 7. Taxonomy Tree ---");
for cat_idx in 0..5 {
let cat = &nodes[cat_idx];
println!(" {} (ID={}, depth={})", cat.name, cat.id, cat.depth);
for child_idx in 0..4 {
let l1_idx = 5 + cat_idx * 4 + child_idx;
let child = &nodes[l1_idx];
println!(" +-- {} (ID={})", child.name, child.id);
for gc_idx in 0..3 {
let l2_idx = 25 + (cat_idx * 4 + child_idx) * 3 + gc_idx;
let grandchild = &nodes[l2_idx];
println!(" +-- {} (ID={})", grandchild.name, grandchild.id);
}
}
}
// ====================================================================
// Summary
// ====================================================================
println!("\n=== Hyperbolic Taxonomy Summary ===\n");
println!(" Total nodes: {}", total_nodes);
println!(" Embedding dims: {}", dim);
println!(" Root query results: {}", results_roots.len());
println!(" Leaf query results: {}", results_leaves.len());
println!(" Science-related: {} in top-{}", science_related, k);
store.close().expect("failed to close store");
println!("\nDone.");
}
fn print_taxonomy_results(results: &[SearchResult], nodes: &[TaxonomyNode]) {
println!(
" {:>4} {:>12} {:>5} {:>30} {:>20}",
"ID", "Distance", "Depth", "Name", "Parent"
);
println!(
" {:->4} {:->12} {:->5} {:->30} {:->20}",
"", "", "", "", ""
);
for r in results {
let node = &nodes[r.id as usize];
println!(
" {:>4} {:>12.6} {:>5} {:>30} {:>20}",
r.id, r.distance, node.depth, node.name, node.parent_name
);
}
}