Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
10
vendor/ruvector/crates/ruvector-attention/src/attention/mod.rs
vendored
Normal file
10
vendor/ruvector/crates/ruvector-attention/src/attention/mod.rs
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
//! Attention mechanism implementations.
|
||||
//!
|
||||
//! This module provides concrete implementations of various attention mechanisms
|
||||
//! including scaled dot-product attention and multi-head attention.
|
||||
|
||||
pub mod multi_head;
|
||||
pub mod scaled_dot_product;
|
||||
|
||||
pub use multi_head::MultiHeadAttention;
|
||||
pub use scaled_dot_product::ScaledDotProductAttention;
|
||||
149
vendor/ruvector/crates/ruvector-attention/src/attention/multi_head.rs
vendored
Normal file
149
vendor/ruvector/crates/ruvector-attention/src/attention/multi_head.rs
vendored
Normal file
@@ -0,0 +1,149 @@
|
||||
//! Multi-head attention implementation.
|
||||
//!
|
||||
//! Implements parallel attention heads for diverse representation learning.
|
||||
|
||||
use crate::{
|
||||
error::{AttentionError, AttentionResult},
|
||||
traits::Attention,
|
||||
};
|
||||
|
||||
use super::scaled_dot_product::ScaledDotProductAttention;
|
||||
|
||||
/// Multi-head attention mechanism.
|
||||
///
|
||||
/// Splits the input into multiple heads, applies attention in parallel,
|
||||
/// and concatenates the results. This allows the model to attend to
|
||||
/// different representation subspaces simultaneously.
|
||||
pub struct MultiHeadAttention {
|
||||
dim: usize,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
}
|
||||
|
||||
impl MultiHeadAttention {
|
||||
/// Creates a new multi-head attention mechanism.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `dim` - The embedding dimension
|
||||
/// * `num_heads` - Number of attention heads
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `dim` is not divisible by `num_heads`.
|
||||
pub fn new(dim: usize, num_heads: usize) -> Self {
|
||||
assert!(
|
||||
dim % num_heads == 0,
|
||||
"Dimension {} must be divisible by number of heads {}",
|
||||
dim,
|
||||
num_heads
|
||||
);
|
||||
|
||||
Self {
|
||||
dim,
|
||||
num_heads,
|
||||
head_dim: dim / num_heads,
|
||||
}
|
||||
}
|
||||
|
||||
/// Splits input into multiple heads.
|
||||
fn split_heads(&self, input: &[f32]) -> Vec<Vec<f32>> {
|
||||
(0..self.num_heads)
|
||||
.map(|h| {
|
||||
let start = h * self.head_dim;
|
||||
let end = start + self.head_dim;
|
||||
input[start..end].to_vec()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Concatenates outputs from multiple heads.
|
||||
fn concat_heads(&self, heads: Vec<Vec<f32>>) -> Vec<f32> {
|
||||
heads.into_iter().flatten().collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl Attention for MultiHeadAttention {
|
||||
fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: &[&[f32]],
|
||||
values: &[&[f32]],
|
||||
) -> AttentionResult<Vec<f32>> {
|
||||
if query.len() != self.dim {
|
||||
return Err(AttentionError::DimensionMismatch {
|
||||
expected: self.dim,
|
||||
actual: query.len(),
|
||||
});
|
||||
}
|
||||
|
||||
// Split query into heads
|
||||
let query_heads = self.split_heads(query);
|
||||
|
||||
// Split keys and values
|
||||
let key_heads: Vec<Vec<Vec<f32>>> = keys.iter().map(|k| self.split_heads(k)).collect();
|
||||
|
||||
let value_heads: Vec<Vec<Vec<f32>>> = values.iter().map(|v| self.split_heads(v)).collect();
|
||||
|
||||
// Compute attention for each head
|
||||
let mut head_outputs = Vec::new();
|
||||
for h in 0..self.num_heads {
|
||||
let head_attn = ScaledDotProductAttention::new(self.head_dim);
|
||||
|
||||
let head_keys: Vec<&[f32]> = key_heads.iter().map(|kh| kh[h].as_slice()).collect();
|
||||
|
||||
let head_values: Vec<&[f32]> = value_heads.iter().map(|vh| vh[h].as_slice()).collect();
|
||||
|
||||
let head_out = head_attn.compute(&query_heads[h], &head_keys, &head_values)?;
|
||||
head_outputs.push(head_out);
|
||||
}
|
||||
|
||||
// Concatenate head outputs
|
||||
Ok(self.concat_heads(head_outputs))
|
||||
}
|
||||
|
||||
fn compute_with_mask(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: &[&[f32]],
|
||||
values: &[&[f32]],
|
||||
_mask: Option<&[bool]>,
|
||||
) -> AttentionResult<Vec<f32>> {
|
||||
// For simplicity, delegate to compute (mask handling can be added per-head)
|
||||
self.compute(query, keys, values)
|
||||
}
|
||||
|
||||
fn dim(&self) -> usize {
|
||||
self.dim
|
||||
}
|
||||
|
||||
fn num_heads(&self) -> usize {
|
||||
self.num_heads
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_multi_head() {
|
||||
let attn = MultiHeadAttention::new(8, 2);
|
||||
let query = vec![1.0_f32; 8];
|
||||
let key1 = vec![0.5_f32; 8];
|
||||
let key2 = vec![0.3_f32; 8];
|
||||
let val1 = vec![1.0_f32; 8];
|
||||
let val2 = vec![2.0_f32; 8];
|
||||
let keys = vec![key1.as_slice(), key2.as_slice()];
|
||||
let values = vec![val1.as_slice(), val2.as_slice()];
|
||||
|
||||
let result = attn.compute(&query, &keys, &values).unwrap();
|
||||
assert_eq!(result.len(), 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "divisible")]
|
||||
fn test_invalid_heads() {
|
||||
MultiHeadAttention::new(10, 3);
|
||||
}
|
||||
}
|
||||
180
vendor/ruvector/crates/ruvector-attention/src/attention/scaled_dot_product.rs
vendored
Normal file
180
vendor/ruvector/crates/ruvector-attention/src/attention/scaled_dot_product.rs
vendored
Normal file
@@ -0,0 +1,180 @@
|
||||
//! Scaled dot-product attention implementation.
|
||||
//!
|
||||
//! Implements the fundamental attention mechanism: softmax(QK^T / √d)V
|
||||
|
||||
use crate::{
|
||||
error::{AttentionError, AttentionResult},
|
||||
traits::Attention,
|
||||
};
|
||||
|
||||
/// Scaled dot-product attention: softmax(QK^T / √d)V
|
||||
///
|
||||
/// This is the fundamental attention mechanism used in transformers.
|
||||
/// It computes attention scores by taking the dot product of queries
|
||||
/// and keys, scaling by the square root of the dimension, applying
|
||||
/// softmax, and using the result to weight values.
|
||||
pub struct ScaledDotProductAttention {
|
||||
dim: usize,
|
||||
}
|
||||
|
||||
impl ScaledDotProductAttention {
|
||||
/// Creates a new scaled dot-product attention mechanism.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `dim` - The embedding dimension
|
||||
pub fn new(dim: usize) -> Self {
|
||||
Self { dim }
|
||||
}
|
||||
|
||||
/// Computes attention scores (before softmax).
|
||||
fn compute_scores(&self, query: &[f32], keys: &[&[f32]]) -> Vec<f32> {
|
||||
let scale = (self.dim as f32).sqrt();
|
||||
keys.iter()
|
||||
.map(|key| {
|
||||
query
|
||||
.iter()
|
||||
.zip(key.iter())
|
||||
.map(|(q, k)| q * k)
|
||||
.sum::<f32>()
|
||||
/ scale
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Applies softmax to attention scores.
|
||||
fn softmax(&self, scores: &[f32]) -> Vec<f32> {
|
||||
let max_score = scores.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
|
||||
let exp_scores: Vec<f32> = scores.iter().map(|s| (s - max_score).exp()).collect();
|
||||
let sum: f32 = exp_scores.iter().sum();
|
||||
exp_scores.iter().map(|e| e / sum).collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl Attention for ScaledDotProductAttention {
|
||||
fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: &[&[f32]],
|
||||
values: &[&[f32]],
|
||||
) -> AttentionResult<Vec<f32>> {
|
||||
if query.len() != self.dim {
|
||||
return Err(AttentionError::DimensionMismatch {
|
||||
expected: self.dim,
|
||||
actual: query.len(),
|
||||
});
|
||||
}
|
||||
|
||||
if keys.is_empty() || values.is_empty() {
|
||||
return Err(AttentionError::EmptyInput("keys or values".to_string()));
|
||||
}
|
||||
|
||||
if keys.len() != values.len() {
|
||||
return Err(AttentionError::DimensionMismatch {
|
||||
expected: keys.len(),
|
||||
actual: values.len(),
|
||||
});
|
||||
}
|
||||
|
||||
// Compute attention scores
|
||||
let scores = self.compute_scores(query, keys);
|
||||
|
||||
// Apply softmax
|
||||
let weights = self.softmax(&scores);
|
||||
|
||||
// Weight values
|
||||
let mut output = vec![0.0; self.dim];
|
||||
for (weight, value) in weights.iter().zip(values.iter()) {
|
||||
for (out, val) in output.iter_mut().zip(value.iter()) {
|
||||
*out += weight * val;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn compute_with_mask(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: &[&[f32]],
|
||||
values: &[&[f32]],
|
||||
mask: Option<&[bool]>,
|
||||
) -> AttentionResult<Vec<f32>> {
|
||||
if mask.is_none() {
|
||||
return self.compute(query, keys, values);
|
||||
}
|
||||
|
||||
let mask = mask.unwrap();
|
||||
if mask.len() != keys.len() {
|
||||
return Err(AttentionError::InvalidMask {
|
||||
expected: format!("{}", keys.len()),
|
||||
actual: format!("{}", mask.len()),
|
||||
});
|
||||
}
|
||||
|
||||
// Compute scores
|
||||
let mut scores = self.compute_scores(query, keys);
|
||||
|
||||
// Apply mask (set masked positions to very negative value)
|
||||
for (score, &m) in scores.iter_mut().zip(mask.iter()) {
|
||||
if !m {
|
||||
*score = f32::NEG_INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
// Apply softmax
|
||||
let weights = self.softmax(&scores);
|
||||
|
||||
// Weight values
|
||||
let mut output = vec![0.0; self.dim];
|
||||
for (weight, value) in weights.iter().zip(values.iter()) {
|
||||
for (out, val) in output.iter_mut().zip(value.iter()) {
|
||||
*out += weight * val;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn dim(&self) -> usize {
|
||||
self.dim
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_scaled_dot_product() {
|
||||
let attn = ScaledDotProductAttention::new(4);
|
||||
let query = vec![1.0_f32, 0.0, 0.0, 0.0];
|
||||
let key1 = vec![1.0_f32, 0.0, 0.0, 0.0];
|
||||
let key2 = vec![0.0_f32, 1.0, 0.0, 0.0];
|
||||
let val1 = vec![1.0_f32, 2.0, 3.0, 4.0];
|
||||
let val2 = vec![5.0_f32, 6.0, 7.0, 8.0];
|
||||
let keys = vec![key1.as_slice(), key2.as_slice()];
|
||||
let values = vec![val1.as_slice(), val2.as_slice()];
|
||||
|
||||
let result = attn.compute(&query, &keys, &values).unwrap();
|
||||
assert_eq!(result.len(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_with_mask() {
|
||||
let attn = ScaledDotProductAttention::new(4);
|
||||
let query = vec![1.0_f32; 4];
|
||||
let key1 = vec![1.0_f32; 4];
|
||||
let key2 = vec![0.5_f32; 4];
|
||||
let val1 = vec![1.0_f32; 4];
|
||||
let val2 = vec![2.0_f32; 4];
|
||||
let keys = vec![key1.as_slice(), key2.as_slice()];
|
||||
let values = vec![val1.as_slice(), val2.as_slice()];
|
||||
let mask = vec![true, false];
|
||||
|
||||
let result = attn
|
||||
.compute_with_mask(&query, &keys, &values, Some(&mask))
|
||||
.unwrap();
|
||||
assert_eq!(result.len(), 4);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user