396 lines
11 KiB
Rust
396 lines
11 KiB
Rust
//! Configuration types for attention mechanisms.
|
|
//!
|
|
//! This module provides configuration structs and builders for various
|
|
//! attention mechanisms including standard, graph, and sparse attention.
|
|
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
use crate::error::{AttentionError, AttentionResult};
|
|
|
|
/// Configuration for standard attention mechanisms.
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
pub struct AttentionConfig {
|
|
/// Model dimension (d_model)
|
|
pub dim: usize,
|
|
/// Number of attention heads
|
|
pub num_heads: usize,
|
|
/// Dropout probability (0.0 to 1.0)
|
|
pub dropout: f32,
|
|
/// Scaling factor (default: 1/sqrt(d_k))
|
|
pub scale: Option<f32>,
|
|
/// Whether to use causal masking
|
|
pub causal: bool,
|
|
}
|
|
|
|
impl AttentionConfig {
|
|
/// Creates a new builder for AttentionConfig.
|
|
pub fn builder() -> AttentionConfigBuilder {
|
|
AttentionConfigBuilder::default()
|
|
}
|
|
|
|
/// Validates the configuration.
|
|
pub fn validate(&self) -> AttentionResult<()> {
|
|
if self.dim == 0 {
|
|
return Err(AttentionError::InvalidConfig(
|
|
"dimension must be greater than 0".to_string(),
|
|
));
|
|
}
|
|
|
|
if self.num_heads == 0 {
|
|
return Err(AttentionError::InvalidConfig(
|
|
"num_heads must be greater than 0".to_string(),
|
|
));
|
|
}
|
|
|
|
if self.dim % self.num_heads != 0 {
|
|
return Err(AttentionError::InvalidHeadCount {
|
|
dim: self.dim,
|
|
num_heads: self.num_heads,
|
|
});
|
|
}
|
|
|
|
if self.dropout < 0.0 || self.dropout > 1.0 {
|
|
return Err(AttentionError::InvalidConfig(
|
|
"dropout must be in range [0.0, 1.0]".to_string(),
|
|
));
|
|
}
|
|
|
|
if let Some(scale) = self.scale {
|
|
if !scale.is_finite() || scale <= 0.0 {
|
|
return Err(AttentionError::InvalidConfig(
|
|
"scale must be positive and finite".to_string(),
|
|
));
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Returns the dimension per head (d_k).
|
|
#[inline]
|
|
pub fn head_dim(&self) -> usize {
|
|
self.dim / self.num_heads
|
|
}
|
|
|
|
/// Returns the effective scale factor.
|
|
#[inline]
|
|
pub fn effective_scale(&self) -> f32 {
|
|
self.scale
|
|
.unwrap_or_else(|| 1.0 / (self.head_dim() as f32).sqrt())
|
|
}
|
|
}
|
|
|
|
/// Builder for AttentionConfig.
|
|
#[derive(Default)]
|
|
pub struct AttentionConfigBuilder {
|
|
dim: Option<usize>,
|
|
num_heads: Option<usize>,
|
|
dropout: f32,
|
|
scale: Option<f32>,
|
|
causal: bool,
|
|
}
|
|
|
|
impl AttentionConfigBuilder {
|
|
/// Sets the model dimension.
|
|
pub fn dim(mut self, dim: usize) -> Self {
|
|
self.dim = Some(dim);
|
|
self
|
|
}
|
|
|
|
/// Sets the number of attention heads.
|
|
pub fn num_heads(mut self, num_heads: usize) -> Self {
|
|
self.num_heads = Some(num_heads);
|
|
self
|
|
}
|
|
|
|
/// Sets the dropout probability.
|
|
pub fn dropout(mut self, dropout: f32) -> Self {
|
|
self.dropout = dropout;
|
|
self
|
|
}
|
|
|
|
/// Sets a custom scale factor.
|
|
pub fn scale(mut self, scale: f32) -> Self {
|
|
self.scale = Some(scale);
|
|
self
|
|
}
|
|
|
|
/// Enables causal masking.
|
|
pub fn causal(mut self, causal: bool) -> Self {
|
|
self.causal = causal;
|
|
self
|
|
}
|
|
|
|
/// Builds the AttentionConfig.
|
|
pub fn build(self) -> AttentionResult<AttentionConfig> {
|
|
let config = AttentionConfig {
|
|
dim: self.dim.ok_or_else(|| {
|
|
AttentionError::InvalidConfig("dimension must be specified".to_string())
|
|
})?,
|
|
num_heads: self.num_heads.ok_or_else(|| {
|
|
AttentionError::InvalidConfig("num_heads must be specified".to_string())
|
|
})?,
|
|
dropout: self.dropout,
|
|
scale: self.scale,
|
|
causal: self.causal,
|
|
};
|
|
|
|
config.validate()?;
|
|
Ok(config)
|
|
}
|
|
}
|
|
|
|
/// Configuration for graph attention networks.
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
pub struct GraphAttentionConfig {
|
|
/// Base attention configuration
|
|
pub base: AttentionConfig,
|
|
/// Edge feature dimension (if using edge features)
|
|
pub edge_dim: Option<usize>,
|
|
/// Negative slope for LeakyReLU
|
|
pub negative_slope: f32,
|
|
/// Whether to concatenate multi-head outputs (vs averaging)
|
|
pub concat_heads: bool,
|
|
}
|
|
|
|
impl GraphAttentionConfig {
|
|
/// Creates a new builder for GraphAttentionConfig.
|
|
pub fn builder() -> GraphAttentionConfigBuilder {
|
|
GraphAttentionConfigBuilder::default()
|
|
}
|
|
|
|
/// Validates the configuration.
|
|
pub fn validate(&self) -> AttentionResult<()> {
|
|
self.base.validate()?;
|
|
|
|
if self.negative_slope <= 0.0 || !self.negative_slope.is_finite() {
|
|
return Err(AttentionError::InvalidConfig(
|
|
"negative_slope must be positive and finite".to_string(),
|
|
));
|
|
}
|
|
|
|
if let Some(edge_dim) = self.edge_dim {
|
|
if edge_dim == 0 {
|
|
return Err(AttentionError::InvalidConfig(
|
|
"edge_dim must be greater than 0".to_string(),
|
|
));
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
/// Builder for GraphAttentionConfig.
|
|
#[derive(Default)]
|
|
pub struct GraphAttentionConfigBuilder {
|
|
base_builder: AttentionConfigBuilder,
|
|
edge_dim: Option<usize>,
|
|
negative_slope: f32,
|
|
concat_heads: bool,
|
|
}
|
|
|
|
impl GraphAttentionConfigBuilder {
|
|
/// Sets the model dimension.
|
|
pub fn dim(mut self, dim: usize) -> Self {
|
|
self.base_builder = self.base_builder.dim(dim);
|
|
self
|
|
}
|
|
|
|
/// Sets the number of attention heads.
|
|
pub fn num_heads(mut self, num_heads: usize) -> Self {
|
|
self.base_builder = self.base_builder.num_heads(num_heads);
|
|
self
|
|
}
|
|
|
|
/// Sets the edge feature dimension.
|
|
pub fn edge_dim(mut self, edge_dim: usize) -> Self {
|
|
self.edge_dim = Some(edge_dim);
|
|
self
|
|
}
|
|
|
|
/// Sets the negative slope for LeakyReLU.
|
|
pub fn negative_slope(mut self, slope: f32) -> Self {
|
|
self.negative_slope = slope;
|
|
self
|
|
}
|
|
|
|
/// Sets whether to concatenate multi-head outputs.
|
|
pub fn concat_heads(mut self, concat: bool) -> Self {
|
|
self.concat_heads = concat;
|
|
self
|
|
}
|
|
|
|
/// Builds the GraphAttentionConfig.
|
|
pub fn build(self) -> AttentionResult<GraphAttentionConfig> {
|
|
let config = GraphAttentionConfig {
|
|
base: self.base_builder.build()?,
|
|
edge_dim: self.edge_dim,
|
|
negative_slope: if self.negative_slope == 0.0 {
|
|
0.2
|
|
} else {
|
|
self.negative_slope
|
|
},
|
|
concat_heads: self.concat_heads,
|
|
};
|
|
|
|
config.validate()?;
|
|
Ok(config)
|
|
}
|
|
}
|
|
|
|
/// Configuration for sparse attention mechanisms.
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
pub struct SparseAttentionConfig {
|
|
/// Base attention configuration
|
|
pub base: AttentionConfig,
|
|
/// Block size for block-sparse attention
|
|
pub block_size: usize,
|
|
/// Number of random blocks per query
|
|
pub num_random_blocks: usize,
|
|
/// Number of global tokens
|
|
pub num_global_tokens: usize,
|
|
}
|
|
|
|
impl SparseAttentionConfig {
|
|
/// Creates a new builder for SparseAttentionConfig.
|
|
pub fn builder() -> SparseAttentionConfigBuilder {
|
|
SparseAttentionConfigBuilder::default()
|
|
}
|
|
|
|
/// Validates the configuration.
|
|
pub fn validate(&self) -> AttentionResult<()> {
|
|
self.base.validate()?;
|
|
|
|
if self.block_size == 0 {
|
|
return Err(AttentionError::InvalidConfig(
|
|
"block_size must be greater than 0".to_string(),
|
|
));
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
/// Builder for SparseAttentionConfig.
|
|
#[derive(Default)]
|
|
pub struct SparseAttentionConfigBuilder {
|
|
base_builder: AttentionConfigBuilder,
|
|
block_size: usize,
|
|
num_random_blocks: usize,
|
|
num_global_tokens: usize,
|
|
}
|
|
|
|
impl SparseAttentionConfigBuilder {
|
|
/// Sets the model dimension.
|
|
pub fn dim(mut self, dim: usize) -> Self {
|
|
self.base_builder = self.base_builder.dim(dim);
|
|
self
|
|
}
|
|
|
|
/// Sets the number of attention heads.
|
|
pub fn num_heads(mut self, num_heads: usize) -> Self {
|
|
self.base_builder = self.base_builder.num_heads(num_heads);
|
|
self
|
|
}
|
|
|
|
/// Sets the block size.
|
|
pub fn block_size(mut self, block_size: usize) -> Self {
|
|
self.block_size = block_size;
|
|
self
|
|
}
|
|
|
|
/// Sets the number of random blocks.
|
|
pub fn num_random_blocks(mut self, num_random_blocks: usize) -> Self {
|
|
self.num_random_blocks = num_random_blocks;
|
|
self
|
|
}
|
|
|
|
/// Sets the number of global tokens.
|
|
pub fn num_global_tokens(mut self, num_global_tokens: usize) -> Self {
|
|
self.num_global_tokens = num_global_tokens;
|
|
self
|
|
}
|
|
|
|
/// Builds the SparseAttentionConfig.
|
|
pub fn build(self) -> AttentionResult<SparseAttentionConfig> {
|
|
let config = SparseAttentionConfig {
|
|
base: self.base_builder.build()?,
|
|
block_size: if self.block_size == 0 {
|
|
64
|
|
} else {
|
|
self.block_size
|
|
},
|
|
num_random_blocks: self.num_random_blocks,
|
|
num_global_tokens: self.num_global_tokens,
|
|
};
|
|
|
|
config.validate()?;
|
|
Ok(config)
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_attention_config_builder() {
|
|
let config = AttentionConfig::builder()
|
|
.dim(512)
|
|
.num_heads(8)
|
|
.dropout(0.1)
|
|
.causal(true)
|
|
.build()
|
|
.unwrap();
|
|
|
|
assert_eq!(config.dim, 512);
|
|
assert_eq!(config.num_heads, 8);
|
|
assert_eq!(config.dropout, 0.1);
|
|
assert!(config.causal);
|
|
assert_eq!(config.head_dim(), 64);
|
|
}
|
|
|
|
#[test]
|
|
fn test_config_validation() {
|
|
let result = AttentionConfig::builder()
|
|
.dim(512)
|
|
.num_heads(7) // Not divisible
|
|
.build();
|
|
|
|
assert!(result.is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn test_graph_attention_config() {
|
|
let config = GraphAttentionConfig::builder()
|
|
.dim(256)
|
|
.num_heads(4)
|
|
.edge_dim(16)
|
|
.negative_slope(0.2)
|
|
.concat_heads(true)
|
|
.build()
|
|
.unwrap();
|
|
|
|
assert_eq!(config.base.dim, 256);
|
|
assert_eq!(config.edge_dim, Some(16));
|
|
assert!(config.concat_heads);
|
|
}
|
|
|
|
#[test]
|
|
fn test_sparse_attention_config() {
|
|
let config = SparseAttentionConfig::builder()
|
|
.dim(512)
|
|
.num_heads(8)
|
|
.block_size(64)
|
|
.num_random_blocks(3)
|
|
.num_global_tokens(64)
|
|
.build()
|
|
.unwrap();
|
|
|
|
assert_eq!(config.base.dim, 512);
|
|
assert_eq!(config.block_size, 64);
|
|
assert_eq!(config.num_random_blocks, 3);
|
|
}
|
|
}
|