Squashed 'vendor/ruvector/' content from commit b64c2172

git-subtree-dir: vendor/ruvector
git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
commit d803bfe2b1
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,10 @@
//! Application layer for the Vector Space bounded context.
//!
//! Contains:
//! - Services: Use case implementations
//! - DTOs: Data transfer objects
//! - Commands/Queries: CQRS patterns
pub mod services;
pub use services::*;

View File

@@ -0,0 +1,705 @@
//! Application services for the Vector Space bounded context.
//!
//! These services implement the use cases for vector indexing and search,
//! providing a high-level API that coordinates domain objects and repositories.
use std::sync::Arc;
use parking_lot::RwLock;
use tracing::{debug, info, instrument, warn};
use crate::distance::{cosine_similarity, normalize_vector};
use crate::domain::{
EmbeddingId, HnswConfig, SimilarityEdge, EdgeType,
VectorError,
};
use crate::infrastructure::hnsw_index::HnswIndex;
/// A search result neighbor with similarity information.
#[derive(Debug, Clone)]
pub struct Neighbor {
/// The embedding ID of this neighbor.
pub id: EmbeddingId,
/// Distance from the query vector.
pub distance: f32,
/// Similarity score (1 - distance for cosine).
pub similarity: f32,
/// Rank in the result set (0 = closest).
pub rank: usize,
}
impl Neighbor {
/// Create a new neighbor from search results.
pub fn new(id: EmbeddingId, distance: f32, rank: usize) -> Self {
Self {
id,
distance,
similarity: 1.0 - distance.clamp(0.0, 1.0),
rank,
}
}
/// Check if this neighbor exceeds a similarity threshold.
#[inline]
pub fn is_above_threshold(&self, threshold: f32) -> bool {
self.similarity >= threshold
}
}
/// Options for search queries.
#[derive(Debug, Clone)]
pub struct SearchOptions {
/// Maximum number of results to return.
pub k: usize,
/// Minimum similarity threshold (results below this are filtered).
pub min_similarity: Option<f32>,
/// Maximum distance threshold.
pub max_distance: Option<f32>,
/// ef_search parameter override (higher = more accurate but slower).
pub ef_search: Option<usize>,
/// Whether to include the query vector in results if it exists.
pub include_query: bool,
}
impl Default for SearchOptions {
fn default() -> Self {
Self {
k: 10,
min_similarity: None,
max_distance: None,
ef_search: None,
include_query: false,
}
}
}
impl SearchOptions {
/// Create new search options with specified k.
pub fn new(k: usize) -> Self {
Self {
k,
..Default::default()
}
}
/// Set minimum similarity threshold.
pub fn with_min_similarity(mut self, threshold: f32) -> Self {
self.min_similarity = Some(threshold);
self
}
/// Set maximum distance threshold.
pub fn with_max_distance(mut self, distance: f32) -> Self {
self.max_distance = Some(distance);
self
}
/// Set ef_search parameter.
pub fn with_ef_search(mut self, ef: usize) -> Self {
self.ef_search = Some(ef);
self
}
/// Include query vector in results.
pub fn include_query(mut self) -> Self {
self.include_query = true;
self
}
}
/// The main service for vector space operations.
///
/// This service provides a thread-safe interface for:
/// - Adding and removing embeddings
/// - Nearest neighbor search
/// - Building similarity graphs
pub struct VectorSpaceService {
/// The underlying HNSW index.
index: Arc<RwLock<HnswIndex>>,
/// Configuration for this service.
config: HnswConfig,
}
impl VectorSpaceService {
/// Create a new vector space service with the given configuration.
pub fn new(config: HnswConfig) -> Self {
let index = HnswIndex::new(&config);
Self {
index: Arc::new(RwLock::new(index)),
config,
}
}
/// Create a service from an existing index.
pub fn from_index(index: HnswIndex, config: HnswConfig) -> Self {
Self {
index: Arc::new(RwLock::new(index)),
config,
}
}
/// Get the index dimensions.
#[inline]
pub fn dimensions(&self) -> usize {
self.config.dimensions
}
/// Get the current number of vectors.
pub fn len(&self) -> usize {
self.index.read().len()
}
/// Check if the index is empty.
pub fn is_empty(&self) -> bool {
self.len() == 0
}
/// Get a reference to the configuration.
pub fn config(&self) -> &HnswConfig {
&self.config
}
/// Add a single embedding to the index.
///
/// The vector will be normalized if the configuration specifies normalization.
#[instrument(skip(self, vector), fields(vector_dim = vector.len()))]
pub async fn add_embedding(
&self,
id: EmbeddingId,
vector: Vec<f32>,
) -> Result<(), VectorError> {
self.validate_vector(&vector)?;
let vector = if self.config.normalize {
normalize_vector(&vector)
} else {
vector
};
let mut index = self.index.write();
index.insert(id, &vector)?;
debug!(id = %id, "Added embedding to index");
Ok(())
}
/// Add multiple embeddings in a batch.
///
/// This is more efficient than multiple single adds due to
/// amortized locking overhead.
#[instrument(skip(self, items), fields(batch_size = items.len()))]
pub async fn add_embeddings_batch(
&self,
items: Vec<(EmbeddingId, Vec<f32>)>,
) -> Result<usize, VectorError> {
if items.is_empty() {
return Ok(0);
}
// Validate all vectors first
for (_, vector) in &items {
self.validate_vector(vector)?;
}
// Normalize if needed
let items: Vec<_> = if self.config.normalize {
items
.into_iter()
.map(|(id, v)| (id, normalize_vector(&v)))
.collect()
} else {
items
};
let mut index = self.index.write();
let mut added = 0;
for (id, vector) in &items {
if let Err(e) = index.insert(*id, vector) {
warn!(id = %id, error = %e, "Failed to add embedding in batch");
} else {
added += 1;
}
}
info!(added, total = items.len(), "Batch insert completed");
Ok(added)
}
/// Find the k nearest neighbors to a query vector.
#[instrument(skip(self, query), fields(query_dim = query.len(), k))]
pub async fn find_neighbors(
&self,
query: &[f32],
k: usize,
) -> Result<Vec<Neighbor>, VectorError> {
self.find_neighbors_with_options(query, SearchOptions::new(k))
.await
}
/// Find neighbors with custom search options.
#[instrument(skip(self, query, options), fields(query_dim = query.len()))]
pub async fn find_neighbors_with_options(
&self,
query: &[f32],
options: SearchOptions,
) -> Result<Vec<Neighbor>, VectorError> {
self.validate_vector(query)?;
let query = if self.config.normalize {
normalize_vector(query)
} else {
query.to_vec()
};
let index = self.index.read();
if index.is_empty() {
return Ok(Vec::new());
}
// Request more results if we're filtering
let k_fetch = if options.min_similarity.is_some() || options.max_distance.is_some() {
options.k * 2
} else {
options.k
};
let results = index.search(&query, k_fetch);
let mut neighbors: Vec<_> = results
.into_iter()
.enumerate()
.map(|(rank, (id, distance))| Neighbor::new(id, distance, rank))
.collect();
// Apply filters
if let Some(min_sim) = options.min_similarity {
neighbors.retain(|n| n.similarity >= min_sim);
}
if let Some(max_dist) = options.max_distance {
neighbors.retain(|n| n.distance <= max_dist);
}
// Truncate to requested k
neighbors.truncate(options.k);
// Re-rank after filtering
for (rank, neighbor) in neighbors.iter_mut().enumerate() {
neighbor.rank = rank;
}
debug!(found = neighbors.len(), "Neighbor search completed");
Ok(neighbors)
}
/// Find neighbors using a filter predicate.
///
/// The filter function receives an EmbeddingId and returns true if the
/// embedding should be included in results.
#[instrument(skip(self, query, filter), fields(query_dim = query.len(), k))]
pub async fn find_neighbors_with_filter<F>(
&self,
query: &[f32],
k: usize,
filter: F,
) -> Result<Vec<Neighbor>, VectorError>
where
F: Fn(&EmbeddingId) -> bool + Send + Sync,
{
self.validate_vector(query)?;
let query = if self.config.normalize {
normalize_vector(query)
} else {
query.to_vec()
};
let index = self.index.read();
if index.is_empty() {
return Ok(Vec::new());
}
// Fetch more results to account for filtering
let k_fetch = k * 4;
let results = index.search(&query, k_fetch);
let mut neighbors: Vec<_> = results
.into_iter()
.filter(|(id, _)| filter(id))
.take(k)
.enumerate()
.map(|(rank, (id, distance))| Neighbor::new(id, distance, rank))
.collect();
// Re-rank
for (rank, neighbor) in neighbors.iter_mut().enumerate() {
neighbor.rank = rank;
}
Ok(neighbors)
}
/// Remove an embedding from the index.
#[instrument(skip(self))]
pub async fn remove_embedding(&self, id: &EmbeddingId) -> Result<(), VectorError> {
let mut index = self.index.write();
index.remove(id)?;
debug!(id = %id, "Removed embedding from index");
Ok(())
}
/// Check if an embedding exists in the index.
pub fn contains(&self, id: &EmbeddingId) -> bool {
self.index.read().contains(id)
}
/// Get a vector by its ID.
pub fn get_vector(&self, id: &EmbeddingId) -> Option<Vec<f32>> {
self.index.read().get_vector(id)
}
/// Build similarity edges for an embedding.
///
/// This finds the k nearest neighbors and creates edges to them.
#[instrument(skip(self, vector))]
pub async fn build_similarity_edges(
&self,
id: EmbeddingId,
vector: &[f32],
k: usize,
min_similarity: f32,
) -> Result<Vec<SimilarityEdge>, VectorError> {
let neighbors = self
.find_neighbors_with_options(
vector,
SearchOptions::new(k).with_min_similarity(min_similarity),
)
.await?;
let edges: Vec<_> = neighbors
.into_iter()
.filter(|n| n.id != id) // Exclude self
.map(|n| {
SimilarityEdge::new(id, n.id, n.distance)
.with_type(EdgeType::Similar)
})
.collect();
Ok(edges)
}
/// Compute pairwise similarities for a set of embeddings.
#[instrument(skip(self, vectors))]
pub async fn compute_pairwise_similarities(
&self,
vectors: &[(EmbeddingId, Vec<f32>)],
) -> Result<Vec<(EmbeddingId, EmbeddingId, f32)>, VectorError> {
if vectors.len() < 2 {
return Ok(Vec::new());
}
// Validate all vectors
for (_, vector) in vectors {
self.validate_vector(vector)?;
}
// Normalize if needed
let vectors: Vec<_> = if self.config.normalize {
vectors
.iter()
.map(|(id, v)| (*id, normalize_vector(v)))
.collect()
} else {
vectors.to_vec()
};
let mut similarities = Vec::with_capacity(vectors.len() * (vectors.len() - 1) / 2);
for i in 0..vectors.len() {
for j in (i + 1)..vectors.len() {
let sim = cosine_similarity(&vectors[i].1, &vectors[j].1);
similarities.push((vectors[i].0, vectors[j].0, sim));
}
}
Ok(similarities)
}
/// Clear all embeddings from the index.
pub async fn clear(&self) -> Result<(), VectorError> {
let mut index = self.index.write();
index.clear();
info!("Cleared all embeddings from index");
Ok(())
}
/// Save the index to a file.
pub async fn save(&self, path: &std::path::Path) -> Result<(), VectorError> {
let index = self.index.read();
index.save(path)?;
info!(path = %path.display(), "Saved index to file");
Ok(())
}
/// Load an index from a file.
pub async fn load(path: &std::path::Path, config: HnswConfig) -> Result<Self, VectorError> {
let index = HnswIndex::load(path)?;
info!(path = %path.display(), "Loaded index from file");
Ok(Self::from_index(index, config))
}
/// Get index statistics.
pub fn stats(&self) -> IndexStatistics {
let index = self.index.read();
IndexStatistics {
vector_count: index.len(),
dimensions: self.config.dimensions,
max_capacity: self.config.max_elements,
utilization: index.len() as f64 / self.config.max_elements as f64,
}
}
/// Validate a vector.
fn validate_vector(&self, vector: &[f32]) -> Result<(), VectorError> {
if vector.len() != self.config.dimensions {
return Err(VectorError::dimension_mismatch(
self.config.dimensions,
vector.len(),
));
}
// Check for NaN or Inf
for (i, &v) in vector.iter().enumerate() {
if v.is_nan() {
return Err(VectorError::invalid_vector(format!(
"NaN value at index {i}"
)));
}
if v.is_infinite() {
return Err(VectorError::invalid_vector(format!(
"Infinite value at index {i}"
)));
}
}
Ok(())
}
}
impl Clone for VectorSpaceService {
fn clone(&self) -> Self {
Self {
index: Arc::clone(&self.index),
config: self.config.clone(),
}
}
}
/// Statistics about the vector index.
#[derive(Debug, Clone)]
pub struct IndexStatistics {
/// Number of vectors in the index.
pub vector_count: usize,
/// Dimensionality of vectors.
pub dimensions: usize,
/// Maximum capacity.
pub max_capacity: usize,
/// Utilization ratio (0.0 - 1.0).
pub utilization: f64,
}
/// Builder for `VectorSpaceService`.
pub struct VectorSpaceServiceBuilder {
config: HnswConfig,
}
impl VectorSpaceServiceBuilder {
/// Create a new builder with default configuration.
pub fn new() -> Self {
Self {
config: HnswConfig::default(),
}
}
/// Set the dimensions.
pub fn dimensions(mut self, dim: usize) -> Self {
self.config.dimensions = dim;
self
}
/// Set the M parameter.
pub fn m(mut self, m: usize) -> Self {
self.config.m = m;
self
}
/// Set ef_construction.
pub fn ef_construction(mut self, ef: usize) -> Self {
self.config.ef_construction = ef;
self
}
/// Set ef_search.
pub fn ef_search(mut self, ef: usize) -> Self {
self.config.ef_search = ef;
self
}
/// Set max elements.
pub fn max_elements(mut self, max: usize) -> Self {
self.config.max_elements = max;
self
}
/// Enable or disable normalization.
pub fn normalize(mut self, normalize: bool) -> Self {
self.config.normalize = normalize;
self
}
/// Build the service.
pub fn build(self) -> Result<VectorSpaceService, VectorError> {
self.config.validate()?;
Ok(VectorSpaceService::new(self.config))
}
}
impl Default for VectorSpaceServiceBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_service() -> VectorSpaceService {
let config = HnswConfig::for_dimension(128)
.with_max_elements(1000)
.with_normalize(false);
VectorSpaceService::new(config)
}
#[tokio::test]
async fn test_add_and_search() {
let service = create_test_service();
let id1 = EmbeddingId::new();
let id2 = EmbeddingId::new();
let v1: Vec<f32> = (0..128).map(|i| i as f32 / 128.0).collect();
let v2: Vec<f32> = (0..128).map(|i| (i as f32 + 1.0) / 128.0).collect();
service.add_embedding(id1, v1.clone()).await.unwrap();
service.add_embedding(id2, v2).await.unwrap();
assert_eq!(service.len(), 2);
let neighbors = service.find_neighbors(&v1, 2).await.unwrap();
assert_eq!(neighbors.len(), 2);
assert_eq!(neighbors[0].id, id1);
}
#[tokio::test]
async fn test_dimension_mismatch() {
let service = create_test_service();
let id = EmbeddingId::new();
let wrong_dim: Vec<f32> = vec![0.1; 64];
let result = service.add_embedding(id, wrong_dim).await;
assert!(matches!(
result,
Err(VectorError::DimensionMismatch { .. })
));
}
#[tokio::test]
async fn test_batch_insert() {
let service = create_test_service();
let items: Vec<_> = (0..10)
.map(|i| {
let id = EmbeddingId::new();
let vector: Vec<f32> = (0..128).map(|j| (i * 128 + j) as f32 / 1280.0).collect();
(id, vector)
})
.collect();
let added = service.add_embeddings_batch(items).await.unwrap();
assert_eq!(added, 10);
assert_eq!(service.len(), 10);
}
#[tokio::test]
async fn test_search_with_filter() {
let service = create_test_service();
let ids: Vec<_> = (0..5).map(|_| EmbeddingId::new()).collect();
for (i, id) in ids.iter().enumerate() {
let vector: Vec<f32> = (0..128).map(|j| (i * 128 + j) as f32 / 640.0).collect();
service.add_embedding(*id, vector).await.unwrap();
}
let query: Vec<f32> = (0..128).map(|j| j as f32 / 640.0).collect();
// Filter to only include odd indices
let odd_ids: std::collections::HashSet<_> =
ids.iter().enumerate().filter(|(i, _)| i % 2 == 1).map(|(_, id)| *id).collect();
let neighbors = service
.find_neighbors_with_filter(&query, 10, |id| odd_ids.contains(id))
.await
.unwrap();
for n in &neighbors {
assert!(odd_ids.contains(&n.id));
}
}
#[test]
fn test_neighbor() {
let neighbor = Neighbor::new(EmbeddingId::new(), 0.2, 0);
assert!((neighbor.similarity - 0.8).abs() < 0.001);
assert!(neighbor.is_above_threshold(0.7));
assert!(!neighbor.is_above_threshold(0.9));
}
#[test]
fn test_search_options() {
let opts = SearchOptions::new(10)
.with_min_similarity(0.8)
.with_max_distance(0.3);
assert_eq!(opts.k, 10);
assert_eq!(opts.min_similarity, Some(0.8));
assert_eq!(opts.max_distance, Some(0.3));
}
#[test]
fn test_builder() {
let service = VectorSpaceServiceBuilder::new()
.dimensions(256)
.m(16)
.ef_construction(100)
.max_elements(5000)
.build()
.unwrap();
assert_eq!(service.dimensions(), 256);
}
}

