git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
211 lines
6.0 KiB
Rust
211 lines
6.0 KiB
Rust
//! Router implementations for MoE expert selection
|
|
|
|
use crate::utils::stable_softmax;
|
|
|
|
/// Router trait for expert selection
|
|
pub trait Router: Send + Sync {
|
|
/// Route input to experts, returning (expert_idx, weight) pairs
|
|
fn route(&self, x: &[f32]) -> Vec<(usize, f32)>;
|
|
|
|
/// Get number of experts
|
|
fn num_experts(&self) -> usize;
|
|
}
|
|
|
|
/// Top-K routing decision
|
|
#[derive(Clone, Debug)]
|
|
pub struct TopKRouting {
|
|
/// Selected experts with their normalized weights
|
|
pub selections: Vec<(usize, f32)>,
|
|
}
|
|
|
|
/// Learned router with softmax gating
|
|
pub struct LearnedRouter {
|
|
num_experts: usize,
|
|
dim: usize,
|
|
top_k: usize,
|
|
/// Gate weights: [num_experts x dim]
|
|
gate_weights: Vec<f32>,
|
|
}
|
|
|
|
impl LearnedRouter {
|
|
/// Create new learned router
|
|
pub fn new(num_experts: usize, dim: usize, top_k: usize) -> Self {
|
|
// Initialize gate weights with Xavier initialization
|
|
let scale = (2.0 / (dim + num_experts) as f32).sqrt();
|
|
let mut seed = 42u64;
|
|
|
|
let gate_weights: Vec<f32> = (0..num_experts * dim)
|
|
.map(|_| {
|
|
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
|
|
let u = (seed as f32) / (u64::MAX as f32);
|
|
(u - 0.5) * 2.0 * scale
|
|
})
|
|
.collect();
|
|
|
|
Self {
|
|
num_experts,
|
|
dim,
|
|
top_k: top_k.min(num_experts),
|
|
gate_weights,
|
|
}
|
|
}
|
|
|
|
/// Compute raw gate logits
|
|
fn compute_logits(&self, x: &[f32]) -> Vec<f32> {
|
|
(0..self.num_experts)
|
|
.map(|i| {
|
|
x.iter()
|
|
.enumerate()
|
|
.map(|(j, &xj)| xj * self.gate_weights[i * self.dim + j])
|
|
.sum()
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
/// Compute gate probabilities
|
|
pub fn compute_gate(&self, x: &[f32]) -> Vec<f32> {
|
|
let logits = self.compute_logits(x);
|
|
stable_softmax(&logits)
|
|
}
|
|
|
|
/// Compute load balancing loss for batch
|
|
pub fn load_balance_loss(&self, routing_decisions: &[TopKRouting]) -> f32 {
|
|
if routing_decisions.is_empty() {
|
|
return 0.0;
|
|
}
|
|
|
|
let batch_size = routing_decisions.len() as f32;
|
|
|
|
// Count how many times each expert is used
|
|
let mut expert_counts = vec![0.0f32; self.num_experts];
|
|
let mut total_weight = vec![0.0f32; self.num_experts];
|
|
|
|
for decision in routing_decisions {
|
|
for &(expert_idx, weight) in &decision.selections {
|
|
expert_counts[expert_idx] += 1.0;
|
|
total_weight[expert_idx] += weight;
|
|
}
|
|
}
|
|
|
|
// Compute auxiliary loss: encourage uniform distribution
|
|
let _avg_count = expert_counts.iter().sum::<f32>() / self.num_experts as f32;
|
|
let _avg_weight = total_weight.iter().sum::<f32>() / self.num_experts as f32;
|
|
|
|
// CV-squared loss from Switch Transformer paper
|
|
let count_var: f32 = expert_counts
|
|
.iter()
|
|
.map(|c| (c / batch_size - 1.0 / self.num_experts as f32).powi(2))
|
|
.sum();
|
|
|
|
self.num_experts as f32 * count_var
|
|
}
|
|
|
|
/// Update gate weights (for training)
|
|
pub fn update_weights(&mut self, gradients: &[f32], learning_rate: f32) {
|
|
for (w, g) in self.gate_weights.iter_mut().zip(gradients.iter()) {
|
|
*w -= learning_rate * g;
|
|
}
|
|
}
|
|
|
|
/// Get expert usage statistics
|
|
pub fn expert_statistics(&self, routing_decisions: &[TopKRouting]) -> Vec<f32> {
|
|
let mut counts = vec![0.0f32; self.num_experts];
|
|
|
|
for decision in routing_decisions {
|
|
for &(expert_idx, _) in &decision.selections {
|
|
counts[expert_idx] += 1.0;
|
|
}
|
|
}
|
|
|
|
let total: f32 = counts.iter().sum();
|
|
if total > 0.0 {
|
|
counts.iter_mut().for_each(|c| *c /= total);
|
|
}
|
|
|
|
counts
|
|
}
|
|
}
|
|
|
|
impl Router for LearnedRouter {
|
|
fn route(&self, x: &[f32]) -> Vec<(usize, f32)> {
|
|
let probs = self.compute_gate(x);
|
|
|
|
// Get top-k indices
|
|
let mut indexed: Vec<(usize, f32)> = probs.into_iter().enumerate().collect();
|
|
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
|
|
|
// Take top-k and renormalize
|
|
let top_k: Vec<(usize, f32)> = indexed.into_iter().take(self.top_k).collect();
|
|
let sum: f32 = top_k.iter().map(|(_, p)| p).sum();
|
|
|
|
if sum > 1e-8 {
|
|
top_k.into_iter().map(|(i, p)| (i, p / sum)).collect()
|
|
} else {
|
|
// Fallback: uniform over top-k
|
|
top_k
|
|
.into_iter()
|
|
.map(|(i, _)| (i, 1.0 / self.top_k as f32))
|
|
.collect()
|
|
}
|
|
}
|
|
|
|
fn num_experts(&self) -> usize {
|
|
self.num_experts
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_learned_router() {
|
|
let router = LearnedRouter::new(4, 64, 2);
|
|
|
|
let x = vec![0.5; 64];
|
|
let routes = router.route(&x);
|
|
|
|
assert_eq!(routes.len(), 2);
|
|
|
|
// Weights should sum to 1
|
|
let sum: f32 = routes.iter().map(|(_, w)| w).sum();
|
|
assert!((sum - 1.0).abs() < 1e-5);
|
|
}
|
|
|
|
#[test]
|
|
fn test_load_balance_loss() {
|
|
let router = LearnedRouter::new(4, 32, 2);
|
|
|
|
// Simulate routing decisions
|
|
let decisions: Vec<TopKRouting> = (0..100)
|
|
.map(|i| TopKRouting {
|
|
selections: vec![(i % 4, 0.6), ((i + 1) % 4, 0.4)],
|
|
})
|
|
.collect();
|
|
|
|
let loss = router.load_balance_loss(&decisions);
|
|
assert!(loss >= 0.0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_expert_statistics() {
|
|
let router = LearnedRouter::new(4, 32, 2);
|
|
|
|
let decisions: Vec<TopKRouting> = vec![
|
|
TopKRouting {
|
|
selections: vec![(0, 0.6), (1, 0.4)],
|
|
},
|
|
TopKRouting {
|
|
selections: vec![(0, 0.5), (2, 0.5)],
|
|
},
|
|
];
|
|
|
|
let stats = router.expert_statistics(&decisions);
|
|
assert_eq!(stats.len(), 4);
|
|
|
|
// Should sum to 1
|
|
let sum: f32 = stats.iter().sum();
|
|
assert!((sum - 1.0).abs() < 1e-5);
|
|
}
|
|
}
|