385 lines
9.6 KiB
Rust
385 lines
9.6 KiB
Rust
//! Utility functions for attention mechanisms.
|
|
//!
|
|
//! This module provides common utilities like softmax, masking, and
|
|
//! numerical stability helpers used across attention implementations.
|
|
|
|
use crate::error::{AttentionError, AttentionResult};
|
|
|
|
/// Stable softmax that returns Vec<f32> directly (no Result)
|
|
/// Used by sparse, moe, and graph modules
|
|
#[inline]
|
|
pub fn stable_softmax(values: &[f32]) -> Vec<f32> {
|
|
if values.is_empty() {
|
|
return vec![];
|
|
}
|
|
|
|
// Find maximum for numerical stability
|
|
let max_val = values
|
|
.iter()
|
|
.copied()
|
|
.filter(|x| x.is_finite())
|
|
.fold(f32::NEG_INFINITY, f32::max);
|
|
|
|
if !max_val.is_finite() {
|
|
// All values are -inf or invalid, return uniform
|
|
let n = values.len();
|
|
return vec![1.0 / n as f32; n];
|
|
}
|
|
|
|
// Compute exp(x - max) and sum
|
|
let mut exp_values: Vec<f32> = values
|
|
.iter()
|
|
.map(|&x| {
|
|
if x.is_finite() {
|
|
(x - max_val).exp()
|
|
} else {
|
|
0.0
|
|
}
|
|
})
|
|
.collect();
|
|
|
|
let sum: f32 = exp_values.iter().sum();
|
|
|
|
if sum <= 1e-10 || !sum.is_finite() {
|
|
// Fallback to uniform
|
|
let n = values.len();
|
|
return vec![1.0 / n as f32; n];
|
|
}
|
|
|
|
// Normalize
|
|
let inv_sum = 1.0 / sum;
|
|
exp_values.iter_mut().for_each(|x| *x *= inv_sum);
|
|
|
|
exp_values
|
|
}
|
|
|
|
/// Computes softmax over a slice of values.
|
|
///
|
|
/// Uses the numerically stable variant: softmax(x) = exp(x - max(x)) / sum(exp(x - max(x)))
|
|
///
|
|
/// # Arguments
|
|
///
|
|
/// * `values` - Input values
|
|
///
|
|
/// # Returns
|
|
///
|
|
/// Softmax-normalized values
|
|
#[inline]
|
|
pub fn softmax(values: &[f32]) -> AttentionResult<Vec<f32>> {
|
|
if values.is_empty() {
|
|
return Err(AttentionError::EmptyInput(
|
|
"cannot compute softmax of empty slice".to_string(),
|
|
));
|
|
}
|
|
|
|
// Find maximum for numerical stability
|
|
let max_val = values.iter().copied().fold(f32::NEG_INFINITY, f32::max);
|
|
|
|
if !max_val.is_finite() {
|
|
return Err(AttentionError::NumericalInstability(
|
|
"non-finite values in softmax input".to_string(),
|
|
));
|
|
}
|
|
|
|
// Compute exp(x - max) and sum
|
|
let mut exp_values: Vec<f32> = values.iter().map(|&x| (x - max_val).exp()).collect();
|
|
|
|
let sum: f32 = exp_values.iter().sum();
|
|
|
|
if sum <= 0.0 || !sum.is_finite() {
|
|
return Err(AttentionError::NumericalInstability(
|
|
"invalid sum in softmax computation".to_string(),
|
|
));
|
|
}
|
|
|
|
// Normalize
|
|
let inv_sum = 1.0 / sum;
|
|
exp_values.iter_mut().for_each(|x| *x *= inv_sum);
|
|
|
|
Ok(exp_values)
|
|
}
|
|
|
|
/// Computes softmax with masking support.
|
|
///
|
|
/// Masked positions are set to negative infinity before softmax,
|
|
/// resulting in zero attention weights.
|
|
///
|
|
/// # Arguments
|
|
///
|
|
/// * `values` - Input values
|
|
/// * `mask` - Optional mask (true = attend, false = mask out)
|
|
///
|
|
/// # Returns
|
|
///
|
|
/// Masked and softmax-normalized values
|
|
#[inline]
|
|
pub fn masked_softmax(values: &[f32], mask: Option<&[bool]>) -> AttentionResult<Vec<f32>> {
|
|
if values.is_empty() {
|
|
return Err(AttentionError::EmptyInput(
|
|
"cannot compute softmax of empty slice".to_string(),
|
|
));
|
|
}
|
|
|
|
let masked_values = if let Some(m) = mask {
|
|
if m.len() != values.len() {
|
|
return Err(AttentionError::InvalidMask {
|
|
expected: format!("{}", values.len()),
|
|
actual: format!("{}", m.len()),
|
|
});
|
|
}
|
|
|
|
values
|
|
.iter()
|
|
.zip(m.iter())
|
|
.map(|(&v, &keep)| if keep { v } else { f32::NEG_INFINITY })
|
|
.collect::<Vec<_>>()
|
|
} else {
|
|
values.to_vec()
|
|
};
|
|
|
|
softmax(&masked_values)
|
|
}
|
|
|
|
/// Applies causal masking to attention scores.
|
|
///
|
|
/// For position i, only positions 0..=i can be attended to.
|
|
///
|
|
/// # Arguments
|
|
///
|
|
/// * `scores` - Attention scores matrix [query_len, key_len]
|
|
/// * `query_len` - Number of query positions
|
|
/// * `key_len` - Number of key positions
|
|
///
|
|
/// # Returns
|
|
///
|
|
/// Causally masked scores
|
|
pub fn apply_causal_mask(
|
|
scores: &mut [f32],
|
|
query_len: usize,
|
|
key_len: usize,
|
|
) -> AttentionResult<()> {
|
|
if scores.len() != query_len * key_len {
|
|
return Err(AttentionError::InvalidMask {
|
|
expected: format!("{}x{}", query_len, key_len),
|
|
actual: format!("{}", scores.len()),
|
|
});
|
|
}
|
|
|
|
for i in 0..query_len {
|
|
for j in (i + 1)..key_len {
|
|
scores[i * key_len + j] = f32::NEG_INFINITY;
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Computes dot product between two vectors.
|
|
///
|
|
/// # Arguments
|
|
///
|
|
/// * `a` - First vector
|
|
/// * `b` - Second vector
|
|
///
|
|
/// # Returns
|
|
///
|
|
/// Dot product value
|
|
#[inline]
|
|
pub fn dot_product(a: &[f32], b: &[f32]) -> AttentionResult<f32> {
|
|
if a.len() != b.len() {
|
|
return Err(AttentionError::DimensionMismatch {
|
|
expected: a.len(),
|
|
actual: b.len(),
|
|
});
|
|
}
|
|
|
|
Ok(a.iter().zip(b.iter()).map(|(x, y)| x * y).sum())
|
|
}
|
|
|
|
/// Scales a vector by a scalar value.
|
|
///
|
|
/// # Arguments
|
|
///
|
|
/// * `vector` - Input vector (modified in place)
|
|
/// * `scale` - Scale factor
|
|
#[inline]
|
|
pub fn scale_vector(vector: &mut [f32], scale: f32) {
|
|
vector.iter_mut().for_each(|x| *x *= scale);
|
|
}
|
|
|
|
/// Adds two vectors element-wise.
|
|
///
|
|
/// # Arguments
|
|
///
|
|
/// * `a` - First vector
|
|
/// * `b` - Second vector
|
|
///
|
|
/// # Returns
|
|
///
|
|
/// Sum vector
|
|
#[inline]
|
|
pub fn add_vectors(a: &[f32], b: &[f32]) -> AttentionResult<Vec<f32>> {
|
|
if a.len() != b.len() {
|
|
return Err(AttentionError::DimensionMismatch {
|
|
expected: a.len(),
|
|
actual: b.len(),
|
|
});
|
|
}
|
|
|
|
Ok(a.iter().zip(b.iter()).map(|(x, y)| x + y).collect())
|
|
}
|
|
|
|
/// Computes L2 norm of a vector.
|
|
///
|
|
/// # Arguments
|
|
///
|
|
/// * `vector` - Input vector
|
|
///
|
|
/// # Returns
|
|
///
|
|
/// L2 norm value
|
|
#[inline]
|
|
pub fn l2_norm(vector: &[f32]) -> f32 {
|
|
vector.iter().map(|x| x * x).sum::<f32>().sqrt()
|
|
}
|
|
|
|
/// Normalizes a vector to unit length.
|
|
///
|
|
/// # Arguments
|
|
///
|
|
/// * `vector` - Input vector (modified in place)
|
|
///
|
|
/// # Returns
|
|
///
|
|
/// Original norm before normalization
|
|
pub fn normalize_vector(vector: &mut [f32]) -> AttentionResult<f32> {
|
|
let norm = l2_norm(vector);
|
|
|
|
if norm <= 0.0 || !norm.is_finite() {
|
|
return Err(AttentionError::NumericalInstability(
|
|
"cannot normalize zero or non-finite vector".to_string(),
|
|
));
|
|
}
|
|
|
|
let inv_norm = 1.0 / norm;
|
|
vector.iter_mut().for_each(|x| *x *= inv_norm);
|
|
|
|
Ok(norm)
|
|
}
|
|
|
|
/// Applies dropout to a vector during training.
|
|
///
|
|
/// # Arguments
|
|
///
|
|
/// * `vector` - Input vector (modified in place)
|
|
/// * `dropout_prob` - Dropout probability (0.0 to 1.0)
|
|
/// * `training` - Whether in training mode
|
|
/// * `rng` - Random number generator
|
|
pub fn apply_dropout(
|
|
vector: &mut [f32],
|
|
dropout_prob: f32,
|
|
training: bool,
|
|
rng: &mut impl rand::Rng,
|
|
) {
|
|
if !training || dropout_prob == 0.0 {
|
|
return;
|
|
}
|
|
|
|
let scale = 1.0 / (1.0 - dropout_prob);
|
|
for x in vector.iter_mut() {
|
|
if rng.gen::<f32>() < dropout_prob {
|
|
*x = 0.0;
|
|
} else {
|
|
*x *= scale;
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use approx::assert_relative_eq;
|
|
|
|
#[test]
|
|
fn test_softmax() {
|
|
let values = vec![1.0, 2.0, 3.0];
|
|
let result = softmax(&values).unwrap();
|
|
|
|
// Sum should be approximately 1.0
|
|
let sum: f32 = result.iter().sum();
|
|
assert_relative_eq!(sum, 1.0, epsilon = 1e-6);
|
|
|
|
// Values should be in ascending order
|
|
assert!(result[0] < result[1]);
|
|
assert!(result[1] < result[2]);
|
|
}
|
|
|
|
#[test]
|
|
fn test_softmax_numerical_stability() {
|
|
let values = vec![1000.0, 1001.0, 1002.0];
|
|
let result = softmax(&values).unwrap();
|
|
|
|
let sum: f32 = result.iter().sum();
|
|
assert_relative_eq!(sum, 1.0, epsilon = 1e-6);
|
|
}
|
|
|
|
#[test]
|
|
fn test_masked_softmax() {
|
|
let values = vec![1.0, 2.0, 3.0, 4.0];
|
|
let mask = vec![true, true, false, false];
|
|
let result = masked_softmax(&values, Some(&mask)).unwrap();
|
|
|
|
// Masked positions should be zero
|
|
assert_relative_eq!(result[2], 0.0, epsilon = 1e-6);
|
|
assert_relative_eq!(result[3], 0.0, epsilon = 1e-6);
|
|
|
|
// Unmasked positions should sum to 1
|
|
let sum: f32 = result[0] + result[1];
|
|
assert_relative_eq!(sum, 1.0, epsilon = 1e-6);
|
|
}
|
|
|
|
#[test]
|
|
fn test_dot_product() {
|
|
let a = vec![1.0, 2.0, 3.0];
|
|
let b = vec![4.0, 5.0, 6.0];
|
|
let result = dot_product(&a, &b).unwrap();
|
|
|
|
assert_relative_eq!(result, 32.0, epsilon = 1e-6);
|
|
}
|
|
|
|
#[test]
|
|
fn test_scale_vector() {
|
|
let mut vector = vec![1.0, 2.0, 3.0];
|
|
scale_vector(&mut vector, 2.0);
|
|
|
|
assert_relative_eq!(vector[0], 2.0);
|
|
assert_relative_eq!(vector[1], 4.0);
|
|
assert_relative_eq!(vector[2], 6.0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_normalize_vector() {
|
|
let mut vector = vec![3.0, 4.0];
|
|
let norm = normalize_vector(&mut vector).unwrap();
|
|
|
|
assert_relative_eq!(norm, 5.0, epsilon = 1e-6);
|
|
assert_relative_eq!(l2_norm(&vector), 1.0, epsilon = 1e-6);
|
|
}
|
|
|
|
#[test]
|
|
fn test_causal_mask() {
|
|
let mut scores = vec![0.0; 9]; // 3x3 matrix
|
|
apply_causal_mask(&mut scores, 3, 3).unwrap();
|
|
|
|
// Check upper triangle is masked
|
|
assert_eq!(scores[1], f32::NEG_INFINITY); // (0, 1)
|
|
assert_eq!(scores[2], f32::NEG_INFINITY); // (0, 2)
|
|
assert_eq!(scores[5], f32::NEG_INFINITY); // (1, 2)
|
|
|
|
// Check diagonal and lower triangle are not masked
|
|
assert_eq!(scores[0], 0.0); // (0, 0)
|
|
assert_eq!(scores[4], 0.0); // (1, 1)
|
|
assert_eq!(scores[8], 0.0); // (2, 2)
|
|
}
|
|
}
|