Files
wifi-densepose/docs/research/latent-space/implementation-plans/agents/05-moe-attention.md
ruv d803bfe2b1 Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector
git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
2026-02-28 14:39:40 -05:00

33 KiB
Raw Blame History

Agent 5: Mixture of Experts (MoE) Adaptive Attention

Overview

This implementation provides a flexible Mixture of Experts attention mechanism that dynamically routes queries to specialized attention experts based on learned gating functions. The system supports multiple expert types (standard multi-head, hyperbolic, linear, edge-featured) with load balancing and efficient top-k routing.

Architecture Components

1. Expert Types Enum

use ndarray::{Array2, Array3, ArrayView2, ArrayView3};
use std::sync::Arc;

/// Types of attention experts available in the mixture
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ExpertType {
    /// Standard multi-head attention
    Standard,
    /// Hyperbolic geometry attention
    Hyperbolic,
    /// Linear complexity attention (e.g., Performer)
    Linear,
    /// Edge-featured attention for graphs
    EdgeFeatured,
}

impl ExpertType {
    pub fn name(&self) -> &str {
        match self {
            ExpertType::Standard => "standard",
            ExpertType::Hyperbolic => "hyperbolic",
            ExpertType::Linear => "linear",
            ExpertType::EdgeFeatured => "edge_featured",
        }
    }
}

2. Attention Expert Trait

/// Trait that all attention experts must implement
pub trait AttentionExpert: Send + Sync {
    /// Forward pass through the expert
    ///
    /// # Arguments
    /// * `queries` - Query embeddings [batch, seq_len, dim]
    /// * `keys` - Key embeddings [batch, seq_len, dim]
    /// * `values` - Value embeddings [batch, seq_len, dim]
    /// * `edge_features` - Optional edge features for graph attention
    ///
    /// # Returns
    /// * Attended output [batch, seq_len, dim]
    /// * Attention weights [batch, num_heads, seq_len, seq_len]
    fn forward(
        &self,
        queries: ArrayView3<f32>,
        keys: ArrayView3<f32>,
        values: ArrayView3<f32>,
        edge_features: Option<ArrayView3<f32>>,
    ) -> (Array3<f32>, Array3<f32>);

    /// Get the expert type
    fn expert_type(&self) -> ExpertType;

    /// Get output dimension
    fn output_dim(&self) -> usize;

    /// Get number of parameters
    fn num_parameters(&self) -> usize;
}

3. Learned Routing Network

use ndarray_rand::RandomExt;
use ndarray_rand::rand_distr::Uniform;

/// Learned routing network for expert selection
pub struct LearnedRouter {
    /// Input dimension
    input_dim: usize,
    /// Number of experts
    num_experts: usize,
    /// Hidden dimension for routing network
    hidden_dim: usize,
    /// Temperature for softmax (higher = more uniform)
    temperature: f32,

    // Network parameters
    /// First layer weights [input_dim, hidden_dim]
    w1: Array2<f32>,
    /// First layer bias [hidden_dim]
    b1: Array1<f32>,
    /// Second layer weights [hidden_dim, num_experts]
    w2: Array2<f32>,
    /// Second layer bias [num_experts]
    b2: Array1<f32>,

    /// Load balancing coefficient
    load_balance_loss_coef: f32,
}

impl LearnedRouter {
    pub fn new(
        input_dim: usize,
        num_experts: usize,
        hidden_dim: usize,
        temperature: f32,
    ) -> Self {
        // Initialize with Xavier/Glorot uniform
        let limit1 = (6.0 / (input_dim + hidden_dim) as f32).sqrt();
        let limit2 = (6.0 / (hidden_dim + num_experts) as f32).sqrt();

        Self {
            input_dim,
            num_experts,
            hidden_dim,
            temperature,
            w1: Array2::random((input_dim, hidden_dim), Uniform::new(-limit1, limit1)),
            b1: Array1::zeros(hidden_dim),
            w2: Array2::random((hidden_dim, num_experts), Uniform::new(-limit2, limit2)),
            b2: Array1::zeros(num_experts),
            load_balance_loss_coef: 0.01,
        }
    }

