692 lines
23 KiB
Rust
692 lines
23 KiB
Rust
//! Lightweight HNSW Index for iOS/Browser WASM
|
|
//!
|
|
//! A simplified HNSW implementation optimized for mobile/browser deployment.
|
|
//! Provides O(log n) approximate nearest neighbor search.
|
|
//!
|
|
//! Based on the paper: "Efficient and Robust Approximate Nearest Neighbor Search
|
|
//! Using Hierarchical Navigable Small World Graphs"
|
|
|
|
use crate::distance::{distance, DistanceMetric};
|
|
use std::collections::{BinaryHeap, HashSet};
|
|
use std::vec::Vec;
|
|
use core::cmp::Ordering;
|
|
|
|
/// HNSW configuration
|
|
#[derive(Clone, Debug)]
|
|
pub struct HnswConfig {
|
|
/// Max connections per node (M parameter)
|
|
pub m: usize,
|
|
/// Max connections at layer 0 (usually 2*M)
|
|
pub m_max_0: usize,
|
|
/// Construction-time search width
|
|
pub ef_construction: usize,
|
|
/// Query-time search width
|
|
pub ef_search: usize,
|
|
/// Level multiplier (1/ln(M))
|
|
pub level_mult: f32,
|
|
}
|
|
|
|
impl Default for HnswConfig {
|
|
fn default() -> Self {
|
|
Self {
|
|
m: 16,
|
|
m_max_0: 32,
|
|
ef_construction: 100,
|
|
ef_search: 50,
|
|
level_mult: 0.36, // 1/ln(16)
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Node in the HNSW graph
|
|
#[derive(Clone, Debug)]
|
|
struct HnswNode {
|
|
/// Vector ID
|
|
id: u64,
|
|
/// Vector data
|
|
vector: Vec<f32>,
|
|
/// Connections at each layer
|
|
connections: Vec<Vec<u64>>,
|
|
/// Node's layer
|
|
level: usize,
|
|
}
|
|
|
|
/// Search candidate with distance
|
|
#[derive(Clone, Debug)]
|
|
struct Candidate {
|
|
id: u64,
|
|
distance: f32,
|
|
}
|
|
|
|
impl PartialEq for Candidate {
|
|
fn eq(&self, other: &Self) -> bool {
|
|
self.id == other.id
|
|
}
|
|
}
|
|
|
|
impl Eq for Candidate {}
|
|
|
|
impl PartialOrd for Candidate {
|
|
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
|
// Reverse order for min-heap behavior in BinaryHeap
|
|
other.distance.partial_cmp(&self.distance)
|
|
}
|
|
}
|
|
|
|
impl Ord for Candidate {
|
|
fn cmp(&self, other: &Self) -> Ordering {
|
|
self.partial_cmp(other).unwrap_or(Ordering::Equal)
|
|
}
|
|
}
|
|
|
|
/// Lightweight HNSW index
|
|
pub struct HnswIndex {
|
|
/// All nodes
|
|
nodes: Vec<HnswNode>,
|
|
/// ID to node index mapping
|
|
id_to_idx: std::collections::HashMap<u64, usize>,
|
|
/// Entry point (topmost node)
|
|
entry_point: Option<usize>,
|
|
/// Maximum level in the graph
|
|
max_level: usize,
|
|
/// Configuration
|
|
config: HnswConfig,
|
|
/// Distance metric
|
|
metric: DistanceMetric,
|
|
/// Dimension
|
|
dim: usize,
|
|
/// Random seed for level generation
|
|
seed: u32,
|
|
}
|
|
|
|
impl HnswIndex {
|
|
/// Create a new HNSW index
|
|
pub fn new(dim: usize, metric: DistanceMetric, config: HnswConfig) -> Self {
|
|
Self {
|
|
nodes: Vec::new(),
|
|
id_to_idx: std::collections::HashMap::new(),
|
|
entry_point: None,
|
|
max_level: 0,
|
|
config,
|
|
metric,
|
|
dim,
|
|
seed: 12345,
|
|
}
|
|
}
|
|
|
|
/// Create with default config
|
|
pub fn with_defaults(dim: usize, metric: DistanceMetric) -> Self {
|
|
Self::new(dim, metric, HnswConfig::default())
|
|
}
|
|
|
|
/// Generate random level for a new node
|
|
fn random_level(&mut self) -> usize {
|
|
// LCG random number generator
|
|
self.seed = self.seed.wrapping_mul(1103515245).wrapping_add(12345);
|
|
let rand = (self.seed >> 16) as f32 / 32768.0;
|
|
|
|
let level = (-rand.ln() * self.config.level_mult).floor() as usize;
|
|
level.min(16) // Cap at 16 levels
|
|
}
|
|
|
|
/// Insert a vector into the index
|
|
pub fn insert(&mut self, id: u64, vector: Vec<f32>) -> bool {
|
|
if vector.len() != self.dim {
|
|
return false;
|
|
}
|
|
|
|
if self.id_to_idx.contains_key(&id) {
|
|
return false; // Already exists
|
|
}
|
|
|
|
let level = self.random_level();
|
|
let node_idx = self.nodes.len();
|
|
|
|
// Create node with empty connections
|
|
let mut node = HnswNode {
|
|
id,
|
|
vector,
|
|
connections: vec![Vec::new(); level + 1],
|
|
level,
|
|
};
|
|
|
|
if let Some(ep_idx) = self.entry_point {
|
|
// Find entry point at the top level
|
|
let mut curr_idx = ep_idx;
|
|
let mut curr_dist = self.distance_to_node(node_idx, curr_idx, &node.vector);
|
|
|
|
// Traverse from top to insertion level
|
|
for lc in (level + 1..=self.max_level).rev() {
|
|
let mut changed = true;
|
|
while changed {
|
|
changed = false;
|
|
if let Some(connections) = self.nodes.get(curr_idx).map(|n| n.connections.get(lc).cloned()).flatten() {
|
|
for &neighbor_id in &connections {
|
|
if let Some(&neighbor_idx) = self.id_to_idx.get(&neighbor_id) {
|
|
let d = self.distance_to_node(node_idx, neighbor_idx, &node.vector);
|
|
if d < curr_dist {
|
|
curr_dist = d;
|
|
curr_idx = neighbor_idx;
|
|
changed = true;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Insert at each level
|
|
for lc in (0..=level.min(self.max_level)).rev() {
|
|
let neighbors = self.search_layer(&node.vector, curr_idx, self.config.ef_construction, lc);
|
|
|
|
// Select M best neighbors
|
|
let m_max = if lc == 0 { self.config.m_max_0 } else { self.config.m };
|
|
let selected: Vec<u64> = neighbors.iter()
|
|
.take(m_max)
|
|
.map(|c| c.id)
|
|
.collect();
|
|
|
|
node.connections[lc] = selected.clone();
|
|
|
|
// Add bidirectional connections
|
|
for &neighbor_id in &selected {
|
|
if let Some(&neighbor_idx) = self.id_to_idx.get(&neighbor_id) {
|
|
if let Some(neighbor_node) = self.nodes.get_mut(neighbor_idx) {
|
|
if lc < neighbor_node.connections.len() {
|
|
neighbor_node.connections[lc].push(id);
|
|
|
|
// Prune if too many connections
|
|
if neighbor_node.connections[lc].len() > m_max {
|
|
let query = &neighbor_node.vector.clone();
|
|
self.prune_connections(neighbor_idx, lc, m_max, query);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if !neighbors.is_empty() {
|
|
curr_idx = self.id_to_idx.get(&neighbors[0].id).copied().unwrap_or(curr_idx);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Add node
|
|
self.nodes.push(node);
|
|
self.id_to_idx.insert(id, node_idx);
|
|
|
|
// Update entry point if this is higher level
|
|
if level > self.max_level || self.entry_point.is_none() {
|
|
self.max_level = level;
|
|
self.entry_point = Some(node_idx);
|
|
}
|
|
|
|
true
|
|
}
|
|
|
|
/// Search for k nearest neighbors
|
|
pub fn search(&self, query: &[f32], k: usize) -> Vec<(u64, f32)> {
|
|
self.search_with_ef(query, k, self.config.ef_search)
|
|
}
|
|
|
|
/// Search with custom ef parameter
|
|
pub fn search_with_ef(&self, query: &[f32], k: usize, ef: usize) -> Vec<(u64, f32)> {
|
|
if query.len() != self.dim || self.entry_point.is_none() {
|
|
return vec![];
|
|
}
|
|
|
|
let ep_idx = self.entry_point.unwrap();
|
|
|
|
// Find entry point by traversing from top
|
|
let mut curr_idx = ep_idx;
|
|
let mut curr_dist = distance(query, &self.nodes[curr_idx].vector, self.metric);
|
|
|
|
for lc in (1..=self.max_level).rev() {
|
|
let mut changed = true;
|
|
while changed {
|
|
changed = false;
|
|
if let Some(connections) = self.nodes.get(curr_idx).and_then(|n| n.connections.get(lc)) {
|
|
for &neighbor_id in connections {
|
|
if let Some(&neighbor_idx) = self.id_to_idx.get(&neighbor_id) {
|
|
let d = distance(query, &self.nodes[neighbor_idx].vector, self.metric);
|
|
if d < curr_dist {
|
|
curr_dist = d;
|
|
curr_idx = neighbor_idx;
|
|
changed = true;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Search at layer 0
|
|
let results = self.search_layer(query, curr_idx, ef, 0);
|
|
|
|
results.into_iter()
|
|
.take(k)
|
|
.map(|c| (c.id, c.distance))
|
|
.collect()
|
|
}
|
|
|
|
/// Search within a specific layer
|
|
fn search_layer(&self, query: &[f32], entry_idx: usize, ef: usize, layer: usize) -> Vec<Candidate> {
|
|
let entry_id = self.nodes[entry_idx].id;
|
|
let entry_dist = distance(query, &self.nodes[entry_idx].vector, self.metric);
|
|
|
|
let mut visited: HashSet<u64> = HashSet::new();
|
|
let mut candidates: BinaryHeap<Candidate> = BinaryHeap::new();
|
|
let mut results: Vec<Candidate> = Vec::new();
|
|
|
|
visited.insert(entry_id);
|
|
candidates.push(Candidate { id: entry_id, distance: entry_dist });
|
|
results.push(Candidate { id: entry_id, distance: entry_dist });
|
|
|
|
while let Some(current) = candidates.pop() {
|
|
// Stop if current is worse than worst in results
|
|
if results.len() >= ef {
|
|
let worst_dist = results.iter().map(|c| c.distance).fold(f32::NEG_INFINITY, f32::max);
|
|
if current.distance > worst_dist {
|
|
break;
|
|
}
|
|
}
|
|
|
|
// Explore neighbors
|
|
if let Some(&curr_idx) = self.id_to_idx.get(¤t.id) {
|
|
if let Some(connections) = self.nodes.get(curr_idx).and_then(|n| n.connections.get(layer)) {
|
|
for &neighbor_id in connections {
|
|
if visited.insert(neighbor_id) {
|
|
if let Some(&neighbor_idx) = self.id_to_idx.get(&neighbor_id) {
|
|
let d = distance(query, &self.nodes[neighbor_idx].vector, self.metric);
|
|
|
|
let should_add = results.len() < ef || {
|
|
let worst = results.iter().map(|c| c.distance).fold(f32::NEG_INFINITY, f32::max);
|
|
d < worst
|
|
};
|
|
|
|
if should_add {
|
|
candidates.push(Candidate { id: neighbor_id, distance: d });
|
|
results.push(Candidate { id: neighbor_id, distance: d });
|
|
|
|
// Keep only ef best
|
|
if results.len() > ef {
|
|
results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
|
|
results.truncate(ef);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
|
|
results
|
|
}
|
|
|
|
/// Prune connections to keep only the best
|
|
fn prune_connections(&mut self, node_idx: usize, layer: usize, max_conn: usize, query: &[f32]) {
|
|
// First, collect connection info without holding mutable borrow
|
|
let connections_to_score: Vec<u64> = if let Some(node) = self.nodes.get(node_idx) {
|
|
if layer < node.connections.len() {
|
|
node.connections[layer].clone()
|
|
} else {
|
|
return;
|
|
}
|
|
} else {
|
|
return;
|
|
};
|
|
|
|
// Score connections
|
|
let mut candidates: Vec<(u64, f32)> = connections_to_score
|
|
.iter()
|
|
.filter_map(|&id| {
|
|
self.id_to_idx.get(&id)
|
|
.and_then(|&idx| self.nodes.get(idx))
|
|
.map(|n| (id, distance(query, &n.vector, self.metric)))
|
|
})
|
|
.collect();
|
|
|
|
candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
|
|
let pruned: Vec<u64> = candidates.into_iter()
|
|
.take(max_conn)
|
|
.map(|(id, _)| id)
|
|
.collect();
|
|
|
|
// Now update the connections
|
|
if let Some(node) = self.nodes.get_mut(node_idx) {
|
|
if layer < node.connections.len() {
|
|
node.connections[layer] = pruned;
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Helper to calculate distance to a node
|
|
fn distance_to_node(&self, _new_idx: usize, existing_idx: usize, new_vector: &[f32]) -> f32 {
|
|
if let Some(node) = self.nodes.get(existing_idx) {
|
|
distance(new_vector, &node.vector, self.metric)
|
|
} else {
|
|
f32::MAX
|
|
}
|
|
}
|
|
|
|
/// Get number of vectors in the index
|
|
pub fn len(&self) -> usize {
|
|
self.nodes.len()
|
|
}
|
|
|
|
/// Check if empty
|
|
pub fn is_empty(&self) -> bool {
|
|
self.nodes.is_empty()
|
|
}
|
|
|
|
/// Get vector by ID
|
|
pub fn get(&self, id: u64) -> Option<&[f32]> {
|
|
self.id_to_idx.get(&id)
|
|
.and_then(|&idx| self.nodes.get(idx))
|
|
.map(|n| n.vector.as_slice())
|
|
}
|
|
|
|
// ============================================
|
|
// Persistence
|
|
// ============================================
|
|
|
|
/// Serialize the HNSW index to bytes
|
|
///
|
|
/// Format:
|
|
/// - Header (32 bytes): dim, metric, m, m_max_0, ef_construction, ef_search, max_level, node_count
|
|
/// - For each node: id (8), level (4), vector (dim*4), connections per layer
|
|
pub fn serialize(&self) -> Vec<u8> {
|
|
let mut bytes = Vec::new();
|
|
|
|
// Header
|
|
bytes.extend_from_slice(&(self.dim as u32).to_le_bytes());
|
|
bytes.extend_from_slice(&(self.metric as u8).to_le_bytes());
|
|
bytes.extend_from_slice(&[0u8; 3]); // padding
|
|
bytes.extend_from_slice(&(self.config.m as u32).to_le_bytes());
|
|
bytes.extend_from_slice(&(self.config.m_max_0 as u32).to_le_bytes());
|
|
bytes.extend_from_slice(&(self.config.ef_construction as u32).to_le_bytes());
|
|
bytes.extend_from_slice(&(self.config.ef_search as u32).to_le_bytes());
|
|
bytes.extend_from_slice(&(self.max_level as u32).to_le_bytes());
|
|
bytes.extend_from_slice(&(self.nodes.len() as u32).to_le_bytes());
|
|
bytes.extend_from_slice(&self.entry_point.map(|e| e as u32).unwrap_or(u32::MAX).to_le_bytes());
|
|
|
|
// Nodes
|
|
for node in &self.nodes {
|
|
// Node header: id, level
|
|
bytes.extend_from_slice(&node.id.to_le_bytes());
|
|
bytes.extend_from_slice(&(node.level as u32).to_le_bytes());
|
|
|
|
// Vector
|
|
for &v in &node.vector {
|
|
bytes.extend_from_slice(&v.to_le_bytes());
|
|
}
|
|
|
|
// Connections: count per layer, then connection IDs
|
|
bytes.extend_from_slice(&(node.connections.len() as u32).to_le_bytes());
|
|
for layer_conns in &node.connections {
|
|
bytes.extend_from_slice(&(layer_conns.len() as u32).to_le_bytes());
|
|
for &conn_id in layer_conns {
|
|
bytes.extend_from_slice(&conn_id.to_le_bytes());
|
|
}
|
|
}
|
|
}
|
|
|
|
bytes
|
|
}
|
|
|
|
/// Deserialize HNSW index from bytes
|
|
pub fn deserialize(bytes: &[u8]) -> Option<Self> {
|
|
if bytes.len() < 36 {
|
|
return None;
|
|
}
|
|
|
|
let mut offset = 0;
|
|
|
|
// Read header
|
|
let dim = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
|
|
let metric = DistanceMetric::from_u8(bytes[4]);
|
|
offset = 8;
|
|
|
|
let m = u32::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3]]) as usize;
|
|
offset += 4;
|
|
let m_max_0 = u32::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3]]) as usize;
|
|
offset += 4;
|
|
let ef_construction = u32::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3]]) as usize;
|
|
offset += 4;
|
|
let ef_search = u32::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3]]) as usize;
|
|
offset += 4;
|
|
let max_level = u32::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3]]) as usize;
|
|
offset += 4;
|
|
let node_count = u32::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3]]) as usize;
|
|
offset += 4;
|
|
let entry_point_raw = u32::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3]]);
|
|
offset += 4;
|
|
let entry_point = if entry_point_raw == u32::MAX { None } else { Some(entry_point_raw as usize) };
|
|
|
|
let config = HnswConfig {
|
|
m,
|
|
m_max_0,
|
|
ef_construction,
|
|
ef_search,
|
|
level_mult: 1.0 / (m as f32).ln(),
|
|
};
|
|
|
|
let mut nodes = Vec::with_capacity(node_count);
|
|
let mut id_to_idx = std::collections::HashMap::new();
|
|
|
|
for node_idx in 0..node_count {
|
|
if offset + 12 > bytes.len() {
|
|
return None;
|
|
}
|
|
|
|
// Node header
|
|
let id = u64::from_le_bytes([
|
|
bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3],
|
|
bytes[offset+4], bytes[offset+5], bytes[offset+6], bytes[offset+7],
|
|
]);
|
|
offset += 8;
|
|
let level = u32::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3]]) as usize;
|
|
offset += 4;
|
|
|
|
// Vector
|
|
let mut vector = Vec::with_capacity(dim);
|
|
for _ in 0..dim {
|
|
if offset + 4 > bytes.len() {
|
|
return None;
|
|
}
|
|
let v = f32::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3]]);
|
|
vector.push(v);
|
|
offset += 4;
|
|
}
|
|
|
|
// Connections
|
|
if offset + 4 > bytes.len() {
|
|
return None;
|
|
}
|
|
let num_layers = u32::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3]]) as usize;
|
|
offset += 4;
|
|
|
|
let mut connections = Vec::with_capacity(num_layers);
|
|
for _ in 0..num_layers {
|
|
if offset + 4 > bytes.len() {
|
|
return None;
|
|
}
|
|
let num_conns = u32::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3]]) as usize;
|
|
offset += 4;
|
|
|
|
let mut layer_conns = Vec::with_capacity(num_conns);
|
|
for _ in 0..num_conns {
|
|
if offset + 8 > bytes.len() {
|
|
return None;
|
|
}
|
|
let conn_id = u64::from_le_bytes([
|
|
bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3],
|
|
bytes[offset+4], bytes[offset+5], bytes[offset+6], bytes[offset+7],
|
|
]);
|
|
layer_conns.push(conn_id);
|
|
offset += 8;
|
|
}
|
|
connections.push(layer_conns);
|
|
}
|
|
|
|
id_to_idx.insert(id, node_idx);
|
|
nodes.push(HnswNode {
|
|
id,
|
|
vector,
|
|
connections,
|
|
level,
|
|
});
|
|
}
|
|
|
|
Some(Self {
|
|
nodes,
|
|
id_to_idx,
|
|
entry_point,
|
|
max_level,
|
|
config,
|
|
metric,
|
|
dim,
|
|
seed: 12345,
|
|
})
|
|
}
|
|
|
|
/// Estimate serialized size in bytes
|
|
pub fn serialized_size(&self) -> usize {
|
|
let mut size = 36; // Header
|
|
for node in &self.nodes {
|
|
size += 12; // id + level
|
|
size += node.vector.len() * 4; // vector
|
|
size += 4; // num_layers
|
|
for layer in &node.connections {
|
|
size += 4 + layer.len() * 8; // count + connection IDs
|
|
}
|
|
}
|
|
size
|
|
}
|
|
}
|
|
|
|
// ============================================
|
|
// WASM Exports
|
|
// ============================================
|
|
|
|
static mut HNSW_INDEX: Option<HnswIndex> = None;
|
|
|
|
/// Create HNSW index
|
|
#[no_mangle]
|
|
pub extern "C" fn hnsw_create(dim: u32, metric: u8, m: u32, ef_construction: u32) -> i32 {
|
|
let config = HnswConfig {
|
|
m: m as usize,
|
|
m_max_0: (m * 2) as usize,
|
|
ef_construction: ef_construction as usize,
|
|
ef_search: 50,
|
|
level_mult: 1.0 / (m as f32).ln(),
|
|
};
|
|
|
|
unsafe {
|
|
HNSW_INDEX = Some(HnswIndex::new(
|
|
dim as usize,
|
|
DistanceMetric::from_u8(metric),
|
|
config,
|
|
));
|
|
}
|
|
0
|
|
}
|
|
|
|
/// Insert vector into HNSW
|
|
#[no_mangle]
|
|
pub extern "C" fn hnsw_insert(id: u64, vector_ptr: *const f32, len: u32) -> i32 {
|
|
unsafe {
|
|
if let Some(index) = HNSW_INDEX.as_mut() {
|
|
let vector = core::slice::from_raw_parts(vector_ptr, len as usize).to_vec();
|
|
if index.insert(id, vector) { 0 } else { -1 }
|
|
} else {
|
|
-1
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Search HNSW index
|
|
#[no_mangle]
|
|
pub extern "C" fn hnsw_search(
|
|
query_ptr: *const f32,
|
|
query_len: u32,
|
|
k: u32,
|
|
ef: u32,
|
|
out_ids: *mut u64,
|
|
out_distances: *mut f32,
|
|
) -> u32 {
|
|
unsafe {
|
|
if let Some(index) = HNSW_INDEX.as_ref() {
|
|
let query = core::slice::from_raw_parts(query_ptr, query_len as usize);
|
|
let results = index.search_with_ef(query, k as usize, ef as usize);
|
|
|
|
let ids = core::slice::from_raw_parts_mut(out_ids, results.len());
|
|
let distances = core::slice::from_raw_parts_mut(out_distances, results.len());
|
|
|
|
for (i, (id, dist)) in results.iter().enumerate() {
|
|
ids[i] = *id;
|
|
distances[i] = *dist;
|
|
}
|
|
|
|
results.len() as u32
|
|
} else {
|
|
0
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Get HNSW index size
|
|
#[no_mangle]
|
|
pub extern "C" fn hnsw_size() -> u32 {
|
|
unsafe {
|
|
HNSW_INDEX.as_ref().map(|i| i.len() as u32).unwrap_or(0)
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_hnsw_insert_search() {
|
|
let mut index = HnswIndex::with_defaults(4, DistanceMetric::Euclidean);
|
|
|
|
// Insert some vectors
|
|
for i in 0..100u64 {
|
|
let v = vec![i as f32, 0.0, 0.0, 0.0];
|
|
assert!(index.insert(i, v));
|
|
}
|
|
|
|
assert_eq!(index.len(), 100);
|
|
|
|
// Search for closest to [50, 0, 0, 0]
|
|
let query = vec![50.0, 0.0, 0.0, 0.0];
|
|
let results = index.search(&query, 5);
|
|
|
|
assert!(!results.is_empty());
|
|
// HNSW is approximate - verify we get results and distance is reasonable
|
|
let (closest_id, closest_dist) = results[0];
|
|
// The closest vector should have a reasonable distance (less than 25)
|
|
assert!(closest_dist < 25.0, "Distance too large: {}", closest_dist);
|
|
// Result should be somewhere in the index
|
|
assert!(closest_id < 100, "Invalid ID: {}", closest_id);
|
|
}
|
|
|
|
#[test]
|
|
fn test_hnsw_cosine() {
|
|
let mut index = HnswIndex::with_defaults(3, DistanceMetric::Cosine);
|
|
|
|
// Insert normalized vectors
|
|
index.insert(1, vec![1.0, 0.0, 0.0]);
|
|
index.insert(2, vec![0.0, 1.0, 0.0]);
|
|
index.insert(3, vec![0.707, 0.707, 0.0]);
|
|
|
|
let query = vec![1.0, 0.0, 0.0];
|
|
let results = index.search(&query, 3);
|
|
|
|
assert_eq!(results[0].0, 1); // Exact match first
|
|
}
|
|
}
|