263 lines
7.0 KiB
Rust
263 lines
7.0 KiB
Rust
//! Distance Metrics for iOS/Browser WASM
|
|
//!
|
|
//! Implements all key Ruvector distance functions with SIMD optimization.
|
|
//! Supports: Euclidean, Cosine, Manhattan, DotProduct, Hamming
|
|
|
|
use crate::simd;
|
|
|
|
/// Distance metric type
|
|
#[derive(Clone, Copy, Debug, PartialEq)]
|
|
#[repr(u8)]
|
|
pub enum DistanceMetric {
|
|
/// Euclidean (L2) distance
|
|
Euclidean = 0,
|
|
/// Cosine distance (1 - cosine_similarity)
|
|
Cosine = 1,
|
|
/// Dot product distance (negative dot for minimization)
|
|
DotProduct = 2,
|
|
/// Manhattan (L1) distance
|
|
Manhattan = 3,
|
|
/// Hamming distance (for binary vectors)
|
|
Hamming = 4,
|
|
}
|
|
|
|
impl DistanceMetric {
|
|
/// Parse from u8
|
|
pub fn from_u8(v: u8) -> Self {
|
|
match v {
|
|
0 => DistanceMetric::Euclidean,
|
|
1 => DistanceMetric::Cosine,
|
|
2 => DistanceMetric::DotProduct,
|
|
3 => DistanceMetric::Manhattan,
|
|
4 => DistanceMetric::Hamming,
|
|
_ => DistanceMetric::Cosine, // Default
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Calculate distance between two vectors
|
|
#[inline]
|
|
pub fn distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
|
|
match metric {
|
|
DistanceMetric::Euclidean => euclidean_distance(a, b),
|
|
DistanceMetric::Cosine => cosine_distance(a, b),
|
|
DistanceMetric::DotProduct => dot_product_distance(a, b),
|
|
DistanceMetric::Manhattan => manhattan_distance(a, b),
|
|
DistanceMetric::Hamming => hamming_distance_float(a, b),
|
|
}
|
|
}
|
|
|
|
/// Euclidean (L2) distance
|
|
#[inline]
|
|
pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
|
|
simd::l2_distance(a, b)
|
|
}
|
|
|
|
/// Squared Euclidean distance (faster, no sqrt)
|
|
#[inline]
|
|
pub fn euclidean_distance_squared(a: &[f32], b: &[f32]) -> f32 {
|
|
let len = a.len().min(b.len());
|
|
let mut sum = 0.0f32;
|
|
for i in 0..len {
|
|
let diff = a[i] - b[i];
|
|
sum += diff * diff;
|
|
}
|
|
sum
|
|
}
|
|
|
|
/// Cosine distance (1 - cosine_similarity)
|
|
#[inline]
|
|
pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
|
|
1.0 - simd::cosine_similarity(a, b)
|
|
}
|
|
|
|
/// Cosine similarity (not distance)
|
|
#[inline]
|
|
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
|
simd::cosine_similarity(a, b)
|
|
}
|
|
|
|
/// Dot product distance (negative for minimization)
|
|
#[inline]
|
|
pub fn dot_product_distance(a: &[f32], b: &[f32]) -> f32 {
|
|
-simd::dot_product(a, b)
|
|
}
|
|
|
|
/// Manhattan (L1) distance
|
|
#[inline]
|
|
pub fn manhattan_distance(a: &[f32], b: &[f32]) -> f32 {
|
|
let len = a.len().min(b.len());
|
|
let mut sum = 0.0f32;
|
|
for i in 0..len {
|
|
sum += (a[i] - b[i]).abs();
|
|
}
|
|
sum
|
|
}
|
|
|
|
/// Hamming distance for float vectors (count sign differences)
|
|
#[inline]
|
|
pub fn hamming_distance_float(a: &[f32], b: &[f32]) -> f32 {
|
|
let len = a.len().min(b.len());
|
|
let mut count = 0u32;
|
|
for i in 0..len {
|
|
if (a[i] > 0.0) != (b[i] > 0.0) {
|
|
count += 1;
|
|
}
|
|
}
|
|
count as f32
|
|
}
|
|
|
|
/// Hamming distance for binary packed vectors
|
|
#[inline]
|
|
pub fn hamming_distance_binary(a: &[u8], b: &[u8]) -> u32 {
|
|
let mut distance = 0u32;
|
|
for (&x, &y) in a.iter().zip(b.iter()) {
|
|
distance += (x ^ y).count_ones();
|
|
}
|
|
distance
|
|
}
|
|
|
|
// ============================================
|
|
// Batch Operations
|
|
// ============================================
|
|
|
|
/// Find k nearest neighbors from a set of vectors
|
|
pub fn find_nearest(
|
|
query: &[f32],
|
|
vectors: &[&[f32]],
|
|
k: usize,
|
|
metric: DistanceMetric,
|
|
) -> Vec<(usize, f32)> {
|
|
let mut distances: Vec<(usize, f32)> = vectors
|
|
.iter()
|
|
.enumerate()
|
|
.map(|(i, v)| (i, distance(query, v, metric)))
|
|
.collect();
|
|
|
|
// Partial sort for top-k
|
|
if k < distances.len() {
|
|
distances.select_nth_unstable_by(k, |a, b| {
|
|
a.1.partial_cmp(&b.1).unwrap_or(core::cmp::Ordering::Equal)
|
|
});
|
|
distances.truncate(k);
|
|
}
|
|
|
|
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(core::cmp::Ordering::Equal));
|
|
distances
|
|
}
|
|
|
|
/// Compute pairwise distances for a batch of queries
|
|
pub fn batch_distances(
|
|
queries: &[&[f32]],
|
|
vectors: &[&[f32]],
|
|
metric: DistanceMetric,
|
|
) -> Vec<Vec<f32>> {
|
|
queries
|
|
.iter()
|
|
.map(|q| {
|
|
vectors.iter().map(|v| distance(q, v, metric)).collect()
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
// ============================================
|
|
// WASM Exports
|
|
// ============================================
|
|
|
|
/// Calculate distance (WASM export)
|
|
#[no_mangle]
|
|
pub extern "C" fn calc_distance(
|
|
a_ptr: *const f32,
|
|
b_ptr: *const f32,
|
|
len: u32,
|
|
metric: u8,
|
|
) -> f32 {
|
|
unsafe {
|
|
let a = core::slice::from_raw_parts(a_ptr, len as usize);
|
|
let b = core::slice::from_raw_parts(b_ptr, len as usize);
|
|
distance(a, b, DistanceMetric::from_u8(metric))
|
|
}
|
|
}
|
|
|
|
/// Batch nearest neighbor search (WASM export)
|
|
/// Returns number of results written
|
|
#[no_mangle]
|
|
pub extern "C" fn find_nearest_batch(
|
|
query_ptr: *const f32,
|
|
query_len: u32,
|
|
vectors_ptr: *const f32,
|
|
num_vectors: u32,
|
|
vector_dim: u32,
|
|
k: u32,
|
|
metric: u8,
|
|
out_indices: *mut u32,
|
|
out_distances: *mut f32,
|
|
) -> u32 {
|
|
unsafe {
|
|
let query = core::slice::from_raw_parts(query_ptr, query_len as usize);
|
|
|
|
// Build vector slice references
|
|
let vector_data = core::slice::from_raw_parts(vectors_ptr, (num_vectors * vector_dim) as usize);
|
|
let vectors: Vec<&[f32]> = (0..num_vectors as usize)
|
|
.map(|i| {
|
|
let start = i * vector_dim as usize;
|
|
&vector_data[start..start + vector_dim as usize]
|
|
})
|
|
.collect();
|
|
|
|
let results = find_nearest(query, &vectors, k as usize, DistanceMetric::from_u8(metric));
|
|
|
|
// Write results
|
|
let indices = core::slice::from_raw_parts_mut(out_indices, results.len());
|
|
let distances = core::slice::from_raw_parts_mut(out_distances, results.len());
|
|
|
|
for (i, (idx, dist)) in results.iter().enumerate() {
|
|
indices[i] = *idx as u32;
|
|
distances[i] = *dist;
|
|
}
|
|
|
|
results.len() as u32
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_euclidean() {
|
|
let a = vec![1.0, 2.0, 3.0];
|
|
let b = vec![4.0, 5.0, 6.0];
|
|
let dist = euclidean_distance(&a, &b);
|
|
assert!((dist - 5.196).abs() < 0.01);
|
|
}
|
|
|
|
#[test]
|
|
fn test_cosine_identical() {
|
|
let a = vec![1.0, 2.0, 3.0];
|
|
let dist = cosine_distance(&a, &a);
|
|
assert!(dist.abs() < 0.001);
|
|
}
|
|
|
|
#[test]
|
|
fn test_manhattan() {
|
|
let a = vec![1.0, 2.0, 3.0];
|
|
let b = vec![4.0, 5.0, 6.0];
|
|
let dist = manhattan_distance(&a, &b);
|
|
assert!((dist - 9.0).abs() < 0.01);
|
|
}
|
|
|
|
#[test]
|
|
fn test_find_nearest() {
|
|
let query = vec![0.0, 0.0];
|
|
let v1 = vec![1.0, 0.0];
|
|
let v2 = vec![2.0, 0.0];
|
|
let v3 = vec![0.5, 0.0];
|
|
let vectors: Vec<&[f32]> = vec![&v1, &v2, &v3];
|
|
|
|
let results = find_nearest(&query, &vectors, 2, DistanceMetric::Euclidean);
|
|
assert_eq!(results.len(), 2);
|
|
assert_eq!(results[0].0, 2); // v3 is closest
|
|
}
|
|
}
|