    /// Compute routing scores for each query
    ///
    /// # Arguments
    /// * `queries` - Query embeddings [batch, seq_len, input_dim]
    ///
    /// # Returns
    /// * Routing logits [batch, seq_len, num_experts]
    pub fn route(&self, queries: ArrayView3<f32>) -> Array3<f32> {
        let (batch_size, seq_len, _) = queries.dim();

        // Reshape for matrix multiply: [batch * seq_len, input_dim]
        let queries_flat = queries
            .to_owned()
            .into_shape((batch_size * seq_len, self.input_dim))
            .unwrap();

        // First layer: [batch * seq_len, hidden_dim]
        let hidden = queries_flat.dot(&self.w1) + &self.b1;
        let hidden = hidden.mapv(|x| x.max(0.0)); // ReLU

        // Second layer: [batch * seq_len, num_experts]
        let logits = hidden.dot(&self.w2) + &self.b2;

        // Apply temperature scaling
        let logits = logits / self.temperature;

        // Reshape back: [batch, seq_len, num_experts]
        logits.into_shape((batch_size, seq_len, self.num_experts)).unwrap()
    }

    /// Compute top-k gating with softmax normalization
    ///
    /// # Arguments
    /// * `logits` - Routing logits [batch, seq_len, num_experts]
    /// * `k` - Number of experts to select
    ///
    /// # Returns
    /// * Gating weights [batch, seq_len, num_experts] (sparse, only top-k nonzero)
    /// * Expert indices [batch, seq_len, k]
    pub fn top_k_gating(&self, logits: ArrayView3<f32>, k: usize) -> (Array3<f32>, Array3<usize>) {
        let (batch_size, seq_len, num_experts) = logits.dim();
        assert!(k <= num_experts, "k must be <= num_experts");

        let mut gates = Array3::<f32>::zeros((batch_size, seq_len, num_experts));
        let mut indices = Array3::<usize>::zeros((batch_size, seq_len, k));

        for b in 0..batch_size {
            for s in 0..seq_len {
                let row = logits.slice(s![b, s, ..]);

                // Get top-k indices
                let mut indexed: Vec<(usize, f32)> = row
                    .iter()
                    .enumerate()
                    .map(|(i, &v)| (i, v))
                    .collect();
                indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());

                let top_k_indices: Vec<usize> = indexed.iter().take(k).map(|(i, _)| *i).collect();
                let top_k_logits: Vec<f32> = indexed.iter().take(k).map(|(_, v)| *v).collect();

                // Softmax over top-k
                let max_logit = top_k_logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
                let exp_sum: f32 = top_k_logits.iter().map(|&x| (x - max_logit).exp()).sum();

                for (idx, &expert_idx) in top_k_indices.iter().enumerate() {
                    let gate = ((top_k_logits[idx] - max_logit).exp()) / exp_sum;
                    gates[[b, s, expert_idx]] = gate;
                    indices[[b, s, idx]] = expert_idx;
                }
            }
        }

        (gates, indices)
    }

    /// Compute load balancing loss to encourage uniform expert usage
    ///
    /// # Arguments
    /// * `gates` - Gating weights [batch, seq_len, num_experts]
    ///
    /// # Returns
    /// * Load balancing auxiliary loss (scalar)
    pub fn load_balancing_loss(&self, gates: ArrayView3<f32>) -> f32 {
        let (batch_size, seq_len, num_experts) = gates.dim();
        let total_tokens = (batch_size * seq_len) as f32;

        // Compute fraction of tokens routed to each expert
        let mut expert_usage = Array1::<f32>::zeros(num_experts);
        for e in 0..num_experts {
            let usage: f32 = gates.slice(s![.., .., e]).sum();
            expert_usage[e] = usage / total_tokens;
        }

        // Compute importance (average gate value when expert is selected)
        let mut expert_importance = Array1::<f32>::zeros(num_experts);
        for e in 0..num_experts {
            let importance: f32 = gates.slice(s![.., .., e]).mean().unwrap_or(0.0);
            expert_importance[e] = importance;
        }

        // Load balancing loss: encourages uniform usage × importance
        // Loss = num_experts * sum(usage[i] * importance[i])
        // Minimal when usage and importance are balanced
        let loss: f32 = expert_usage.iter()
            .zip(expert_importance.iter())
            .map(|(&u, &i)| u * i)
            .sum::<f32>() * num_experts as f32;

        loss * self.load_balance_loss_coef
    }
}

use ndarray::Array1;
use ndarray::s;

4. MoE Attention Configuration

