Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
262
vendor/ruvector/examples/wasm/ios/src/distance.rs
vendored
Normal file
262
vendor/ruvector/examples/wasm/ios/src/distance.rs
vendored
Normal file
@@ -0,0 +1,262 @@
|
||||
//! Distance Metrics for iOS/Browser WASM
|
||||
//!
|
||||
//! Implements all key Ruvector distance functions with SIMD optimization.
|
||||
//! Supports: Euclidean, Cosine, Manhattan, DotProduct, Hamming
|
||||
|
||||
use crate::simd;
|
||||
|
||||
/// Distance metric type
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[repr(u8)]
|
||||
pub enum DistanceMetric {
|
||||
/// Euclidean (L2) distance
|
||||
Euclidean = 0,
|
||||
/// Cosine distance (1 - cosine_similarity)
|
||||
Cosine = 1,
|
||||
/// Dot product distance (negative dot for minimization)
|
||||
DotProduct = 2,
|
||||
/// Manhattan (L1) distance
|
||||
Manhattan = 3,
|
||||
/// Hamming distance (for binary vectors)
|
||||
Hamming = 4,
|
||||
}
|
||||
|
||||
impl DistanceMetric {
|
||||
/// Parse from u8
|
||||
pub fn from_u8(v: u8) -> Self {
|
||||
match v {
|
||||
0 => DistanceMetric::Euclidean,
|
||||
1 => DistanceMetric::Cosine,
|
||||
2 => DistanceMetric::DotProduct,
|
||||
3 => DistanceMetric::Manhattan,
|
||||
4 => DistanceMetric::Hamming,
|
||||
_ => DistanceMetric::Cosine, // Default
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate distance between two vectors
|
||||
#[inline]
|
||||
pub fn distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
|
||||
match metric {
|
||||
DistanceMetric::Euclidean => euclidean_distance(a, b),
|
||||
DistanceMetric::Cosine => cosine_distance(a, b),
|
||||
DistanceMetric::DotProduct => dot_product_distance(a, b),
|
||||
DistanceMetric::Manhattan => manhattan_distance(a, b),
|
||||
DistanceMetric::Hamming => hamming_distance_float(a, b),
|
||||
}
|
||||
}
|
||||
|
||||
/// Euclidean (L2) distance
|
||||
#[inline]
|
||||
pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
|
||||
simd::l2_distance(a, b)
|
||||
}
|
||||
|
||||
/// Squared Euclidean distance (faster, no sqrt)
|
||||
#[inline]
|
||||
pub fn euclidean_distance_squared(a: &[f32], b: &[f32]) -> f32 {
|
||||
let len = a.len().min(b.len());
|
||||
let mut sum = 0.0f32;
|
||||
for i in 0..len {
|
||||
let diff = a[i] - b[i];
|
||||
sum += diff * diff;
|
||||
}
|
||||
sum
|
||||
}
|
||||
|
||||
/// Cosine distance (1 - cosine_similarity)
|
||||
#[inline]
|
||||
pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
|
||||
1.0 - simd::cosine_similarity(a, b)
|
||||
}
|
||||
|
||||
/// Cosine similarity (not distance)
|
||||
#[inline]
|
||||
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||
simd::cosine_similarity(a, b)
|
||||
}
|
||||
|
||||
/// Dot product distance (negative for minimization)
|
||||
#[inline]
|
||||
pub fn dot_product_distance(a: &[f32], b: &[f32]) -> f32 {
|
||||
-simd::dot_product(a, b)
|
||||
}
|
||||
|
||||
/// Manhattan (L1) distance
|
||||
#[inline]
|
||||
pub fn manhattan_distance(a: &[f32], b: &[f32]) -> f32 {
|
||||
let len = a.len().min(b.len());
|
||||
let mut sum = 0.0f32;
|
||||
for i in 0..len {
|
||||
sum += (a[i] - b[i]).abs();
|
||||
}
|
||||
sum
|
||||
}
|
||||
|
||||
/// Hamming distance for float vectors (count sign differences)
|
||||
#[inline]
|
||||
pub fn hamming_distance_float(a: &[f32], b: &[f32]) -> f32 {
|
||||
let len = a.len().min(b.len());
|
||||
let mut count = 0u32;
|
||||
for i in 0..len {
|
||||
if (a[i] > 0.0) != (b[i] > 0.0) {
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
count as f32
|
||||
}
|
||||
|
||||
/// Hamming distance for binary packed vectors
|
||||
#[inline]
|
||||
pub fn hamming_distance_binary(a: &[u8], b: &[u8]) -> u32 {
|
||||
let mut distance = 0u32;
|
||||
for (&x, &y) in a.iter().zip(b.iter()) {
|
||||
distance += (x ^ y).count_ones();
|
||||
}
|
||||
distance
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// Batch Operations
|
||||
// ============================================
|
||||
|
||||
/// Find k nearest neighbors from a set of vectors
|
||||
pub fn find_nearest(
|
||||
query: &[f32],
|
||||
vectors: &[&[f32]],
|
||||
k: usize,
|
||||
metric: DistanceMetric,
|
||||
) -> Vec<(usize, f32)> {
|
||||
let mut distances: Vec<(usize, f32)> = vectors
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, v)| (i, distance(query, v, metric)))
|
||||
.collect();
|
||||
|
||||
// Partial sort for top-k
|
||||
if k < distances.len() {
|
||||
distances.select_nth_unstable_by(k, |a, b| {
|
||||
a.1.partial_cmp(&b.1).unwrap_or(core::cmp::Ordering::Equal)
|
||||
});
|
||||
distances.truncate(k);
|
||||
}
|
||||
|
||||
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(core::cmp::Ordering::Equal));
|
||||
distances
|
||||
}
|
||||
|
||||
/// Compute pairwise distances for a batch of queries
|
||||
pub fn batch_distances(
|
||||
queries: &[&[f32]],
|
||||
vectors: &[&[f32]],
|
||||
metric: DistanceMetric,
|
||||
) -> Vec<Vec<f32>> {
|
||||
queries
|
||||
.iter()
|
||||
.map(|q| {
|
||||
vectors.iter().map(|v| distance(q, v, metric)).collect()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// WASM Exports
|
||||
// ============================================
|
||||
|
||||
/// Calculate distance (WASM export)
|
||||
#[no_mangle]
|
||||
pub extern "C" fn calc_distance(
|
||||
a_ptr: *const f32,
|
||||
b_ptr: *const f32,
|
||||
len: u32,
|
||||
metric: u8,
|
||||
) -> f32 {
|
||||
unsafe {
|
||||
let a = core::slice::from_raw_parts(a_ptr, len as usize);
|
||||
let b = core::slice::from_raw_parts(b_ptr, len as usize);
|
||||
distance(a, b, DistanceMetric::from_u8(metric))
|
||||
}
|
||||
}
|
||||
|
||||
/// Batch nearest neighbor search (WASM export)
|
||||
/// Returns number of results written
|
||||
#[no_mangle]
|
||||
pub extern "C" fn find_nearest_batch(
|
||||
query_ptr: *const f32,
|
||||
query_len: u32,
|
||||
vectors_ptr: *const f32,
|
||||
num_vectors: u32,
|
||||
vector_dim: u32,
|
||||
k: u32,
|
||||
metric: u8,
|
||||
out_indices: *mut u32,
|
||||
out_distances: *mut f32,
|
||||
) -> u32 {
|
||||
unsafe {
|
||||
let query = core::slice::from_raw_parts(query_ptr, query_len as usize);
|
||||
|
||||
// Build vector slice references
|
||||
let vector_data = core::slice::from_raw_parts(vectors_ptr, (num_vectors * vector_dim) as usize);
|
||||
let vectors: Vec<&[f32]> = (0..num_vectors as usize)
|
||||
.map(|i| {
|
||||
let start = i * vector_dim as usize;
|
||||
&vector_data[start..start + vector_dim as usize]
|
||||
})
|
||||
.collect();
|
||||
|
||||
let results = find_nearest(query, &vectors, k as usize, DistanceMetric::from_u8(metric));
|
||||
|
||||
// Write results
|
||||
let indices = core::slice::from_raw_parts_mut(out_indices, results.len());
|
||||
let distances = core::slice::from_raw_parts_mut(out_distances, results.len());
|
||||
|
||||
for (i, (idx, dist)) in results.iter().enumerate() {
|
||||
indices[i] = *idx as u32;
|
||||
distances[i] = *dist;
|
||||
}
|
||||
|
||||
results.len() as u32
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_euclidean() {
|
||||
let a = vec![1.0, 2.0, 3.0];
|
||||
let b = vec![4.0, 5.0, 6.0];
|
||||
let dist = euclidean_distance(&a, &b);
|
||||
assert!((dist - 5.196).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_identical() {
|
||||
let a = vec![1.0, 2.0, 3.0];
|
||||
let dist = cosine_distance(&a, &a);
|
||||
assert!(dist.abs() < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_manhattan() {
|
||||
let a = vec![1.0, 2.0, 3.0];
|
||||
let b = vec![4.0, 5.0, 6.0];
|
||||
let dist = manhattan_distance(&a, &b);
|
||||
assert!((dist - 9.0).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_nearest() {
|
||||
let query = vec![0.0, 0.0];
|
||||
let v1 = vec![1.0, 0.0];
|
||||
let v2 = vec![2.0, 0.0];
|
||||
let v3 = vec![0.5, 0.0];
|
||||
let vectors: Vec<&[f32]> = vec![&v1, &v2, &v3];
|
||||
|
||||
let results = find_nearest(&query, &vectors, 2, DistanceMetric::Euclidean);
|
||||
assert_eq!(results.len(), 2);
|
||||
assert_eq!(results[0].0, 2); // v3 is closest
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user