git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
33 KiB
33 KiB
Agent 5: Mixture of Experts (MoE) Adaptive Attention
Overview
This implementation provides a flexible Mixture of Experts attention mechanism that dynamically routes queries to specialized attention experts based on learned gating functions. The system supports multiple expert types (standard multi-head, hyperbolic, linear, edge-featured) with load balancing and efficient top-k routing.
Architecture Components
1. Expert Types Enum
use ndarray::{Array2, Array3, ArrayView2, ArrayView3};
use std::sync::Arc;
/// Types of attention experts available in the mixture
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ExpertType {
/// Standard multi-head attention
Standard,
/// Hyperbolic geometry attention
Hyperbolic,
/// Linear complexity attention (e.g., Performer)
Linear,
/// Edge-featured attention for graphs
EdgeFeatured,
}
impl ExpertType {
pub fn name(&self) -> &str {
match self {
ExpertType::Standard => "standard",
ExpertType::Hyperbolic => "hyperbolic",
ExpertType::Linear => "linear",
ExpertType::EdgeFeatured => "edge_featured",
}
}
}
2. Attention Expert Trait
/// Trait that all attention experts must implement
pub trait AttentionExpert: Send + Sync {
/// Forward pass through the expert
///
/// # Arguments
/// * `queries` - Query embeddings [batch, seq_len, dim]
/// * `keys` - Key embeddings [batch, seq_len, dim]
/// * `values` - Value embeddings [batch, seq_len, dim]
/// * `edge_features` - Optional edge features for graph attention
///
/// # Returns
/// * Attended output [batch, seq_len, dim]
/// * Attention weights [batch, num_heads, seq_len, seq_len]
fn forward(
&self,
queries: ArrayView3<f32>,
keys: ArrayView3<f32>,
values: ArrayView3<f32>,
edge_features: Option<ArrayView3<f32>>,
) -> (Array3<f32>, Array3<f32>);
/// Get the expert type
fn expert_type(&self) -> ExpertType;
/// Get output dimension
fn output_dim(&self) -> usize;
/// Get number of parameters
fn num_parameters(&self) -> usize;
}
3. Learned Routing Network
use ndarray_rand::RandomExt;
use ndarray_rand::rand_distr::Uniform;
/// Learned routing network for expert selection
pub struct LearnedRouter {
/// Input dimension
input_dim: usize,
/// Number of experts
num_experts: usize,
/// Hidden dimension for routing network
hidden_dim: usize,
/// Temperature for softmax (higher = more uniform)
temperature: f32,
// Network parameters
/// First layer weights [input_dim, hidden_dim]
w1: Array2<f32>,
/// First layer bias [hidden_dim]
b1: Array1<f32>,
/// Second layer weights [hidden_dim, num_experts]
w2: Array2<f32>,
/// Second layer bias [num_experts]
b2: Array1<f32>,
/// Load balancing coefficient
load_balance_loss_coef: f32,
}
impl LearnedRouter {
pub fn new(
input_dim: usize,
num_experts: usize,
hidden_dim: usize,
temperature: f32,
) -> Self {
// Initialize with Xavier/Glorot uniform
let limit1 = (6.0 / (input_dim + hidden_dim) as f32).sqrt();
let limit2 = (6.0 / (hidden_dim + num_experts) as f32).sqrt();
Self {
input_dim,
num_experts,
hidden_dim,
temperature,
w1: Array2::random((input_dim, hidden_dim), Uniform::new(-limit1, limit1)),
b1: Array1::zeros(hidden_dim),
w2: Array2::random((hidden_dim, num_experts), Uniform::new(-limit2, limit2)),
b2: Array1::zeros(num_experts),
load_balance_loss_coef: 0.01,
}
}
/// Compute routing scores for each query
///
/// # Arguments
/// * `queries` - Query embeddings [batch, seq_len, input_dim]
///
/// # Returns
/// * Routing logits [batch, seq_len, num_experts]
pub fn route(&self, queries: ArrayView3<f32>) -> Array3<f32> {
let (batch_size, seq_len, _) = queries.dim();
// Reshape for matrix multiply: [batch * seq_len, input_dim]
let queries_flat = queries
.to_owned()
.into_shape((batch_size * seq_len, self.input_dim))
.unwrap();
// First layer: [batch * seq_len, hidden_dim]
let hidden = queries_flat.dot(&self.w1) + &self.b1;
let hidden = hidden.mapv(|x| x.max(0.0)); // ReLU
// Second layer: [batch * seq_len, num_experts]
let logits = hidden.dot(&self.w2) + &self.b2;
// Apply temperature scaling
let logits = logits / self.temperature;
// Reshape back: [batch, seq_len, num_experts]
logits.into_shape((batch_size, seq_len, self.num_experts)).unwrap()
}
/// Compute top-k gating with softmax normalization
///
/// # Arguments
/// * `logits` - Routing logits [batch, seq_len, num_experts]
/// * `k` - Number of experts to select
///
/// # Returns
/// * Gating weights [batch, seq_len, num_experts] (sparse, only top-k nonzero)
/// * Expert indices [batch, seq_len, k]
pub fn top_k_gating(&self, logits: ArrayView3<f32>, k: usize) -> (Array3<f32>, Array3<usize>) {
let (batch_size, seq_len, num_experts) = logits.dim();
assert!(k <= num_experts, "k must be <= num_experts");
let mut gates = Array3::<f32>::zeros((batch_size, seq_len, num_experts));
let mut indices = Array3::<usize>::zeros((batch_size, seq_len, k));
for b in 0..batch_size {
for s in 0..seq_len {
let row = logits.slice(s![b, s, ..]);
// Get top-k indices
let mut indexed: Vec<(usize, f32)> = row
.iter()
.enumerate()
.map(|(i, &v)| (i, v))
.collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
let top_k_indices: Vec<usize> = indexed.iter().take(k).map(|(i, _)| *i).collect();
let top_k_logits: Vec<f32> = indexed.iter().take(k).map(|(_, v)| *v).collect();
// Softmax over top-k
let max_logit = top_k_logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_sum: f32 = top_k_logits.iter().map(|&x| (x - max_logit).exp()).sum();
for (idx, &expert_idx) in top_k_indices.iter().enumerate() {
let gate = ((top_k_logits[idx] - max_logit).exp()) / exp_sum;
gates[[b, s, expert_idx]] = gate;
indices[[b, s, idx]] = expert_idx;
}
}
}
(gates, indices)
}
/// Compute load balancing loss to encourage uniform expert usage
///
/// # Arguments
/// * `gates` - Gating weights [batch, seq_len, num_experts]
///
/// # Returns
/// * Load balancing auxiliary loss (scalar)
pub fn load_balancing_loss(&self, gates: ArrayView3<f32>) -> f32 {
let (batch_size, seq_len, num_experts) = gates.dim();
let total_tokens = (batch_size * seq_len) as f32;
// Compute fraction of tokens routed to each expert
let mut expert_usage = Array1::<f32>::zeros(num_experts);
for e in 0..num_experts {
let usage: f32 = gates.slice(s![.., .., e]).sum();
expert_usage[e] = usage / total_tokens;
}
// Compute importance (average gate value when expert is selected)
let mut expert_importance = Array1::<f32>::zeros(num_experts);
for e in 0..num_experts {
let importance: f32 = gates.slice(s![.., .., e]).mean().unwrap_or(0.0);
expert_importance[e] = importance;
}
// Load balancing loss: encourages uniform usage × importance
// Loss = num_experts * sum(usage[i] * importance[i])
// Minimal when usage and importance are balanced
let loss: f32 = expert_usage.iter()
.zip(expert_importance.iter())
.map(|(&u, &i)| u * i)
.sum::<f32>() * num_experts as f32;
loss * self.load_balance_loss_coef
}
}
use ndarray::Array1;
use ndarray::s;
4. MoE Attention Configuration
/// Configuration for MoE Attention
#[derive(Debug, Clone)]
pub struct MoEAttentionConfig {
/// Input/output dimension
pub dim: usize,
/// Number of attention heads per expert
pub num_heads: usize,
/// Number of experts in the mixture
pub num_experts: usize,
/// Number of experts to activate per query (top-k)
pub top_k: usize,
/// Expert types to include
pub expert_types: Vec<ExpertType>,
/// Hidden dimension for routing network
pub router_hidden_dim: usize,
/// Temperature for routing softmax
pub router_temperature: f32,
/// Load balancing loss coefficient
pub load_balance_coef: f32,
/// Dropout rate
pub dropout: f32,
}
impl Default for MoEAttentionConfig {
fn default() -> Self {
Self {
dim: 512,
num_heads: 8,
num_experts: 4,
top_k: 2,
expert_types: vec![
ExpertType::Standard,
ExpertType::Hyperbolic,
ExpertType::Linear,
ExpertType::EdgeFeatured,
],
router_hidden_dim: 256,
router_temperature: 1.0,
load_balance_coef: 0.01,
dropout: 0.1,
}
}
}
impl MoEAttentionConfig {
pub fn builder() -> MoEAttentionConfigBuilder {
MoEAttentionConfigBuilder::default()
}
}
/// Builder for MoEAttentionConfig
#[derive(Default)]
pub struct MoEAttentionConfigBuilder {
dim: Option<usize>,
num_heads: Option<usize>,
num_experts: Option<usize>,
top_k: Option<usize>,
expert_types: Option<Vec<ExpertType>>,
router_hidden_dim: Option<usize>,
router_temperature: Option<f32>,
load_balance_coef: Option<f32>,
dropout: Option<f32>,
}
impl MoEAttentionConfigBuilder {
pub fn dim(mut self, dim: usize) -> Self {
self.dim = Some(dim);
self
}
pub fn num_heads(mut self, num_heads: usize) -> Self {
self.num_heads = Some(num_heads);
self
}
pub fn num_experts(mut self, num_experts: usize) -> Self {
self.num_experts = Some(num_experts);
self
}
pub fn top_k(mut self, top_k: usize) -> Self {
self.top_k = Some(top_k);
self
}
pub fn expert_types(mut self, expert_types: Vec<ExpertType>) -> Self {
self.expert_types = Some(expert_types);
self
}
pub fn router_hidden_dim(mut self, dim: usize) -> Self {
self.router_hidden_dim = Some(dim);
self
}
pub fn router_temperature(mut self, temp: f32) -> Self {
self.router_temperature = Some(temp);
self
}
pub fn load_balance_coef(mut self, coef: f32) -> Self {
self.load_balance_coef = Some(coef);
self
}
pub fn dropout(mut self, dropout: f32) -> Self {
self.dropout = Some(dropout);
self
}
pub fn build(self) -> MoEAttentionConfig {
let default = MoEAttentionConfig::default();
MoEAttentionConfig {
dim: self.dim.unwrap_or(default.dim),
num_heads: self.num_heads.unwrap_or(default.num_heads),
num_experts: self.num_experts.unwrap_or(default.num_experts),
top_k: self.top_k.unwrap_or(default.top_k),
expert_types: self.expert_types.unwrap_or(default.expert_types),
router_hidden_dim: self.router_hidden_dim.unwrap_or(default.router_hidden_dim),
router_temperature: self.router_temperature.unwrap_or(default.router_temperature),
load_balance_coef: self.load_balance_coef.unwrap_or(default.load_balance_coef),
dropout: self.dropout.unwrap_or(default.dropout),
}
}
}
5. MoE Attention Implementation
/// Main Mixture of Experts Attention module
pub struct MoEAttention {
config: MoEAttentionConfig,
/// Routing network
router: LearnedRouter,
/// Pool of attention experts
experts: Vec<Box<dyn AttentionExpert>>,
/// Output projection
output_projection: Array2<f32>,
output_bias: Array1<f32>,
/// Training mode flag
training: bool,
/// Cached auxiliary losses
aux_loss: f32,
}
impl MoEAttention {
pub fn new(config: MoEAttentionConfig) -> Self {
assert_eq!(
config.expert_types.len(),
config.num_experts,
"Number of expert types must match num_experts"
);
assert!(
config.top_k <= config.num_experts,
"top_k must be <= num_experts"
);
// Create router
let router = LearnedRouter::new(
config.dim,
config.num_experts,
config.router_hidden_dim,
config.router_temperature,
);
// Create experts (placeholder - would instantiate actual expert implementations)
let experts: Vec<Box<dyn AttentionExpert>> = config
.expert_types
.iter()
.map(|&expert_type| {
create_expert(expert_type, config.dim, config.num_heads, config.dropout)
})
.collect();
// Output projection
let limit = (6.0 / (config.dim + config.dim) as f32).sqrt();
let output_projection = Array2::random(
(config.dim, config.dim),
Uniform::new(-limit, limit),
);
let output_bias = Array1::zeros(config.dim);
Self {
config,
router,
experts,
output_projection,
output_bias,
training: true,
aux_loss: 0.0,
}
}
/// Forward pass through MoE attention
///
/// # Arguments
/// * `queries` - Query embeddings [batch, seq_len, dim]
/// * `keys` - Key embeddings [batch, seq_len, dim]
/// * `values` - Value embeddings [batch, seq_len, dim]
/// * `edge_features` - Optional edge features [batch, seq_len, seq_len, edge_dim]
///
/// # Returns
/// * Output embeddings [batch, seq_len, dim]
/// * Attention weights [batch, num_experts, num_heads, seq_len, seq_len]
/// * Expert assignment info
pub fn forward(
&mut self,
queries: ArrayView3<f32>,
keys: ArrayView3<f32>,
values: ArrayView3<f32>,
edge_features: Option<ArrayView3<f32>>,
) -> MoEAttentionOutput {
let (batch_size, seq_len, dim) = queries.dim();
assert_eq!(dim, self.config.dim);
// 1. Compute routing scores
let routing_logits = self.router.route(queries);
// 2. Get top-k gating
let (gates, expert_indices) = self.router.top_k_gating(
routing_logits.view(),
self.config.top_k,
);
// 3. Compute load balancing loss (only in training)
if self.training {
self.aux_loss = self.router.load_balancing_loss(gates.view());
}
// 4. Initialize output accumulator
let mut output = Array3::<f32>::zeros((batch_size, seq_len, dim));
let mut all_attention_weights = Vec::new();
// 5. Process each expert
for expert_idx in 0..self.config.num_experts {
// Get the expert
let expert = &self.experts[expert_idx];
// Find tokens routed to this expert
let expert_mask = gates.slice(s![.., .., expert_idx]);
// Skip if no tokens assigned
if expert_mask.iter().all(|&x| x == 0.0) {
continue;
}
// Run expert on all queries (could be optimized to only process assigned tokens)
let (expert_output, expert_attn) = expert.forward(
queries,
keys,
values,
edge_features,
);
// Weight and accumulate expert output
for b in 0..batch_size {
for s in 0..seq_len {
let gate = expert_mask[[b, s]];
if gate > 0.0 {
let weighted_output = expert_output.slice(s![b, s, ..]).mapv(|x| x * gate);
let mut out_slice = output.slice_mut(s![b, s, ..]);
out_slice += &weighted_output;
}
}
}
all_attention_weights.push(expert_attn);
}
// 6. Apply output projection
let output_flat = output
.to_owned()
.into_shape((batch_size * seq_len, dim))
.unwrap();
let projected = output_flat.dot(&self.output_projection) + &self.output_bias;
let output = projected.into_shape((batch_size, seq_len, dim)).unwrap();
MoEAttentionOutput {
output,
attention_weights: all_attention_weights,
gates,
expert_indices,
aux_loss: self.aux_loss,
}
}
/// Set training mode
pub fn train(&mut self) {
self.training = true;
}
/// Set evaluation mode
pub fn eval(&mut self) {
self.training = false;
}
/// Get auxiliary loss (for backprop)
pub fn get_aux_loss(&self) -> f32 {
self.aux_loss
}
/// Get number of parameters
pub fn num_parameters(&self) -> usize {
let router_params = self.router.w1.len() + self.router.b1.len()
+ self.router.w2.len() + self.router.b2.len();
let expert_params: usize = self.experts.iter().map(|e| e.num_parameters()).sum();
let output_params = self.output_projection.len() + self.output_bias.len();
router_params + expert_params + output_params
}
}
/// Output from MoE attention forward pass
pub struct MoEAttentionOutput {
/// Main output [batch, seq_len, dim]
pub output: Array3<f32>,
/// Attention weights from each expert
pub attention_weights: Vec<Array3<f32>>,
/// Gating weights [batch, seq_len, num_experts]
pub gates: Array3<f32>,
/// Top-k expert indices [batch, seq_len, k]
pub expert_indices: Array3<usize>,
/// Auxiliary load balancing loss
pub aux_loss: f32,
}
6. Expert Factory and Implementations
/// Factory function to create experts
fn create_expert(
expert_type: ExpertType,
dim: usize,
num_heads: usize,
dropout: f32,
) -> Box<dyn AttentionExpert> {
match expert_type {
ExpertType::Standard => Box::new(StandardAttentionExpert::new(dim, num_heads, dropout)),
ExpertType::Hyperbolic => Box::new(HyperbolicAttentionExpert::new(dim, num_heads, dropout)),
ExpertType::Linear => Box::new(LinearAttentionExpert::new(dim, num_heads, dropout)),
ExpertType::EdgeFeatured => Box::new(EdgeAttentionExpert::new(dim, num_heads, dropout)),
}
}
// Placeholder implementations (would be fully implemented in actual code)
/// Standard multi-head attention expert
pub struct StandardAttentionExpert {
dim: usize,
num_heads: usize,
head_dim: usize,
dropout: f32,
// Projection matrices would be here
wq: Array2<f32>,
wk: Array2<f32>,
wv: Array2<f32>,
}
impl StandardAttentionExpert {
pub fn new(dim: usize, num_heads: usize, dropout: f32) -> Self {
let head_dim = dim / num_heads;
let limit = (6.0 / (dim + dim) as f32).sqrt();
Self {
dim,
num_heads,
head_dim,
dropout,
wq: Array2::random((dim, dim), Uniform::new(-limit, limit)),
wk: Array2::random((dim, dim), Uniform::new(-limit, limit)),
wv: Array2::random((dim, dim), Uniform::new(-limit, limit)),
}
}
}
impl AttentionExpert for StandardAttentionExpert {
fn forward(
&self,
queries: ArrayView3<f32>,
keys: ArrayView3<f32>,
values: ArrayView3<f32>,
_edge_features: Option<ArrayView3<f32>>,
) -> (Array3<f32>, Array3<f32>) {
let (batch_size, seq_len, _) = queries.dim();
// Standard scaled dot-product attention
// (Simplified - full implementation would reshape for multi-head)
let q_flat = queries.to_owned().into_shape((batch_size * seq_len, self.dim)).unwrap();
let k_flat = keys.to_owned().into_shape((batch_size * seq_len, self.dim)).unwrap();
let v_flat = values.to_owned().into_shape((batch_size * seq_len, self.dim)).unwrap();
let q_proj = q_flat.dot(&self.wq);
let k_proj = k_flat.dot(&self.wk);
let v_proj = v_flat.dot(&self.wv);
// Reshape and compute attention
let output = v_proj; // Placeholder
let output = output.into_shape((batch_size, seq_len, self.dim)).unwrap();
let attn_weights = Array3::ones((batch_size, self.num_heads, seq_len, seq_len).f())
.into_shape((batch_size, self.num_heads, seq_len))
.unwrap(); // Placeholder
(output, attn_weights)
}
fn expert_type(&self) -> ExpertType {
ExpertType::Standard
}
fn output_dim(&self) -> usize {
self.dim
}
fn num_parameters(&self) -> usize {
self.wq.len() + self.wk.len() + self.wv.len()
}
}
/// Hyperbolic attention expert
pub struct HyperbolicAttentionExpert {
dim: usize,
num_heads: usize,
dropout: f32,
curvature: f32,
}
impl HyperbolicAttentionExpert {
pub fn new(dim: usize, num_heads: usize, dropout: f32) -> Self {
Self {
dim,
num_heads,
dropout,
curvature: -1.0,
}
}
}
impl AttentionExpert for HyperbolicAttentionExpert {
fn forward(
&self,
queries: ArrayView3<f32>,
_keys: ArrayView3<f32>,
_values: ArrayView3<f32>,
_edge_features: Option<ArrayView3<f32>>,
) -> (Array3<f32>, Array3<f32>) {
// Placeholder - would implement hyperbolic geometry attention
let (batch_size, seq_len, _) = queries.dim();
(
Array3::zeros((batch_size, seq_len, self.dim)),
Array3::zeros((batch_size, self.num_heads, seq_len)),
)
}
fn expert_type(&self) -> ExpertType {
ExpertType::Hyperbolic
}
fn output_dim(&self) -> usize {
self.dim
}
fn num_parameters(&self) -> usize {
0 // Placeholder
}
}
/// Linear complexity attention expert (e.g., Performer-style)
pub struct LinearAttentionExpert {
dim: usize,
num_heads: usize,
dropout: f32,
}
impl LinearAttentionExpert {
pub fn new(dim: usize, num_heads: usize, dropout: f32) -> Self {
Self { dim, num_heads, dropout }
}
}
impl AttentionExpert for LinearAttentionExpert {
fn forward(
&self,
queries: ArrayView3<f32>,
_keys: ArrayView3<f32>,
_values: ArrayView3<f32>,
_edge_features: Option<ArrayView3<f32>>,
) -> (Array3<f32>, Array3<f32>) {
// Placeholder - would implement kernel-based linear attention
let (batch_size, seq_len, _) = queries.dim();
(
Array3::zeros((batch_size, seq_len, self.dim)),
Array3::zeros((batch_size, self.num_heads, seq_len)),
)
}
fn expert_type(&self) -> ExpertType {
ExpertType::Linear
}
fn output_dim(&self) -> usize {
self.dim
}
fn num_parameters(&self) -> usize {
0 // Placeholder
}
}
/// Edge-featured attention expert for graphs
pub struct EdgeAttentionExpert {
dim: usize,
num_heads: usize,
dropout: f32,
}
impl EdgeAttentionExpert {
pub fn new(dim: usize, num_heads: usize, dropout: f32) -> Self {
Self { dim, num_heads, dropout }
}
}
impl AttentionExpert for EdgeAttentionExpert {
fn forward(
&self,
queries: ArrayView3<f32>,
_keys: ArrayView3<f32>,
_values: ArrayView3<f32>,
edge_features: Option<ArrayView3<f32>>,
) -> (Array3<f32>, Array3<f32>) {
// Placeholder - would incorporate edge features into attention
let (batch_size, seq_len, _) = queries.dim();
let _ = edge_features; // Would use this
(
Array3::zeros((batch_size, seq_len, self.dim)),
Array3::zeros((batch_size, self.num_heads, seq_len)),
)
}
fn expert_type(&self) -> ExpertType {
ExpertType::EdgeFeatured
}
fn output_dim(&self) -> usize {
self.dim
}
fn num_parameters(&self) -> usize {
0 // Placeholder
}
}
Training Considerations
1. Loss Function
/// Complete loss for training MoE Attention
pub struct MoEAttentionLoss {
/// Main task loss (e.g., classification, regression)
task_loss_fn: Box<dyn Fn(&Array2<f32>, &Array2<f32>) -> f32>,
/// Load balancing loss coefficient
load_balance_coef: f32,
}
impl MoEAttentionLoss {
pub fn compute(
&self,
predictions: &Array2<f32>,
targets: &Array2<f32>,
aux_loss: f32,
) -> f32 {
let task_loss = (self.task_loss_fn)(predictions, targets);
let total_loss = task_loss + self.load_balance_coef * aux_loss;
total_loss
}
}
2. Training Loop Integration
/// Training step for MoE Attention
pub fn training_step(
model: &mut MoEAttention,
queries: ArrayView3<f32>,
keys: ArrayView3<f32>,
values: ArrayView3<f32>,
targets: &Array2<f32>,
loss_fn: &MoEAttentionLoss,
learning_rate: f32,
) -> f32 {
// Forward pass
model.train();
let output = model.forward(queries, keys, values, None);
// Compute loss
let predictions = output.output.slice(s![.., 0, ..]).to_owned(); // Simplified
let total_loss = loss_fn.compute(&predictions, targets, output.aux_loss);
// Backward pass (would use autograd in real implementation)
// gradients = compute_gradients(total_loss);
// update_parameters(model, gradients, learning_rate);
total_loss
}
3. Optimization Strategies
/// Optimizer configuration for MoE training
pub struct MoEOptimizer {
/// Learning rate for routing network
router_lr: f32,
/// Learning rate for experts
expert_lr: f32,
/// Learning rate for output projection
output_lr: f32,
/// Weight decay
weight_decay: f32,
/// Gradient clipping threshold
grad_clip: f32,
}
impl Default for MoEOptimizer {
fn default() -> Self {
Self {
router_lr: 1e-3,
expert_lr: 5e-4,
output_lr: 5e-4,
weight_decay: 1e-5,
grad_clip: 1.0,
}
}
}
4. Expert Specialization Monitoring
/// Monitor expert specialization during training
pub struct ExpertSpecializationMonitor {
/// Expert usage statistics [num_experts]
usage_counts: Array1<usize>,
/// Average gating weights per expert [num_experts]
avg_gates: Array1<f32>,
/// Number of steps
num_steps: usize,
}
impl ExpertSpecializationMonitor {
pub fn new(num_experts: usize) -> Self {
Self {
usage_counts: Array1::zeros(num_experts),
avg_gates: Array1::zeros(num_experts),
num_steps: 0,
}
}
pub fn update(&mut self, gates: ArrayView3<f32>) {
let num_experts = gates.dim().2;
for e in 0..num_experts {
let expert_gates = gates.slice(s![.., .., e]);
let usage = expert_gates.iter().filter(|&&x| x > 0.0).count();
let avg_gate = expert_gates.mean().unwrap_or(0.0);
self.usage_counts[e] += usage;
self.avg_gates[e] += avg_gate;
}
self.num_steps += 1;
}
pub fn get_statistics(&self) -> ExpertStats {
let avg_usage = self.usage_counts.mapv(|x| x as f32 / self.num_steps as f32);
let avg_gates = self.avg_gates.mapv(|x| x / self.num_steps as f32);
ExpertStats {
avg_usage,
avg_gates,
}
}
}
pub struct ExpertStats {
pub avg_usage: Array1<f32>,
pub avg_gates: Array1<f32>,
}
Usage Examples
Basic Usage
fn example_basic_usage() {
// Create configuration
let config = MoEAttentionConfig::builder()
.dim(512)
.num_heads(8)
.num_experts(4)
.top_k(2)
.expert_types(vec![
ExpertType::Standard,
ExpertType::Hyperbolic,
ExpertType::Linear,
ExpertType::EdgeFeatured,
])
.router_temperature(1.0)
.load_balance_coef(0.01)
.build();
// Create MoE attention module
let mut moe_attn = MoEAttention::new(config);
// Prepare inputs
let batch_size = 32;
let seq_len = 128;
let dim = 512;
let queries = Array3::<f32>::zeros((batch_size, seq_len, dim));
let keys = Array3::<f32>::zeros((batch_size, seq_len, dim));
let values = Array3::<f32>::zeros((batch_size, seq_len, dim));
// Forward pass
moe_attn.train();
let output = moe_attn.forward(
queries.view(),
keys.view(),
values.view(),
None,
);
println!("Output shape: {:?}", output.output.dim());
println!("Auxiliary loss: {:.6}", output.aux_loss);
println!("Number of parameters: {}", moe_attn.num_parameters());
}
Advanced Training Loop
fn example_training_loop() {
let config = MoEAttentionConfig::default();
let mut model = MoEAttention::new(config);
let mut monitor = ExpertSpecializationMonitor::new(4);
let num_epochs = 10;
let batch_size = 32;
let seq_len = 128;
let dim = 512;
for epoch in 0..num_epochs {
let mut epoch_loss = 0.0;
let num_batches = 100;
for batch in 0..num_batches {
// Generate dummy data
let queries = Array3::<f32>::random((batch_size, seq_len, dim), Uniform::new(0.0, 1.0));
let keys = Array3::<f32>::random((batch_size, seq_len, dim), Uniform::new(0.0, 1.0));
let values = Array3::<f32>::random((batch_size, seq_len, dim), Uniform::new(0.0, 1.0));
// Forward pass
let output = model.forward(queries.view(), keys.view(), values.view(), None);
// Track expert usage
monitor.update(output.gates.view());
// Compute loss (simplified)
let task_loss = output.output.mapv(|x| x * x).sum() / (batch_size * seq_len) as f32;
let total_loss = task_loss + output.aux_loss;
epoch_loss += total_loss;
// Backward pass and optimization would go here
}
println!("Epoch {}: Loss = {:.6}", epoch, epoch_loss / num_batches as f32);
// Print expert statistics
if epoch % 5 == 0 {
let stats = monitor.get_statistics();
println!("Expert usage: {:?}", stats.avg_usage);
println!("Expert gates: {:?}", stats.avg_gates);
}
}
}
Integration with Graph Neural Networks
/// Integrate MoE Attention into GNN layer
pub struct MoEGNNLayer {
moe_attention: MoEAttention,
node_transform: Array2<f32>,
edge_encoder: Array2<f32>,
}
impl MoEGNNLayer {
pub fn forward(
&mut self,
node_features: ArrayView3<f32>,
edge_features: ArrayView3<f32>,
adjacency: ArrayView2<f32>,
) -> Array3<f32> {
let (batch_size, num_nodes, node_dim) = node_features.dim();
// Transform node features
let queries = node_features;
let keys = node_features;
let values = node_features;
// Encode edge features
// edge_encoded = edge_features.dot(&self.edge_encoder)
// Apply MoE attention
let output = self.moe_attention.forward(
queries,
keys,
values,
Some(edge_features),
);
output.output
}
}
Performance Characteristics
Computational Complexity
- Routing: O(B × S × (D × H + H × E)) where B=batch, S=sequence, D=dim, H=hidden, E=experts
- Expert Forward: O(B × S² × D) per expert (for standard attention)
- Total: O(B × S² × D × k) where k=top-k experts activated
Memory Usage
- Router: (D × H + H × E) parameters
- Experts: ~(3 × D² + D) parameters per expert × E experts
- Activations: O(B × S × D × k) during forward pass
Load Balancing Benefits
- Prevents expert collapse (all tokens routed to one expert)
- Encourages specialization while maintaining coverage
- Typical coefficient: 0.01-0.1 depending on task
Best Practices
-
Expert Selection
- Start with k=2 for most tasks
- Increase k for more complex tasks requiring multiple perspectives
- Monitor expert usage to ensure balanced routing
-
Temperature Tuning
- Start with temperature=1.0
- Decrease (0.5-0.8) for sharper routing (more specialization)
- Increase (1.2-2.0) for softer routing (more collaboration)
-
Load Balancing
- Use coefficient 0.01 as baseline
- Increase if experts are underutilized
- Decrease if routing is too uniform
-
Expert Types
- Include diverse expert types for different inductive biases
- Standard for general patterns
- Hyperbolic for hierarchical structures
- Linear for long sequences
- EdgeFeatured for relational data
References
- Switch Transformers: Scaling to Trillion Parameter Models (Fedus et al., 2021)
- GShard: Scaling Giant Models with Conditional Computation (Lepikhin et al., 2020)
- Mixture-of-Experts with Expert Choice Routing (Zhou et al., 2022)
- ST-MoE: Designing Stable and Transferable Sparse Expert Models (Zoph et al., 2022)