/// Configuration for MoE Attention
#[derive(Debug, Clone)]
pub struct MoEAttentionConfig {
    /// Input/output dimension
    pub dim: usize,
    /// Number of attention heads per expert
    pub num_heads: usize,
    /// Number of experts in the mixture
    pub num_experts: usize,
    /// Number of experts to activate per query (top-k)
    pub top_k: usize,
    /// Expert types to include
    pub expert_types: Vec<ExpertType>,
    /// Hidden dimension for routing network
    pub router_hidden_dim: usize,
    /// Temperature for routing softmax
    pub router_temperature: f32,
    /// Load balancing loss coefficient
    pub load_balance_coef: f32,
    /// Dropout rate
    pub dropout: f32,
}

impl Default for MoEAttentionConfig {
    fn default() -> Self {
        Self {
            dim: 512,
            num_heads: 8,
            num_experts: 4,
            top_k: 2,
            expert_types: vec![
                ExpertType::Standard,
                ExpertType::Hyperbolic,
                ExpertType::Linear,
                ExpertType::EdgeFeatured,
            ],
            router_hidden_dim: 256,
            router_temperature: 1.0,
            load_balance_coef: 0.01,
            dropout: 0.1,
        }
    }
}

impl MoEAttentionConfig {
    pub fn builder() -> MoEAttentionConfigBuilder {
        MoEAttentionConfigBuilder::default()
    }
}

/// Builder for MoEAttentionConfig
#[derive(Default)]
pub struct MoEAttentionConfigBuilder {
    dim: Option<usize>,
    num_heads: Option<usize>,
    num_experts: Option<usize>,
    top_k: Option<usize>,
    expert_types: Option<Vec<ExpertType>>,
    router_hidden_dim: Option<usize>,
    router_temperature: Option<f32>,
    load_balance_coef: Option<f32>,
    dropout: Option<f32>,
}

impl MoEAttentionConfigBuilder {
    pub fn dim(mut self, dim: usize) -> Self {
        self.dim = Some(dim);
        self
    }

    pub fn num_heads(mut self, num_heads: usize) -> Self {
        self.num_heads = Some(num_heads);
        self
    }

    pub fn num_experts(mut self, num_experts: usize) -> Self {
        self.num_experts = Some(num_experts);
        self
    }

    pub fn top_k(mut self, top_k: usize) -> Self {
        self.top_k = Some(top_k);
        self
    }

    pub fn expert_types(mut self, expert_types: Vec<ExpertType>) -> Self {
        self.expert_types = Some(expert_types);
        self
    }

    pub fn router_hidden_dim(mut self, dim: usize) -> Self {
        self.router_hidden_dim = Some(dim);
        self
    }

    pub fn router_temperature(mut self, temp: f32) -> Self {
        self.router_temperature = Some(temp);
        self
    }

    pub fn load_balance_coef(mut self, coef: f32) -> Self {
        self.load_balance_coef = Some(coef);
        self
    }

    pub fn dropout(mut self, dropout: f32) -> Self {
        self.dropout = Some(dropout);
        self
    }

    pub fn build(self) -> MoEAttentionConfig {
        let default = MoEAttentionConfig::default();

        MoEAttentionConfig {
            dim: self.dim.unwrap_or(default.dim),
            num_heads: self.num_heads.unwrap_or(default.num_heads),
            num_experts: self.num_experts.unwrap_or(default.num_experts),
            top_k: self.top_k.unwrap_or(default.top_k),
            expert_types: self.expert_types.unwrap_or(default.expert_types),
            router_hidden_dim: self.router_hidden_dim.unwrap_or(default.router_hidden_dim),
            router_temperature: self.router_temperature.unwrap_or(default.router_temperature),
            load_balance_coef: self.load_balance_coef.unwrap_or(default.load_balance_coef),
            dropout: self.dropout.unwrap_or(default.dropout),
        }
    }
}

5. MoE Attention Implementation

/// Main Mixture of Experts Attention module
pub struct MoEAttention {
    config: MoEAttentionConfig,

    /// Routing network
    router: LearnedRouter,

    /// Pool of attention experts
    experts: Vec<Box<dyn AttentionExpert>>,

    /// Output projection
    output_projection: Array2<f32>,
    output_bias: Array1<f32>,

    /// Training mode flag
    training: bool,

    /// Cached auxiliary losses
    aux_loss: f32,
}