View File

@@ -0,0 +1,375 @@
//! Distance metrics for vector similarity computation.
//!
//! This module provides optimized implementations of common distance metrics:
//! - Cosine distance/similarity
//! - Euclidean (L2) distance
//! - Dot product
//!
//! ## SIMD Optimization
//!
//! When the `simd` feature is enabled, these functions use SIMD intrinsics
//! for improved performance on supported architectures.
#![allow(dead_code)] // Utility functions for future use
/// Compute the cosine distance between two vectors.
///
/// Cosine distance = 1 - cosine_similarity
///
/// Returns a value in [0, 2]:
/// - 0 = identical direction
/// - 1 = orthogonal
/// - 2 = opposite direction
///
/// # Panics
/// Panics in debug mode if vectors have different lengths.
#[inline]
pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
1.0 - cosine_similarity(a, b)
}
/// Compute the cosine similarity between two vectors.
///
/// Returns a value in [-1, 1]:
/// - 1 = identical direction
/// - 0 = orthogonal
/// - -1 = opposite direction
///
/// # Panics
/// Panics in debug mode if vectors have different lengths.
#[inline]
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "Vector length mismatch");
let dot = dot_product(a, b);
let norm_a = l2_norm(a);
let norm_b = l2_norm(b);
if norm_a < 1e-10 || norm_b < 1e-10 {
return 0.0;
}
(dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
}
/// Compute the Euclidean (L2) distance between two vectors.
///
/// Returns the straight-line distance in n-dimensional space.
///
/// # Panics
/// Panics in debug mode if vectors have different lengths.
#[inline]
pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
squared_euclidean_distance(a, b).sqrt()
}
/// Compute the squared Euclidean distance.
///
/// This is faster than `euclidean_distance` when only comparing distances,
/// as it avoids the square root operation.
#[inline]
pub fn squared_euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "Vector length mismatch");
#[cfg(feature = "simd")]
{
simd_squared_euclidean(a, b)
}
#[cfg(not(feature = "simd"))]
{
a.iter()
.zip(b.iter())
.map(|(x, y)| {
let diff = x - y;
diff * diff
})
.sum()
}
}
/// Compute the dot product of two vectors.
///
/// # Panics
/// Panics in debug mode if vectors have different lengths.
#[inline]
pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "Vector length mismatch");
#[cfg(feature = "simd")]
{
simd_dot_product(a, b)
}
#[cfg(not(feature = "simd"))]
{
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
}
/// Compute the L2 (Euclidean) norm of a vector.
#[inline]
pub fn l2_norm(v: &[f32]) -> f32 {
dot_product(v, v).sqrt()
}
/// Compute the L1 (Manhattan) norm of a vector.
#[inline]
pub fn l1_norm(v: &[f32]) -> f32 {
v.iter().map(|x| x.abs()).sum()
}
/// Normalize a vector to unit length (L2 normalization).
///
/// Returns a zero vector if the input has zero or near-zero norm.
#[inline]
pub fn normalize_vector(v: &[f32]) -> Vec<f32> {
let norm = l2_norm(v);
if norm < 1e-10 {
return vec![0.0; v.len()];
}
v.iter().map(|x| x / norm).collect()
}
/// Normalize a vector in place.
#[inline]
pub fn normalize_vector_inplace(v: &mut [f32]) {
let norm = l2_norm(v);
if norm < 1e-10 {
v.fill(0.0);
return;
}
for x in v.iter_mut() {
*x /= norm;
}
}
/// Compute the Manhattan (L1) distance between two vectors.
#[inline]
pub fn manhattan_distance(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "Vector length mismatch");
a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum()
}
/// Compute the Chebyshev (L-infinity) distance between two vectors.
///
/// This is the maximum absolute difference along any dimension.
#[inline]
pub fn chebyshev_distance(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "Vector length mismatch");
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).abs())
.fold(0.0f32, f32::max)
}
/// Compute angular distance (based on cosine).
///
/// Returns the angle in radians between two vectors, normalized to [0, 1].
#[inline]
pub fn angular_distance(a: &[f32], b: &[f32]) -> f32 {
let cos_sim = cosine_similarity(a, b);
cos_sim.acos() / std::f32::consts::PI
}
/// Add two vectors element-wise.
#[inline]
pub fn vector_add(a: &[f32], b: &[f32]) -> Vec<f32> {
debug_assert_eq!(a.len(), b.len(), "Vector length mismatch");
a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
}
/// Subtract two vectors element-wise (a - b).
#[inline]
pub fn vector_sub(a: &[f32], b: &[f32]) -> Vec<f32> {
debug_assert_eq!(a.len(), b.len(), "Vector length mismatch");
a.iter().zip(b.iter()).map(|(x, y)| x - y).collect()
}
/// Scale a vector by a scalar.
#[inline]
pub fn vector_scale(v: &[f32], scalar: f32) -> Vec<f32> {
v.iter().map(|x| x * scalar).collect()
}
/// Compute the centroid (average) of multiple vectors.
pub fn centroid(vectors: &[&[f32]]) -> Option<Vec<f32>> {
if vectors.is_empty() {
return None;
}
let dim = vectors[0].len();
let n = vectors.len() as f32;
let mut result = vec![0.0; dim];
for v in vectors {
debug_assert_eq!(v.len(), dim, "Vector dimension mismatch");
for (i, &x) in v.iter().enumerate() {
result[i] += x;
}
}
for x in result.iter_mut() {
*x /= n;
}
Some(result)
}
/// Check if a vector is normalized (unit length).
#[inline]
pub fn is_normalized(v: &[f32], tolerance: f32) -> bool {
let norm = l2_norm(v);
(norm - 1.0).abs() < tolerance
}
// SIMD implementations
#[cfg(feature = "simd")]
fn simd_dot_product(a: &[f32], b: &[f32]) -> f32 {
// Fall back to scalar for now - can be enhanced with platform-specific SIMD
// when needed (e.g., using std::arch for AVX/AVX2/AVX-512)
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[cfg(feature = "simd")]
fn simd_squared_euclidean(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| {
let diff = x - y;
diff * diff
})
.sum()
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_cosine_similarity_identical() {
let v = vec![1.0, 2.0, 3.0];
let sim = cosine_similarity(&v, &v);
assert_relative_eq!(sim, 1.0, epsilon = 1e-6);
}
#[test]
fn test_cosine_similarity_opposite() {
let v1 = vec![1.0, 0.0, 0.0];
let v2 = vec![-1.0, 0.0, 0.0];
let sim = cosine_similarity(&v1, &v2);
assert_relative_eq!(sim, -1.0, epsilon = 1e-6);
}
#[test]
fn test_cosine_similarity_orthogonal() {
let v1 = vec![1.0, 0.0, 0.0];
let v2 = vec![0.0, 1.0, 0.0];
let sim = cosine_similarity(&v1, &v2);
assert_relative_eq!(sim, 0.0, epsilon = 1e-6);
}
#[test]
fn test_cosine_distance() {
let v1 = vec![1.0, 0.0, 0.0];
let v2 = vec![0.0, 1.0, 0.0];
let dist = cosine_distance(&v1, &v2);
assert_relative_eq!(dist, 1.0, epsilon = 1e-6);
}
#[test]
fn test_euclidean_distance() {
let v1 = vec![0.0, 0.0, 0.0];
let v2 = vec![3.0, 4.0, 0.0];
let dist = euclidean_distance(&v1, &v2);
assert_relative_eq!(dist, 5.0, epsilon = 1e-6);
}
#[test]
fn test_normalize_vector() {
let v = vec![3.0, 4.0];
let normalized = normalize_vector(&v);
assert_relative_eq!(l2_norm(&normalized), 1.0, epsilon = 1e-6);
assert_relative_eq!(normalized[0], 0.6, epsilon = 1e-6);
assert_relative_eq!(normalized[1], 0.8, epsilon = 1e-6);
}
#[test]
fn test_normalize_zero_vector() {
let v = vec![0.0, 0.0, 0.0];
let normalized = normalize_vector(&v);
assert!(normalized.iter().all(|&x| x == 0.0));
}
#[test]
fn test_dot_product() {
let v1 = vec![1.0, 2.0, 3.0];
let v2 = vec![4.0, 5.0, 6.0];
let dot = dot_product(&v1, &v2);
assert_relative_eq!(dot, 32.0, epsilon = 1e-6); // 1*4 + 2*5 + 3*6 = 32
}
#[test]
fn test_manhattan_distance() {
let v1 = vec![0.0, 0.0];
let v2 = vec![3.0, 4.0];
let dist = manhattan_distance(&v1, &v2);
assert_relative_eq!(dist, 7.0, epsilon = 1e-6);
}
#[test]
fn test_chebyshev_distance() {
let v1 = vec![0.0, 0.0];
let v2 = vec![3.0, 4.0];
let dist = chebyshev_distance(&v1, &v2);
assert_relative_eq!(dist, 4.0, epsilon = 1e-6);
}
#[test]
fn test_centroid() {
let v1 = vec![0.0, 0.0];
let v2 = vec![2.0, 2.0];
let v3 = vec![4.0, 4.0];
let c = centroid(&[&v1, &v2, &v3]).unwrap();
assert_relative_eq!(c[0], 2.0, epsilon = 1e-6);
assert_relative_eq!(c[1], 2.0, epsilon = 1e-6);
}
#[test]
fn test_is_normalized() {
let v = normalize_vector(&[3.0, 4.0]);
assert!(is_normalized(&v, 1e-6));
let v2 = vec![1.0, 2.0, 3.0];
assert!(!is_normalized(&v2, 1e-6));
}
#[test]
fn test_vector_operations() {
let v1 = vec![1.0, 2.0];
let v2 = vec![3.0, 4.0];
let sum = vector_add(&v1, &v2);
assert_eq!(sum, vec![4.0, 6.0]);
let diff = vector_sub(&v1, &v2);
assert_eq!(diff, vec![-2.0, -2.0]);
let scaled = vector_scale(&v1, 2.0);
assert_eq!(scaled, vec![2.0, 4.0]);
}
#[test]
fn test_angular_distance() {
let v1 = vec![1.0, 0.0];
let v2 = vec![0.0, 1.0];
let dist = angular_distance(&v1, &v2);
// 90 degrees = pi/2, normalized = 0.5
assert_relative_eq!(dist, 0.5, epsilon = 1e-6);
}
}

