Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,358 @@
//! Attention Mechanism Module for iOS WASM
//!
//! Lightweight self-attention for content ranking and sequence modeling.
//! Optimized for minimal memory footprint on mobile devices.
/// Maximum sequence length for attention
const MAX_SEQ_LEN: usize = 64;
/// Single attention head
pub struct AttentionHead {
/// Dimension of key/query/value
dim: usize,
/// Query projection weights
w_query: Vec<f32>,
/// Key projection weights
w_key: Vec<f32>,
/// Value projection weights
w_value: Vec<f32>,
/// Scaling factor (1/sqrt(dim))
scale: f32,
}
impl AttentionHead {
/// Create a new attention head with random initialization
pub fn new(input_dim: usize, head_dim: usize, seed: u32) -> Self {
let dim = head_dim;
let weight_size = input_dim * dim;
// Xavier initialization with deterministic pseudo-random
let std_dev = (2.0 / (input_dim + dim) as f32).sqrt();
let w_query = Self::init_weights(weight_size, seed, std_dev);
let w_key = Self::init_weights(weight_size, seed.wrapping_add(1), std_dev);
let w_value = Self::init_weights(weight_size, seed.wrapping_add(2), std_dev);
Self {
dim,
w_query,
w_key,
w_value,
scale: 1.0 / (dim as f32).sqrt(),
}
}
/// Initialize weights with pseudo-random values
fn init_weights(size: usize, seed: u32, std_dev: f32) -> Vec<f32> {
let mut weights = Vec::with_capacity(size);
let mut s = seed;
for _ in 0..size {
s = s.wrapping_mul(1103515245).wrapping_add(12345);
let uniform = ((s >> 16) as f32 / 32768.0) - 1.0;
weights.push(uniform * std_dev);
}
weights
}
/// Project input to query/key/value space
#[inline]
fn project(&self, input: &[f32], weights: &[f32]) -> Vec<f32> {
let input_dim = self.w_query.len() / self.dim;
let mut output = vec![0.0; self.dim];
for (i, o) in output.iter_mut().enumerate() {
for (j, &inp) in input.iter().take(input_dim).enumerate() {
let idx = j * self.dim + i;
if idx < weights.len() {
*o += inp * weights[idx];
}
}
}
output
}
/// Compute attention scores between query and key
#[inline]
fn attention_score(&self, query: &[f32], key: &[f32]) -> f32 {
let dot: f32 = query.iter().zip(key.iter()).map(|(q, k)| q * k).sum();
dot * self.scale
}
/// Apply softmax to attention scores
fn softmax(scores: &mut [f32]) {
if scores.is_empty() {
return;
}
// Numerical stability: subtract max
let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0;
for s in scores.iter_mut() {
*s = (*s - max_score).exp();
sum += *s;
}
if sum > 1e-8 {
for s in scores.iter_mut() {
*s /= sum;
}
}
}
/// Compute self-attention over a sequence
pub fn forward(&self, sequence: &[Vec<f32>]) -> Vec<Vec<f32>> {
let seq_len = sequence.len().min(MAX_SEQ_LEN);
if seq_len == 0 {
return vec![];
}
// Project to Q, K, V
let queries: Vec<Vec<f32>> = sequence.iter().take(seq_len)
.map(|x| self.project(x, &self.w_query))
.collect();
let keys: Vec<Vec<f32>> = sequence.iter().take(seq_len)
.map(|x| self.project(x, &self.w_key))
.collect();
let values: Vec<Vec<f32>> = sequence.iter().take(seq_len)
.map(|x| self.project(x, &self.w_value))
.collect();
// Compute attention for each position
let mut outputs = Vec::with_capacity(seq_len);
for q in &queries {
// Compute attention scores
let mut scores: Vec<f32> = keys.iter()
.map(|k| self.attention_score(q, k))
.collect();
Self::softmax(&mut scores);
// Weighted sum of values
let mut output = vec![0.0; self.dim];
for (score, value) in scores.iter().zip(values.iter()) {
for (o, v) in output.iter_mut().zip(value.iter()) {
*o += score * v;
}
}
outputs.push(output);
}
outputs
}
/// Get output dimension
pub fn dim(&self) -> usize {
self.dim
}
}
/// Multi-head attention layer
pub struct MultiHeadAttention {
heads: Vec<AttentionHead>,
/// Output projection weights
w_out: Vec<f32>,
output_dim: usize,
}
impl MultiHeadAttention {
/// Create new multi-head attention
pub fn new(input_dim: usize, num_heads: usize, head_dim: usize, seed: u32) -> Self {
let heads: Vec<AttentionHead> = (0..num_heads)
.map(|i| AttentionHead::new(input_dim, head_dim, seed.wrapping_add(i as u32 * 10)))
.collect();
let concat_dim = num_heads * head_dim;
let output_dim = input_dim;
let w_out = AttentionHead::init_weights(
concat_dim * output_dim,
seed.wrapping_add(1000),
(2.0 / (concat_dim + output_dim) as f32).sqrt(),
);
Self {
heads,
w_out,
output_dim,
}
}
/// Forward pass through multi-head attention
pub fn forward(&self, sequence: &[Vec<f32>]) -> Vec<Vec<f32>> {
if sequence.is_empty() {
return vec![];
}
// Get outputs from all heads
let head_outputs: Vec<Vec<Vec<f32>>> = self.heads.iter()
.map(|head| head.forward(sequence))
.collect();
// Concatenate and project
let seq_len = head_outputs[0].len();
let head_dim = if self.heads.is_empty() { 0 } else { self.heads[0].dim() };
let concat_dim = self.heads.len() * head_dim;
let mut outputs = Vec::with_capacity(seq_len);
for pos in 0..seq_len {
// Concatenate heads
let mut concat = Vec::with_capacity(concat_dim);
for head_out in &head_outputs {
concat.extend_from_slice(&head_out[pos]);
}
// Output projection
let mut output = vec![0.0; self.output_dim];
for (i, o) in output.iter_mut().enumerate() {
for (j, &c) in concat.iter().enumerate() {
let idx = j * self.output_dim + i;
if idx < self.w_out.len() {
*o += c * self.w_out[idx];
}
}
}
outputs.push(output);
}
outputs
}
/// Apply attention pooling to get single output
pub fn pool(&self, sequence: &[Vec<f32>]) -> Vec<f32> {
let attended = self.forward(sequence);
if attended.is_empty() {
return vec![0.0; self.output_dim];
}
// Mean pooling over sequence
let mut pooled = vec![0.0; self.output_dim];
for item in &attended {
for (p, v) in pooled.iter_mut().zip(item.iter()) {
*p += v;
}
}
let n = attended.len() as f32;
for p in &mut pooled {
*p /= n;
}
pooled
}
}
/// Context-aware content ranker using attention
pub struct AttentionRanker {
attention: MultiHeadAttention,
/// Query transformation weights
w_query_transform: Vec<f32>,
dim: usize,
}
impl AttentionRanker {
/// Create new attention-based ranker
pub fn new(dim: usize, num_heads: usize) -> Self {
let head_dim = dim / num_heads.max(1);
let attention = MultiHeadAttention::new(dim, num_heads, head_dim, 54321);
let w_query_transform = AttentionHead::init_weights(
dim * dim,
99999,
(2.0 / (dim * 2) as f32).sqrt(),
);
Self {
attention,
w_query_transform,
dim,
}
}
/// Rank content items based on user context
///
/// Returns indices sorted by relevance score
pub fn rank(&self, query: &[f32], items: &[Vec<f32>]) -> Vec<(usize, f32)> {
if items.is_empty() || query.len() != self.dim {
return vec![];
}
// Transform query
let mut transformed_query = vec![0.0; self.dim];
for (i, tq) in transformed_query.iter_mut().enumerate() {
for (j, &q) in query.iter().enumerate() {
let idx = j * self.dim + i;
if idx < self.w_query_transform.len() {
*tq += q * self.w_query_transform[idx];
}
}
}
// Create sequence with query prepended
let mut sequence = vec![transformed_query.clone()];
sequence.extend(items.iter().cloned());
// Apply attention
let attended = self.attention.forward(&sequence);
// Score each item by similarity to attended query
let query_attended = &attended[0];
let mut scores: Vec<(usize, f32)> = attended[1..].iter()
.enumerate()
.map(|(i, item)| {
let sim: f32 = query_attended.iter()
.zip(item.iter())
.map(|(q, v)| q * v)
.sum();
(i, sim)
})
.collect();
// Sort by score descending
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(core::cmp::Ordering::Equal));
scores
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_attention_head() {
let head = AttentionHead::new(64, 16, 12345);
let sequence = vec![vec![0.5; 64]; 5];
let output = head.forward(&sequence);
assert_eq!(output.len(), 5);
assert_eq!(output[0].len(), 16);
}
#[test]
fn test_multi_head_attention() {
let mha = MultiHeadAttention::new(64, 4, 16, 12345);
let sequence = vec![vec![0.5; 64]; 5];
let output = mha.forward(&sequence);
assert_eq!(output.len(), 5);
assert_eq!(output[0].len(), 64);
}
#[test]
fn test_attention_ranker() {
let ranker = AttentionRanker::new(64, 4);
let query = vec![0.5; 64];
let items = vec![vec![0.3; 64], vec![0.7; 64], vec![0.1; 64]];
let ranked = ranker.rank(&query, &items);
assert_eq!(ranked.len(), 3);
}
}

View File