impl MoEAttention {
    pub fn new(config: MoEAttentionConfig) -> Self {
        assert_eq!(
            config.expert_types.len(),
            config.num_experts,
            "Number of expert types must match num_experts"
        );
        assert!(
            config.top_k <= config.num_experts,
            "top_k must be <= num_experts"
        );

        // Create router
        let router = LearnedRouter::new(
            config.dim,
            config.num_experts,
            config.router_hidden_dim,
            config.router_temperature,
        );

        // Create experts (placeholder - would instantiate actual expert implementations)
        let experts: Vec<Box<dyn AttentionExpert>> = config
            .expert_types
            .iter()
            .map(|&expert_type| {
                create_expert(expert_type, config.dim, config.num_heads, config.dropout)
            })
            .collect();

        // Output projection
        let limit = (6.0 / (config.dim + config.dim) as f32).sqrt();
        let output_projection = Array2::random(
            (config.dim, config.dim),
            Uniform::new(-limit, limit),
        );
        let output_bias = Array1::zeros(config.dim);

        Self {
            config,
            router,
            experts,
            output_projection,
            output_bias,
            training: true,
            aux_loss: 0.0,
        }
    }

    /// Forward pass through MoE attention
    ///
    /// # Arguments
    /// * `queries` - Query embeddings [batch, seq_len, dim]
    /// * `keys` - Key embeddings [batch, seq_len, dim]
    /// * `values` - Value embeddings [batch, seq_len, dim]
    /// * `edge_features` - Optional edge features [batch, seq_len, seq_len, edge_dim]
    ///
    /// # Returns
    /// * Output embeddings [batch, seq_len, dim]
    /// * Attention weights [batch, num_experts, num_heads, seq_len, seq_len]
    /// * Expert assignment info
    pub fn forward(
        &mut self,
        queries: ArrayView3<f32>,
        keys: ArrayView3<f32>,
        values: ArrayView3<f32>,
        edge_features: Option<ArrayView3<f32>>,
    ) -> MoEAttentionOutput {
        let (batch_size, seq_len, dim) = queries.dim();
        assert_eq!(dim, self.config.dim);

        // 1. Compute routing scores
        let routing_logits = self.router.route(queries);

        // 2. Get top-k gating
        let (gates, expert_indices) = self.router.top_k_gating(
            routing_logits.view(),
            self.config.top_k,
        );

        // 3. Compute load balancing loss (only in training)
        if self.training {
            self.aux_loss = self.router.load_balancing_loss(gates.view());
        }

        // 4. Initialize output accumulator
        let mut output = Array3::<f32>::zeros((batch_size, seq_len, dim));
        let mut all_attention_weights = Vec::new();

        // 5. Process each expert
        for expert_idx in 0..self.config.num_experts {
            // Get the expert
            let expert = &self.experts[expert_idx];

            // Find tokens routed to this expert
            let expert_mask = gates.slice(s![.., .., expert_idx]);

            // Skip if no tokens assigned
            if expert_mask.iter().all(|&x| x == 0.0) {
                continue;
            }

            // Run expert on all queries (could be optimized to only process assigned tokens)
            let (expert_output, expert_attn) = expert.forward(
                queries,
                keys,
                values,
                edge_features,
            );

            // Weight and accumulate expert output
            for b in 0..batch_size {
                for s in 0..seq_len {
                    let gate = expert_mask[[b, s]];
                    if gate > 0.0 {
                        let weighted_output = expert_output.slice(s![b, s, ..]).mapv(|x| x * gate);
                        let mut out_slice = output.slice_mut(s![b, s, ..]);
                        out_slice += &weighted_output;
                    }
                }
            }

            all_attention_weights.push(expert_attn);
        }

        // 6. Apply output projection
        let output_flat = output
            .to_owned()
            .into_shape((batch_size * seq_len, dim))
            .unwrap();
        let projected = output_flat.dot(&self.output_projection) + &self.output_bias;
        let output = projected.into_shape((batch_size, seq_len, dim)).unwrap();

        MoEAttentionOutput {
            output,
            attention_weights: all_attention_weights,
            gates,
            expert_indices,
            aux_loss: self.aux_loss,
        }
    }

    /// Set training mode
    pub fn train(&mut self) {
        self.training = true;
    }

    /// Set evaluation mode
    pub fn eval(&mut self) {
        self.training = false;
    }

