git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
537 lines
16 KiB
Rust
537 lines
16 KiB
Rust
//! λ-based Mixture-of-Depths (MoD) routing.
|
|
//!
|
|
//! Unlike learned routers (Raposo et al., 2024), we use mincut λ-delta as the routing signal.
|
|
//! Tokens with stable coherence can skip layers; boundary tokens must compute.
|
|
//!
|
|
//! ## Design Rationale
|
|
//!
|
|
//! Traditional MoD uses learned routing mechanisms, but this introduces:
|
|
//! - Non-deterministic behavior
|
|
//! - Additional training overhead
|
|
//! - Lack of explainability
|
|
//!
|
|
//! Our approach leverages the existing mincut λ signal:
|
|
//! - λ-delta stable → token can skip (coherence maintained)
|
|
//! - λ-delta volatile → token must compute (on partition boundary)
|
|
//! - Boundary token → always compute (critical for coherence)
|
|
//!
|
|
//! This achieves 50% FLOPs reduction while maintaining deterministic behavior
|
|
//! and providing clear intervention witnesses.
|
|
|
|
extern crate alloc;
|
|
use alloc::vec;
|
|
use alloc::vec::Vec;
|
|
|
|
use crate::packets::GatePacket;
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
/// Configuration for MoD routing.
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
pub struct ModRoutingConfig {
|
|
/// Threshold for λ-delta to allow skipping (Q15: 0-32767)
|
|
/// If |λ_delta| < threshold, token is considered stable and can skip
|
|
pub lambda_delta_skip_threshold: i32,
|
|
|
|
/// Whether to force boundary tokens to always compute
|
|
/// When true, tokens identified as on partition boundaries must compute
|
|
pub boundary_token_force_compute: bool,
|
|
|
|
/// Layer capacity ratio (0.0-1.0)
|
|
/// 0.5 = only 50% of tokens can compute per layer (MoD target)
|
|
pub layer_capacity_ratio: f32,
|
|
|
|
/// Minimum tokens that must compute per layer
|
|
/// Ensures at least this many tokens compute regardless of routing
|
|
pub min_tokens_per_layer: u16,
|
|
|
|
/// Enable adaptive capacity based on λ stability
|
|
/// When true, capacity adjusts based on overall coherence
|
|
pub adaptive_capacity: bool,
|
|
}
|
|
|
|
impl Default for ModRoutingConfig {
|
|
fn default() -> Self {
|
|
Self {
|
|
// Allow skip if λ changed by less than ~10% (3276 / 32768 ≈ 0.1)
|
|
lambda_delta_skip_threshold: 3276,
|
|
boundary_token_force_compute: true,
|
|
// Target 50% FLOPs reduction (Raposo et al., 2024)
|
|
layer_capacity_ratio: 0.5,
|
|
min_tokens_per_layer: 4,
|
|
adaptive_capacity: true,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl ModRoutingConfig {
|
|
/// Create a configuration targeting specific FLOPs reduction
|
|
///
|
|
/// # Arguments
|
|
/// * `flops_reduction` - Target FLOPs reduction (0.0-1.0), e.g., 0.5 for 50%
|
|
pub fn with_flops_reduction(flops_reduction: f32) -> Self {
|
|
Self {
|
|
layer_capacity_ratio: 1.0 - flops_reduction.clamp(0.0, 0.9),
|
|
..Default::default()
|
|
}
|
|
}
|
|
|
|
/// Validate configuration
|
|
pub fn validate(&self) -> Result<(), &'static str> {
|
|
if self.layer_capacity_ratio <= 0.0 || self.layer_capacity_ratio > 1.0 {
|
|
return Err("layer_capacity_ratio must be in range (0.0, 1.0]");
|
|
}
|
|
if self.lambda_delta_skip_threshold < 0 {
|
|
return Err("lambda_delta_skip_threshold must be non-negative");
|
|
}
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
/// Router decision for a token.
|
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
|
#[repr(u8)]
|
|
pub enum TokenRoute {
|
|
/// Process through full attention + FFN
|
|
Compute = 0,
|
|
|
|
/// Skip layer - residual connection only
|
|
Skip = 1,
|
|
|
|
/// Must compute - token is on partition boundary
|
|
Boundary = 2,
|
|
}
|
|
|
|
impl TokenRoute {
|
|
/// Check if this route requires computation
|
|
#[inline]
|
|
pub fn requires_compute(&self) -> bool {
|
|
!matches!(self, TokenRoute::Skip)
|
|
}
|
|
|
|
/// Check if this is a boundary token
|
|
#[inline]
|
|
pub fn is_boundary(&self) -> bool {
|
|
matches!(self, TokenRoute::Boundary)
|
|
}
|
|
}
|
|
|
|
/// MoD router using mincut λ signals.
|
|
///
|
|
/// This router decides which tokens should compute at each layer based on:
|
|
/// 1. λ-delta stability (stable tokens can skip)
|
|
/// 2. Boundary token detection (boundary tokens must compute)
|
|
/// 3. Layer capacity constraints (enforce target FLOPs reduction)
|
|
pub struct MincutDepthRouter {
|
|
config: ModRoutingConfig,
|
|
}
|
|
|
|
impl MincutDepthRouter {
|
|
/// Create a new MoD router with the given configuration
|
|
pub fn new(config: ModRoutingConfig) -> Result<Self, &'static str> {
|
|
config.validate()?;
|
|
Ok(Self { config })
|
|
}
|
|
}
|
|
|
|
impl Default for MincutDepthRouter {
|
|
fn default() -> Self {
|
|
Self {
|
|
config: ModRoutingConfig::default(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl MincutDepthRouter {
|
|
/// Route tokens based on gate packet and token positions.
|
|
///
|
|
/// # Arguments
|
|
/// * `gate` - Gate packet with λ signals
|
|
/// * `token_positions` - Position indices of tokens in sequence
|
|
///
|
|
/// # Returns
|
|
/// Vector of routing decisions, one per token
|
|
pub fn route_tokens(&self, gate: &GatePacket, token_positions: &[u16]) -> Vec<TokenRoute> {
|
|
let num_tokens = token_positions.len();
|
|
if num_tokens == 0 {
|
|
return Vec::new();
|
|
}
|
|
|
|
let mut routes = vec![TokenRoute::Skip; num_tokens];
|
|
|
|
// Calculate effective capacity for this layer
|
|
let capacity = self.calculate_layer_capacity(gate, num_tokens);
|
|
|
|
// Step 1: Mark boundary tokens (must compute)
|
|
let boundary_count = if self.config.boundary_token_force_compute {
|
|
self.mark_boundary_tokens(gate, &mut routes, token_positions)
|
|
} else {
|
|
0
|
|
};
|
|
|
|
// Step 2: Route remaining tokens based on λ-delta stability
|
|
let mut compute_count = boundary_count;
|
|
let lambda_delta_abs = gate.lambda_delta().abs();
|
|
|
|
// If λ is unstable, more tokens should compute
|
|
if lambda_delta_abs > self.config.lambda_delta_skip_threshold {
|
|
// Unstable coherence - route more tokens to compute
|
|
compute_count += self.route_unstable_tokens(
|
|
gate,
|
|
&mut routes,
|
|
token_positions,
|
|
capacity.saturating_sub(boundary_count),
|
|
);
|
|
} else {
|
|
// Stable coherence - can skip more aggressively
|
|
compute_count += self.route_stable_tokens(
|
|
gate,
|
|
&mut routes,
|
|
token_positions,
|
|
capacity.saturating_sub(boundary_count),
|
|
);
|
|
}
|
|
|
|
// Step 3: Ensure minimum compute tokens
|
|
if compute_count < self.config.min_tokens_per_layer as usize {
|
|
self.ensure_minimum_compute(
|
|
&mut routes,
|
|
self.config.min_tokens_per_layer as usize - compute_count,
|
|
);
|
|
}
|
|
|
|
routes
|
|
}
|
|
|
|
/// Compute layer mask from routing decisions.
|
|
///
|
|
/// # Arguments
|
|
/// * `routes` - Routing decisions for all tokens
|
|
/// * `layer` - Current layer index (for future layer-specific routing)
|
|
///
|
|
/// # Returns
|
|
/// Boolean mask where `true` means token should compute
|
|
pub fn compute_layer_mask(&self, routes: &[TokenRoute], _layer: usize) -> Vec<bool> {
|
|
routes.iter().map(|r| r.requires_compute()).collect()
|
|
}
|
|
|
|
/// Get routing statistics for analysis
|
|
pub fn routing_stats(&self, routes: &[TokenRoute]) -> RoutingStats {
|
|
let total = routes.len();
|
|
let compute = routes.iter().filter(|r| r.requires_compute()).count();
|
|
let skip = routes
|
|
.iter()
|
|
.filter(|r| matches!(r, TokenRoute::Skip))
|
|
.count();
|
|
let boundary = routes.iter().filter(|r| r.is_boundary()).count();
|
|
|
|
RoutingStats {
|
|
total_tokens: total,
|
|
compute_tokens: compute,
|
|
skip_tokens: skip,
|
|
boundary_tokens: boundary,
|
|
compute_ratio: if total > 0 {
|
|
compute as f32 / total as f32
|
|
} else {
|
|
0.0
|
|
},
|
|
skip_ratio: if total > 0 {
|
|
skip as f32 / total as f32
|
|
} else {
|
|
0.0
|
|
},
|
|
}
|
|
}
|
|
|
|
// ---- Private helpers ----
|
|
|
|
fn calculate_layer_capacity(&self, gate: &GatePacket, num_tokens: usize) -> usize {
|
|
let mut capacity = (num_tokens as f32 * self.config.layer_capacity_ratio).ceil() as usize;
|
|
|
|
// Adaptive capacity based on λ stability
|
|
if self.config.adaptive_capacity {
|
|
let lambda_delta_abs = gate.lambda_delta().abs();
|
|
let stability_ratio = 1.0 - (lambda_delta_abs as f32 / 32768.0).min(1.0);
|
|
|
|
// If very stable (high stability_ratio), can reduce capacity further
|
|
// If unstable (low stability_ratio), increase capacity
|
|
let adjustment = if stability_ratio > 0.9 {
|
|
0.9 // Very stable - use even less capacity
|
|
} else if stability_ratio < 0.5 {
|
|
1.2 // Unstable - use more capacity
|
|
} else {
|
|
1.0 // Normal
|
|
};
|
|
|
|
capacity = (capacity as f32 * adjustment).ceil() as usize;
|
|
}
|
|
|
|
capacity
|
|
.max(self.config.min_tokens_per_layer as usize)
|
|
.min(num_tokens)
|
|
}
|
|
|
|
fn mark_boundary_tokens(
|
|
&self,
|
|
gate: &GatePacket,
|
|
routes: &mut [TokenRoute],
|
|
token_positions: &[u16],
|
|
) -> usize {
|
|
// Heuristic: tokens near partition boundaries based on boundary_concentration
|
|
// Higher boundary_concentration means fewer, more concentrated boundaries
|
|
|
|
let boundary_ratio = if gate.boundary_concentration_q15 > 16384 {
|
|
// High concentration - fewer boundary tokens
|
|
0.1
|
|
} else {
|
|
// Low concentration - more boundary tokens
|
|
0.2
|
|
};
|
|
|
|
let boundary_count = (routes.len() as f32 * boundary_ratio).ceil() as usize;
|
|
let mut marked = 0;
|
|
|
|
// Simple heuristic: mark tokens at regular intervals as potential boundaries
|
|
// In practice, this would use actual boundary edge IDs from mincut
|
|
if boundary_count > 0 && !token_positions.is_empty() {
|
|
let stride = routes.len() / boundary_count.max(1);
|
|
for i in (0..routes.len()).step_by(stride.max(1)) {
|
|
if marked >= boundary_count {
|
|
break;
|
|
}
|
|
routes[i] = TokenRoute::Boundary;
|
|
marked += 1;
|
|
}
|
|
}
|
|
|
|
marked
|
|
}
|
|
|
|
fn route_unstable_tokens(
|
|
&self,
|
|
_gate: &GatePacket,
|
|
routes: &mut [TokenRoute],
|
|
_token_positions: &[u16],
|
|
target_count: usize,
|
|
) -> usize {
|
|
// When unstable, route more tokens to compute
|
|
// Prioritize tokens not already marked as boundary
|
|
let mut routed = 0;
|
|
|
|
for route in routes.iter_mut() {
|
|
if routed >= target_count {
|
|
break;
|
|
}
|
|
if matches!(route, TokenRoute::Skip) {
|
|
*route = TokenRoute::Compute;
|
|
routed += 1;
|
|
}
|
|
}
|
|
|
|
routed
|
|
}
|
|
|
|
fn route_stable_tokens(
|
|
&self,
|
|
_gate: &GatePacket,
|
|
routes: &mut [TokenRoute],
|
|
_token_positions: &[u16],
|
|
target_count: usize,
|
|
) -> usize {
|
|
// When stable, can skip more aggressively
|
|
// Only route enough tokens to meet target capacity
|
|
let mut routed = 0;
|
|
|
|
for route in routes.iter_mut() {
|
|
if routed >= target_count {
|
|
break;
|
|
}
|
|
if matches!(route, TokenRoute::Skip) {
|
|
*route = TokenRoute::Compute;
|
|
routed += 1;
|
|
}
|
|
}
|
|
|
|
routed
|
|
}
|
|
|
|
fn ensure_minimum_compute(&self, routes: &mut [TokenRoute], needed: usize) {
|
|
let mut added = 0;
|
|
|
|
for route in routes.iter_mut() {
|
|
if added >= needed {
|
|
break;
|
|
}
|
|
if matches!(route, TokenRoute::Skip) {
|
|
*route = TokenRoute::Compute;
|
|
added += 1;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Statistics for a routing decision.
|
|
#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)]
|
|
pub struct RoutingStats {
|
|
/// Total number of tokens
|
|
pub total_tokens: usize,
|
|
|
|
/// Number of tokens that computed
|
|
pub compute_tokens: usize,
|
|
|
|
/// Number of tokens that skipped
|
|
pub skip_tokens: usize,
|
|
|
|
/// Number of boundary tokens
|
|
pub boundary_tokens: usize,
|
|
|
|
/// Ratio of tokens that computed
|
|
pub compute_ratio: f32,
|
|
|
|
/// Ratio of tokens that skipped
|
|
pub skip_ratio: f32,
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use alloc::vec;
|
|
use alloc::vec::Vec;
|
|
|
|
#[test]
|
|
fn test_mod_routing_config_default() {
|
|
let config = ModRoutingConfig::default();
|
|
assert!(config.validate().is_ok());
|
|
assert_eq!(config.layer_capacity_ratio, 0.5);
|
|
}
|
|
|
|
#[test]
|
|
fn test_mod_routing_config_flops_reduction() {
|
|
let config = ModRoutingConfig::with_flops_reduction(0.5);
|
|
assert_eq!(config.layer_capacity_ratio, 0.5);
|
|
|
|
let config = ModRoutingConfig::with_flops_reduction(0.75);
|
|
assert_eq!(config.layer_capacity_ratio, 0.25);
|
|
}
|
|
|
|
#[test]
|
|
fn test_token_route_methods() {
|
|
assert!(TokenRoute::Compute.requires_compute());
|
|
assert!(!TokenRoute::Skip.requires_compute());
|
|
assert!(TokenRoute::Boundary.requires_compute());
|
|
|
|
assert!(!TokenRoute::Compute.is_boundary());
|
|
assert!(TokenRoute::Boundary.is_boundary());
|
|
}
|
|
|
|
#[test]
|
|
fn test_router_creation() {
|
|
let router = MincutDepthRouter::default();
|
|
assert_eq!(router.config.layer_capacity_ratio, 0.5);
|
|
|
|
let config = ModRoutingConfig::default();
|
|
let router = MincutDepthRouter::new(config);
|
|
assert!(router.is_ok());
|
|
}
|
|
|
|
#[test]
|
|
fn test_route_tokens_stable() {
|
|
let router = MincutDepthRouter::default();
|
|
let gate = GatePacket {
|
|
lambda: 100,
|
|
lambda_prev: 95, // Small delta (5)
|
|
boundary_edges: 5,
|
|
boundary_concentration_q15: 20000,
|
|
partition_count: 3,
|
|
flags: 0,
|
|
};
|
|
|
|
let tokens: Vec<u16> = (0..16).collect();
|
|
let routes = router.route_tokens(&gate, &tokens);
|
|
|
|
assert_eq!(routes.len(), 16);
|
|
|
|
let stats = router.routing_stats(&routes);
|
|
assert_eq!(stats.total_tokens, 16);
|
|
assert!(stats.skip_ratio > 0.0); // Should skip some tokens when stable
|
|
}
|
|
|
|
#[test]
|
|
fn test_route_tokens_unstable() {
|
|
let router = MincutDepthRouter::default();
|
|
let gate = GatePacket {
|
|
lambda: 40,
|
|
lambda_prev: 100, // Large delta (60)
|
|
boundary_edges: 15,
|
|
boundary_concentration_q15: 8000,
|
|
partition_count: 5,
|
|
flags: 0,
|
|
};
|
|
|
|
let tokens: Vec<u16> = (0..16).collect();
|
|
let routes = router.route_tokens(&gate, &tokens);
|
|
|
|
let stats = router.routing_stats(&routes);
|
|
// When unstable, should compute more tokens
|
|
assert!(stats.compute_ratio >= 0.5);
|
|
}
|
|
|
|
#[test]
|
|
fn test_compute_layer_mask() {
|
|
let router = MincutDepthRouter::default();
|
|
let routes = vec![
|
|
TokenRoute::Compute,
|
|
TokenRoute::Skip,
|
|
TokenRoute::Boundary,
|
|
TokenRoute::Skip,
|
|
];
|
|
|
|
let mask = router.compute_layer_mask(&routes, 0);
|
|
assert_eq!(mask, vec![true, false, true, false]);
|
|
}
|
|
|
|
#[test]
|
|
fn test_routing_stats() {
|
|
let router = MincutDepthRouter::default();
|
|
let routes = vec![
|
|
TokenRoute::Compute,
|
|
TokenRoute::Compute,
|
|
TokenRoute::Skip,
|
|
TokenRoute::Skip,
|
|
TokenRoute::Boundary,
|
|
TokenRoute::Skip,
|
|
];
|
|
|
|
let stats = router.routing_stats(&routes);
|
|
assert_eq!(stats.total_tokens, 6);
|
|
assert_eq!(stats.compute_tokens, 3); // 2 Compute + 1 Boundary
|
|
assert_eq!(stats.skip_tokens, 3);
|
|
assert_eq!(stats.boundary_tokens, 1);
|
|
assert_eq!(stats.compute_ratio, 0.5);
|
|
}
|
|
|
|
#[test]
|
|
fn test_minimum_tokens_enforced() {
|
|
let config = ModRoutingConfig {
|
|
min_tokens_per_layer: 8,
|
|
..Default::default()
|
|
};
|
|
let router = MincutDepthRouter::new(config).unwrap();
|
|
|
|
let gate = GatePacket {
|
|
lambda: 100,
|
|
lambda_prev: 99, // Very stable
|
|
boundary_edges: 0,
|
|
boundary_concentration_q15: 30000,
|
|
partition_count: 1,
|
|
flags: 0,
|
|
};
|
|
|
|
let tokens: Vec<u16> = (0..16).collect();
|
|
let routes = router.route_tokens(&gate, &tokens);
|
|
|
|
let stats = router.routing_stats(&routes);
|
|
// Should have at least min_tokens_per_layer computing
|
|
assert!(stats.compute_tokens >= 8);
|
|
}
|
|
}
|