940 lines
29 KiB
Rust
940 lines
29 KiB
Rust
//! Memory service with HNSW vector search and graph storage
|
|
//!
|
|
//! Provides efficient vector similarity search using HNSW algorithm
|
|
//! with SIMD-accelerated distance computations.
|
|
|
|
use crate::config::MemoryConfig;
|
|
use crate::error::{Error, MemoryError, Result};
|
|
use crate::types::{EdgeType, MemoryEdge, MemoryNode, NodeType};
|
|
|
|
use dashmap::DashMap;
|
|
use parking_lot::RwLock;
|
|
use rand::Rng;
|
|
use std::collections::{BinaryHeap, HashMap, HashSet};
|
|
use std::sync::atomic::{AtomicU64, Ordering};
|
|
use std::sync::Arc;
|
|
|
|
/// Search result from memory
|
|
#[derive(Debug, Clone)]
|
|
pub struct SearchResult {
|
|
/// Retrieved candidates
|
|
pub candidates: Vec<SearchCandidate>,
|
|
/// Expanded subgraph
|
|
pub subgraph: SubGraph,
|
|
/// Statistics
|
|
pub stats: SearchStats,
|
|
}
|
|
|
|
/// Single search candidate
|
|
#[derive(Debug, Clone)]
|
|
pub struct SearchCandidate {
|
|
/// Node ID
|
|
pub id: String,
|
|
/// Distance to query
|
|
pub distance: f32,
|
|
/// Node data
|
|
pub node: MemoryNode,
|
|
}
|
|
|
|
/// Subgraph from neighborhood expansion
|
|
#[derive(Debug, Clone)]
|
|
pub struct SubGraph {
|
|
/// Nodes in subgraph
|
|
pub nodes: Vec<MemoryNode>,
|
|
/// Edges in subgraph
|
|
pub edges: Vec<MemoryEdge>,
|
|
/// Center node IDs
|
|
pub center_ids: Vec<String>,
|
|
}
|
|
|
|
/// Search statistics
|
|
#[derive(Debug, Clone, Default)]
|
|
pub struct SearchStats {
|
|
/// Number of candidates
|
|
pub k_retrieved: usize,
|
|
/// Distance statistics
|
|
pub distance_mean: f32,
|
|
pub distance_std: f32,
|
|
pub distance_min: f32,
|
|
pub distance_max: f32,
|
|
/// Graph depth
|
|
pub graph_depth: usize,
|
|
/// HNSW layers traversed
|
|
pub layers_traversed: usize,
|
|
/// Distance computations performed
|
|
pub distance_computations: usize,
|
|
}
|
|
|
|
/// HNSW graph layer
|
|
struct HnswLayer {
|
|
/// Connections: node_id -> connected node_ids
|
|
connections: DashMap<usize, Vec<usize>>,
|
|
/// Maximum connections per node
|
|
max_connections: usize,
|
|
}
|
|
|
|
impl HnswLayer {
|
|
fn new(max_connections: usize) -> Self {
|
|
Self {
|
|
connections: DashMap::new(),
|
|
max_connections,
|
|
}
|
|
}
|
|
|
|
fn add_connection(&self, from: usize, to: usize) {
|
|
self.connections
|
|
.entry(from)
|
|
.or_insert_with(Vec::new)
|
|
.push(to);
|
|
}
|
|
|
|
fn get_neighbors(&self, node: usize) -> Vec<usize> {
|
|
self.connections
|
|
.get(&node)
|
|
.map(|v| v.clone())
|
|
.unwrap_or_default()
|
|
}
|
|
|
|
fn prune_connections(&self, node: usize, vectors: &[Vec<f32>], max_conn: usize) {
|
|
if let Some(mut neighbors) = self.connections.get_mut(&node) {
|
|
if neighbors.len() > max_conn {
|
|
// Keep closest neighbors
|
|
let node_vec = &vectors[node];
|
|
let mut scored: Vec<(usize, f32)> = neighbors
|
|
.iter()
|
|
.map(|&n| (n, cosine_distance(node_vec, &vectors[n])))
|
|
.collect();
|
|
scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
|
|
*neighbors = scored.into_iter().take(max_conn).map(|(n, _)| n).collect();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Candidate for priority queue (min-heap by distance)
|
|
#[derive(Clone)]
|
|
struct Candidate {
|
|
distance: f32,
|
|
node_id: usize,
|
|
}
|
|
|
|
impl PartialEq for Candidate {
|
|
fn eq(&self, other: &Self) -> bool {
|
|
self.node_id == other.node_id
|
|
}
|
|
}
|
|
|
|
impl Eq for Candidate {}
|
|
|
|
impl PartialOrd for Candidate {
|
|
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
|
Some(self.cmp(other))
|
|
}
|
|
}
|
|
|
|
impl Ord for Candidate {
|
|
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
|
// Reverse for min-heap (smaller distance = higher priority)
|
|
other
|
|
.distance
|
|
.partial_cmp(&self.distance)
|
|
.unwrap_or(std::cmp::Ordering::Equal)
|
|
}
|
|
}
|
|
|
|
/// Memory service providing vector search and graph operations
|
|
pub struct MemoryService {
|
|
/// Vectors storage
|
|
vectors: RwLock<Vec<Vec<f32>>>,
|
|
/// Node ID to index mapping
|
|
id_to_index: DashMap<String, usize>,
|
|
/// Index to node ID mapping
|
|
index_to_id: RwLock<Vec<String>>,
|
|
/// Node storage
|
|
nodes: DashMap<String, MemoryNode>,
|
|
/// Edge storage (src_id -> edges)
|
|
edges: DashMap<String, Vec<MemoryEdge>>,
|
|
/// HNSW layers
|
|
hnsw_layers: RwLock<Vec<HnswLayer>>,
|
|
/// Entry point for HNSW
|
|
entry_point: RwLock<Option<usize>>,
|
|
/// Max layer (highest level)
|
|
max_layer: RwLock<usize>,
|
|
/// Configuration
|
|
config: MemoryConfig,
|
|
/// Statistics
|
|
stats: MemoryStats,
|
|
}
|
|
|
|
/// Memory service statistics
|
|
struct MemoryStats {
|
|
/// Total insertions
|
|
insertions: AtomicU64,
|
|
/// Total searches
|
|
searches: AtomicU64,
|
|
/// Total distance computations
|
|
distance_computations: AtomicU64,
|
|
}
|
|
|
|
impl MemoryService {
|
|
/// Create a new memory service
|
|
pub async fn new(config: &MemoryConfig) -> Result<Self> {
|
|
// Note: ml (level multiplier) is computed per-insert in hnsw_insert()
|
|
// to avoid storing it and to handle edge cases properly
|
|
|
|
Ok(Self {
|
|
vectors: RwLock::new(Vec::new()),
|
|
id_to_index: DashMap::new(),
|
|
index_to_id: RwLock::new(Vec::new()),
|
|
nodes: DashMap::new(),
|
|
edges: DashMap::new(),
|
|
hnsw_layers: RwLock::new(vec![HnswLayer::new(config.hnsw_m * 2)]),
|
|
entry_point: RwLock::new(None),
|
|
max_layer: RwLock::new(0),
|
|
config: config.clone(),
|
|
stats: MemoryStats {
|
|
insertions: AtomicU64::new(0),
|
|
searches: AtomicU64::new(0),
|
|
distance_computations: AtomicU64::new(0),
|
|
},
|
|
})
|
|
}
|
|
|
|
/// Search with graph expansion using HNSW
|
|
pub async fn search_with_graph(
|
|
&self,
|
|
query: &[f32],
|
|
k: usize,
|
|
ef_search: usize,
|
|
max_hops: usize,
|
|
) -> Result<SearchResult> {
|
|
self.stats.searches.fetch_add(1, Ordering::Relaxed);
|
|
|
|
let vectors = self.vectors.read();
|
|
if vectors.is_empty() {
|
|
return Ok(SearchResult {
|
|
candidates: vec![],
|
|
subgraph: SubGraph {
|
|
nodes: vec![],
|
|
edges: vec![],
|
|
center_ids: vec![],
|
|
},
|
|
stats: SearchStats::default(),
|
|
});
|
|
}
|
|
|
|
// HNSW search
|
|
let (neighbors, layers_traversed, dist_comps) = self.hnsw_search(query, k, ef_search);
|
|
self.stats
|
|
.distance_computations
|
|
.fetch_add(dist_comps as u64, Ordering::Relaxed);
|
|
|
|
// Convert to candidates
|
|
let index_to_id = self.index_to_id.read();
|
|
let candidates: Vec<SearchCandidate> = neighbors
|
|
.into_iter()
|
|
.filter_map(|(idx, distance)| {
|
|
let id = index_to_id.get(idx)?.clone();
|
|
let node = self.nodes.get(&id)?.clone();
|
|
Some(SearchCandidate { id, distance, node })
|
|
})
|
|
.collect();
|
|
|
|
// Expand neighborhood
|
|
let center_ids: Vec<String> = candidates.iter().map(|c| c.id.clone()).collect();
|
|
let subgraph = self.expand_neighborhood(¢er_ids, max_hops)?;
|
|
|
|
// Compute stats
|
|
let stats = self.compute_stats(&candidates, layers_traversed, dist_comps);
|
|
|
|
Ok(SearchResult {
|
|
candidates,
|
|
subgraph,
|
|
stats,
|
|
})
|
|
}
|
|
|
|
/// HNSW search implementation
|
|
fn hnsw_search(&self, query: &[f32], k: usize, ef: usize) -> (Vec<(usize, f32)>, usize, usize) {
|
|
let vectors = self.vectors.read();
|
|
let layers = self.hnsw_layers.read();
|
|
let entry = *self.entry_point.read();
|
|
let max_layer = *self.max_layer.read();
|
|
|
|
let mut dist_comps = 0;
|
|
let mut layers_traversed = 0;
|
|
|
|
let entry_point = match entry {
|
|
Some(ep) => ep,
|
|
None => return (vec![], 0, 0),
|
|
};
|
|
|
|
// Start from entry point
|
|
let mut current = entry_point;
|
|
let mut current_dist = cosine_distance(query, &vectors[current]);
|
|
dist_comps += 1;
|
|
|
|
// Traverse from top layer to layer 1
|
|
for layer_idx in (1..=max_layer).rev() {
|
|
layers_traversed += 1;
|
|
let layer = &layers[layer_idx];
|
|
|
|
loop {
|
|
let neighbors = layer.get_neighbors(current);
|
|
let mut changed = false;
|
|
|
|
for &neighbor in &neighbors {
|
|
if neighbor < vectors.len() {
|
|
let dist = cosine_distance(query, &vectors[neighbor]);
|
|
dist_comps += 1;
|
|
if dist < current_dist {
|
|
current = neighbor;
|
|
current_dist = dist;
|
|
changed = true;
|
|
}
|
|
}
|
|
}
|
|
|
|
if !changed {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Search at layer 0 with ef
|
|
layers_traversed += 1;
|
|
let layer_0 = &layers[0];
|
|
|
|
let mut visited = HashSet::new();
|
|
let mut candidates = BinaryHeap::new();
|
|
let mut result = BinaryHeap::new();
|
|
|
|
visited.insert(current);
|
|
candidates.push(Candidate {
|
|
distance: current_dist,
|
|
node_id: current,
|
|
});
|
|
result.push(std::cmp::Reverse(Candidate {
|
|
distance: current_dist,
|
|
node_id: current,
|
|
}));
|
|
|
|
while let Some(Candidate {
|
|
distance: _,
|
|
node_id: current_node,
|
|
}) = candidates.pop()
|
|
{
|
|
// Check if we should stop
|
|
if let Some(std::cmp::Reverse(furthest)) = result.peek() {
|
|
if result.len() >= ef {
|
|
let current_cand = candidates.peek();
|
|
if let Some(cc) = current_cand {
|
|
if cc.distance > furthest.distance {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Explore neighbors
|
|
let neighbors = layer_0.get_neighbors(current_node);
|
|
for &neighbor in &neighbors {
|
|
if !visited.contains(&neighbor) && neighbor < vectors.len() {
|
|
visited.insert(neighbor);
|
|
let dist = cosine_distance(query, &vectors[neighbor]);
|
|
dist_comps += 1;
|
|
|
|
let should_add = result.len() < ef || {
|
|
if let Some(std::cmp::Reverse(furthest)) = result.peek() {
|
|
dist < furthest.distance
|
|
} else {
|
|
true
|
|
}
|
|
};
|
|
|
|
if should_add {
|
|
candidates.push(Candidate {
|
|
distance: dist,
|
|
node_id: neighbor,
|
|
});
|
|
result.push(std::cmp::Reverse(Candidate {
|
|
distance: dist,
|
|
node_id: neighbor,
|
|
}));
|
|
|
|
if result.len() > ef {
|
|
result.pop();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Extract top-k results
|
|
let mut final_results: Vec<(usize, f32)> = result
|
|
.into_iter()
|
|
.map(|std::cmp::Reverse(c)| (c.node_id, c.distance))
|
|
.collect();
|
|
final_results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
|
|
final_results.truncate(k);
|
|
|
|
(final_results, layers_traversed, dist_comps)
|
|
}
|
|
|
|
/// Insert a node with HNSW indexing
|
|
pub fn insert_node(&self, node: MemoryNode) -> Result<String> {
|
|
let id = node.id.clone();
|
|
let vector = node.vector.clone();
|
|
|
|
// Check capacity
|
|
if self.nodes.len() >= self.config.max_nodes {
|
|
return Err(Error::Memory(MemoryError::CapacityExceeded));
|
|
}
|
|
|
|
// Add to storage
|
|
let index = {
|
|
let mut vectors = self.vectors.write();
|
|
let idx = vectors.len();
|
|
vectors.push(vector.clone());
|
|
idx
|
|
};
|
|
|
|
{
|
|
let mut index_to_id = self.index_to_id.write();
|
|
index_to_id.push(id.clone());
|
|
}
|
|
|
|
self.id_to_index.insert(id.clone(), index);
|
|
self.nodes.insert(id.clone(), node);
|
|
|
|
// Insert into HNSW
|
|
self.hnsw_insert(index, &vector);
|
|
self.stats.insertions.fetch_add(1, Ordering::Relaxed);
|
|
|
|
Ok(id)
|
|
}
|
|
|
|
/// HNSW insertion
|
|
fn hnsw_insert(&self, node_idx: usize, vector: &[f32]) {
|
|
let m = self.config.hnsw_m;
|
|
let m_max = m * 2;
|
|
// Guard against m=1 which would cause ln(1)=0 and division by zero
|
|
// Use m=2 as minimum for level calculation
|
|
let m_for_level = m.max(2) as f32;
|
|
let ml = 1.0 / m_for_level.ln();
|
|
|
|
// Determine level for this node
|
|
let level = self.random_level(ml);
|
|
|
|
let vectors = self.vectors.read();
|
|
let mut layers = self.hnsw_layers.write();
|
|
let mut entry = self.entry_point.write();
|
|
let mut max_layer = self.max_layer.write();
|
|
|
|
// Ensure we have enough layers
|
|
while layers.len() <= level {
|
|
layers.push(HnswLayer::new(m_max));
|
|
}
|
|
|
|
// If first node, set as entry point
|
|
if entry.is_none() {
|
|
*entry = Some(node_idx);
|
|
*max_layer = level;
|
|
return;
|
|
}
|
|
|
|
let entry_point = entry.unwrap();
|
|
let mut current = entry_point;
|
|
let mut current_dist = cosine_distance(vector, &vectors[current]);
|
|
|
|
// Traverse from top layer down to level+1
|
|
for layer_idx in (level + 1..=*max_layer).rev() {
|
|
let layer = &layers[layer_idx];
|
|
loop {
|
|
let neighbors = layer.get_neighbors(current);
|
|
let mut changed = false;
|
|
for &neighbor in &neighbors {
|
|
if neighbor < vectors.len() {
|
|
let dist = cosine_distance(vector, &vectors[neighbor]);
|
|
if dist < current_dist {
|
|
current = neighbor;
|
|
current_dist = dist;
|
|
changed = true;
|
|
}
|
|
}
|
|
}
|
|
if !changed {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Insert at each layer from level down to 0
|
|
for layer_idx in (0..=level.min(*max_layer)).rev() {
|
|
let layer = &layers[layer_idx];
|
|
let max_conn = if layer_idx == 0 { m_max } else { m };
|
|
|
|
// Find ef_construction nearest neighbors
|
|
let ef = self.config.hnsw_ef_construction;
|
|
let neighbors = self.search_layer(&vectors, vector, current, ef, layer);
|
|
|
|
// Connect to m nearest
|
|
let connections: Vec<usize> = neighbors
|
|
.into_iter()
|
|
.take(max_conn)
|
|
.map(|(idx, _)| idx)
|
|
.collect();
|
|
|
|
// Add bidirectional connections
|
|
for &conn in &connections {
|
|
layer.add_connection(node_idx, conn);
|
|
layer.add_connection(conn, node_idx);
|
|
// Prune if too many connections
|
|
layer.prune_connections(conn, &vectors, max_conn);
|
|
}
|
|
|
|
// Update entry point for next layer
|
|
if !connections.is_empty() {
|
|
current = connections[0];
|
|
}
|
|
}
|
|
|
|
// Update entry point if necessary
|
|
if level > *max_layer {
|
|
*entry = Some(node_idx);
|
|
*max_layer = level;
|
|
}
|
|
}
|
|
|
|
/// Search within a single layer
|
|
fn search_layer(
|
|
&self,
|
|
vectors: &[Vec<f32>],
|
|
query: &[f32],
|
|
entry: usize,
|
|
ef: usize,
|
|
layer: &HnswLayer,
|
|
) -> Vec<(usize, f32)> {
|
|
let mut visited = HashSet::new();
|
|
let mut candidates = BinaryHeap::new();
|
|
let mut result = Vec::new();
|
|
|
|
let entry_dist = cosine_distance(query, &vectors[entry]);
|
|
visited.insert(entry);
|
|
candidates.push(Candidate {
|
|
distance: entry_dist,
|
|
node_id: entry,
|
|
});
|
|
result.push((entry, entry_dist));
|
|
|
|
while let Some(Candidate {
|
|
distance: _,
|
|
node_id,
|
|
}) = candidates.pop()
|
|
{
|
|
if result.len() >= ef {
|
|
result.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
|
|
if let Some(&(_, furthest_dist)) = result.last() {
|
|
if let Some(closest) = candidates.peek() {
|
|
if closest.distance > furthest_dist {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
let neighbors = layer.get_neighbors(node_id);
|
|
for &neighbor in &neighbors {
|
|
if !visited.contains(&neighbor) && neighbor < vectors.len() {
|
|
visited.insert(neighbor);
|
|
let dist = cosine_distance(query, &vectors[neighbor]);
|
|
candidates.push(Candidate {
|
|
distance: dist,
|
|
node_id: neighbor,
|
|
});
|
|
result.push((neighbor, dist));
|
|
}
|
|
}
|
|
}
|
|
|
|
result.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
|
|
result.truncate(ef);
|
|
result
|
|
}
|
|
|
|
/// Random level for HNSW (exponential distribution)
|
|
fn random_level(&self, ml: f32) -> usize {
|
|
let mut rng = rand::thread_rng();
|
|
let r: f32 = rng.gen();
|
|
// Guard against r=0 which would cause ln(0) = -inf
|
|
// Also clamp result to prevent overflow when casting to usize
|
|
if r <= f32::EPSILON {
|
|
return 0;
|
|
}
|
|
let level = (-r.ln() * ml).floor();
|
|
// Clamp to reasonable max level to prevent overflow
|
|
level.min(32.0) as usize
|
|
}
|
|
|
|
/// Insert an edge
|
|
pub fn insert_edge(&self, edge: MemoryEdge) -> Result<String> {
|
|
let id = edge.id.clone();
|
|
self.edges
|
|
.entry(edge.src.clone())
|
|
.or_insert_with(Vec::new)
|
|
.push(edge);
|
|
Ok(id)
|
|
}
|
|
|
|
/// Update edge weight
|
|
pub fn update_edge_weight(&self, src: &str, dst: &str, delta: f32) -> Result<()> {
|
|
if let Some(mut edges) = self.edges.get_mut(src) {
|
|
for edge in edges.iter_mut() {
|
|
if edge.dst == dst {
|
|
edge.weight = (edge.weight + delta).clamp(0.0, 1.0);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
/// Get node count
|
|
pub fn node_count(&self) -> usize {
|
|
self.nodes.len()
|
|
}
|
|
|
|
/// Get edge count
|
|
pub fn edge_count(&self) -> usize {
|
|
self.edges.iter().map(|e| e.len()).sum()
|
|
}
|
|
|
|
/// Get node by ID
|
|
pub fn get_node(&self, id: &str) -> Option<MemoryNode> {
|
|
self.nodes.get(id).map(|n| n.clone())
|
|
}
|
|
|
|
/// Get edges from a node
|
|
pub fn get_edges(&self, src: &str) -> Vec<MemoryEdge> {
|
|
self.edges.get(src).map(|e| e.clone()).unwrap_or_default()
|
|
}
|
|
|
|
/// Batch insert nodes
|
|
pub fn insert_batch(&self, nodes: Vec<MemoryNode>) -> Result<Vec<String>> {
|
|
nodes.into_iter().map(|n| self.insert_node(n)).collect()
|
|
}
|
|
|
|
/// Flush pending writes (for persistence)
|
|
pub async fn flush(&self) -> Result<()> {
|
|
// In production, this would persist to disk
|
|
Ok(())
|
|
}
|
|
|
|
/// Get memory statistics
|
|
pub fn get_stats(&self) -> MemoryServiceStats {
|
|
MemoryServiceStats {
|
|
node_count: self.nodes.len(),
|
|
edge_count: self.edge_count(),
|
|
total_insertions: self.stats.insertions.load(Ordering::Relaxed),
|
|
total_searches: self.stats.searches.load(Ordering::Relaxed),
|
|
total_distance_computations: self.stats.distance_computations.load(Ordering::Relaxed),
|
|
hnsw_layers: self.hnsw_layers.read().len(),
|
|
}
|
|
}
|
|
|
|
/// Expand neighborhood via graph traversal
|
|
fn expand_neighborhood(&self, center_ids: &[String], max_hops: usize) -> Result<SubGraph> {
|
|
let mut visited = HashSet::new();
|
|
let mut all_nodes = Vec::new();
|
|
let mut all_edges = Vec::new();
|
|
let mut frontier: Vec<String> = center_ids.to_vec();
|
|
|
|
for hop in 0..=max_hops {
|
|
let mut next_frontier = Vec::new();
|
|
let is_last_hop = hop == max_hops;
|
|
|
|
for node_id in &frontier {
|
|
if visited.contains(node_id) {
|
|
continue;
|
|
}
|
|
visited.insert(node_id.clone());
|
|
|
|
// Get node
|
|
if let Some(node) = self.nodes.get(node_id) {
|
|
all_nodes.push(node.clone());
|
|
}
|
|
|
|
// Get edges (only collect if not on last hop, to avoid edges leading outside)
|
|
if !is_last_hop {
|
|
if let Some(edges) = self.edges.get(node_id) {
|
|
for edge in edges.iter() {
|
|
all_edges.push(edge.clone());
|
|
if !visited.contains(&edge.dst) {
|
|
next_frontier.push(edge.dst.clone());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
frontier = next_frontier;
|
|
}
|
|
|
|
Ok(SubGraph {
|
|
nodes: all_nodes,
|
|
edges: all_edges,
|
|
center_ids: center_ids.to_vec(),
|
|
})
|
|
}
|
|
|
|
fn compute_stats(
|
|
&self,
|
|
candidates: &[SearchCandidate],
|
|
layers: usize,
|
|
dist_comps: usize,
|
|
) -> SearchStats {
|
|
if candidates.is_empty() {
|
|
return SearchStats::default();
|
|
}
|
|
|
|
let distances: Vec<f32> = candidates.iter().map(|c| c.distance).collect();
|
|
let mean = distances.iter().sum::<f32>() / distances.len() as f32;
|
|
let var =
|
|
distances.iter().map(|d| (d - mean).powi(2)).sum::<f32>() / distances.len() as f32;
|
|
|
|
SearchStats {
|
|
k_retrieved: candidates.len(),
|
|
distance_mean: mean,
|
|
distance_std: var.sqrt(),
|
|
distance_min: distances.iter().cloned().fold(f32::INFINITY, f32::min),
|
|
distance_max: distances.iter().cloned().fold(f32::NEG_INFINITY, f32::max),
|
|
graph_depth: 0,
|
|
layers_traversed: layers,
|
|
distance_computations: dist_comps,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Public statistics about memory service
|
|
#[derive(Debug, Clone)]
|
|
pub struct MemoryServiceStats {
|
|
/// Number of nodes
|
|
pub node_count: usize,
|
|
/// Number of edges
|
|
pub edge_count: usize,
|
|
/// Total insertions
|
|
pub total_insertions: u64,
|
|
/// Total searches
|
|
pub total_searches: u64,
|
|
/// Total distance computations
|
|
pub total_distance_computations: u64,
|
|
/// Number of HNSW layers
|
|
pub hnsw_layers: usize,
|
|
}
|
|
|
|
/// SIMD-accelerated cosine distance using simsimd when available
|
|
#[cfg(feature = "simd")]
|
|
pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
|
|
use simsimd::SpatialSimilarity;
|
|
let cos_sim = f32::cosine(a, b).unwrap_or(0.0);
|
|
1.0 - cos_sim
|
|
}
|
|
|
|
#[cfg(not(feature = "simd"))]
|
|
pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
|
|
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
|
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
|
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
|
|
|
if norm_a > 0.0 && norm_b > 0.0 {
|
|
1.0 - dot / (norm_a * norm_b)
|
|
} else {
|
|
1.0
|
|
}
|
|
}
|
|
|
|
/// Euclidean distance
|
|
pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
|
|
a.iter()
|
|
.zip(b.iter())
|
|
.map(|(x, y)| (x - y).powi(2))
|
|
.sum::<f32>()
|
|
.sqrt()
|
|
}
|
|
|
|
/// Inner product (negative for use as distance)
|
|
pub fn inner_product_distance(a: &[f32], b: &[f32]) -> f32 {
|
|
-a.iter().zip(b.iter()).map(|(x, y)| x * y).sum::<f32>()
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
fn create_test_node(id: &str, vector: Vec<f32>) -> MemoryNode {
|
|
MemoryNode {
|
|
id: id.into(),
|
|
vector,
|
|
text: format!("Test node {}", id),
|
|
node_type: NodeType::Document,
|
|
source: "test".into(),
|
|
metadata: HashMap::new(),
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_memory_insert_and_search() {
|
|
let config = MemoryConfig::default();
|
|
let memory = MemoryService::new(&config).await.unwrap();
|
|
|
|
let node = create_test_node("test-1", vec![1.0, 0.0, 0.0]);
|
|
memory.insert_node(node).unwrap();
|
|
|
|
let query = vec![1.0, 0.0, 0.0];
|
|
let result = memory.search_with_graph(&query, 10, 64, 2).await.unwrap();
|
|
|
|
assert_eq!(result.candidates.len(), 1);
|
|
assert_eq!(result.candidates[0].id, "test-1");
|
|
assert!(result.candidates[0].distance < 0.001);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_hnsw_search_accuracy() {
|
|
let mut config = MemoryConfig::default();
|
|
config.hnsw_m = 16;
|
|
config.hnsw_ef_construction = 100;
|
|
let memory = MemoryService::new(&config).await.unwrap();
|
|
|
|
// Insert 100 random vectors
|
|
let dim = 128;
|
|
let mut rng = rand::thread_rng();
|
|
let mut vectors = Vec::new();
|
|
|
|
for i in 0..100 {
|
|
let mut vec: Vec<f32> = (0..dim).map(|_| rng.gen::<f32>() - 0.5).collect();
|
|
// Normalize
|
|
let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
|
|
vec.iter_mut().for_each(|x| *x /= norm);
|
|
vectors.push(vec.clone());
|
|
|
|
let node = create_test_node(&format!("node-{}", i), vec);
|
|
memory.insert_node(node).unwrap();
|
|
}
|
|
|
|
// Search for a specific vector
|
|
let query = vectors[42].clone();
|
|
let result = memory.search_with_graph(&query, 10, 64, 0).await.unwrap();
|
|
|
|
// The closest should be the exact match
|
|
assert!(!result.candidates.is_empty());
|
|
assert_eq!(result.candidates[0].id, "node-42");
|
|
assert!(result.candidates[0].distance < 0.001);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_graph_expansion() {
|
|
let config = MemoryConfig::default();
|
|
let memory = MemoryService::new(&config).await.unwrap();
|
|
|
|
// Create nodes
|
|
for i in 0..5 {
|
|
let node = create_test_node(&format!("node-{}", i), vec![i as f32, 0.0, 0.0]);
|
|
memory.insert_node(node).unwrap();
|
|
}
|
|
|
|
// Create edges: 0 -> 1 -> 2 -> 3 -> 4
|
|
for i in 0..4 {
|
|
let edge = MemoryEdge {
|
|
id: format!("edge-{}", i),
|
|
src: format!("node-{}", i),
|
|
dst: format!("node-{}", i + 1),
|
|
edge_type: EdgeType::Follows,
|
|
weight: 1.0,
|
|
metadata: HashMap::new(),
|
|
};
|
|
memory.insert_edge(edge).unwrap();
|
|
}
|
|
|
|
// Expand from node-0 with 2 hops
|
|
let subgraph = memory.expand_neighborhood(&["node-0".into()], 2).unwrap();
|
|
|
|
// Should include node-0, node-1, node-2
|
|
assert_eq!(subgraph.nodes.len(), 3);
|
|
assert_eq!(subgraph.edges.len(), 2);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_batch_insert() {
|
|
let config = MemoryConfig::default();
|
|
let memory = MemoryService::new(&config).await.unwrap();
|
|
|
|
let nodes: Vec<MemoryNode> = (0..10)
|
|
.map(|i| create_test_node(&format!("batch-{}", i), vec![i as f32; 3]))
|
|
.collect();
|
|
|
|
let ids = memory.insert_batch(nodes).unwrap();
|
|
assert_eq!(ids.len(), 10);
|
|
assert_eq!(memory.node_count(), 10);
|
|
}
|
|
|
|
#[test]
|
|
fn test_cosine_distance() {
|
|
let a = vec![1.0, 0.0, 0.0];
|
|
let b = vec![1.0, 0.0, 0.0];
|
|
assert!(cosine_distance(&a, &b) < 0.001);
|
|
|
|
let c = vec![0.0, 1.0, 0.0];
|
|
assert!((cosine_distance(&a, &c) - 1.0).abs() < 0.001);
|
|
|
|
let d = vec![-1.0, 0.0, 0.0];
|
|
assert!((cosine_distance(&a, &d) - 2.0).abs() < 0.001);
|
|
}
|
|
|
|
#[test]
|
|
fn test_edge_weight_update() {
|
|
let config = MemoryConfig::default();
|
|
let rt = tokio::runtime::Runtime::new().unwrap();
|
|
let memory = rt.block_on(MemoryService::new(&config)).unwrap();
|
|
|
|
let edge = MemoryEdge {
|
|
id: "e1".into(),
|
|
src: "n1".into(),
|
|
dst: "n2".into(),
|
|
edge_type: EdgeType::Cites,
|
|
weight: 0.5,
|
|
metadata: HashMap::new(),
|
|
};
|
|
memory.insert_edge(edge).unwrap();
|
|
|
|
// Update weight
|
|
memory.update_edge_weight("n1", "n2", 0.2).unwrap();
|
|
|
|
let edges = memory.get_edges("n1");
|
|
assert_eq!(edges.len(), 1);
|
|
assert!((edges[0].weight - 0.7).abs() < 0.001);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_memory_stats() {
|
|
let config = MemoryConfig::default();
|
|
let memory = MemoryService::new(&config).await.unwrap();
|
|
|
|
// Insert some nodes
|
|
for i in 0..5 {
|
|
let node = create_test_node(&format!("stat-{}", i), vec![i as f32; 3]);
|
|
memory.insert_node(node).unwrap();
|
|
}
|
|
|
|
// Perform a search
|
|
memory
|
|
.search_with_graph(&[0.0, 0.0, 0.0], 5, 32, 0)
|
|
.await
|
|
.unwrap();
|
|
|
|
let stats = memory.get_stats();
|
|
assert_eq!(stats.node_count, 5);
|
|
assert_eq!(stats.total_insertions, 5);
|
|
assert_eq!(stats.total_searches, 1);
|
|
}
|
|
}
|