View File

@@ -0,0 +1,639 @@
//! Domain entities for the Vector Space bounded context.
//!
//! These are the core domain objects that represent the vector indexing domain.
use serde::{Deserialize, Serialize};
use std::fmt;
use uuid::Uuid;
/// A unique identifier for an embedding vector.
///
/// This wraps a UUID and provides domain-specific semantics.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct EmbeddingId(Uuid);
impl EmbeddingId {
/// Create a new random embedding ID.
#[inline]
pub fn new() -> Self {
Self(Uuid::new_v4())
}
/// Create an embedding ID from a UUID.
#[inline]
pub const fn from_uuid(uuid: Uuid) -> Self {
Self(uuid)
}
/// Parse an embedding ID from a string.
pub fn parse(s: &str) -> Result<Self, uuid::Error> {
Ok(Self(Uuid::parse_str(s)?))
}
/// Get the inner UUID.
#[inline]
pub const fn as_uuid(&self) -> &Uuid {
&self.0
}
/// Convert to bytes for storage.
#[inline]
pub fn as_bytes(&self) -> &[u8; 16] {
self.0.as_bytes()
}
/// Create from bytes.
#[inline]
pub fn from_bytes(bytes: [u8; 16]) -> Self {
Self(Uuid::from_bytes(bytes))
}
}
impl Default for EmbeddingId {
fn default() -> Self {
Self::new()
}
}
impl fmt::Display for EmbeddingId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<Uuid> for EmbeddingId {
fn from(uuid: Uuid) -> Self {
Self(uuid)
}
}
impl From<EmbeddingId> for Uuid {
fn from(id: EmbeddingId) -> Self {
id.0
}
}
/// Unix timestamp in milliseconds.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct Timestamp(i64);
impl Timestamp {
/// Create a timestamp for the current moment.
pub fn now() -> Self {
Self(chrono::Utc::now().timestamp_millis())
}
/// Create a timestamp from milliseconds since Unix epoch.
#[inline]
pub const fn from_millis(millis: i64) -> Self {
Self(millis)
}
/// Get milliseconds since Unix epoch.
#[inline]
pub const fn as_millis(&self) -> i64 {
self.0
}
/// Convert to chrono DateTime.
pub fn to_datetime(&self) -> chrono::DateTime<chrono::Utc> {
chrono::DateTime::from_timestamp_millis(self.0)
.unwrap_or_else(|| chrono::DateTime::UNIX_EPOCH)
}
}
impl Default for Timestamp {
fn default() -> Self {
Self::now()
}
}
impl fmt::Display for Timestamp {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.to_datetime().format("%Y-%m-%d %H:%M:%S%.3f UTC"))
}
}
/// Configuration for the HNSW index.
///
/// These parameters control the trade-off between search accuracy,
/// index build time, and memory usage.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HnswConfig {
/// Number of bi-directional links per element.
/// Higher values improve recall but increase memory.
/// Recommended: 32 for 1536-dimensional vectors.
pub m: usize,
/// Size of dynamic candidate list during construction.
/// Higher values improve index quality but slow construction.
/// Recommended: 200 for high-quality indices.
pub ef_construction: usize,
/// Size of dynamic candidate list during search.
/// Higher values improve recall but slow queries.
/// Recommended: 128 for balanced accuracy/speed.
pub ef_search: usize,
/// Maximum number of elements the index can hold.
/// Pre-allocating improves construction performance.
pub max_elements: usize,
/// Dimensionality of vectors in this index.
pub dimensions: usize,
/// Whether to normalize vectors before indexing.
pub normalize: bool,
/// Distance metric to use.
pub distance_metric: DistanceMetric,
}
impl HnswConfig {
/// Create a configuration optimized for a given dimension.
pub fn for_dimension(dim: usize) -> Self {
Self {
m: if dim >= 1024 { 32 } else { 16 },
ef_construction: 200,
ef_search: 128,
max_elements: 1_000_000,
dimensions: dim,
normalize: true,
distance_metric: DistanceMetric::Cosine,
}
}
/// Create a configuration for OpenAI-style 1536-D embeddings.
pub fn for_openai_embeddings() -> Self {
Self::for_dimension(1536)
}
/// Create a configuration for smaller sentence transformers (384-D).
pub fn for_sentence_transformers() -> Self {
Self::for_dimension(384)
}
/// Builder: set M parameter.
pub fn with_m(mut self, m: usize) -> Self {
self.m = m;
self
}
/// Builder: set ef_construction parameter.
pub fn with_ef_construction(mut self, ef: usize) -> Self {
self.ef_construction = ef;
self
}
/// Builder: set ef_search parameter.
pub fn with_ef_search(mut self, ef: usize) -> Self {
self.ef_search = ef;
self
}
/// Builder: set maximum elements.
pub fn with_max_elements(mut self, max: usize) -> Self {
self.max_elements = max;
self
}
/// Builder: set distance metric.
pub fn with_distance_metric(mut self, metric: DistanceMetric) -> Self {
self.distance_metric = metric;
self
}
/// Builder: set normalization flag.
pub fn with_normalize(mut self, normalize: bool) -> Self {
self.normalize = normalize;
self
}
/// Validate the configuration.
pub fn validate(&self) -> Result<(), ConfigValidationError> {
if self.m < 2 {
return Err(ConfigValidationError::InvalidM(self.m));
}
if self.ef_construction < self.m {
return Err(ConfigValidationError::EfTooSmall {
ef: self.ef_construction,
m: self.m,
});
}
if self.dimensions == 0 {
return Err(ConfigValidationError::ZeroDimensions);
}
if self.max_elements == 0 {
return Err(ConfigValidationError::ZeroMaxElements);
}
Ok(())
}
}
impl Default for HnswConfig {
fn default() -> Self {
Self::for_openai_embeddings()
}
}
/// Distance metric for vector similarity.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum DistanceMetric {
/// Cosine distance (1 - cosine_similarity).
/// Best for normalized embeddings.
Cosine,
/// Euclidean (L2) distance.
/// Best for spatial data.
Euclidean,
/// Dot product (negative for similarity ranking).
/// Best for when vectors are already normalized.
DotProduct,
/// Poincaré distance in hyperbolic space.
/// Best for hierarchical relationships.
Poincare,
}
impl Default for DistanceMetric {
fn default() -> Self {
Self::Cosine
}
}
/// Configuration validation errors.
#[derive(Debug, Clone, thiserror::Error)]
pub enum ConfigValidationError {
#[error("M parameter must be >= 2, got {0}")]
InvalidM(usize),
#[error("ef_construction ({ef}) must be >= M ({m})")]
EfTooSmall { ef: usize, m: usize },
#[error("dimensions cannot be zero")]
ZeroDimensions,
#[error("max_elements cannot be zero")]
ZeroMaxElements,
}
/// Metadata about a vector index.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorIndex {
/// Unique identifier for this index.
pub id: String,
/// Human-readable name.
pub name: String,
/// Number of dimensions per vector.
pub dimensions: usize,
/// Current number of vectors in the index.
pub size: usize,
/// Configuration used for this index.
pub config: HnswConfig,
/// When the index was created.
pub created_at: Timestamp,
/// When the index was last modified.
pub updated_at: Timestamp,
/// Optional description.
pub description: Option<String>,
}
impl VectorIndex {
/// Create a new vector index metadata object.
pub fn new(id: impl Into<String>, name: impl Into<String>, config: HnswConfig) -> Self {
let now = Timestamp::now();
Self {
id: id.into(),
name: name.into(),
dimensions: config.dimensions,
size: 0,
config,
created_at: now,
updated_at: now,
description: None,
}
}
/// Update the size and modification timestamp.
pub fn update_size(&mut self, size: usize) {
self.size = size;
self.updated_at = Timestamp::now();
}
/// Set the description.
pub fn with_description(mut self, desc: impl Into<String>) -> Self {
self.description = Some(desc.into());
self
}
}
/// Type of relationship between embeddings.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum EdgeType {
/// Embeddings are similar based on vector proximity.
Similar,
/// Embeddings are sequential (temporal ordering).
Sequential,
/// Embeddings belong to the same cluster.
SameCluster,
/// Embeddings are from the same source/recording.
SameSource,
/// Custom relationship type.
Custom,
}
impl Default for EdgeType {
fn default() -> Self {
Self::Similar
}
}
impl fmt::Display for EdgeType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Similar => write!(f, "similar"),
Self::Sequential => write!(f, "sequential"),
Self::SameCluster => write!(f, "same_cluster"),
Self::SameSource => write!(f, "same_source"),
Self::Custom => write!(f, "custom"),
}
}
}
/// An edge in the similarity graph between embeddings.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SimilarityEdge {
/// Source embedding ID.
pub from_id: EmbeddingId,
/// Target embedding ID.
pub to_id: EmbeddingId,
/// Distance between the embeddings.
pub distance: f32,
/// Type of relationship.
pub edge_type: EdgeType,
/// When this edge was created.
pub created_at: Timestamp,
/// Optional weight for weighted graph operations.
pub weight: Option<f32>,
/// Optional metadata.
pub metadata: Option<EdgeMetadata>,
}
impl SimilarityEdge {
/// Create a new similarity edge.
pub fn new(from_id: EmbeddingId, to_id: EmbeddingId, distance: f32) -> Self {
Self {
from_id,
to_id,
distance,
edge_type: EdgeType::Similar,
created_at: Timestamp::now(),
weight: None,
metadata: None,
}
}
/// Create a sequential edge (for temporal ordering).
pub fn sequential(from_id: EmbeddingId, to_id: EmbeddingId) -> Self {
Self {
from_id,
to_id,
distance: 0.0,
edge_type: EdgeType::Sequential,
created_at: Timestamp::now(),
weight: None,
metadata: None,
}
}
/// Set the edge type.
pub fn with_type(mut self, edge_type: EdgeType) -> Self {
self.edge_type = edge_type;
self
}
/// Set the weight.
pub fn with_weight(mut self, weight: f32) -> Self {
self.weight = Some(weight);
self
}
/// Set metadata.
pub fn with_metadata(mut self, metadata: EdgeMetadata) -> Self {
self.metadata = Some(metadata);
self
}
/// Get similarity (1 - distance) for cosine metric.
#[inline]
pub fn similarity(&self) -> f32 {
1.0 - self.distance.clamp(0.0, 1.0)
}
/// Check if this is a strong connection (high similarity).
#[inline]
pub fn is_strong(&self, threshold: f32) -> bool {
self.similarity() >= threshold
}
}
/// Optional metadata for edges.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct EdgeMetadata {
/// Source of this relationship.
pub source: Option<String>,
/// Confidence score for this relationship.
pub confidence: Option<f32>,
/// Additional key-value pairs.
pub attributes: hashbrown::HashMap<String, String>,
}
impl EdgeMetadata {
/// Create new empty metadata.
pub fn new() -> Self {
Self::default()
}
/// Set the source.
pub fn with_source(mut self, source: impl Into<String>) -> Self {
self.source = Some(source.into());
self
}
/// Set the confidence.
pub fn with_confidence(mut self, confidence: f32) -> Self {
self.confidence = Some(confidence);
self
}
/// Add an attribute.
pub fn with_attribute(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.attributes.insert(key.into(), value.into());
self
}
}
/// A stored vector with its ID and optional metadata.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StoredVector {
/// Unique identifier.
pub id: EmbeddingId,
/// The vector data.
pub vector: Vec<f32>,
/// When this vector was stored.
pub created_at: Timestamp,
/// Optional metadata.
pub metadata: Option<VectorMetadata>,
}
impl StoredVector {
/// Create a new stored vector.
pub fn new(id: EmbeddingId, vector: Vec<f32>) -> Self {
Self {
id,
vector,
created_at: Timestamp::now(),
metadata: None,
}
}
/// Set metadata.
pub fn with_metadata(mut self, metadata: VectorMetadata) -> Self {
self.metadata = Some(metadata);
self
}
/// Get the dimensionality.
#[inline]
pub fn dimensions(&self) -> usize {
self.vector.len()
}
}
/// Optional metadata for stored vectors.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct VectorMetadata {
/// Source file or recording ID.
pub source_id: Option<String>,
/// Timestamp within the source (e.g., audio timestamp).
pub source_timestamp: Option<f64>,
/// Labels or tags.
pub labels: Vec<String>,
/// Additional key-value pairs.
pub attributes: hashbrown::HashMap<String, serde_json::Value>,
}
impl VectorMetadata {
/// Create new empty metadata.
pub fn new() -> Self {
Self::default()
}
/// Set the source ID.
pub fn with_source_id(mut self, id: impl Into<String>) -> Self {
self.source_id = Some(id.into());
self
}
/// Set the source timestamp.
pub fn with_source_timestamp(mut self, ts: f64) -> Self {
self.source_timestamp = Some(ts);
self
}
/// Add a label.
pub fn with_label(mut self, label: impl Into<String>) -> Self {
self.labels.push(label.into());
self
}
/// Add an attribute.
pub fn with_attribute(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.attributes.insert(key.into(), value);
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embedding_id_creation() {
let id1 = EmbeddingId::new();
let id2 = EmbeddingId::new();
assert_ne!(id1, id2);
}
#[test]
fn test_embedding_id_parse() {
let id = EmbeddingId::new();
let s = id.to_string();
let parsed = EmbeddingId::parse(&s).unwrap();
assert_eq!(id, parsed);
}
#[test]
fn test_hnsw_config_default() {
let config = HnswConfig::default();
assert_eq!(config.dimensions, 1536);
assert_eq!(config.m, 32);
assert!(config.validate().is_ok());
}
#[test]
fn test_hnsw_config_validation() {
let config = HnswConfig::default().with_m(1);
assert!(config.validate().is_err());
let config = HnswConfig::default().with_ef_construction(10);
assert!(config.validate().is_err());
}
#[test]
fn test_similarity_edge() {
let from = EmbeddingId::new();
let to = EmbeddingId::new();
let edge = SimilarityEdge::new(from, to, 0.2);
assert_eq!(edge.similarity(), 0.8);
assert!(edge.is_strong(0.7));
assert!(!edge.is_strong(0.9));
}
#[test]
fn test_timestamp() {
let ts1 = Timestamp::now();
std::thread::sleep(std::time::Duration::from_millis(10));
let ts2 = Timestamp::now();
assert!(ts2 > ts1);
}
}

View File

@@ -0,0 +1,193 @@
//! Error types for the Vector Space bounded context.
use std::path::PathBuf;
use thiserror::Error;
use super::entities::{ConfigValidationError, EmbeddingId};
/// Main error type for vector operations.
#[derive(Debug, Error)]
pub enum VectorError {
/// Dimension mismatch between vector and index.
#[error("Dimension mismatch: expected {expected}, got {got}")]
DimensionMismatch {
/// Expected dimensions from index configuration.
expected: usize,
/// Actual dimensions of the provided vector.
got: usize,
},
/// Vector with this ID already exists.
#[error("Vector with ID {0} already exists")]
DuplicateId(EmbeddingId),
/// Vector with this ID was not found.
#[error("Vector with ID {0} not found")]
NotFound(EmbeddingId),
/// Index capacity exceeded.
#[error("Index capacity exceeded: max {max}, current {current}")]
CapacityExceeded {
/// Maximum capacity.
max: usize,
/// Current size.
current: usize,
},
/// Invalid vector data (e.g., contains NaN or Inf).
#[error("Invalid vector data: {0}")]
InvalidVector(String),
/// Configuration error.
#[error("Configuration error: {0}")]
ConfigError(#[from] ConfigValidationError),
/// Index is empty.
#[error("Index is empty")]
EmptyIndex,
/// Serialization error.
#[error("Serialization error: {0}")]
SerializationError(String),
/// IO error during persistence.
#[error("IO error: {0}")]
IoError(#[from] std::io::Error),
/// File not found.
#[error("File not found: {0}")]
FileNotFound(PathBuf),
/// Corrupted index file.
#[error("Corrupted index file: {0}")]
CorruptedFile(String),
/// Lock acquisition failed.
#[error("Failed to acquire lock: {0}")]
LockError(String),
/// Operation timeout.
#[error("Operation timed out after {0}ms")]
Timeout(u64),
/// Index not initialized.
#[error("Index not initialized")]
NotInitialized,
/// Concurrent modification detected.
#[error("Concurrent modification detected")]
ConcurrentModification,
/// Graph operation error.
#[error("Graph error: {0}")]
GraphError(String),
/// Search parameters invalid.
#[error("Invalid search parameters: {0}")]
InvalidSearchParams(String),
/// Internal error (should not happen in normal operation).
#[error("Internal error: {0}")]
Internal(String),
}
impl VectorError {
/// Create a dimension mismatch error.
pub fn dimension_mismatch(expected: usize, got: usize) -> Self {
Self::DimensionMismatch { expected, got }
}
/// Create a capacity exceeded error.
pub fn capacity_exceeded(max: usize, current: usize) -> Self {
Self::CapacityExceeded { max, current }
}
/// Create an invalid vector error.
pub fn invalid_vector(msg: impl Into<String>) -> Self {
Self::InvalidVector(msg.into())
}
/// Create a serialization error.
pub fn serialization(msg: impl Into<String>) -> Self {
Self::SerializationError(msg.into())
}
/// Create a corrupted file error.
pub fn corrupted(msg: impl Into<String>) -> Self {
Self::CorruptedFile(msg.into())
}
/// Create a lock error.
pub fn lock(msg: impl Into<String>) -> Self {
Self::LockError(msg.into())
}
/// Create a graph error.
pub fn graph(msg: impl Into<String>) -> Self {
Self::GraphError(msg.into())
}
/// Create an invalid search params error.
pub fn invalid_search(msg: impl Into<String>) -> Self {
Self::InvalidSearchParams(msg.into())
}
/// Create an internal error.
pub fn internal(msg: impl Into<String>) -> Self {
Self::Internal(msg.into())
}
/// Check if this is a retriable error.
pub fn is_retriable(&self) -> bool {
matches!(
self,
Self::LockError(_) | Self::Timeout(_) | Self::ConcurrentModification
)
}
/// Check if this is a not-found error.
pub fn is_not_found(&self) -> bool {
matches!(self, Self::NotFound(_) | Self::FileNotFound(_))
}
}
impl From<bincode::Error> for VectorError {
fn from(e: bincode::Error) -> Self {
Self::SerializationError(e.to_string())
}
}
impl From<serde_json::Error> for VectorError {
fn from(e: serde_json::Error) -> Self {
Self::SerializationError(e.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_messages() {
let err = VectorError::dimension_mismatch(1536, 768);
assert!(err.to_string().contains("1536"));
assert!(err.to_string().contains("768"));
let err = VectorError::NotFound(EmbeddingId::new());
assert!(err.to_string().contains("not found"));
}
#[test]
fn test_is_retriable() {
assert!(VectorError::lock("test").is_retriable());
assert!(VectorError::Timeout(1000).is_retriable());
assert!(!VectorError::EmptyIndex.is_retriable());
}
#[test]
fn test_is_not_found() {
assert!(VectorError::NotFound(EmbeddingId::new()).is_not_found());
assert!(VectorError::FileNotFound(PathBuf::from("/test")).is_not_found());
assert!(!VectorError::EmptyIndex.is_not_found());
}
}

View File

@@ -0,0 +1,15 @@
//! Domain layer for the Vector Space bounded context.
//!
//! Contains:
//! - Entities: Core domain objects with identity
//! - Value Objects: Immutable objects defined by their attributes
//! - Repository Traits: Abstractions for persistence
//! - Domain Errors: Error types specific to this context
pub mod entities;
pub mod repository;
pub mod error;
pub use entities::*;
pub use repository::*;
pub use error::*;

View File

@@ -0,0 +1,328 @@
//! Repository traits for the Vector Space bounded context.
//!
//! These traits define the persistence abstractions that allow the domain
//! to remain independent of specific storage implementations.
use async_trait::async_trait;
use super::entities::{EmbeddingId, SimilarityEdge, EdgeType, StoredVector, VectorMetadata};
use super::error::VectorError;
/// Result type for repository operations.
pub type RepoResult<T> = Result<T, VectorError>;
/// Repository trait for vector index operations.
///
/// This trait abstracts the HNSW index storage, allowing for different
/// implementations (in-memory, file-backed, distributed).
#[async_trait]
pub trait VectorIndexRepository: Send + Sync {
/// Insert a single vector into the index.
///
/// # Arguments
/// * `id` - Unique identifier for this vector
/// * `vector` - The embedding vector data
///
/// # Errors
/// Returns error if the vector dimensions don't match the index configuration.
async fn insert(&self, id: &EmbeddingId, vector: &[f32]) -> RepoResult<()>;
/// Search for the k nearest neighbors to a query vector.
///
/// # Arguments
/// * `query` - The query vector
/// * `k` - Number of neighbors to return
///
/// # Returns
/// A vector of (id, distance) tuples, sorted by ascending distance.
async fn search(&self, query: &[f32], k: usize) -> RepoResult<Vec<(EmbeddingId, f32)>>;
/// Insert multiple vectors in a batch.
///
/// This is more efficient than multiple single inserts due to
/// amortized locking and potential parallelization.
///
/// # Arguments
/// * `items` - Slice of (id, vector) pairs to insert
async fn batch_insert(&self, items: &[(EmbeddingId, Vec<f32>)]) -> RepoResult<()>;
/// Remove a vector from the index.
///
/// # Arguments
/// * `id` - The ID of the vector to remove
///
/// # Note
/// Not all HNSW implementations support efficient removal.
/// Some may mark as deleted without reclaiming space.
async fn remove(&self, id: &EmbeddingId) -> RepoResult<()>;
/// Check if a vector exists in the index.
async fn contains(&self, id: &EmbeddingId) -> RepoResult<bool>;
/// Get the current number of vectors in the index.
async fn len(&self) -> RepoResult<usize>;
/// Check if the index is empty.
async fn is_empty(&self) -> RepoResult<bool> {
Ok(self.len().await? == 0)
}
/// Clear all vectors from the index.
async fn clear(&self) -> RepoResult<()>;
/// Get the dimensionality of vectors in this index.
fn dimensions(&self) -> usize;
}
/// Extended repository trait with additional query capabilities.
#[async_trait]
pub trait VectorIndexRepositoryExt: VectorIndexRepository {
/// Search with a filter predicate.
///
/// # Arguments
/// * `query` - The query vector
/// * `k` - Number of neighbors to return
/// * `filter` - Predicate that must return true for results to be included
async fn search_with_filter<F>(
&self,
query: &[f32],
k: usize,
filter: F,
) -> RepoResult<Vec<(EmbeddingId, f32)>>
where
F: Fn(&EmbeddingId) -> bool + Send + Sync;
/// Search within a distance threshold.
///
/// Returns all vectors within the given distance, up to a maximum count.
async fn search_within_radius(
&self,
query: &[f32],
radius: f32,
max_results: usize,
) -> RepoResult<Vec<(EmbeddingId, f32)>>;
/// Get multiple vectors by their IDs.
async fn get_vectors(&self, ids: &[EmbeddingId]) -> RepoResult<Vec<Option<StoredVector>>>;
/// Get a single vector by ID.
async fn get_vector(&self, id: &EmbeddingId) -> RepoResult<Option<StoredVector>>;
/// Update the metadata for a vector.
async fn update_metadata(&self, id: &EmbeddingId, metadata: VectorMetadata) -> RepoResult<()>;
/// List all vector IDs in the index.
async fn list_ids(&self, offset: usize, limit: usize) -> RepoResult<Vec<EmbeddingId>>;
}
/// Repository trait for graph edge operations.
///
/// This manages the similarity graph between embeddings, supporting
/// graph-based queries and traversals.
#[async_trait]
pub trait GraphEdgeRepository: Send + Sync {
/// Add an edge between two embeddings.
async fn add_edge(&self, edge: SimilarityEdge) -> RepoResult<()>;
/// Add multiple edges in a batch.
async fn add_edges(&self, edges: &[SimilarityEdge]) -> RepoResult<()>;
/// Remove an edge between two embeddings.
async fn remove_edge(&self, from: &EmbeddingId, to: &EmbeddingId) -> RepoResult<()>;
/// Get all edges from a given embedding.
async fn get_edges_from(&self, id: &EmbeddingId) -> RepoResult<Vec<SimilarityEdge>>;
/// Get all edges to a given embedding.
async fn get_edges_to(&self, id: &EmbeddingId) -> RepoResult<Vec<SimilarityEdge>>;
/// Get edges of a specific type from an embedding.
async fn get_edges_by_type(
&self,
id: &EmbeddingId,
edge_type: EdgeType,
) -> RepoResult<Vec<SimilarityEdge>>;
/// Find edges with similarity above a threshold.
async fn get_strong_edges(
&self,
id: &EmbeddingId,
min_similarity: f32,
) -> RepoResult<Vec<SimilarityEdge>>;
/// Get the number of edges in the graph.
async fn edge_count(&self) -> RepoResult<usize>;
/// Clear all edges.
async fn clear(&self) -> RepoResult<()>;
/// Remove all edges connected to an embedding.
async fn remove_edges_for(&self, id: &EmbeddingId) -> RepoResult<()>;
}
/// Trait for graph traversal operations.
#[async_trait]
pub trait GraphTraversal: GraphEdgeRepository {
/// Find the shortest path between two embeddings.
async fn shortest_path(
&self,
from: &EmbeddingId,
to: &EmbeddingId,
max_depth: usize,
) -> RepoResult<Option<Vec<EmbeddingId>>>;
/// Find all embeddings within n hops of a given embedding.
async fn neighbors_within_hops(
&self,
id: &EmbeddingId,
hops: usize,
) -> RepoResult<Vec<(EmbeddingId, usize)>>;
/// Find connected components in the graph.
async fn connected_components(&self) -> RepoResult<Vec<Vec<EmbeddingId>>>;
/// Calculate PageRank-style centrality for embeddings.
async fn centrality_scores(&self) -> RepoResult<Vec<(EmbeddingId, f32)>>;
}
/// Trait for persistence operations.
#[async_trait]
pub trait IndexPersistence: Send + Sync {
/// Save the index to a file.
async fn save(&self, path: &std::path::Path) -> RepoResult<()>;
/// Load the index from a file.
async fn load(path: &std::path::Path) -> RepoResult<Self>
where
Self: Sized;
/// Export the index to a portable format.
async fn export(&self, path: &std::path::Path, format: ExportFormat) -> RepoResult<()>;
/// Import vectors from a portable format.
async fn import(&self, path: &std::path::Path, format: ExportFormat) -> RepoResult<usize>;
}
/// Export formats for index data.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ExportFormat {
/// Binary format using bincode (fast, compact).
Bincode,
/// JSON format (readable, portable).
Json,
/// NumPy-compatible format for vector data.
Numpy,
/// CSV format for interoperability.
Csv,
}
/// Statistics about an index.
#[derive(Debug, Clone)]
pub struct IndexStats {
/// Total number of vectors.
pub vector_count: usize,
/// Number of bytes used.
pub memory_bytes: usize,
/// Average search latency in microseconds.
pub avg_search_latency_us: f64,
/// Index build time in milliseconds.
pub build_time_ms: u64,
/// Dimensionality.
pub dimensions: usize,
/// Number of levels in the HNSW graph.
pub levels: usize,
/// Average connections per node.
pub avg_connections: f64,
}
/// Trait for index statistics and monitoring.
#[async_trait]
pub trait IndexMonitoring: Send + Sync {
/// Get current index statistics.
async fn stats(&self) -> RepoResult<IndexStats>;
/// Get memory usage breakdown.
async fn memory_usage(&self) -> RepoResult<MemoryUsage>;
/// Run a self-check to verify index integrity.
async fn verify(&self) -> RepoResult<VerificationResult>;
}
/// Memory usage breakdown.
#[derive(Debug, Clone)]
pub struct MemoryUsage {
/// Memory used by vector data.
pub vectors_bytes: usize,
/// Memory used by the HNSW graph structure.
pub graph_bytes: usize,
/// Memory used by ID mappings.
pub id_map_bytes: usize,
/// Memory used by metadata.
pub metadata_bytes: usize,
/// Total memory.
pub total_bytes: usize,
}
/// Result of index verification.
#[derive(Debug, Clone)]
pub struct VerificationResult {
/// Whether the index is valid.
pub is_valid: bool,
/// List of issues found.
pub issues: Vec<String>,
/// Number of orphaned nodes.
pub orphaned_nodes: usize,
/// Number of broken links.
pub broken_links: usize,
}
impl VerificationResult {
/// Create a successful verification result.
pub fn ok() -> Self {
Self {
is_valid: true,
issues: Vec::new(),
orphaned_nodes: 0,
broken_links: 0,
}
}
/// Create a failed verification result.
pub fn failed(issues: Vec<String>) -> Self {
Self {
is_valid: false,
issues,
orphaned_nodes: 0,
broken_links: 0,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_verification_result() {
let ok = VerificationResult::ok();
assert!(ok.is_valid);
assert!(ok.issues.is_empty());
let failed = VerificationResult::failed(vec!["test issue".into()]);
assert!(!failed.is_valid);
assert_eq!(failed.issues.len(), 1);
}
}

View File

@@ -0,0 +1,454 @@
//! Hyperbolic geometry operations for hierarchical embeddings.
//!
//! This module implements operations in the Poincare ball model of hyperbolic space,
//! which is particularly useful for representing hierarchical relationships
//! in embeddings (e.g., taxonomy trees, part-whole relationships).
//!
//! ## Poincare Ball Model
//!
//! The Poincare ball is the open unit ball B^n = {x in R^n : ||x|| < 1}
//! equipped with the Riemannian metric:
//!
//! g_x = (2 / (1 - ||x||^2))^2 * I
//!
//! This metric causes distances to grow exponentially near the boundary,
//! making it ideal for tree-like structures.
//!
//! ## Key Operations
//!
//! - `exp_map`: Project from tangent space to hyperbolic space
//! - `log_map`: Project from hyperbolic space to tangent space
//! - `mobius_add`: Gyrovector addition (parallel transport)
//! - `poincare_distance`: Geodesic distance on the manifold
#![allow(dead_code)] // Hyperbolic geometry utilities for future use
/// Default curvature for the Poincare ball model.
/// Negative curvature corresponds to hyperbolic space.
pub const DEFAULT_CURVATURE: f32 = -1.0;
/// Epsilon for numerical stability.
const EPS: f32 = 1e-7;
/// Maximum norm to prevent points from reaching the boundary.
const MAX_NORM: f32 = 1.0 - 1e-5;
/// Compute the Poincare distance between two points in the Poincare ball.
///
/// The geodesic distance in the Poincare ball model is:
///
/// d(u, v) = (1/sqrt(-c)) * arcosh(1 + 2 * ||u - v||^2 / ((1 - ||u||^2) * (1 - ||v||^2)))
///
/// where c is the (negative) curvature.
///
/// # Arguments
/// * `u` - First point in the Poincare ball
/// * `v` - Second point in the Poincare ball
/// * `curvature` - Curvature of the space (negative for hyperbolic)
///
/// # Returns
/// The geodesic distance between u and v.
pub fn poincare_distance(u: &[f32], v: &[f32], curvature: f32) -> f32 {
debug_assert_eq!(u.len(), v.len(), "Vector length mismatch");
debug_assert!(curvature < 0.0, "Curvature must be negative for hyperbolic space");
let sqrt_c = (-curvature).sqrt();
let norm_u_sq = squared_norm(u);
let norm_v_sq = squared_norm(v);
// Clamp norms to ensure they're inside the ball
let norm_u_sq = norm_u_sq.min(MAX_NORM * MAX_NORM);
let norm_v_sq = norm_v_sq.min(MAX_NORM * MAX_NORM);
let diff_sq = squared_distance(u, v);
let denominator = (1.0 - norm_u_sq) * (1.0 - norm_v_sq);
let argument = 1.0 + 2.0 * diff_sq / (denominator + EPS);
// arcosh(x) = ln(x + sqrt(x^2 - 1))
let arcosh_val = (argument + (argument * argument - 1.0).max(0.0).sqrt()).ln();
arcosh_val / sqrt_c
}
/// Exponential map: project from tangent space at origin to the Poincare ball.
///
/// Maps a Euclidean vector v from the tangent space T_0 B^n at the origin
/// to a point on the Poincare ball.
///
/// exp_0(v) = tanh(sqrt(-c) * ||v|| / 2) * v / (sqrt(-c) * ||v||)
///
/// # Arguments
/// * `v` - Vector in tangent space
/// * `curvature` - Curvature of the space (negative for hyperbolic)
///
/// # Returns
/// Point in the Poincare ball.
pub fn exp_map(v: &[f32], curvature: f32) -> Vec<f32> {
let sqrt_c = (-curvature).sqrt();
let norm_v = l2_norm(v);
if norm_v < EPS {
return vec![0.0; v.len()];
}
let scale = (sqrt_c * norm_v / 2.0).tanh() / (sqrt_c * norm_v);
v.iter().map(|&x| x * scale).collect()
}
/// Exponential map from an arbitrary base point.
///
/// exp_x(v) = mobius_add(x, exp_0(v), c)
///
/// # Arguments
/// * `x` - Base point in the Poincare ball
/// * `v` - Vector in tangent space at x
/// * `curvature` - Curvature of the space
pub fn exp_map_at(x: &[f32], v: &[f32], curvature: f32) -> Vec<f32> {
let exp_v = exp_map(v, curvature);
mobius_add(x, &exp_v, curvature)
}
/// Logarithmic map: project from Poincare ball to tangent space at origin.
///
/// Inverse of the exponential map.
///
/// log_0(y) = (2 / sqrt(-c)) * arctanh(sqrt(-c) * ||y||) * y / ||y||
///
/// # Arguments
/// * `y` - Point in the Poincare ball
/// * `curvature` - Curvature of the space (negative for hyperbolic)
///
/// # Returns
/// Vector in tangent space at origin.
pub fn log_map(y: &[f32], curvature: f32) -> Vec<f32> {
let sqrt_c = (-curvature).sqrt();
let norm_y = l2_norm(y).min(MAX_NORM);
if norm_y < EPS {
return vec![0.0; y.len()];
}
let scale = (2.0 / sqrt_c) * (sqrt_c * norm_y).atanh() / norm_y;
y.iter().map(|&x| x * scale).collect()
}
/// Logarithmic map from an arbitrary base point.
///
/// log_x(y) = log_0(mobius_add(-x, y, c))
///
/// # Arguments
/// * `x` - Base point in the Poincare ball
/// * `y` - Target point in the Poincare ball
/// * `curvature` - Curvature of the space
pub fn log_map_at(x: &[f32], y: &[f32], curvature: f32) -> Vec<f32> {
let neg_x: Vec<f32> = x.iter().map(|&v| -v).collect();
let diff = mobius_add(&neg_x, y, curvature);
log_map(&diff, curvature)
}
/// Mobius addition (gyrovector addition).
///
/// The Mobius addition is the binary operation in the Poincare ball
/// that generalizes vector addition. It can be seen as parallel transport
/// followed by addition.
///
/// u ⊕ v = ((1 + 2c<u,v> + c||v||^2)u + (1 - c||u||^2)v) /
/// (1 + 2c<u,v> + c^2||u||^2||v||^2)
///
/// # Arguments
/// * `u` - First point in the Poincare ball
/// * `v` - Second point in the Poincare ball
/// * `curvature` - Curvature of the space (negative for hyperbolic)
///
/// # Returns
/// Result of Mobius addition u ⊕ v.
pub fn mobius_add(u: &[f32], v: &[f32], curvature: f32) -> Vec<f32> {
debug_assert_eq!(u.len(), v.len(), "Vector length mismatch");
let c = -curvature;
let norm_u_sq = squared_norm(u);
let norm_v_sq = squared_norm(v);
let dot_uv = dot_product(u, v);
let numerator_u_coef = 1.0 + 2.0 * c * dot_uv + c * norm_v_sq;
let numerator_v_coef = 1.0 - c * norm_u_sq;
let denominator = 1.0 + 2.0 * c * dot_uv + c * c * norm_u_sq * norm_v_sq;
let mut result = Vec::with_capacity(u.len());
for i in 0..u.len() {
let value = (numerator_u_coef * u[i] + numerator_v_coef * v[i]) / (denominator + EPS);
result.push(value);
}
// Project back into the ball if needed
project_to_ball(&mut result);
result
}
/// Mobius scalar multiplication.
///
/// r ⊗ x = (1/sqrt(c)) * tanh(r * arctanh(sqrt(c) * ||x||)) * x / ||x||
///
/// # Arguments
/// * `r` - Scalar multiplier
/// * `x` - Point in the Poincare ball
/// * `curvature` - Curvature of the space
pub fn mobius_scalar_mul(r: f32, x: &[f32], curvature: f32) -> Vec<f32> {
let sqrt_c = (-curvature).sqrt();
let norm_x = l2_norm(x).min(MAX_NORM);
if norm_x < EPS {
return vec![0.0; x.len()];
}
let scale = (r * (sqrt_c * norm_x).atanh()).tanh() / (sqrt_c * norm_x);
x.iter().map(|&v| v * scale).collect()
}
/// Compute the hyperbolic midpoint of two points.
///
/// The midpoint is the point on the geodesic between u and v
/// that is equidistant from both.
pub fn hyperbolic_midpoint(u: &[f32], v: &[f32], curvature: f32) -> Vec<f32> {
// log_u(v) gives direction and distance to v from u
let log_v = log_map_at(u, v, curvature);
// Scale by 0.5 to get halfway
let half_log: Vec<f32> = log_v.iter().map(|&x| x * 0.5).collect();
// Map back to the ball
exp_map_at(u, &half_log, curvature)
}
/// Compute the hyperbolic centroid of multiple points.
///
/// This is the Einstein (Frechet) mean in hyperbolic space.
pub fn hyperbolic_centroid(points: &[&[f32]], curvature: f32) -> Option<Vec<f32>> {
if points.is_empty() {
return None;
}
let dim = points[0].len();
// Start with the Euclidean centroid projected onto the ball
let mut centroid = vec![0.0; dim];
for point in points {
for (i, &v) in point.iter().enumerate() {
centroid[i] += v;
}
}
for x in centroid.iter_mut() {
*x /= points.len() as f32;
}
project_to_ball(&mut centroid);
// Iteratively refine using gradient descent
// (simplified version - could use Riemannian gradient descent for better accuracy)
for _ in 0..10 {
let mut grad = vec![0.0; dim];
for point in points {
let log_p = log_map_at(&centroid, point, curvature);
for (i, &v) in log_p.iter().enumerate() {
grad[i] += v;
}
}
// Average gradient
for x in grad.iter_mut() {
*x /= points.len() as f32;
}
// Update centroid
centroid = exp_map_at(&centroid, &grad, curvature);
}
Some(centroid)
}
/// Convert a Euclidean embedding to a Poincare ball embedding.
///
/// Uses the exponential map at the origin.
pub fn euclidean_to_poincare(euclidean: &[f32], curvature: f32) -> Vec<f32> {
exp_map(euclidean, curvature)
}
/// Convert a Poincare ball embedding to Euclidean space.
///
/// Uses the logarithmic map at the origin.
pub fn poincare_to_euclidean(poincare: &[f32], curvature: f32) -> Vec<f32> {
log_map(poincare, curvature)
}
/// Project a point into the Poincare ball if it lies outside.
fn project_to_ball(v: &mut [f32]) {
let norm = l2_norm(v);
if norm >= MAX_NORM {
let scale = MAX_NORM / norm;
for x in v.iter_mut() {
*x *= scale;
}
}
}
/// Compute the squared L2 norm of a vector.
#[inline]
fn squared_norm(v: &[f32]) -> f32 {
v.iter().map(|x| x * x).sum()
}
/// Compute the L2 norm of a vector.
#[inline]
fn l2_norm(v: &[f32]) -> f32 {
squared_norm(v).sqrt()
}
/// Compute the squared Euclidean distance between two vectors.
#[inline]
fn squared_distance(u: &[f32], v: &[f32]) -> f32 {
u.iter()
.zip(v.iter())
.map(|(a, b)| {
let diff = a - b;
diff * diff
})
.sum()
}
/// Compute the dot product of two vectors.
#[inline]
fn dot_product(u: &[f32], v: &[f32]) -> f32 {
u.iter().zip(v.iter()).map(|(a, b)| a * b).sum()
}
/// Conformal factor at a point (metric scaling).
///
/// The conformal factor lambda(x) = 2 / (1 - ||x||^2)
/// determines how much distances are scaled at point x.
pub fn conformal_factor(x: &[f32]) -> f32 {
let norm_sq = squared_norm(x).min(MAX_NORM * MAX_NORM);
2.0 / (1.0 - norm_sq)
}
/// Check if a point is inside the Poincare ball.
pub fn is_in_ball(x: &[f32]) -> bool {
squared_norm(x) < 1.0
}
/// Compute hyperbolic angle between vectors in tangent space.
pub fn hyperbolic_angle(u: &[f32], v: &[f32]) -> f32 {
let norm_u = l2_norm(u);
let norm_v = l2_norm(v);
if norm_u < EPS || norm_v < EPS {
return 0.0;
}
let cos_angle = dot_product(u, v) / (norm_u * norm_v);
cos_angle.clamp(-1.0, 1.0).acos()
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_poincare_distance_same_point() {
let u = vec![0.1, 0.2, 0.3];
let dist = poincare_distance(&u, &u, DEFAULT_CURVATURE);
assert_relative_eq!(dist, 0.0, epsilon = 1e-5);
}
#[test]
fn test_poincare_distance_origin() {
let origin = vec![0.0, 0.0, 0.0];
let v = vec![0.5, 0.0, 0.0];
let dist = poincare_distance(&origin, &v, DEFAULT_CURVATURE);
assert!(dist > 0.0);
}
#[test]
fn test_exp_log_inverse() {
let v = vec![0.5, 0.3, 0.1];
let exp_v = exp_map(&v, DEFAULT_CURVATURE);
let log_exp_v = log_map(&exp_v, DEFAULT_CURVATURE);
for (a, b) in v.iter().zip(log_exp_v.iter()) {
assert_relative_eq!(a, b, epsilon = 1e-4);
}
}
#[test]
fn test_mobius_add_zero() {
let u = vec![0.1, 0.2, 0.3];
let zero = vec![0.0, 0.0, 0.0];
let result = mobius_add(&u, &zero, DEFAULT_CURVATURE);
for (a, b) in u.iter().zip(result.iter()) {
assert_relative_eq!(a, b, epsilon = 1e-5);
}
}
#[test]
fn test_mobius_add_stays_in_ball() {
let u = vec![0.8, 0.0, 0.0];
let v = vec![0.0, 0.8, 0.0];
let result = mobius_add(&u, &v, DEFAULT_CURVATURE);
let norm = l2_norm(&result);
assert!(norm < 1.0);
}
#[test]
fn test_hyperbolic_midpoint() {
let u = vec![0.1, 0.0, 0.0];
let v = vec![0.5, 0.0, 0.0];
let mid = hyperbolic_midpoint(&u, &v, DEFAULT_CURVATURE);
// Midpoint should be between u and v
assert!(mid[0] > u[0] && mid[0] < v[0]);
// Distances should be approximately equal
let dist_u = poincare_distance(&u, &mid, DEFAULT_CURVATURE);
let dist_v = poincare_distance(&v, &mid, DEFAULT_CURVATURE);
assert_relative_eq!(dist_u, dist_v, epsilon = 1e-3);
}
#[test]
fn test_euclidean_poincare_conversion() {
let euclidean = vec![0.3, 0.2, 0.1];
let poincare = euclidean_to_poincare(&euclidean, DEFAULT_CURVATURE);
assert!(is_in_ball(&poincare));
let back = poincare_to_euclidean(&poincare, DEFAULT_CURVATURE);
for (a, b) in euclidean.iter().zip(back.iter()) {
assert_relative_eq!(a, b, epsilon = 1e-4);
}
}
#[test]
fn test_conformal_factor() {
let origin = vec![0.0, 0.0, 0.0];
assert_relative_eq!(conformal_factor(&origin), 2.0, epsilon = 1e-5);
// Near boundary, factor should be large
let near_boundary = vec![0.99, 0.0, 0.0];
assert!(conformal_factor(&near_boundary) > 10.0);
}
#[test]
fn test_is_in_ball() {
assert!(is_in_ball(&[0.0, 0.0, 0.0]));
assert!(is_in_ball(&[0.5, 0.5, 0.0]));
assert!(!is_in_ball(&[1.0, 0.0, 0.0]));
assert!(!is_in_ball(&[0.6, 0.6, 0.6])); // norm > 1
}
}

View File

@@ -0,0 +1,653 @@
//! Graph storage for similarity relationships between embeddings.
//!
//! This module provides storage and querying for the similarity graph,
//! supporting edge types like SIMILAR, SEQUENTIAL, and SAME_CLUSTER.
use std::collections::{HashMap, HashSet, VecDeque};
use async_trait::async_trait;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use tracing::{debug, instrument};
use crate::domain::{
EdgeType, EmbeddingId, GraphEdgeRepository, GraphTraversal, SimilarityEdge, VectorError,
};
use crate::domain::repository::RepoResult;
/// In-memory graph store for similarity edges.
///
/// This implementation uses adjacency lists for efficient edge traversal
/// and supports bidirectional lookups.
#[derive(Debug, Default)]
pub struct InMemoryGraphStore {
/// Forward edges: from_id -> list of edges
forward: RwLock<HashMap<EmbeddingId, Vec<SimilarityEdge>>>,
/// Reverse edges: to_id -> list of edges
reverse: RwLock<HashMap<EmbeddingId, Vec<SimilarityEdge>>>,
/// Total edge count
count: RwLock<usize>,
}
impl InMemoryGraphStore {
/// Create a new empty graph store.
pub fn new() -> Self {
Self::default()
}
/// Get the number of edges.
pub fn len(&self) -> usize {
*self.count.read()
}
/// Check if the store is empty.
pub fn is_empty(&self) -> bool {
self.len() == 0
}
/// Get all unique node IDs in the graph.
pub fn node_ids(&self) -> Vec<EmbeddingId> {
let forward = self.forward.read();
let reverse = self.reverse.read();
let mut ids: HashSet<EmbeddingId> = forward.keys().copied().collect();
ids.extend(reverse.keys().copied());
ids.into_iter().collect()
}
/// Get the degree (number of connections) for a node.
pub fn degree(&self, id: &EmbeddingId) -> usize {
let forward = self.forward.read();
let reverse = self.reverse.read();
let out_degree = forward.get(id).map(|e| e.len()).unwrap_or(0);
let in_degree = reverse.get(id).map(|e| e.len()).unwrap_or(0);
out_degree + in_degree
}
/// Export the graph for serialization.
pub fn export(&self) -> GraphExport {
let forward = self.forward.read();
let edges: Vec<_> = forward.values().flatten().cloned().collect();
GraphExport { edges }
}
/// Import a graph from serialized data.
pub fn import(&self, data: GraphExport) -> RepoResult<()> {
let mut forward = self.forward.write();
let mut reverse = self.reverse.write();
let mut count = self.count.write();
forward.clear();
reverse.clear();
*count = 0;
for edge in data.edges {
forward
.entry(edge.from_id)
.or_default()
.push(edge.clone());
reverse
.entry(edge.to_id)
.or_default()
.push(edge);
*count += 1;
}
Ok(())
}
}
#[async_trait]
impl GraphEdgeRepository for InMemoryGraphStore {
#[instrument(skip(self, edge))]
async fn add_edge(&self, edge: SimilarityEdge) -> RepoResult<()> {
let mut forward = self.forward.write();
let mut reverse = self.reverse.write();
let mut count = self.count.write();
forward
.entry(edge.from_id)
.or_default()
.push(edge.clone());
reverse
.entry(edge.to_id)
.or_default()
.push(edge);
*count += 1;
debug!("Added edge, total count: {}", *count);
Ok(())
}
#[instrument(skip(self, edges), fields(count = edges.len()))]
async fn add_edges(&self, edges: &[SimilarityEdge]) -> RepoResult<()> {
let mut forward = self.forward.write();
let mut reverse = self.reverse.write();
let mut count = self.count.write();
for edge in edges {
forward
.entry(edge.from_id)
.or_default()
.push(edge.clone());
reverse
.entry(edge.to_id)
.or_default()
.push(edge.clone());
*count += 1;
}
debug!("Added {} edges, total count: {}", edges.len(), *count);
Ok(())
}
async fn remove_edge(&self, from: &EmbeddingId, to: &EmbeddingId) -> RepoResult<()> {
let mut forward = self.forward.write();
let mut reverse = self.reverse.write();
let mut count = self.count.write();
let mut removed = false;
if let Some(edges) = forward.get_mut(from) {
let len_before = edges.len();
edges.retain(|e| &e.to_id != to);
if edges.len() < len_before {
removed = true;
}
if edges.is_empty() {
forward.remove(from);
}
}
if let Some(edges) = reverse.get_mut(to) {
edges.retain(|e| &e.from_id != from);
if edges.is_empty() {
reverse.remove(to);
}
}
if removed {
*count = count.saturating_sub(1);
}
Ok(())
}
async fn get_edges_from(&self, id: &EmbeddingId) -> RepoResult<Vec<SimilarityEdge>> {
let forward = self.forward.read();
Ok(forward.get(id).cloned().unwrap_or_default())
}
async fn get_edges_to(&self, id: &EmbeddingId) -> RepoResult<Vec<SimilarityEdge>> {
let reverse = self.reverse.read();
Ok(reverse.get(id).cloned().unwrap_or_default())
}
async fn get_edges_by_type(
&self,
id: &EmbeddingId,
edge_type: EdgeType,
) -> RepoResult<Vec<SimilarityEdge>> {
let forward = self.forward.read();
Ok(forward
.get(id)
.map(|edges| {
edges
.iter()
.filter(|e| e.edge_type == edge_type)
.cloned()
.collect()
})
.unwrap_or_default())
}
async fn get_strong_edges(
&self,
id: &EmbeddingId,
min_similarity: f32,
) -> RepoResult<Vec<SimilarityEdge>> {
let forward = self.forward.read();
Ok(forward
.get(id)
.map(|edges| {
edges
.iter()
.filter(|e| e.similarity() >= min_similarity)
.cloned()
.collect()
})
.unwrap_or_default())
}
async fn edge_count(&self) -> RepoResult<usize> {
Ok(*self.count.read())
}
async fn clear(&self) -> RepoResult<()> {
self.forward.write().clear();
self.reverse.write().clear();
*self.count.write() = 0;
Ok(())
}
async fn remove_edges_for(&self, id: &EmbeddingId) -> RepoResult<()> {
let mut forward = self.forward.write();
let mut reverse = self.reverse.write();
let mut count = self.count.write();
// Remove outgoing edges
if let Some(edges) = forward.remove(id) {
*count = count.saturating_sub(edges.len());
// Clean up reverse references
for edge in edges {
if let Some(rev_edges) = reverse.get_mut(&edge.to_id) {
rev_edges.retain(|e| &e.from_id != id);
if rev_edges.is_empty() {
reverse.remove(&edge.to_id);
}
}
}
}
// Remove incoming edges
if let Some(edges) = reverse.remove(id) {
*count = count.saturating_sub(edges.len());
// Clean up forward references
for edge in edges {
if let Some(fwd_edges) = forward.get_mut(&edge.from_id) {
fwd_edges.retain(|e| &e.to_id != id);
if fwd_edges.is_empty() {
forward.remove(&edge.from_id);
}
}
}
}
Ok(())
}
}
#[async_trait]
impl GraphTraversal for InMemoryGraphStore {
async fn shortest_path(
&self,
from: &EmbeddingId,
to: &EmbeddingId,
max_depth: usize,
) -> RepoResult<Option<Vec<EmbeddingId>>> {
if from == to {
return Ok(Some(vec![*from]));
}
let forward = self.forward.read();
// BFS
let mut visited: HashSet<EmbeddingId> = HashSet::new();
let mut queue: VecDeque<(EmbeddingId, Vec<EmbeddingId>)> = VecDeque::new();
visited.insert(*from);
queue.push_back((*from, vec![*from]));
while let Some((current, path)) = queue.pop_front() {
if path.len() > max_depth {
continue;
}
if let Some(edges) = forward.get(&current) {
for edge in edges {
if &edge.to_id == to {
let mut result = path.clone();
result.push(edge.to_id);
return Ok(Some(result));
}
if !visited.contains(&edge.to_id) {
visited.insert(edge.to_id);
let mut new_path = path.clone();
new_path.push(edge.to_id);
queue.push_back((edge.to_id, new_path));
}
}
}
}
Ok(None)
}
async fn neighbors_within_hops(
&self,
id: &EmbeddingId,
hops: usize,
) -> RepoResult<Vec<(EmbeddingId, usize)>> {
let forward = self.forward.read();
let mut visited: HashMap<EmbeddingId, usize> = HashMap::new();
let mut queue: VecDeque<(EmbeddingId, usize)> = VecDeque::new();
visited.insert(*id, 0);
queue.push_back((*id, 0));
while let Some((current, depth)) = queue.pop_front() {
if depth >= hops {
continue;
}
if let Some(edges) = forward.get(&current) {
for edge in edges {
if !visited.contains_key(&edge.to_id) {
visited.insert(edge.to_id, depth + 1);
queue.push_back((edge.to_id, depth + 1));
}
}
}
}
// Remove the starting node
visited.remove(id);
Ok(visited.into_iter().collect())
}
async fn connected_components(&self) -> RepoResult<Vec<Vec<EmbeddingId>>> {
let forward = self.forward.read();
let reverse = self.reverse.read();
// Get all nodes
let mut all_nodes: HashSet<EmbeddingId> = forward.keys().copied().collect();
all_nodes.extend(reverse.keys().copied());
let mut visited: HashSet<EmbeddingId> = HashSet::new();
let mut components: Vec<Vec<EmbeddingId>> = Vec::new();
for &start in &all_nodes {
if visited.contains(&start) {
continue;
}
let mut component: Vec<EmbeddingId> = Vec::new();
let mut stack: Vec<EmbeddingId> = vec![start];
while let Some(current) = stack.pop() {
if visited.contains(&current) {
continue;
}
visited.insert(current);
component.push(current);
// Add neighbors (both directions for undirected view)
if let Some(edges) = forward.get(&current) {
for edge in edges {
if !visited.contains(&edge.to_id) {
stack.push(edge.to_id);
}
}
}
if let Some(edges) = reverse.get(&current) {
for edge in edges {
if !visited.contains(&edge.from_id) {
stack.push(edge.from_id);
}
}
}
}
if !component.is_empty() {
components.push(component);
}
}
Ok(components)
}
async fn centrality_scores(&self) -> RepoResult<Vec<(EmbeddingId, f32)>> {
// Simple degree centrality (normalized)
let forward = self.forward.read();
let reverse = self.reverse.read();
let mut degrees: HashMap<EmbeddingId, usize> = HashMap::new();
for (id, edges) in forward.iter() {
*degrees.entry(*id).or_default() += edges.len();
}
for (id, edges) in reverse.iter() {
*degrees.entry(*id).or_default() += edges.len();
}
let max_degree = degrees.values().copied().max().unwrap_or(1) as f32;
let scores: Vec<_> = degrees
.into_iter()
.map(|(id, degree)| (id, degree as f32 / max_degree))
.collect();
Ok(scores)
}
}
/// Serializable export format for the graph.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphExport {
/// All edges in the graph.
pub edges: Vec<SimilarityEdge>,
}
impl GraphExport {
/// Create a new empty export.
pub fn new() -> Self {
Self { edges: Vec::new() }
}
/// Save to a file.
pub fn save(&self, path: &std::path::Path) -> Result<(), VectorError> {
let file = std::fs::File::create(path)?;
let writer = std::io::BufWriter::new(file);
bincode::serialize_into(writer, self)?;
Ok(())
}
/// Load from a file.
pub fn load(path: &std::path::Path) -> Result<Self, VectorError> {
let file = std::fs::File::open(path)?;
let reader = std::io::BufReader::new(file);
let export = bincode::deserialize_from(reader)?;
Ok(export)
}
}
impl Default for GraphExport {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_add_and_query_edges() {
let store = InMemoryGraphStore::new();
let id1 = EmbeddingId::new();
let id2 = EmbeddingId::new();
let id3 = EmbeddingId::new();
let edge1 = SimilarityEdge::new(id1, id2, 0.1);
let edge2 = SimilarityEdge::new(id1, id3, 0.2);
store.add_edge(edge1).await.unwrap();
store.add_edge(edge2).await.unwrap();
assert_eq!(store.edge_count().await.unwrap(), 2);
let from_edges = store.get_edges_from(&id1).await.unwrap();
assert_eq!(from_edges.len(), 2);
let to_edges = store.get_edges_to(&id2).await.unwrap();
assert_eq!(to_edges.len(), 1);
}
#[tokio::test]
async fn test_remove_edge() {
let store = InMemoryGraphStore::new();
let id1 = EmbeddingId::new();
let id2 = EmbeddingId::new();
store.add_edge(SimilarityEdge::new(id1, id2, 0.1)).await.unwrap();
assert_eq!(store.edge_count().await.unwrap(), 1);
store.remove_edge(&id1, &id2).await.unwrap();
assert_eq!(store.edge_count().await.unwrap(), 0);
}
#[tokio::test]
async fn test_edges_by_type() {
let store = InMemoryGraphStore::new();
let id1 = EmbeddingId::new();
let id2 = EmbeddingId::new();
let id3 = EmbeddingId::new();
store
.add_edge(SimilarityEdge::new(id1, id2, 0.1).with_type(EdgeType::Similar))
.await
.unwrap();
store
.add_edge(SimilarityEdge::sequential(id1, id3))
.await
.unwrap();
let similar = store.get_edges_by_type(&id1, EdgeType::Similar).await.unwrap();
assert_eq!(similar.len(), 1);
let sequential = store.get_edges_by_type(&id1, EdgeType::Sequential).await.unwrap();
assert_eq!(sequential.len(), 1);
}
#[tokio::test]
async fn test_strong_edges() {
let store = InMemoryGraphStore::new();
let id1 = EmbeddingId::new();
let id2 = EmbeddingId::new();
let id3 = EmbeddingId::new();
store.add_edge(SimilarityEdge::new(id1, id2, 0.1)).await.unwrap(); // 0.9 similarity
store.add_edge(SimilarityEdge::new(id1, id3, 0.5)).await.unwrap(); // 0.5 similarity
let strong = store.get_strong_edges(&id1, 0.8).await.unwrap();
assert_eq!(strong.len(), 1);
assert_eq!(strong[0].to_id, id2);
}
#[tokio::test]
async fn test_shortest_path() {
let store = InMemoryGraphStore::new();
let id1 = EmbeddingId::new();
let id2 = EmbeddingId::new();
let id3 = EmbeddingId::new();
store.add_edge(SimilarityEdge::new(id1, id2, 0.1)).await.unwrap();
store.add_edge(SimilarityEdge::new(id2, id3, 0.1)).await.unwrap();
let path = store.shortest_path(&id1, &id3, 10).await.unwrap();
assert!(path.is_some());
let path = path.unwrap();
assert_eq!(path.len(), 3);
assert_eq!(path[0], id1);
assert_eq!(path[2], id3);
}
#[tokio::test]
async fn test_neighbors_within_hops() {
let store = InMemoryGraphStore::new();
let id1 = EmbeddingId::new();
let id2 = EmbeddingId::new();
let id3 = EmbeddingId::new();
let id4 = EmbeddingId::new();
store.add_edge(SimilarityEdge::new(id1, id2, 0.1)).await.unwrap();
store.add_edge(SimilarityEdge::new(id2, id3, 0.1)).await.unwrap();
store.add_edge(SimilarityEdge::new(id3, id4, 0.1)).await.unwrap();
let neighbors = store.neighbors_within_hops(&id1, 2).await.unwrap();
let neighbor_ids: HashSet<_> = neighbors.iter().map(|(id, _)| *id).collect();
assert!(neighbor_ids.contains(&id2));
assert!(neighbor_ids.contains(&id3));
assert!(!neighbor_ids.contains(&id4)); // 3 hops away
}
#[tokio::test]
async fn test_connected_components() {
let store = InMemoryGraphStore::new();
// Component 1
let id1 = EmbeddingId::new();
let id2 = EmbeddingId::new();
store.add_edge(SimilarityEdge::new(id1, id2, 0.1)).await.unwrap();
// Component 2
let id3 = EmbeddingId::new();
let id4 = EmbeddingId::new();
store.add_edge(SimilarityEdge::new(id3, id4, 0.1)).await.unwrap();
let components = store.connected_components().await.unwrap();
assert_eq!(components.len(), 2);
}
#[tokio::test]
async fn test_remove_edges_for() {
let store = InMemoryGraphStore::new();
let id1 = EmbeddingId::new();
let id2 = EmbeddingId::new();
let id3 = EmbeddingId::new();
store.add_edge(SimilarityEdge::new(id1, id2, 0.1)).await.unwrap();
store.add_edge(SimilarityEdge::new(id1, id3, 0.1)).await.unwrap();
store.add_edge(SimilarityEdge::new(id3, id1, 0.1)).await.unwrap();
assert_eq!(store.edge_count().await.unwrap(), 3);
store.remove_edges_for(&id1).await.unwrap();
assert_eq!(store.edge_count().await.unwrap(), 0);
}
#[tokio::test]
async fn test_export_import() {
let store = InMemoryGraphStore::new();
let id1 = EmbeddingId::new();
let id2 = EmbeddingId::new();
store.add_edge(SimilarityEdge::new(id1, id2, 0.1)).await.unwrap();
let export = store.export();
assert_eq!(export.edges.len(), 1);
let new_store = InMemoryGraphStore::new();
new_store.import(export).unwrap();
assert_eq!(new_store.edge_count().await.unwrap(), 1);
}
}

View File

@@ -0,0 +1,574 @@
//! HNSW index implementation for high-performance vector search.
//!
//! This module wraps the `instant-distance` crate to provide a thread-safe,
//! serializable HNSW index optimized for embedding vectors.
//!
//! ## Performance Characteristics
//!
//! - Insert: O(log n) average, O(n) worst case
//! - Search: O(log n) average for k-NN queries
//! - Memory: O(n * (m + dim)) where m is connections per node
//!
//! ## Target: 150x speedup over brute-force
//!
//! For 1M vectors at 1536 dimensions:
//! - Brute-force: ~500ms per query
//! - HNSW: ~3ms per query (166x speedup)
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufReader, BufWriter};
use std::path::Path;
use instant_distance::{Builder, HnswMap, Point, Search};
use serde::{Deserialize, Serialize};
use tracing::{debug, instrument};
use crate::domain::{EmbeddingId, HnswConfig, VectorError};
/// A point wrapper for instant-distance that holds the vector data.
#[derive(Clone, Debug)]
struct VectorPoint {
data: Vec<f32>,
}
impl instant_distance::Point for VectorPoint {
fn distance(&self, other: &Self) -> f32 {
// Cosine distance = 1 - cosine_similarity
// For normalized vectors, dot product = cosine similarity
let dot: f32 = self
.data
.iter()
.zip(other.data.iter())
.map(|(a, b)| a * b)
.sum();
let norm_a: f32 = self.data.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = other.data.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a < 1e-10 || norm_b < 1e-10 {
return 1.0; // Maximum distance for zero vectors
}
let cosine_sim = dot / (norm_a * norm_b);
1.0 - cosine_sim.clamp(-1.0, 1.0)
}
}
/// Serializable representation of the index for persistence.
#[derive(Serialize, Deserialize)]
struct SerializedIndex {
/// All vectors stored in the index.
vectors: Vec<(EmbeddingId, Vec<f32>)>,
/// Dimensions of vectors.
dimensions: usize,
}
/// HNSW index for fast approximate nearest neighbor search.
///
/// This index provides O(log n) search performance with high recall,
/// making it suitable for large-scale embedding search.
pub struct HnswIndex {
/// The HNSW map from instant-distance.
/// Uses Option to allow rebuilding.
inner: Option<HnswMap<VectorPoint, EmbeddingId>>,
/// Mapping from embedding ID to internal index.
id_to_idx: HashMap<EmbeddingId, usize>,
/// Storage of vectors for reconstruction and serialization.
vectors: Vec<(EmbeddingId, Vec<f32>)>,
/// Configuration.
config: HnswConfig,
/// Whether the index needs rebuilding.
needs_rebuild: bool,
/// Search buffer (reused across queries).
_search_buf: Search,
}
impl HnswIndex {
/// Create a new empty HNSW index.
pub fn new(config: &HnswConfig) -> Self {
Self {
inner: None,
id_to_idx: HashMap::new(),
vectors: Vec::new(),
config: config.clone(),
needs_rebuild: false,
_search_buf: Search::default(),
}
}
/// Insert a vector into the index.
///
/// Note: HNSW indices are typically built in batch for best performance.
/// Single insertions trigger a rebuild which is expensive.
#[instrument(skip(self, vector), fields(dim = vector.len()))]
pub fn insert(&mut self, id: EmbeddingId, vector: &[f32]) -> Result<(), VectorError> {
// Validate dimensions
if vector.len() != self.config.dimensions {
return Err(VectorError::dimension_mismatch(
self.config.dimensions,
vector.len(),
));
}
// Check capacity
if self.vectors.len() >= self.config.max_elements {
return Err(VectorError::capacity_exceeded(
self.config.max_elements,
self.vectors.len(),
));
}
// Check for duplicate
if self.id_to_idx.contains_key(&id) {
return Err(VectorError::DuplicateId(id));
}
// Store vector
let idx = self.vectors.len();
self.id_to_idx.insert(id, idx);
self.vectors.push((id, vector.to_vec()));
self.needs_rebuild = true;
debug!(id = %id, idx = idx, "Inserted vector");
Ok(())
}
/// Search for the k nearest neighbors.
///
/// Returns a vector of (id, distance) tuples sorted by ascending distance.
#[instrument(skip(self, query), fields(dim = query.len(), k))]
pub fn search(&self, query: &[f32], k: usize) -> Vec<(EmbeddingId, f32)> {
if self.vectors.is_empty() {
return Vec::new();
}
// If index not built, fall back to brute force
let inner = match &self.inner {
Some(inner) => inner,
None => return self.brute_force_search(query, k),
};
let query_point = VectorPoint {
data: query.to_vec(),
};
// Create a new search buffer for this query
let mut search = Search::default();
let results = inner.search(&query_point, &mut search);
results
.take(k)
.map(|item| {
let id = *item.value;
let distance = item.distance;
(id, distance)
})
.collect()
}
/// Brute-force search fallback for small indices or when HNSW isn't built.
fn brute_force_search(&self, query: &[f32], k: usize) -> Vec<(EmbeddingId, f32)> {
let query_point = VectorPoint {
data: query.to_vec(),
};
let mut distances: Vec<_> = self
.vectors
.iter()
.map(|(id, vec)| {
let point = VectorPoint { data: vec.clone() };
let dist = query_point.distance(&point);
(*id, dist)
})
.collect();
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
distances.truncate(k);
distances
}
/// Remove a vector from the index.
///
/// Note: HNSW doesn't efficiently support deletions. This marks the
/// vector as deleted and triggers a rebuild on next search/rebuild call.
pub fn remove(&mut self, id: &EmbeddingId) -> Result<(), VectorError> {
let idx = self
.id_to_idx
.remove(id)
.ok_or_else(|| VectorError::NotFound(*id))?;
// Remove from vectors (swap-remove for efficiency)
self.vectors.swap_remove(idx);
// Update index mapping for swapped element
if idx < self.vectors.len() {
let swapped_id = self.vectors[idx].0;
self.id_to_idx.insert(swapped_id, idx);
}
self.needs_rebuild = true;
debug!(id = %id, "Removed vector");
Ok(())
}
/// Check if a vector exists in the index.
#[inline]
pub fn contains(&self, id: &EmbeddingId) -> bool {
self.id_to_idx.contains_key(id)
}
/// Get the number of vectors in the index.
#[inline]
pub fn len(&self) -> usize {
self.vectors.len()
}
/// Check if the index is empty.
#[inline]
pub fn is_empty(&self) -> bool {
self.vectors.is_empty()
}
/// Get a vector by its ID.
pub fn get_vector(&self, id: &EmbeddingId) -> Option<Vec<f32>> {
self.id_to_idx
.get(id)
.map(|&idx| self.vectors[idx].1.clone())
}
/// Clear all vectors from the index.
pub fn clear(&mut self) {
self.vectors.clear();
self.id_to_idx.clear();
self.inner = None;
self.needs_rebuild = false;
}
/// Build or rebuild the HNSW index from the current vectors.
///
/// This should be called after batch insertions for optimal performance.
#[instrument(skip(self))]
pub fn build(&mut self) -> Result<(), VectorError> {
if self.vectors.is_empty() {
self.inner = None;
self.needs_rebuild = false;
return Ok(());
}
let points: Vec<VectorPoint> = self
.vectors
.iter()
.map(|(_, vec)| VectorPoint { data: vec.clone() })
.collect();
let values: Vec<EmbeddingId> = self.vectors.iter().map(|(id, _)| *id).collect();
// Build the HNSW index
let hnsw = Builder::default()
.ef_construction(self.config.ef_construction)
.build(points, values);
self.inner = Some(hnsw);
self.needs_rebuild = false;
debug!(
vectors = self.vectors.len(),
"Built HNSW index"
);
Ok(())
}
/// Rebuild the index if needed.
pub fn rebuild_if_needed(&mut self) -> Result<(), VectorError> {
if self.needs_rebuild {
self.build()
} else {
Ok(())
}
}
/// Save the index to a file.
#[instrument(skip(self))]
pub fn save(&self, path: &Path) -> Result<(), VectorError> {
let file = File::create(path)?;
let writer = BufWriter::new(file);
let serialized = SerializedIndex {
vectors: self.vectors.clone(),
dimensions: self.config.dimensions,
};
bincode::serialize_into(writer, &serialized)?;
debug!(path = %path.display(), vectors = self.vectors.len(), "Saved index");
Ok(())
}
/// Load an index from a file.
#[instrument]
pub fn load(path: &Path) -> Result<Self, VectorError> {
if !path.exists() {
return Err(VectorError::FileNotFound(path.to_path_buf()));
}
let file = File::open(path)?;
let reader = BufReader::new(file);
let serialized: SerializedIndex = bincode::deserialize_from(reader)
.map_err(|e| VectorError::corrupted(format!("Failed to deserialize: {e}")))?;
let config = HnswConfig::for_dimension(serialized.dimensions);
let mut index = Self::new(&config);
// Restore vectors
for (id, vector) in serialized.vectors {
let idx = index.vectors.len();
index.id_to_idx.insert(id, idx);
index.vectors.push((id, vector));
}
// Build the HNSW structure
index.build()?;
debug!(
path = %path.display(),
vectors = index.vectors.len(),
"Loaded index"
);
Ok(index)
}
/// Get all embedding IDs in the index.
pub fn ids(&self) -> impl Iterator<Item = &EmbeddingId> {
self.id_to_idx.keys()
}
/// Iterate over all vectors.
pub fn iter(&self) -> impl Iterator<Item = (&EmbeddingId, &[f32])> {
self.vectors.iter().map(|(id, vec)| (id, vec.as_slice()))
}
}
impl Default for HnswIndex {
fn default() -> Self {
Self::new(&HnswConfig::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::NamedTempFile;
fn create_test_index() -> HnswIndex {
let config = HnswConfig::for_dimension(64).with_max_elements(1000);
HnswIndex::new(&config)
}
fn random_vector(dim: usize, seed: u64) -> Vec<f32> {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
(0..dim)
.map(|i| {
let mut hasher = DefaultHasher::new();
(seed, i).hash(&mut hasher);
let h = hasher.finish();
((h % 1000) as f32 / 1000.0) * 2.0 - 1.0
})
.collect()
}
#[test]
fn test_insert_and_search() {
let mut index = create_test_index();
let id1 = EmbeddingId::new();
let id2 = EmbeddingId::new();
let v1 = random_vector(64, 1);
let v2 = random_vector(64, 2);
index.insert(id1, &v1).unwrap();
index.insert(id2, &v2).unwrap();
index.build().unwrap();
let results = index.search(&v1, 2);
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, id1); // Closest should be itself
}
#[test]
fn test_dimension_validation() {
let mut index = create_test_index();
let id = EmbeddingId::new();
let wrong_dim = random_vector(32, 1);
let result = index.insert(id, &wrong_dim);
assert!(matches!(
result,
Err(VectorError::DimensionMismatch { expected: 64, got: 32 })
));
}
#[test]
fn test_duplicate_detection() {
let mut index = create_test_index();
let id = EmbeddingId::new();
let v = random_vector(64, 1);
index.insert(id, &v).unwrap();
let result = index.insert(id, &v);
assert!(matches!(result, Err(VectorError::DuplicateId(_))));
}
#[test]
fn test_remove() {
let mut index = create_test_index();
let id1 = EmbeddingId::new();
let id2 = EmbeddingId::new();
index.insert(id1, &random_vector(64, 1)).unwrap();
index.insert(id2, &random_vector(64, 2)).unwrap();
assert_eq!(index.len(), 2);
assert!(index.contains(&id1));
index.remove(&id1).unwrap();
assert_eq!(index.len(), 1);
assert!(!index.contains(&id1));
assert!(index.contains(&id2));
}
#[test]
fn test_capacity_limit() {
let config = HnswConfig::for_dimension(64).with_max_elements(2);
let mut index = HnswIndex::new(&config);
index.insert(EmbeddingId::new(), &random_vector(64, 1)).unwrap();
index.insert(EmbeddingId::new(), &random_vector(64, 2)).unwrap();
let result = index.insert(EmbeddingId::new(), &random_vector(64, 3));
assert!(matches!(result, Err(VectorError::CapacityExceeded { .. })));
}
#[test]
fn test_save_and_load() {
let mut index = create_test_index();
let ids: Vec<_> = (0..10).map(|_| EmbeddingId::new()).collect();
for (i, id) in ids.iter().enumerate() {
index.insert(*id, &random_vector(64, i as u64)).unwrap();
}
index.build().unwrap();
let file = NamedTempFile::new().unwrap();
index.save(file.path()).unwrap();
let loaded = HnswIndex::load(file.path()).unwrap();
assert_eq!(loaded.len(), index.len());
for id in &ids {
assert!(loaded.contains(id));
}
}
#[test]
fn test_brute_force_fallback() {
let mut index = create_test_index();
let id1 = EmbeddingId::new();
let id2 = EmbeddingId::new();
let v1 = random_vector(64, 1);
let v2 = random_vector(64, 2);
index.insert(id1, &v1).unwrap();
index.insert(id2, &v2).unwrap();
// Don't build - should use brute force
let results = index.search(&v1, 2);
assert_eq!(results.len(), 2);
}
#[test]
fn test_get_vector() {
let mut index = create_test_index();
let id = EmbeddingId::new();
let v = random_vector(64, 1);
index.insert(id, &v).unwrap();
let retrieved = index.get_vector(&id).unwrap();
assert_eq!(retrieved, v);
let unknown = EmbeddingId::new();
assert!(index.get_vector(&unknown).is_none());
}
#[test]
fn test_search_accuracy() {
// Test that HNSW finds correct nearest neighbors
let config = HnswConfig::for_dimension(64)
.with_max_elements(100)
.with_ef_construction(200)
.with_ef_search(128);
let mut index = HnswIndex::new(&config);
// Insert vectors with known relationships
let base: Vec<f32> = (0..64).map(|i| i as f32 / 64.0).collect();
let id_base = EmbeddingId::new();
index.insert(id_base, &base).unwrap();
// Insert similar vectors (small perturbations)
let similar_ids: Vec<_> = (0..5)
.map(|i| {
let id = EmbeddingId::new();
let v: Vec<f32> = base
.iter()
.map(|&x| x + 0.01 * (i as f32 + 1.0))
.collect();
index.insert(id, &v).unwrap();
id
})
.collect();
// Insert dissimilar vectors
for i in 0..10 {
let id = EmbeddingId::new();
let v: Vec<f32> = (0..64).map(|j| ((i + j) % 7) as f32 / 7.0).collect();
index.insert(id, &v).unwrap();
}
index.build().unwrap();
// Search for vectors similar to base
let results = index.search(&base, 6);
// The base vector should be first
assert_eq!(results[0].0, id_base);
// Similar vectors should be in top results
let top_ids: std::collections::HashSet<_> =
results.iter().take(6).map(|(id, _)| *id).collect();
for similar_id in &similar_ids {
assert!(
top_ids.contains(similar_id),
"Similar vector not found in top results"
);
}
}
}

View File

@@ -0,0 +1,12 @@
//! Infrastructure layer for the Vector Space bounded context.
//!
//! Contains:
//! - HNSW index implementation
//! - Graph storage adapters
//! - Persistence implementations
pub mod hnsw_index;
pub mod graph_store;
pub use hnsw_index::HnswIndex;
pub use graph_store::InMemoryGraphStore;

View File

@@ -0,0 +1,116 @@
//! # sevensense-vector
//!
//! Vector database operations and HNSW indexing for the 7sense bioacoustics platform.
//!
//! This crate provides:
//! - Local HNSW index with 150x search speedup over brute-force
//! - Optional Qdrant client wrapper for distributed deployments
//! - Collection management
//! - Similarity search with filtering
//! - Batch operations and persistence
//! - Hyperbolic embeddings for hierarchical relationships
//!
//! ## Architecture
//!
//! Following Domain-Driven Design:
//! ```text
//! sevensense-vector
//! ├── domain/ # Core entities, value objects, repository traits
//! │ ├── entities.rs # EmbeddingId, HnswConfig, SimilarityEdge
//! │ ├── repository.rs # VectorIndexRepository, GraphEdgeRepository
//! │ └── error.rs # VectorError
//! ├── application/ # Service layer with use cases
//! │ └── services.rs # VectorSpaceService
//! └── infrastructure/ # HNSW implementation and storage adapters
//! ├── hnsw_index.rs # Local HNSW index
//! └── graph_store.rs # Edge storage
//! ```
//!
//! ## Performance Targets
//!
//! - 150x search speedup over brute-force linear scan
//! - Sub-millisecond queries for up to 1M vectors
//! - Efficient batch insertion with parallelization
//!
//! ## Example
//!
//! ```rust,ignore
//! use sevensense_vector::prelude::*;
//!
//! // Create a vector space service
//! let config = HnswConfig::for_dimension(1536);
//! let service = VectorSpaceService::new(config);
//!
//! // Add embeddings
//! let id = EmbeddingId::new();
//! let vector = vec![0.1; 1536];
//! service.add_embedding(id, vector).await?;
//!
//! // Search for neighbors
//! let query = vec![0.15; 1536];
//! let neighbors = service.find_neighbors(&query, 10).await?;
//! ```
#![warn(missing_docs)]
#![warn(clippy::all)]
#![warn(clippy::pedantic)]
#![warn(clippy::nursery)]
#![allow(clippy::module_name_repetitions)]
#![allow(clippy::must_use_candidate)]
#![allow(clippy::missing_errors_doc)]
#![allow(clippy::missing_panics_doc)]
pub mod domain;
pub mod application;
pub mod infrastructure;
mod distance;
mod hyperbolic;
// Re-export commonly used types
pub use domain::entities::{
EmbeddingId, HnswConfig, SimilarityEdge, EdgeType, VectorIndex, Timestamp,
};
pub use domain::repository::{VectorIndexRepository, GraphEdgeRepository};
pub use application::services::{VectorSpaceService, Neighbor, SearchOptions};
pub use infrastructure::hnsw_index::HnswIndex;
pub use distance::{cosine_distance, euclidean_distance, cosine_similarity, normalize_vector};
pub use hyperbolic::{poincare_distance, exp_map, log_map, mobius_add};
/// Error types for vector operations
pub mod error {
pub use crate::domain::error::*;
}
/// Prelude module for convenient imports
pub mod prelude {
//! Common imports for vector operations.
pub use crate::domain::entities::{
EmbeddingId, HnswConfig, SimilarityEdge, EdgeType, VectorIndex,
};
pub use crate::domain::repository::{VectorIndexRepository, GraphEdgeRepository};
pub use crate::application::services::{VectorSpaceService, Neighbor, SearchOptions};
pub use crate::infrastructure::hnsw_index::HnswIndex;
pub use crate::distance::{cosine_distance, euclidean_distance, cosine_similarity};
pub use crate::error::VectorError;
}
/// Crate version information
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_library_compiles() {
// Basic smoke test
let config = HnswConfig::default();
assert_eq!(config.m, 32);
}
#[test]
fn test_version() {
assert!(!VERSION.is_empty());
}
}