Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
This commit is contained in:
299
crates/ruvector-attention/src/moe/expert.rs
Normal file
299
crates/ruvector-attention/src/moe/expert.rs
Normal file
@@ -0,0 +1,299 @@
|
||||
//! Expert implementations for MoE attention
|
||||
|
||||
use crate::error::AttentionResult;
|
||||
use crate::utils::stable_softmax;
|
||||
|
||||
/// Type of expert
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub enum ExpertType {
|
||||
/// Standard scaled dot-product
|
||||
Standard,
|
||||
/// Hyperbolic attention
|
||||
Hyperbolic,
|
||||
/// Linear attention
|
||||
Linear,
|
||||
}
|
||||
|
||||
/// Expert trait for attention computation
|
||||
pub trait Expert: Send + Sync {
|
||||
/// Compute attention for this expert
|
||||
fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: &[&[f32]],
|
||||
values: &[&[f32]],
|
||||
) -> AttentionResult<Vec<f32>>;
|
||||
|
||||
/// Get expert type
|
||||
fn expert_type(&self) -> ExpertType;
|
||||
|
||||
/// Get dimension
|
||||
fn dim(&self) -> usize;
|
||||
}
|
||||
|
||||
/// Standard scaled dot-product expert
|
||||
pub struct StandardExpert {
|
||||
dim: usize,
|
||||
scale: f32,
|
||||
}
|
||||
|
||||
impl StandardExpert {
|
||||
pub fn new(dim: usize) -> Self {
|
||||
Self {
|
||||
dim,
|
||||
scale: 1.0 / (dim as f32).sqrt(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Expert for StandardExpert {
|
||||
fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: &[&[f32]],
|
||||
values: &[&[f32]],
|
||||
) -> AttentionResult<Vec<f32>> {
|
||||
// Compute attention scores
|
||||
let scores: Vec<f32> = keys
|
||||
.iter()
|
||||
.map(|k| {
|
||||
query
|
||||
.iter()
|
||||
.zip(k.iter())
|
||||
.map(|(q, ki)| q * ki)
|
||||
.sum::<f32>()
|
||||
* self.scale
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Softmax
|
||||
let weights = stable_softmax(&scores);
|
||||
|
||||
// Weighted sum
|
||||
let mut output = vec![0.0f32; self.dim];
|
||||
for (weight, value) in weights.iter().zip(values.iter()) {
|
||||
for (o, v) in output.iter_mut().zip(value.iter()) {
|
||||
*o += weight * v;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn expert_type(&self) -> ExpertType {
|
||||
ExpertType::Standard
|
||||
}
|
||||
|
||||
fn dim(&self) -> usize {
|
||||
self.dim
|
||||
}
|
||||
}
|
||||
|
||||
/// Hyperbolic expert using Poincaré distance
|
||||
pub struct HyperbolicExpert {
|
||||
dim: usize,
|
||||
curvature: f32,
|
||||
}
|
||||
|
||||
impl HyperbolicExpert {
|
||||
pub fn new(dim: usize, curvature: f32) -> Self {
|
||||
Self { dim, curvature }
|
||||
}
|
||||
|
||||
fn poincare_distance(&self, u: &[f32], v: &[f32]) -> f32 {
|
||||
let c = self.curvature.abs();
|
||||
let sqrt_c = c.sqrt();
|
||||
|
||||
let diff_sq: f32 = u.iter().zip(v.iter()).map(|(a, b)| (a - b).powi(2)).sum();
|
||||
let norm_u_sq: f32 = u.iter().map(|x| x * x).sum();
|
||||
let norm_v_sq: f32 = v.iter().map(|x| x * x).sum();
|
||||
|
||||
let denom = (1.0 - c * norm_u_sq).max(1e-7) * (1.0 - c * norm_v_sq).max(1e-7);
|
||||
let arg = 1.0 + 2.0 * c * diff_sq / denom;
|
||||
|
||||
(1.0 / sqrt_c) * arg.max(1.0).acosh()
|
||||
}
|
||||
}
|
||||
|
||||
impl Expert for HyperbolicExpert {
|
||||
fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: &[&[f32]],
|
||||
values: &[&[f32]],
|
||||
) -> AttentionResult<Vec<f32>> {
|
||||
// Use negative Poincaré distance as similarity
|
||||
let scores: Vec<f32> = keys
|
||||
.iter()
|
||||
.map(|k| -self.poincare_distance(query, k))
|
||||
.collect();
|
||||
|
||||
let weights = stable_softmax(&scores);
|
||||
|
||||
let mut output = vec![0.0f32; self.dim];
|
||||
for (weight, value) in weights.iter().zip(values.iter()) {
|
||||
for (o, v) in output.iter_mut().zip(value.iter()) {
|
||||
*o += weight * v;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn expert_type(&self) -> ExpertType {
|
||||
ExpertType::Hyperbolic
|
||||
}
|
||||
|
||||
fn dim(&self) -> usize {
|
||||
self.dim
|
||||
}
|
||||
}
|
||||
|
||||
/// Linear attention expert with random features
|
||||
pub struct LinearExpert {
|
||||
dim: usize,
|
||||
num_features: usize,
|
||||
random_features: Vec<f32>,
|
||||
}
|
||||
|
||||
impl LinearExpert {
|
||||
pub fn new(dim: usize, num_features: usize) -> Self {
|
||||
use std::f32::consts::PI;
|
||||
|
||||
// Generate random features
|
||||
let mut features = Vec::with_capacity(num_features * dim);
|
||||
let mut seed = 123u64;
|
||||
|
||||
for _ in 0..((num_features * dim + 1) / 2) {
|
||||
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
|
||||
let u1 = (seed as f32) / (u64::MAX as f32);
|
||||
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
|
||||
let u2 = (seed as f32) / (u64::MAX as f32);
|
||||
|
||||
let r = (-2.0 * u1.max(1e-10).ln()).sqrt();
|
||||
let theta = 2.0 * PI * u2;
|
||||
|
||||
features.push(r * theta.cos() / (dim as f32).sqrt());
|
||||
if features.len() < num_features * dim {
|
||||
features.push(r * theta.sin() / (dim as f32).sqrt());
|
||||
}
|
||||
}
|
||||
features.truncate(num_features * dim);
|
||||
|
||||
Self {
|
||||
dim,
|
||||
num_features,
|
||||
random_features: features,
|
||||
}
|
||||
}
|
||||
|
||||
fn feature_map(&self, x: &[f32]) -> Vec<f32> {
|
||||
(0..self.num_features)
|
||||
.map(|i| {
|
||||
let proj: f32 = x
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(j, &xj)| xj * self.random_features[i * self.dim + j])
|
||||
.sum();
|
||||
let norm_sq: f32 = x.iter().map(|xi| xi * xi).sum();
|
||||
(proj - norm_sq / 2.0).exp() / (self.num_features as f32).sqrt()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl Expert for LinearExpert {
|
||||
fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: &[&[f32]],
|
||||
values: &[&[f32]],
|
||||
) -> AttentionResult<Vec<f32>> {
|
||||
let phi_q = self.feature_map(query);
|
||||
let value_dim = values.get(0).map(|v| v.len()).unwrap_or(self.dim);
|
||||
|
||||
let mut kv_sum = vec![0.0f32; self.num_features * value_dim];
|
||||
let mut k_sum = vec![0.0f32; self.num_features];
|
||||
|
||||
for (key, value) in keys.iter().zip(values.iter()) {
|
||||
let phi_k = self.feature_map(key);
|
||||
for (i, &phi_ki) in phi_k.iter().enumerate() {
|
||||
for (j, &vj) in value.iter().enumerate() {
|
||||
kv_sum[i * value_dim + j] += phi_ki * vj;
|
||||
}
|
||||
k_sum[i] += phi_ki;
|
||||
}
|
||||
}
|
||||
|
||||
let mut output = vec![0.0f32; value_dim];
|
||||
let mut normalizer = 0.0f32;
|
||||
|
||||
for (i, &phi_qi) in phi_q.iter().enumerate() {
|
||||
for (j, out_j) in output.iter_mut().enumerate() {
|
||||
*out_j += phi_qi * kv_sum[i * value_dim + j];
|
||||
}
|
||||
normalizer += phi_qi * k_sum[i];
|
||||
}
|
||||
|
||||
if normalizer.abs() > 1e-8 {
|
||||
output.iter_mut().for_each(|x| *x /= normalizer);
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn expert_type(&self) -> ExpertType {
|
||||
ExpertType::Linear
|
||||
}
|
||||
|
||||
fn dim(&self) -> usize {
|
||||
self.dim
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_standard_expert() {
|
||||
let expert = StandardExpert::new(64);
|
||||
let query = vec![0.5; 64];
|
||||
let keys: Vec<Vec<f32>> = vec![vec![0.3; 64]; 10];
|
||||
let values: Vec<Vec<f32>> = vec![vec![1.0; 64]; 10];
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let result = expert.compute(&query, &keys_refs, &values_refs).unwrap();
|
||||
assert_eq!(result.len(), 64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hyperbolic_expert() {
|
||||
let expert = HyperbolicExpert::new(32, 1.0);
|
||||
let query = vec![0.1; 32]; // Small values to stay in ball
|
||||
let keys: Vec<Vec<f32>> = vec![vec![0.1; 32]; 5];
|
||||
let values: Vec<Vec<f32>> = vec![vec![1.0; 32]; 5];
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let result = expert.compute(&query, &keys_refs, &values_refs).unwrap();
|
||||
assert_eq!(result.len(), 32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_linear_expert() {
|
||||
let expert = LinearExpert::new(64, 32);
|
||||
let query = vec![0.5; 64];
|
||||
let keys: Vec<Vec<f32>> = vec![vec![0.3; 64]; 10];
|
||||
let values: Vec<Vec<f32>> = vec![vec![1.0; 64]; 10];
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let result = expert.compute(&query, &keys_refs, &values_refs).unwrap();
|
||||
assert_eq!(result.len(), 64);
|
||||
}
|
||||
}
|
||||
11
crates/ruvector-attention/src/moe/mod.rs
Normal file
11
crates/ruvector-attention/src/moe/mod.rs
Normal file
@@ -0,0 +1,11 @@
|
||||
//! Mixture of Experts (MoE) attention mechanisms
|
||||
//!
|
||||
//! This module provides MoE attention where different inputs route to specialized experts.
|
||||
|
||||
pub mod expert;
|
||||
pub mod moe_attention;
|
||||
pub mod router;
|
||||
|
||||
pub use expert::{Expert, ExpertType, HyperbolicExpert, LinearExpert, StandardExpert};
|
||||
pub use moe_attention::{MoEAttention, MoEConfig};
|
||||
pub use router::{LearnedRouter, Router, TopKRouting};
|
||||
262
crates/ruvector-attention/src/moe/moe_attention.rs
Normal file
262
crates/ruvector-attention/src/moe/moe_attention.rs
Normal file
@@ -0,0 +1,262 @@
|
||||
//! Mixture of Experts attention layer
|
||||
|
||||
use super::expert::{Expert, HyperbolicExpert, LinearExpert, StandardExpert};
|
||||
use super::router::{LearnedRouter, Router, TopKRouting};
|
||||
use crate::error::{AttentionError, AttentionResult};
|
||||
use crate::traits::Attention;
|
||||
|
||||
/// MoE configuration
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct MoEConfig {
|
||||
pub dim: usize,
|
||||
pub num_experts: usize,
|
||||
pub top_k: usize,
|
||||
pub expert_capacity: f32,
|
||||
pub jitter_noise: f32,
|
||||
}
|
||||
|
||||
impl Default for MoEConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
dim: 256,
|
||||
num_experts: 4,
|
||||
top_k: 2,
|
||||
expert_capacity: 1.25,
|
||||
jitter_noise: 0.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl MoEConfig {
|
||||
pub fn builder() -> MoEConfigBuilder {
|
||||
MoEConfigBuilder::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct MoEConfigBuilder {
|
||||
config: MoEConfig,
|
||||
}
|
||||
|
||||
impl MoEConfigBuilder {
|
||||
pub fn dim(mut self, dim: usize) -> Self {
|
||||
self.config.dim = dim;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn num_experts(mut self, n: usize) -> Self {
|
||||
self.config.num_experts = n;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn top_k(mut self, k: usize) -> Self {
|
||||
self.config.top_k = k;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn expert_capacity(mut self, c: f32) -> Self {
|
||||
self.config.expert_capacity = c;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn jitter_noise(mut self, j: f32) -> Self {
|
||||
self.config.jitter_noise = j;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> MoEConfig {
|
||||
self.config
|
||||
}
|
||||
}
|
||||
|
||||
/// Mixture of Experts attention
|
||||
pub struct MoEAttention {
|
||||
experts: Vec<Box<dyn Expert>>,
|
||||
router: LearnedRouter,
|
||||
config: MoEConfig,
|
||||
}
|
||||
|
||||
impl MoEAttention {
|
||||
/// Create new MoE attention
|
||||
pub fn new(config: MoEConfig) -> Self {
|
||||
// Create diverse experts
|
||||
let mut experts: Vec<Box<dyn Expert>> = Vec::new();
|
||||
|
||||
// Ensure we have at least num_experts
|
||||
let num_each = (config.num_experts + 2) / 3;
|
||||
|
||||
for _ in 0..num_each {
|
||||
experts.push(Box::new(StandardExpert::new(config.dim)));
|
||||
}
|
||||
for _ in 0..num_each {
|
||||
experts.push(Box::new(HyperbolicExpert::new(config.dim, 1.0)));
|
||||
}
|
||||
for _ in 0..num_each {
|
||||
experts.push(Box::new(LinearExpert::new(config.dim, config.dim / 4)));
|
||||
}
|
||||
|
||||
experts.truncate(config.num_experts);
|
||||
|
||||
let router = LearnedRouter::new(config.num_experts, config.dim, config.top_k);
|
||||
|
||||
Self {
|
||||
experts,
|
||||
router,
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute with auxiliary load balance loss
|
||||
pub fn compute_with_loss(
|
||||
&self,
|
||||
queries: &[&[f32]],
|
||||
keys: &[&[f32]],
|
||||
values: &[&[f32]],
|
||||
) -> AttentionResult<(Vec<Vec<f32>>, f32)> {
|
||||
let mut outputs = Vec::with_capacity(queries.len());
|
||||
let mut routing_decisions = Vec::with_capacity(queries.len());
|
||||
|
||||
for query in queries {
|
||||
let routes = self.router.route(query);
|
||||
routing_decisions.push(TopKRouting {
|
||||
selections: routes.clone(),
|
||||
});
|
||||
|
||||
let mut output = vec![0.0f32; self.config.dim];
|
||||
for (expert_idx, weight) in routes {
|
||||
let expert_output = self.experts[expert_idx].compute(query, keys, values)?;
|
||||
for (o, e) in output.iter_mut().zip(expert_output.iter()) {
|
||||
*o += weight * e;
|
||||
}
|
||||
}
|
||||
outputs.push(output);
|
||||
}
|
||||
|
||||
let loss = self.router.load_balance_loss(&routing_decisions);
|
||||
Ok((outputs, loss))
|
||||
}
|
||||
|
||||
/// Get expert usage statistics
|
||||
pub fn expert_statistics(&self, routing_decisions: &[TopKRouting]) -> Vec<f32> {
|
||||
self.router.expert_statistics(routing_decisions)
|
||||
}
|
||||
}
|
||||
|
||||
impl Attention for MoEAttention {
|
||||
fn compute(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: &[&[f32]],
|
||||
values: &[&[f32]],
|
||||
) -> AttentionResult<Vec<f32>> {
|
||||
if keys.is_empty() {
|
||||
return Err(AttentionError::InvalidConfig("Empty keys".to_string()));
|
||||
}
|
||||
if query.len() != self.config.dim {
|
||||
return Err(AttentionError::DimensionMismatch {
|
||||
expected: self.config.dim,
|
||||
actual: query.len(),
|
||||
});
|
||||
}
|
||||
|
||||
// Route query to experts
|
||||
let routes = self.router.route(query);
|
||||
|
||||
// Compute weighted sum of expert outputs
|
||||
let mut output = vec![0.0f32; self.config.dim];
|
||||
|
||||
for (expert_idx, weight) in routes {
|
||||
let expert_output = self.experts[expert_idx].compute(query, keys, values)?;
|
||||
for (o, e) in output.iter_mut().zip(expert_output.iter()) {
|
||||
*o += weight * e;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn compute_with_mask(
|
||||
&self,
|
||||
query: &[f32],
|
||||
keys: &[&[f32]],
|
||||
values: &[&[f32]],
|
||||
mask: Option<&[bool]>,
|
||||
) -> AttentionResult<Vec<f32>> {
|
||||
if let Some(m) = mask {
|
||||
let filtered: Vec<(usize, bool)> = m
|
||||
.iter()
|
||||
.copied()
|
||||
.enumerate()
|
||||
.filter(|(_, keep)| *keep)
|
||||
.collect();
|
||||
let filtered_keys: Vec<&[f32]> = filtered.iter().map(|(i, _)| keys[*i]).collect();
|
||||
let filtered_values: Vec<&[f32]> = filtered.iter().map(|(i, _)| values[*i]).collect();
|
||||
self.compute(query, &filtered_keys, &filtered_values)
|
||||
} else {
|
||||
self.compute(query, keys, values)
|
||||
}
|
||||
}
|
||||
|
||||
fn dim(&self) -> usize {
|
||||
self.config.dim
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_moe_attention() {
|
||||
let config = MoEConfig::builder().dim(64).num_experts(4).top_k(2).build();
|
||||
|
||||
let moe = MoEAttention::new(config);
|
||||
|
||||
let query = vec![0.5; 64];
|
||||
let keys: Vec<Vec<f32>> = vec![vec![0.3; 64]; 10];
|
||||
let values: Vec<Vec<f32>> = vec![vec![1.0; 64]; 10];
|
||||
|
||||
let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let result = moe.compute(&query, &keys_refs, &values_refs).unwrap();
|
||||
assert_eq!(result.len(), 64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_moe_with_loss() {
|
||||
let config = MoEConfig::builder().dim(32).num_experts(4).top_k(2).build();
|
||||
|
||||
let moe = MoEAttention::new(config);
|
||||
|
||||
let queries: Vec<Vec<f32>> = (0..10).map(|_| vec![0.5; 32]).collect();
|
||||
let keys: Vec<Vec<f32>> = vec![vec![0.3; 32]; 5];
|
||||
let values: Vec<Vec<f32>> = vec![vec![1.0; 32]; 5];
|
||||
|
||||
let query_refs: Vec<&[f32]> = queries.iter().map(|q| q.as_slice()).collect();
|
||||
let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
|
||||
let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
|
||||
|
||||
let (outputs, loss) = moe
|
||||
.compute_with_loss(&query_refs, &keys_refs, &values_refs)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(outputs.len(), 10);
|
||||
assert!(loss >= 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_builder() {
|
||||
let config = MoEConfig::builder()
|
||||
.dim(128)
|
||||
.num_experts(8)
|
||||
.top_k(3)
|
||||
.expert_capacity(1.5)
|
||||
.jitter_noise(0.1)
|
||||
.build();
|
||||
|
||||
assert_eq!(config.dim, 128);
|
||||
assert_eq!(config.num_experts, 8);
|
||||
assert_eq!(config.top_k, 3);
|
||||
}
|
||||
}
|
||||
210
crates/ruvector-attention/src/moe/router.rs
Normal file
210
crates/ruvector-attention/src/moe/router.rs
Normal file
@@ -0,0 +1,210 @@
|
||||
//! Router implementations for MoE expert selection
|
||||
|
||||
use crate::utils::stable_softmax;
|
||||
|
||||
/// Router trait for expert selection
|
||||
pub trait Router: Send + Sync {
|
||||
/// Route input to experts, returning (expert_idx, weight) pairs
|
||||
fn route(&self, x: &[f32]) -> Vec<(usize, f32)>;
|
||||
|
||||
/// Get number of experts
|
||||
fn num_experts(&self) -> usize;
|
||||
}
|
||||
|
||||
/// Top-K routing decision
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct TopKRouting {
|
||||
/// Selected experts with their normalized weights
|
||||
pub selections: Vec<(usize, f32)>,
|
||||
}
|
||||
|
||||
/// Learned router with softmax gating
|
||||
pub struct LearnedRouter {
|
||||
num_experts: usize,
|
||||
dim: usize,
|
||||
top_k: usize,
|
||||
/// Gate weights: [num_experts x dim]
|
||||
gate_weights: Vec<f32>,
|
||||
}
|
||||
|
||||
impl LearnedRouter {
|
||||
/// Create new learned router
|
||||
pub fn new(num_experts: usize, dim: usize, top_k: usize) -> Self {
|
||||
// Initialize gate weights with Xavier initialization
|
||||
let scale = (2.0 / (dim + num_experts) as f32).sqrt();
|
||||
let mut seed = 42u64;
|
||||
|
||||
let gate_weights: Vec<f32> = (0..num_experts * dim)
|
||||
.map(|_| {
|
||||
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
|
||||
let u = (seed as f32) / (u64::MAX as f32);
|
||||
(u - 0.5) * 2.0 * scale
|
||||
})
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
num_experts,
|
||||
dim,
|
||||
top_k: top_k.min(num_experts),
|
||||
gate_weights,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute raw gate logits
|
||||
fn compute_logits(&self, x: &[f32]) -> Vec<f32> {
|
||||
(0..self.num_experts)
|
||||
.map(|i| {
|
||||
x.iter()
|
||||
.enumerate()
|
||||
.map(|(j, &xj)| xj * self.gate_weights[i * self.dim + j])
|
||||
.sum()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Compute gate probabilities
|
||||
pub fn compute_gate(&self, x: &[f32]) -> Vec<f32> {
|
||||
let logits = self.compute_logits(x);
|
||||
stable_softmax(&logits)
|
||||
}
|
||||
|
||||
/// Compute load balancing loss for batch
|
||||
pub fn load_balance_loss(&self, routing_decisions: &[TopKRouting]) -> f32 {
|
||||
if routing_decisions.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let batch_size = routing_decisions.len() as f32;
|
||||
|
||||
// Count how many times each expert is used
|
||||
let mut expert_counts = vec![0.0f32; self.num_experts];
|
||||
let mut total_weight = vec![0.0f32; self.num_experts];
|
||||
|
||||
for decision in routing_decisions {
|
||||
for &(expert_idx, weight) in &decision.selections {
|
||||
expert_counts[expert_idx] += 1.0;
|
||||
total_weight[expert_idx] += weight;
|
||||
}
|
||||
}
|
||||
|
||||
// Compute auxiliary loss: encourage uniform distribution
|
||||
let _avg_count = expert_counts.iter().sum::<f32>() / self.num_experts as f32;
|
||||
let _avg_weight = total_weight.iter().sum::<f32>() / self.num_experts as f32;
|
||||
|
||||
// CV-squared loss from Switch Transformer paper
|
||||
let count_var: f32 = expert_counts
|
||||
.iter()
|
||||
.map(|c| (c / batch_size - 1.0 / self.num_experts as f32).powi(2))
|
||||
.sum();
|
||||
|
||||
self.num_experts as f32 * count_var
|
||||
}
|
||||
|
||||
/// Update gate weights (for training)
|
||||
pub fn update_weights(&mut self, gradients: &[f32], learning_rate: f32) {
|
||||
for (w, g) in self.gate_weights.iter_mut().zip(gradients.iter()) {
|
||||
*w -= learning_rate * g;
|
||||
}
|
||||
}
|
||||
|
||||
/// Get expert usage statistics
|
||||
pub fn expert_statistics(&self, routing_decisions: &[TopKRouting]) -> Vec<f32> {
|
||||
let mut counts = vec![0.0f32; self.num_experts];
|
||||
|
||||
for decision in routing_decisions {
|
||||
for &(expert_idx, _) in &decision.selections {
|
||||
counts[expert_idx] += 1.0;
|
||||
}
|
||||
}
|
||||
|
||||
let total: f32 = counts.iter().sum();
|
||||
if total > 0.0 {
|
||||
counts.iter_mut().for_each(|c| *c /= total);
|
||||
}
|
||||
|
||||
counts
|
||||
}
|
||||
}
|
||||
|
||||
impl Router for LearnedRouter {
|
||||
fn route(&self, x: &[f32]) -> Vec<(usize, f32)> {
|
||||
let probs = self.compute_gate(x);
|
||||
|
||||
// Get top-k indices
|
||||
let mut indexed: Vec<(usize, f32)> = probs.into_iter().enumerate().collect();
|
||||
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
|
||||
// Take top-k and renormalize
|
||||
let top_k: Vec<(usize, f32)> = indexed.into_iter().take(self.top_k).collect();
|
||||
let sum: f32 = top_k.iter().map(|(_, p)| p).sum();
|
||||
|
||||
if sum > 1e-8 {
|
||||
top_k.into_iter().map(|(i, p)| (i, p / sum)).collect()
|
||||
} else {
|
||||
// Fallback: uniform over top-k
|
||||
top_k
|
||||
.into_iter()
|
||||
.map(|(i, _)| (i, 1.0 / self.top_k as f32))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
fn num_experts(&self) -> usize {
|
||||
self.num_experts
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_learned_router() {
|
||||
let router = LearnedRouter::new(4, 64, 2);
|
||||
|
||||
let x = vec![0.5; 64];
|
||||
let routes = router.route(&x);
|
||||
|
||||
assert_eq!(routes.len(), 2);
|
||||
|
||||
// Weights should sum to 1
|
||||
let sum: f32 = routes.iter().map(|(_, w)| w).sum();
|
||||
assert!((sum - 1.0).abs() < 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_balance_loss() {
|
||||
let router = LearnedRouter::new(4, 32, 2);
|
||||
|
||||
// Simulate routing decisions
|
||||
let decisions: Vec<TopKRouting> = (0..100)
|
||||
.map(|i| TopKRouting {
|
||||
selections: vec![(i % 4, 0.6), ((i + 1) % 4, 0.4)],
|
||||
})
|
||||
.collect();
|
||||
|
||||
let loss = router.load_balance_loss(&decisions);
|
||||
assert!(loss >= 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expert_statistics() {
|
||||
let router = LearnedRouter::new(4, 32, 2);
|
||||
|
||||
let decisions: Vec<TopKRouting> = vec![
|
||||
TopKRouting {
|
||||
selections: vec![(0, 0.6), (1, 0.4)],
|
||||
},
|
||||
TopKRouting {
|
||||
selections: vec![(0, 0.5), (2, 0.5)],
|
||||
},
|
||||
];
|
||||
|
||||
let stats = router.expert_statistics(&decisions);
|
||||
assert_eq!(stats.len(), 4);
|
||||
|
||||
// Should sum to 1
|
||||
let sum: f32 = stats.iter().sum();
|
||||
assert!((sum - 1.0).abs() < 1e-5);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user