    /// Get auxiliary loss (for backprop)
    pub fn get_aux_loss(&self) -> f32 {
        self.aux_loss
    }

    /// Get number of parameters
    pub fn num_parameters(&self) -> usize {
        let router_params = self.router.w1.len() + self.router.b1.len()
            + self.router.w2.len() + self.router.b2.len();
        let expert_params: usize = self.experts.iter().map(|e| e.num_parameters()).sum();
        let output_params = self.output_projection.len() + self.output_bias.len();

        router_params + expert_params + output_params
    }
}

/// Output from MoE attention forward pass
pub struct MoEAttentionOutput {
    /// Main output [batch, seq_len, dim]
    pub output: Array3<f32>,
    /// Attention weights from each expert
    pub attention_weights: Vec<Array3<f32>>,
    /// Gating weights [batch, seq_len, num_experts]
    pub gates: Array3<f32>,
    /// Top-k expert indices [batch, seq_len, k]
    pub expert_indices: Array3<usize>,
    /// Auxiliary load balancing loss
    pub aux_loss: f32,
}

6. Expert Factory and Implementations

/// Factory function to create experts
fn create_expert(
    expert_type: ExpertType,
    dim: usize,
    num_heads: usize,
    dropout: f32,
) -> Box<dyn AttentionExpert> {
    match expert_type {
        ExpertType::Standard => Box::new(StandardAttentionExpert::new(dim, num_heads, dropout)),
        ExpertType::Hyperbolic => Box::new(HyperbolicAttentionExpert::new(dim, num_heads, dropout)),
        ExpertType::Linear => Box::new(LinearAttentionExpert::new(dim, num_heads, dropout)),
        ExpertType::EdgeFeatured => Box::new(EdgeAttentionExpert::new(dim, num_heads, dropout)),
    }
}

// Placeholder implementations (would be fully implemented in actual code)

/// Standard multi-head attention expert
pub struct StandardAttentionExpert {
    dim: usize,
    num_heads: usize,
    head_dim: usize,
    dropout: f32,
    // Projection matrices would be here
    wq: Array2<f32>,
    wk: Array2<f32>,
    wv: Array2<f32>,
}

impl StandardAttentionExpert {
    pub fn new(dim: usize, num_heads: usize, dropout: f32) -> Self {
        let head_dim = dim / num_heads;
        let limit = (6.0 / (dim + dim) as f32).sqrt();

        Self {
            dim,
            num_heads,
            head_dim,
            dropout,
            wq: Array2::random((dim, dim), Uniform::new(-limit, limit)),
            wk: Array2::random((dim, dim), Uniform::new(-limit, limit)),
            wv: Array2::random((dim, dim), Uniform::new(-limit, limit)),
        }
    }
}

impl AttentionExpert for StandardAttentionExpert {
    fn forward(
        &self,
        queries: ArrayView3<f32>,
        keys: ArrayView3<f32>,
        values: ArrayView3<f32>,
        _edge_features: Option<ArrayView3<f32>>,
    ) -> (Array3<f32>, Array3<f32>) {
        let (batch_size, seq_len, _) = queries.dim();

        // Standard scaled dot-product attention
        // (Simplified - full implementation would reshape for multi-head)
        let q_flat = queries.to_owned().into_shape((batch_size * seq_len, self.dim)).unwrap();
        let k_flat = keys.to_owned().into_shape((batch_size * seq_len, self.dim)).unwrap();
        let v_flat = values.to_owned().into_shape((batch_size * seq_len, self.dim)).unwrap();

        let q_proj = q_flat.dot(&self.wq);
        let k_proj = k_flat.dot(&self.wk);
        let v_proj = v_flat.dot(&self.wv);

        // Reshape and compute attention
        let output = v_proj; // Placeholder
        let output = output.into_shape((batch_size, seq_len, self.dim)).unwrap();

        let attn_weights = Array3::ones((batch_size, self.num_heads, seq_len, seq_len).f())
            .into_shape((batch_size, self.num_heads, seq_len))
            .unwrap(); // Placeholder

        (output, attn_weights)
    }

    fn expert_type(&self) -> ExpertType {
        ExpertType::Standard
    }

    fn output_dim(&self) -> usize {
        self.dim
    }

    fn num_parameters(&self) -> usize {
        self.wq.len() + self.wk.len() + self.wv.len()
    }
}

