Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
273
vendor/ruvector/examples/ruvLLM/esp32/src/optimizations/binary_quant.rs
vendored
Normal file
273
vendor/ruvector/examples/ruvLLM/esp32/src/optimizations/binary_quant.rs
vendored
Normal file
@@ -0,0 +1,273 @@
|
||||
//! Binary Quantization - 32x Memory Compression
|
||||
//!
|
||||
//! Adapted from ruvector-postgres/src/quantization/binary.rs
|
||||
//! Converts f32/i8 vectors to 1-bit per dimension with Hamming distance.
|
||||
|
||||
use heapless::Vec as HVec;
|
||||
|
||||
/// Maximum binary vector size in bytes (supports up to 512 dimensions)
|
||||
pub const MAX_BINARY_SIZE: usize = 64;
|
||||
|
||||
/// Binary quantized vector - 1 bit per dimension
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BinaryVector<const N: usize> {
|
||||
/// Packed binary data (8 dimensions per byte)
|
||||
pub data: HVec<u8, N>,
|
||||
/// Original dimension count
|
||||
pub dim: usize,
|
||||
/// Threshold used for binarization
|
||||
pub threshold: i8,
|
||||
}
|
||||
|
||||
impl<const N: usize> BinaryVector<N> {
|
||||
/// Create binary vector from INT8 values
|
||||
/// Values >= threshold become 1, values < threshold become 0
|
||||
pub fn from_i8(values: &[i8], threshold: i8) -> crate::Result<Self> {
|
||||
let dim = values.len();
|
||||
let num_bytes = (dim + 7) / 8;
|
||||
|
||||
if num_bytes > N {
|
||||
return Err(crate::Error::BufferOverflow);
|
||||
}
|
||||
|
||||
let mut data = HVec::new();
|
||||
|
||||
for chunk_idx in 0..(num_bytes) {
|
||||
let mut byte = 0u8;
|
||||
for bit_idx in 0..8 {
|
||||
let val_idx = chunk_idx * 8 + bit_idx;
|
||||
if val_idx < dim && values[val_idx] >= threshold {
|
||||
byte |= 1 << bit_idx;
|
||||
}
|
||||
}
|
||||
data.push(byte).map_err(|_| crate::Error::BufferOverflow)?;
|
||||
}
|
||||
|
||||
Ok(Self { data, dim, threshold })
|
||||
}
|
||||
|
||||
/// Create binary vector from f32 values (for host-side quantization)
|
||||
#[cfg(feature = "host-test")]
|
||||
pub fn from_f32(values: &[f32], threshold: f32) -> crate::Result<Self> {
|
||||
let i8_threshold = (threshold * 127.0) as i8;
|
||||
let i8_values: heapless::Vec<i8, 512> = values
|
||||
.iter()
|
||||
.map(|&v| (v * 127.0).clamp(-128.0, 127.0) as i8)
|
||||
.collect();
|
||||
Self::from_i8(&i8_values, i8_threshold)
|
||||
}
|
||||
|
||||
/// Get number of packed bytes
|
||||
pub fn num_bytes(&self) -> usize {
|
||||
self.data.len()
|
||||
}
|
||||
|
||||
/// Memory savings compared to INT8
|
||||
pub fn compression_ratio(&self) -> f32 {
|
||||
self.dim as f32 / self.data.len() as f32
|
||||
}
|
||||
}
|
||||
|
||||
/// Binary embedding table for vocabulary (32x smaller than INT8)
|
||||
pub struct BinaryEmbedding<const VOCAB: usize, const DIM_BYTES: usize> {
|
||||
/// Packed binary embeddings [VOCAB * DIM_BYTES]
|
||||
data: HVec<u8, { 32 * 1024 }>, // Max 32KB
|
||||
/// Vocabulary size
|
||||
vocab_size: usize,
|
||||
/// Dimensions (in bits)
|
||||
dim: usize,
|
||||
/// Bytes per embedding
|
||||
bytes_per_embed: usize,
|
||||
}
|
||||
|
||||
impl<const VOCAB: usize, const DIM_BYTES: usize> BinaryEmbedding<VOCAB, DIM_BYTES> {
|
||||
/// Create random binary embeddings for testing
|
||||
pub fn random(vocab_size: usize, dim: usize, seed: u32) -> crate::Result<Self> {
|
||||
let bytes_per_embed = (dim + 7) / 8;
|
||||
let total_bytes = vocab_size * bytes_per_embed;
|
||||
|
||||
let mut data = HVec::new();
|
||||
let mut rng_state = seed;
|
||||
|
||||
for _ in 0..total_bytes {
|
||||
rng_state = rng_state.wrapping_mul(1103515245).wrapping_add(12345);
|
||||
let byte = ((rng_state >> 16) & 0xFF) as u8;
|
||||
data.push(byte).map_err(|_| crate::Error::BufferOverflow)?;
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
data,
|
||||
vocab_size,
|
||||
dim,
|
||||
bytes_per_embed,
|
||||
})
|
||||
}
|
||||
|
||||
/// Look up binary embedding for a token
|
||||
pub fn lookup(&self, token_id: u16, output: &mut [u8]) -> crate::Result<()> {
|
||||
let id = token_id as usize;
|
||||
if id >= self.vocab_size {
|
||||
return Err(crate::Error::InvalidModel("Token ID out of range"));
|
||||
}
|
||||
|
||||
let start = id * self.bytes_per_embed;
|
||||
let end = start + self.bytes_per_embed;
|
||||
|
||||
if output.len() < self.bytes_per_embed {
|
||||
return Err(crate::Error::BufferOverflow);
|
||||
}
|
||||
|
||||
output[..self.bytes_per_embed].copy_from_slice(&self.data[start..end]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Memory size in bytes
|
||||
pub fn memory_size(&self) -> usize {
|
||||
self.data.len()
|
||||
}
|
||||
|
||||
/// Compression vs INT8 embedding of same dimensions
|
||||
pub fn compression_vs_int8(&self) -> f32 {
|
||||
8.0 // 8 bits per dimension -> 1 bit per dimension = 8x
|
||||
}
|
||||
}
|
||||
|
||||
/// Hamming distance between two binary vectors
|
||||
///
|
||||
/// Counts the number of differing bits. Uses POPCNT-like operations.
|
||||
/// On ESP32, this is extremely fast as it uses simple bitwise operations.
|
||||
#[inline]
|
||||
pub fn hamming_distance(a: &[u8], b: &[u8]) -> u32 {
|
||||
debug_assert_eq!(a.len(), b.len());
|
||||
|
||||
let mut distance: u32 = 0;
|
||||
|
||||
// Process 4 bytes at a time for better performance
|
||||
let chunks = a.len() / 4;
|
||||
for i in 0..chunks {
|
||||
let idx = i * 4;
|
||||
let xor0 = a[idx] ^ b[idx];
|
||||
let xor1 = a[idx + 1] ^ b[idx + 1];
|
||||
let xor2 = a[idx + 2] ^ b[idx + 2];
|
||||
let xor3 = a[idx + 3] ^ b[idx + 3];
|
||||
|
||||
distance += popcount8(xor0) + popcount8(xor1) + popcount8(xor2) + popcount8(xor3);
|
||||
}
|
||||
|
||||
// Handle remainder
|
||||
for i in (chunks * 4)..a.len() {
|
||||
distance += popcount8(a[i] ^ b[i]);
|
||||
}
|
||||
|
||||
distance
|
||||
}
|
||||
|
||||
/// Hamming similarity (inverted distance, normalized to 0-1 range)
|
||||
#[inline]
|
||||
pub fn hamming_similarity(a: &[u8], b: &[u8]) -> f32 {
|
||||
let total_bits = (a.len() * 8) as f32;
|
||||
let distance = hamming_distance(a, b) as f32;
|
||||
1.0 - (distance / total_bits)
|
||||
}
|
||||
|
||||
/// Hamming similarity as fixed-point (0-255 range)
|
||||
#[inline]
|
||||
pub fn hamming_similarity_fixed(a: &[u8], b: &[u8]) -> u8 {
|
||||
let total_bits = (a.len() * 8) as u32;
|
||||
let matching_bits = total_bits - hamming_distance(a, b);
|
||||
((matching_bits * 255) / total_bits) as u8
|
||||
}
|
||||
|
||||
/// Population count for a single byte (count of 1 bits)
|
||||
/// Uses lookup table for ESP32 efficiency
|
||||
#[inline]
|
||||
pub fn popcount8(x: u8) -> u32 {
|
||||
// Lookup table for byte population count
|
||||
const POPCOUNT_TABLE: [u8; 256] = [
|
||||
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4,
|
||||
1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
|
||||
1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
|
||||
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
|
||||
1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
|
||||
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
|
||||
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
|
||||
3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
|
||||
1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
|
||||
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
|
||||
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
|
||||
3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
|
||||
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
|
||||
3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
|
||||
3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
|
||||
4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8,
|
||||
];
|
||||
POPCOUNT_TABLE[x as usize] as u32
|
||||
}
|
||||
|
||||
/// XNOR-popcount for binary neural network inference
|
||||
/// Equivalent to computing dot product of {-1, +1} vectors
|
||||
#[inline]
|
||||
pub fn xnor_popcount(a: &[u8], b: &[u8]) -> i32 {
|
||||
debug_assert_eq!(a.len(), b.len());
|
||||
|
||||
let total_bits = (a.len() * 8) as i32;
|
||||
let mut matching: i32 = 0;
|
||||
|
||||
for (&x, &y) in a.iter().zip(b.iter()) {
|
||||
// XNOR: same bits = 1, different bits = 0
|
||||
let xnor = !(x ^ y);
|
||||
matching += popcount8(xnor) as i32;
|
||||
}
|
||||
|
||||
// Convert to {-1, +1} dot product equivalent
|
||||
// matching bits contribute +1, non-matching contribute -1
|
||||
// result = 2 * matching - total_bits
|
||||
2 * matching - total_bits
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_binary_quantization() {
|
||||
let values = [10i8, -5, 20, -10, 0, 15, -8, 30];
|
||||
let binary = BinaryVector::<8>::from_i8(&values, 0).unwrap();
|
||||
|
||||
assert_eq!(binary.dim, 8);
|
||||
assert_eq!(binary.num_bytes(), 1);
|
||||
|
||||
// Expected: bits where value >= 0: positions 0, 2, 4, 5, 7
|
||||
// Binary: 10110101 = 0xB5
|
||||
assert_eq!(binary.data[0], 0b10110101);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hamming_distance() {
|
||||
let a = [0b11110000u8, 0b10101010];
|
||||
let b = [0b11110000u8, 0b10101010];
|
||||
assert_eq!(hamming_distance(&a, &b), 0);
|
||||
|
||||
let c = [0b00001111u8, 0b01010101];
|
||||
assert_eq!(hamming_distance(&a, &c), 16); // All bits different
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_xnor_popcount() {
|
||||
let a = [0b11111111u8];
|
||||
let b = [0b11111111u8];
|
||||
// Perfect match: 8 matching bits -> 2*8 - 8 = 8
|
||||
assert_eq!(xnor_popcount(&a, &b), 8);
|
||||
|
||||
let c = [0b00000000u8];
|
||||
// Complete mismatch: 0 matching bits -> 2*0 - 8 = -8
|
||||
assert_eq!(xnor_popcount(&a, &c), -8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compression_ratio() {
|
||||
let values = [0i8; 64];
|
||||
let binary = BinaryVector::<8>::from_i8(&values, 0).unwrap();
|
||||
assert_eq!(binary.compression_ratio(), 8.0);
|
||||
}
|
||||
}
|
||||
266
vendor/ruvector/examples/ruvLLM/esp32/src/optimizations/lookup_tables.rs
vendored
Normal file
266
vendor/ruvector/examples/ruvLLM/esp32/src/optimizations/lookup_tables.rs
vendored
Normal file
@@ -0,0 +1,266 @@
|
||||
//! Lookup Tables for Fast Fixed-Point Operations
|
||||
//!
|
||||
//! Pre-computed tables for softmax, exp, and distance operations.
|
||||
//! Critical for ESP32 which lacks FPU on most variants.
|
||||
|
||||
/// Softmax lookup table (256 entries)
|
||||
///
|
||||
/// Pre-computed exp(x) values for x in [-8, 0] range, scaled to INT8.
|
||||
/// Used for fast fixed-point softmax without floating-point operations.
|
||||
pub struct SoftmaxLUT {
|
||||
/// exp(x) values, scaled by 255
|
||||
exp_table: [u8; 256],
|
||||
/// Scale factor for input normalization
|
||||
input_scale: i32,
|
||||
}
|
||||
|
||||
impl SoftmaxLUT {
|
||||
/// Create softmax LUT with default parameters
|
||||
pub const fn new() -> Self {
|
||||
// Pre-compute exp(x) for x in [-8, 0], scaled to [0, 255]
|
||||
// exp(-8) ≈ 0.000335, exp(0) = 1
|
||||
// We discretize into 256 bins
|
||||
|
||||
let mut exp_table = [0u8; 256];
|
||||
|
||||
// Approximate exp using polynomial: exp(x) ≈ 1 + x + x²/2 + x³/6
|
||||
// For integer approximation: exp(x/32) scaled by 255
|
||||
let mut i = 0;
|
||||
while i < 256 {
|
||||
// x ranges from -8 (i=0) to 0 (i=255)
|
||||
// x = (i - 255) / 32
|
||||
let x_scaled = i as i32 - 255; // Range: -255 to 0
|
||||
|
||||
// Linear approximation of exp for negative values
|
||||
// exp(x) ≈ 255 + x for small |x|, clamped to [1, 255]
|
||||
let mut exp_approx = 255 + x_scaled;
|
||||
if exp_approx < 1 { exp_approx = 1; }
|
||||
if exp_approx > 255 { exp_approx = 255; }
|
||||
exp_table[i] = exp_approx as u8;
|
||||
|
||||
i += 1;
|
||||
}
|
||||
|
||||
Self {
|
||||
exp_table,
|
||||
input_scale: 32, // Divide input by 32 before lookup
|
||||
}
|
||||
}
|
||||
|
||||
/// Look up approximate exp(x) for x in [-8, 0]
|
||||
#[inline]
|
||||
pub fn exp(&self, x: i32) -> u8 {
|
||||
// Clamp x to valid range and scale
|
||||
let x_clamped = x.max(-255).min(0);
|
||||
let idx = (x_clamped + 255) as usize;
|
||||
self.exp_table[idx]
|
||||
}
|
||||
|
||||
/// Compute softmax over an array of INT32 logits
|
||||
/// Output is scaled by 256 (i.e., 256 = probability 1.0)
|
||||
pub fn softmax(&self, logits: &[i32], output: &mut [u16]) {
|
||||
if logits.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Find max for numerical stability
|
||||
let max_logit = logits.iter().cloned().max().unwrap_or(0);
|
||||
|
||||
// Compute exp and sum
|
||||
let mut sum: u32 = 0;
|
||||
for (&logit, out) in logits.iter().zip(output.iter_mut()) {
|
||||
let x = logit - max_logit;
|
||||
let exp_val = self.exp(x) as u16;
|
||||
*out = exp_val;
|
||||
sum += exp_val as u32;
|
||||
}
|
||||
|
||||
// Normalize: probability = exp / sum, scaled by 256
|
||||
if sum > 0 {
|
||||
for out in output.iter_mut() {
|
||||
*out = ((*out as u32 * 256) / sum) as u16;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Fast softmax using only integer operations
|
||||
/// Returns probabilities scaled by 256
|
||||
pub fn softmax_fast(&self, logits: &mut [i32]) {
|
||||
if logits.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Find max
|
||||
let max = logits.iter().cloned().max().unwrap_or(0);
|
||||
|
||||
// Subtract max and apply exp approximation
|
||||
let mut sum: i32 = 0;
|
||||
for logit in logits.iter_mut() {
|
||||
let x = (*logit - max).max(-255);
|
||||
*logit = self.exp_table[(x + 255) as usize] as i32;
|
||||
sum += *logit;
|
||||
}
|
||||
|
||||
// Normalize (multiply by 256 then divide by sum)
|
||||
if sum > 0 {
|
||||
for logit in logits.iter_mut() {
|
||||
*logit = (*logit << 8) / sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SoftmaxLUT {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Exponential lookup table for more precise exp approximation
|
||||
pub struct ExpLUT {
|
||||
/// exp(x/64) for x in [0, 255], scaled by 256
|
||||
table: [u16; 256],
|
||||
}
|
||||
|
||||
impl ExpLUT {
|
||||
/// Create with higher precision (uses more memory)
|
||||
pub const fn new() -> Self {
|
||||
let mut table = [0u16; 256];
|
||||
|
||||
let mut i = 0;
|
||||
while i < 256 {
|
||||
// exp(x/64) for x in [0, 255]
|
||||
// At x=0: exp(0) = 1 -> 256
|
||||
// At x=255: exp(255/64) ≈ exp(3.98) ≈ 53.5 -> scaled
|
||||
|
||||
// Polynomial approximation: 1 + x + x²/2
|
||||
let x = i as i32;
|
||||
let x_scaled = x * 256 / 64; // x/64 * 256 for fixed-point
|
||||
let x2 = (x_scaled * x_scaled) >> 9; // x² / 512
|
||||
|
||||
let mut exp_val = 256 + x_scaled + (x2 >> 1);
|
||||
if exp_val > 65535 { exp_val = 65535; }
|
||||
table[i] = exp_val as u16;
|
||||
|
||||
i += 1;
|
||||
}
|
||||
|
||||
Self { table }
|
||||
}
|
||||
|
||||
/// exp(x) where x is in range [0, 4) scaled by 64
|
||||
#[inline]
|
||||
pub fn exp(&self, x: u8) -> u16 {
|
||||
self.table[x as usize]
|
||||
}
|
||||
}
|
||||
|
||||
/// Distance lookup table for common embedding similarities
|
||||
pub struct DistanceLUT<const SIZE: usize> {
|
||||
/// Pre-computed squared differences for INT8 pairs
|
||||
sq_diff_table: [u16; 512], // For INT8 diffs in [-255, 255]
|
||||
}
|
||||
|
||||
impl<const SIZE: usize> DistanceLUT<SIZE> {
|
||||
/// Create distance LUT
|
||||
pub const fn new() -> Self {
|
||||
let mut sq_diff_table = [0u16; 512];
|
||||
|
||||
let mut i = 0i32;
|
||||
while i < 512 {
|
||||
let diff = i - 256; // Map [0, 511] to [-256, 255]
|
||||
let mut sq = diff * diff;
|
||||
if sq > 65535 { sq = 65535; }
|
||||
sq_diff_table[i as usize] = sq as u16;
|
||||
i += 1;
|
||||
}
|
||||
|
||||
Self { sq_diff_table }
|
||||
}
|
||||
|
||||
/// Look up squared difference between two INT8 values
|
||||
#[inline]
|
||||
pub fn squared_diff(&self, a: i8, b: i8) -> u16 {
|
||||
let diff = a as i32 - b as i32;
|
||||
let idx = (diff + 256) as usize;
|
||||
self.sq_diff_table[idx]
|
||||
}
|
||||
|
||||
/// Compute L2 squared distance using lookup table
|
||||
pub fn l2_squared(&self, a: &[i8], b: &[i8]) -> u32 {
|
||||
debug_assert_eq!(a.len(), b.len());
|
||||
|
||||
let mut sum: u32 = 0;
|
||||
for (&x, &y) in a.iter().zip(b.iter()) {
|
||||
sum += self.squared_diff(x, y) as u32;
|
||||
}
|
||||
sum
|
||||
}
|
||||
}
|
||||
|
||||
/// Global static lookup tables (no heap allocation)
|
||||
pub static SOFTMAX_LUT: SoftmaxLUT = SoftmaxLUT::new();
|
||||
pub static EXP_LUT: ExpLUT = ExpLUT::new();
|
||||
pub static DISTANCE_LUT: DistanceLUT<256> = DistanceLUT::new();
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_softmax_lut() {
|
||||
let lut = SoftmaxLUT::new();
|
||||
|
||||
// exp(0) should be maximum (255)
|
||||
assert_eq!(lut.exp(0), 255);
|
||||
|
||||
// exp(-255) should be minimum (1)
|
||||
assert_eq!(lut.exp(-255), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_softmax_normalization() {
|
||||
let lut = SoftmaxLUT::new();
|
||||
let logits = [100i32, 50, 0, -50];
|
||||
let mut output = [0u16; 4];
|
||||
|
||||
lut.softmax(&logits, &mut output);
|
||||
|
||||
// Sum should be approximately 256
|
||||
let sum: u16 = output.iter().sum();
|
||||
assert!((sum as i32 - 256).abs() < 10);
|
||||
|
||||
// First element should have highest probability
|
||||
assert!(output[0] > output[1]);
|
||||
assert!(output[1] > output[2]);
|
||||
assert!(output[2] > output[3]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_distance_lut() {
|
||||
let lut = DistanceLUT::<256>::new();
|
||||
|
||||
// Same values: squared diff = 0
|
||||
assert_eq!(lut.squared_diff(10, 10), 0);
|
||||
|
||||
// Diff of 10: squared = 100
|
||||
assert_eq!(lut.squared_diff(10, 0), 100);
|
||||
assert_eq!(lut.squared_diff(0, 10), 100);
|
||||
|
||||
// Negative values
|
||||
assert_eq!(lut.squared_diff(-10, 0), 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_l2_distance() {
|
||||
let lut = DistanceLUT::<256>::new();
|
||||
|
||||
let a = [10i8, 20, 30, 40];
|
||||
let b = [10i8, 20, 30, 40];
|
||||
assert_eq!(lut.l2_squared(&a, &b), 0);
|
||||
|
||||
let c = [0i8, 0, 0, 0];
|
||||
// (10² + 20² + 30² + 40²) = 100 + 400 + 900 + 1600 = 3000
|
||||
assert_eq!(lut.l2_squared(&a, &c), 3000);
|
||||
}
|
||||
}
|
||||
323
vendor/ruvector/examples/ruvLLM/esp32/src/optimizations/micro_lora.rs
vendored
Normal file
323
vendor/ruvector/examples/ruvLLM/esp32/src/optimizations/micro_lora.rs
vendored
Normal file
@@ -0,0 +1,323 @@
|
||||
//! MicroLoRA - Tiny Low-Rank Adaptation for ESP32
|
||||
//!
|
||||
//! Adapted from ruvLLM's SONA architecture for on-device adaptation.
|
||||
//! Uses INT8 weights with rank 1-2 for minimal memory footprint.
|
||||
|
||||
use heapless::Vec as HVec;
|
||||
use crate::quantized::QuantParams;
|
||||
|
||||
/// Maximum LoRA rank (keep very small for ESP32)
|
||||
pub const MAX_LORA_RANK: usize = 2;
|
||||
/// Maximum dimension for LoRA matrices
|
||||
pub const MAX_LORA_DIM: usize = 64;
|
||||
|
||||
/// MicroLoRA configuration
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct LoRAConfig {
|
||||
/// Rank of the low-rank matrices (1 or 2 for ESP32)
|
||||
pub rank: usize,
|
||||
/// Input/output dimension
|
||||
pub dim: usize,
|
||||
/// Scaling factor (alpha / rank)
|
||||
pub scale: i8,
|
||||
/// Whether LoRA is frozen (inference-only)
|
||||
pub frozen: bool,
|
||||
}
|
||||
|
||||
impl Default for LoRAConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
rank: 1,
|
||||
dim: 32,
|
||||
scale: 8, // alpha=8, rank=1 -> scale=8
|
||||
frozen: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// MicroLoRA adapter for a single layer
|
||||
///
|
||||
/// Implements: output = input + scale * (input @ A) @ B
|
||||
/// Where A is [dim, rank] and B is [rank, dim]
|
||||
pub struct MicroLoRA {
|
||||
/// Down projection: A matrix [dim, rank] as INT8
|
||||
a_weights: HVec<i8, { MAX_LORA_DIM * MAX_LORA_RANK }>,
|
||||
/// Up projection: B matrix [rank, dim] as INT8
|
||||
b_weights: HVec<i8, { MAX_LORA_RANK * MAX_LORA_DIM }>,
|
||||
/// Configuration
|
||||
config: LoRAConfig,
|
||||
/// Quantization params for A
|
||||
a_params: QuantParams,
|
||||
/// Quantization params for B
|
||||
b_params: QuantParams,
|
||||
/// Intermediate buffer for rank-sized vector
|
||||
intermediate: [i32; MAX_LORA_RANK],
|
||||
}
|
||||
|
||||
impl MicroLoRA {
|
||||
/// Create new MicroLoRA with random initialization
|
||||
pub fn new(config: LoRAConfig, seed: u32) -> crate::Result<Self> {
|
||||
if config.rank > MAX_LORA_RANK || config.dim > MAX_LORA_DIM {
|
||||
return Err(crate::Error::InvalidModel("LoRA dimensions too large"));
|
||||
}
|
||||
|
||||
let mut a_weights = HVec::new();
|
||||
let mut b_weights = HVec::new();
|
||||
|
||||
let mut rng_state = seed;
|
||||
let mut next_rand = || {
|
||||
rng_state = rng_state.wrapping_mul(1103515245).wrapping_add(12345);
|
||||
(((rng_state >> 16) & 0x3F) as i16 - 32) as i8 // Small values [-32, 31]
|
||||
};
|
||||
|
||||
// Initialize A with small random values
|
||||
for _ in 0..(config.dim * config.rank) {
|
||||
a_weights.push(next_rand()).map_err(|_| crate::Error::BufferOverflow)?;
|
||||
}
|
||||
|
||||
// Initialize B with zeros (LoRA starts as identity)
|
||||
for _ in 0..(config.rank * config.dim) {
|
||||
b_weights.push(0).map_err(|_| crate::Error::BufferOverflow)?;
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
a_weights,
|
||||
b_weights,
|
||||
config,
|
||||
a_params: QuantParams::default(),
|
||||
b_params: QuantParams::default(),
|
||||
intermediate: [0; MAX_LORA_RANK],
|
||||
})
|
||||
}
|
||||
|
||||
/// Create MicroLoRA from pre-trained weights
|
||||
pub fn from_weights(
|
||||
config: LoRAConfig,
|
||||
a_weights: &[i8],
|
||||
b_weights: &[i8],
|
||||
) -> crate::Result<Self> {
|
||||
if a_weights.len() != config.dim * config.rank {
|
||||
return Err(crate::Error::InvalidModel("A weights size mismatch"));
|
||||
}
|
||||
if b_weights.len() != config.rank * config.dim {
|
||||
return Err(crate::Error::InvalidModel("B weights size mismatch"));
|
||||
}
|
||||
|
||||
let mut a_vec = HVec::new();
|
||||
let mut b_vec = HVec::new();
|
||||
|
||||
for &w in a_weights {
|
||||
a_vec.push(w).map_err(|_| crate::Error::BufferOverflow)?;
|
||||
}
|
||||
for &w in b_weights {
|
||||
b_vec.push(w).map_err(|_| crate::Error::BufferOverflow)?;
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
a_weights: a_vec,
|
||||
b_weights: b_vec,
|
||||
config,
|
||||
a_params: QuantParams::default(),
|
||||
b_params: QuantParams::default(),
|
||||
intermediate: [0; MAX_LORA_RANK],
|
||||
})
|
||||
}
|
||||
|
||||
/// Apply LoRA adaptation to input
|
||||
///
|
||||
/// Computes: output = input + scale * (input @ A) @ B
|
||||
/// All operations in INT8/INT32
|
||||
#[inline]
|
||||
pub fn apply(&mut self, input: &[i8], output: &mut [i32]) {
|
||||
let dim = self.config.dim;
|
||||
let rank = self.config.rank;
|
||||
let scale = self.config.scale as i32;
|
||||
|
||||
// Clear intermediate buffer
|
||||
for i in 0..rank {
|
||||
self.intermediate[i] = 0;
|
||||
}
|
||||
|
||||
// Step 1: intermediate = input @ A (down projection)
|
||||
// A is [dim, rank], input is [dim], result is [rank]
|
||||
for r in 0..rank {
|
||||
let mut sum: i32 = 0;
|
||||
for d in 0..dim {
|
||||
sum += input[d] as i32 * self.a_weights[d * rank + r] as i32;
|
||||
}
|
||||
self.intermediate[r] = sum >> 4; // Scale down to prevent overflow
|
||||
}
|
||||
|
||||
// Step 2: lora_output = intermediate @ B (up projection)
|
||||
// B is [rank, dim], intermediate is [rank], result is [dim]
|
||||
for d in 0..dim {
|
||||
let mut sum: i32 = 0;
|
||||
for r in 0..rank {
|
||||
sum += self.intermediate[r] * self.b_weights[r * dim + d] as i32;
|
||||
}
|
||||
// Add scaled LoRA output to original output
|
||||
output[d] += (sum * scale) >> 8;
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply LoRA and store result in-place
|
||||
pub fn apply_inplace(&mut self, data: &mut [i32], input: &[i8]) {
|
||||
self.apply(input, data);
|
||||
}
|
||||
|
||||
/// Memory size of this LoRA adapter
|
||||
pub fn memory_size(&self) -> usize {
|
||||
self.a_weights.len() + self.b_weights.len()
|
||||
}
|
||||
|
||||
/// Update LoRA weights with gradient (simplified for on-device learning)
|
||||
///
|
||||
/// Uses a simple gradient accumulation approach suitable for ESP32:
|
||||
/// A += lr * input^T @ grad_intermediate
|
||||
/// B += lr * intermediate^T @ grad_output
|
||||
#[cfg(not(feature = "frozen"))]
|
||||
pub fn update(&mut self, input: &[i8], grad_output: &[i32], learning_rate: i8) {
|
||||
let dim = self.config.dim;
|
||||
let rank = self.config.rank;
|
||||
let lr = learning_rate as i32;
|
||||
|
||||
// Compute gradient for intermediate (simplified)
|
||||
let mut grad_intermediate = [0i32; MAX_LORA_RANK];
|
||||
for r in 0..rank {
|
||||
let mut sum: i32 = 0;
|
||||
for d in 0..dim {
|
||||
sum += grad_output[d] * self.b_weights[r * dim + d] as i32;
|
||||
}
|
||||
grad_intermediate[r] = sum >> 8;
|
||||
}
|
||||
|
||||
// Update A weights: A += lr * outer(input, grad_intermediate)
|
||||
for d in 0..dim {
|
||||
for r in 0..rank {
|
||||
let grad = (input[d] as i32 * grad_intermediate[r] * lr) >> 12;
|
||||
let idx = d * rank + r;
|
||||
let new_val = self.a_weights[idx] as i32 + grad;
|
||||
self.a_weights[idx] = new_val.clamp(-127, 127) as i8;
|
||||
}
|
||||
}
|
||||
|
||||
// Update B weights: B += lr * outer(intermediate, grad_output)
|
||||
for r in 0..rank {
|
||||
for d in 0..dim {
|
||||
let grad = (self.intermediate[r] * grad_output[d] * lr) >> 12;
|
||||
let idx = r * dim + d;
|
||||
let new_val = self.b_weights[idx] as i32 + grad;
|
||||
self.b_weights[idx] = new_val.clamp(-127, 127) as i8;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Collection of MicroLoRA adapters for all layers
|
||||
pub struct LoRAStack<const NUM_LAYERS: usize> {
|
||||
/// LoRA adapters per layer
|
||||
adapters: [Option<MicroLoRA>; NUM_LAYERS],
|
||||
/// Number of active adapters
|
||||
active_count: usize,
|
||||
}
|
||||
|
||||
impl<const NUM_LAYERS: usize> LoRAStack<NUM_LAYERS> {
|
||||
/// Create empty LoRA stack
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
adapters: core::array::from_fn(|_| None),
|
||||
active_count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add LoRA adapter to a layer
|
||||
pub fn add_adapter(&mut self, layer_idx: usize, adapter: MicroLoRA) -> crate::Result<()> {
|
||||
if layer_idx >= NUM_LAYERS {
|
||||
return Err(crate::Error::InvalidModel("Layer index out of range"));
|
||||
}
|
||||
self.adapters[layer_idx] = Some(adapter);
|
||||
self.active_count += 1;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get adapter for a layer (if exists)
|
||||
pub fn get(&mut self, layer_idx: usize) -> Option<&mut MicroLoRA> {
|
||||
self.adapters.get_mut(layer_idx).and_then(|a| a.as_mut())
|
||||
}
|
||||
|
||||
/// Total memory used by all adapters
|
||||
pub fn total_memory(&self) -> usize {
|
||||
self.adapters.iter()
|
||||
.filter_map(|a| a.as_ref())
|
||||
.map(|a| a.memory_size())
|
||||
.sum()
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize> Default for LoRAStack<N> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_micro_lora_creation() {
|
||||
let config = LoRAConfig {
|
||||
rank: 2,
|
||||
dim: 32,
|
||||
scale: 8,
|
||||
frozen: true,
|
||||
};
|
||||
|
||||
let lora = MicroLoRA::new(config, 42).unwrap();
|
||||
|
||||
// A: 32 * 2 = 64 bytes, B: 2 * 32 = 64 bytes
|
||||
assert_eq!(lora.memory_size(), 128);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lora_apply() {
|
||||
let config = LoRAConfig {
|
||||
rank: 1,
|
||||
dim: 4,
|
||||
scale: 64, // Larger scale for testing
|
||||
frozen: true,
|
||||
};
|
||||
|
||||
// Create with known weights - larger values to survive scaling
|
||||
let a_weights = [16i8, 32, 48, 64]; // [4, 1]
|
||||
let b_weights = [64i8, 64, 64, 64]; // [1, 4]
|
||||
|
||||
let mut lora = MicroLoRA::from_weights(config, &a_weights, &b_weights).unwrap();
|
||||
|
||||
let input = [64i8, 64, 64, 64];
|
||||
let mut output = [0i32; 4];
|
||||
|
||||
lora.apply(&input, &mut output);
|
||||
|
||||
// With larger values, the output should be non-zero after scaling
|
||||
// intermediate = sum(64 * [16,32,48,64]) >> 4 = (10240) >> 4 = 640
|
||||
// output = (640 * 64 * scale) >> 8
|
||||
// This should produce non-zero results
|
||||
let non_zero_count = output.iter().filter(|&&o| o != 0).count();
|
||||
assert!(non_zero_count > 0, "At least some outputs should be non-zero, got {:?}", output);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lora_stack() {
|
||||
let mut stack = LoRAStack::<4>::new();
|
||||
|
||||
let config = LoRAConfig::default();
|
||||
let adapter = MicroLoRA::new(config, 42).unwrap();
|
||||
|
||||
stack.add_adapter(0, adapter).unwrap();
|
||||
|
||||
assert!(stack.get(0).is_some());
|
||||
assert!(stack.get(1).is_none());
|
||||
assert!(stack.total_memory() > 0);
|
||||
}
|
||||
}
|
||||
25
vendor/ruvector/examples/ruvLLM/esp32/src/optimizations/mod.rs
vendored
Normal file
25
vendor/ruvector/examples/ruvLLM/esp32/src/optimizations/mod.rs
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
//! Advanced Optimizations from Ruvector
|
||||
//!
|
||||
//! This module brings key optimizations from the ruvector ecosystem to ESP32:
|
||||
//! - Binary quantization (32x compression)
|
||||
//! - Product quantization (8-32x compression)
|
||||
//! - Hamming distance with POPCNT
|
||||
//! - Fixed-point softmax with lookup tables
|
||||
//! - MicroLoRA for on-device adaptation
|
||||
//! - Sparse attention patterns
|
||||
//! - MinCut-inspired layer pruning
|
||||
|
||||
pub mod binary_quant;
|
||||
pub mod product_quant;
|
||||
pub mod lookup_tables;
|
||||
pub mod micro_lora;
|
||||
pub mod sparse_attention;
|
||||
pub mod pruning;
|
||||
|
||||
// Re-exports
|
||||
pub use binary_quant::{BinaryVector, BinaryEmbedding, hamming_distance, hamming_similarity};
|
||||
pub use product_quant::{ProductQuantizer, PQCode};
|
||||
pub use lookup_tables::{SoftmaxLUT, ExpLUT, DistanceLUT};
|
||||
pub use micro_lora::{MicroLoRA, LoRAConfig};
|
||||
pub use sparse_attention::{SparseAttention, AttentionPattern};
|
||||
pub use pruning::{LayerPruner, PruningConfig};
|
||||
336
vendor/ruvector/examples/ruvLLM/esp32/src/optimizations/product_quant.rs
vendored
Normal file
336
vendor/ruvector/examples/ruvLLM/esp32/src/optimizations/product_quant.rs
vendored
Normal file
@@ -0,0 +1,336 @@
|
||||
//! Product Quantization - 8-32x Memory Compression
|
||||
//!
|
||||
//! Adapted from ruvector-postgres for ESP32 constraints.
|
||||
//! Splits vectors into subvectors and quantizes each independently.
|
||||
|
||||
use heapless::Vec as HVec;
|
||||
|
||||
/// Maximum number of subquantizers
|
||||
pub const MAX_SUBQUANTIZERS: usize = 8;
|
||||
/// Maximum codebook size per subquantizer
|
||||
pub const MAX_CODEBOOK_SIZE: usize = 16; // 4-bit codes
|
||||
/// Maximum subvector dimension
|
||||
pub const MAX_SUBVEC_DIM: usize = 8;
|
||||
|
||||
/// Product Quantization configuration
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct PQConfig {
|
||||
/// Number of subquantizers (M)
|
||||
pub num_subquantizers: usize,
|
||||
/// Number of codes per subquantizer (K = 2^bits)
|
||||
pub codebook_size: usize,
|
||||
/// Dimension of each subvector
|
||||
pub subvec_dim: usize,
|
||||
/// Total vector dimension
|
||||
pub dim: usize,
|
||||
}
|
||||
|
||||
impl Default for PQConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
num_subquantizers: 4,
|
||||
codebook_size: 16, // 4-bit codes
|
||||
subvec_dim: 8,
|
||||
dim: 32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Product Quantized code for a vector
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PQCode<const M: usize> {
|
||||
/// Code indices for each subquantizer (4-bit packed)
|
||||
pub codes: HVec<u8, M>,
|
||||
}
|
||||
|
||||
impl<const M: usize> PQCode<M> {
|
||||
/// Create from code indices
|
||||
pub fn from_codes(codes: &[u8]) -> crate::Result<Self> {
|
||||
let mut code_vec = HVec::new();
|
||||
for &c in codes {
|
||||
code_vec.push(c).map_err(|_| crate::Error::BufferOverflow)?;
|
||||
}
|
||||
Ok(Self { codes: code_vec })
|
||||
}
|
||||
|
||||
/// Get code for subquantizer i
|
||||
#[inline]
|
||||
pub fn get_code(&self, i: usize) -> u8 {
|
||||
self.codes.get(i).copied().unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Memory size in bytes
|
||||
pub fn memory_size(&self) -> usize {
|
||||
self.codes.len()
|
||||
}
|
||||
}
|
||||
|
||||
/// Product Quantizer with codebooks
|
||||
pub struct ProductQuantizer<const M: usize, const K: usize, const D: usize> {
|
||||
/// Codebooks: [M][K][D] flattened to [M * K * D]
|
||||
/// Each subquantizer has K centroids of dimension D
|
||||
codebooks: HVec<i8, { 8 * 16 * 8 }>, // Max 1024 bytes
|
||||
/// Configuration
|
||||
config: PQConfig,
|
||||
}
|
||||
|
||||
impl<const M: usize, const K: usize, const D: usize> ProductQuantizer<M, K, D> {
|
||||
/// Create with random codebooks (for testing)
|
||||
pub fn random(config: PQConfig, seed: u32) -> crate::Result<Self> {
|
||||
let total_size = config.num_subquantizers * config.codebook_size * config.subvec_dim;
|
||||
|
||||
let mut codebooks = HVec::new();
|
||||
let mut rng_state = seed;
|
||||
|
||||
for _ in 0..total_size {
|
||||
rng_state = rng_state.wrapping_mul(1103515245).wrapping_add(12345);
|
||||
let val = (((rng_state >> 16) & 0xFF) as i16 - 128) as i8;
|
||||
codebooks.push(val).map_err(|_| crate::Error::BufferOverflow)?;
|
||||
}
|
||||
|
||||
Ok(Self { codebooks, config })
|
||||
}
|
||||
|
||||
/// Create from pre-trained codebooks
|
||||
pub fn from_codebooks(config: PQConfig, codebooks: &[i8]) -> crate::Result<Self> {
|
||||
let expected = config.num_subquantizers * config.codebook_size * config.subvec_dim;
|
||||
if codebooks.len() != expected {
|
||||
return Err(crate::Error::InvalidModel("Codebook size mismatch"));
|
||||
}
|
||||
|
||||
let mut cb_vec = HVec::new();
|
||||
for &v in codebooks {
|
||||
cb_vec.push(v).map_err(|_| crate::Error::BufferOverflow)?;
|
||||
}
|
||||
|
||||
Ok(Self { codebooks: cb_vec, config })
|
||||
}
|
||||
|
||||
/// Get centroid for subquantizer m, code k
|
||||
#[inline]
|
||||
fn get_centroid(&self, m: usize, k: usize) -> &[i8] {
|
||||
let d = self.config.subvec_dim;
|
||||
let kk = self.config.codebook_size;
|
||||
let start = m * kk * d + k * d;
|
||||
&self.codebooks[start..start + d]
|
||||
}
|
||||
|
||||
/// Encode a vector to PQ codes
|
||||
pub fn encode(&self, vector: &[i8]) -> crate::Result<PQCode<M>> {
|
||||
if vector.len() != self.config.dim {
|
||||
return Err(crate::Error::InvalidModel("Vector dimension mismatch"));
|
||||
}
|
||||
|
||||
let mut codes = HVec::new();
|
||||
let d = self.config.subvec_dim;
|
||||
|
||||
for m in 0..self.config.num_subquantizers {
|
||||
let subvec = &vector[m * d..(m + 1) * d];
|
||||
|
||||
// Find nearest centroid
|
||||
let mut best_code = 0u8;
|
||||
let mut best_dist = i32::MAX;
|
||||
|
||||
for k in 0..self.config.codebook_size {
|
||||
let centroid = self.get_centroid(m, k);
|
||||
let dist = Self::l2_squared(subvec, centroid);
|
||||
if dist < best_dist {
|
||||
best_dist = dist;
|
||||
best_code = k as u8;
|
||||
}
|
||||
}
|
||||
|
||||
codes.push(best_code).map_err(|_| crate::Error::BufferOverflow)?;
|
||||
}
|
||||
|
||||
Ok(PQCode { codes })
|
||||
}
|
||||
|
||||
/// Decode PQ codes back to approximate vector
|
||||
pub fn decode(&self, code: &PQCode<M>, output: &mut [i8]) -> crate::Result<()> {
|
||||
if output.len() != self.config.dim {
|
||||
return Err(crate::Error::InvalidModel("Output dimension mismatch"));
|
||||
}
|
||||
|
||||
let d = self.config.subvec_dim;
|
||||
|
||||
for m in 0..self.config.num_subquantizers {
|
||||
let k = code.get_code(m) as usize;
|
||||
let centroid = self.get_centroid(m, k);
|
||||
output[m * d..(m + 1) * d].copy_from_slice(centroid);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Compute asymmetric distance: exact query vs PQ-encoded database vector
|
||||
pub fn asymmetric_distance(&self, query: &[i8], code: &PQCode<M>) -> i32 {
|
||||
let d = self.config.subvec_dim;
|
||||
let mut total_dist: i32 = 0;
|
||||
|
||||
for m in 0..self.config.num_subquantizers {
|
||||
let query_sub = &query[m * d..(m + 1) * d];
|
||||
let k = code.get_code(m) as usize;
|
||||
let centroid = self.get_centroid(m, k);
|
||||
total_dist += Self::l2_squared(query_sub, centroid);
|
||||
}
|
||||
|
||||
total_dist
|
||||
}
|
||||
|
||||
/// Compute distance using pre-computed distance table (faster for batch queries)
|
||||
pub fn distance_with_table(&self, table: &PQDistanceTable<M, K>, code: &PQCode<M>) -> i32 {
|
||||
let mut total: i32 = 0;
|
||||
for m in 0..self.config.num_subquantizers {
|
||||
let k = code.get_code(m) as usize;
|
||||
total += table.get(m, k);
|
||||
}
|
||||
total
|
||||
}
|
||||
|
||||
/// Build distance table for a query (precompute all query-centroid distances)
|
||||
pub fn build_distance_table(&self, query: &[i8]) -> PQDistanceTable<M, K> {
|
||||
let mut table = PQDistanceTable::new();
|
||||
let d = self.config.subvec_dim;
|
||||
|
||||
for m in 0..self.config.num_subquantizers {
|
||||
let query_sub = &query[m * d..(m + 1) * d];
|
||||
for k in 0..self.config.codebook_size {
|
||||
let centroid = self.get_centroid(m, k);
|
||||
let dist = Self::l2_squared(query_sub, centroid);
|
||||
table.set(m, k, dist);
|
||||
}
|
||||
}
|
||||
|
||||
table
|
||||
}
|
||||
|
||||
/// L2 squared distance between two INT8 vectors
|
||||
#[inline]
|
||||
fn l2_squared(a: &[i8], b: &[i8]) -> i32 {
|
||||
let mut sum: i32 = 0;
|
||||
for (&x, &y) in a.iter().zip(b.iter()) {
|
||||
let diff = x as i32 - y as i32;
|
||||
sum += diff * diff;
|
||||
}
|
||||
sum
|
||||
}
|
||||
|
||||
/// Memory usage of codebooks
|
||||
pub fn memory_size(&self) -> usize {
|
||||
self.codebooks.len()
|
||||
}
|
||||
|
||||
/// Compression ratio vs INT8
|
||||
pub fn compression_ratio(&self) -> f32 {
|
||||
let original = self.config.dim as f32; // 1 byte per dim
|
||||
let compressed = self.config.num_subquantizers as f32; // 1 byte per code
|
||||
original / compressed
|
||||
}
|
||||
}
|
||||
|
||||
/// Pre-computed distance table for fast PQ distance computation
|
||||
pub struct PQDistanceTable<const M: usize, const K: usize> {
|
||||
/// Distances: [M][K] flattened
|
||||
distances: [i32; 128], // Max 8 subquantizers * 16 codes
|
||||
}
|
||||
|
||||
impl<const M: usize, const K: usize> PQDistanceTable<M, K> {
|
||||
/// Create empty table
|
||||
pub fn new() -> Self {
|
||||
Self { distances: [0; 128] }
|
||||
}
|
||||
|
||||
/// Get distance for subquantizer m, code k
|
||||
#[inline]
|
||||
pub fn get(&self, m: usize, k: usize) -> i32 {
|
||||
self.distances[m * K + k]
|
||||
}
|
||||
|
||||
/// Set distance for subquantizer m, code k
|
||||
#[inline]
|
||||
pub fn set(&mut self, m: usize, k: usize, dist: i32) {
|
||||
self.distances[m * K + k] = dist;
|
||||
}
|
||||
}
|
||||
|
||||
impl<const M: usize, const K: usize> Default for PQDistanceTable<M, K> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_pq_config() {
|
||||
let config = PQConfig::default();
|
||||
assert_eq!(config.num_subquantizers, 4);
|
||||
assert_eq!(config.codebook_size, 16);
|
||||
assert_eq!(config.subvec_dim, 8);
|
||||
assert_eq!(config.dim, 32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pq_encode_decode() {
|
||||
let config = PQConfig {
|
||||
num_subquantizers: 4,
|
||||
codebook_size: 16,
|
||||
subvec_dim: 8,
|
||||
dim: 32,
|
||||
};
|
||||
|
||||
let pq = ProductQuantizer::<4, 16, 8>::random(config, 42).unwrap();
|
||||
|
||||
// Create a test vector
|
||||
let mut vector = [0i8; 32];
|
||||
for i in 0..32 {
|
||||
vector[i] = (i as i8).wrapping_mul(3);
|
||||
}
|
||||
|
||||
// Encode
|
||||
let code = pq.encode(&vector).unwrap();
|
||||
assert_eq!(code.codes.len(), 4);
|
||||
|
||||
// Decode
|
||||
let mut decoded = [0i8; 32];
|
||||
pq.decode(&code, &mut decoded).unwrap();
|
||||
|
||||
// Decoded should be approximate (using centroids)
|
||||
// Just verify it runs without error
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pq_compression() {
|
||||
let config = PQConfig::default();
|
||||
let pq = ProductQuantizer::<4, 16, 8>::random(config, 42).unwrap();
|
||||
|
||||
// 32 bytes original -> 4 bytes codes = 8x compression
|
||||
assert_eq!(pq.compression_ratio(), 8.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_distance_table() {
|
||||
let config = PQConfig::default();
|
||||
let pq = ProductQuantizer::<4, 16, 8>::random(config, 42).unwrap();
|
||||
|
||||
let mut query = [0i8; 32];
|
||||
for i in 0..32 {
|
||||
query[i] = i as i8;
|
||||
}
|
||||
|
||||
let table = pq.build_distance_table(&query);
|
||||
|
||||
// Encode a vector and compute distance both ways
|
||||
let mut vector = [10i8; 32];
|
||||
let code = pq.encode(&vector).unwrap();
|
||||
|
||||
let dist1 = pq.asymmetric_distance(&query, &code);
|
||||
let dist2 = pq.distance_with_table(&table, &code);
|
||||
|
||||
// Should be equal
|
||||
assert_eq!(dist1, dist2);
|
||||
}
|
||||
}
|
||||
446
vendor/ruvector/examples/ruvLLM/esp32/src/optimizations/pruning.rs
vendored
Normal file
446
vendor/ruvector/examples/ruvLLM/esp32/src/optimizations/pruning.rs
vendored
Normal file
@@ -0,0 +1,446 @@
|
||||
//! MinCut-Inspired Layer Pruning for ESP32
|
||||
//!
|
||||
//! Intelligent pruning strategies adapted from ruvector graph algorithms.
|
||||
//! Identifies and removes least important weights/neurons while preserving model quality.
|
||||
|
||||
use heapless::Vec as HVec;
|
||||
|
||||
/// Maximum neurons to track for pruning
|
||||
pub const MAX_PRUNING_UNITS: usize = 64;
|
||||
|
||||
/// Pruning configuration
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct PruningConfig {
|
||||
/// Target sparsity (0.0 = no pruning, 1.0 = all pruned)
|
||||
pub target_sparsity: f32,
|
||||
/// Minimum importance threshold (absolute value)
|
||||
pub importance_threshold: i8,
|
||||
/// Enable structured pruning (whole neurons vs individual weights)
|
||||
pub structured: bool,
|
||||
/// Gradual pruning steps (0 = one-shot)
|
||||
pub gradual_steps: usize,
|
||||
}
|
||||
|
||||
impl Default for PruningConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
target_sparsity: 0.5,
|
||||
importance_threshold: 8,
|
||||
structured: true,
|
||||
gradual_steps: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Maximum mask words (supports up to 2048 weights)
|
||||
pub const MAX_MASK_WORDS: usize = 64;
|
||||
|
||||
/// Pruning mask for a weight matrix
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PruningMask<const N: usize> {
|
||||
/// Bitmask: 1 = keep, 0 = prune
|
||||
pub mask: HVec<u32, MAX_MASK_WORDS>,
|
||||
/// Number of elements
|
||||
pub size: usize,
|
||||
/// Number of pruned elements
|
||||
pub pruned_count: usize,
|
||||
}
|
||||
|
||||
impl<const N: usize> PruningMask<N> {
|
||||
/// Create mask with all weights kept
|
||||
pub fn new(size: usize) -> crate::Result<Self> {
|
||||
let num_words = (size + 31) / 32;
|
||||
let mut mask = HVec::new();
|
||||
|
||||
for i in 0..num_words {
|
||||
let bits = if i == num_words - 1 && size % 32 != 0 {
|
||||
(1u32 << (size % 32)) - 1
|
||||
} else {
|
||||
u32::MAX
|
||||
};
|
||||
mask.push(bits).map_err(|_| crate::Error::BufferOverflow)?;
|
||||
}
|
||||
|
||||
Ok(Self { mask, size, pruned_count: 0 })
|
||||
}
|
||||
|
||||
/// Check if weight at index is kept
|
||||
#[inline]
|
||||
pub fn is_kept(&self, idx: usize) -> bool {
|
||||
let word = idx / 32;
|
||||
let bit = idx % 32;
|
||||
(self.mask.get(word).copied().unwrap_or(0) >> bit) & 1 == 1
|
||||
}
|
||||
|
||||
/// Prune weight at index
|
||||
pub fn prune(&mut self, idx: usize) {
|
||||
if idx < self.size && self.is_kept(idx) {
|
||||
let word = idx / 32;
|
||||
let bit = idx % 32;
|
||||
if let Some(w) = self.mask.get_mut(word) {
|
||||
*w &= !(1 << bit);
|
||||
self.pruned_count += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Current sparsity level
|
||||
pub fn sparsity(&self) -> f32 {
|
||||
self.pruned_count as f32 / self.size as f32
|
||||
}
|
||||
}
|
||||
|
||||
/// Layer-level pruner using importance scoring
|
||||
pub struct LayerPruner {
|
||||
/// Configuration
|
||||
config: PruningConfig,
|
||||
/// Importance scores for neurons/weights
|
||||
importance_scores: HVec<i16, MAX_PRUNING_UNITS>,
|
||||
/// Current pruning step (for gradual pruning)
|
||||
current_step: usize,
|
||||
}
|
||||
|
||||
impl LayerPruner {
|
||||
/// Create new pruner with config
|
||||
pub fn new(config: PruningConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
importance_scores: HVec::new(),
|
||||
current_step: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute importance scores for weights using magnitude
|
||||
pub fn compute_magnitude_importance(&mut self, weights: &[i8]) {
|
||||
self.importance_scores.clear();
|
||||
|
||||
for &w in weights.iter().take(MAX_PRUNING_UNITS) {
|
||||
let importance = (w as i16).abs();
|
||||
let _ = self.importance_scores.push(importance);
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute importance using gradient information (simplified)
|
||||
/// For on-device: use weight * activation as proxy
|
||||
pub fn compute_gradient_importance(&mut self, weights: &[i8], activations: &[i8]) {
|
||||
self.importance_scores.clear();
|
||||
|
||||
for (&w, &a) in weights.iter().zip(activations.iter()).take(MAX_PRUNING_UNITS) {
|
||||
// |weight * activation| as importance proxy
|
||||
let importance = ((w as i32 * a as i32).abs() >> 4) as i16;
|
||||
let _ = self.importance_scores.push(importance);
|
||||
}
|
||||
}
|
||||
|
||||
/// Create pruning mask based on importance scores
|
||||
pub fn create_mask<const N: usize>(&self, size: usize) -> crate::Result<PruningMask<N>> {
|
||||
let mut mask = PruningMask::new(size)?;
|
||||
|
||||
// Count weights below threshold
|
||||
let threshold = self.compute_threshold(size);
|
||||
|
||||
for (idx, &score) in self.importance_scores.iter().enumerate() {
|
||||
if score < threshold {
|
||||
mask.prune(idx);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(mask)
|
||||
}
|
||||
|
||||
/// Compute importance threshold for target sparsity
|
||||
fn compute_threshold(&self, size: usize) -> i16 {
|
||||
let target_pruned = (size as f32 * self.config.target_sparsity) as usize;
|
||||
|
||||
if target_pruned == 0 || self.importance_scores.is_empty() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Find threshold that achieves target sparsity
|
||||
// Simple approach: sort importance and pick threshold
|
||||
let mut sorted: HVec<i16, MAX_PRUNING_UNITS> = HVec::new();
|
||||
for &s in &self.importance_scores {
|
||||
let _ = sorted.push(s);
|
||||
}
|
||||
|
||||
// Bubble sort (fine for small arrays)
|
||||
for i in 0..sorted.len() {
|
||||
for j in 0..sorted.len() - 1 - i {
|
||||
if sorted[j] > sorted[j + 1] {
|
||||
sorted.swap(j, j + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let idx = target_pruned.min(sorted.len().saturating_sub(1));
|
||||
sorted.get(idx).copied().unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Apply pruning mask to weights in-place
|
||||
pub fn apply_mask<const N: usize>(&self, weights: &mut [i8], mask: &PruningMask<N>) {
|
||||
for (idx, weight) in weights.iter_mut().enumerate() {
|
||||
if !mask.is_kept(idx) {
|
||||
*weight = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Structured pruning: remove entire neurons
|
||||
pub fn prune_neurons(
|
||||
&mut self,
|
||||
weights: &mut [i8],
|
||||
input_dim: usize,
|
||||
output_dim: usize,
|
||||
) -> HVec<bool, MAX_PRUNING_UNITS> {
|
||||
// Compute per-neuron importance (L1 norm of weights)
|
||||
let mut neuron_importance: HVec<i32, MAX_PRUNING_UNITS> = HVec::new();
|
||||
|
||||
for out_idx in 0..output_dim.min(MAX_PRUNING_UNITS) {
|
||||
let mut l1_sum: i32 = 0;
|
||||
for in_idx in 0..input_dim {
|
||||
let w_idx = out_idx * input_dim + in_idx;
|
||||
if w_idx < weights.len() {
|
||||
l1_sum += (weights[w_idx] as i32).abs();
|
||||
}
|
||||
}
|
||||
let _ = neuron_importance.push(l1_sum);
|
||||
}
|
||||
|
||||
// Find threshold
|
||||
let target_pruned = (output_dim as f32 * self.config.target_sparsity) as usize;
|
||||
let mut sorted: HVec<i32, MAX_PRUNING_UNITS> = neuron_importance.clone();
|
||||
|
||||
for i in 0..sorted.len() {
|
||||
for j in 0..sorted.len() - 1 - i {
|
||||
if sorted[j] > sorted[j + 1] {
|
||||
sorted.swap(j, j + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let threshold = sorted.get(target_pruned).copied().unwrap_or(0);
|
||||
|
||||
// Mark neurons to prune
|
||||
let mut keep_mask: HVec<bool, MAX_PRUNING_UNITS> = HVec::new();
|
||||
|
||||
for &importance in &neuron_importance {
|
||||
let _ = keep_mask.push(importance >= threshold);
|
||||
}
|
||||
|
||||
// Zero out pruned neurons
|
||||
for out_idx in 0..output_dim.min(keep_mask.len()) {
|
||||
if !keep_mask[out_idx] {
|
||||
for in_idx in 0..input_dim {
|
||||
let w_idx = out_idx * input_dim + in_idx;
|
||||
if w_idx < weights.len() {
|
||||
weights[w_idx] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
keep_mask
|
||||
}
|
||||
|
||||
/// Get statistics about pruning
|
||||
pub fn pruning_stats<const N: usize>(&self, mask: &PruningMask<N>) -> PruningStats {
|
||||
PruningStats {
|
||||
total_weights: mask.size,
|
||||
pruned_weights: mask.pruned_count,
|
||||
sparsity: mask.sparsity(),
|
||||
memory_saved: mask.pruned_count, // 1 byte per weight
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Statistics about pruning results
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PruningStats {
|
||||
/// Total weight count
|
||||
pub total_weights: usize,
|
||||
/// Number of pruned weights
|
||||
pub pruned_weights: usize,
|
||||
/// Achieved sparsity
|
||||
pub sparsity: f32,
|
||||
/// Memory saved in bytes
|
||||
pub memory_saved: usize,
|
||||
}
|
||||
|
||||
/// MinCut-inspired importance scoring
|
||||
/// Treats weight matrix as bipartite graph, finds min-cut to preserve information flow
|
||||
pub struct MinCutScorer {
|
||||
/// Flow values from source to each input neuron
|
||||
input_flow: HVec<i32, MAX_PRUNING_UNITS>,
|
||||
/// Flow values from each output neuron to sink
|
||||
output_flow: HVec<i32, MAX_PRUNING_UNITS>,
|
||||
}
|
||||
|
||||
impl MinCutScorer {
|
||||
/// Create scorer
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
input_flow: HVec::new(),
|
||||
output_flow: HVec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute edge importance using simplified max-flow
|
||||
/// Edges in min-cut are most critical for information flow
|
||||
pub fn compute_edge_importance(
|
||||
&mut self,
|
||||
weights: &[i8],
|
||||
input_dim: usize,
|
||||
output_dim: usize,
|
||||
) -> HVec<i16, MAX_PRUNING_UNITS> {
|
||||
// Initialize flow (simplified: use column/row sums)
|
||||
self.input_flow.clear();
|
||||
self.output_flow.clear();
|
||||
|
||||
// Input flow: sum of absolute weights per input
|
||||
for in_idx in 0..input_dim.min(MAX_PRUNING_UNITS) {
|
||||
let mut flow: i32 = 0;
|
||||
for out_idx in 0..output_dim {
|
||||
let w_idx = out_idx * input_dim + in_idx;
|
||||
if w_idx < weights.len() {
|
||||
flow += (weights[w_idx] as i32).abs();
|
||||
}
|
||||
}
|
||||
let _ = self.input_flow.push(flow);
|
||||
}
|
||||
|
||||
// Output flow: sum of absolute weights per output
|
||||
for out_idx in 0..output_dim.min(MAX_PRUNING_UNITS) {
|
||||
let mut flow: i32 = 0;
|
||||
for in_idx in 0..input_dim {
|
||||
let w_idx = out_idx * input_dim + in_idx;
|
||||
if w_idx < weights.len() {
|
||||
flow += (weights[w_idx] as i32).abs();
|
||||
}
|
||||
}
|
||||
let _ = self.output_flow.push(flow);
|
||||
}
|
||||
|
||||
// Edge importance = min(input_flow, output_flow) * |weight|
|
||||
// Edges on min-cut have bottleneck flow
|
||||
let mut importance: HVec<i16, MAX_PRUNING_UNITS> = HVec::new();
|
||||
|
||||
for out_idx in 0..output_dim.min(self.output_flow.len()) {
|
||||
let out_flow = self.output_flow[out_idx];
|
||||
for in_idx in 0..input_dim.min(self.input_flow.len()) {
|
||||
let in_flow = self.input_flow[in_idx];
|
||||
let w_idx = out_idx * input_dim + in_idx;
|
||||
|
||||
if w_idx < weights.len() {
|
||||
let w = (weights[w_idx] as i32).abs();
|
||||
let bottleneck = in_flow.min(out_flow);
|
||||
let edge_importance = ((w * bottleneck) >> 10) as i16;
|
||||
|
||||
if importance.len() < MAX_PRUNING_UNITS {
|
||||
let _ = importance.push(edge_importance);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
importance
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MinCutScorer {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_pruning_mask() {
|
||||
let mut mask = PruningMask::<64>::new(50).unwrap();
|
||||
|
||||
assert!(mask.is_kept(0));
|
||||
assert!(mask.is_kept(49));
|
||||
assert_eq!(mask.sparsity(), 0.0);
|
||||
|
||||
mask.prune(10);
|
||||
mask.prune(20);
|
||||
|
||||
assert!(!mask.is_kept(10));
|
||||
assert!(!mask.is_kept(20));
|
||||
assert!(mask.is_kept(15));
|
||||
assert_eq!(mask.pruned_count, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_magnitude_pruning() {
|
||||
let config = PruningConfig {
|
||||
target_sparsity: 0.5,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut pruner = LayerPruner::new(config);
|
||||
|
||||
// Weights with varying magnitudes
|
||||
let weights: [i8; 8] = [1, -2, 50, -60, 3, -4, 70, 5];
|
||||
pruner.compute_magnitude_importance(&weights);
|
||||
|
||||
let mask = pruner.create_mask::<8>(8).unwrap();
|
||||
|
||||
// Should prune ~50% (low magnitude weights)
|
||||
assert!(mask.sparsity() >= 0.25 && mask.sparsity() <= 0.75);
|
||||
|
||||
// High magnitude weights should be kept
|
||||
assert!(mask.is_kept(2)); // 50
|
||||
assert!(mask.is_kept(3)); // -60
|
||||
assert!(mask.is_kept(6)); // 70
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_structured_pruning() {
|
||||
let config = PruningConfig {
|
||||
target_sparsity: 0.5,
|
||||
structured: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut pruner = LayerPruner::new(config);
|
||||
|
||||
// 4x4 weight matrix
|
||||
let mut weights: [i8; 16] = [
|
||||
10, 10, 10, 10, // High importance neuron
|
||||
1, 1, 1, 1, // Low importance
|
||||
20, 20, 20, 20, // High importance
|
||||
2, 2, 2, 2, // Low importance
|
||||
];
|
||||
|
||||
let keep_mask = pruner.prune_neurons(&mut weights, 4, 4);
|
||||
|
||||
// Should keep high importance neurons
|
||||
assert!(keep_mask[0]); // First neuron kept
|
||||
assert!(keep_mask[2]); // Third neuron kept
|
||||
|
||||
// Low importance neurons should be zeroed
|
||||
if !keep_mask[1] {
|
||||
assert_eq!(weights[4], 0);
|
||||
assert_eq!(weights[5], 0);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mincut_scorer() {
|
||||
let mut scorer = MinCutScorer::new();
|
||||
|
||||
let weights: [i8; 9] = [
|
||||
10, 20, 30,
|
||||
5, 10, 15,
|
||||
1, 2, 3,
|
||||
];
|
||||
|
||||
let importance = scorer.compute_edge_importance(&weights, 3, 3);
|
||||
|
||||
// Should have computed importance for edges
|
||||
assert!(!importance.is_empty());
|
||||
}
|
||||
}
|
||||
298
vendor/ruvector/examples/ruvLLM/esp32/src/optimizations/sparse_attention.rs
vendored
Normal file
298
vendor/ruvector/examples/ruvLLM/esp32/src/optimizations/sparse_attention.rs
vendored
Normal file
@@ -0,0 +1,298 @@
|
||||
//! Sparse Attention Patterns for ESP32
|
||||
//!
|
||||
//! Reduces attention complexity from O(n²) to O(n) using:
|
||||
//! - Sliding window attention
|
||||
//! - Strided patterns
|
||||
//! - Block-sparse attention
|
||||
|
||||
use heapless::Vec as HVec;
|
||||
|
||||
/// Maximum sequence length for sparse patterns
|
||||
pub const MAX_SPARSE_SEQ: usize = 32;
|
||||
/// Maximum window size
|
||||
pub const MAX_WINDOW_SIZE: usize = 8;
|
||||
|
||||
/// Attention pattern types
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum AttentionPattern {
|
||||
/// Full attention (O(n²)) - baseline
|
||||
Full,
|
||||
/// Sliding window attention (O(n * w))
|
||||
SlidingWindow { window_size: usize },
|
||||
/// Strided attention (O(n * n/s))
|
||||
Strided { stride: usize },
|
||||
/// Combined window + stride
|
||||
Longformer { window_size: usize, stride: usize },
|
||||
/// Block diagonal attention
|
||||
BlockDiagonal { block_size: usize },
|
||||
/// Local + global tokens
|
||||
BigBird { window_size: usize, global_tokens: usize },
|
||||
}
|
||||
|
||||
impl Default for AttentionPattern {
|
||||
fn default() -> Self {
|
||||
// Sliding window is best for tiny models
|
||||
Self::SlidingWindow { window_size: 4 }
|
||||
}
|
||||
}
|
||||
|
||||
/// Sparse attention implementation
|
||||
pub struct SparseAttention {
|
||||
/// Pattern type
|
||||
pattern: AttentionPattern,
|
||||
/// Attention mask (true = attend, false = skip)
|
||||
/// Stored as bitmask for memory efficiency
|
||||
mask_data: HVec<u32, MAX_SPARSE_SEQ>,
|
||||
/// Sequence length
|
||||
seq_len: usize,
|
||||
}
|
||||
|
||||
impl SparseAttention {
|
||||
/// Create sparse attention with given pattern
|
||||
pub fn new(pattern: AttentionPattern, seq_len: usize) -> crate::Result<Self> {
|
||||
if seq_len > MAX_SPARSE_SEQ {
|
||||
return Err(crate::Error::BufferOverflow);
|
||||
}
|
||||
|
||||
let mut sa = Self {
|
||||
pattern,
|
||||
mask_data: HVec::new(),
|
||||
seq_len,
|
||||
};
|
||||
|
||||
sa.build_mask()?;
|
||||
Ok(sa)
|
||||
}
|
||||
|
||||
/// Build attention mask based on pattern
|
||||
fn build_mask(&mut self) -> crate::Result<()> {
|
||||
self.mask_data.clear();
|
||||
|
||||
for i in 0..self.seq_len {
|
||||
let mut row_mask: u32 = 0;
|
||||
|
||||
for j in 0..self.seq_len {
|
||||
if j <= i && self.should_attend(i, j) {
|
||||
row_mask |= 1 << j;
|
||||
}
|
||||
}
|
||||
|
||||
self.mask_data.push(row_mask).map_err(|_| crate::Error::BufferOverflow)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if position i should attend to position j
|
||||
fn should_attend(&self, i: usize, j: usize) -> bool {
|
||||
match self.pattern {
|
||||
AttentionPattern::Full => true,
|
||||
|
||||
AttentionPattern::SlidingWindow { window_size } => {
|
||||
i.saturating_sub(window_size) <= j
|
||||
}
|
||||
|
||||
AttentionPattern::Strided { stride } => {
|
||||
j % stride == 0 || i.saturating_sub(1) <= j
|
||||
}
|
||||
|
||||
AttentionPattern::Longformer { window_size, stride } => {
|
||||
// Local window OR strided global
|
||||
i.saturating_sub(window_size) <= j || j % stride == 0
|
||||
}
|
||||
|
||||
AttentionPattern::BlockDiagonal { block_size } => {
|
||||
// Same block
|
||||
i / block_size == j / block_size
|
||||
}
|
||||
|
||||
AttentionPattern::BigBird { window_size, global_tokens } => {
|
||||
// Local window OR global tokens (first N positions)
|
||||
i.saturating_sub(window_size) <= j || j < global_tokens
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if query position i should attend to key position j
|
||||
#[inline]
|
||||
pub fn should_attend_at(&self, i: usize, j: usize) -> bool {
|
||||
if i >= self.seq_len || j >= self.seq_len {
|
||||
return false;
|
||||
}
|
||||
(self.mask_data[i] >> j) & 1 == 1
|
||||
}
|
||||
|
||||
/// Get mask row for position i (for vectorized attention)
|
||||
#[inline]
|
||||
pub fn get_mask_row(&self, i: usize) -> u32 {
|
||||
self.mask_data.get(i).copied().unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Apply sparse attention: scores = Q @ K^T, masked
|
||||
/// Only computes necessary positions
|
||||
pub fn sparse_qk(
|
||||
&self,
|
||||
query: &[i8], // [dim]
|
||||
keys: &[&[i8]], // [seq_len][dim]
|
||||
scores: &mut [i32], // [seq_len]
|
||||
query_pos: usize,
|
||||
) {
|
||||
let mask = self.get_mask_row(query_pos);
|
||||
|
||||
for (j, key) in keys.iter().enumerate() {
|
||||
if (mask >> j) & 1 == 1 {
|
||||
// Compute dot product
|
||||
let mut sum: i32 = 0;
|
||||
for (&q, &k) in query.iter().zip(key.iter()) {
|
||||
sum += q as i32 * k as i32;
|
||||
}
|
||||
scores[j] = sum;
|
||||
} else {
|
||||
scores[j] = i32::MIN; // Will be zeroed by softmax
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Count active attention positions
|
||||
pub fn active_positions(&self) -> usize {
|
||||
self.mask_data.iter().map(|m| m.count_ones() as usize).sum()
|
||||
}
|
||||
|
||||
/// Theoretical vs actual computation ratio
|
||||
pub fn sparsity_ratio(&self) -> f32 {
|
||||
let full = self.seq_len * (self.seq_len + 1) / 2; // Lower triangular
|
||||
let sparse = self.active_positions();
|
||||
sparse as f32 / full as f32
|
||||
}
|
||||
|
||||
/// Memory savings description
|
||||
pub fn memory_savings(&self) -> &'static str {
|
||||
match self.pattern {
|
||||
AttentionPattern::Full => "None (O(n²))",
|
||||
AttentionPattern::SlidingWindow { .. } => "O(n) - linear",
|
||||
AttentionPattern::Strided { .. } => "O(n) - linear",
|
||||
AttentionPattern::Longformer { .. } => "O(n) - linear",
|
||||
AttentionPattern::BlockDiagonal { .. } => "O(n) - block-linear",
|
||||
AttentionPattern::BigBird { .. } => "O(n) - linear",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Precomputed attention patterns for different sequence lengths
|
||||
pub struct AttentionPatternCache {
|
||||
/// Cached patterns for common lengths
|
||||
patterns: [Option<SparseAttention>; 4],
|
||||
}
|
||||
|
||||
impl AttentionPatternCache {
|
||||
/// Create cache with sliding window patterns
|
||||
pub fn new_sliding(window_size: usize) -> Self {
|
||||
let pattern = AttentionPattern::SlidingWindow { window_size };
|
||||
|
||||
Self {
|
||||
patterns: [
|
||||
SparseAttention::new(pattern, 8).ok(),
|
||||
SparseAttention::new(pattern, 16).ok(),
|
||||
SparseAttention::new(pattern, 24).ok(),
|
||||
SparseAttention::new(pattern, 32).ok(),
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
/// Get pattern for sequence length
|
||||
pub fn get(&self, seq_len: usize) -> Option<&SparseAttention> {
|
||||
let idx = match seq_len {
|
||||
1..=8 => 0,
|
||||
9..=16 => 1,
|
||||
17..=24 => 2,
|
||||
25..=32 => 3,
|
||||
_ => return None,
|
||||
};
|
||||
self.patterns[idx].as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_sliding_window() {
|
||||
let sa = SparseAttention::new(
|
||||
AttentionPattern::SlidingWindow { window_size: 2 },
|
||||
8,
|
||||
).unwrap();
|
||||
|
||||
// Position 0: should only attend to 0
|
||||
assert!(sa.should_attend_at(0, 0));
|
||||
assert!(!sa.should_attend_at(0, 1));
|
||||
|
||||
// Position 4: should attend to 2, 3, 4
|
||||
assert!(!sa.should_attend_at(4, 1));
|
||||
assert!(sa.should_attend_at(4, 2));
|
||||
assert!(sa.should_attend_at(4, 3));
|
||||
assert!(sa.should_attend_at(4, 4));
|
||||
assert!(!sa.should_attend_at(4, 5)); // Future
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_strided() {
|
||||
let sa = SparseAttention::new(
|
||||
AttentionPattern::Strided { stride: 4 },
|
||||
16,
|
||||
).unwrap();
|
||||
|
||||
// Position 10: attends to 0, 4, 8, 9, 10
|
||||
assert!(sa.should_attend_at(10, 0)); // stride
|
||||
assert!(sa.should_attend_at(10, 4)); // stride
|
||||
assert!(sa.should_attend_at(10, 8)); // stride
|
||||
assert!(sa.should_attend_at(10, 9)); // local
|
||||
assert!(sa.should_attend_at(10, 10)); // self
|
||||
assert!(!sa.should_attend_at(10, 1)); // not stride, not local
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sparsity() {
|
||||
let full = SparseAttention::new(AttentionPattern::Full, 16).unwrap();
|
||||
let sparse = SparseAttention::new(
|
||||
AttentionPattern::SlidingWindow { window_size: 4 },
|
||||
16,
|
||||
).unwrap();
|
||||
|
||||
// Full should have all positions
|
||||
assert!(full.sparsity_ratio() > 0.99);
|
||||
|
||||
// Sparse should save computation
|
||||
assert!(sparse.sparsity_ratio() < full.sparsity_ratio());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_block_diagonal() {
|
||||
let sa = SparseAttention::new(
|
||||
AttentionPattern::BlockDiagonal { block_size: 4 },
|
||||
16,
|
||||
).unwrap();
|
||||
|
||||
// Position 5 (block 1): attends to 4, 5 only
|
||||
assert!(!sa.should_attend_at(5, 3)); // Block 0
|
||||
assert!(sa.should_attend_at(5, 4)); // Block 1
|
||||
assert!(sa.should_attend_at(5, 5)); // Block 1, self
|
||||
assert!(!sa.should_attend_at(5, 6)); // Block 1, future
|
||||
assert!(!sa.should_attend_at(5, 8)); // Block 2
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bigbird() {
|
||||
let sa = SparseAttention::new(
|
||||
AttentionPattern::BigBird { window_size: 2, global_tokens: 2 },
|
||||
16,
|
||||
).unwrap();
|
||||
|
||||
// Position 10: attends to 0, 1 (global), 8, 9, 10 (window)
|
||||
assert!(sa.should_attend_at(10, 0)); // global
|
||||
assert!(sa.should_attend_at(10, 1)); // global
|
||||
assert!(!sa.should_attend_at(10, 5)); // neither
|
||||
assert!(sa.should_attend_at(10, 8)); // window
|
||||
assert!(sa.should_attend_at(10, 10)); // self
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user