git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
359 lines
10 KiB
Rust
359 lines
10 KiB
Rust
//! Attention Mechanism Module for iOS WASM
|
|
//!
|
|
//! Lightweight self-attention for content ranking and sequence modeling.
|
|
//! Optimized for minimal memory footprint on mobile devices.
|
|
|
|
/// Maximum sequence length for attention
|
|
const MAX_SEQ_LEN: usize = 64;
|
|
|
|
/// Single attention head
|
|
pub struct AttentionHead {
|
|
/// Dimension of key/query/value
|
|
dim: usize,
|
|
/// Query projection weights
|
|
w_query: Vec<f32>,
|
|
/// Key projection weights
|
|
w_key: Vec<f32>,
|
|
/// Value projection weights
|
|
w_value: Vec<f32>,
|
|
/// Scaling factor (1/sqrt(dim))
|
|
scale: f32,
|
|
}
|
|
|
|
impl AttentionHead {
|
|
/// Create a new attention head with random initialization
|
|
pub fn new(input_dim: usize, head_dim: usize, seed: u32) -> Self {
|
|
let dim = head_dim;
|
|
let weight_size = input_dim * dim;
|
|
|
|
// Xavier initialization with deterministic pseudo-random
|
|
let std_dev = (2.0 / (input_dim + dim) as f32).sqrt();
|
|
|
|
let w_query = Self::init_weights(weight_size, seed, std_dev);
|
|
let w_key = Self::init_weights(weight_size, seed.wrapping_add(1), std_dev);
|
|
let w_value = Self::init_weights(weight_size, seed.wrapping_add(2), std_dev);
|
|
|
|
Self {
|
|
dim,
|
|
w_query,
|
|
w_key,
|
|
w_value,
|
|
scale: 1.0 / (dim as f32).sqrt(),
|
|
}
|
|
}
|
|
|
|
/// Initialize weights with pseudo-random values
|
|
fn init_weights(size: usize, seed: u32, std_dev: f32) -> Vec<f32> {
|
|
let mut weights = Vec::with_capacity(size);
|
|
let mut s = seed;
|
|
|
|
for _ in 0..size {
|
|
s = s.wrapping_mul(1103515245).wrapping_add(12345);
|
|
let uniform = ((s >> 16) as f32 / 32768.0) - 1.0;
|
|
weights.push(uniform * std_dev);
|
|
}
|
|
|
|
weights
|
|
}
|
|
|
|
/// Project input to query/key/value space
|
|
#[inline]
|
|
fn project(&self, input: &[f32], weights: &[f32]) -> Vec<f32> {
|
|
let input_dim = self.w_query.len() / self.dim;
|
|
let mut output = vec![0.0; self.dim];
|
|
|
|
for (i, o) in output.iter_mut().enumerate() {
|
|
for (j, &inp) in input.iter().take(input_dim).enumerate() {
|
|
let idx = j * self.dim + i;
|
|
if idx < weights.len() {
|
|
*o += inp * weights[idx];
|
|
}
|
|
}
|
|
}
|
|
|
|
output
|
|
}
|
|
|
|
/// Compute attention scores between query and key
|
|
#[inline]
|
|
fn attention_score(&self, query: &[f32], key: &[f32]) -> f32 {
|
|
let dot: f32 = query.iter().zip(key.iter()).map(|(q, k)| q * k).sum();
|
|
dot * self.scale
|
|
}
|
|
|
|
/// Apply softmax to attention scores
|
|
fn softmax(scores: &mut [f32]) {
|
|
if scores.is_empty() {
|
|
return;
|
|
}
|
|
|
|
// Numerical stability: subtract max
|
|
let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
|
|
|
let mut sum = 0.0;
|
|
for s in scores.iter_mut() {
|
|
*s = (*s - max_score).exp();
|
|
sum += *s;
|
|
}
|
|
|
|
if sum > 1e-8 {
|
|
for s in scores.iter_mut() {
|
|
*s /= sum;
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Compute self-attention over a sequence
|
|
pub fn forward(&self, sequence: &[Vec<f32>]) -> Vec<Vec<f32>> {
|
|
let seq_len = sequence.len().min(MAX_SEQ_LEN);
|
|
if seq_len == 0 {
|
|
return vec![];
|
|
}
|
|
|
|
// Project to Q, K, V
|
|
let queries: Vec<Vec<f32>> = sequence.iter().take(seq_len)
|
|
.map(|x| self.project(x, &self.w_query))
|
|
.collect();
|
|
let keys: Vec<Vec<f32>> = sequence.iter().take(seq_len)
|
|
.map(|x| self.project(x, &self.w_key))
|
|
.collect();
|
|
let values: Vec<Vec<f32>> = sequence.iter().take(seq_len)
|
|
.map(|x| self.project(x, &self.w_value))
|
|
.collect();
|
|
|
|
// Compute attention for each position
|
|
let mut outputs = Vec::with_capacity(seq_len);
|
|
|
|
for q in &queries {
|
|
// Compute attention scores
|
|
let mut scores: Vec<f32> = keys.iter()
|
|
.map(|k| self.attention_score(q, k))
|
|
.collect();
|
|
|
|
Self::softmax(&mut scores);
|
|
|
|
// Weighted sum of values
|
|
let mut output = vec![0.0; self.dim];
|
|
for (score, value) in scores.iter().zip(values.iter()) {
|
|
for (o, v) in output.iter_mut().zip(value.iter()) {
|
|
*o += score * v;
|
|
}
|
|
}
|
|
|
|
outputs.push(output);
|
|
}
|
|
|
|
outputs
|
|
}
|
|
|
|
/// Get output dimension
|
|
pub fn dim(&self) -> usize {
|
|
self.dim
|
|
}
|
|
}
|
|
|
|
/// Multi-head attention layer
|
|
pub struct MultiHeadAttention {
|
|
heads: Vec<AttentionHead>,
|
|
/// Output projection weights
|
|
w_out: Vec<f32>,
|
|
output_dim: usize,
|
|
}
|
|
|
|
impl MultiHeadAttention {
|
|
/// Create new multi-head attention
|
|
pub fn new(input_dim: usize, num_heads: usize, head_dim: usize, seed: u32) -> Self {
|
|
let heads: Vec<AttentionHead> = (0..num_heads)
|
|
.map(|i| AttentionHead::new(input_dim, head_dim, seed.wrapping_add(i as u32 * 10)))
|
|
.collect();
|
|
|
|
let concat_dim = num_heads * head_dim;
|
|
let output_dim = input_dim;
|
|
let w_out = AttentionHead::init_weights(
|
|
concat_dim * output_dim,
|
|
seed.wrapping_add(1000),
|
|
(2.0 / (concat_dim + output_dim) as f32).sqrt(),
|
|
);
|
|
|
|
Self {
|
|
heads,
|
|
w_out,
|
|
output_dim,
|
|
}
|
|
}
|
|
|
|
/// Forward pass through multi-head attention
|
|
pub fn forward(&self, sequence: &[Vec<f32>]) -> Vec<Vec<f32>> {
|
|
if sequence.is_empty() {
|
|
return vec![];
|
|
}
|
|
|
|
// Get outputs from all heads
|
|
let head_outputs: Vec<Vec<Vec<f32>>> = self.heads.iter()
|
|
.map(|head| head.forward(sequence))
|
|
.collect();
|
|
|
|
// Concatenate and project
|
|
let seq_len = head_outputs[0].len();
|
|
let head_dim = if self.heads.is_empty() { 0 } else { self.heads[0].dim() };
|
|
let concat_dim = self.heads.len() * head_dim;
|
|
|
|
let mut outputs = Vec::with_capacity(seq_len);
|
|
|
|
for pos in 0..seq_len {
|
|
// Concatenate heads
|
|
let mut concat = Vec::with_capacity(concat_dim);
|
|
for head_out in &head_outputs {
|
|
concat.extend_from_slice(&head_out[pos]);
|
|
}
|
|
|
|
// Output projection
|
|
let mut output = vec![0.0; self.output_dim];
|
|
for (i, o) in output.iter_mut().enumerate() {
|
|
for (j, &c) in concat.iter().enumerate() {
|
|
let idx = j * self.output_dim + i;
|
|
if idx < self.w_out.len() {
|
|
*o += c * self.w_out[idx];
|
|
}
|
|
}
|
|
}
|
|
|
|
outputs.push(output);
|
|
}
|
|
|
|
outputs
|
|
}
|
|
|
|
/// Apply attention pooling to get single output
|
|
pub fn pool(&self, sequence: &[Vec<f32>]) -> Vec<f32> {
|
|
let attended = self.forward(sequence);
|
|
|
|
if attended.is_empty() {
|
|
return vec![0.0; self.output_dim];
|
|
}
|
|
|
|
// Mean pooling over sequence
|
|
let mut pooled = vec![0.0; self.output_dim];
|
|
for item in &attended {
|
|
for (p, v) in pooled.iter_mut().zip(item.iter()) {
|
|
*p += v;
|
|
}
|
|
}
|
|
|
|
let n = attended.len() as f32;
|
|
for p in &mut pooled {
|
|
*p /= n;
|
|
}
|
|
|
|
pooled
|
|
}
|
|
}
|
|
|
|
/// Context-aware content ranker using attention
|
|
pub struct AttentionRanker {
|
|
attention: MultiHeadAttention,
|
|
/// Query transformation weights
|
|
w_query_transform: Vec<f32>,
|
|
dim: usize,
|
|
}
|
|
|
|
impl AttentionRanker {
|
|
/// Create new attention-based ranker
|
|
pub fn new(dim: usize, num_heads: usize) -> Self {
|
|
let head_dim = dim / num_heads.max(1);
|
|
let attention = MultiHeadAttention::new(dim, num_heads, head_dim, 54321);
|
|
|
|
let w_query_transform = AttentionHead::init_weights(
|
|
dim * dim,
|
|
99999,
|
|
(2.0 / (dim * 2) as f32).sqrt(),
|
|
);
|
|
|
|
Self {
|
|
attention,
|
|
w_query_transform,
|
|
dim,
|
|
}
|
|
}
|
|
|
|
/// Rank content items based on user context
|
|
///
|
|
/// Returns indices sorted by relevance score
|
|
pub fn rank(&self, query: &[f32], items: &[Vec<f32>]) -> Vec<(usize, f32)> {
|
|
if items.is_empty() || query.len() != self.dim {
|
|
return vec![];
|
|
}
|
|
|
|
// Transform query
|
|
let mut transformed_query = vec![0.0; self.dim];
|
|
for (i, tq) in transformed_query.iter_mut().enumerate() {
|
|
for (j, &q) in query.iter().enumerate() {
|
|
let idx = j * self.dim + i;
|
|
if idx < self.w_query_transform.len() {
|
|
*tq += q * self.w_query_transform[idx];
|
|
}
|
|
}
|
|
}
|
|
|
|
// Create sequence with query prepended
|
|
let mut sequence = vec![transformed_query.clone()];
|
|
sequence.extend(items.iter().cloned());
|
|
|
|
// Apply attention
|
|
let attended = self.attention.forward(&sequence);
|
|
|
|
// Score each item by similarity to attended query
|
|
let query_attended = &attended[0];
|
|
let mut scores: Vec<(usize, f32)> = attended[1..].iter()
|
|
.enumerate()
|
|
.map(|(i, item)| {
|
|
let sim: f32 = query_attended.iter()
|
|
.zip(item.iter())
|
|
.map(|(q, v)| q * v)
|
|
.sum();
|
|
(i, sim)
|
|
})
|
|
.collect();
|
|
|
|
// Sort by score descending
|
|
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(core::cmp::Ordering::Equal));
|
|
|
|
scores
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_attention_head() {
|
|
let head = AttentionHead::new(64, 16, 12345);
|
|
let sequence = vec![vec![0.5; 64]; 5];
|
|
|
|
let output = head.forward(&sequence);
|
|
assert_eq!(output.len(), 5);
|
|
assert_eq!(output[0].len(), 16);
|
|
}
|
|
|
|
#[test]
|
|
fn test_multi_head_attention() {
|
|
let mha = MultiHeadAttention::new(64, 4, 16, 12345);
|
|
let sequence = vec![vec![0.5; 64]; 5];
|
|
|
|
let output = mha.forward(&sequence);
|
|
assert_eq!(output.len(), 5);
|
|
assert_eq!(output[0].len(), 64);
|
|
}
|
|
|
|
#[test]
|
|
fn test_attention_ranker() {
|
|
let ranker = AttentionRanker::new(64, 4);
|
|
let query = vec![0.5; 64];
|
|
let items = vec![vec![0.3; 64], vec![0.7; 64], vec![0.1; 64]];
|
|
|
|
let ranked = ranker.rank(&query, &items);
|
|
assert_eq!(ranked.len(), 3);
|
|
}
|
|
}
|