/// Hyperbolic attention expert
pub struct HyperbolicAttentionExpert {
    dim: usize,
    num_heads: usize,
    dropout: f32,
    curvature: f32,
}

impl HyperbolicAttentionExpert {
    pub fn new(dim: usize, num_heads: usize, dropout: f32) -> Self {
        Self {
            dim,
            num_heads,
            dropout,
            curvature: -1.0,
        }
    }
}

impl AttentionExpert for HyperbolicAttentionExpert {
    fn forward(
        &self,
        queries: ArrayView3<f32>,
        _keys: ArrayView3<f32>,
        _values: ArrayView3<f32>,
        _edge_features: Option<ArrayView3<f32>>,
    ) -> (Array3<f32>, Array3<f32>) {
        // Placeholder - would implement hyperbolic geometry attention
        let (batch_size, seq_len, _) = queries.dim();
        (
            Array3::zeros((batch_size, seq_len, self.dim)),
            Array3::zeros((batch_size, self.num_heads, seq_len)),
        )
    }

    fn expert_type(&self) -> ExpertType {
        ExpertType::Hyperbolic
    }

    fn output_dim(&self) -> usize {
        self.dim
    }

    fn num_parameters(&self) -> usize {
        0 // Placeholder
    }
}

/// Linear complexity attention expert (e.g., Performer-style)
pub struct LinearAttentionExpert {
    dim: usize,
    num_heads: usize,
    dropout: f32,
}

impl LinearAttentionExpert {
    pub fn new(dim: usize, num_heads: usize, dropout: f32) -> Self {
        Self { dim, num_heads, dropout }
    }
}

impl AttentionExpert for LinearAttentionExpert {
    fn forward(
        &self,
        queries: ArrayView3<f32>,
        _keys: ArrayView3<f32>,
        _values: ArrayView3<f32>,
        _edge_features: Option<ArrayView3<f32>>,
    ) -> (Array3<f32>, Array3<f32>) {
        // Placeholder - would implement kernel-based linear attention
        let (batch_size, seq_len, _) = queries.dim();
        (
            Array3::zeros((batch_size, seq_len, self.dim)),
            Array3::zeros((batch_size, self.num_heads, seq_len)),
        )
    }

    fn expert_type(&self) -> ExpertType {
        ExpertType::Linear
    }

    fn output_dim(&self) -> usize {
        self.dim
    }

    fn num_parameters(&self) -> usize {
        0 // Placeholder
    }
}

/// Edge-featured attention expert for graphs
pub struct EdgeAttentionExpert {
    dim: usize,
    num_heads: usize,
    dropout: f32,
}

impl EdgeAttentionExpert {
    pub fn new(dim: usize, num_heads: usize, dropout: f32) -> Self {
        Self { dim, num_heads, dropout }
    }
}

impl AttentionExpert for EdgeAttentionExpert {
    fn forward(
        &self,
        queries: ArrayView3<f32>,
        _keys: ArrayView3<f32>,
        _values: ArrayView3<f32>,
        edge_features: Option<ArrayView3<f32>>,
    ) -> (Array3<f32>, Array3<f32>) {
        // Placeholder - would incorporate edge features into attention
        let (batch_size, seq_len, _) = queries.dim();
        let _ = edge_features; // Would use this
        (
            Array3::zeros((batch_size, seq_len, self.dim)),
            Array3::zeros((batch_size, self.num_heads, seq_len)),
        )
    }

    fn expert_type(&self) -> ExpertType {
        ExpertType::EdgeFeatured
    }

    fn output_dim(&self) -> usize {
        self.dim
    }

    fn num_parameters(&self) -> usize {
        0 // Placeholder
    }
}

Training Considerations

1. Loss Function

/// Complete loss for training MoE Attention
pub struct MoEAttentionLoss {
    /// Main task loss (e.g., classification, regression)
    task_loss_fn: Box<dyn Fn(&Array2<f32>, &Array2<f32>) -> f32>,
    /// Load balancing loss coefficient
    load_balance_coef: f32,
}

impl MoEAttentionLoss {
    pub fn compute(
        &self,
        predictions: &Array2<f32>,
        targets: &Array2<f32>,
        aux_loss: f32,
    ) -> f32 {
        let task_loss = (self.task_loss_fn)(predictions, targets);
        let total_loss = task_loss + self.load_balance_coef * aux_loss;
        total_loss
    }
}

