//! Lightweight HNSW Index for iOS/Browser WASM //! //! A simplified HNSW implementation optimized for mobile/browser deployment. //! Provides O(log n) approximate nearest neighbor search. //! //! Based on the paper: "Efficient and Robust Approximate Nearest Neighbor Search //! Using Hierarchical Navigable Small World Graphs" use crate::distance::{distance, DistanceMetric}; use std::collections::{BinaryHeap, HashSet}; use std::vec::Vec; use core::cmp::Ordering; /// HNSW configuration #[derive(Clone, Debug)] pub struct HnswConfig { /// Max connections per node (M parameter) pub m: usize, /// Max connections at layer 0 (usually 2*M) pub m_max_0: usize, /// Construction-time search width pub ef_construction: usize, /// Query-time search width pub ef_search: usize, /// Level multiplier (1/ln(M)) pub level_mult: f32, } impl Default for HnswConfig { fn default() -> Self { Self { m: 16, m_max_0: 32, ef_construction: 100, ef_search: 50, level_mult: 0.36, // 1/ln(16) } } } /// Node in the HNSW graph #[derive(Clone, Debug)] struct HnswNode { /// Vector ID id: u64, /// Vector data vector: Vec, /// Connections at each layer connections: Vec>, /// Node's layer level: usize, } /// Search candidate with distance #[derive(Clone, Debug)] struct Candidate { id: u64, distance: f32, } impl PartialEq for Candidate { fn eq(&self, other: &Self) -> bool { self.id == other.id } } impl Eq for Candidate {} impl PartialOrd for Candidate { fn partial_cmp(&self, other: &Self) -> Option { // Reverse order for min-heap behavior in BinaryHeap other.distance.partial_cmp(&self.distance) } } impl Ord for Candidate { fn cmp(&self, other: &Self) -> Ordering { self.partial_cmp(other).unwrap_or(Ordering::Equal) } } /// Lightweight HNSW index pub struct HnswIndex { /// All nodes nodes: Vec, /// ID to node index mapping id_to_idx: std::collections::HashMap, /// Entry point (topmost node) entry_point: Option, /// Maximum level in the graph max_level: usize, /// Configuration config: HnswConfig, /// Distance metric metric: DistanceMetric, /// Dimension dim: usize, /// Random seed for level generation seed: u32, } impl HnswIndex { /// Create a new HNSW index pub fn new(dim: usize, metric: DistanceMetric, config: HnswConfig) -> Self { Self { nodes: Vec::new(), id_to_idx: std::collections::HashMap::new(), entry_point: None, max_level: 0, config, metric, dim, seed: 12345, } } /// Create with default config pub fn with_defaults(dim: usize, metric: DistanceMetric) -> Self { Self::new(dim, metric, HnswConfig::default()) } /// Generate random level for a new node fn random_level(&mut self) -> usize { // LCG random number generator self.seed = self.seed.wrapping_mul(1103515245).wrapping_add(12345); let rand = (self.seed >> 16) as f32 / 32768.0; let level = (-rand.ln() * self.config.level_mult).floor() as usize; level.min(16) // Cap at 16 levels } /// Insert a vector into the index pub fn insert(&mut self, id: u64, vector: Vec) -> bool { if vector.len() != self.dim { return false; } if self.id_to_idx.contains_key(&id) { return false; // Already exists } let level = self.random_level(); let node_idx = self.nodes.len(); // Create node with empty connections let mut node = HnswNode { id, vector, connections: vec![Vec::new(); level + 1], level, }; if let Some(ep_idx) = self.entry_point { // Find entry point at the top level let mut curr_idx = ep_idx; let mut curr_dist = self.distance_to_node(node_idx, curr_idx, &node.vector); // Traverse from top to insertion level for lc in (level + 1..=self.max_level).rev() { let mut changed = true; while changed { changed = false; if let Some(connections) = self.nodes.get(curr_idx).map(|n| n.connections.get(lc).cloned()).flatten() { for &neighbor_id in &connections { if let Some(&neighbor_idx) = self.id_to_idx.get(&neighbor_id) { let d = self.distance_to_node(node_idx, neighbor_idx, &node.vector); if d < curr_dist { curr_dist = d; curr_idx = neighbor_idx; changed = true; } } } } } } // Insert at each level for lc in (0..=level.min(self.max_level)).rev() { let neighbors = self.search_layer(&node.vector, curr_idx, self.config.ef_construction, lc); // Select M best neighbors let m_max = if lc == 0 { self.config.m_max_0 } else { self.config.m }; let selected: Vec = neighbors.iter() .take(m_max) .map(|c| c.id) .collect(); node.connections[lc] = selected.clone(); // Add bidirectional connections for &neighbor_id in &selected { if let Some(&neighbor_idx) = self.id_to_idx.get(&neighbor_id) { if let Some(neighbor_node) = self.nodes.get_mut(neighbor_idx) { if lc < neighbor_node.connections.len() { neighbor_node.connections[lc].push(id); // Prune if too many connections if neighbor_node.connections[lc].len() > m_max { let query = &neighbor_node.vector.clone(); self.prune_connections(neighbor_idx, lc, m_max, query); } } } } } if !neighbors.is_empty() { curr_idx = self.id_to_idx.get(&neighbors[0].id).copied().unwrap_or(curr_idx); } } } // Add node self.nodes.push(node); self.id_to_idx.insert(id, node_idx); // Update entry point if this is higher level if level > self.max_level || self.entry_point.is_none() { self.max_level = level; self.entry_point = Some(node_idx); } true } /// Search for k nearest neighbors pub fn search(&self, query: &[f32], k: usize) -> Vec<(u64, f32)> { self.search_with_ef(query, k, self.config.ef_search) } /// Search with custom ef parameter pub fn search_with_ef(&self, query: &[f32], k: usize, ef: usize) -> Vec<(u64, f32)> { if query.len() != self.dim || self.entry_point.is_none() { return vec![]; } let ep_idx = self.entry_point.unwrap(); // Find entry point by traversing from top let mut curr_idx = ep_idx; let mut curr_dist = distance(query, &self.nodes[curr_idx].vector, self.metric); for lc in (1..=self.max_level).rev() { let mut changed = true; while changed { changed = false; if let Some(connections) = self.nodes.get(curr_idx).and_then(|n| n.connections.get(lc)) { for &neighbor_id in connections { if let Some(&neighbor_idx) = self.id_to_idx.get(&neighbor_id) { let d = distance(query, &self.nodes[neighbor_idx].vector, self.metric); if d < curr_dist { curr_dist = d; curr_idx = neighbor_idx; changed = true; } } } } } } // Search at layer 0 let results = self.search_layer(query, curr_idx, ef, 0); results.into_iter() .take(k) .map(|c| (c.id, c.distance)) .collect() } /// Search within a specific layer fn search_layer(&self, query: &[f32], entry_idx: usize, ef: usize, layer: usize) -> Vec { let entry_id = self.nodes[entry_idx].id; let entry_dist = distance(query, &self.nodes[entry_idx].vector, self.metric); let mut visited: HashSet = HashSet::new(); let mut candidates: BinaryHeap = BinaryHeap::new(); let mut results: Vec = Vec::new(); visited.insert(entry_id); candidates.push(Candidate { id: entry_id, distance: entry_dist }); results.push(Candidate { id: entry_id, distance: entry_dist }); while let Some(current) = candidates.pop() { // Stop if current is worse than worst in results if results.len() >= ef { let worst_dist = results.iter().map(|c| c.distance).fold(f32::NEG_INFINITY, f32::max); if current.distance > worst_dist { break; } } // Explore neighbors if let Some(&curr_idx) = self.id_to_idx.get(¤t.id) { if let Some(connections) = self.nodes.get(curr_idx).and_then(|n| n.connections.get(layer)) { for &neighbor_id in connections { if visited.insert(neighbor_id) { if let Some(&neighbor_idx) = self.id_to_idx.get(&neighbor_id) { let d = distance(query, &self.nodes[neighbor_idx].vector, self.metric); let should_add = results.len() < ef || { let worst = results.iter().map(|c| c.distance).fold(f32::NEG_INFINITY, f32::max); d < worst }; if should_add { candidates.push(Candidate { id: neighbor_id, distance: d }); results.push(Candidate { id: neighbor_id, distance: d }); // Keep only ef best if results.len() > ef { results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap()); results.truncate(ef); } } } } } } } } results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap()); results } /// Prune connections to keep only the best fn prune_connections(&mut self, node_idx: usize, layer: usize, max_conn: usize, query: &[f32]) { // First, collect connection info without holding mutable borrow let connections_to_score: Vec = if let Some(node) = self.nodes.get(node_idx) { if layer < node.connections.len() { node.connections[layer].clone() } else { return; } } else { return; }; // Score connections let mut candidates: Vec<(u64, f32)> = connections_to_score .iter() .filter_map(|&id| { self.id_to_idx.get(&id) .and_then(|&idx| self.nodes.get(idx)) .map(|n| (id, distance(query, &n.vector, self.metric))) }) .collect(); candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); let pruned: Vec = candidates.into_iter() .take(max_conn) .map(|(id, _)| id) .collect(); // Now update the connections if let Some(node) = self.nodes.get_mut(node_idx) { if layer < node.connections.len() { node.connections[layer] = pruned; } } } /// Helper to calculate distance to a node fn distance_to_node(&self, _new_idx: usize, existing_idx: usize, new_vector: &[f32]) -> f32 { if let Some(node) = self.nodes.get(existing_idx) { distance(new_vector, &node.vector, self.metric) } else { f32::MAX } } /// Get number of vectors in the index pub fn len(&self) -> usize { self.nodes.len() } /// Check if empty pub fn is_empty(&self) -> bool { self.nodes.is_empty() } /// Get vector by ID pub fn get(&self, id: u64) -> Option<&[f32]> { self.id_to_idx.get(&id) .and_then(|&idx| self.nodes.get(idx)) .map(|n| n.vector.as_slice()) } // ============================================ // Persistence // ============================================ /// Serialize the HNSW index to bytes /// /// Format: /// - Header (32 bytes): dim, metric, m, m_max_0, ef_construction, ef_search, max_level, node_count /// - For each node: id (8), level (4), vector (dim*4), connections per layer pub fn serialize(&self) -> Vec { let mut bytes = Vec::new(); // Header bytes.extend_from_slice(&(self.dim as u32).to_le_bytes()); bytes.extend_from_slice(&(self.metric as u8).to_le_bytes()); bytes.extend_from_slice(&[0u8; 3]); // padding bytes.extend_from_slice(&(self.config.m as u32).to_le_bytes()); bytes.extend_from_slice(&(self.config.m_max_0 as u32).to_le_bytes()); bytes.extend_from_slice(&(self.config.ef_construction as u32).to_le_bytes()); bytes.extend_from_slice(&(self.config.ef_search as u32).to_le_bytes()); bytes.extend_from_slice(&(self.max_level as u32).to_le_bytes()); bytes.extend_from_slice(&(self.nodes.len() as u32).to_le_bytes()); bytes.extend_from_slice(&self.entry_point.map(|e| e as u32).unwrap_or(u32::MAX).to_le_bytes()); // Nodes for node in &self.nodes { // Node header: id, level bytes.extend_from_slice(&node.id.to_le_bytes()); bytes.extend_from_slice(&(node.level as u32).to_le_bytes()); // Vector for &v in &node.vector { bytes.extend_from_slice(&v.to_le_bytes()); } // Connections: count per layer, then connection IDs bytes.extend_from_slice(&(node.connections.len() as u32).to_le_bytes()); for layer_conns in &node.connections { bytes.extend_from_slice(&(layer_conns.len() as u32).to_le_bytes()); for &conn_id in layer_conns { bytes.extend_from_slice(&conn_id.to_le_bytes()); } } } bytes } /// Deserialize HNSW index from bytes pub fn deserialize(bytes: &[u8]) -> Option { if bytes.len() < 36 { return None; } let mut offset = 0; // Read header let dim = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize; let metric = DistanceMetric::from_u8(bytes[4]); offset = 8; let m = u32::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3]]) as usize; offset += 4; let m_max_0 = u32::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3]]) as usize; offset += 4; let ef_construction = u32::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3]]) as usize; offset += 4; let ef_search = u32::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3]]) as usize; offset += 4; let max_level = u32::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3]]) as usize; offset += 4; let node_count = u32::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3]]) as usize; offset += 4; let entry_point_raw = u32::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3]]); offset += 4; let entry_point = if entry_point_raw == u32::MAX { None } else { Some(entry_point_raw as usize) }; let config = HnswConfig { m, m_max_0, ef_construction, ef_search, level_mult: 1.0 / (m as f32).ln(), }; let mut nodes = Vec::with_capacity(node_count); let mut id_to_idx = std::collections::HashMap::new(); for node_idx in 0..node_count { if offset + 12 > bytes.len() { return None; } // Node header let id = u64::from_le_bytes([ bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3], bytes[offset+4], bytes[offset+5], bytes[offset+6], bytes[offset+7], ]); offset += 8; let level = u32::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3]]) as usize; offset += 4; // Vector let mut vector = Vec::with_capacity(dim); for _ in 0..dim { if offset + 4 > bytes.len() { return None; } let v = f32::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3]]); vector.push(v); offset += 4; } // Connections if offset + 4 > bytes.len() { return None; } let num_layers = u32::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3]]) as usize; offset += 4; let mut connections = Vec::with_capacity(num_layers); for _ in 0..num_layers { if offset + 4 > bytes.len() { return None; } let num_conns = u32::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3]]) as usize; offset += 4; let mut layer_conns = Vec::with_capacity(num_conns); for _ in 0..num_conns { if offset + 8 > bytes.len() { return None; } let conn_id = u64::from_le_bytes([ bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3], bytes[offset+4], bytes[offset+5], bytes[offset+6], bytes[offset+7], ]); layer_conns.push(conn_id); offset += 8; } connections.push(layer_conns); } id_to_idx.insert(id, node_idx); nodes.push(HnswNode { id, vector, connections, level, }); } Some(Self { nodes, id_to_idx, entry_point, max_level, config, metric, dim, seed: 12345, }) } /// Estimate serialized size in bytes pub fn serialized_size(&self) -> usize { let mut size = 36; // Header for node in &self.nodes { size += 12; // id + level size += node.vector.len() * 4; // vector size += 4; // num_layers for layer in &node.connections { size += 4 + layer.len() * 8; // count + connection IDs } } size } } // ============================================ // WASM Exports // ============================================ static mut HNSW_INDEX: Option = None; /// Create HNSW index #[no_mangle] pub extern "C" fn hnsw_create(dim: u32, metric: u8, m: u32, ef_construction: u32) -> i32 { let config = HnswConfig { m: m as usize, m_max_0: (m * 2) as usize, ef_construction: ef_construction as usize, ef_search: 50, level_mult: 1.0 / (m as f32).ln(), }; unsafe { HNSW_INDEX = Some(HnswIndex::new( dim as usize, DistanceMetric::from_u8(metric), config, )); } 0 } /// Insert vector into HNSW #[no_mangle] pub extern "C" fn hnsw_insert(id: u64, vector_ptr: *const f32, len: u32) -> i32 { unsafe { if let Some(index) = HNSW_INDEX.as_mut() { let vector = core::slice::from_raw_parts(vector_ptr, len as usize).to_vec(); if index.insert(id, vector) { 0 } else { -1 } } else { -1 } } } /// Search HNSW index #[no_mangle] pub extern "C" fn hnsw_search( query_ptr: *const f32, query_len: u32, k: u32, ef: u32, out_ids: *mut u64, out_distances: *mut f32, ) -> u32 { unsafe { if let Some(index) = HNSW_INDEX.as_ref() { let query = core::slice::from_raw_parts(query_ptr, query_len as usize); let results = index.search_with_ef(query, k as usize, ef as usize); let ids = core::slice::from_raw_parts_mut(out_ids, results.len()); let distances = core::slice::from_raw_parts_mut(out_distances, results.len()); for (i, (id, dist)) in results.iter().enumerate() { ids[i] = *id; distances[i] = *dist; } results.len() as u32 } else { 0 } } } /// Get HNSW index size #[no_mangle] pub extern "C" fn hnsw_size() -> u32 { unsafe { HNSW_INDEX.as_ref().map(|i| i.len() as u32).unwrap_or(0) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_hnsw_insert_search() { let mut index = HnswIndex::with_defaults(4, DistanceMetric::Euclidean); // Insert some vectors for i in 0..100u64 { let v = vec![i as f32, 0.0, 0.0, 0.0]; assert!(index.insert(i, v)); } assert_eq!(index.len(), 100); // Search for closest to [50, 0, 0, 0] let query = vec![50.0, 0.0, 0.0, 0.0]; let results = index.search(&query, 5); assert!(!results.is_empty()); // HNSW is approximate - verify we get results and distance is reasonable let (closest_id, closest_dist) = results[0]; // The closest vector should have a reasonable distance (less than 25) assert!(closest_dist < 25.0, "Distance too large: {}", closest_dist); // Result should be somewhere in the index assert!(closest_id < 100, "Invalid ID: {}", closest_id); } #[test] fn test_hnsw_cosine() { let mut index = HnswIndex::with_defaults(3, DistanceMetric::Cosine); // Insert normalized vectors index.insert(1, vec![1.0, 0.0, 0.0]); index.insert(2, vec![0.0, 1.0, 0.0]); index.insert(3, vec![0.707, 0.707, 0.0]); let query = vec![1.0, 0.0, 0.0]; let results = index.search(&query, 3); assert_eq!(results[0].0, 1); // Exact match first } }