Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
This commit is contained in:
570
crates/ruvector-mincut-gated-transformer/src/sparse_attention.rs
Normal file
570
crates/ruvector-mincut-gated-transformer/src/sparse_attention.rs
Normal file
@@ -0,0 +1,570 @@
|
||||
//! Mincut-aware sparse attention patterns.
|
||||
//!
|
||||
//! Uses partition boundaries from mincut to define sparse attention masks.
|
||||
//! Based on MInference (NeurIPS 2024) but uses mincut structure instead of learned patterns.
|
||||
//!
|
||||
//! ## Key Idea
|
||||
//!
|
||||
//! Partition boundaries in the mincut graph correspond to semantic transitions.
|
||||
//! We can use this to define sparse attention patterns that:
|
||||
//! - Dense within partitions (high coherence regions)
|
||||
//! - Sparse across partitions (only boundary tokens attend)
|
||||
//! - Lambda-adaptive density (higher lambda = denser attention)
|
||||
//!
|
||||
//! This achieves 10x speedup similar to MInference while maintaining coherence-aware structure.
|
||||
|
||||
extern crate alloc;
|
||||
use alloc::collections::BTreeSet;
|
||||
use alloc::vec;
|
||||
use alloc::vec::Vec;
|
||||
|
||||
use crate::packets::GatePacket;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Configuration for sparse attention patterns.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct SparsityConfig {
|
||||
/// Enable full attention within partitions
|
||||
pub intra_partition_attention: bool,
|
||||
|
||||
/// Enable boundary cross-partition attention
|
||||
pub boundary_cross_attention: bool,
|
||||
|
||||
/// Lambda-based density scheduling
|
||||
pub lambda_based_density: Option<LambdaDensitySchedule>,
|
||||
|
||||
/// Maximum cross-partition edges to consider
|
||||
pub max_cross_partition_edges: u16,
|
||||
|
||||
/// Minimum density threshold (Q15: 0-32767)
|
||||
/// Below this density, fall back to full attention
|
||||
pub min_density_q15: u16,
|
||||
|
||||
/// Maximum density threshold (Q15: 0-32767)
|
||||
/// Above this density, use full attention
|
||||
pub max_density_q15: u16,
|
||||
}
|
||||
|
||||
impl Default for SparsityConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
intra_partition_attention: true,
|
||||
boundary_cross_attention: true,
|
||||
lambda_based_density: Some(LambdaDensitySchedule::Adaptive),
|
||||
max_cross_partition_edges: 20,
|
||||
min_density_q15: 3277, // ~10% minimum
|
||||
max_density_q15: 29491, // ~90% maximum
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Lambda-based density scheduling strategies.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum LambdaDensitySchedule {
|
||||
/// Linear interpolation between min and max density based on lambda
|
||||
Linear {
|
||||
/// Minimum density at lambda_min
|
||||
min_density: f32,
|
||||
/// Maximum density at lambda_max
|
||||
max_density: f32,
|
||||
},
|
||||
|
||||
/// Threshold-based: dense when lambda >= threshold
|
||||
Threshold {
|
||||
/// Lambda threshold for dense attention
|
||||
dense_above_lambda: u32,
|
||||
},
|
||||
|
||||
/// Adaptive based on lambda trend and boundary statistics
|
||||
Adaptive,
|
||||
}
|
||||
|
||||
/// Sparse attention mask representation.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct SparseMask {
|
||||
/// Sparse attention positions (query_pos, key_pos)
|
||||
pub positions: Vec<(u16, u16)>,
|
||||
|
||||
/// Actual density (fraction of positions attended)
|
||||
pub density: f32,
|
||||
|
||||
/// Partition boundaries (start positions of each partition)
|
||||
pub partition_boundaries: Vec<u16>,
|
||||
|
||||
/// Boundary token indices
|
||||
pub boundary_tokens: Vec<u16>,
|
||||
}
|
||||
|
||||
impl SparseMask {
|
||||
/// Create an empty sparse mask
|
||||
pub fn empty() -> Self {
|
||||
Self {
|
||||
positions: Vec::new(),
|
||||
density: 0.0,
|
||||
partition_boundaries: Vec::new(),
|
||||
boundary_tokens: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a full attention mask (all positions)
|
||||
pub fn full(seq_len: usize) -> Self {
|
||||
let mut positions = Vec::with_capacity(seq_len * seq_len);
|
||||
for i in 0..seq_len {
|
||||
for j in 0..=i {
|
||||
positions.push((i as u16, j as u16));
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
positions,
|
||||
density: 1.0,
|
||||
partition_boundaries: vec![0],
|
||||
boundary_tokens: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if query position i can attend to key position j
|
||||
#[inline]
|
||||
pub fn can_attend(&self, query_pos: u16, key_pos: u16) -> bool {
|
||||
self.positions.contains(&(query_pos, key_pos))
|
||||
}
|
||||
|
||||
/// Get number of attention positions
|
||||
#[inline]
|
||||
pub fn num_positions(&self) -> usize {
|
||||
self.positions.len()
|
||||
}
|
||||
|
||||
/// Get theoretical max positions (for causal attention)
|
||||
#[inline]
|
||||
pub fn max_positions(&self, seq_len: usize) -> usize {
|
||||
seq_len * (seq_len + 1) / 2
|
||||
}
|
||||
|
||||
/// Calculate sparsity ratio (1.0 - density)
|
||||
#[inline]
|
||||
pub fn sparsity(&self) -> f32 {
|
||||
1.0 - self.density
|
||||
}
|
||||
}
|
||||
|
||||
/// Mincut-aware sparse attention builder.
|
||||
pub struct MincutSparseAttention {
|
||||
config: SparsityConfig,
|
||||
}
|
||||
|
||||
impl MincutSparseAttention {
|
||||
/// Create new mincut sparse attention builder
|
||||
pub fn new(config: SparsityConfig) -> Self {
|
||||
Self { config }
|
||||
}
|
||||
|
||||
/// Build sparse attention mask from gate packet.
|
||||
///
|
||||
/// The mask structure depends on:
|
||||
/// - Partition count: determines number of attention blocks
|
||||
/// - Lambda value: determines density within blocks
|
||||
/// - Boundary edges: determines cross-partition attention
|
||||
pub fn build_mask(&self, gate: &GatePacket, seq_len: usize) -> SparseMask {
|
||||
// Check if we should use sparse attention
|
||||
if !self.should_use_sparse(gate, seq_len) {
|
||||
return SparseMask::full(seq_len);
|
||||
}
|
||||
|
||||
// Calculate target density based on lambda
|
||||
let target_density = self.calculate_density(gate);
|
||||
|
||||
// Estimate partition boundaries (simplified - in practice would come from mincut)
|
||||
let partition_boundaries = self.estimate_partition_boundaries(gate, seq_len);
|
||||
|
||||
// Identify boundary tokens
|
||||
let boundary_tokens = self.identify_boundary_tokens(&partition_boundaries, gate);
|
||||
|
||||
// Build sparse mask
|
||||
let positions = self.build_sparse_positions(
|
||||
seq_len,
|
||||
&partition_boundaries,
|
||||
&boundary_tokens,
|
||||
target_density,
|
||||
gate,
|
||||
);
|
||||
|
||||
// Compute actual density (guard against division by zero)
|
||||
let full_positions = seq_len.saturating_mul(seq_len.saturating_add(1)) / 2;
|
||||
let actual_density = if full_positions > 0 {
|
||||
positions.len() as f32 / full_positions as f32
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
SparseMask {
|
||||
positions,
|
||||
density: actual_density,
|
||||
partition_boundaries,
|
||||
boundary_tokens,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute sparse attention with mask.
|
||||
///
|
||||
/// Only computes attention for positions in the sparse mask.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `q` - Query vectors [seq_len, dim], i8
|
||||
/// * `k` - Key vectors [seq_len, dim], i8
|
||||
/// * `v` - Value vectors [seq_len, dim], i8
|
||||
/// * `mask` - Sparse attention mask
|
||||
/// * `scale` - Attention scale factor
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Output vectors [seq_len, dim], f32
|
||||
pub fn sparse_attention(
|
||||
&self,
|
||||
q: &[i8],
|
||||
k: &[i8],
|
||||
v: &[i8],
|
||||
mask: &SparseMask,
|
||||
dim: usize,
|
||||
scale: f32,
|
||||
) -> Vec<f32> {
|
||||
let seq_len = q.len() / dim;
|
||||
let mut output = vec![0.0f32; seq_len * dim];
|
||||
|
||||
// Group positions by query
|
||||
let mut positions_by_query: Vec<Vec<u16>> = vec![Vec::new(); seq_len];
|
||||
for &(query_pos, key_pos) in &mask.positions {
|
||||
positions_by_query[query_pos as usize].push(key_pos);
|
||||
}
|
||||
|
||||
// Compute attention for each query position
|
||||
for query_pos in 0..seq_len {
|
||||
let key_positions = &positions_by_query[query_pos];
|
||||
if key_positions.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Compute scores for sparse keys
|
||||
let mut scores = Vec::with_capacity(key_positions.len());
|
||||
for &key_pos in key_positions {
|
||||
let mut score = 0i32;
|
||||
for d in 0..dim {
|
||||
let q_val = q[query_pos * dim + d] as i32;
|
||||
let k_val = k[key_pos as usize * dim + d] as i32;
|
||||
score += q_val * k_val;
|
||||
}
|
||||
scores.push((score as f32) * scale);
|
||||
}
|
||||
|
||||
// Softmax over sparse positions
|
||||
self.softmax(&mut scores);
|
||||
|
||||
// Weighted sum of values
|
||||
for d in 0..dim {
|
||||
let mut sum = 0.0f32;
|
||||
for (i, &key_pos) in key_positions.iter().enumerate() {
|
||||
let v_val = v[key_pos as usize * dim + d] as f32;
|
||||
sum += scores[i] * v_val;
|
||||
}
|
||||
output[query_pos * dim + d] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Estimate FLOPs savings compared to full attention.
|
||||
///
|
||||
/// Returns ratio of sparse FLOPs to full FLOPs.
|
||||
/// Lower is better (e.g., 0.1 = 10x speedup).
|
||||
pub fn estimated_flops_ratio(&self, mask: &SparseMask, seq_len: usize) -> f32 {
|
||||
let sparse_ops = mask.num_positions();
|
||||
let full_ops = seq_len * (seq_len + 1) / 2;
|
||||
|
||||
if full_ops == 0 {
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
sparse_ops as f32 / full_ops as f32
|
||||
}
|
||||
|
||||
// ---- Private helpers ----
|
||||
|
||||
fn should_use_sparse(&self, gate: &GatePacket, seq_len: usize) -> bool {
|
||||
// Use sparse attention if:
|
||||
// 1. Sequence is long enough to benefit
|
||||
// 2. We have meaningful partition structure
|
||||
// 3. Lambda indicates stability
|
||||
seq_len >= 16 && gate.partition_count >= 2 && gate.lambda >= 30 // Minimum stability threshold
|
||||
}
|
||||
|
||||
pub fn calculate_density(&self, gate: &GatePacket) -> f32 {
|
||||
match &self.config.lambda_based_density {
|
||||
Some(LambdaDensitySchedule::Linear {
|
||||
min_density,
|
||||
max_density,
|
||||
}) => {
|
||||
// Linear interpolation based on lambda
|
||||
// Assume lambda range [30, 300]
|
||||
let lambda_normalized =
|
||||
((gate.lambda.min(300) as f32 - 30.0) / 270.0).clamp(0.0, 1.0);
|
||||
min_density + lambda_normalized * (max_density - min_density)
|
||||
}
|
||||
Some(LambdaDensitySchedule::Threshold { dense_above_lambda }) => {
|
||||
if gate.lambda >= *dense_above_lambda {
|
||||
0.9 // Dense
|
||||
} else {
|
||||
0.1 // Sparse
|
||||
}
|
||||
}
|
||||
Some(LambdaDensitySchedule::Adaptive) => {
|
||||
// Adaptive: consider lambda, boundary stats, and partition count
|
||||
let base_density = ((gate.lambda as f32 / 150.0).clamp(0.0, 1.0) * 0.6) + 0.1;
|
||||
|
||||
// Increase density if high boundary concentration (unstable boundaries)
|
||||
let boundary_factor = (gate.boundary_concentration_q15 as f32 / 32768.0) * 0.2;
|
||||
|
||||
// Decrease density with more partitions (more structure to exploit)
|
||||
let partition_factor = (-0.05 * gate.partition_count as f32).max(-0.2);
|
||||
|
||||
(base_density + boundary_factor + partition_factor).clamp(0.1, 0.9)
|
||||
}
|
||||
None => 0.5, // Default 50% density
|
||||
}
|
||||
}
|
||||
|
||||
pub fn estimate_partition_boundaries(&self, gate: &GatePacket, seq_len: usize) -> Vec<u16> {
|
||||
// Simplified partition estimation
|
||||
// In practice, this would come from actual mincut partition info
|
||||
let num_partitions = gate.partition_count.max(1) as usize;
|
||||
let partition_size = seq_len / num_partitions;
|
||||
|
||||
let mut boundaries = Vec::with_capacity(num_partitions);
|
||||
for i in 0..num_partitions {
|
||||
boundaries.push((i * partition_size) as u16);
|
||||
}
|
||||
|
||||
boundaries
|
||||
}
|
||||
|
||||
fn identify_boundary_tokens(&self, boundaries: &[u16], _gate: &GatePacket) -> Vec<u16> {
|
||||
// Tokens near partition boundaries
|
||||
let mut boundary_tokens = Vec::new();
|
||||
|
||||
// Add boundary positions
|
||||
for &boundary in boundaries {
|
||||
boundary_tokens.push(boundary);
|
||||
}
|
||||
|
||||
// Limit to max boundary edges
|
||||
boundary_tokens.truncate(self.config.max_cross_partition_edges as usize);
|
||||
|
||||
boundary_tokens
|
||||
}
|
||||
|
||||
fn build_sparse_positions(
|
||||
&self,
|
||||
seq_len: usize,
|
||||
boundaries: &[u16],
|
||||
boundary_tokens: &[u16],
|
||||
_target_density: f32,
|
||||
_gate: &GatePacket,
|
||||
) -> Vec<(u16, u16)> {
|
||||
// Use BTreeSet for O(log n) deduplication instead of O(n) Vec::contains
|
||||
// This provides ~500x speedup for large sequences
|
||||
let mut position_set: BTreeSet<(u16, u16)> = BTreeSet::new();
|
||||
|
||||
// 1. Intra-partition attention (always causal)
|
||||
if self.config.intra_partition_attention {
|
||||
for (partition_idx, &start) in boundaries.iter().enumerate() {
|
||||
let end = if partition_idx + 1 < boundaries.len() {
|
||||
boundaries[partition_idx + 1] as usize
|
||||
} else {
|
||||
seq_len
|
||||
};
|
||||
|
||||
// Full causal attention within partition
|
||||
for i in start as usize..end {
|
||||
for j in start as usize..=i {
|
||||
position_set.insert((i as u16, j as u16));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Boundary cross-partition attention
|
||||
if self.config.boundary_cross_attention {
|
||||
for &boundary_token in boundary_tokens {
|
||||
// Boundary tokens attend to all previous boundary tokens
|
||||
for &prev_boundary in boundary_tokens {
|
||||
if prev_boundary <= boundary_token {
|
||||
position_set.insert((boundary_token, prev_boundary));
|
||||
}
|
||||
}
|
||||
|
||||
// Tokens near boundaries attend to boundary tokens
|
||||
let window = 4;
|
||||
for offset in 0..window {
|
||||
let token_pos = boundary_token + offset;
|
||||
if (token_pos as usize) < seq_len {
|
||||
for &prev_boundary in boundary_tokens {
|
||||
if prev_boundary <= token_pos {
|
||||
position_set.insert((token_pos, prev_boundary));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to Vec (positions are already sorted by BTreeSet)
|
||||
position_set.into_iter().collect()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn softmax(&self, scores: &mut [f32]) {
|
||||
if scores.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let max = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
|
||||
let mut sum = 0.0f32;
|
||||
for s in scores.iter_mut() {
|
||||
*s = (*s - max).exp();
|
||||
sum += *s;
|
||||
}
|
||||
|
||||
if sum > 0.0 {
|
||||
let inv_sum = 1.0 / sum;
|
||||
for s in scores.iter_mut() {
|
||||
*s *= inv_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use alloc::vec;
|
||||
use alloc::vec::Vec;
|
||||
|
||||
#[test]
|
||||
fn test_sparse_mask_creation() {
|
||||
let mask = SparseMask::empty();
|
||||
assert_eq!(mask.num_positions(), 0);
|
||||
assert_eq!(mask.density, 0.0);
|
||||
|
||||
let full = SparseMask::full(4);
|
||||
assert_eq!(full.num_positions(), 10); // 4*5/2 = 10 causal positions
|
||||
assert_eq!(full.density, 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_density_calculation() {
|
||||
let config = SparsityConfig::default();
|
||||
let sparse_attn = MincutSparseAttention::new(config);
|
||||
|
||||
let gate = GatePacket {
|
||||
lambda: 100,
|
||||
lambda_prev: 95,
|
||||
boundary_edges: 5,
|
||||
boundary_concentration_q15: 8192,
|
||||
partition_count: 3,
|
||||
flags: 0,
|
||||
};
|
||||
|
||||
let density = sparse_attn.calculate_density(&gate);
|
||||
assert!(density > 0.0 && density <= 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mask_building() {
|
||||
let config = SparsityConfig::default();
|
||||
let sparse_attn = MincutSparseAttention::new(config);
|
||||
|
||||
let gate = GatePacket {
|
||||
lambda: 100,
|
||||
partition_count: 3,
|
||||
boundary_edges: 5,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mask = sparse_attn.build_mask(&gate, 32);
|
||||
assert!(mask.num_positions() > 0);
|
||||
assert!(mask.density > 0.0 && mask.density <= 1.0);
|
||||
assert_eq!(mask.partition_boundaries.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_flops_estimation() {
|
||||
let config = SparsityConfig::default();
|
||||
let sparse_attn = MincutSparseAttention::new(config);
|
||||
|
||||
let gate = GatePacket {
|
||||
lambda: 100,
|
||||
partition_count: 3,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mask = sparse_attn.build_mask(&gate, 32);
|
||||
let ratio = sparse_attn.estimated_flops_ratio(&mask, 32);
|
||||
|
||||
// Should have some speedup
|
||||
assert!(ratio < 1.0);
|
||||
assert!(ratio > 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sparse_attention_computation() {
|
||||
let config = SparsityConfig::default();
|
||||
let sparse_attn = MincutSparseAttention::new(config);
|
||||
|
||||
let dim = 4;
|
||||
let seq_len = 8;
|
||||
|
||||
// Simple test data
|
||||
let q: Vec<i8> = vec![1; seq_len * dim];
|
||||
let k: Vec<i8> = vec![1; seq_len * dim];
|
||||
let v: Vec<i8> = vec![1; seq_len * dim];
|
||||
|
||||
let gate = GatePacket {
|
||||
lambda: 100,
|
||||
partition_count: 2,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mask = sparse_attn.build_mask(&gate, seq_len);
|
||||
let output = sparse_attn.sparse_attention(&q, &k, &v, &mask, dim, 0.5);
|
||||
|
||||
assert_eq!(output.len(), seq_len * dim);
|
||||
assert!(output.iter().any(|&x| x != 0.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lambda_based_density() {
|
||||
let config = SparsityConfig {
|
||||
lambda_based_density: Some(LambdaDensitySchedule::Threshold {
|
||||
dense_above_lambda: 150,
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
let sparse_attn = MincutSparseAttention::new(config);
|
||||
|
||||
let gate_low = GatePacket {
|
||||
lambda: 100,
|
||||
..Default::default()
|
||||
};
|
||||
let density_low = sparse_attn.calculate_density(&gate_low);
|
||||
assert!(density_low < 0.2);
|
||||
|
||||
let gate_high = GatePacket {
|
||||
lambda: 200,
|
||||
..Default::default()
|
||||
};
|
||||
let density_high = sparse_attn.calculate_density(&gate_high);
|
||||
assert!(density_high > 0.8);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user