Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
This commit is contained in:
210
crates/ruvector-attention/src/moe/router.rs
Normal file
210
crates/ruvector-attention/src/moe/router.rs
Normal file
@@ -0,0 +1,210 @@
|
||||
//! Router implementations for MoE expert selection
|
||||
|
||||
use crate::utils::stable_softmax;
|
||||
|
||||
/// Router trait for expert selection
|
||||
pub trait Router: Send + Sync {
|
||||
/// Route input to experts, returning (expert_idx, weight) pairs
|
||||
fn route(&self, x: &[f32]) -> Vec<(usize, f32)>;
|
||||
|
||||
/// Get number of experts
|
||||
fn num_experts(&self) -> usize;
|
||||
}
|
||||
|
||||
/// Top-K routing decision
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct TopKRouting {
|
||||
/// Selected experts with their normalized weights
|
||||
pub selections: Vec<(usize, f32)>,
|
||||
}
|
||||
|
||||
/// Learned router with softmax gating
|
||||
pub struct LearnedRouter {
|
||||
num_experts: usize,
|
||||
dim: usize,
|
||||
top_k: usize,
|
||||
/// Gate weights: [num_experts x dim]
|
||||
gate_weights: Vec<f32>,
|
||||
}
|
||||
|
||||
impl LearnedRouter {
|
||||
/// Create new learned router
|
||||
pub fn new(num_experts: usize, dim: usize, top_k: usize) -> Self {
|
||||
// Initialize gate weights with Xavier initialization
|
||||
let scale = (2.0 / (dim + num_experts) as f32).sqrt();
|
||||
let mut seed = 42u64;
|
||||
|
||||
let gate_weights: Vec<f32> = (0..num_experts * dim)
|
||||
.map(|_| {
|
||||
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
|
||||
let u = (seed as f32) / (u64::MAX as f32);
|
||||
(u - 0.5) * 2.0 * scale
|
||||
})
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
num_experts,
|
||||
dim,
|
||||
top_k: top_k.min(num_experts),
|
||||
gate_weights,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute raw gate logits
|
||||
fn compute_logits(&self, x: &[f32]) -> Vec<f32> {
|
||||
(0..self.num_experts)
|
||||
.map(|i| {
|
||||
x.iter()
|
||||
.enumerate()
|
||||
.map(|(j, &xj)| xj * self.gate_weights[i * self.dim + j])
|
||||
.sum()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Compute gate probabilities
|
||||
pub fn compute_gate(&self, x: &[f32]) -> Vec<f32> {
|
||||
let logits = self.compute_logits(x);
|
||||
stable_softmax(&logits)
|
||||
}
|
||||
|
||||
/// Compute load balancing loss for batch
|
||||
pub fn load_balance_loss(&self, routing_decisions: &[TopKRouting]) -> f32 {
|
||||
if routing_decisions.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let batch_size = routing_decisions.len() as f32;
|
||||
|
||||
// Count how many times each expert is used
|
||||
let mut expert_counts = vec![0.0f32; self.num_experts];
|
||||
let mut total_weight = vec![0.0f32; self.num_experts];
|
||||
|
||||
for decision in routing_decisions {
|
||||
for &(expert_idx, weight) in &decision.selections {
|
||||
expert_counts[expert_idx] += 1.0;
|
||||
total_weight[expert_idx] += weight;
|
||||
}
|
||||
}
|
||||
|
||||
// Compute auxiliary loss: encourage uniform distribution
|
||||
let _avg_count = expert_counts.iter().sum::<f32>() / self.num_experts as f32;
|
||||
let _avg_weight = total_weight.iter().sum::<f32>() / self.num_experts as f32;
|
||||
|
||||
// CV-squared loss from Switch Transformer paper
|
||||
let count_var: f32 = expert_counts
|
||||
.iter()
|
||||
.map(|c| (c / batch_size - 1.0 / self.num_experts as f32).powi(2))
|
||||
.sum();
|
||||
|
||||
self.num_experts as f32 * count_var
|
||||
}
|
||||
|
||||
/// Update gate weights (for training)
|
||||
pub fn update_weights(&mut self, gradients: &[f32], learning_rate: f32) {
|
||||
for (w, g) in self.gate_weights.iter_mut().zip(gradients.iter()) {
|
||||
*w -= learning_rate * g;
|
||||
}
|
||||
}
|
||||
|
||||
/// Get expert usage statistics
|
||||
pub fn expert_statistics(&self, routing_decisions: &[TopKRouting]) -> Vec<f32> {
|
||||
let mut counts = vec![0.0f32; self.num_experts];
|
||||
|
||||
for decision in routing_decisions {
|
||||
for &(expert_idx, _) in &decision.selections {
|
||||
counts[expert_idx] += 1.0;
|
||||
}
|
||||
}
|
||||
|
||||
let total: f32 = counts.iter().sum();
|
||||
if total > 0.0 {
|
||||
counts.iter_mut().for_each(|c| *c /= total);
|
||||
}
|
||||
|
||||
counts
|
||||
}
|
||||
}
|
||||
|
||||
impl Router for LearnedRouter {
|
||||
fn route(&self, x: &[f32]) -> Vec<(usize, f32)> {
|
||||
let probs = self.compute_gate(x);
|
||||
|
||||
// Get top-k indices
|
||||
let mut indexed: Vec<(usize, f32)> = probs.into_iter().enumerate().collect();
|
||||
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
|
||||
// Take top-k and renormalize
|
||||
let top_k: Vec<(usize, f32)> = indexed.into_iter().take(self.top_k).collect();
|
||||
let sum: f32 = top_k.iter().map(|(_, p)| p).sum();
|
||||
|
||||
if sum > 1e-8 {
|
||||
top_k.into_iter().map(|(i, p)| (i, p / sum)).collect()
|
||||
} else {
|
||||
// Fallback: uniform over top-k
|
||||
top_k
|
||||
.into_iter()
|
||||
.map(|(i, _)| (i, 1.0 / self.top_k as f32))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
fn num_experts(&self) -> usize {
|
||||
self.num_experts
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_learned_router() {
|
||||
let router = LearnedRouter::new(4, 64, 2);
|
||||
|
||||
let x = vec![0.5; 64];
|
||||
let routes = router.route(&x);
|
||||
|
||||
assert_eq!(routes.len(), 2);
|
||||
|
||||
// Weights should sum to 1
|
||||
let sum: f32 = routes.iter().map(|(_, w)| w).sum();
|
||||
assert!((sum - 1.0).abs() < 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_balance_loss() {
|
||||
let router = LearnedRouter::new(4, 32, 2);
|
||||
|
||||
// Simulate routing decisions
|
||||
let decisions: Vec<TopKRouting> = (0..100)
|
||||
.map(|i| TopKRouting {
|
||||
selections: vec![(i % 4, 0.6), ((i + 1) % 4, 0.4)],
|
||||
})
|
||||
.collect();
|
||||
|
||||
let loss = router.load_balance_loss(&decisions);
|
||||
assert!(loss >= 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expert_statistics() {
|
||||
let router = LearnedRouter::new(4, 32, 2);
|
||||
|
||||
let decisions: Vec<TopKRouting> = vec![
|
||||
TopKRouting {
|
||||
selections: vec![(0, 0.6), (1, 0.4)],
|
||||
},
|
||||
TopKRouting {
|
||||
selections: vec![(0, 0.5), (2, 0.5)],
|
||||
},
|
||||
];
|
||||
|
||||
let stats = router.expert_statistics(&decisions);
|
||||
assert_eq!(stats.len(), 4);
|
||||
|
||||
// Should sum to 1
|
||||
let sum: f32 = stats.iter().sum();
|
||||
assert!((sum - 1.0).abs() < 1e-5);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user