2. Training Loop Integration

/// Training step for MoE Attention
pub fn training_step(
    model: &mut MoEAttention,
    queries: ArrayView3<f32>,
    keys: ArrayView3<f32>,
    values: ArrayView3<f32>,
    targets: &Array2<f32>,
    loss_fn: &MoEAttentionLoss,
    learning_rate: f32,
) -> f32 {
    // Forward pass
    model.train();
    let output = model.forward(queries, keys, values, None);

    // Compute loss
    let predictions = output.output.slice(s![.., 0, ..]).to_owned(); // Simplified
    let total_loss = loss_fn.compute(&predictions, targets, output.aux_loss);

    // Backward pass (would use autograd in real implementation)
    // gradients = compute_gradients(total_loss);
    // update_parameters(model, gradients, learning_rate);

    total_loss
}

3. Optimization Strategies

/// Optimizer configuration for MoE training
pub struct MoEOptimizer {
    /// Learning rate for routing network
    router_lr: f32,
    /// Learning rate for experts
    expert_lr: f32,
    /// Learning rate for output projection
    output_lr: f32,
    /// Weight decay
    weight_decay: f32,
    /// Gradient clipping threshold
    grad_clip: f32,
}

impl Default for MoEOptimizer {
    fn default() -> Self {
        Self {
            router_lr: 1e-3,
            expert_lr: 5e-4,
            output_lr: 5e-4,
            weight_decay: 1e-5,
            grad_clip: 1.0,
        }
    }
}

4. Expert Specialization Monitoring

/// Monitor expert specialization during training
pub struct ExpertSpecializationMonitor {
    /// Expert usage statistics [num_experts]
    usage_counts: Array1<usize>,
    /// Average gating weights per expert [num_experts]
    avg_gates: Array1<f32>,
    /// Number of steps
    num_steps: usize,
}

impl ExpertSpecializationMonitor {
    pub fn new(num_experts: usize) -> Self {
        Self {
            usage_counts: Array1::zeros(num_experts),
            avg_gates: Array1::zeros(num_experts),
            num_steps: 0,
        }
    }

    pub fn update(&mut self, gates: ArrayView3<f32>) {
        let num_experts = gates.dim().2;

        for e in 0..num_experts {
            let expert_gates = gates.slice(s![.., .., e]);
            let usage = expert_gates.iter().filter(|&&x| x > 0.0).count();
            let avg_gate = expert_gates.mean().unwrap_or(0.0);

            self.usage_counts[e] += usage;
            self.avg_gates[e] += avg_gate;
        }

        self.num_steps += 1;
    }

    pub fn get_statistics(&self) -> ExpertStats {
        let avg_usage = self.usage_counts.mapv(|x| x as f32 / self.num_steps as f32);
        let avg_gates = self.avg_gates.mapv(|x| x / self.num_steps as f32);

        ExpertStats {
            avg_usage,
            avg_gates,
        }
    }
}

pub struct ExpertStats {
    pub avg_usage: Array1<f32>,
    pub avg_gates: Array1<f32>,
}

Usage Examples

Basic Usage

fn example_basic_usage() {
    // Create configuration
    let config = MoEAttentionConfig::builder()
        .dim(512)
        .num_heads(8)
        .num_experts(4)
        .top_k(2)
        .expert_types(vec![
            ExpertType::Standard,
            ExpertType::Hyperbolic,
            ExpertType::Linear,
            ExpertType::EdgeFeatured,
        ])
        .router_temperature(1.0)
        .load_balance_coef(0.01)
        .build();

    // Create MoE attention module
    let mut moe_attn = MoEAttention::new(config);

    // Prepare inputs
    let batch_size = 32;
    let seq_len = 128;
    let dim = 512;

    let queries = Array3::<f32>::zeros((batch_size, seq_len, dim));
    let keys = Array3::<f32>::zeros((batch_size, seq_len, dim));
    let values = Array3::<f32>::zeros((batch_size, seq_len, dim));

    // Forward pass
    moe_attn.train();
    let output = moe_attn.forward(
        queries.view(),
        keys.view(),
        values.view(),
        None,
    );

    println!("Output shape: {:?}", output.output.dim());
    println!("Auxiliary loss: {:.6}", output.aux_loss);
    println!("Number of parameters: {}", moe_attn.num_parameters());
}

