Files
wifi-densepose/examples/ruvLLM/esp32/src/attention.rs
ruv d803bfe2b1 Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector
git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
2026-02-28 14:39:40 -05:00

328 lines
9.3 KiB
Rust

//! Attention mechanisms for ESP32
//!
//! Implements simplified attention patterns optimized for microcontrollers.
// Quantized operations for attention
/// Simplified single-head attention for ESP32
///
/// This is a memory-efficient attention that processes one head at a time
/// to minimize activation memory.
pub struct MicroAttention {
/// Head dimension
head_dim: usize,
/// Number of heads
num_heads: usize,
/// Cached attention scaling factor (1/sqrt(head_dim) as fixed-point)
scale_shift: u8,
}
impl MicroAttention {
/// Create new attention module
pub fn new(embed_dim: usize, num_heads: usize) -> Self {
let head_dim = embed_dim / num_heads;
// Approximate 1/sqrt(head_dim) as right shift
// sqrt(64) = 8, so shift by 3
// sqrt(32) ≈ 5.66, so shift by 2-3
let scale_shift = match head_dim {
d if d >= 64 => 3,
d if d >= 32 => 3,
d if d >= 16 => 2,
_ => 1,
};
Self {
head_dim,
num_heads,
scale_shift,
}
}
/// Compute attention scores between query and keys
///
/// Returns scores in i32 format (scaled by 256)
#[inline]
pub fn compute_scores(
&self,
query: &[i8], // [head_dim]
keys: &[&[i8]], // [seq_len, head_dim]
scores: &mut [i32], // [seq_len]
) {
for (i, key) in keys.iter().enumerate() {
let mut dot: i32 = 0;
for j in 0..self.head_dim {
dot += query[j] as i32 * key[j] as i32;
}
// Scale by 1/sqrt(d_k)
scores[i] = dot >> self.scale_shift;
}
}
/// Apply causal mask (set future positions to minimum)
#[inline]
pub fn apply_causal_mask(&self, scores: &mut [i32], current_pos: usize) {
for i in (current_pos + 1)..scores.len() {
scores[i] = i32::MIN / 2; // Avoid overflow in softmax
}
}
/// Fixed-point softmax optimized for ESP32
///
/// Uses integer arithmetic only, suitable for chips without FPU.
/// Output is scaled by 256 (i.e., 256 = 1.0)
#[inline]
pub fn softmax_fixed(&self, scores: &mut [i32]) {
if scores.is_empty() {
return;
}
// Find maximum for numerical stability
let max_score = scores.iter().cloned().max().unwrap_or(0);
// Compute exp approximation and sum
// exp(x) ≈ 1 + x + x²/2 for small x
// We use simpler linear: exp(x) ≈ 256 + x for x in [-256, 0]
let mut sum: i64 = 0;
for score in scores.iter_mut() {
let x = *score - max_score;
// Clamp to prevent overflow
let x_clamped = x.max(-512).min(0);
// Linear approximation of exp, result in range [0, 256]
*score = (256 + x_clamped / 2).max(1) as i32;
sum += *score as i64;
}
// Normalize: output[i] = score[i] * 256 / sum
if sum > 0 {
for score in scores.iter_mut() {
*score = ((*score as i64 * 256) / sum) as i32;
}
}
}
/// Compute weighted sum of values
///
/// output = sum(attention_weights[i] * values[i])
#[inline]
pub fn weighted_sum(
&self,
weights: &[i32], // [seq_len], scaled by 256
values: &[&[i8]], // [seq_len, head_dim]
output: &mut [i32], // [head_dim]
) {
// Clear output
for o in output.iter_mut() {
*o = 0;
}
// Accumulate weighted values
for (&weight, value) in weights.iter().zip(values.iter()) {
for j in 0..self.head_dim {
output[j] += weight * value[j] as i32;
}
}
// Descale (weights were scaled by 256)
for o in output.iter_mut() {
*o >>= 8;
}
}
}
/// Linear attention approximation for very long sequences
///
/// Uses kernel feature maps to achieve O(n) complexity instead of O(n²)
pub struct LinearAttention {
/// Feature dimension for kernel
feature_dim: usize,
}
impl LinearAttention {
pub fn new(feature_dim: usize) -> Self {
Self { feature_dim }
}
/// ELU-based feature map: φ(x) = elu(x) + 1
/// For INT8: approximate as max(x, 0) + 1
#[inline]
pub fn feature_map(&self, x: i8) -> i16 {
(x.max(0) as i16) + 1
}
/// Compute linear attention
/// Instead of softmax(QK^T)V, computes φ(Q)(φ(K)^T V)
pub fn forward(
&self,
query: &[i8], // [dim]
keys: &[&[i8]], // [seq_len, dim]
values: &[&[i8]], // [seq_len, dim]
output: &mut [i32], // [dim]
) {
let dim = query.len();
// Compute φ(K)^T V: [dim, dim] accumulated over sequence
// This is O(n * dim²) but can be incrementally updated
let mut kv_cache = [[0i32; 64]; 64]; // Fixed size for embedded
for (key, value) in keys.iter().zip(values.iter()) {
for i in 0..dim.min(64) {
let phi_k = self.feature_map(key[i]);
for j in 0..dim.min(64) {
kv_cache[i][j] += phi_k as i32 * value[j] as i32;
}
}
}
// Compute φ(Q) @ (φ(K)^T V)
for i in 0..dim.min(64) {
let phi_q = self.feature_map(query[i]);
let mut sum: i32 = 0;
for j in 0..dim.min(64) {
sum += phi_q as i32 * kv_cache[j][i];
}
output[i] = sum >> 8;
}
// Compute denominator: φ(Q) @ sum(φ(K))
let mut k_sum = [0i32; 64];
for key in keys.iter() {
for i in 0..dim.min(64) {
k_sum[i] += self.feature_map(key[i]) as i32;
}
}
let mut denom: i32 = 0;
for i in 0..dim.min(64) {
denom += self.feature_map(query[i]) as i32 * k_sum[i];
}
// Normalize
if denom > 0 {
for o in output.iter_mut() {
*o = (*o << 8) / denom;
}
}
}
}
/// Sliding window attention for memory efficiency
///
/// Only attends to the last N tokens, reducing memory from O(n²) to O(n*window)
pub struct SlidingWindowAttention {
window_size: usize,
head_dim: usize,
}
impl SlidingWindowAttention {
pub fn new(window_size: usize, head_dim: usize) -> Self {
Self { window_size, head_dim }
}
/// Compute attention with sliding window
pub fn forward(
&self,
query: &[i8],
keys: &[[i8; 64]], // Ring buffer of keys
values: &[[i8; 64]], // Ring buffer of values
cache_len: usize,
output: &mut [i32],
) {
let window_start = cache_len.saturating_sub(self.window_size);
let mut scores = [0i32; 32]; // Max window size
// Compute attention scores for window
for i in window_start..cache_len {
let mut dot: i32 = 0;
for j in 0..self.head_dim {
dot += query[j] as i32 * keys[i % self.window_size][j] as i32;
}
scores[i - window_start] = dot >> 3;
}
// Softmax over window
let window_len = cache_len - window_start;
let scores_slice = &mut scores[..window_len];
// Find max
let max = scores_slice.iter().cloned().max().unwrap_or(0);
let mut sum: i32 = 0;
for s in scores_slice.iter_mut() {
*s = (256 + (*s - max) / 2).max(1);
sum += *s;
}
// Normalize and compute output
for o in output[..self.head_dim].iter_mut() {
*o = 0;
}
for i in 0..window_len {
let weight = (scores[i] * 256) / sum.max(1);
let value = &values[(window_start + i) % self.window_size];
for j in 0..self.head_dim {
output[j] += weight * value[j] as i32;
}
}
for o in output[..self.head_dim].iter_mut() {
*o >>= 8;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_micro_attention() {
let attn = MicroAttention::new(64, 4);
let query = [10i8; 16];
let key1 = [10i8; 16];
let key2 = [5i8; 16];
let keys: [&[i8]; 2] = [&key1, &key2];
let mut scores = [0i32; 2];
attn.compute_scores(&query, &keys, &mut scores);
// First key should have higher score (same as query)
assert!(scores[0] > scores[1]);
}
#[test]
fn test_softmax_fixed() {
let attn = MicroAttention::new(64, 4);
let mut scores = [100i32, 50, 0, -50];
attn.softmax_fixed(&mut scores);
// Check that scores sum to ~256
let sum: i32 = scores.iter().sum();
assert!((sum - 256).abs() < 10);
// Check ordering preserved
assert!(scores[0] > scores[1]);
assert!(scores[1] > scores[2]);
assert!(scores[2] > scores[3]);
}
#[test]
fn test_linear_attention() {
let attn = LinearAttention::new(16);
let query = [10i8; 16];
let key = [10i8; 16];
let value = [5i8; 16];
let keys: [&[i8]; 1] = [&key];
let values: [&[i8]; 1] = [&value];
let mut output = [0i32; 16];
attn.forward(&query, &keys, &values, &mut output);
// Output should be non-zero
assert!(output.iter().any(|&x| x != 0));
}
}