@@ -0,0 +1,262 @@
//! Distance Metrics for iOS/Browser WASM
//!
//! Implements all key Ruvector distance functions with SIMD optimization.
//! Supports: Euclidean, Cosine, Manhattan, DotProduct, Hamming
use crate::simd;
/// Distance metric type
#[derive(Clone, Copy, Debug, PartialEq)]
#[repr(u8)]
pub enum DistanceMetric {
/// Euclidean (L2) distance
Euclidean = 0,
/// Cosine distance (1 - cosine_similarity)
Cosine = 1,
/// Dot product distance (negative dot for minimization)
DotProduct = 2,
/// Manhattan (L1) distance
Manhattan = 3,
/// Hamming distance (for binary vectors)
Hamming = 4,
}
impl DistanceMetric {
/// Parse from u8
pub fn from_u8(v: u8) -> Self {
match v {
0 => DistanceMetric::Euclidean,
1 => DistanceMetric::Cosine,
2 => DistanceMetric::DotProduct,
3 => DistanceMetric::Manhattan,
4 => DistanceMetric::Hamming,
_ => DistanceMetric::Cosine, // Default
}
}
}
/// Calculate distance between two vectors
#[inline]
pub fn distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
match metric {
DistanceMetric::Euclidean => euclidean_distance(a, b),
DistanceMetric::Cosine => cosine_distance(a, b),
DistanceMetric::DotProduct => dot_product_distance(a, b),
DistanceMetric::Manhattan => manhattan_distance(a, b),
DistanceMetric::Hamming => hamming_distance_float(a, b),
}
}
/// Euclidean (L2) distance
#[inline]
pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
simd::l2_distance(a, b)
}
/// Squared Euclidean distance (faster, no sqrt)
#[inline]
pub fn euclidean_distance_squared(a: &[f32], b: &[f32]) -> f32 {
let len = a.len().min(b.len());
let mut sum = 0.0f32;
for i in 0..len {
let diff = a[i] - b[i];
sum += diff * diff;
}
sum
}
/// Cosine distance (1 - cosine_similarity)
#[inline]
pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
1.0 - simd::cosine_similarity(a, b)
}
/// Cosine similarity (not distance)
#[inline]
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
simd::cosine_similarity(a, b)
}
/// Dot product distance (negative for minimization)
#[inline]
pub fn dot_product_distance(a: &[f32], b: &[f32]) -> f32 {
-simd::dot_product(a, b)
}
/// Manhattan (L1) distance
#[inline]
pub fn manhattan_distance(a: &[f32], b: &[f32]) -> f32 {
let len = a.len().min(b.len());
let mut sum = 0.0f32;
for i in 0..len {
sum += (a[i] - b[i]).abs();
}
sum
}
/// Hamming distance for float vectors (count sign differences)
#[inline]
pub fn hamming_distance_float(a: &[f32], b: &[f32]) -> f32 {
let len = a.len().min(b.len());
let mut count = 0u32;
for i in 0..len {
if (a[i] > 0.0) != (b[i] > 0.0) {
count += 1;
}
}
count as f32
}
/// Hamming distance for binary packed vectors
#[inline]
pub fn hamming_distance_binary(a: &[u8], b: &[u8]) -> u32 {
let mut distance = 0u32;
for (&x, &y) in a.iter().zip(b.iter()) {
distance += (x ^ y).count_ones();
}
distance
}
// ============================================
// Batch Operations
// ============================================
/// Find k nearest neighbors from a set of vectors
pub fn find_nearest(
query: &[f32],
vectors: &[&[f32]],
k: usize,
metric: DistanceMetric,
) -> Vec<(usize, f32)> {
let mut distances: Vec<(usize, f32)> = vectors
.iter()
.enumerate()
.map(|(i, v)| (i, distance(query, v, metric)))
.collect();
// Partial sort for top-k
if k < distances.len() {
distances.select_nth_unstable_by(k, |a, b| {
a.1.partial_cmp(&b.1).unwrap_or(core::cmp::Ordering::Equal)
});
distances.truncate(k);
}
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(core::cmp::Ordering::Equal));
distances
}
/// Compute pairwise distances for a batch of queries
pub fn batch_distances(
queries: &[&[f32]],
vectors: &[&[f32]],
metric: DistanceMetric,
) -> Vec<Vec<f32>> {
queries
.iter()
.map(|q| {
vectors.iter().map(|v| distance(q, v, metric)).collect()
})
.collect()
}
// ============================================
// WASM Exports
// ============================================
/// Calculate distance (WASM export)
#[no_mangle]
pub extern "C" fn calc_distance(
a_ptr: *const f32,
b_ptr: *const f32,
len: u32,
metric: u8,
) -> f32 {
unsafe {
let a = core::slice::from_raw_parts(a_ptr, len as usize);
let b = core::slice::from_raw_parts(b_ptr, len as usize);
distance(a, b, DistanceMetric::from_u8(metric))
}
}
/// Batch nearest neighbor search (WASM export)
/// Returns number of results written
#[no_mangle]
pub extern "C" fn find_nearest_batch(
query_ptr: *const f32,
query_len: u32,
vectors_ptr: *const f32,
num_vectors: u32,
vector_dim: u32,
k: u32,
metric: u8,
out_indices: *mut u32,
out_distances: *mut f32,
) -> u32 {
unsafe {
let query = core::slice::from_raw_parts(query_ptr, query_len as usize);
// Build vector slice references
let vector_data = core::slice::from_raw_parts(vectors_ptr, (num_vectors * vector_dim) as usize);
let vectors: Vec<&[f32]> = (0..num_vectors as usize)
.map(|i| {
let start = i * vector_dim as usize;
&vector_data[start..start + vector_dim as usize]
})
.collect();
let results = find_nearest(query, &vectors, k as usize, DistanceMetric::from_u8(metric));
// Write results
let indices = core::slice::from_raw_parts_mut(out_indices, results.len());
let distances = core::slice::from_raw_parts_mut(out_distances, results.len());
for (i, (idx, dist)) in results.iter().enumerate() {
indices[i] = *idx as u32;
distances[i] = *dist;
}
results.len() as u32
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_euclidean() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
let dist = euclidean_distance(&a, &b);
assert!((dist - 5.196).abs() < 0.01);
}
#[test]
fn test_cosine_identical() {
let a = vec![1.0, 2.0, 3.0];
let dist = cosine_distance(&a, &a);
assert!(dist.abs() < 0.001);
}
#[test]
fn test_manhattan() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
let dist = manhattan_distance(&a, &b);
assert!((dist - 9.0).abs() < 0.01);
}
#[test]
fn test_find_nearest() {
let query = vec![0.0, 0.0];
let v1 = vec![1.0, 0.0];
let v2 = vec![2.0, 0.0];
let v3 = vec![0.5, 0.0];
let vectors: Vec<&[f32]> = vec![&v1, &v2, &v3];
let results = find_nearest(&query, &vectors, 2, DistanceMetric::Euclidean);
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, 2); // v3 is closest
}
}

View File

@@ -0,0 +1,212 @@
//! Content Embedding Module for iOS WASM
//!
//! Lightweight embedding generation for content recommendations.
//! Optimized for minimal binary size and sub-100ms latency on iPhone 12+.
/// Maximum embedding dimensions (memory budget constraint)
pub const MAX_EMBEDDING_DIM: usize = 256;
/// Default embedding dimension for content
pub const DEFAULT_DIM: usize = 64;
/// Content metadata for embedding generation
#[derive(Clone, Debug)]
pub struct ContentMetadata {
/// Content identifier
pub id: u64,
/// Content type (0=video, 1=audio, 2=image, 3=text)
pub content_type: u8,
/// Duration in seconds (for video/audio)
pub duration_secs: u32,
/// Category tags (bit flags)
pub category_flags: u32,
/// Popularity score (0.0 - 1.0)
pub popularity: f32,
/// Recency score (0.0 - 1.0)
pub recency: f32,
}
impl Default for ContentMetadata {
fn default() -> Self {
Self {
id: 0,
content_type: 0,
duration_secs: 0,
category_flags: 0,
popularity: 0.5,
recency: 0.5,
}
}
}
/// Lightweight content embedder optimized for iOS
pub struct ContentEmbedder {
dim: usize,
// Pre-computed projection weights (random but deterministic)
projection: Vec<f32>,
}
impl ContentEmbedder {
/// Create a new embedder with specified dimension
pub fn new(dim: usize) -> Self {
let dim = dim.min(MAX_EMBEDDING_DIM);
// Initialize deterministic pseudo-random projection
// Using simple LCG for reproducibility without rand crate
let mut projection = Vec::with_capacity(dim * 8);
let mut seed: u32 = 12345;
for _ in 0..(dim * 8) {
seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
let val = ((seed >> 16) as f32 / 32768.0) - 1.0;
projection.push(val * 0.1); // Scale factor
}
Self { dim, projection }
}
/// Embed content metadata into a vector
#[inline]
pub fn embed(&self, content: &ContentMetadata) -> Vec<f32> {
let mut embedding = vec![0.0f32; self.dim];
// Feature extraction with projection
let features = [
content.content_type as f32 / 4.0,
(content.duration_secs as f32).ln_1p() / 10.0,
(content.category_flags as f32).sqrt() / 64.0,
content.popularity,
content.recency,
content.id as f32 % 1000.0 / 1000.0,
((content.id >> 10) as f32 % 1000.0) / 1000.0,
((content.id >> 20) as f32 % 1000.0) / 1000.0,
];
// Project features to embedding space
for (i, e) in embedding.iter_mut().enumerate() {
for (j, &feat) in features.iter().enumerate() {
let proj_idx = i * 8 + j;
if proj_idx < self.projection.len() {
*e += feat * self.projection[proj_idx];
}
}
}
// L2 normalize
self.normalize(&mut embedding);
embedding
}
/// Embed raw feature vector
#[inline]
pub fn embed_features(&self, features: &[f32]) -> Vec<f32> {
let mut embedding = vec![0.0f32; self.dim];
for (i, e) in embedding.iter_mut().enumerate() {
for (j, &feat) in features.iter().take(8).enumerate() {
let proj_idx = i * 8 + j;
if proj_idx < self.projection.len() {
*e += feat * self.projection[proj_idx];
}
}
}
self.normalize(&mut embedding);
embedding
}
/// L2 normalize a vector in place
#[inline]
fn normalize(&self, vec: &mut [f32]) {
let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-8 {
for x in vec.iter_mut() {
*x /= norm;
}
}
}
/// Compute cosine similarity between two embeddings
#[inline]
pub fn similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
/// Get embedding dimension
pub fn dim(&self) -> usize {
self.dim
}
}
/// User vibe/preference state for personalized recommendations
#[derive(Clone, Debug, Default)]
pub struct VibeState {
/// Energy level (0.0 = calm, 1.0 = energetic)
pub energy: f32,
/// Mood valence (-1.0 = negative, 1.0 = positive)
pub mood: f32,
/// Focus level (0.0 = relaxed, 1.0 = focused)
pub focus: f32,
/// Time of day preference (0.0 = morning, 1.0 = night)
pub time_context: f32,
/// Custom preference weights
pub preferences: [f32; 4],
}
impl VibeState {
/// Convert vibe state to embedding
pub fn to_embedding(&self, embedder: &ContentEmbedder) -> Vec<f32> {
let features = [
self.energy,
(self.mood + 1.0) / 2.0, // Normalize to 0-1
self.focus,
self.time_context,
self.preferences[0],
self.preferences[1],
self.preferences[2],
self.preferences[3],
];
embedder.embed_features(&features)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embedder_creation() {
let embedder = ContentEmbedder::new(64);
assert_eq!(embedder.dim(), 64);
}
#[test]
fn test_embedding_normalized() {
let embedder = ContentEmbedder::new(64);
let content = ContentMetadata::default();
let embedding = embedder.embed(&content);
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 0.01);
}
#[test]
fn test_similarity_range() {
let embedder = ContentEmbedder::new(64);
let c1 = ContentMetadata { id: 1, ..Default::default() };
let c2 = ContentMetadata { id: 2, ..Default::default() };
let e1 = embedder.embed(&c1);
let e2 = embedder.embed(&c2);
let sim = ContentEmbedder::similarity(&e1, &e2);
assert!(sim >= -1.0 && sim <= 1.0);
}
}

