//! Memory service with HNSW vector search and graph storage //! //! Provides efficient vector similarity search using HNSW algorithm //! with SIMD-accelerated distance computations. use crate::config::MemoryConfig; use crate::error::{Error, MemoryError, Result}; use crate::types::{EdgeType, MemoryEdge, MemoryNode, NodeType}; use dashmap::DashMap; use parking_lot::RwLock; use rand::Rng; use std::collections::{BinaryHeap, HashMap, HashSet}; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; /// Search result from memory #[derive(Debug, Clone)] pub struct SearchResult { /// Retrieved candidates pub candidates: Vec, /// Expanded subgraph pub subgraph: SubGraph, /// Statistics pub stats: SearchStats, } /// Single search candidate #[derive(Debug, Clone)] pub struct SearchCandidate { /// Node ID pub id: String, /// Distance to query pub distance: f32, /// Node data pub node: MemoryNode, } /// Subgraph from neighborhood expansion #[derive(Debug, Clone)] pub struct SubGraph { /// Nodes in subgraph pub nodes: Vec, /// Edges in subgraph pub edges: Vec, /// Center node IDs pub center_ids: Vec, } /// Search statistics #[derive(Debug, Clone, Default)] pub struct SearchStats { /// Number of candidates pub k_retrieved: usize, /// Distance statistics pub distance_mean: f32, pub distance_std: f32, pub distance_min: f32, pub distance_max: f32, /// Graph depth pub graph_depth: usize, /// HNSW layers traversed pub layers_traversed: usize, /// Distance computations performed pub distance_computations: usize, } /// HNSW graph layer struct HnswLayer { /// Connections: node_id -> connected node_ids connections: DashMap>, /// Maximum connections per node max_connections: usize, } impl HnswLayer { fn new(max_connections: usize) -> Self { Self { connections: DashMap::new(), max_connections, } } fn add_connection(&self, from: usize, to: usize) { self.connections .entry(from) .or_insert_with(Vec::new) .push(to); } fn get_neighbors(&self, node: usize) -> Vec { self.connections .get(&node) .map(|v| v.clone()) .unwrap_or_default() } fn prune_connections(&self, node: usize, vectors: &[Vec], max_conn: usize) { if let Some(mut neighbors) = self.connections.get_mut(&node) { if neighbors.len() > max_conn { // Keep closest neighbors let node_vec = &vectors[node]; let mut scored: Vec<(usize, f32)> = neighbors .iter() .map(|&n| (n, cosine_distance(node_vec, &vectors[n]))) .collect(); scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); *neighbors = scored.into_iter().take(max_conn).map(|(n, _)| n).collect(); } } } } /// Candidate for priority queue (min-heap by distance) #[derive(Clone)] struct Candidate { distance: f32, node_id: usize, } impl PartialEq for Candidate { fn eq(&self, other: &Self) -> bool { self.node_id == other.node_id } } impl Eq for Candidate {} impl PartialOrd for Candidate { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } impl Ord for Candidate { fn cmp(&self, other: &Self) -> std::cmp::Ordering { // Reverse for min-heap (smaller distance = higher priority) other .distance .partial_cmp(&self.distance) .unwrap_or(std::cmp::Ordering::Equal) } } /// Memory service providing vector search and graph operations pub struct MemoryService { /// Vectors storage vectors: RwLock>>, /// Node ID to index mapping id_to_index: DashMap, /// Index to node ID mapping index_to_id: RwLock>, /// Node storage nodes: DashMap, /// Edge storage (src_id -> edges) edges: DashMap>, /// HNSW layers hnsw_layers: RwLock>, /// Entry point for HNSW entry_point: RwLock>, /// Max layer (highest level) max_layer: RwLock, /// Configuration config: MemoryConfig, /// Statistics stats: MemoryStats, } /// Memory service statistics struct MemoryStats { /// Total insertions insertions: AtomicU64, /// Total searches searches: AtomicU64, /// Total distance computations distance_computations: AtomicU64, } impl MemoryService { /// Create a new memory service pub async fn new(config: &MemoryConfig) -> Result { // Note: ml (level multiplier) is computed per-insert in hnsw_insert() // to avoid storing it and to handle edge cases properly Ok(Self { vectors: RwLock::new(Vec::new()), id_to_index: DashMap::new(), index_to_id: RwLock::new(Vec::new()), nodes: DashMap::new(), edges: DashMap::new(), hnsw_layers: RwLock::new(vec![HnswLayer::new(config.hnsw_m * 2)]), entry_point: RwLock::new(None), max_layer: RwLock::new(0), config: config.clone(), stats: MemoryStats { insertions: AtomicU64::new(0), searches: AtomicU64::new(0), distance_computations: AtomicU64::new(0), }, }) } /// Search with graph expansion using HNSW pub async fn search_with_graph( &self, query: &[f32], k: usize, ef_search: usize, max_hops: usize, ) -> Result { self.stats.searches.fetch_add(1, Ordering::Relaxed); let vectors = self.vectors.read(); if vectors.is_empty() { return Ok(SearchResult { candidates: vec![], subgraph: SubGraph { nodes: vec![], edges: vec![], center_ids: vec![], }, stats: SearchStats::default(), }); } // HNSW search let (neighbors, layers_traversed, dist_comps) = self.hnsw_search(query, k, ef_search); self.stats .distance_computations .fetch_add(dist_comps as u64, Ordering::Relaxed); // Convert to candidates let index_to_id = self.index_to_id.read(); let candidates: Vec = neighbors .into_iter() .filter_map(|(idx, distance)| { let id = index_to_id.get(idx)?.clone(); let node = self.nodes.get(&id)?.clone(); Some(SearchCandidate { id, distance, node }) }) .collect(); // Expand neighborhood let center_ids: Vec = candidates.iter().map(|c| c.id.clone()).collect(); let subgraph = self.expand_neighborhood(¢er_ids, max_hops)?; // Compute stats let stats = self.compute_stats(&candidates, layers_traversed, dist_comps); Ok(SearchResult { candidates, subgraph, stats, }) } /// HNSW search implementation fn hnsw_search(&self, query: &[f32], k: usize, ef: usize) -> (Vec<(usize, f32)>, usize, usize) { let vectors = self.vectors.read(); let layers = self.hnsw_layers.read(); let entry = *self.entry_point.read(); let max_layer = *self.max_layer.read(); let mut dist_comps = 0; let mut layers_traversed = 0; let entry_point = match entry { Some(ep) => ep, None => return (vec![], 0, 0), }; // Start from entry point let mut current = entry_point; let mut current_dist = cosine_distance(query, &vectors[current]); dist_comps += 1; // Traverse from top layer to layer 1 for layer_idx in (1..=max_layer).rev() { layers_traversed += 1; let layer = &layers[layer_idx]; loop { let neighbors = layer.get_neighbors(current); let mut changed = false; for &neighbor in &neighbors { if neighbor < vectors.len() { let dist = cosine_distance(query, &vectors[neighbor]); dist_comps += 1; if dist < current_dist { current = neighbor; current_dist = dist; changed = true; } } } if !changed { break; } } } // Search at layer 0 with ef layers_traversed += 1; let layer_0 = &layers[0]; let mut visited = HashSet::new(); let mut candidates = BinaryHeap::new(); let mut result = BinaryHeap::new(); visited.insert(current); candidates.push(Candidate { distance: current_dist, node_id: current, }); result.push(std::cmp::Reverse(Candidate { distance: current_dist, node_id: current, })); while let Some(Candidate { distance: _, node_id: current_node, }) = candidates.pop() { // Check if we should stop if let Some(std::cmp::Reverse(furthest)) = result.peek() { if result.len() >= ef { let current_cand = candidates.peek(); if let Some(cc) = current_cand { if cc.distance > furthest.distance { break; } } } } // Explore neighbors let neighbors = layer_0.get_neighbors(current_node); for &neighbor in &neighbors { if !visited.contains(&neighbor) && neighbor < vectors.len() { visited.insert(neighbor); let dist = cosine_distance(query, &vectors[neighbor]); dist_comps += 1; let should_add = result.len() < ef || { if let Some(std::cmp::Reverse(furthest)) = result.peek() { dist < furthest.distance } else { true } }; if should_add { candidates.push(Candidate { distance: dist, node_id: neighbor, }); result.push(std::cmp::Reverse(Candidate { distance: dist, node_id: neighbor, })); if result.len() > ef { result.pop(); } } } } } // Extract top-k results let mut final_results: Vec<(usize, f32)> = result .into_iter() .map(|std::cmp::Reverse(c)| (c.node_id, c.distance)) .collect(); final_results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); final_results.truncate(k); (final_results, layers_traversed, dist_comps) } /// Insert a node with HNSW indexing pub fn insert_node(&self, node: MemoryNode) -> Result { let id = node.id.clone(); let vector = node.vector.clone(); // Check capacity if self.nodes.len() >= self.config.max_nodes { return Err(Error::Memory(MemoryError::CapacityExceeded)); } // Add to storage let index = { let mut vectors = self.vectors.write(); let idx = vectors.len(); vectors.push(vector.clone()); idx }; { let mut index_to_id = self.index_to_id.write(); index_to_id.push(id.clone()); } self.id_to_index.insert(id.clone(), index); self.nodes.insert(id.clone(), node); // Insert into HNSW self.hnsw_insert(index, &vector); self.stats.insertions.fetch_add(1, Ordering::Relaxed); Ok(id) } /// HNSW insertion fn hnsw_insert(&self, node_idx: usize, vector: &[f32]) { let m = self.config.hnsw_m; let m_max = m * 2; // Guard against m=1 which would cause ln(1)=0 and division by zero // Use m=2 as minimum for level calculation let m_for_level = m.max(2) as f32; let ml = 1.0 / m_for_level.ln(); // Determine level for this node let level = self.random_level(ml); let vectors = self.vectors.read(); let mut layers = self.hnsw_layers.write(); let mut entry = self.entry_point.write(); let mut max_layer = self.max_layer.write(); // Ensure we have enough layers while layers.len() <= level { layers.push(HnswLayer::new(m_max)); } // If first node, set as entry point if entry.is_none() { *entry = Some(node_idx); *max_layer = level; return; } let entry_point = entry.unwrap(); let mut current = entry_point; let mut current_dist = cosine_distance(vector, &vectors[current]); // Traverse from top layer down to level+1 for layer_idx in (level + 1..=*max_layer).rev() { let layer = &layers[layer_idx]; loop { let neighbors = layer.get_neighbors(current); let mut changed = false; for &neighbor in &neighbors { if neighbor < vectors.len() { let dist = cosine_distance(vector, &vectors[neighbor]); if dist < current_dist { current = neighbor; current_dist = dist; changed = true; } } } if !changed { break; } } } // Insert at each layer from level down to 0 for layer_idx in (0..=level.min(*max_layer)).rev() { let layer = &layers[layer_idx]; let max_conn = if layer_idx == 0 { m_max } else { m }; // Find ef_construction nearest neighbors let ef = self.config.hnsw_ef_construction; let neighbors = self.search_layer(&vectors, vector, current, ef, layer); // Connect to m nearest let connections: Vec = neighbors .into_iter() .take(max_conn) .map(|(idx, _)| idx) .collect(); // Add bidirectional connections for &conn in &connections { layer.add_connection(node_idx, conn); layer.add_connection(conn, node_idx); // Prune if too many connections layer.prune_connections(conn, &vectors, max_conn); } // Update entry point for next layer if !connections.is_empty() { current = connections[0]; } } // Update entry point if necessary if level > *max_layer { *entry = Some(node_idx); *max_layer = level; } } /// Search within a single layer fn search_layer( &self, vectors: &[Vec], query: &[f32], entry: usize, ef: usize, layer: &HnswLayer, ) -> Vec<(usize, f32)> { let mut visited = HashSet::new(); let mut candidates = BinaryHeap::new(); let mut result = Vec::new(); let entry_dist = cosine_distance(query, &vectors[entry]); visited.insert(entry); candidates.push(Candidate { distance: entry_dist, node_id: entry, }); result.push((entry, entry_dist)); while let Some(Candidate { distance: _, node_id, }) = candidates.pop() { if result.len() >= ef { result.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); if let Some(&(_, furthest_dist)) = result.last() { if let Some(closest) = candidates.peek() { if closest.distance > furthest_dist { break; } } } } let neighbors = layer.get_neighbors(node_id); for &neighbor in &neighbors { if !visited.contains(&neighbor) && neighbor < vectors.len() { visited.insert(neighbor); let dist = cosine_distance(query, &vectors[neighbor]); candidates.push(Candidate { distance: dist, node_id: neighbor, }); result.push((neighbor, dist)); } } } result.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); result.truncate(ef); result } /// Random level for HNSW (exponential distribution) fn random_level(&self, ml: f32) -> usize { let mut rng = rand::thread_rng(); let r: f32 = rng.gen(); // Guard against r=0 which would cause ln(0) = -inf // Also clamp result to prevent overflow when casting to usize if r <= f32::EPSILON { return 0; } let level = (-r.ln() * ml).floor(); // Clamp to reasonable max level to prevent overflow level.min(32.0) as usize } /// Insert an edge pub fn insert_edge(&self, edge: MemoryEdge) -> Result { let id = edge.id.clone(); self.edges .entry(edge.src.clone()) .or_insert_with(Vec::new) .push(edge); Ok(id) } /// Update edge weight pub fn update_edge_weight(&self, src: &str, dst: &str, delta: f32) -> Result<()> { if let Some(mut edges) = self.edges.get_mut(src) { for edge in edges.iter_mut() { if edge.dst == dst { edge.weight = (edge.weight + delta).clamp(0.0, 1.0); break; } } } Ok(()) } /// Get node count pub fn node_count(&self) -> usize { self.nodes.len() } /// Get edge count pub fn edge_count(&self) -> usize { self.edges.iter().map(|e| e.len()).sum() } /// Get node by ID pub fn get_node(&self, id: &str) -> Option { self.nodes.get(id).map(|n| n.clone()) } /// Get edges from a node pub fn get_edges(&self, src: &str) -> Vec { self.edges.get(src).map(|e| e.clone()).unwrap_or_default() } /// Batch insert nodes pub fn insert_batch(&self, nodes: Vec) -> Result> { nodes.into_iter().map(|n| self.insert_node(n)).collect() } /// Flush pending writes (for persistence) pub async fn flush(&self) -> Result<()> { // In production, this would persist to disk Ok(()) } /// Get memory statistics pub fn get_stats(&self) -> MemoryServiceStats { MemoryServiceStats { node_count: self.nodes.len(), edge_count: self.edge_count(), total_insertions: self.stats.insertions.load(Ordering::Relaxed), total_searches: self.stats.searches.load(Ordering::Relaxed), total_distance_computations: self.stats.distance_computations.load(Ordering::Relaxed), hnsw_layers: self.hnsw_layers.read().len(), } } /// Expand neighborhood via graph traversal fn expand_neighborhood(&self, center_ids: &[String], max_hops: usize) -> Result { let mut visited = HashSet::new(); let mut all_nodes = Vec::new(); let mut all_edges = Vec::new(); let mut frontier: Vec = center_ids.to_vec(); for hop in 0..=max_hops { let mut next_frontier = Vec::new(); let is_last_hop = hop == max_hops; for node_id in &frontier { if visited.contains(node_id) { continue; } visited.insert(node_id.clone()); // Get node if let Some(node) = self.nodes.get(node_id) { all_nodes.push(node.clone()); } // Get edges (only collect if not on last hop, to avoid edges leading outside) if !is_last_hop { if let Some(edges) = self.edges.get(node_id) { for edge in edges.iter() { all_edges.push(edge.clone()); if !visited.contains(&edge.dst) { next_frontier.push(edge.dst.clone()); } } } } } frontier = next_frontier; } Ok(SubGraph { nodes: all_nodes, edges: all_edges, center_ids: center_ids.to_vec(), }) } fn compute_stats( &self, candidates: &[SearchCandidate], layers: usize, dist_comps: usize, ) -> SearchStats { if candidates.is_empty() { return SearchStats::default(); } let distances: Vec = candidates.iter().map(|c| c.distance).collect(); let mean = distances.iter().sum::() / distances.len() as f32; let var = distances.iter().map(|d| (d - mean).powi(2)).sum::() / distances.len() as f32; SearchStats { k_retrieved: candidates.len(), distance_mean: mean, distance_std: var.sqrt(), distance_min: distances.iter().cloned().fold(f32::INFINITY, f32::min), distance_max: distances.iter().cloned().fold(f32::NEG_INFINITY, f32::max), graph_depth: 0, layers_traversed: layers, distance_computations: dist_comps, } } } /// Public statistics about memory service #[derive(Debug, Clone)] pub struct MemoryServiceStats { /// Number of nodes pub node_count: usize, /// Number of edges pub edge_count: usize, /// Total insertions pub total_insertions: u64, /// Total searches pub total_searches: u64, /// Total distance computations pub total_distance_computations: u64, /// Number of HNSW layers pub hnsw_layers: usize, } /// SIMD-accelerated cosine distance using simsimd when available #[cfg(feature = "simd")] pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 { use simsimd::SpatialSimilarity; let cos_sim = f32::cosine(a, b).unwrap_or(0.0); 1.0 - cos_sim } #[cfg(not(feature = "simd"))] pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 { let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(); let norm_a: f32 = a.iter().map(|x| x * x).sum::().sqrt(); let norm_b: f32 = b.iter().map(|x| x * x).sum::().sqrt(); if norm_a > 0.0 && norm_b > 0.0 { 1.0 - dot / (norm_a * norm_b) } else { 1.0 } } /// Euclidean distance pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 { a.iter() .zip(b.iter()) .map(|(x, y)| (x - y).powi(2)) .sum::() .sqrt() } /// Inner product (negative for use as distance) pub fn inner_product_distance(a: &[f32], b: &[f32]) -> f32 { -a.iter().zip(b.iter()).map(|(x, y)| x * y).sum::() } #[cfg(test)] mod tests { use super::*; fn create_test_node(id: &str, vector: Vec) -> MemoryNode { MemoryNode { id: id.into(), vector, text: format!("Test node {}", id), node_type: NodeType::Document, source: "test".into(), metadata: HashMap::new(), } } #[tokio::test] async fn test_memory_insert_and_search() { let config = MemoryConfig::default(); let memory = MemoryService::new(&config).await.unwrap(); let node = create_test_node("test-1", vec![1.0, 0.0, 0.0]); memory.insert_node(node).unwrap(); let query = vec![1.0, 0.0, 0.0]; let result = memory.search_with_graph(&query, 10, 64, 2).await.unwrap(); assert_eq!(result.candidates.len(), 1); assert_eq!(result.candidates[0].id, "test-1"); assert!(result.candidates[0].distance < 0.001); } #[tokio::test] async fn test_hnsw_search_accuracy() { let mut config = MemoryConfig::default(); config.hnsw_m = 16; config.hnsw_ef_construction = 100; let memory = MemoryService::new(&config).await.unwrap(); // Insert 100 random vectors let dim = 128; let mut rng = rand::thread_rng(); let mut vectors = Vec::new(); for i in 0..100 { let mut vec: Vec = (0..dim).map(|_| rng.gen::() - 0.5).collect(); // Normalize let norm: f32 = vec.iter().map(|x| x * x).sum::().sqrt(); vec.iter_mut().for_each(|x| *x /= norm); vectors.push(vec.clone()); let node = create_test_node(&format!("node-{}", i), vec); memory.insert_node(node).unwrap(); } // Search for a specific vector let query = vectors[42].clone(); let result = memory.search_with_graph(&query, 10, 64, 0).await.unwrap(); // The closest should be the exact match assert!(!result.candidates.is_empty()); assert_eq!(result.candidates[0].id, "node-42"); assert!(result.candidates[0].distance < 0.001); } #[tokio::test] async fn test_graph_expansion() { let config = MemoryConfig::default(); let memory = MemoryService::new(&config).await.unwrap(); // Create nodes for i in 0..5 { let node = create_test_node(&format!("node-{}", i), vec![i as f32, 0.0, 0.0]); memory.insert_node(node).unwrap(); } // Create edges: 0 -> 1 -> 2 -> 3 -> 4 for i in 0..4 { let edge = MemoryEdge { id: format!("edge-{}", i), src: format!("node-{}", i), dst: format!("node-{}", i + 1), edge_type: EdgeType::Follows, weight: 1.0, metadata: HashMap::new(), }; memory.insert_edge(edge).unwrap(); } // Expand from node-0 with 2 hops let subgraph = memory.expand_neighborhood(&["node-0".into()], 2).unwrap(); // Should include node-0, node-1, node-2 assert_eq!(subgraph.nodes.len(), 3); assert_eq!(subgraph.edges.len(), 2); } #[tokio::test] async fn test_batch_insert() { let config = MemoryConfig::default(); let memory = MemoryService::new(&config).await.unwrap(); let nodes: Vec = (0..10) .map(|i| create_test_node(&format!("batch-{}", i), vec![i as f32; 3])) .collect(); let ids = memory.insert_batch(nodes).unwrap(); assert_eq!(ids.len(), 10); assert_eq!(memory.node_count(), 10); } #[test] fn test_cosine_distance() { let a = vec![1.0, 0.0, 0.0]; let b = vec![1.0, 0.0, 0.0]; assert!(cosine_distance(&a, &b) < 0.001); let c = vec![0.0, 1.0, 0.0]; assert!((cosine_distance(&a, &c) - 1.0).abs() < 0.001); let d = vec![-1.0, 0.0, 0.0]; assert!((cosine_distance(&a, &d) - 2.0).abs() < 0.001); } #[test] fn test_edge_weight_update() { let config = MemoryConfig::default(); let rt = tokio::runtime::Runtime::new().unwrap(); let memory = rt.block_on(MemoryService::new(&config)).unwrap(); let edge = MemoryEdge { id: "e1".into(), src: "n1".into(), dst: "n2".into(), edge_type: EdgeType::Cites, weight: 0.5, metadata: HashMap::new(), }; memory.insert_edge(edge).unwrap(); // Update weight memory.update_edge_weight("n1", "n2", 0.2).unwrap(); let edges = memory.get_edges("n1"); assert_eq!(edges.len(), 1); assert!((edges[0].weight - 0.7).abs() < 0.001); } #[tokio::test] async fn test_memory_stats() { let config = MemoryConfig::default(); let memory = MemoryService::new(&config).await.unwrap(); // Insert some nodes for i in 0..5 { let node = create_test_node(&format!("stat-{}", i), vec![i as f32; 3]); memory.insert_node(node).unwrap(); } // Perform a search memory .search_with_graph(&[0.0, 0.0, 0.0], 5, 32, 0) .await .unwrap(); let stats = memory.get_stats(); assert_eq!(stats.node_count, 5); assert_eq!(stats.total_insertions, 5); assert_eq!(stats.total_searches, 1); } }