Advanced Training Loop

fn example_training_loop() {
    let config = MoEAttentionConfig::default();
    let mut model = MoEAttention::new(config);
    let mut monitor = ExpertSpecializationMonitor::new(4);

    let num_epochs = 10;
    let batch_size = 32;
    let seq_len = 128;
    let dim = 512;

    for epoch in 0..num_epochs {
        let mut epoch_loss = 0.0;
        let num_batches = 100;

        for batch in 0..num_batches {
            // Generate dummy data
            let queries = Array3::<f32>::random((batch_size, seq_len, dim), Uniform::new(0.0, 1.0));
            let keys = Array3::<f32>::random((batch_size, seq_len, dim), Uniform::new(0.0, 1.0));
            let values = Array3::<f32>::random((batch_size, seq_len, dim), Uniform::new(0.0, 1.0));

            // Forward pass
            let output = model.forward(queries.view(), keys.view(), values.view(), None);

            // Track expert usage
            monitor.update(output.gates.view());

            // Compute loss (simplified)
            let task_loss = output.output.mapv(|x| x * x).sum() / (batch_size * seq_len) as f32;
            let total_loss = task_loss + output.aux_loss;

            epoch_loss += total_loss;

            // Backward pass and optimization would go here
        }

        println!("Epoch {}: Loss = {:.6}", epoch, epoch_loss / num_batches as f32);

        // Print expert statistics
        if epoch % 5 == 0 {
            let stats = monitor.get_statistics();
            println!("Expert usage: {:?}", stats.avg_usage);
            println!("Expert gates: {:?}", stats.avg_gates);
        }
    }
}

Integration with Graph Neural Networks

/// Integrate MoE Attention into GNN layer
pub struct MoEGNNLayer {
    moe_attention: MoEAttention,
    node_transform: Array2<f32>,
    edge_encoder: Array2<f32>,
}

impl MoEGNNLayer {
    pub fn forward(
        &mut self,
        node_features: ArrayView3<f32>,
        edge_features: ArrayView3<f32>,
        adjacency: ArrayView2<f32>,
    ) -> Array3<f32> {
        let (batch_size, num_nodes, node_dim) = node_features.dim();

        // Transform node features
        let queries = node_features;
        let keys = node_features;
        let values = node_features;

        // Encode edge features
        // edge_encoded = edge_features.dot(&self.edge_encoder)

        // Apply MoE attention
        let output = self.moe_attention.forward(
            queries,
            keys,
            values,
            Some(edge_features),
        );

        output.output
    }
}

Performance Characteristics

Computational Complexity

  • Routing: O(B × S × (D × H + H × E)) where B=batch, S=sequence, D=dim, H=hidden, E=experts
  • Expert Forward: O(B ×× D) per expert (for standard attention)
  • Total: O(B ×× D × k) where k=top-k experts activated

Memory Usage

  • Router: (D × H + H × E) parameters
  • Experts: ~(3 × D² + D) parameters per expert × E experts
  • Activations: O(B × S × D × k) during forward pass

Load Balancing Benefits

  • Prevents expert collapse (all tokens routed to one expert)
  • Encourages specialization while maintaining coverage
  • Typical coefficient: 0.01-0.1 depending on task

Best Practices

  1. Expert Selection

    • Start with k=2 for most tasks
    • Increase k for more complex tasks requiring multiple perspectives
    • Monitor expert usage to ensure balanced routing
  2. Temperature Tuning

    • Start with temperature=1.0
    • Decrease (0.5-0.8) for sharper routing (more specialization)
    • Increase (1.2-2.0) for softer routing (more collaboration)
  3. Load Balancing

    • Use coefficient 0.01 as baseline
    • Increase if experts are underutilized
    • Decrease if routing is too uniform
  4. Expert Types

    • Include diverse expert types for different inductive biases
    • Standard for general patterns
    • Hyperbolic for hierarchical structures
    • Linear for long sequences
    • EdgeFeatured for relational data

References

  • Switch Transformers: Scaling to Trillion Parameter Models (Fedus et al., 2021)
  • GShard: Scaling Giant Models with Conditional Computation (Lepikhin et al., 2020)
  • Mixture-of-Experts with Expert Choice Routing (Zhou et al., 2022)
  • ST-MoE: Designing Stable and Transferable Sparse Expert Models (Zoph et al., 2022)