View File

@@ -0,0 +1,691 @@
//! Lightweight HNSW Index for iOS/Browser WASM
//!
//! A simplified HNSW implementation optimized for mobile/browser deployment.
//! Provides O(log n) approximate nearest neighbor search.
//!
//! Based on the paper: "Efficient and Robust Approximate Nearest Neighbor Search
//! Using Hierarchical Navigable Small World Graphs"
use crate::distance::{distance, DistanceMetric};
use std::collections::{BinaryHeap, HashSet};
use std::vec::Vec;
use core::cmp::Ordering;
/// HNSW configuration
#[derive(Clone, Debug)]
pub struct HnswConfig {
/// Max connections per node (M parameter)
pub m: usize,
/// Max connections at layer 0 (usually 2*M)
pub m_max_0: usize,
/// Construction-time search width
pub ef_construction: usize,
/// Query-time search width
pub ef_search: usize,
/// Level multiplier (1/ln(M))
pub level_mult: f32,
}
impl Default for HnswConfig {
fn default() -> Self {
Self {
m: 16,
m_max_0: 32,
ef_construction: 100,
ef_search: 50,
level_mult: 0.36, // 1/ln(16)
}
}
}
/// Node in the HNSW graph
#[derive(Clone, Debug)]
struct HnswNode {
/// Vector ID
id: u64,
/// Vector data
vector: Vec<f32>,
/// Connections at each layer
connections: Vec<Vec<u64>>,
/// Node's layer
level: usize,
}
/// Search candidate with distance
#[derive(Clone, Debug)]
struct Candidate {
id: u64,
distance: f32,
}
impl PartialEq for Candidate {
fn eq(&self, other: &Self) -> bool {
self.id == other.id
}
}
impl Eq for Candidate {}
impl PartialOrd for Candidate {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
// Reverse order for min-heap behavior in BinaryHeap
other.distance.partial_cmp(&self.distance)
}
}
impl Ord for Candidate {
fn cmp(&self, other: &Self) -> Ordering {
self.partial_cmp(other).unwrap_or(Ordering::Equal)
}
}
/// Lightweight HNSW index
pub struct HnswIndex {
/// All nodes
nodes: Vec<HnswNode>,
/// ID to node index mapping
id_to_idx: std::collections::HashMap<u64, usize>,
/// Entry point (topmost node)
entry_point: Option<usize>,
/// Maximum level in the graph
max_level: usize,
/// Configuration
config: HnswConfig,
/// Distance metric
metric: DistanceMetric,
/// Dimension
dim: usize,
/// Random seed for level generation
seed: u32,
}
impl HnswIndex {
/// Create a new HNSW index
pub fn new(dim: usize, metric: DistanceMetric, config: HnswConfig) -> Self {
Self {
nodes: Vec::new(),
id_to_idx: std::collections::HashMap::new(),
entry_point: None,
max_level: 0,
config,
metric,
dim,
seed: 12345,
}
}
/// Create with default config
pub fn with_defaults(dim: usize, metric: DistanceMetric) -> Self {
Self::new(dim, metric, HnswConfig::default())
}
/// Generate random level for a new node
fn random_level(&mut self) -> usize {
// LCG random number generator
self.seed = self.seed.wrapping_mul(1103515245).wrapping_add(12345);
let rand = (self.seed >> 16) as f32 / 32768.0;
let level = (-rand.ln() * self.config.level_mult).floor() as usize;
level.min(16) // Cap at 16 levels
}
/// Insert a vector into the index
pub fn insert(&mut self, id: u64, vector: Vec<f32>) -> bool {
if vector.len() != self.dim {
return false;
}
if self.id_to_idx.contains_key(&id) {
return false; // Already exists
}
let level = self.random_level();
let node_idx = self.nodes.len();
// Create node with empty connections
let mut node = HnswNode {
id,
vector,
connections: vec![Vec::new(); level + 1],
level,
};
if let Some(ep_idx) = self.entry_point {
// Find entry point at the top level
let mut curr_idx = ep_idx;
let mut curr_dist = self.distance_to_node(node_idx, curr_idx, &node.vector);
// Traverse from top to insertion level
for lc in (level + 1..=self.max_level).rev() {
let mut changed = true;
while changed {
changed = false;
if let Some(connections) = self.nodes.get(curr_idx).map(|n| n.connections.get(lc).cloned()).flatten() {
for &neighbor_id in &connections {
if let Some(&neighbor_idx) = self.id_to_idx.get(&neighbor_id) {
let d = self.distance_to_node(node_idx, neighbor_idx, &node.vector);
if d < curr_dist {
curr_dist = d;
curr_idx = neighbor_idx;
changed = true;
}
}
}
}
}
}
// Insert at each level
for lc in (0..=level.min(self.max_level)).rev() {
let neighbors = self.search_layer(&node.vector, curr_idx, self.config.ef_construction, lc);
// Select M best neighbors
let m_max = if lc == 0 { self.config.m_max_0 } else { self.config.m };
let selected: Vec<u64> = neighbors.iter()
.take(m_max)
.map(|c| c.id)
.collect();
node.connections[lc] = selected.clone();
// Add bidirectional connections
for &neighbor_id in &selected {
if let Some(&neighbor_idx) = self.id_to_idx.get(&neighbor_id) {
if let Some(neighbor_node) = self.nodes.get_mut(neighbor_idx) {
if lc < neighbor_node.connections.len() {
neighbor_node.connections[lc].push(id);
// Prune if too many connections
if neighbor_node.connections[lc].len() > m_max {
let query = &neighbor_node.vector.clone();
self.prune_connections(neighbor_idx, lc, m_max, query);
}
}
}
}
}
if !neighbors.is_empty() {
curr_idx = self.id_to_idx.get(&neighbors[0].id).copied().unwrap_or(curr_idx);
}
}
}
// Add node
self.nodes.push(node);
self.id_to_idx.insert(id, node_idx);
// Update entry point if this is higher level
if level > self.max_level || self.entry_point.is_none() {
self.max_level = level;
self.entry_point = Some(node_idx);
}
true
}
/// Search for k nearest neighbors
pub fn search(&self, query: &[f32], k: usize) -> Vec<(u64, f32)> {
self.search_with_ef(query, k, self.config.ef_search)
}
/// Search with custom ef parameter
pub fn search_with_ef(&self, query: &[f32], k: usize, ef: usize) -> Vec<(u64, f32)> {
if query.len() != self.dim || self.entry_point.is_none() {
return vec![];
}
let ep_idx = self.entry_point.unwrap();
// Find entry point by traversing from top
let mut curr_idx = ep_idx;
let mut curr_dist = distance(query, &self.nodes[curr_idx].vector, self.metric);
for lc in (1..=self.max_level).rev() {
let mut changed = true;
while changed {
changed = false;
if let Some(connections) = self.nodes.get(curr_idx).and_then(|n| n.connections.get(lc)) {
for &neighbor_id in connections {
if let Some(&neighbor_idx) = self.id_to_idx.get(&neighbor_id) {
let d = distance(query, &self.nodes[neighbor_idx].vector, self.metric);
if d < curr_dist {
curr_dist = d;
curr_idx = neighbor_idx;
changed = true;
}
}
}
}
}
}
// Search at layer 0
let results = self.search_layer(query, curr_idx, ef, 0);
results.into_iter()
.take(k)
.map(|c| (c.id, c.distance))
.collect()
}
/// Search within a specific layer
fn search_layer(&self, query: &[f32], entry_idx: usize, ef: usize, layer: usize) -> Vec<Candidate> {
let entry_id = self.nodes[entry_idx].id;
let entry_dist = distance(query, &self.nodes[entry_idx].vector, self.metric);
let mut visited: HashSet<u64> = HashSet::new();
let mut candidates: BinaryHeap<Candidate> = BinaryHeap::new();
let mut results: Vec<Candidate> = Vec::new();
visited.insert(entry_id);
candidates.push(Candidate { id: entry_id, distance: entry_dist });
results.push(Candidate { id: entry_id, distance: entry_dist });
while let Some(current) = candidates.pop() {
// Stop if current is worse than worst in results
if results.len() >= ef {
let worst_dist = results.iter().map(|c| c.distance).fold(f32::NEG_INFINITY, f32::max);
if current.distance > worst_dist {
break;
}
}
// Explore neighbors
if let Some(&curr_idx) = self.id_to_idx.get(&current.id) {
if let Some(connections) = self.nodes.get(curr_idx).and_then(|n| n.connections.get(layer)) {
for &neighbor_id in connections {
if visited.insert(neighbor_id) {
if let Some(&neighbor_idx) = self.id_to_idx.get(&neighbor_id) {
let d = distance(query, &self.nodes[neighbor_idx].vector, self.metric);
let should_add = results.len() < ef || {
let worst = results.iter().map(|c| c.distance).fold(f32::NEG_INFINITY, f32::max);
d < worst
};
if should_add {
candidates.push(Candidate { id: neighbor_id, distance: d });
results.push(Candidate { id: neighbor_id, distance: d });
// Keep only ef best
if results.len() > ef {
results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
results.truncate(ef);
}
}
}
}
}
}
}
}
results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
results
}
/// Prune connections to keep only the best
fn prune_connections(&mut self, node_idx: usize, layer: usize, max_conn: usize, query: &[f32]) {
// First, collect connection info without holding mutable borrow
let connections_to_score: Vec<u64> = if let Some(node) = self.nodes.get(node_idx) {
if layer < node.connections.len() {
node.connections[layer].clone()
} else {
return;
}
} else {
return;
};
// Score connections
let mut candidates: Vec<(u64, f32)> = connections_to_score
.iter()
.filter_map(|&id| {
self.id_to_idx.get(&id)
.and_then(|&idx| self.nodes.get(idx))
.map(|n| (id, distance(query, &n.vector, self.metric)))
})
.collect();
candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
let pruned: Vec<u64> = candidates.into_iter()
.take(max_conn)
.map(|(id, _)| id)
.collect();
// Now update the connections
if let Some(node) = self.nodes.get_mut(node_idx) {
if layer < node.connections.len() {
node.connections[layer] = pruned;
}
}
}
/// Helper to calculate distance to a node
fn distance_to_node(&self, _new_idx: usize, existing_idx: usize, new_vector: &[f32]) -> f32 {
if let Some(node) = self.nodes.get(existing_idx) {
distance(new_vector, &node.vector, self.metric)
} else {
f32::MAX
}
}
/// Get number of vectors in the index
pub fn len(&self) -> usize {
self.nodes.len()
}
/// Check if empty
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
/// Get vector by ID
pub fn get(&self, id: u64) -> Option<&[f32]> {
self.id_to_idx.get(&id)
.and_then(|&idx| self.nodes.get(idx))
.map(|n| n.vector.as_slice())
}
// ============================================
// Persistence
// ============================================
/// Serialize the HNSW index to bytes
///
/// Format:
/// - Header (32 bytes): dim, metric, m, m_max_0, ef_construction, ef_search, max_level, node_count
/// - For each node: id (8), level (4), vector (dim*4), connections per layer
pub fn serialize(&self) -> Vec<u8> {
let mut bytes = Vec::new();
// Header
bytes.extend_from_slice(&(self.dim as u32).to_le_bytes());
bytes.extend_from_slice(&(self.metric as u8).to_le_bytes());
bytes.extend_from_slice(&[0u8; 3]); // padding
bytes.extend_from_slice(&(self.config.m as u32).to_le_bytes());
bytes.extend_from_slice(&(self.config.m_max_0 as u32).to_le_bytes());
bytes.extend_from_slice(&(self.config.ef_construction as u32).to_le_bytes());
bytes.extend_from_slice(&(self.config.ef_search as u32).to_le_bytes());
bytes.extend_from_slice(&(self.max_level as u32).to_le_bytes());
bytes.extend_from_slice(&(self.nodes.len() as u32).to_le_bytes());
bytes.extend_from_slice(&self.entry_point.map(|e| e as u32).unwrap_or(u32::MAX).to_le_bytes());
// Nodes
for node in &self.nodes {
// Node header: id, level
bytes.extend_from_slice(&node.id.to_le_bytes());
bytes.extend_from_slice(&(node.level as u32).to_le_bytes());
// Vector
for &v in &node.vector {
bytes.extend_from_slice(&v.to_le_bytes());
}
// Connections: count per layer, then connection IDs
bytes.extend_from_slice(&(node.connections.len() as u32).to_le_bytes());
for layer_conns in &node.connections {
bytes.extend_from_slice(&(layer_conns.len() as u32).to_le_bytes());
for &conn_id in layer_conns {
bytes.extend_from_slice(&conn_id.to_le_bytes());
}
}
}
bytes
}
/// Deserialize HNSW index from bytes
pub fn deserialize(bytes: &[u8]) -> Option<Self> {
if bytes.len() < 36 {
return None;
}
let mut offset = 0;
// Read header
let dim = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
let metric = DistanceMetric::from_u8(bytes[4]);
offset = 8;
let m = u32::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3]]) as usize;
offset += 4;
let m_max_0 = u32::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3]]) as usize;
offset += 4;
let ef_construction = u32::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3]]) as usize;
offset += 4;
let ef_search = u32::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3]]) as usize;
offset += 4;
let max_level = u32::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3]]) as usize;
offset += 4;
let node_count = u32::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3]]) as usize;
offset += 4;
let entry_point_raw = u32::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3]]);
offset += 4;
let entry_point = if entry_point_raw == u32::MAX { None } else { Some(entry_point_raw as usize) };
let config = HnswConfig {
m,
m_max_0,
ef_construction,
ef_search,
level_mult: 1.0 / (m as f32).ln(),
};
let mut nodes = Vec::with_capacity(node_count);
let mut id_to_idx = std::collections::HashMap::new();
for node_idx in 0..node_count {
if offset + 12 > bytes.len() {
return None;
}
// Node header
let id = u64::from_le_bytes([
bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3],
bytes[offset+4], bytes[offset+5], bytes[offset+6], bytes[offset+7],
]);
offset += 8;
let level = u32::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3]]) as usize;
offset += 4;
// Vector
let mut vector = Vec::with_capacity(dim);
for _ in 0..dim {
if offset + 4 > bytes.len() {
return None;
}
let v = f32::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3]]);
vector.push(v);
offset += 4;
}
// Connections
if offset + 4 > bytes.len() {
return None;
}
let num_layers = u32::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3]]) as usize;
offset += 4;
let mut connections = Vec::with_capacity(num_layers);
for _ in 0..num_layers {
if offset + 4 > bytes.len() {
return None;
}
let num_conns = u32::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3]]) as usize;
offset += 4;
let mut layer_conns = Vec::with_capacity(num_conns);
for _ in 0..num_conns {
if offset + 8 > bytes.len() {
return None;
}
let conn_id = u64::from_le_bytes([
bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3],
bytes[offset+4], bytes[offset+5], bytes[offset+6], bytes[offset+7],
]);
layer_conns.push(conn_id);
offset += 8;
}
connections.push(layer_conns);
}
id_to_idx.insert(id, node_idx);
nodes.push(HnswNode {
id,
vector,
connections,
level,
});
}
Some(Self {
nodes,
id_to_idx,
entry_point,
max_level,
config,
metric,
dim,
seed: 12345,
})
}
/// Estimate serialized size in bytes
pub fn serialized_size(&self) -> usize {
let mut size = 36; // Header
for node in &self.nodes {
size += 12; // id + level
size += node.vector.len() * 4; // vector
size += 4; // num_layers
for layer in &node.connections {
size += 4 + layer.len() * 8; // count + connection IDs
}
}
size
}
}
// ============================================
// WASM Exports
// ============================================
static mut HNSW_INDEX: Option<HnswIndex> = None;
/// Create HNSW index
#[no_mangle]
pub extern "C" fn hnsw_create(dim: u32, metric: u8, m: u32, ef_construction: u32) -> i32 {
let config = HnswConfig {
m: m as usize,
m_max_0: (m * 2) as usize,
ef_construction: ef_construction as usize,
ef_search: 50,
level_mult: 1.0 / (m as f32).ln(),
};
unsafe {
HNSW_INDEX = Some(HnswIndex::new(
dim as usize,
DistanceMetric::from_u8(metric),
config,
));
}
0
}
/// Insert vector into HNSW
#[no_mangle]
pub extern "C" fn hnsw_insert(id: u64, vector_ptr: *const f32, len: u32) -> i32 {
unsafe {
if let Some(index) = HNSW_INDEX.as_mut() {
let vector = core::slice::from_raw_parts(vector_ptr, len as usize).to_vec();
if index.insert(id, vector) { 0 } else { -1 }
} else {
-1
}
}
}
/// Search HNSW index
#[no_mangle]
pub extern "C" fn hnsw_search(
query_ptr: *const f32,
query_len: u32,
k: u32,
ef: u32,
out_ids: *mut u64,
out_distances: *mut f32,
) -> u32 {
unsafe {
if let Some(index) = HNSW_INDEX.as_ref() {
let query = core::slice::from_raw_parts(query_ptr, query_len as usize);
let results = index.search_with_ef(query, k as usize, ef as usize);
let ids = core::slice::from_raw_parts_mut(out_ids, results.len());
let distances = core::slice::from_raw_parts_mut(out_distances, results.len());
for (i, (id, dist)) in results.iter().enumerate() {
ids[i] = *id;
distances[i] = *dist;
}
results.len() as u32
} else {
0
}
}
}
/// Get HNSW index size
#[no_mangle]
pub extern "C" fn hnsw_size() -> u32 {
unsafe {
HNSW_INDEX.as_ref().map(|i| i.len() as u32).unwrap_or(0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hnsw_insert_search() {
let mut index = HnswIndex::with_defaults(4, DistanceMetric::Euclidean);
// Insert some vectors
for i in 0..100u64 {
let v = vec![i as f32, 0.0, 0.0, 0.0];
assert!(index.insert(i, v));
}
assert_eq!(index.len(), 100);
// Search for closest to [50, 0, 0, 0]
let query = vec![50.0, 0.0, 0.0, 0.0];
let results = index.search(&query, 5);
assert!(!results.is_empty());
// HNSW is approximate - verify we get results and distance is reasonable
let (closest_id, closest_dist) = results[0];
// The closest vector should have a reasonable distance (less than 25)
assert!(closest_dist < 25.0, "Distance too large: {}", closest_dist);
// Result should be somewhere in the index
assert!(closest_id < 100, "Invalid ID: {}", closest_id);
}
#[test]
fn test_hnsw_cosine() {
let mut index = HnswIndex::with_defaults(3, DistanceMetric::Cosine);
// Insert normalized vectors
index.insert(1, vec![1.0, 0.0, 0.0]);
index.insert(2, vec![0.0, 1.0, 0.0]);
index.insert(3, vec![0.707, 0.707, 0.0]);
let query = vec![1.0, 0.0, 0.0];
let results = index.search(&query, 3);
assert_eq!(results[0].0, 1); // Exact match first
}
}

View File

@@ -0,0 +1,352 @@
//! iOS Capability Detection & Optimization Module
//!
//! Provides runtime detection of iOS-specific features and optimization hints.
//! Works with both WasmKit native and Safari WebAssembly runtimes.
// ============================================
// Capability Flags
// ============================================
/// iOS device capability flags (bit flags)
#[repr(u32)]
pub enum Capability {
/// WASM SIMD128 support (iOS 16.4+)
Simd128 = 1 << 0,
/// Bulk memory operations
BulkMemory = 1 << 1,
/// Mutable globals
MutableGlobals = 1 << 2,
/// Reference types
ReferenceTypes = 1 << 3,
/// Multi-value returns
MultiValue = 1 << 4,
/// Tail call optimization
TailCall = 1 << 5,
/// Relaxed SIMD (iOS 17+)
RelaxedSimd = 1 << 6,
/// Exception handling
ExceptionHandling = 1 << 7,
/// Memory64 (large memory)
Memory64 = 1 << 8,
/// Threads (SharedArrayBuffer)
Threads = 1 << 9,
}
/// Runtime capabilities structure
#[derive(Clone, Debug, Default)]
pub struct RuntimeCapabilities {
/// Bitfield of supported capabilities
pub flags: u32,
/// Estimated CPU cores (for parallelism hints)
pub cpu_cores: u8,
/// Available memory in MB
pub memory_mb: u32,
/// Device generation hint (A11=11, A12=12, etc.)
pub device_gen: u8,
/// iOS version major (16, 17, etc.)
pub ios_version: u8,
}
impl RuntimeCapabilities {
/// Check if a capability is available
#[inline]
pub fn has(&self, cap: Capability) -> bool {
(self.flags & (cap as u32)) != 0
}
/// Check if SIMD is available
#[inline]
pub fn has_simd(&self) -> bool {
self.has(Capability::Simd128)
}
/// Check if relaxed SIMD is available (FMA, etc.)
#[inline]
pub fn has_relaxed_simd(&self) -> bool {
self.has(Capability::RelaxedSimd)
}
/// Check if threading is available
#[inline]
pub fn has_threads(&self) -> bool {
self.has(Capability::Threads)
}
/// Get recommended batch size for operations
#[inline]
pub fn recommended_batch_size(&self) -> usize {
if self.has_simd() {
if self.device_gen >= 15 { 256 } // A15+ (iPhone 13+)
else if self.device_gen >= 13 { 128 } // A13-A14
else { 64 } // A11-A12
} else {
32 // Fallback
}
}
/// Get recommended embedding cache size
#[inline]
pub fn recommended_cache_size(&self) -> usize {
let base = if self.memory_mb >= 4096 { 1000 } // 4GB+ devices
else if self.memory_mb >= 2048 { 500 }
else { 100 };
base
}
}
// ============================================
// Compile-time Detection
// ============================================
/// Detect capabilities at compile time
pub const fn compile_time_capabilities() -> u32 {
let mut flags = 0u32;
// SIMD128
if cfg!(target_feature = "simd128") {
flags |= Capability::Simd128 as u32;
}
// Bulk memory (always enabled in our build)
if cfg!(target_feature = "bulk-memory") {
flags |= Capability::BulkMemory as u32;
}
// Mutable globals (always enabled in our build)
if cfg!(target_feature = "mutable-globals") {
flags |= Capability::MutableGlobals as u32;
}
flags
}
/// Get compile-time capability report
#[no_mangle]
pub extern "C" fn get_compile_capabilities() -> u32 {
compile_time_capabilities()
}
// ============================================
// Optimization Strategies
// ============================================
/// Optimization strategy for different device tiers
#[derive(Clone, Copy, Debug, PartialEq)]
#[repr(u8)]
pub enum OptimizationTier {
/// Minimal - older devices, focus on memory
Minimal = 0,
/// Balanced - mid-range devices
Balanced = 1,
/// Performance - high-end devices, maximize speed
Performance = 2,
/// Ultra - latest devices with all features
Ultra = 3,
}
impl OptimizationTier {
/// Determine tier from capabilities
pub fn from_capabilities(caps: &RuntimeCapabilities) -> Self {
if caps.device_gen >= 15 && caps.has_relaxed_simd() {
OptimizationTier::Ultra
} else if caps.device_gen >= 13 && caps.has_simd() {
OptimizationTier::Performance
} else if caps.has_simd() {
OptimizationTier::Balanced
} else {
OptimizationTier::Minimal
}
}
/// Get embedding dimension for this tier
pub fn embedding_dim(&self) -> usize {
match self {
OptimizationTier::Ultra => 128,
OptimizationTier::Performance => 64,
OptimizationTier::Balanced => 64,
OptimizationTier::Minimal => 32,
}
}
/// Get attention heads for this tier
pub fn attention_heads(&self) -> usize {
match self {
OptimizationTier::Ultra => 8,
OptimizationTier::Performance => 4,
OptimizationTier::Balanced => 4,
OptimizationTier::Minimal => 2,
}
}
/// Get Q-learning state buckets for this tier
pub fn state_buckets(&self) -> usize {
match self {
OptimizationTier::Ultra => 64,
OptimizationTier::Performance => 32,
OptimizationTier::Balanced => 16,
OptimizationTier::Minimal => 8,
}
}
}
// ============================================
// Memory Optimization
// ============================================
/// Memory pool configuration for iOS
#[derive(Clone, Debug)]
pub struct MemoryConfig {
/// Main pool size in bytes
pub main_pool_bytes: usize,
/// Embedding cache entries
pub cache_entries: usize,
/// History buffer size
pub history_size: usize,
/// Use memory-mapped I/O hint
pub use_mmap: bool,
}
impl MemoryConfig {
/// Create config for given optimization tier
pub fn for_tier(tier: OptimizationTier) -> Self {
match tier {
OptimizationTier::Ultra => Self {
main_pool_bytes: 4 * 1024 * 1024, // 4MB
cache_entries: 1000,
history_size: 200,
use_mmap: true,
},
OptimizationTier::Performance => Self {
main_pool_bytes: 2 * 1024 * 1024, // 2MB
cache_entries: 500,
history_size: 100,
use_mmap: true,
},
OptimizationTier::Balanced => Self {
main_pool_bytes: 1 * 1024 * 1024, // 1MB
cache_entries: 200,
history_size: 50,
use_mmap: false,
},
OptimizationTier::Minimal => Self {
main_pool_bytes: 512 * 1024, // 512KB
cache_entries: 100,
history_size: 25,
use_mmap: false,
},
}
}
}
// ============================================
// Swift Bridge Info
// ============================================
/// Information for Swift integration
#[repr(C)]
pub struct SwiftBridgeInfo {
/// WASM module version
pub version_major: u8,
pub version_minor: u8,
pub version_patch: u8,
/// Feature flags
pub feature_flags: u32,
/// Recommended embedding dimension
pub embedding_dim: u16,
/// Recommended batch size
pub batch_size: u16,
}
/// Get bridge info for Swift
#[no_mangle]
pub extern "C" fn get_bridge_info() -> SwiftBridgeInfo {
SwiftBridgeInfo {
version_major: 0,
version_minor: 1,
version_patch: 0,
feature_flags: compile_time_capabilities(),
embedding_dim: 64,
batch_size: if cfg!(target_feature = "simd128") { 128 } else { 32 },
}
}
// ============================================
// Neural Engine Offload Hints
// ============================================
/// Operations that could benefit from Neural Engine offload
#[derive(Clone, Copy, Debug, PartialEq)]
#[repr(u8)]
pub enum NeuralEngineOp {
/// Batch embedding generation
BatchEmbed = 0,
/// Large matrix multiply (attention)
MatMul = 1,
/// Softmax over large sequences
Softmax = 2,
/// Similarity search over many vectors
BatchSimilarity = 3,
}
/// Check if operation should be offloaded to Neural Engine
pub fn should_offload_to_ane(op: NeuralEngineOp, size: usize) -> bool {
// Neural Engine is efficient for larger batch sizes
match op {
NeuralEngineOp::BatchEmbed => size >= 50,
NeuralEngineOp::MatMul => size >= 100,
NeuralEngineOp::Softmax => size >= 256,
NeuralEngineOp::BatchSimilarity => size >= 100,
}
}
// ============================================
// Performance Hints Export
// ============================================
/// Get recommended parameters for given device memory (MB)
#[no_mangle]
pub extern "C" fn get_recommended_config(memory_mb: u32) -> u64 {
// Pack config into u64: [cache_size:16][batch_size:16][dim:16][heads:16]
let (cache, batch, dim, heads) = if memory_mb >= 4096 {
(1000u16, 256u16, 128u16, 8u16)
} else if memory_mb >= 2048 {
(500u16, 128u16, 64u16, 4u16)
} else if memory_mb >= 1024 {
(200u16, 64u16, 64u16, 4u16)
} else {
(100u16, 32u16, 32u16, 2u16)
};
((cache as u64) << 48) | ((batch as u64) << 32) | ((dim as u64) << 16) | (heads as u64)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compile_capabilities() {
let caps = compile_time_capabilities();
// Should have bulk memory and mutable globals at minimum
assert!(caps != 0 || !cfg!(target_feature = "bulk-memory"));
}
#[test]
fn test_optimization_tier() {
let caps = RuntimeCapabilities {
flags: Capability::Simd128 as u32,
cpu_cores: 6,
memory_mb: 4096,
device_gen: 14,
ios_version: 17,
};
let tier = OptimizationTier::from_capabilities(&caps);
assert_eq!(tier, OptimizationTier::Performance);
}
#[test]
fn test_memory_config() {
let config = MemoryConfig::for_tier(OptimizationTier::Performance);
assert_eq!(config.cache_entries, 500);
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,354 @@
//! Q-Learning Module for iOS WASM
//!
//! Lightweight reinforcement learning for adaptive recommendations.
//! Uses tabular Q-learning with function approximation for state generalization.
/// Maximum number of actions (content recommendations)
const MAX_ACTIONS: usize = 100;
/// State discretization buckets
const STATE_BUCKETS: usize = 16;
/// User interaction types
#[derive(Clone, Copy, Debug, PartialEq)]
#[repr(u8)]
pub enum InteractionType {
/// User viewed content
View = 0,
/// User liked/saved content
Like = 1,
/// User shared content
Share = 2,
/// User skipped content
Skip = 3,
/// User completed content (video/audio)
Complete = 4,
/// User dismissed/hid content
Dismiss = 5,
}
impl InteractionType {
/// Convert interaction to reward signal
#[inline]
pub fn to_reward(self) -> f32 {
match self {
InteractionType::View => 0.1,
InteractionType::Like => 0.8,
InteractionType::Share => 1.0,
InteractionType::Skip => -0.1,
InteractionType::Complete => 0.6,
InteractionType::Dismiss => -0.5,
}
}
}
/// User interaction event
#[derive(Clone, Debug)]
pub struct UserInteraction {
/// Content ID that was interacted with
pub content_id: u64,
/// Type of interaction
pub interaction: InteractionType,
/// Time spent in seconds
pub time_spent: f32,
/// Position in recommendation list (0-indexed)
pub position: u8,
}
/// Q-Learning agent for personalized recommendations
pub struct QLearner {
/// Q-values: state_bucket x action -> value
q_table: Vec<f32>,
/// Learning rate (alpha)
learning_rate: f32,
/// Discount factor (gamma)
discount: f32,
/// Exploration rate (epsilon)
exploration: f32,
/// Number of state buckets
state_dim: usize,
/// Number of actions
action_dim: usize,
/// Visit counts for UCB exploration
visit_counts: Vec<u32>,
/// Total updates
total_updates: u64,
}
impl QLearner {
/// Create a new Q-learner
pub fn new(action_dim: usize) -> Self {
let action_dim = action_dim.min(MAX_ACTIONS);
let state_dim = STATE_BUCKETS;
let table_size = state_dim * action_dim;
Self {
q_table: vec![0.0; table_size],
learning_rate: 0.1,
discount: 0.95,
exploration: 0.1,
state_dim,
action_dim,
visit_counts: vec![0; table_size],
total_updates: 0,
}
}
/// Create with custom hyperparameters
pub fn with_params(
action_dim: usize,
learning_rate: f32,
discount: f32,
exploration: f32,
) -> Self {
let mut learner = Self::new(action_dim);
learner.learning_rate = learning_rate.clamp(0.001, 1.0);
learner.discount = discount.clamp(0.0, 1.0);
learner.exploration = exploration.clamp(0.0, 1.0);
learner
}
/// Discretize state embedding to bucket index
#[inline]
fn discretize_state(&self, state_embedding: &[f32]) -> usize {
if state_embedding.is_empty() {
return 0;
}
// Use first few dimensions to compute hash
let mut hash: u32 = 0;
for (i, &val) in state_embedding.iter().take(8).enumerate() {
let quantized = ((val + 1.0) * 127.0) as u32;
hash = hash.wrapping_add(quantized << (i * 4));
}
(hash as usize) % self.state_dim
}
/// Get Q-value for state-action pair
#[inline]
fn get_q(&self, state: usize, action: usize) -> f32 {
let idx = state * self.action_dim + action;
if idx < self.q_table.len() {
self.q_table[idx]
} else {
0.0
}
}
/// Set Q-value for state-action pair
#[inline]
fn set_q(&mut self, state: usize, action: usize, value: f32) {
let idx = state * self.action_dim + action;
if idx < self.q_table.len() {
self.q_table[idx] = value;
self.visit_counts[idx] += 1;
}
}
/// Select action using epsilon-greedy with UCB exploration bonus
pub fn select_action(&self, state_embedding: &[f32], rng_seed: u32) -> usize {
let state = self.discretize_state(state_embedding);
// Epsilon-greedy exploration
let explore_threshold = (rng_seed % 1000) as f32 / 1000.0;
if explore_threshold < self.exploration {
// Random action
return (rng_seed as usize) % self.action_dim;
}
// Greedy action with UCB bonus
let mut best_action = 0;
let mut best_value = f32::NEG_INFINITY;
let total_visits = self.total_updates.max(1) as f32;
for action in 0..self.action_dim {
let q_val = self.get_q(state, action);
let visits = self.visit_counts[state * self.action_dim + action].max(1) as f32;
// UCB exploration bonus
let ucb_bonus = (2.0 * total_visits.ln() / visits).sqrt() * 0.5;
let value = q_val + ucb_bonus;
if value > best_value {
best_value = value;
best_action = action;
}
}
best_action
}
/// Update Q-value based on interaction
pub fn update(
&mut self,
state_embedding: &[f32],
action: usize,
interaction: &UserInteraction,
next_state_embedding: &[f32],
) {
let state = self.discretize_state(state_embedding);
let next_state = self.discretize_state(next_state_embedding);
// Compute reward
let base_reward = interaction.interaction.to_reward();
let time_bonus = (interaction.time_spent / 60.0).min(1.0) * 0.2;
let position_bonus = (1.0 - interaction.position as f32 / 10.0).max(0.0) * 0.1;
let reward = base_reward + time_bonus + position_bonus;
// Find max Q-value for next state
let mut max_next_q = f32::NEG_INFINITY;
for a in 0..self.action_dim {
let q = self.get_q(next_state, a);
if q > max_next_q {
max_next_q = q;
}
}
if max_next_q == f32::NEG_INFINITY {
max_next_q = 0.0;
}
// Q-learning update
let current_q = self.get_q(state, action);
let td_target = reward + self.discount * max_next_q;
let new_q = current_q + self.learning_rate * (td_target - current_q);
self.set_q(state, action, new_q);
self.total_updates += 1;
// Decay exploration over time
if self.total_updates % 100 == 0 {
self.exploration = (self.exploration * 0.99).max(0.01);
}
}
/// Get action rankings for a state (returns sorted action indices)
pub fn rank_actions(&self, state_embedding: &[f32]) -> Vec<usize> {
let state = self.discretize_state(state_embedding);
let mut action_values: Vec<(usize, f32)> = (0..self.action_dim)
.map(|a| (a, self.get_q(state, a)))
.collect();
action_values.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(core::cmp::Ordering::Equal));
action_values.into_iter().map(|(a, _)| a).collect()
}
/// Serialize Q-table to bytes for persistence
pub fn serialize(&self) -> Vec<u8> {
let mut bytes = Vec::with_capacity(self.q_table.len() * 4 + 32);
// Header
bytes.extend_from_slice(&(self.state_dim as u32).to_le_bytes());
bytes.extend_from_slice(&(self.action_dim as u32).to_le_bytes());
bytes.extend_from_slice(&self.learning_rate.to_le_bytes());
bytes.extend_from_slice(&self.discount.to_le_bytes());
bytes.extend_from_slice(&self.exploration.to_le_bytes());
bytes.extend_from_slice(&self.total_updates.to_le_bytes());
// Q-table
for &q in &self.q_table {
bytes.extend_from_slice(&q.to_le_bytes());
}
bytes
}
/// Deserialize Q-table from bytes
pub fn deserialize(bytes: &[u8]) -> Option<Self> {
// Header: 4+4+4+4+4+8 = 28 bytes
const HEADER_SIZE: usize = 28;
if bytes.len() < HEADER_SIZE {
return None;
}
let state_dim = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
let action_dim = u32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]) as usize;
let learning_rate = f32::from_le_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]);
let discount = f32::from_le_bytes([bytes[12], bytes[13], bytes[14], bytes[15]]);
let exploration = f32::from_le_bytes([bytes[16], bytes[17], bytes[18], bytes[19]]);
let total_updates = u64::from_le_bytes([
bytes[20], bytes[21], bytes[22], bytes[23],
bytes[24], bytes[25], bytes[26], bytes[27],
]);
let table_size = state_dim * action_dim;
let expected_len = HEADER_SIZE + table_size * 4;
if bytes.len() < expected_len {
return None;
}
let mut q_table = Vec::with_capacity(table_size);
for i in 0..table_size {
let offset = HEADER_SIZE + i * 4;
let q = f32::from_le_bytes([
bytes[offset], bytes[offset + 1], bytes[offset + 2], bytes[offset + 3],
]);
q_table.push(q);
}
Some(Self {
q_table,
learning_rate,
discount,
exploration,
state_dim,
action_dim,
visit_counts: vec![0; table_size],
total_updates,
})
}
/// Get current exploration rate
pub fn exploration_rate(&self) -> f32 {
self.exploration
}
/// Get total number of updates
pub fn update_count(&self) -> u64 {
self.total_updates
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_qlearner_creation() {
let learner = QLearner::new(50);
assert_eq!(learner.action_dim, 50);
}
#[test]
fn test_action_selection() {
let learner = QLearner::new(10);
let state = vec![0.5; 64];
let action = learner.select_action(&state, 42);
assert!(action < 10);
}
#[test]
fn test_serialization_roundtrip() {
let mut learner = QLearner::with_params(10, 0.1, 0.9, 0.2);
// Do some updates
let state = vec![0.5; 64];
let interaction = UserInteraction {
content_id: 1,
interaction: InteractionType::Like,
time_spent: 30.0,
position: 0,
};
learner.update(&state, 0, &interaction, &state);
// Serialize and deserialize
let bytes = learner.serialize();
let restored = QLearner::deserialize(&bytes).unwrap();
assert_eq!(restored.action_dim, learner.action_dim);
assert_eq!(restored.total_updates, learner.total_updates);
}
}

