Files
wifi-densepose/examples/vibecast-7sense/crates/sevensense-learning/src/infrastructure/attention.rs
ruv d803bfe2b1 Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector
git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
2026-02-28 14:39:40 -05:00

577 lines
17 KiB
Rust

//! Attention mechanisms for GNN models.
//!
//! This module provides attention layers and mechanisms including:
//! - Single-head attention
//! - Multi-head attention
//! - Graph-level attention readout
use ndarray::{Array1, Array2, Axis};
use rand::Rng;
use rand_distr::{Distribution, Uniform};
/// Error type for attention operations
#[derive(Debug, thiserror::Error)]
pub enum AttentionError {
/// Dimension mismatch
#[error("Dimension mismatch: expected {expected}, got {actual}")]
DimensionMismatch { expected: usize, actual: usize },
/// Invalid configuration
#[error("Invalid attention configuration: {0}")]
InvalidConfig(String),
/// Computation error
#[error("Attention computation error: {0}")]
ComputationError(String),
}
/// Result type for attention operations
pub type AttentionResult<T> = Result<T, AttentionError>;
/// Single-head attention layer
#[derive(Debug, Clone)]
pub struct AttentionLayer {
/// Query weight matrix
query_weights: Array2<f32>,
/// Key weight matrix
key_weights: Array2<f32>,
/// Value weight matrix
value_weights: Array2<f32>,
/// Attention dimension
attention_dim: usize,
/// Input dimension
input_dim: usize,
/// Output dimension
output_dim: usize,
/// Scaling factor
scale: f32,
}
impl AttentionLayer {
/// Create a new attention layer
#[must_use]
pub fn new(input_dim: usize, attention_dim: usize) -> Self {
let query_weights = xavier_init(input_dim, attention_dim);
let key_weights = xavier_init(input_dim, attention_dim);
let value_weights = xavier_init(input_dim, attention_dim);
let scale = (attention_dim as f32).sqrt();
Self {
query_weights,
key_weights,
value_weights,
attention_dim,
input_dim,
output_dim: attention_dim,
scale,
}
}
/// Get the output dimension
#[must_use]
pub fn output_dim(&self) -> usize {
self.output_dim
}
/// Compute attention scores
pub fn compute_attention(
&self,
query: &Array2<f32>,
key: &Array2<f32>,
mask: Option<&Array2<f32>>,
) -> AttentionResult<Array2<f32>> {
// Q * K^T / sqrt(d_k)
let scores = query.dot(&key.t()) / self.scale;
// Apply mask if provided (set masked positions to -inf)
let scores = if let Some(mask) = mask {
let mut masked = scores;
for i in 0..masked.nrows() {
for j in 0..masked.ncols() {
if mask[[i, j]] == 0.0 {
masked[[i, j]] = f32::NEG_INFINITY;
}
}
}
masked
} else {
scores
};
// Softmax
let attention_weights = softmax_2d(&scores);
Ok(attention_weights)
}
/// Forward pass through the attention layer
pub fn forward(&self, features: &Array2<f32>) -> AttentionResult<Array2<f32>> {
if features.ncols() != self.input_dim {
return Err(AttentionError::DimensionMismatch {
expected: self.input_dim,
actual: features.ncols(),
});
}
// Compute Q, K, V
let q = features.dot(&self.query_weights);
let k = features.dot(&self.key_weights);
let v = features.dot(&self.value_weights);
// Compute attention weights
let attention_weights = self.compute_attention(&q, &k, None)?;
// Apply attention to values
let output = attention_weights.dot(&v);
Ok(output)
}
/// Forward pass with explicit Q, K, V
pub fn forward_qkv(
&self,
query: &Array2<f32>,
key: &Array2<f32>,
value: &Array2<f32>,
mask: Option<&Array2<f32>>,
) -> AttentionResult<Array2<f32>> {
// Transform Q, K, V
let q = query.dot(&self.query_weights);
let k = key.dot(&self.key_weights);
let v = value.dot(&self.value_weights);
// Compute attention weights
let attention_weights = self.compute_attention(&q, &k, mask)?;
// Apply attention to values
let output = attention_weights.dot(&v);
Ok(output)
}
/// Graph-level readout using attention
pub fn graph_readout(&self, node_features: &Array2<f32>) -> AttentionResult<Array1<f32>> {
// Compute attention-weighted mean of node features
let attended = self.forward(node_features)?;
// Mean over nodes
let mean = attended.mean_axis(Axis(0)).unwrap();
Ok(mean)
}
/// Update weights with gradient
pub fn update_weights(&mut self, _lr: f32, weight_decay: f32) {
self.query_weights -= &(&self.query_weights * weight_decay);
self.key_weights -= &(&self.key_weights * weight_decay);
self.value_weights -= &(&self.value_weights * weight_decay);
}
}
/// Multi-head attention layer
#[derive(Debug, Clone)]
pub struct MultiHeadAttention {
/// Individual attention heads
heads: Vec<AttentionLayer>,
/// Output projection
output_projection: Array2<f32>,
/// Number of heads
num_heads: usize,
/// Dimension per head
head_dim: usize,
/// Total output dimension
output_dim: usize,
/// Dropout probability
dropout: f32,
}
impl MultiHeadAttention {
/// Create a new multi-head attention layer
#[must_use]
pub fn new(input_dim: usize, num_heads: usize, head_dim: usize, dropout: f32) -> Self {
let mut heads = Vec::with_capacity(num_heads);
for _ in 0..num_heads {
heads.push(AttentionLayer::new(input_dim, head_dim));
}
let total_dim = num_heads * head_dim;
let output_projection = xavier_init(total_dim, input_dim);
Self {
heads,
output_projection,
num_heads,
head_dim,
output_dim: input_dim,
dropout,
}
}
/// Get the number of heads
#[must_use]
pub fn num_heads(&self) -> usize {
self.num_heads
}
/// Get the output dimension
#[must_use]
pub fn output_dim(&self) -> usize {
self.output_dim
}
/// Forward pass through multi-head attention
pub fn forward(&self, features: &Array2<f32>) -> AttentionResult<Array2<f32>> {
let n = features.nrows();
// Compute attention for each head
let mut head_outputs = Vec::with_capacity(self.num_heads);
for head in &self.heads {
let output = head.forward(features)?;
head_outputs.push(output);
}
// Concatenate head outputs
let mut concat = Array2::zeros((n, self.num_heads * self.head_dim));
for (h, output) in head_outputs.iter().enumerate() {
let start = h * self.head_dim;
for i in 0..n {
for j in 0..self.head_dim {
concat[[i, start + j]] = output[[i, j]];
}
}
}
// Apply output projection
let output = concat.dot(&self.output_projection);
Ok(output)
}
/// Forward pass with explicit Q, K, V
pub fn forward_qkv(
&self,
query: &Array2<f32>,
key: &Array2<f32>,
value: &Array2<f32>,
mask: Option<&Array2<f32>>,
) -> AttentionResult<Array2<f32>> {
let n = query.nrows();
// Compute attention for each head
let mut head_outputs = Vec::with_capacity(self.num_heads);
for head in &self.heads {
let output = head.forward_qkv(query, key, value, mask)?;
head_outputs.push(output);
}
// Concatenate head outputs
let mut concat = Array2::zeros((n, self.num_heads * self.head_dim));
for (h, output) in head_outputs.iter().enumerate() {
let start = h * self.head_dim;
for i in 0..n {
for j in 0..self.head_dim {
concat[[i, start + j]] = output[[i, j]];
}
}
}
// Apply output projection
let output = concat.dot(&self.output_projection);
Ok(output)
}
/// Graph-level readout using multi-head attention
pub fn graph_readout(&self, node_features: &Array2<f32>) -> AttentionResult<Array1<f32>> {
let attended = self.forward(node_features)?;
let mean = attended.mean_axis(Axis(0)).unwrap();
Ok(mean)
}
}
/// Cross-attention between two sequences
#[derive(Debug, Clone)]
pub struct CrossAttention {
/// Query projection for source
query_proj: Array2<f32>,
/// Key projection for target
key_proj: Array2<f32>,
/// Value projection for target
value_proj: Array2<f32>,
/// Attention dimension
attention_dim: usize,
/// Source dimension
source_dim: usize,
/// Target dimension
target_dim: usize,
}
impl CrossAttention {
/// Create a new cross-attention layer
#[must_use]
pub fn new(source_dim: usize, target_dim: usize, attention_dim: usize) -> Self {
Self {
query_proj: xavier_init(source_dim, attention_dim),
key_proj: xavier_init(target_dim, attention_dim),
value_proj: xavier_init(target_dim, attention_dim),
attention_dim,
source_dim,
target_dim,
}
}
/// Compute cross-attention between source and target
pub fn forward(
&self,
source: &Array2<f32>,
target: &Array2<f32>,
) -> AttentionResult<Array2<f32>> {
// Project source to query
let query = source.dot(&self.query_proj);
// Project target to key and value
let key = target.dot(&self.key_proj);
let value = target.dot(&self.value_proj);
// Compute attention scores
let scale = (self.attention_dim as f32).sqrt();
let scores = query.dot(&key.t()) / scale;
// Softmax
let attention_weights = softmax_2d(&scores);
// Apply attention to values
let output = attention_weights.dot(&value);
Ok(output)
}
}
/// Set attention for set-to-set operations
#[derive(Debug, Clone)]
pub struct SetAttention {
/// Multi-head attention
mha: MultiHeadAttention,
/// Layer normalization parameters
layer_norm_weight: Array1<f32>,
layer_norm_bias: Array1<f32>,
/// Feed-forward network
ffn_w1: Array2<f32>,
ffn_w2: Array2<f32>,
/// Dimensions
input_dim: usize,
hidden_dim: usize,
}
impl SetAttention {
/// Create a new set attention layer (Set Transformer style)
#[must_use]
pub fn new(input_dim: usize, num_heads: usize, hidden_dim: usize) -> Self {
let head_dim = input_dim / num_heads;
let mha = MultiHeadAttention::new(input_dim, num_heads, head_dim, 0.0);
Self {
mha,
layer_norm_weight: Array1::ones(input_dim),
layer_norm_bias: Array1::zeros(input_dim),
ffn_w1: xavier_init(input_dim, hidden_dim),
ffn_w2: xavier_init(hidden_dim, input_dim),
input_dim,
hidden_dim,
}
}
/// Forward pass with self-attention and feed-forward
pub fn forward(&self, features: &Array2<f32>) -> AttentionResult<Array2<f32>> {
// Self-attention with residual
let attended = self.mha.forward(features)?;
let residual1 = features + &attended;
let normed1 = layer_norm(&residual1, &self.layer_norm_weight, &self.layer_norm_bias);
// Feed-forward with residual
let hidden = normed1.dot(&self.ffn_w1).mapv(|x| x.max(0.0)); // ReLU
let ffn_out = hidden.dot(&self.ffn_w2);
let residual2 = &normed1 + &ffn_out;
let output = layer_norm(&residual2, &self.layer_norm_weight, &self.layer_norm_bias);
Ok(output)
}
/// Aggregate set elements using learned attention
pub fn aggregate(&self, features: &Array2<f32>) -> AttentionResult<Array1<f32>> {
let transformed = self.forward(features)?;
Ok(transformed.mean_axis(Axis(0)).unwrap())
}
}
// =========== Helper Functions ===========
/// Xavier/Glorot initialization
fn xavier_init(fan_in: usize, fan_out: usize) -> Array2<f32> {
let limit = (6.0 / (fan_in + fan_out) as f32).sqrt();
let uniform = Uniform::new(-limit, limit);
let mut rng = rand::thread_rng();
Array2::from_shape_fn((fan_in, fan_out), |_| uniform.sample(&mut rng))
}
/// Softmax over 2D array (row-wise)
fn softmax_2d(scores: &Array2<f32>) -> Array2<f32> {
let mut result = scores.clone();
for mut row in result.rows_mut() {
// Subtract max for numerical stability
let max = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0;
for val in row.iter_mut() {
if val.is_finite() {
*val = (*val - max).exp();
sum += *val;
} else {
*val = 0.0;
}
}
if sum > 0.0 {
row /= sum;
}
}
result
}
/// Layer normalization
fn layer_norm(x: &Array2<f32>, weight: &Array1<f32>, bias: &Array1<f32>) -> Array2<f32> {
let eps = 1e-5;
let mut result = x.clone();
for mut row in result.rows_mut() {
let mean = row.mean().unwrap_or(0.0);
let variance = row.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / row.len() as f32;
let std = (variance + eps).sqrt();
for (i, val) in row.iter_mut().enumerate() {
*val = (*val - mean) / std * weight[i] + bias[i];
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_attention_layer() {
let layer = AttentionLayer::new(8, 16);
assert_eq!(layer.output_dim(), 16);
let features = Array2::from_elem((3, 8), 0.5);
let output = layer.forward(&features).unwrap();
assert_eq!(output.shape(), &[3, 16]);
}
#[test]
fn test_multi_head_attention() {
let mha = MultiHeadAttention::new(8, 4, 4, 0.0);
assert_eq!(mha.num_heads(), 4);
assert_eq!(mha.output_dim(), 8);
let features = Array2::from_elem((5, 8), 0.5);
let output = mha.forward(&features).unwrap();
assert_eq!(output.shape(), &[5, 8]);
}
#[test]
fn test_cross_attention() {
let cross = CrossAttention::new(8, 16, 32);
let source = Array2::from_elem((3, 8), 0.5);
let target = Array2::from_elem((5, 16), 0.5);
let output = cross.forward(&source, &target).unwrap();
assert_eq!(output.shape(), &[3, 32]);
}
#[test]
fn test_set_attention() {
let set_attn = SetAttention::new(16, 4, 64);
let features = Array2::from_elem((4, 16), 0.5);
let output = set_attn.forward(&features).unwrap();
assert_eq!(output.shape(), &[4, 16]);
let aggregated = set_attn.aggregate(&features).unwrap();
assert_eq!(aggregated.len(), 16);
}
#[test]
fn test_softmax() {
let scores = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let probs = softmax_2d(&scores);
// Each row should sum to 1
for row in probs.rows() {
let sum: f32 = row.iter().sum();
assert!((sum - 1.0).abs() < 1e-5);
}
// Higher values should have higher probabilities
assert!(probs[[0, 2]] > probs[[0, 1]]);
assert!(probs[[0, 1]] > probs[[0, 0]]);
}
#[test]
fn test_layer_norm() {
let x = Array2::from_shape_vec((2, 4), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
let weight = Array1::ones(4);
let bias = Array1::zeros(4);
let normed = layer_norm(&x, &weight, &bias);
// Each row should have mean ~0 and std ~1
for row in normed.rows() {
let mean: f32 = row.iter().sum::<f32>() / row.len() as f32;
assert!(mean.abs() < 1e-5);
}
}
#[test]
fn test_graph_readout() {
let layer = AttentionLayer::new(8, 4);
let node_features = Array2::from_elem((5, 8), 0.5);
let readout = layer.graph_readout(&node_features).unwrap();
assert_eq!(readout.len(), 4);
}
#[test]
fn test_attention_with_mask() {
let layer = AttentionLayer::new(4, 4);
let features = Array2::from_elem((3, 4), 0.5);
// Create a mask that blocks second and third positions
let mut mask = Array2::ones((3, 3));
mask[[0, 1]] = 0.0;
mask[[0, 2]] = 0.0;
let query = features.dot(&layer.query_weights);
let key = features.dot(&layer.key_weights);
let attn_weights = layer.compute_attention(&query, &key, Some(&mask)).unwrap();
// First row should only attend to itself
assert!(attn_weights[[0, 0]] > 0.99); // Almost all attention to self
assert!(attn_weights[[0, 1]] < 0.01);
assert!(attn_weights[[0, 2]] < 0.01);
}
}