View File

@@ -0,0 +1,531 @@
//! Quantization Techniques for iOS/Browser WASM
//!
//! Memory-efficient vector compression for mobile devices.
//! - Scalar Quantization: 4x compression (f32 → u8)
//! - Binary Quantization: 32x compression (f32 → 1 bit)
//! - Product Quantization: 8-16x compression
use std::vec::Vec;
// ============================================
// Scalar Quantization (4x compression)
// ============================================
/// Scalar-quantized vector (f32 → u8)
#[derive(Clone, Debug)]
pub struct ScalarQuantized {
/// Quantized values
pub data: Vec<u8>,
/// Minimum value for reconstruction
pub min: f32,
/// Scale factor for reconstruction
pub scale: f32,
}
impl ScalarQuantized {
/// Quantize a float vector to u8
pub fn quantize(vector: &[f32]) -> Self {
if vector.is_empty() {
return Self {
data: vec![],
min: 0.0,
scale: 1.0,
};
}
let min = vector.iter().cloned().fold(f32::INFINITY, f32::min);
let max = vector.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let scale = if (max - min).abs() < f32::EPSILON {
1.0
} else {
(max - min) / 255.0
};
let data = vector
.iter()
.map(|&v| ((v - min) / scale).round().clamp(0.0, 255.0) as u8)
.collect();
Self { data, min, scale }
}
/// Reconstruct approximate float vector
pub fn reconstruct(&self) -> Vec<f32> {
self.data
.iter()
.map(|&v| self.min + (v as f32) * self.scale)
.collect()
}
/// Fast distance calculation in quantized space
pub fn distance(&self, other: &Self) -> f32 {
let mut sum = 0i32;
for (&a, &b) in self.data.iter().zip(other.data.iter()) {
let diff = a as i32 - b as i32;
sum += diff * diff;
}
(sum as f32).sqrt() * self.scale.max(other.scale)
}
/// Asymmetric distance (query is float, database is quantized)
pub fn asymmetric_distance(&self, query: &[f32]) -> f32 {
let len = self.data.len().min(query.len());
let mut sum = 0.0f32;
for i in 0..len {
let reconstructed = self.min + (self.data[i] as f32) * self.scale;
let diff = reconstructed - query[i];
sum += diff * diff;
}
sum.sqrt()
}
/// Get memory size in bytes
pub fn memory_size(&self) -> usize {
self.data.len() + 8 // data + min + scale
}
/// Serialize to bytes
pub fn serialize(&self) -> Vec<u8> {
let mut bytes = Vec::with_capacity(8 + self.data.len());
bytes.extend_from_slice(&self.min.to_le_bytes());
bytes.extend_from_slice(&self.scale.to_le_bytes());
bytes.extend_from_slice(&self.data);
bytes
}
/// Deserialize from bytes
pub fn deserialize(bytes: &[u8]) -> Option<Self> {
if bytes.len() < 8 {
return None;
}
let min = f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
let scale = f32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]);
let data = bytes[8..].to_vec();
Some(Self { data, min, scale })
}
/// Estimate serialized size
pub fn serialized_size(&self) -> usize {
8 + self.data.len()
}
}
// ============================================
// Binary Quantization (32x compression)
// ============================================
/// Binary-quantized vector (f32 → 1 bit)
#[derive(Clone, Debug)]
pub struct BinaryQuantized {
/// Packed bits (8 dimensions per byte)
pub bits: Vec<u8>,
/// Original dimension count
pub dimensions: usize,
}
impl BinaryQuantized {
/// Quantize float vector to binary (sign-based)
pub fn quantize(vector: &[f32]) -> Self {
let dimensions = vector.len();
let num_bytes = (dimensions + 7) / 8;
let mut bits = vec![0u8; num_bytes];
for (i, &v) in vector.iter().enumerate() {
if v > 0.0 {
let byte_idx = i / 8;
let bit_idx = i % 8;
bits[byte_idx] |= 1 << bit_idx;
}
}
Self { bits, dimensions }
}
/// Quantize with threshold (not just sign)
pub fn quantize_with_threshold(vector: &[f32], threshold: f32) -> Self {
let dimensions = vector.len();
let num_bytes = (dimensions + 7) / 8;
let mut bits = vec![0u8; num_bytes];
for (i, &v) in vector.iter().enumerate() {
if v > threshold {
let byte_idx = i / 8;
let bit_idx = i % 8;
bits[byte_idx] |= 1 << bit_idx;
}
}
Self { bits, dimensions }
}
/// Hamming distance between two binary vectors
pub fn distance(&self, other: &Self) -> u32 {
let mut distance = 0u32;
for (&a, &b) in self.bits.iter().zip(other.bits.iter()) {
distance += (a ^ b).count_ones();
}
distance
}
/// Asymmetric distance to float query
pub fn asymmetric_distance(&self, query: &[f32]) -> f32 {
let mut distance = 0u32;
for (i, &q) in query.iter().take(self.dimensions).enumerate() {
let byte_idx = i / 8;
let bit_idx = i % 8;
let bit = (self.bits.get(byte_idx).unwrap_or(&0) >> bit_idx) & 1;
let query_bit = if q > 0.0 { 1 } else { 0 };
if bit != query_bit {
distance += 1;
}
}
distance as f32
}
/// Reconstruct to +1/-1 vector
pub fn reconstruct(&self) -> Vec<f32> {
let mut result = Vec::with_capacity(self.dimensions);
for i in 0..self.dimensions {
let byte_idx = i / 8;
let bit_idx = i % 8;
let bit = (self.bits.get(byte_idx).unwrap_or(&0) >> bit_idx) & 1;
result.push(if bit == 1 { 1.0 } else { -1.0 });
}
result
}
/// Get memory size in bytes
pub fn memory_size(&self) -> usize {
self.bits.len() + 8 // bits + dimensions (as usize)
}
/// Serialize to bytes
pub fn serialize(&self) -> Vec<u8> {
let mut bytes = Vec::with_capacity(4 + self.bits.len());
bytes.extend_from_slice(&(self.dimensions as u32).to_le_bytes());
bytes.extend_from_slice(&self.bits);
bytes
}
/// Deserialize from bytes
pub fn deserialize(bytes: &[u8]) -> Option<Self> {
if bytes.len() < 4 {
return None;
}
let dimensions = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
let bits = bytes[4..].to_vec();
Some(Self { bits, dimensions })
}
/// Estimate serialized size
pub fn serialized_size(&self) -> usize {
4 + self.bits.len()
}
}
// ============================================
// Simple Product Quantization (8-16x compression)
// ============================================
/// Product-quantized vector
#[derive(Clone, Debug)]
pub struct ProductQuantized {
/// Quantized codes (one per subspace)
pub codes: Vec<u8>,
/// Number of subspaces
pub num_subspaces: usize,
}
/// Product quantization codebook
#[derive(Clone, Debug)]
pub struct PQCodebook {
/// Centroids for each subspace [subspace][centroid][dim]
pub centroids: Vec<Vec<Vec<f32>>>,
/// Number of subspaces
pub num_subspaces: usize,
/// Dimension per subspace
pub subspace_dim: usize,
/// Number of centroids (usually 256 for u8 codes)
pub num_centroids: usize,
}
impl PQCodebook {
/// Train a PQ codebook using k-means
pub fn train(
vectors: &[Vec<f32>],
num_subspaces: usize,
num_centroids: usize,
iterations: usize,
) -> Self {
if vectors.is_empty() {
return Self {
centroids: vec![],
num_subspaces,
subspace_dim: 0,
num_centroids,
};
}
let dim = vectors[0].len();
let subspace_dim = dim / num_subspaces;
let mut centroids = Vec::with_capacity(num_subspaces);
// Train each subspace independently
for s in 0..num_subspaces {
let start = s * subspace_dim;
let end = start + subspace_dim;
// Extract subvectors
let subvectors: Vec<Vec<f32>> = vectors
.iter()
.map(|v| v[start..end].to_vec())
.collect();
// Run k-means
let subspace_centroids = kmeans(&subvectors, num_centroids, iterations);
centroids.push(subspace_centroids);
}
Self {
centroids,
num_subspaces,
subspace_dim,
num_centroids,
}
}
/// Encode a vector using this codebook
pub fn encode(&self, vector: &[f32]) -> ProductQuantized {
let mut codes = Vec::with_capacity(self.num_subspaces);
for (s, subspace_centroids) in self.centroids.iter().enumerate() {
let start = s * self.subspace_dim;
let end = start + self.subspace_dim;
let subvector = &vector[start..end];
// Find nearest centroid
let code = subspace_centroids
.iter()
.enumerate()
.map(|(i, c)| {
let dist = euclidean_squared(subvector, c);
(i, dist)
})
.min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
.map(|(i, _)| i as u8)
.unwrap_or(0);
codes.push(code);
}
ProductQuantized {
codes,
num_subspaces: self.num_subspaces,
}
}
/// Decode a PQ vector back to approximate floats
pub fn decode(&self, pq: &ProductQuantized) -> Vec<f32> {
let mut result = Vec::with_capacity(self.num_subspaces * self.subspace_dim);
for (s, &code) in pq.codes.iter().enumerate() {
if s < self.centroids.len() && (code as usize) < self.centroids[s].len() {
result.extend_from_slice(&self.centroids[s][code as usize]);
}
}
result
}
/// Compute distance using precomputed distance table (ADC)
pub fn asymmetric_distance(&self, pq: &ProductQuantized, query: &[f32]) -> f32 {
let mut dist = 0.0f32;
for (s, &code) in pq.codes.iter().enumerate() {
let start = s * self.subspace_dim;
let end = start + self.subspace_dim;
let query_sub = &query[start..end];
if s < self.centroids.len() && (code as usize) < self.centroids[s].len() {
let centroid = &self.centroids[s][code as usize];
dist += euclidean_squared(query_sub, centroid);
}
}
dist.sqrt()
}
}
// ============================================
// Helper Functions
// ============================================
fn euclidean_squared(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(&x, &y)| {
let d = x - y;
d * d
})
.sum()
}
fn kmeans(vectors: &[Vec<f32>], k: usize, iterations: usize) -> Vec<Vec<f32>> {
if vectors.is_empty() || k == 0 {
return vec![];
}
let dim = vectors[0].len();
// Initialize centroids (first k vectors or random subset)
let mut centroids: Vec<Vec<f32>> = vectors.iter().take(k).cloned().collect();
// Pad if not enough vectors
while centroids.len() < k {
centroids.push(vec![0.0; dim]);
}
for _ in 0..iterations {
// Assign vectors to clusters
let mut assignments: Vec<Vec<Vec<f32>>> = vec![vec![]; k];
for vector in vectors {
let nearest = centroids
.iter()
.enumerate()
.map(|(i, c)| (i, euclidean_squared(vector, c)))
.min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
.map(|(i, _)| i)
.unwrap_or(0);
assignments[nearest].push(vector.clone());
}
// Update centroids
for (centroid, assigned) in centroids.iter_mut().zip(assignments.iter()) {
if !assigned.is_empty() {
for (i, c) in centroid.iter_mut().enumerate() {
*c = assigned.iter().map(|v| v[i]).sum::<f32>() / assigned.len() as f32;
}
}
}
}
centroids
}
// ============================================
// WASM Exports
// ============================================
/// Scalar quantize a vector
#[no_mangle]
pub extern "C" fn scalar_quantize(
input_ptr: *const f32,
len: u32,
out_data: *mut u8,
out_min: *mut f32,
out_scale: *mut f32,
) {
unsafe {
let input = core::slice::from_raw_parts(input_ptr, len as usize);
let sq = ScalarQuantized::quantize(input);
let out = core::slice::from_raw_parts_mut(out_data, sq.data.len());
out.copy_from_slice(&sq.data);
*out_min = sq.min;
*out_scale = sq.scale;
}
}
/// Binary quantize a vector
#[no_mangle]
pub extern "C" fn binary_quantize(
input_ptr: *const f32,
len: u32,
out_bits: *mut u8,
) -> u32 {
unsafe {
let input = core::slice::from_raw_parts(input_ptr, len as usize);
let bq = BinaryQuantized::quantize(input);
let out = core::slice::from_raw_parts_mut(out_bits, bq.bits.len());
out.copy_from_slice(&bq.bits);
bq.bits.len() as u32
}
}
/// Hamming distance between two binary vectors
#[no_mangle]
pub extern "C" fn hamming_distance(
a_ptr: *const u8,
b_ptr: *const u8,
len: u32,
) -> u32 {
unsafe {
let a = core::slice::from_raw_parts(a_ptr, len as usize);
let b = core::slice::from_raw_parts(b_ptr, len as usize);
let mut distance = 0u32;
for (&x, &y) in a.iter().zip(b.iter()) {
distance += (x ^ y).count_ones();
}
distance
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_scalar_quantization() {
let v = vec![0.0, 0.5, 1.0, 0.25, 0.75];
let sq = ScalarQuantized::quantize(&v);
let reconstructed = sq.reconstruct();
for (orig, recon) in v.iter().zip(reconstructed.iter()) {
assert!((orig - recon).abs() < 0.01);
}
}
#[test]
fn test_binary_quantization() {
let v = vec![1.0, -1.0, 0.5, -0.5];
let bq = BinaryQuantized::quantize(&v);
assert_eq!(bq.dimensions, 4);
assert_eq!(bq.bits.len(), 1);
assert_eq!(bq.bits[0], 0b0101); // positions 0 and 2 are positive
}
#[test]
fn test_hamming_distance() {
let v1 = vec![1.0, 1.0, 1.0, 1.0];
let v2 = vec![1.0, -1.0, 1.0, -1.0];
let bq1 = BinaryQuantized::quantize(&v1);
let bq2 = BinaryQuantized::quantize(&v2);
assert_eq!(bq1.distance(&bq2), 2);
}
#[test]
fn test_pq_encode_decode() {
let vectors: Vec<Vec<f32>> = (0..100)
.map(|i| vec![i as f32 / 100.0; 8])
.collect();
let codebook = PQCodebook::train(&vectors, 2, 16, 10);
let pq = codebook.encode(&vectors[50]);
let decoded = codebook.decode(&pq);
assert_eq!(decoded.len(), 8);
}
}

View File

@@ -0,0 +1,487 @@
//! SIMD-Optimized Vector Operations for iOS WASM
//!
//! Provides 4-8x speedup on iOS devices with Safari 16.4+ (iOS 16.4+)
//! Uses WebAssembly SIMD128 instructions for vectorized math.
//!
//! ## Supported Operations
//! - Dot product (cosine similarity numerator)
//! - L2 distance (Euclidean)
//! - Vector normalization
//! - Batch similarity computation
//!
//! ## Requirements
//! - Build with: `RUSTFLAGS="-C target-feature=+simd128"`
//! - Runtime: Safari 16.4+ / iOS 16.4+ / WasmKit with SIMD
#[cfg(target_feature = "simd128")]
use core::arch::wasm32::*;
/// Check if SIMD is available at compile time
#[inline]
pub const fn simd_available() -> bool {
cfg!(target_feature = "simd128")
}
// ============================================
// SIMD-Optimized Operations
// ============================================
#[cfg(target_feature = "simd128")]
mod simd_impl {
use super::*;
/// SIMD dot product - processes 4 floats per instruction
///
/// Performance: ~4x faster than scalar for vectors >= 16 elements
#[inline]
pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len());
let len = a.len();
let simd_len = len - (len % 4);
let mut sum = f32x4_splat(0.0);
// Process 4 elements at a time
let mut i = 0;
while i < simd_len {
unsafe {
let va = v128_load(a.as_ptr().add(i) as *const v128);
let vb = v128_load(b.as_ptr().add(i) as *const v128);
sum = f32x4_add(sum, f32x4_mul(va, vb));
}
i += 4;
}
// Horizontal sum of SIMD lanes
let mut result = f32x4_extract_lane::<0>(sum)
+ f32x4_extract_lane::<1>(sum)
+ f32x4_extract_lane::<2>(sum)
+ f32x4_extract_lane::<3>(sum);
// Handle remainder
for j in simd_len..len {
result += a[j] * b[j];
}
result
}
/// SIMD L2 norm (vector magnitude)
#[inline]
pub fn l2_norm(v: &[f32]) -> f32 {
dot_product(v, v).sqrt()
}
/// SIMD L2 distance between two vectors
#[inline]
pub fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len());
let len = a.len();
let simd_len = len - (len % 4);
let mut sum = f32x4_splat(0.0);
let mut i = 0;
while i < simd_len {
unsafe {
let va = v128_load(a.as_ptr().add(i) as *const v128);
let vb = v128_load(b.as_ptr().add(i) as *const v128);
let diff = f32x4_sub(va, vb);
sum = f32x4_add(sum, f32x4_mul(diff, diff));
}
i += 4;
}
let mut result = f32x4_extract_lane::<0>(sum)
+ f32x4_extract_lane::<1>(sum)
+ f32x4_extract_lane::<2>(sum)
+ f32x4_extract_lane::<3>(sum);
for j in simd_len..len {
let diff = a[j] - b[j];
result += diff * diff;
}
result.sqrt()
}
/// SIMD vector normalization (in-place)
#[inline]
pub fn normalize(v: &mut [f32]) {
let norm = l2_norm(v);
if norm < 1e-8 {
return;
}
let len = v.len();
let simd_len = len - (len % 4);
let inv_norm = f32x4_splat(1.0 / norm);
let mut i = 0;
while i < simd_len {
unsafe {
let ptr = v.as_mut_ptr().add(i) as *mut v128;
let val = v128_load(ptr as *const v128);
let normalized = f32x4_mul(val, inv_norm);
v128_store(ptr, normalized);
}
i += 4;
}
let scalar_inv = 1.0 / norm;
for j in simd_len..len {
v[j] *= scalar_inv;
}
}
/// SIMD cosine similarity
#[inline]
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
let dot = dot_product(a, b);
let norm_a = l2_norm(a);
let norm_b = l2_norm(b);
if norm_a < 1e-8 || norm_b < 1e-8 {
return 0.0;
}
dot / (norm_a * norm_b)
}
/// Batch dot products - compute similarity of query against multiple vectors
/// Returns scores in the output slice
#[inline]
pub fn batch_dot_products(query: &[f32], vectors: &[&[f32]], out: &mut [f32]) {
for (i, vec) in vectors.iter().enumerate() {
if i < out.len() {
out[i] = dot_product(query, vec);
}
}
}
/// SIMD vector addition (out = a + b)
#[inline]
pub fn add(a: &[f32], b: &[f32], out: &mut [f32]) {
assert_eq!(a.len(), b.len());
assert_eq!(a.len(), out.len());
let len = a.len();
let simd_len = len - (len % 4);
let mut i = 0;
while i < simd_len {
unsafe {
let va = v128_load(a.as_ptr().add(i) as *const v128);
let vb = v128_load(b.as_ptr().add(i) as *const v128);
let sum = f32x4_add(va, vb);
v128_store(out.as_mut_ptr().add(i) as *mut v128, sum);
}
i += 4;
}
for j in simd_len..len {
out[j] = a[j] + b[j];
}
}
/// SIMD scalar multiply (out = a * scalar)
#[inline]
pub fn scale(a: &[f32], scalar: f32, out: &mut [f32]) {
assert_eq!(a.len(), out.len());
let len = a.len();
let simd_len = len - (len % 4);
let vscalar = f32x4_splat(scalar);
let mut i = 0;
while i < simd_len {
unsafe {
let va = v128_load(a.as_ptr().add(i) as *const v128);
let scaled = f32x4_mul(va, vscalar);
v128_store(out.as_mut_ptr().add(i) as *mut v128, scaled);
}
i += 4;
}
for j in simd_len..len {
out[j] = a[j] * scalar;
}
}
/// SIMD max element
#[inline]
pub fn max(v: &[f32]) -> f32 {
if v.is_empty() {
return f32::NEG_INFINITY;
}
let len = v.len();
let simd_len = len - (len % 4);
let mut max_vec = f32x4_splat(f32::NEG_INFINITY);
let mut i = 0;
while i < simd_len {
unsafe {
let val = v128_load(v.as_ptr().add(i) as *const v128);
max_vec = f32x4_pmax(max_vec, val);
}
i += 4;
}
let mut result = f32x4_extract_lane::<0>(max_vec)
.max(f32x4_extract_lane::<1>(max_vec))
.max(f32x4_extract_lane::<2>(max_vec))
.max(f32x4_extract_lane::<3>(max_vec));
for j in simd_len..len {
result = result.max(v[j]);
}
result
}
/// SIMD softmax (in-place, numerically stable)
pub fn softmax(v: &mut [f32]) {
if v.is_empty() {
return;
}
// Find max for numerical stability
let max_val = max(v);
// Subtract max and exp
let len = v.len();
let mut sum = 0.0f32;
for x in v.iter_mut() {
*x = (*x - max_val).exp();
sum += *x;
}
// Normalize
if sum > 1e-8 {
let inv_sum = 1.0 / sum;
let simd_len = len - (len % 4);
let vinv = f32x4_splat(inv_sum);
let mut i = 0;
while i < simd_len {
unsafe {
let ptr = v.as_mut_ptr().add(i) as *mut v128;
let val = v128_load(ptr as *const v128);
v128_store(ptr, f32x4_mul(val, vinv));
}
i += 4;
}
for j in simd_len..len {
v[j] *= inv_sum;
}
}
}
}
// ============================================
// Scalar Fallback (when SIMD not available)
// ============================================
#[cfg(not(target_feature = "simd128"))]
mod scalar_impl {
/// Scalar dot product fallback
#[inline]
pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
/// Scalar L2 norm fallback
#[inline]
pub fn l2_norm(v: &[f32]) -> f32 {
v.iter().map(|x| x * x).sum::<f32>().sqrt()
}
/// Scalar L2 distance fallback
#[inline]
pub fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| {
let d = x - y;
d * d
})
.sum::<f32>()
.sqrt()
}
/// Scalar normalize fallback
#[inline]
pub fn normalize(v: &mut [f32]) {
let norm = l2_norm(v);
if norm > 1e-8 {
for x in v.iter_mut() {
*x /= norm;
}
}
}
/// Scalar cosine similarity fallback
#[inline]
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
let dot = dot_product(a, b);
let norm_a = l2_norm(a);
let norm_b = l2_norm(b);
if norm_a < 1e-8 || norm_b < 1e-8 {
0.0
} else {
dot / (norm_a * norm_b)
}
}
/// Scalar batch dot products fallback
#[inline]
pub fn batch_dot_products(query: &[f32], vectors: &[&[f32]], out: &mut [f32]) {
for (i, vec) in vectors.iter().enumerate() {
if i < out.len() {
out[i] = dot_product(query, vec);
}
}
}
/// Scalar add fallback
#[inline]
pub fn add(a: &[f32], b: &[f32], out: &mut [f32]) {
for i in 0..a.len().min(b.len()).min(out.len()) {
out[i] = a[i] + b[i];
}
}
/// Scalar scale fallback
#[inline]
pub fn scale(a: &[f32], scalar: f32, out: &mut [f32]) {
for i in 0..a.len().min(out.len()) {
out[i] = a[i] * scalar;
}
}
/// Scalar max fallback
#[inline]
pub fn max(v: &[f32]) -> f32 {
v.iter().cloned().fold(f32::NEG_INFINITY, f32::max)
}
/// Scalar softmax fallback
pub fn softmax(v: &mut [f32]) {
let max_val = max(v);
let mut sum = 0.0f32;
for x in v.iter_mut() {
*x = (*x - max_val).exp();
sum += *x;
}
if sum > 1e-8 {
for x in v.iter_mut() {
*x /= sum;
}
}
}
}
// ============================================
// Public API (auto-selects SIMD or scalar)
// ============================================
#[cfg(target_feature = "simd128")]
pub use simd_impl::*;
#[cfg(not(target_feature = "simd128"))]
pub use scalar_impl::*;
// ============================================
// iOS-Specific Optimizations
// ============================================
/// Prefetch hint for upcoming memory access (no-op in WASM, hint for future)
#[inline]
pub fn prefetch(_ptr: *const f32) {
// WASM doesn't have prefetch, but this is a placeholder for future
// When WebAssembly gains prefetch hints, we can enable this
}
/// Aligned allocation hint for SIMD (16-byte alignment for v128)
#[inline]
pub const fn simd_alignment() -> usize {
16 // 128-bit SIMD requires 16-byte alignment
}
/// Check if a slice is properly aligned for SIMD
#[inline]
pub fn is_simd_aligned(ptr: *const f32) -> bool {
(ptr as usize) % simd_alignment() == 0
}
// ============================================
// Benchmarking Utilities
// ============================================
/// Benchmark a single dot product operation
#[no_mangle]
pub extern "C" fn bench_dot_product(a_ptr: *const f32, b_ptr: *const f32, len: u32) -> f32 {
unsafe {
let a = core::slice::from_raw_parts(a_ptr, len as usize);
let b = core::slice::from_raw_parts(b_ptr, len as usize);
dot_product(a, b)
}
}
/// Benchmark L2 distance
#[no_mangle]
pub extern "C" fn bench_l2_distance(a_ptr: *const f32, b_ptr: *const f32, len: u32) -> f32 {
unsafe {
let a = core::slice::from_raw_parts(a_ptr, len as usize);
let b = core::slice::from_raw_parts(b_ptr, len as usize);
l2_distance(a, b)
}
}
/// Get SIMD capability flag for runtime detection
#[no_mangle]
pub extern "C" fn has_simd() -> i32 {
if simd_available() { 1 } else { 0 }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dot_product() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let b = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
let result = dot_product(&a, &b);
assert!((result - 36.0).abs() < 0.001);
}
#[test]
fn test_l2_norm() {
let v = vec![3.0, 4.0];
let result = l2_norm(&v);
assert!((result - 5.0).abs() < 0.001);
}
#[test]
fn test_normalize() {
let mut v = vec![3.0, 4.0, 0.0, 0.0];
normalize(&mut v);
assert!((v[0] - 0.6).abs() < 0.001);
assert!((v[1] - 0.8).abs() < 0.001);
}
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0, 0.0];
let result = cosine_similarity(&a, &b);
assert!((result - 1.0).abs() < 0.001);
}
}