Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,401 @@
//! Federation Coordinator - Cluster Management
//!
//! Manages the multi-chip cluster with self-learning optimization.
//! Integrates MicroLoRA for distributed fine-tuning.
use super::protocol::{ChipId, FederationMessage, MessageType, CommStats};
use super::{FederationConfig, FederationMode, FederationSpeedup, estimate_speedup};
use crate::optimizations::micro_lora::{MicroLoRA, LoRAConfig, LoRAStack};
/// Maximum chips in cluster
pub const MAX_CLUSTER_SIZE: usize = 8;
/// Cluster topology
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ClusterTopology {
/// Linear pipeline: 0 -> 1 -> 2 -> 3 -> 4
Linear,
/// Ring: 0 -> 1 -> 2 -> 3 -> 4 -> 0
Ring,
/// Star: 0 <-> all others
Star,
/// Mesh: all-to-all
Mesh,
}
/// Chip status in cluster
#[derive(Debug, Clone)]
pub struct ChipStatus {
/// Chip ID
pub id: ChipId,
/// Is chip active
pub active: bool,
/// Last heartbeat time (in ticks)
pub last_heartbeat: u32,
/// Current load (0-255)
pub load: u8,
/// Memory used (KB)
pub memory_used_kb: u16,
/// Tokens processed
pub tokens_processed: u32,
}
/// Self-learning state for optimization
#[derive(Debug, Clone)]
pub struct SelfLearningState {
/// Learning rate for LoRA updates
pub learning_rate: i8,
/// Gradient accumulation counter
pub gradient_steps: u32,
/// Average loss (fixed-point)
pub avg_loss: i32,
/// Best loss seen
pub best_loss: i32,
/// Adaptation enabled
pub enabled: bool,
}
impl Default for SelfLearningState {
fn default() -> Self {
Self {
learning_rate: 4,
gradient_steps: 0,
avg_loss: i32::MAX,
best_loss: i32::MAX,
enabled: false,
}
}
}
/// Federation coordinator
pub struct FederationCoordinator {
/// This coordinator's chip ID
chip_id: ChipId,
/// Is this the master coordinator
is_master: bool,
/// Cluster configuration
config: FederationConfig,
/// Topology
topology: ClusterTopology,
/// Status of all chips
chip_status: [Option<ChipStatus>; MAX_CLUSTER_SIZE],
/// Communication stats
comm_stats: CommStats,
/// Self-learning state
learning: SelfLearningState,
/// Distributed LoRA adapters (one per layer shard)
lora_stack: Option<LoRAStack<4>>,
/// Current tick (for timeouts)
current_tick: u32,
/// Sequence counter
seq_counter: u16,
}
impl FederationCoordinator {
/// Create new coordinator
pub fn new(config: FederationConfig, is_master: bool) -> Self {
let chip_status = core::array::from_fn(|i| {
if i < config.num_chips {
Some(ChipStatus {
id: ChipId(i as u8),
active: i == config.chip_id.0 as usize,
last_heartbeat: 0,
load: 0,
memory_used_kb: 0,
tokens_processed: 0,
})
} else {
None
}
});
Self {
chip_id: config.chip_id,
is_master,
topology: Self::optimal_topology(&config),
config,
chip_status,
comm_stats: CommStats::default(),
learning: SelfLearningState::default(),
lora_stack: None,
current_tick: 0,
seq_counter: 0,
}
}
/// Determine optimal topology for config
fn optimal_topology(config: &FederationConfig) -> ClusterTopology {
match config.mode {
FederationMode::Pipeline => ClusterTopology::Linear,
FederationMode::TensorParallel => ClusterTopology::Star,
FederationMode::Speculative => ClusterTopology::Star,
FederationMode::MixtureOfExperts => ClusterTopology::Mesh,
_ => ClusterTopology::Linear,
}
}
/// Initialize distributed LoRA for self-learning
pub fn init_distributed_lora(&mut self, dim: usize, seed: u32) -> crate::Result<()> {
let lora_config = LoRAConfig {
rank: 1, // Minimal rank for distributed
dim,
scale: 8,
frozen: false,
};
let mut stack = LoRAStack::new();
// Each chip gets LoRA for its assigned layers
let layers_per_chip = self.config.layers_per_chip;
for i in 0..layers_per_chip.min(4) {
let layer_seed = seed.wrapping_add(i as u32 * 1000);
let adapter = MicroLoRA::new(lora_config, layer_seed)?;
stack.add_adapter(i, adapter)?;
}
self.lora_stack = Some(stack);
self.learning.enabled = true;
Ok(())
}
/// Process tick (call regularly)
pub fn tick(&mut self) {
self.current_tick += 1;
// Check for timeouts
for status in self.chip_status.iter_mut().flatten() {
if self.current_tick - status.last_heartbeat > 1000 {
status.active = false;
}
}
}
/// Handle received message
pub fn handle_message(&mut self, msg: &FederationMessage) -> Option<FederationMessage> {
self.comm_stats.messages_received += 1;
self.comm_stats.bytes_received += msg.payload.len() as u32;
let msg_type = MessageType::from(msg.header.msg_type);
match msg_type {
MessageType::Heartbeat => {
// Update chip status
let src = msg.header.src as usize;
if let Some(status) = self.chip_status.get_mut(src).and_then(|s| s.as_mut()) {
status.active = true;
status.last_heartbeat = self.current_tick;
}
None
}
MessageType::Discovery => {
// Respond with our status
Some(self.create_heartbeat())
}
MessageType::Barrier => {
// Acknowledge barrier
Some(FederationMessage::new(
MessageType::Ack,
self.chip_id,
ChipId(msg.header.src),
msg.header.seq,
))
}
_ => None,
}
}
/// Create heartbeat message
pub fn create_heartbeat(&mut self) -> FederationMessage {
self.seq_counter += 1;
let mut msg = FederationMessage::new(
MessageType::Heartbeat,
self.chip_id,
ChipId::BROADCAST,
self.seq_counter,
);
// Add load info to payload
if let Some(status) = &self.chip_status[self.chip_id.0 as usize] {
let _ = msg.payload.push(status.load);
let _ = msg.payload.push((status.memory_used_kb & 0xFF) as u8);
let _ = msg.payload.push((status.memory_used_kb >> 8) as u8);
}
msg.header.payload_len = msg.payload.len() as u16;
msg.update_checksum();
self.comm_stats.messages_sent += 1;
msg
}
/// Get number of active chips
pub fn active_chip_count(&self) -> usize {
self.chip_status.iter().filter(|s| s.as_ref().is_some_and(|s| s.active)).count()
}
/// Estimate current speedup based on active chips
pub fn current_speedup(&self) -> FederationSpeedup {
let active = self.active_chip_count();
let mut effective_config = self.config.clone();
effective_config.num_chips = active;
estimate_speedup(&effective_config)
}
/// Update learning state with loss
pub fn update_learning(&mut self, loss: i32) {
if !self.learning.enabled {
return;
}
self.learning.gradient_steps += 1;
// Exponential moving average of loss
if self.learning.avg_loss == i32::MAX {
self.learning.avg_loss = loss;
} else {
self.learning.avg_loss = (self.learning.avg_loss * 15 + loss) / 16;
}
// Track best
if loss < self.learning.best_loss {
self.learning.best_loss = loss;
}
// Adaptive learning rate
if self.learning.gradient_steps % 100 == 0 {
if self.learning.avg_loss < self.learning.best_loss * 11 / 10 {
// Good progress, increase LR
self.learning.learning_rate = (self.learning.learning_rate + 1).min(16);
} else {
// Slow progress, decrease LR
self.learning.learning_rate = (self.learning.learning_rate - 1).max(1);
}
}
}
/// Apply distributed LoRA update
#[cfg(not(feature = "frozen"))]
pub fn apply_lora_gradient(
&mut self,
layer_idx: usize,
input: &[i8],
grad_output: &[i32],
) {
if let Some(ref mut stack) = self.lora_stack {
if let Some(lora) = stack.get(layer_idx) {
lora.update(input, grad_output, self.learning.learning_rate);
}
}
}
/// Get LoRA adapter for a layer
pub fn get_lora(&mut self, layer_idx: usize) -> Option<&mut MicroLoRA> {
self.lora_stack.as_mut()?.get(layer_idx)
}
/// Get cluster statistics
pub fn stats(&self) -> ClusterStats {
let total_tokens: u32 = self.chip_status.iter()
.filter_map(|s| s.as_ref())
.map(|s| s.tokens_processed)
.sum();
let total_memory: u32 = self.chip_status.iter()
.filter_map(|s| s.as_ref())
.map(|s| s.memory_used_kb as u32)
.sum();
ClusterStats {
active_chips: self.active_chip_count(),
total_chips: self.config.num_chips,
total_tokens_processed: total_tokens,
total_memory_kb: total_memory,
messages_sent: self.comm_stats.messages_sent,
messages_received: self.comm_stats.messages_received,
current_speedup: self.current_speedup(),
learning_enabled: self.learning.enabled,
learning_rate: self.learning.learning_rate,
avg_loss: self.learning.avg_loss,
}
}
/// Update chip's token count
pub fn record_tokens(&mut self, count: u32) {
if let Some(status) = self.chip_status.get_mut(self.chip_id.0 as usize).and_then(|s| s.as_mut()) {
status.tokens_processed += count;
}
}
/// Update chip's memory usage
pub fn update_memory_usage(&mut self, kb: u16) {
if let Some(status) = self.chip_status.get_mut(self.chip_id.0 as usize).and_then(|s| s.as_mut()) {
status.memory_used_kb = kb;
}
}
}
/// Cluster statistics
#[derive(Debug, Clone)]
pub struct ClusterStats {
/// Active chips
pub active_chips: usize,
/// Total chips configured
pub total_chips: usize,
/// Total tokens processed
pub total_tokens_processed: u32,
/// Total memory used (KB)
pub total_memory_kb: u32,
/// Messages sent
pub messages_sent: u32,
/// Messages received
pub messages_received: u32,
/// Current speedup estimate
pub current_speedup: FederationSpeedup,
/// Self-learning enabled
pub learning_enabled: bool,
/// Current learning rate
pub learning_rate: i8,
/// Average loss
pub avg_loss: i32,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_coordinator_creation() {
let config = FederationConfig::default();
let coord = FederationCoordinator::new(config, true);
assert_eq!(coord.active_chip_count(), 1); // Only self is active initially
}
#[test]
fn test_distributed_lora() {
let config = FederationConfig::default();
let mut coord = FederationCoordinator::new(config, true);
coord.init_distributed_lora(32, 42).unwrap();
assert!(coord.learning.enabled);
assert!(coord.get_lora(0).is_some());
}
#[test]
fn test_learning_update() {
let config = FederationConfig::default();
let mut coord = FederationCoordinator::new(config, true);
coord.learning.enabled = true;
coord.update_learning(1000);
coord.update_learning(900);
coord.update_learning(800);
assert!(coord.learning.avg_loss < 1000);
assert_eq!(coord.learning.best_loss, 800);
}
}

View File

@@ -0,0 +1,344 @@
//! FastGRNN-Inspired Micro Router for ESP32
//!
//! Lightweight gated routing for dynamic chip selection.
//! Adapted from ruvector's FastGRNN for minimal compute overhead.
//!
//! Key differences from full FastGRNN:
//! - INT8 weights instead of FP32
//! - Fixed-point gate computation
//! - Minimal hidden dimension (4-8)
use heapless::Vec as HVec;
use super::protocol::ChipId;
/// Maximum hidden dimension for micro router
pub const MAX_ROUTER_HIDDEN: usize = 8;
/// Maximum input features
pub const MAX_ROUTER_INPUT: usize = 16;
/// Micro FastGRNN configuration
#[derive(Debug, Clone, Copy)]
pub struct MicroGRNNConfig {
/// Input dimension
pub input_dim: usize,
/// Hidden dimension
pub hidden_dim: usize,
/// Number of output classes (chips)
pub num_chips: usize,
/// Zeta parameter (gate scaling)
pub zeta: i8,
/// Nu parameter (update scaling)
pub nu: i8,
}
impl Default for MicroGRNNConfig {
fn default() -> Self {
Self {
input_dim: 8,
hidden_dim: 4,
num_chips: 5,
zeta: 16,
nu: 16,
}
}
}
/// Micro FastGRNN cell for routing decisions
pub struct MicroFastGRNN {
config: MicroGRNNConfig,
/// Gate weights: W_g [input_dim * hidden_dim] + U_g [hidden_dim * hidden_dim]
w_gate: HVec<i8, 128>,
u_gate: HVec<i8, 64>,
/// Update weights: W_u, U_u
w_update: HVec<i8, 128>,
u_update: HVec<i8, 64>,
/// Biases
bias_gate: HVec<i8, MAX_ROUTER_HIDDEN>,
bias_update: HVec<i8, MAX_ROUTER_HIDDEN>,
/// Output projection to chips
w_output: HVec<i8, 64>,
/// Hidden state
hidden: HVec<i32, MAX_ROUTER_HIDDEN>,
}
impl MicroFastGRNN {
/// Create new micro FastGRNN
pub fn new(config: MicroGRNNConfig, seed: u32) -> crate::Result<Self> {
let mut rng_state = seed;
let mut next_rand = || {
rng_state = rng_state.wrapping_mul(1103515245).wrapping_add(12345);
(((rng_state >> 16) & 0x3F) as i16 - 32) as i8
};
// Initialize weights
let gate_size = config.input_dim * config.hidden_dim;
let hidden_size = config.hidden_dim * config.hidden_dim;
let output_size = config.hidden_dim * config.num_chips;
let mut w_gate = HVec::new();
let mut u_gate = HVec::new();
let mut w_update = HVec::new();
let mut u_update = HVec::new();
let mut w_output = HVec::new();
let mut bias_gate = HVec::new();
let mut bias_update = HVec::new();
let mut hidden = HVec::new();
for _ in 0..gate_size {
w_gate.push(next_rand()).map_err(|_| crate::Error::BufferOverflow)?;
w_update.push(next_rand()).map_err(|_| crate::Error::BufferOverflow)?;
}
for _ in 0..hidden_size {
u_gate.push(next_rand()).map_err(|_| crate::Error::BufferOverflow)?;
u_update.push(next_rand()).map_err(|_| crate::Error::BufferOverflow)?;
}
for _ in 0..output_size {
w_output.push(next_rand()).map_err(|_| crate::Error::BufferOverflow)?;
}
for _ in 0..config.hidden_dim {
bias_gate.push(0).map_err(|_| crate::Error::BufferOverflow)?;
bias_update.push(0).map_err(|_| crate::Error::BufferOverflow)?;
hidden.push(0).map_err(|_| crate::Error::BufferOverflow)?;
}
Ok(Self {
config,
w_gate,
u_gate,
w_update,
u_update,
bias_gate,
bias_update,
w_output,
hidden,
})
}
/// Reset hidden state
pub fn reset(&mut self) {
for h in self.hidden.iter_mut() {
*h = 0;
}
}
/// Fixed-point sigmoid approximation
#[inline]
fn sigmoid_fp(x: i32) -> i32 {
// Piecewise linear sigmoid: clamp to [0, 256] representing [0, 1]
if x < -512 { 0 }
else if x > 512 { 256 }
else { (x + 512) >> 2 }
}
/// Fixed-point tanh approximation
#[inline]
fn tanh_fp(x: i32) -> i32 {
// Piecewise linear tanh: clamp to [-256, 256] representing [-1, 1]
if x < -512 { -256 }
else if x > 512 { 256 }
else { x >> 1 }
}
/// Matrix-vector multiply (INT8 weights, INT32 accumulator)
fn matmul(&self, weights: &[i8], input: &[i32], rows: usize, cols: usize) -> HVec<i32, MAX_ROUTER_HIDDEN> {
let mut output = HVec::new();
for r in 0..rows {
let mut sum: i32 = 0;
for c in 0..cols {
if c < input.len() {
sum += weights[r * cols + c] as i32 * input[c];
}
}
let _ = output.push(sum >> 8); // Scale down
}
output
}
/// One step of FastGRNN computation
///
/// h_new = (1 - z) ⊙ h + z ⊙ tanh(W_u*x + U_u*h + b_u)
/// where z = sigmoid(W_g*x + U_g*h + b_g)
pub fn step(&mut self, input: &[i8]) -> crate::Result<()> {
// Convert input to i32
let input_i32: HVec<i32, MAX_ROUTER_INPUT> = input.iter()
.take(self.config.input_dim)
.map(|&x| x as i32 * 16) // Scale up
.collect();
// Compute gate: z = sigmoid(W_g * x + U_g * h + b_g)
let wx_gate = self.matmul(&self.w_gate, &input_i32, self.config.hidden_dim, self.config.input_dim);
let uh_gate = self.matmul(&self.u_gate, &self.hidden, self.config.hidden_dim, self.config.hidden_dim);
let mut gate = HVec::<i32, MAX_ROUTER_HIDDEN>::new();
for i in 0..self.config.hidden_dim {
let wx = wx_gate.get(i).copied().unwrap_or(0);
let uh = uh_gate.get(i).copied().unwrap_or(0);
let b = self.bias_gate.get(i).copied().unwrap_or(0) as i32 * 16;
let z = Self::sigmoid_fp((wx + uh + b) * self.config.zeta as i32 / 16);
let _ = gate.push(z);
}
// Compute update: u = tanh(W_u * x + U_u * h + b_u)
let wx_update = self.matmul(&self.w_update, &input_i32, self.config.hidden_dim, self.config.input_dim);
let uh_update = self.matmul(&self.u_update, &self.hidden, self.config.hidden_dim, self.config.hidden_dim);
// Update hidden state: h = (1 - z) * h + z * u
for i in 0..self.config.hidden_dim {
let wx = wx_update.get(i).copied().unwrap_or(0);
let uh = uh_update.get(i).copied().unwrap_or(0);
let b = self.bias_update.get(i).copied().unwrap_or(0) as i32 * 16;
let u = Self::tanh_fp((wx + uh + b) * self.config.nu as i32 / 16);
let z = gate.get(i).copied().unwrap_or(128);
let h = self.hidden.get(i).copied().unwrap_or(0);
// h_new = (256 - z) * h / 256 + z * u / 256
let h_new = ((256 - z) * h + z * u) >> 8;
self.hidden[i] = h_new;
}
Ok(())
}
/// Get routing decision (which chip to use)
pub fn route(&self) -> ChipId {
// Output projection: scores = W_o * hidden
let mut scores = [0i32; 8];
for chip in 0..self.config.num_chips {
let mut sum: i32 = 0;
for h in 0..self.config.hidden_dim {
let w_idx = chip * self.config.hidden_dim + h;
let w = self.w_output.get(w_idx).copied().unwrap_or(0) as i32;
let hidden = self.hidden.get(h).copied().unwrap_or(0);
sum += w * hidden;
}
scores[chip] = sum;
}
// Find argmax
let mut best_chip = 0;
let mut best_score = scores[0];
for (i, &score) in scores[..self.config.num_chips].iter().enumerate() {
if score > best_score {
best_score = score;
best_chip = i;
}
}
ChipId(best_chip as u8)
}
/// Get routing probabilities (softmax-like)
pub fn route_probs(&self) -> HVec<u8, 8> {
let mut probs = HVec::new();
let mut scores = [0i32; 8];
let mut max_score = i32::MIN;
// Compute scores
for chip in 0..self.config.num_chips {
let mut sum: i32 = 0;
for h in 0..self.config.hidden_dim {
let w_idx = chip * self.config.hidden_dim + h;
let w = self.w_output.get(w_idx).copied().unwrap_or(0) as i32;
let hidden = self.hidden.get(h).copied().unwrap_or(0);
sum += w * hidden;
}
scores[chip] = sum;
if sum > max_score {
max_score = sum;
}
}
// Simple softmax approximation
let mut total: i32 = 0;
for chip in 0..self.config.num_chips {
let exp_score = (scores[chip] - max_score + 256).max(1);
scores[chip] = exp_score;
total += exp_score;
}
for chip in 0..self.config.num_chips {
let prob = (scores[chip] * 255 / total.max(1)) as u8;
let _ = probs.push(prob);
}
probs
}
/// Memory size
pub fn memory_size(&self) -> usize {
self.w_gate.len() + self.u_gate.len() +
self.w_update.len() + self.u_update.len() +
self.w_output.len() +
self.bias_gate.len() + self.bias_update.len() +
self.hidden.len() * 4
}
}
/// Feature extractor for routing input
pub struct RoutingFeatures {
/// Token embedding summary (mean)
pub embed_mean: i8,
/// Token embedding variance proxy
pub embed_var: i8,
/// Current sequence position (normalized)
pub position: i8,
/// Current load on each chip (0-127)
pub chip_loads: [i8; 5],
}
impl RoutingFeatures {
/// Convert to input vector
pub fn to_input(&self) -> [i8; 8] {
[
self.embed_mean,
self.embed_var,
self.position,
self.chip_loads[0],
self.chip_loads[1],
self.chip_loads[2],
self.chip_loads[3],
self.chip_loads[4],
]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_micro_fastgrnn() {
let config = MicroGRNNConfig::default();
let mut router = MicroFastGRNN::new(config, 42).unwrap();
// Test step
let input = [10i8, 20, 30, 40, 50, 60, 70, 80];
router.step(&input).unwrap();
// Should produce valid routing
let chip = router.route();
assert!(chip.0 < 5);
println!("Memory: {} bytes", router.memory_size());
}
#[test]
fn test_routing_probs() {
let config = MicroGRNNConfig::default();
let mut router = MicroFastGRNN::new(config, 42).unwrap();
let input = [10i8; 8];
router.step(&input).unwrap();
let probs = router.route_probs();
assert_eq!(probs.len(), 5);
// Sum should be approximately 255
let sum: i32 = probs.iter().map(|&p| p as i32).sum();
assert!(sum > 200 && sum < 280);
}
}

View File

@@ -0,0 +1,705 @@
//! Massive Scale Federation - 100s to Millions of Chips
//!
//! Hierarchical coordination for extreme-scale distributed inference.
//!
//! # Topology Options
//!
//! ```text
//! Flat (≤16 chips): Hierarchical Tree (≤10K): Hypercube (≤1M):
//! ○─○─○─○─○ ┌───[Root]───┐ ○═══○
//! │ │ │ │ │ │ │ │ ╱│ │╲
//! └─┴─┴─┴─┘ [L1] [L1] [L1] ○─┼───┼─○
//! │││ │││ │││ │ ○═══○ │
//! chips chips chips ○═══════○
//! ```
//!
//! # Scaling Laws
//!
//! - **Pipeline**: O(n) throughput, O(1) latency per stage
//! - **Tree**: O(log n) coordination, O(n) compute
//! - **Hypercube**: O(log n) hops, O(n) total bandwidth
//! - **Torus**: O(√n) diameter, excellent locality
use heapless::Vec as HVec;
use super::protocol::ChipId;
/// Maximum depth for hierarchical topologies
pub const MAX_TREE_DEPTH: usize = 20; // 2^20 = 1M chips
/// Maximum children per node in tree
pub const MAX_CHILDREN: usize = 16;
/// Maximum nodes at any level
pub const MAX_LEVEL_NODES: usize = 64;
/// Large-scale topology types
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum MassiveTopology {
/// Flat mesh - up to ~16 chips
FlatMesh { size: usize },
/// Binary tree - scales to millions
BinaryTree { depth: usize },
/// K-ary tree with configurable fanout
KaryTree { depth: usize, fanout: usize },
/// Hypercube - O(log n) diameter
Hypercube { dimensions: usize },
/// 2D Torus - good for spatial locality
Torus2D { width: usize, height: usize },
/// 3D Torus - even better scaling
Torus3D { x: usize, y: usize, z: usize },
/// Butterfly network - FFT-like communication
Butterfly { stages: usize },
/// Hierarchical pipeline - practical for real deployments
HierarchicalPipeline {
clusters: usize, // Number of clusters
chips_per_cluster: usize,
},
}
impl MassiveTopology {
/// Total number of chips in topology
pub fn total_chips(&self) -> usize {
match *self {
Self::FlatMesh { size } => size,
Self::BinaryTree { depth } => (1 << depth) - 1,
Self::KaryTree { depth, fanout } => {
// (k^(d+1) - 1) / (k - 1)
if fanout == 1 { depth + 1 }
else { (fanout.pow(depth as u32 + 1) - 1) / (fanout - 1) }
}
Self::Hypercube { dimensions } => 1 << dimensions,
Self::Torus2D { width, height } => width * height,
Self::Torus3D { x, y, z } => x * y * z,
Self::Butterfly { stages } => stages * (1 << stages),
Self::HierarchicalPipeline { clusters, chips_per_cluster } => {
clusters * chips_per_cluster
}
}
}
/// Network diameter (max hops between any two nodes)
pub fn diameter(&self) -> usize {
match *self {
Self::FlatMesh { size } => size - 1,
Self::BinaryTree { depth } => 2 * depth,
Self::KaryTree { depth, .. } => 2 * depth,
Self::Hypercube { dimensions } => dimensions,
Self::Torus2D { width, height } => width / 2 + height / 2,
Self::Torus3D { x, y, z } => x / 2 + y / 2 + z / 2,
Self::Butterfly { stages } => stages,
Self::HierarchicalPipeline { chips_per_cluster, .. } => {
chips_per_cluster + 2 // Within cluster + up + down
}
}
}
/// Bisection bandwidth (edges crossing middle cut)
pub fn bisection_bandwidth(&self) -> usize {
match *self {
Self::FlatMesh { .. } => 1,
Self::BinaryTree { .. } => 1, // Root is bottleneck
Self::KaryTree { fanout, .. } => fanout,
Self::Hypercube { dimensions } => 1 << (dimensions - 1),
Self::Torus2D { width, height } => 2 * width.min(height),
Self::Torus3D { x, y, z } => 2 * x.min(y).min(z) * x.min(y).min(z),
Self::Butterfly { stages } => 1 << (stages - 1),
Self::HierarchicalPipeline { clusters, .. } => clusters,
}
}
/// Recommended topology for given chip count
pub fn recommended(chip_count: usize) -> Self {
match chip_count {
0..=16 => Self::FlatMesh { size: chip_count },
17..=256 => Self::HierarchicalPipeline {
clusters: (chip_count as f64).sqrt().ceil() as usize,
chips_per_cluster: (chip_count as f64).sqrt().ceil() as usize,
},
257..=10_000 => {
// Use hierarchical pipeline for medium scale
let clusters = (chip_count as f64).sqrt().ceil() as usize;
let per_cluster = (chip_count + clusters - 1) / clusters;
Self::HierarchicalPipeline {
clusters,
chips_per_cluster: per_cluster,
}
}
10_001..=1_000_000 => {
// Hypercube for large scale
let dims = (chip_count as f64).log2().ceil() as usize;
Self::Hypercube { dimensions: dims }
}
_ => {
// Millions+ : 3D Torus
let side = (chip_count as f64).cbrt().ceil() as usize;
Self::Torus3D { x: side, y: side, z: side }
}
}
}
}
/// Scaling configuration for massive clusters
#[derive(Debug, Clone)]
pub struct MassiveScaleConfig {
/// Topology type
pub topology: MassiveTopology,
/// Layers of model
pub total_layers: usize,
/// Embedding dimension
pub embed_dim: usize,
/// Communication latency per hop (microseconds)
pub hop_latency_us: usize,
/// Bandwidth per link (bytes/sec)
pub link_bandwidth: usize,
/// Computation time per layer (microseconds)
pub layer_compute_us: usize,
/// Enable speculative execution
pub speculative: bool,
/// Speculation depth (tokens to draft)
pub spec_depth: usize,
/// Enable gradient checkpointing for memory
pub gradient_checkpointing: bool,
/// Fault tolerance level (0=none, 1=retry, 2=redundancy)
pub fault_tolerance: u8,
}
impl Default for MassiveScaleConfig {
fn default() -> Self {
Self {
topology: MassiveTopology::HierarchicalPipeline {
clusters: 10,
chips_per_cluster: 10,
},
total_layers: 32,
embed_dim: 64,
hop_latency_us: 10, // SPI latency
link_bandwidth: 10_000_000, // 10 MB/s
layer_compute_us: 4000, // 4ms per layer on ESP32
speculative: true,
spec_depth: 4,
gradient_checkpointing: false,
fault_tolerance: 1,
}
}
}
/// Performance projection for massive scale
#[derive(Debug, Clone)]
pub struct ScaleProjection {
/// Total chips
pub total_chips: usize,
/// Throughput in tokens/sec
pub throughput_tokens_sec: f64,
/// Latency per token in milliseconds
pub latency_ms: f64,
/// Memory per chip in KB
pub memory_per_chip_kb: f64,
/// Total model parameters supportable
pub max_parameters: usize,
/// Efficiency (vs linear scaling)
pub efficiency: f64,
/// Communication overhead percentage
pub comm_overhead_pct: f64,
/// Estimated power in watts
pub power_watts: f64,
/// Estimated cost in USD
pub cost_usd: f64,
}
/// Massive scale simulator
pub struct MassiveScaleSimulator {
config: MassiveScaleConfig,
}
impl MassiveScaleSimulator {
pub fn new(config: MassiveScaleConfig) -> Self {
Self { config }
}
/// Project performance for current configuration
pub fn project(&self) -> ScaleProjection {
let chips = self.config.topology.total_chips();
let diameter = self.config.topology.diameter();
let bisection = self.config.topology.bisection_bandwidth();
// Compute distribution
let layers_per_chip = (self.config.total_layers as f64 / chips as f64).max(0.1);
let compute_per_chip_us = layers_per_chip * self.config.layer_compute_us as f64;
// Communication cost
let activation_size = self.config.embed_dim * 4; // INT8 with some overhead
let comm_time_us = (activation_size as f64 / self.config.link_bandwidth as f64)
* 1_000_000.0
* diameter as f64;
// Pipeline efficiency
let pipeline_stages = chips.min(self.config.total_layers);
let bubble_overhead = (pipeline_stages - 1) as f64 / pipeline_stages as f64;
// Speculative multiplier
let spec_multiplier = if self.config.speculative {
1.0 + (self.config.spec_depth as f64 - 1.0) * 0.7 // 70% acceptance
} else {
1.0
};
// Throughput calculation
let base_throughput = 1_000_000.0 / compute_per_chip_us.max(1.0);
let comm_factor = 1.0 / (1.0 + comm_time_us / compute_per_chip_us.max(1.0));
let efficiency = (1.0 - bubble_overhead * 0.15) * comm_factor;
let throughput = base_throughput * pipeline_stages as f64 * efficiency * spec_multiplier;
// Latency
let latency_us = compute_per_chip_us * pipeline_stages as f64 + comm_time_us;
let latency_ms = latency_us / 1000.0;
// Memory
let base_memory_kb = 119.0; // Single chip baseline
let memory_per_chip = base_memory_kb / (chips as f64).sqrt().max(1.0);
// Max parameters
let params_per_chip = (memory_per_chip * 1024.0 * 0.7) as usize; // 70% for weights
let max_parameters = params_per_chip * chips;
// Communication overhead
let comm_overhead = comm_time_us / (compute_per_chip_us + comm_time_us) * 100.0;
// Power and cost estimates
let power_per_chip = 0.5; // 500mW per ESP32
let cost_per_chip = 4.0; // $4 per ESP32
ScaleProjection {
total_chips: chips,
throughput_tokens_sec: throughput,
latency_ms,
memory_per_chip_kb: memory_per_chip,
max_parameters,
efficiency,
comm_overhead_pct: comm_overhead,
power_watts: power_per_chip * chips as f64,
cost_usd: cost_per_chip * chips as f64,
}
}
/// Run scaling study across multiple configurations
pub fn scaling_study(&self, chip_counts: &[usize]) -> HVec<ScaleProjection, 32> {
let mut results = HVec::new();
for &count in chip_counts {
let topology = MassiveTopology::recommended(count);
let config = MassiveScaleConfig {
topology,
..self.config.clone()
};
let sim = MassiveScaleSimulator::new(config);
let _ = results.push(sim.project());
}
results
}
/// Find optimal configuration for target throughput
pub fn optimize_for_throughput(&self, target_tokens_sec: f64) -> MassiveScaleConfig {
let mut best_config = self.config.clone();
let mut best_efficiency = 0.0;
// Try different chip counts
for power in 2..=20 {
let chips = 1 << power;
for &topology in &[
MassiveTopology::KaryTree { depth: power, fanout: 4 },
MassiveTopology::Hypercube { dimensions: power },
MassiveTopology::HierarchicalPipeline {
clusters: 1 << (power / 2),
chips_per_cluster: 1 << (power - power / 2),
},
] {
if topology.total_chips() < 4 { continue; }
let config = MassiveScaleConfig {
topology,
..self.config.clone()
};
let sim = MassiveScaleSimulator::new(config.clone());
let proj = sim.project();
if proj.throughput_tokens_sec >= target_tokens_sec {
let efficiency = proj.throughput_tokens_sec / (proj.total_chips as f64);
if efficiency > best_efficiency {
best_efficiency = efficiency;
best_config = config;
}
}
}
}
best_config
}
}
/// Distributed coordinator for massive scale
pub struct DistributedCoordinator {
/// This node's ID
node_id: u32,
/// Parent node (None if root)
parent: Option<u32>,
/// Child nodes
children: HVec<u32, MAX_CHILDREN>,
/// Sibling nodes (same level)
siblings: HVec<u32, MAX_CHILDREN>,
/// Current level in hierarchy
level: u8,
/// Total levels
total_levels: u8,
/// Local state
local_state: NodeState,
}
/// State of a node in the distributed system
#[derive(Debug, Clone, Default)]
pub struct NodeState {
/// Tokens processed
pub tokens_processed: u64,
/// Current load (0-255)
pub load: u8,
/// Last heartbeat (ticks)
pub last_heartbeat: u32,
/// Active flag
pub active: bool,
/// Current sequence position being processed
pub seq_position: u32,
/// Error count
pub errors: u16,
}
impl DistributedCoordinator {
/// Create coordinator for position in tree
pub fn new(node_id: u32, total_nodes: usize, topology: MassiveTopology) -> Self {
let (parent, children, siblings, level, total_levels) =
Self::compute_neighbors(node_id, total_nodes, topology);
Self {
node_id,
parent,
children,
siblings,
level,
total_levels,
local_state: NodeState { active: true, ..Default::default() },
}
}
fn compute_neighbors(
node_id: u32,
total_nodes: usize,
topology: MassiveTopology
) -> (Option<u32>, HVec<u32, MAX_CHILDREN>, HVec<u32, MAX_CHILDREN>, u8, u8) {
let mut children = HVec::new();
let mut siblings = HVec::new();
match topology {
MassiveTopology::BinaryTree { depth } |
MassiveTopology::KaryTree { depth, fanout: 2 } => {
let level = (node_id + 1).ilog2() as u8;
let parent = if node_id == 0 { None } else { Some((node_id - 1) / 2) };
let left = 2 * node_id + 1;
let right = 2 * node_id + 2;
if (left as usize) < total_nodes {
let _ = children.push(left);
}
if (right as usize) < total_nodes {
let _ = children.push(right);
}
// Sibling
if node_id > 0 {
let sib = if node_id % 2 == 1 { node_id + 1 } else { node_id - 1 };
if (sib as usize) < total_nodes {
let _ = siblings.push(sib);
}
}
(parent, children, siblings, level, depth as u8)
}
MassiveTopology::Hypercube { dimensions } => {
// In hypercube, neighbors differ by one bit
let level = node_id.count_ones() as u8;
for d in 0..dimensions {
let neighbor = node_id ^ (1 << d);
if (neighbor as usize) < total_nodes {
if neighbor < node_id {
// Could be parent
}
let _ = siblings.push(neighbor);
}
}
(None, children, siblings, level, dimensions as u8)
}
MassiveTopology::HierarchicalPipeline { clusters, chips_per_cluster } => {
let cluster_id = node_id as usize / chips_per_cluster;
let local_id = node_id as usize % chips_per_cluster;
let level = local_id as u8;
// Parent is previous in pipeline
let parent = if local_id > 0 {
Some(node_id - 1)
} else if cluster_id > 0 {
// Cross-cluster: last node of previous cluster
Some((cluster_id * chips_per_cluster - 1) as u32)
} else {
None
};
// Child is next in pipeline
if local_id + 1 < chips_per_cluster {
let _ = children.push(node_id + 1);
} else if cluster_id + 1 < clusters {
// Cross-cluster
let _ = children.push(((cluster_id + 1) * chips_per_cluster) as u32);
}
(parent, children, siblings, level, chips_per_cluster as u8)
}
_ => {
// Default: linear chain
let parent = if node_id > 0 { Some(node_id - 1) } else { None };
if ((node_id + 1) as usize) < total_nodes {
let _ = children.push(node_id + 1);
}
(parent, children, siblings, node_id as u8, total_nodes as u8)
}
}
}
/// Check if this node is root
pub fn is_root(&self) -> bool {
self.parent.is_none()
}
/// Check if this node is leaf
pub fn is_leaf(&self) -> bool {
self.children.is_empty()
}
/// Get nodes to send to for broadcast
pub fn broadcast_targets(&self) -> &[u32] {
&self.children
}
/// Get node to send to for aggregation (reduce)
pub fn reduce_target(&self) -> Option<u32> {
self.parent
}
/// Update local state
pub fn update_state(&mut self, tokens: u64, load: u8) {
self.local_state.tokens_processed = tokens;
self.local_state.load = load;
self.local_state.last_heartbeat = self.local_state.last_heartbeat.wrapping_add(1);
}
/// Get aggregate statistics (for root to report)
pub fn aggregate_stats(&self, child_stats: &[NodeState]) -> NodeState {
let mut agg = self.local_state.clone();
for child in child_stats {
agg.tokens_processed += child.tokens_processed;
agg.load = agg.load.saturating_add(child.load / (child_stats.len() as u8).max(1));
agg.errors += child.errors;
}
agg
}
}
/// Gossip protocol for state synchronization at massive scale
pub struct GossipProtocol {
/// Known node states (sampled)
known_states: HVec<(u32, NodeState), 64>,
/// Fanout for gossip
fanout: usize,
/// Round number
round: u32,
}
impl GossipProtocol {
pub fn new(fanout: usize) -> Self {
Self {
known_states: HVec::new(),
fanout,
round: 0,
}
}
/// Select random nodes for gossip
pub fn select_gossip_targets(&self, my_id: u32, total_nodes: usize, seed: u32) -> HVec<u32, 8> {
let mut targets = HVec::new();
let mut rng = seed.wrapping_mul(1103515245).wrapping_add(my_id);
for _ in 0..self.fanout.min(8) {
rng = rng.wrapping_mul(1103515245).wrapping_add(12345);
let target = (rng % total_nodes as u32) as u32;
if target != my_id && !targets.contains(&target) {
let _ = targets.push(target);
}
}
targets
}
/// Merge received state
pub fn merge_state(&mut self, node_id: u32, state: NodeState) {
// Update or insert
for (id, s) in self.known_states.iter_mut() {
if *id == node_id {
*s = state;
return;
}
}
// Insert new
if self.known_states.len() < 64 {
let _ = self.known_states.push((node_id, state));
} else {
// Replace oldest (simple LRU)
self.known_states[0] = (node_id, state);
}
}
/// Get estimated cluster health
pub fn cluster_health(&self) -> f32 {
if self.known_states.is_empty() {
return 1.0;
}
let active = self.known_states.iter().filter(|(_, s)| s.active).count();
active as f32 / self.known_states.len() as f32
}
}
/// Fault tolerance manager
pub struct FaultTolerance {
/// Redundancy level (1 = no redundancy, 2 = pairs, 3 = triples)
redundancy: u8,
/// Failed node IDs
failed_nodes: HVec<u32, 64>,
/// Backup assignments (primary -> backup)
backups: HVec<(u32, u32), 32>,
}
impl FaultTolerance {
pub fn new(redundancy: u8) -> Self {
Self {
redundancy: redundancy.max(1),
failed_nodes: HVec::new(),
backups: HVec::new(),
}
}
/// Mark node as failed
pub fn mark_failed(&mut self, node_id: u32) {
if !self.failed_nodes.contains(&node_id) {
let _ = self.failed_nodes.push(node_id);
}
}
/// Get backup for failed node
pub fn get_backup(&self, failed_id: u32) -> Option<u32> {
self.backups.iter()
.find(|(primary, _)| *primary == failed_id)
.map(|(_, backup)| *backup)
}
/// Assign backups for nodes
pub fn assign_backups(&mut self, total_nodes: usize) {
if self.redundancy < 2 { return; }
for i in 0..total_nodes {
let backup = (i + total_nodes / 2) % total_nodes;
if self.backups.len() < 32 {
let _ = self.backups.push((i as u32, backup as u32));
}
}
}
/// Check if node is available (not failed)
pub fn is_available(&self, node_id: u32) -> bool {
!self.failed_nodes.contains(&node_id)
}
/// Get failure rate
pub fn failure_rate(&self, total_nodes: usize) -> f32 {
self.failed_nodes.len() as f32 / total_nodes as f32
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_topology_sizing() {
assert_eq!(MassiveTopology::BinaryTree { depth: 10 }.total_chips(), 1023);
assert_eq!(MassiveTopology::Hypercube { dimensions: 10 }.total_chips(), 1024);
assert_eq!(MassiveTopology::Torus2D { width: 100, height: 100 }.total_chips(), 10_000);
}
#[test]
fn test_scaling_projection() {
let config = MassiveScaleConfig {
topology: MassiveTopology::HierarchicalPipeline {
clusters: 10,
chips_per_cluster: 10,
},
..Default::default()
};
let sim = MassiveScaleSimulator::new(config);
let proj = sim.project();
assert_eq!(proj.total_chips, 100);
assert!(proj.throughput_tokens_sec > 1000.0);
assert!(proj.efficiency > 0.5);
println!("100 chips: {:.0} tok/s, {:.1}% efficiency",
proj.throughput_tokens_sec, proj.efficiency * 100.0);
}
#[test]
fn test_massive_scale() {
let chip_counts = [5, 100, 1000, 10_000, 100_000, 1_000_000];
for &count in &chip_counts {
let topology = MassiveTopology::recommended(count);
let config = MassiveScaleConfig {
topology,
..Default::default()
};
let sim = MassiveScaleSimulator::new(config);
let proj = sim.project();
println!("{:>10} chips: {:>12.0} tok/s, {:>6.1}% eff, ${:.0}",
count, proj.throughput_tokens_sec, proj.efficiency * 100.0, proj.cost_usd);
}
}
#[test]
fn test_distributed_coordinator() {
let coord = DistributedCoordinator::new(
5,
100,
MassiveTopology::BinaryTree { depth: 7 }
);
assert!(!coord.is_root());
println!("Node 5: parent={:?}, children={:?}", coord.parent, coord.children);
}
#[test]
fn test_gossip_protocol() {
let mut gossip = GossipProtocol::new(3);
let targets = gossip.select_gossip_targets(5, 1000, 42);
assert!(!targets.is_empty());
assert!(!targets.contains(&5)); // Shouldn't include self
gossip.merge_state(10, NodeState { active: true, ..Default::default() });
assert_eq!(gossip.cluster_health(), 1.0);
}
}

View File

@@ -0,0 +1,420 @@
//! Medium Scale Federation - 100 to 500 Chip Clusters
//!
//! This is the "sweet spot" for ESP32 federation:
//! - High efficiency (40-70%)
//! - Practical throughput (50K-100K tokens/sec)
//! - Manageable communication overhead
//! - Affordable cost ($400-$2,000)
//!
//! # Why 100-500 Chips?
//!
//! ```text
//! Performance vs Chip Count:
//!
//! 100K ┤ ┌─────────────────────── Communication-bound
//! │ ____/│ Sweet Spot
//! 80K ┤ / │ 100-500 chips
//! │ / │
//! 60K ┤ / │ • 40-70% efficiency
//! │ │ │ • Low communication overhead
//! 40K ┤ │ │ • Best $/performance
//! ││ └─────────────────────────────────
//! 20K ┤│
//! │
//! 0 ┼──────────────────────────────────────────────────
//! 5 50 100 200 500 1K 5K 10K 100K 1M
//! ▲ ▲
//! │ │
//! Good start Best value
//! ```
//!
//! # Topology Recommendations
//!
//! | Chips | Best Topology | Clusters × Chips | Efficiency |
//! |-------|---------------|------------------|------------|
//! | 100 | 10×10 Grid | 10 × 10 | ~70% |
//! | 144 | 12×12 Grid | 12 × 12 | ~65% |
//! | 256 | 16×16 Grid | 16 × 16 | ~55% |
//! | 400 | 20×20 Grid | 20 × 20 | ~45% |
//! | 500 | 25×20 Grid | 25 × 20 | ~40% |
use super::massive_scale::{MassiveTopology, MassiveScaleConfig, MassiveScaleSimulator, ScaleProjection};
use heapless::Vec as HVec;
/// Medium-scale cluster sizes (sweet spot)
pub const MEDIUM_SCALE_MIN: usize = 100;
pub const MEDIUM_SCALE_MAX: usize = 500;
pub const MEDIUM_SCALE_OPTIMAL: usize = 256; // Best efficiency/throughput balance
/// Pre-optimized cluster configurations
#[derive(Debug, Clone, Copy)]
pub struct MediumClusterConfig {
/// Total chips in cluster
pub total_chips: usize,
/// Number of clusters (groups)
pub clusters: usize,
/// Chips per cluster
pub chips_per_cluster: usize,
/// Expected throughput (tokens/sec)
pub expected_throughput: f64,
/// Expected efficiency
pub expected_efficiency: f64,
/// Estimated cost USD
pub cost_usd: f64,
/// Power consumption watts
pub power_watts: f64,
/// Max model parameters supportable
pub max_params: usize,
}
impl MediumClusterConfig {
/// Get optimal configuration for given chip count
pub fn optimal_for(chip_count: usize) -> Self {
let chips = chip_count.clamp(MEDIUM_SCALE_MIN, MEDIUM_SCALE_MAX);
// Find best square-ish layout
let sqrt = (chips as f64).sqrt();
let clusters = sqrt.ceil() as usize;
let per_cluster = (chips + clusters - 1) / clusters;
let actual_chips = clusters * per_cluster;
// Simulate to get accurate projections
let config = MassiveScaleConfig {
topology: MassiveTopology::HierarchicalPipeline {
clusters,
chips_per_cluster: per_cluster,
},
total_layers: 32,
embed_dim: 64,
hop_latency_us: 10,
link_bandwidth: 10_000_000,
layer_compute_us: 4000,
speculative: true,
spec_depth: 4,
gradient_checkpointing: false,
fault_tolerance: 1,
};
let sim = MassiveScaleSimulator::new(config);
let proj = sim.project();
Self {
total_chips: actual_chips,
clusters,
chips_per_cluster: per_cluster,
expected_throughput: proj.throughput_tokens_sec,
expected_efficiency: proj.efficiency,
cost_usd: proj.cost_usd,
power_watts: proj.power_watts,
max_params: proj.max_parameters,
}
}
/// Get all standard configurations
pub fn standard_configs() -> [Self; 5] {
[
Self::optimal_for(100),
Self::optimal_for(144),
Self::optimal_for(256),
Self::optimal_for(400),
Self::optimal_for(500),
]
}
}
/// Comparison with smaller clusters
#[derive(Debug, Clone)]
pub struct ScaleComparison {
/// Single chip baseline
pub single_chip: ScaleProjection,
/// 5-chip small cluster
pub small_cluster: ScaleProjection,
/// Medium cluster (specified)
pub medium_cluster: ScaleProjection,
/// Throughput multiplier vs single
pub throughput_multiplier: f64,
/// Throughput multiplier vs 5-chip
pub vs_small_multiplier: f64,
/// Cost per 1K tokens/sec
pub cost_per_1k_tokens: f64,
}
impl ScaleComparison {
/// Compare medium cluster against baselines
pub fn analyze(chip_count: usize) -> Self {
let base_config = MassiveScaleConfig {
total_layers: 32,
embed_dim: 64,
hop_latency_us: 10,
link_bandwidth: 10_000_000,
layer_compute_us: 4000,
speculative: true,
spec_depth: 4,
..Default::default()
};
// Single chip
let single_sim = MassiveScaleSimulator::new(MassiveScaleConfig {
topology: MassiveTopology::FlatMesh { size: 1 },
..base_config.clone()
});
let single = single_sim.project();
// 5-chip small cluster
let small_sim = MassiveScaleSimulator::new(MassiveScaleConfig {
topology: MassiveTopology::FlatMesh { size: 5 },
..base_config.clone()
});
let small = small_sim.project();
// Medium cluster
let medium_sim = MassiveScaleSimulator::new(MassiveScaleConfig {
topology: MassiveTopology::recommended(chip_count),
..base_config.clone()
});
let medium = medium_sim.project();
Self {
throughput_multiplier: medium.throughput_tokens_sec / single.throughput_tokens_sec,
vs_small_multiplier: medium.throughput_tokens_sec / small.throughput_tokens_sec,
cost_per_1k_tokens: medium.cost_usd / (medium.throughput_tokens_sec / 1000.0),
single_chip: single,
small_cluster: small,
medium_cluster: medium,
}
}
}
/// Model categories that can run at different scales
#[derive(Debug, Clone, Copy)]
pub enum ModelCategory {
/// 50K-500K params, minimal memory
Nano,
/// 500K-5M params, basic tasks
Micro,
/// 5M-20M params, good general use
Small,
/// 20M-100M params, high quality
Base,
/// 100M-500M params, needs large clusters
Large,
}
impl ModelCategory {
/// Minimum chips required for this model category
pub fn min_chips(&self) -> usize {
match self {
Self::Nano => 1,
Self::Micro => 5,
Self::Small => 50,
Self::Base => 200,
Self::Large => 500,
}
}
/// Parameter range
pub fn param_range(&self) -> (usize, usize) {
match self {
Self::Nano => (50_000, 500_000),
Self::Micro => (500_000, 5_000_000),
Self::Small => (5_000_000, 20_000_000),
Self::Base => (20_000_000, 100_000_000),
Self::Large => (100_000_000, 500_000_000),
}
}
/// Example models
pub fn examples(&self) -> &'static str {
match self {
Self::Nano => "TinyBERT-nano, Custom embeddings",
Self::Micro => "DistilBERT-tiny, MiniLM",
Self::Small => "TinyLlama, Phi-nano",
Self::Base => "Phi-1, GPT-2-Small",
Self::Large => "Phi-2, LLaMA-7B (quantized)",
}
}
/// What's possible with given chip count
pub fn for_chip_count(chips: usize) -> Self {
match chips {
0..=4 => Self::Nano,
5..=49 => Self::Micro,
50..=199 => Self::Small,
200..=499 => Self::Base,
_ => Self::Large,
}
}
}
/// Hardware configuration for physical deployment
#[derive(Debug, Clone)]
pub struct HardwareConfig {
/// Chips per PCB (physical board)
pub chips_per_board: usize,
/// Number of PCBs
pub num_boards: usize,
/// Communication bus
pub bus_type: BusType,
/// Power supply requirement (watts)
pub power_supply_watts: f64,
/// Recommended form factor
pub form_factor: &'static str,
}
#[derive(Debug, Clone, Copy)]
pub enum BusType {
/// SPI - up to 40MHz, simple
Spi,
/// I2C - 400kHz standard, lower bandwidth
I2c,
/// UART mesh - flexible, medium speed
Uart,
/// Custom high-speed interconnect
HighSpeed,
}
impl BusType {
pub fn bandwidth_bytes_sec(&self) -> usize {
match self {
Self::Spi => 5_000_000, // 5 MB/s typical
Self::I2c => 50_000, // 50 KB/s
Self::Uart => 1_000_000, // 1 MB/s at 10Mbaud
Self::HighSpeed => 50_000_000, // Custom FPGA/ASIC
}
}
}
impl HardwareConfig {
/// Recommended hardware for chip count
pub fn for_cluster(chip_count: usize) -> Self {
match chip_count {
0..=25 => Self {
chips_per_board: chip_count.min(10),
num_boards: (chip_count + 9) / 10,
bus_type: BusType::Spi,
power_supply_watts: chip_count as f64 * 0.5 + 10.0,
form_factor: "Single PCB or small rack",
},
26..=100 => Self {
chips_per_board: 10,
num_boards: (chip_count + 9) / 10,
bus_type: BusType::Spi,
power_supply_watts: chip_count as f64 * 0.5 + 25.0,
form_factor: "1U rack mount (10 boards)",
},
101..=256 => Self {
chips_per_board: 16,
num_boards: (chip_count + 15) / 16,
bus_type: BusType::Uart,
power_supply_watts: chip_count as f64 * 0.5 + 50.0,
form_factor: "2U-4U rack mount",
},
257..=500 => Self {
chips_per_board: 20,
num_boards: (chip_count + 19) / 20,
bus_type: BusType::Uart,
power_supply_watts: chip_count as f64 * 0.5 + 75.0,
form_factor: "Full rack unit",
},
_ => Self {
chips_per_board: 25,
num_boards: (chip_count + 24) / 25,
bus_type: BusType::HighSpeed,
power_supply_watts: chip_count as f64 * 0.5 + 100.0,
form_factor: "Multi-rack datacenter",
},
}
}
}
/// Run complete analysis for 100-500 chip clusters
pub struct MediumScaleAnalyzer;
impl MediumScaleAnalyzer {
/// Compare all standard medium-scale configurations
pub fn full_analysis() -> HVec<(MediumClusterConfig, ScaleComparison), 8> {
let mut results = HVec::new();
for chips in [100, 144, 196, 256, 324, 400, 484, 500] {
if chips <= MEDIUM_SCALE_MAX {
let config = MediumClusterConfig::optimal_for(chips);
let comparison = ScaleComparison::analyze(chips);
let _ = results.push((config, comparison));
}
}
results
}
/// Find optimal configuration for target throughput
pub fn optimize_for_throughput(target_tokens_sec: f64) -> Option<MediumClusterConfig> {
// Binary search in medium scale range
let mut low = MEDIUM_SCALE_MIN;
let mut high = MEDIUM_SCALE_MAX;
let mut best: Option<MediumClusterConfig> = None;
while low <= high {
let mid = (low + high) / 2;
let config = MediumClusterConfig::optimal_for(mid);
if config.expected_throughput >= target_tokens_sec {
best = Some(config);
high = mid.saturating_sub(1);
} else {
low = mid + 1;
}
}
best
}
/// Find optimal configuration for target cost
pub fn optimize_for_budget(budget_usd: f64) -> MediumClusterConfig {
let max_chips = (budget_usd / 4.0) as usize; // $4 per chip
let clamped = max_chips.clamp(MEDIUM_SCALE_MIN, MEDIUM_SCALE_MAX);
MediumClusterConfig::optimal_for(clamped)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_optimal_config_100() {
let config = MediumClusterConfig::optimal_for(100);
assert_eq!(config.clusters, 10);
assert_eq!(config.chips_per_cluster, 10);
assert!(config.expected_throughput > 40000.0); // 40K+ tok/s
assert!(config.expected_efficiency > 0.5); // 50%+ efficiency
}
#[test]
fn test_optimal_config_256() {
let config = MediumClusterConfig::optimal_for(256);
assert_eq!(config.clusters, 16);
assert_eq!(config.chips_per_cluster, 16);
assert!(config.expected_throughput > 60000.0); // 60K+ tok/s
}
#[test]
fn test_scale_comparison() {
let comparison = ScaleComparison::analyze(256);
assert!(comparison.throughput_multiplier > 50.0); // 50x+ vs single chip
assert!(comparison.vs_small_multiplier > 10.0); // 10x+ vs 5 chips
}
#[test]
fn test_model_categories() {
assert_eq!(ModelCategory::for_chip_count(50).min_chips(), 50);
assert_eq!(ModelCategory::for_chip_count(256).min_chips(), 200);
}
#[test]
fn test_hardware_config() {
let hw = HardwareConfig::for_cluster(256);
assert_eq!(hw.chips_per_board, 16);
assert_eq!(hw.num_boards, 16);
assert!(hw.power_supply_watts > 100.0);
}
}

View File

@@ -0,0 +1,280 @@
//! Federation Module for Multi-ESP32 Distributed Inference
//!
//! Enables running larger models across multiple ESP32 chips:
//! - Pipeline parallelism: Each chip handles different layers
//! - Tensor parallelism: Split attention heads across chips
//! - Model sharding: Distribute embeddings/weights
//! - Speculative decoding: Draft on one chip, verify on others
//!
//! # Architecture Options
//!
//! ```text
//! 5-Chip Pipeline (recommended for latency):
//! ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐
//! │ ESP32-0 │───▶│ ESP32-1 │───▶│ ESP32-2 │───▶│ ESP32-3 │───▶│ ESP32-4 │
//! │ Embed + │ │ Layer 1 │ │ Layer 2 │ │ Layer 3 │ │ Layer 4 │
//! │ Layer 0 │ │ │ │ │ │ │ │ + Head │
//! └─────────┘ └─────────┘ └─────────┘ └─────────┘ └─────────┘
//!
//! 5-Chip Tensor Parallel (for throughput):
//! ┌─────────┐
//! │ ESP32-0 │ ◀──┐
//! │ Head 0 │ │
//! └─────────┘ │
//! ┌─────────┐ │ ┌─────────┐
//! │ ESP32-1 │ ◀──┼────│ ESP32-4 │
//! │ Head 1 │ │ │ Coord │
//! └─────────┘ │ └─────────┘
//! ┌─────────┐ │
//! │ ESP32-2 │ ◀──┤
//! │ Head 2 │ │
//! └─────────┘ │
//! ┌─────────┐ │
//! │ ESP32-3 │ ◀──┘
//! │ Head 3 │
//! └─────────┘
//! ```
pub mod pipeline;
pub mod tensor_parallel;
pub mod sharding;
pub mod speculative;
pub mod protocol;
pub mod coordinator;
pub mod fastgrnn_router;
pub mod massive_scale;
pub mod medium_scale;
// Re-exports
pub use pipeline::{PipelineNode, PipelineConfig, PipelineRole};
pub use tensor_parallel::{TensorParallelNode, TPConfig};
pub use sharding::{ShardedEmbedding, ShardConfig};
pub use speculative::{SpeculativeDecoder, DraftVerifyConfig};
pub use protocol::{FederationMessage, MessageType, ChipId};
pub use coordinator::{FederationCoordinator, ClusterTopology};
pub use fastgrnn_router::{MicroFastGRNN, MicroGRNNConfig, RoutingFeatures};
pub use massive_scale::{
MassiveTopology, MassiveScaleConfig, MassiveScaleSimulator, ScaleProjection,
DistributedCoordinator, GossipProtocol, FaultTolerance,
};
pub use medium_scale::{
MediumClusterConfig, ScaleComparison, MediumScaleAnalyzer,
ModelCategory, HardwareConfig, BusType,
MEDIUM_SCALE_MIN, MEDIUM_SCALE_MAX, MEDIUM_SCALE_OPTIMAL,
};
/// Maximum chips in small federation
pub const MAX_FEDERATION_SIZE: usize = 8;
/// Maximum chips in massive scale (theoretical)
pub const MAX_MASSIVE_SCALE: usize = 1_000_000;
/// Federation mode
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum FederationMode {
/// Single chip (no federation)
Standalone,
/// Pipeline parallelism - each chip handles different layers
Pipeline,
/// Tensor parallelism - split heads across chips
TensorParallel,
/// Hybrid: pipeline + tensor parallel
Hybrid,
/// Speculative decoding with draft/verify
Speculative,
/// Mixture of Experts - each chip is an expert
MixtureOfExperts,
}
/// Federation cluster configuration
#[derive(Debug, Clone)]
pub struct FederationConfig {
/// Number of chips in cluster
pub num_chips: usize,
/// This chip's ID (0-indexed)
pub chip_id: ChipId,
/// Federation mode
pub mode: FederationMode,
/// Communication bus type
pub bus: CommunicationBus,
/// Layers per chip (for pipeline mode)
pub layers_per_chip: usize,
/// Heads per chip (for tensor parallel mode)
pub heads_per_chip: usize,
/// Enable pipelining (process next token while current finishes)
pub enable_pipelining: bool,
}
impl Default for FederationConfig {
fn default() -> Self {
Self {
num_chips: 5,
chip_id: ChipId(0),
mode: FederationMode::Pipeline,
bus: CommunicationBus::Spi,
layers_per_chip: 2,
heads_per_chip: 1,
enable_pipelining: true,
}
}
}
/// Communication bus between chips
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum CommunicationBus {
/// SPI bus (fastest, 10-80 MHz)
Spi,
/// I2C bus (slower, 400 kHz - 1 MHz)
I2c,
/// UART (flexible, up to 5 Mbps)
Uart,
/// ESP-NOW (wireless, ~1 Mbps)
EspNow,
/// Custom parallel bus
Parallel,
}
impl CommunicationBus {
/// Estimated bandwidth in bytes/second
pub const fn bandwidth_bytes_per_sec(&self) -> usize {
match self {
Self::Spi => 10_000_000, // 10 MB/s at 80 MHz
Self::I2c => 100_000, // 100 KB/s at 1 MHz
Self::Uart => 500_000, // 500 KB/s at 5 Mbps
Self::EspNow => 125_000, // ~1 Mbps
Self::Parallel => 20_000_000, // Custom 8-bit parallel
}
}
/// Latency overhead in microseconds
pub const fn latency_us(&self) -> usize {
match self {
Self::Spi => 10,
Self::I2c => 50,
Self::Uart => 20,
Self::EspNow => 500, // Wireless overhead
Self::Parallel => 5,
}
}
}
/// Calculate optimal federation configuration for given model
pub fn calculate_optimal_config(
model_size_bytes: usize,
num_layers: usize,
num_heads: usize,
num_chips: usize,
per_chip_ram: usize,
) -> FederationConfig {
let model_per_chip = model_size_bytes / num_chips;
// Check if model fits with pipeline parallelism
if model_per_chip <= per_chip_ram {
let layers_per_chip = (num_layers + num_chips - 1) / num_chips;
return FederationConfig {
num_chips,
chip_id: ChipId(0),
mode: FederationMode::Pipeline,
bus: CommunicationBus::Spi,
layers_per_chip,
heads_per_chip: num_heads,
enable_pipelining: true,
};
}
// Try tensor parallelism
let heads_per_chip = (num_heads + num_chips - 1) / num_chips;
FederationConfig {
num_chips,
chip_id: ChipId(0),
mode: FederationMode::TensorParallel,
bus: CommunicationBus::Spi,
layers_per_chip: num_layers,
heads_per_chip,
enable_pipelining: false,
}
}
/// Estimate performance improvement from federation
pub fn estimate_speedup(config: &FederationConfig) -> FederationSpeedup {
let n = config.num_chips as f32;
match config.mode {
FederationMode::Standalone => FederationSpeedup {
throughput_multiplier: 1.0,
latency_reduction: 1.0,
memory_per_chip_reduction: 1.0,
},
FederationMode::Pipeline => FederationSpeedup {
// Pipeline: n-way throughput, slightly higher latency
throughput_multiplier: n * 0.85, // 85% efficiency due to bubble
latency_reduction: 1.0 / (1.0 + 0.1 * (n - 1.0)), // Slight increase
memory_per_chip_reduction: n,
},
FederationMode::TensorParallel => FederationSpeedup {
// TP: near-linear speedup on attention
throughput_multiplier: n * 0.7, // Communication overhead
latency_reduction: n * 0.7,
memory_per_chip_reduction: n * 0.8, // Some duplication
},
FederationMode::Hybrid => FederationSpeedup {
throughput_multiplier: n * 0.75,
latency_reduction: (n / 2.0) * 0.8,
memory_per_chip_reduction: n * 0.9,
},
FederationMode::Speculative => FederationSpeedup {
// Speculative: 2-4x speedup typical
throughput_multiplier: 2.5,
latency_reduction: 2.0,
memory_per_chip_reduction: 1.0, // Full model on draft chip
},
FederationMode::MixtureOfExperts => FederationSpeedup {
throughput_multiplier: n * 0.9, // Excellent scaling
latency_reduction: 1.5,
memory_per_chip_reduction: n,
},
}
}
/// Performance improvement estimates
#[derive(Debug, Clone)]
pub struct FederationSpeedup {
/// Throughput improvement (tokens/sec multiplier)
pub throughput_multiplier: f32,
/// Latency reduction (time per token)
pub latency_reduction: f32,
/// Memory reduction per chip
pub memory_per_chip_reduction: f32,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_optimal_config() {
let config = calculate_optimal_config(
500 * 1024, // 500 KB model
10, // 10 layers
4, // 4 heads
5, // 5 chips
120 * 1024, // 120 KB per chip
);
assert_eq!(config.mode, FederationMode::Pipeline);
assert_eq!(config.layers_per_chip, 2);
}
#[test]
fn test_speedup_estimate() {
let config = FederationConfig {
num_chips: 5,
mode: FederationMode::Pipeline,
..Default::default()
};
let speedup = estimate_speedup(&config);
assert!(speedup.throughput_multiplier > 4.0);
assert!(speedup.memory_per_chip_reduction >= 5.0);
}
}

View File

@@ -0,0 +1,387 @@
//! Pipeline Parallelism for Multi-ESP32 Inference
//!
//! Distributes layers across chips for linear scaling with model size.
//! Each chip processes its assigned layers and passes activations to the next.
//!
//! # 5-Chip Pipeline Example
//!
//! ```text
//! Token 0: [C0:embed+L0] → [C1:L1-2] → [C2:L3-4] → [C3:L5-6] → [C4:L7+head]
//! Token 1: idle [C0:embed] [C1:L1-2] [C2:L3-4] [C3:L5-6]
//! Token 2: idle idle [C0:embed] [C1:L1-2] [C2:L3-4]
//! ...
//! ```
use heapless::Vec as HVec;
use super::protocol::{ChipId, FederationMessage};
/// Maximum layers per chip
pub const MAX_LAYERS_PER_CHIP: usize = 4;
/// Pipeline depth (tokens in flight)
pub const MAX_PIPELINE_DEPTH: usize = 8;
/// Role in the pipeline
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum PipelineRole {
/// First chip: handles embedding + first layers
Head,
/// Middle chip: processes middle layers
Middle,
/// Last chip: final layers + output head
Tail,
/// Single chip mode (no pipeline)
Standalone,
}
/// Pipeline configuration
#[derive(Debug, Clone)]
pub struct PipelineConfig {
/// Total chips in pipeline
pub num_chips: usize,
/// This chip's position (0 = head)
pub position: usize,
/// Layers assigned to this chip
pub layer_start: usize,
/// Number of layers on this chip
pub layer_count: usize,
/// Total layers in model
pub total_layers: usize,
/// Embedding dimension
pub embed_dim: usize,
/// Enable micro-batching
pub micro_batch_size: usize,
}
impl PipelineConfig {
/// Create config for a specific chip in the pipeline
pub fn for_chip(
chip_pos: usize,
num_chips: usize,
total_layers: usize,
embed_dim: usize,
) -> Self {
let layers_per_chip = (total_layers + num_chips - 1) / num_chips;
let layer_start = chip_pos * layers_per_chip;
let layer_count = layers_per_chip.min(total_layers - layer_start);
Self {
num_chips,
position: chip_pos,
layer_start,
layer_count,
total_layers,
embed_dim,
micro_batch_size: 1,
}
}
/// Get role of this chip
pub fn role(&self) -> PipelineRole {
if self.num_chips == 1 {
PipelineRole::Standalone
} else if self.position == 0 {
PipelineRole::Head
} else if self.position == self.num_chips - 1 {
PipelineRole::Tail
} else {
PipelineRole::Middle
}
}
/// Previous chip in pipeline (if any)
pub fn prev_chip(&self) -> Option<ChipId> {
if self.position > 0 {
Some(ChipId((self.position - 1) as u8))
} else {
None
}
}
/// Next chip in pipeline (if any)
pub fn next_chip(&self) -> Option<ChipId> {
if self.position + 1 < self.num_chips {
Some(ChipId((self.position + 1) as u8))
} else {
None
}
}
}
/// Pipeline state for a chip
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum PipelineState {
/// Waiting for input from previous chip
WaitingInput,
/// Processing layers
Processing,
/// Waiting to send output
WaitingSend,
/// Idle (pipeline bubble)
Idle,
}
/// In-flight token tracking
#[derive(Debug, Clone)]
pub struct InFlightToken {
/// Sequence position
pub seq_pos: u16,
/// Token ID
pub token_id: u16,
/// Current layer being processed
pub current_layer: u8,
/// Activation data (INT8)
pub activation: HVec<i8, 128>,
}
/// Pipeline node managing this chip's portion
pub struct PipelineNode {
/// Configuration
config: PipelineConfig,
/// Current state
state: PipelineState,
/// Chip ID
chip_id: ChipId,
/// Sequence counter
seq_counter: u16,
/// Tokens in flight in the pipeline
in_flight: HVec<InFlightToken, MAX_PIPELINE_DEPTH>,
/// Completed tokens waiting to send
output_queue: HVec<InFlightToken, MAX_PIPELINE_DEPTH>,
/// Input buffer for receiving activations
input_buffer: HVec<i8, 256>,
/// Barrier counter for synchronization
barrier_counter: u16,
}
impl PipelineNode {
/// Create new pipeline node
pub fn new(config: PipelineConfig) -> Self {
Self {
chip_id: ChipId(config.position as u8),
config,
state: PipelineState::Idle,
seq_counter: 0,
in_flight: HVec::new(),
output_queue: HVec::new(),
input_buffer: HVec::new(),
barrier_counter: 0,
}
}
/// Get current pipeline state
pub fn state(&self) -> PipelineState {
self.state
}
/// Check if this chip should handle embedding
pub fn handles_embedding(&self) -> bool {
self.config.role() == PipelineRole::Head ||
self.config.role() == PipelineRole::Standalone
}
/// Check if this chip should handle output head
pub fn handles_output(&self) -> bool {
self.config.role() == PipelineRole::Tail ||
self.config.role() == PipelineRole::Standalone
}
/// Start processing a new token (head chip only)
pub fn start_token(&mut self, token_id: u16) -> crate::Result<()> {
if !self.handles_embedding() {
return Err(crate::Error::UnsupportedFeature("Not head chip"));
}
if self.in_flight.len() >= MAX_PIPELINE_DEPTH {
return Err(crate::Error::BufferOverflow);
}
let token = InFlightToken {
seq_pos: self.seq_counter,
token_id,
current_layer: 0,
activation: HVec::new(),
};
self.in_flight.push(token).map_err(|_| crate::Error::BufferOverflow)?;
self.seq_counter += 1;
self.state = PipelineState::Processing;
Ok(())
}
/// Receive activation from previous chip
pub fn receive_activation(&mut self, msg: &FederationMessage) -> crate::Result<()> {
let (layer_idx, position, data) = msg.get_activation_data()
.ok_or(crate::Error::InvalidModel("Invalid activation message"))?;
// Create in-flight token from received data
let mut activation = HVec::new();
for &d in data {
activation.push(d as i8).map_err(|_| crate::Error::BufferOverflow)?;
}
let token = InFlightToken {
seq_pos: position,
token_id: 0, // Not needed for middle/tail chips
current_layer: layer_idx,
activation,
};
self.in_flight.push(token).map_err(|_| crate::Error::BufferOverflow)?;
self.state = PipelineState::Processing;
Ok(())
}
/// Process one step (one layer for one token)
/// Returns true if there's work to do
pub fn process_step<F>(&mut self, mut layer_fn: F) -> crate::Result<bool>
where
F: FnMut(usize, &mut [i8]) -> crate::Result<()>,
{
if self.in_flight.is_empty() {
self.state = PipelineState::WaitingInput;
return Ok(false);
}
// Process first token in queue
let token = &mut self.in_flight[0];
// Determine which layer to process
let relative_layer = token.current_layer as usize - self.config.layer_start;
if relative_layer < self.config.layer_count {
// Process this layer
let layer_idx = self.config.layer_start + relative_layer;
layer_fn(layer_idx, &mut token.activation)?;
token.current_layer += 1;
}
// Check if done with this chip's layers
let next_layer = token.current_layer as usize;
if next_layer >= self.config.layer_start + self.config.layer_count {
// Move to output queue
if let Some(completed) = self.in_flight.pop() {
self.output_queue.push(completed).map_err(|_| crate::Error::BufferOverflow)?;
}
self.state = PipelineState::WaitingSend;
}
Ok(true)
}
/// Get activation to send to next chip
pub fn get_output(&mut self) -> Option<FederationMessage> {
if self.output_queue.is_empty() {
return None;
}
let token = self.output_queue.pop()?;
let next_chip = self.config.next_chip()?;
// Convert activation to bytes
let data: Vec<i8> = token.activation.iter().cloned().collect();
FederationMessage::activation(
self.chip_id,
next_chip,
token.seq_pos,
token.current_layer,
token.seq_pos,
&data,
).ok()
}
/// Check if output is available (for tail chip)
pub fn has_final_output(&self) -> bool {
self.handles_output() && !self.output_queue.is_empty()
}
/// Get final output logits (tail chip only)
pub fn get_final_output(&mut self) -> Option<HVec<i8, 128>> {
if !self.handles_output() {
return None;
}
let token = self.output_queue.pop()?;
Some(token.activation)
}
/// Get pipeline statistics
pub fn stats(&self) -> PipelineStats {
PipelineStats {
in_flight_count: self.in_flight.len(),
output_queue_len: self.output_queue.len(),
tokens_processed: self.seq_counter as usize,
current_state: self.state,
}
}
/// Create synchronization barrier
pub fn create_barrier(&mut self) -> FederationMessage {
self.barrier_counter += 1;
FederationMessage::barrier(self.chip_id, self.barrier_counter)
}
}
/// Pipeline statistics
#[derive(Debug, Clone)]
pub struct PipelineStats {
/// Tokens currently in pipeline
pub in_flight_count: usize,
/// Tokens waiting to send
pub output_queue_len: usize,
/// Total tokens processed
pub tokens_processed: usize,
/// Current state
pub current_state: PipelineState,
}
/// Calculate pipeline efficiency
pub fn calculate_pipeline_efficiency(
num_chips: usize,
tokens_generated: usize,
) -> f32 {
// Pipeline efficiency = useful work / total work
// With N chips, first N-1 tokens have bubble overhead
if tokens_generated <= num_chips {
tokens_generated as f32 / (num_chips as f32 * tokens_generated as f32)
} else {
// After warmup, efficiency approaches 100%
let warmup_overhead = (num_chips - 1) as f32;
let useful_work = tokens_generated as f32;
useful_work / (useful_work + warmup_overhead)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pipeline_config() {
// 5 chips, 10 layers
let config = PipelineConfig::for_chip(0, 5, 10, 64);
assert_eq!(config.role(), PipelineRole::Head);
assert_eq!(config.layer_start, 0);
assert_eq!(config.layer_count, 2);
let config = PipelineConfig::for_chip(2, 5, 10, 64);
assert_eq!(config.role(), PipelineRole::Middle);
assert_eq!(config.layer_start, 4);
let config = PipelineConfig::for_chip(4, 5, 10, 64);
assert_eq!(config.role(), PipelineRole::Tail);
}
#[test]
fn test_pipeline_efficiency() {
// After 100 tokens, efficiency should be high
let eff = calculate_pipeline_efficiency(5, 100);
assert!(eff > 0.95);
// During warmup, efficiency is lower
let eff_warmup = calculate_pipeline_efficiency(5, 5);
assert!(eff_warmup < 0.5);
}
}

View File

@@ -0,0 +1,414 @@
//! Inter-Chip Communication Protocol
//!
//! Defines the message format for ESP32-to-ESP32 communication.
//! Designed for low overhead on SPI/I2C/UART buses.
use heapless::Vec as HVec;
/// Maximum activation size that can be sent in one message
pub const MAX_ACTIVATION_SIZE: usize = 256;
/// Maximum message payload
pub const MAX_PAYLOAD_SIZE: usize = 512;
/// Protocol version
pub const PROTOCOL_VERSION: u8 = 1;
/// Chip identifier in the federation
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub struct ChipId(pub u8);
impl ChipId {
pub const BROADCAST: ChipId = ChipId(0xFF);
pub fn is_broadcast(&self) -> bool {
self.0 == 0xFF
}
}
/// Message types for federation protocol
#[derive(Debug, Clone, Copy, PartialEq)]
#[repr(u8)]
pub enum MessageType {
/// Heartbeat / keep-alive
Heartbeat = 0x00,
/// Cluster discovery
Discovery = 0x01,
/// Ready signal
Ready = 0x02,
/// Forward pass activation data
Activation = 0x10,
/// Attention K/V cache update
KVCache = 0x11,
/// Gradient (for future training)
Gradient = 0x12,
/// Token embedding request
EmbedRequest = 0x20,
/// Token embedding response
EmbedResponse = 0x21,
/// Output logits
Logits = 0x22,
/// Sampled token
Token = 0x23,
/// Speculative draft tokens
DraftTokens = 0x30,
/// Verification result
VerifyResult = 0x31,
/// Synchronization barrier
Barrier = 0x40,
/// Acknowledgment
Ack = 0x41,
/// Error
Error = 0xFF,
}
impl From<u8> for MessageType {
fn from(v: u8) -> Self {
match v {
0x00 => Self::Heartbeat,
0x01 => Self::Discovery,
0x02 => Self::Ready,
0x10 => Self::Activation,
0x11 => Self::KVCache,
0x12 => Self::Gradient,
0x20 => Self::EmbedRequest,
0x21 => Self::EmbedResponse,
0x22 => Self::Logits,
0x23 => Self::Token,
0x30 => Self::DraftTokens,
0x31 => Self::VerifyResult,
0x40 => Self::Barrier,
0x41 => Self::Ack,
_ => Self::Error,
}
}
}
/// Message header (8 bytes)
#[derive(Debug, Clone, Copy)]
#[repr(C, packed)]
pub struct MessageHeader {
/// Protocol version
pub version: u8,
/// Message type
pub msg_type: u8,
/// Source chip ID
pub src: u8,
/// Destination chip ID
pub dst: u8,
/// Sequence number (for ordering)
pub seq: u16,
/// Payload length
pub payload_len: u16,
}
impl MessageHeader {
pub const SIZE: usize = 8;
pub fn new(msg_type: MessageType, src: ChipId, dst: ChipId, seq: u16, payload_len: u16) -> Self {
Self {
version: PROTOCOL_VERSION,
msg_type: msg_type as u8,
src: src.0,
dst: dst.0,
seq,
payload_len,
}
}
/// Serialize to bytes
pub fn to_bytes(&self) -> [u8; 8] {
[
self.version,
self.msg_type,
self.src,
self.dst,
(self.seq & 0xFF) as u8,
(self.seq >> 8) as u8,
(self.payload_len & 0xFF) as u8,
(self.payload_len >> 8) as u8,
]
}
/// Deserialize from bytes
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
if bytes.len() < 8 {
return None;
}
Some(Self {
version: bytes[0],
msg_type: bytes[1],
src: bytes[2],
dst: bytes[3],
seq: (bytes[4] as u16) | ((bytes[5] as u16) << 8),
payload_len: (bytes[6] as u16) | ((bytes[7] as u16) << 8),
})
}
/// Calculate simple checksum
pub fn checksum(&self) -> u8 {
let bytes = self.to_bytes();
bytes.iter().fold(0u8, |acc, &b| acc.wrapping_add(b))
}
}
/// Complete federation message
#[derive(Debug, Clone)]
pub struct FederationMessage {
/// Message header
pub header: MessageHeader,
/// Payload data
pub payload: HVec<u8, MAX_PAYLOAD_SIZE>,
/// Checksum
pub checksum: u8,
}
impl FederationMessage {
/// Create new message
pub fn new(msg_type: MessageType, src: ChipId, dst: ChipId, seq: u16) -> Self {
Self {
header: MessageHeader::new(msg_type, src, dst, seq, 0),
payload: HVec::new(),
checksum: 0,
}
}
/// Create activation message with INT8 data
pub fn activation(
src: ChipId,
dst: ChipId,
seq: u16,
layer_idx: u8,
position: u16,
data: &[i8],
) -> crate::Result<Self> {
let mut msg = Self::new(MessageType::Activation, src, dst, seq);
// Payload format: [layer_idx:1][position:2][data:N]
msg.payload.push(layer_idx).map_err(|_| crate::Error::BufferOverflow)?;
msg.payload.push((position & 0xFF) as u8).map_err(|_| crate::Error::BufferOverflow)?;
msg.payload.push((position >> 8) as u8).map_err(|_| crate::Error::BufferOverflow)?;
for &d in data {
msg.payload.push(d as u8).map_err(|_| crate::Error::BufferOverflow)?;
}
msg.header.payload_len = msg.payload.len() as u16;
msg.update_checksum();
Ok(msg)
}
/// Create token message
pub fn token(src: ChipId, dst: ChipId, seq: u16, token_id: u16) -> Self {
let mut msg = Self::new(MessageType::Token, src, dst, seq);
let _ = msg.payload.push((token_id & 0xFF) as u8);
let _ = msg.payload.push((token_id >> 8) as u8);
msg.header.payload_len = 2;
msg.update_checksum();
msg
}
/// Create draft tokens message for speculative decoding
pub fn draft_tokens(src: ChipId, dst: ChipId, seq: u16, tokens: &[u16]) -> crate::Result<Self> {
let mut msg = Self::new(MessageType::DraftTokens, src, dst, seq);
msg.payload.push(tokens.len() as u8).map_err(|_| crate::Error::BufferOverflow)?;
for &t in tokens {
msg.payload.push((t & 0xFF) as u8).map_err(|_| crate::Error::BufferOverflow)?;
msg.payload.push((t >> 8) as u8).map_err(|_| crate::Error::BufferOverflow)?;
}
msg.header.payload_len = msg.payload.len() as u16;
msg.update_checksum();
Ok(msg)
}
/// Create barrier synchronization message
pub fn barrier(src: ChipId, barrier_id: u16) -> Self {
let mut msg = Self::new(MessageType::Barrier, src, ChipId::BROADCAST, 0);
let _ = msg.payload.push((barrier_id & 0xFF) as u8);
let _ = msg.payload.push((barrier_id >> 8) as u8);
msg.header.payload_len = 2;
msg.update_checksum();
msg
}
/// Update checksum
pub fn update_checksum(&mut self) {
let mut sum = self.header.checksum();
for &b in &self.payload {
sum = sum.wrapping_add(b);
}
self.checksum = sum;
}
/// Verify checksum
pub fn verify_checksum(&self) -> bool {
let mut sum = self.header.checksum();
for &b in &self.payload {
sum = sum.wrapping_add(b);
}
sum == self.checksum
}
/// Serialize to bytes
pub fn to_bytes(&self) -> HVec<u8, { MAX_PAYLOAD_SIZE + 16 }> {
let mut bytes = HVec::new();
// Header
for b in self.header.to_bytes() {
let _ = bytes.push(b);
}
// Payload
for &b in &self.payload {
let _ = bytes.push(b);
}
// Checksum
let _ = bytes.push(self.checksum);
bytes
}
/// Deserialize from bytes
pub fn from_bytes(bytes: &[u8]) -> crate::Result<Self> {
if bytes.len() < MessageHeader::SIZE + 1 {
return Err(crate::Error::InvalidModel("Message too short"));
}
let header = MessageHeader::from_bytes(bytes)
.ok_or(crate::Error::InvalidModel("Invalid header"))?;
let payload_end = MessageHeader::SIZE + header.payload_len as usize;
if bytes.len() < payload_end + 1 {
return Err(crate::Error::InvalidModel("Payload incomplete"));
}
let mut payload = HVec::new();
for &b in &bytes[MessageHeader::SIZE..payload_end] {
payload.push(b).map_err(|_| crate::Error::BufferOverflow)?;
}
let checksum = bytes[payload_end];
let msg = Self {
header,
payload,
checksum,
};
if !msg.verify_checksum() {
return Err(crate::Error::InvalidModel("Checksum mismatch"));
}
Ok(msg)
}
/// Extract activation data from payload
pub fn get_activation_data(&self) -> Option<(u8, u16, &[u8])> {
if self.header.msg_type != MessageType::Activation as u8 {
return None;
}
if self.payload.len() < 3 {
return None;
}
let layer_idx = self.payload[0];
let position = (self.payload[1] as u16) | ((self.payload[2] as u16) << 8);
let data = &self.payload[3..];
Some((layer_idx, position, data))
}
/// Extract token from payload
pub fn get_token(&self) -> Option<u16> {
if self.header.msg_type != MessageType::Token as u8 {
return None;
}
if self.payload.len() < 2 {
return None;
}
Some((self.payload[0] as u16) | ((self.payload[1] as u16) << 8))
}
}
/// Communication statistics
#[derive(Debug, Default, Clone)]
pub struct CommStats {
/// Messages sent
pub messages_sent: u32,
/// Messages received
pub messages_received: u32,
/// Bytes sent
pub bytes_sent: u32,
/// Bytes received
pub bytes_received: u32,
/// Checksum errors
pub checksum_errors: u32,
/// Timeouts
pub timeouts: u32,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_header() {
let header = MessageHeader::new(
MessageType::Activation,
ChipId(0),
ChipId(1),
42,
100,
);
let bytes = header.to_bytes();
let decoded = MessageHeader::from_bytes(&bytes).unwrap();
assert_eq!(decoded.msg_type, MessageType::Activation as u8);
assert_eq!(decoded.src, 0);
assert_eq!(decoded.dst, 1);
// Copy packed fields to avoid UB from unaligned references
let seq = decoded.seq;
let payload_len = decoded.payload_len;
assert_eq!(seq, 42);
assert_eq!(payload_len, 100);
}
#[test]
fn test_activation_message() {
let data: [i8; 8] = [1, 2, 3, 4, 5, 6, 7, 8];
let msg = FederationMessage::activation(
ChipId(0),
ChipId(1),
1,
0,
10,
&data,
).unwrap();
let bytes = msg.to_bytes();
let decoded = FederationMessage::from_bytes(&bytes).unwrap();
let (layer, pos, act_data) = decoded.get_activation_data().unwrap();
assert_eq!(layer, 0);
assert_eq!(pos, 10);
assert_eq!(act_data.len(), 8);
}
#[test]
fn test_token_message() {
let msg = FederationMessage::token(ChipId(4), ChipId(0), 100, 12345);
let bytes = msg.to_bytes();
let decoded = FederationMessage::from_bytes(&bytes).unwrap();
assert_eq!(decoded.get_token(), Some(12345));
}
}

View File

@@ -0,0 +1,143 @@
//! Embedding Sharding - Distribute Vocabulary Across Chips
//!
//! For large vocabularies, shard embeddings across chips.
//! Each chip holds a portion of the embedding table.
use heapless::Vec as HVec;
use super::protocol::ChipId;
/// Sharding configuration
#[derive(Debug, Clone)]
pub struct ShardConfig {
/// Total vocabulary size
pub vocab_size: usize,
/// Number of shards (chips)
pub num_shards: usize,
/// This chip's shard ID
pub shard_id: usize,
/// Embedding dimension
pub embed_dim: usize,
/// Vocab range for this shard
pub vocab_start: usize,
pub vocab_end: usize,
}
impl ShardConfig {
/// Create config for a specific shard
pub fn for_shard(
shard_id: usize,
num_shards: usize,
vocab_size: usize,
embed_dim: usize,
) -> Self {
let vocab_per_shard = (vocab_size + num_shards - 1) / num_shards;
let vocab_start = shard_id * vocab_per_shard;
let vocab_end = (vocab_start + vocab_per_shard).min(vocab_size);
Self {
vocab_size,
num_shards,
shard_id,
embed_dim,
vocab_start,
vocab_end,
}
}
/// Check if this shard handles a token
pub fn handles_token(&self, token_id: u16) -> bool {
let t = token_id as usize;
t >= self.vocab_start && t < self.vocab_end
}
/// Get shard that handles a token
pub fn shard_for_token(token_id: u16, num_shards: usize, vocab_size: usize) -> usize {
let vocab_per_shard = (vocab_size + num_shards - 1) / num_shards;
(token_id as usize) / vocab_per_shard
}
/// Vocab size for this shard
pub fn shard_vocab_size(&self) -> usize {
self.vocab_end - self.vocab_start
}
}
/// Sharded embedding table
pub struct ShardedEmbedding<const MAX_VOCAB: usize, const DIM: usize> {
config: ShardConfig,
/// Local embedding weights (only our shard)
weights: HVec<i8, 8192>, // Max 8KB per shard
}
impl<const MAX_VOCAB: usize, const DIM: usize> ShardedEmbedding<MAX_VOCAB, DIM> {
/// Create sharded embedding
pub fn new(config: ShardConfig, seed: u32) -> crate::Result<Self> {
let shard_size = config.shard_vocab_size() * config.embed_dim;
let mut weights = HVec::new();
let mut rng_state = seed.wrapping_add(config.shard_id as u32 * 12345);
for _ in 0..shard_size {
rng_state = rng_state.wrapping_mul(1103515245).wrapping_add(12345);
let val = (((rng_state >> 16) & 0xFF) as i16 - 128) as i8;
weights.push(val).map_err(|_| crate::Error::BufferOverflow)?;
}
Ok(Self { config, weights })
}
/// Lookup embedding (only works if we have the token)
pub fn lookup(&self, token_id: u16, output: &mut [i8]) -> crate::Result<bool> {
if !self.config.handles_token(token_id) {
return Ok(false);
}
let local_idx = token_id as usize - self.config.vocab_start;
let start = local_idx * self.config.embed_dim;
let end = start + self.config.embed_dim;
if end > self.weights.len() || output.len() < self.config.embed_dim {
return Err(crate::Error::BufferOverflow);
}
output[..self.config.embed_dim].copy_from_slice(&self.weights[start..end]);
Ok(true)
}
/// Memory per shard vs full embedding
pub fn memory_saved(&self) -> f32 {
self.config.num_shards as f32
}
/// Get responsible chip for a token
pub fn responsible_chip(&self, token_id: u16) -> ChipId {
let shard = ShardConfig::shard_for_token(
token_id,
self.config.num_shards,
self.config.vocab_size,
);
ChipId(shard as u8)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sharding() {
// 1000 vocab, 5 shards
let config = ShardConfig::for_shard(2, 5, 1000, 32);
assert_eq!(config.vocab_start, 400);
assert_eq!(config.vocab_end, 600);
assert!(config.handles_token(450));
assert!(!config.handles_token(300));
}
#[test]
fn test_shard_lookup() {
let shard = ShardConfig::shard_for_token(450, 5, 1000);
assert_eq!(shard, 2);
}
}

View File

@@ -0,0 +1,294 @@
//! Speculative Decoding - Draft and Verify
//!
//! Use a smaller/faster model to draft tokens, verify with larger model.
//! Perfect for federated setup: one chip drafts, others verify in parallel.
//!
//! # Benefits
//! - 2-4x speedup for autoregressive generation
//! - Maintains exact output quality
//! - Natural fit for multi-chip setup
use heapless::Vec as HVec;
use super::protocol::{ChipId, FederationMessage};
/// Maximum draft tokens per batch
pub const MAX_DRAFT_TOKENS: usize = 8;
/// Speculative decoding configuration
#[derive(Debug, Clone)]
pub struct DraftVerifyConfig {
/// Number of draft tokens to generate
pub draft_length: usize,
/// Acceptance threshold (0.0-1.0)
pub acceptance_threshold: f32,
/// Draft chip ID (usually chip 0)
pub draft_chip: ChipId,
/// Verify chips (all others)
pub verify_chips: HVec<ChipId, 4>,
/// Enable adaptive draft length
pub adaptive: bool,
}
impl Default for DraftVerifyConfig {
fn default() -> Self {
Self {
draft_length: 4,
acceptance_threshold: 0.9,
draft_chip: ChipId(0),
verify_chips: HVec::new(),
adaptive: true,
}
}
}
impl DraftVerifyConfig {
/// Create config for 5-chip setup
pub fn for_five_chips() -> Self {
let mut verify_chips = HVec::new();
for i in 1..5 {
let _ = verify_chips.push(ChipId(i));
}
Self {
draft_length: 4,
acceptance_threshold: 0.9,
draft_chip: ChipId(0),
verify_chips,
adaptive: true,
}
}
}
/// Draft result from drafting chip
#[derive(Debug, Clone)]
pub struct DraftResult {
/// Draft token IDs
pub tokens: HVec<u16, MAX_DRAFT_TOKENS>,
/// Draft token probabilities (fixed-point, 0-255)
pub probs: HVec<u8, MAX_DRAFT_TOKENS>,
/// Starting position
pub start_pos: u16,
}
/// Verification result from verifying chip
#[derive(Debug, Clone)]
pub struct VerifyResult {
/// Number of accepted tokens
pub accepted_count: usize,
/// Correct token for first rejection (if any)
pub correction: Option<u16>,
/// Verification probabilities
pub verify_probs: HVec<u8, MAX_DRAFT_TOKENS>,
}
/// Speculative decoder
pub struct SpeculativeDecoder {
config: DraftVerifyConfig,
/// Is this the draft chip?
is_draft_chip: bool,
/// Current acceptance rate (for adaptive)
acceptance_rate: f32,
/// Draft tokens waiting for verification
pending_draft: Option<DraftResult>,
/// Statistics
stats: SpecStats,
}
impl SpeculativeDecoder {
/// Create for a specific chip
pub fn new(config: DraftVerifyConfig, chip_id: ChipId) -> Self {
let is_draft_chip = chip_id == config.draft_chip;
Self {
config,
is_draft_chip,
acceptance_rate: 0.9,
pending_draft: None,
stats: SpecStats::default(),
}
}
/// Check if this is the drafting chip
pub fn is_drafter(&self) -> bool {
self.is_draft_chip
}
/// Submit draft tokens (drafter only)
pub fn submit_draft(&mut self, draft: DraftResult) -> crate::Result<FederationMessage> {
if !self.is_draft_chip {
return Err(crate::Error::UnsupportedFeature("Not draft chip"));
}
// Create message to broadcast to verify chips
let tokens: Vec<u16> = draft.tokens.iter().cloned().collect();
let msg = FederationMessage::draft_tokens(
self.config.draft_chip,
ChipId::BROADCAST,
draft.start_pos,
&tokens,
)?;
self.pending_draft = Some(draft);
self.stats.drafts_sent += 1;
Ok(msg)
}
/// Verify draft tokens (verifier only)
pub fn verify_draft<F>(
&mut self,
draft: &DraftResult,
mut get_prob: F,
) -> VerifyResult
where
F: FnMut(u16, u16) -> u8, // (position, token) -> probability
{
let mut accepted_count = 0;
let mut correction = None;
let mut verify_probs = HVec::new();
for (i, &token) in draft.tokens.iter().enumerate() {
let pos = draft.start_pos + i as u16;
let verify_prob = get_prob(pos, token);
let _ = verify_probs.push(verify_prob);
let draft_prob = draft.probs.get(i).copied().unwrap_or(128);
// Acceptance criterion: verify_prob >= draft_prob * threshold
let threshold = (draft_prob as f32 * self.config.acceptance_threshold) as u8;
if verify_prob >= threshold {
accepted_count += 1;
} else {
// Rejection - sample correct token
// In real impl, would sample from verify distribution
correction = Some(token.wrapping_add(1)); // Placeholder
break;
}
}
VerifyResult {
accepted_count,
correction,
verify_probs,
}
}
/// Process verification result (drafter)
pub fn process_verification(&mut self, result: &VerifyResult) -> HVec<u16, MAX_DRAFT_TOKENS> {
let mut accepted_tokens = HVec::new();
if let Some(ref draft) = self.pending_draft {
// Accept tokens up to rejection point
for i in 0..result.accepted_count {
if let Some(&token) = draft.tokens.get(i) {
let _ = accepted_tokens.push(token);
}
}
// Add correction if any
if let Some(correct_token) = result.correction {
let _ = accepted_tokens.push(correct_token);
}
self.stats.tokens_accepted += result.accepted_count;
self.stats.tokens_rejected += draft.tokens.len() - result.accepted_count;
// Update acceptance rate
let batch_rate = result.accepted_count as f32 / draft.tokens.len() as f32;
self.acceptance_rate = 0.9 * self.acceptance_rate + 0.1 * batch_rate;
}
self.pending_draft = None;
accepted_tokens
}
/// Get adaptive draft length based on acceptance rate
pub fn adaptive_draft_length(&self) -> usize {
if !self.config.adaptive {
return self.config.draft_length;
}
// Higher acceptance -> longer drafts
if self.acceptance_rate > 0.95 {
(self.config.draft_length + 2).min(MAX_DRAFT_TOKENS)
} else if self.acceptance_rate > 0.8 {
self.config.draft_length
} else if self.acceptance_rate > 0.5 {
(self.config.draft_length - 1).max(1)
} else {
1 // Fall back to no speculation
}
}
/// Get speedup estimate
pub fn estimated_speedup(&self) -> f32 {
// Speedup = accepted_tokens / (1 + verify_overhead)
let avg_accepted = self.acceptance_rate * self.adaptive_draft_length() as f32;
let verify_overhead = 0.2; // Verification overhead
avg_accepted / (1.0 + verify_overhead)
}
/// Get statistics
pub fn stats(&self) -> &SpecStats {
&self.stats
}
}
/// Speculative decoding statistics
#[derive(Debug, Default, Clone)]
pub struct SpecStats {
/// Total draft batches sent
pub drafts_sent: usize,
/// Total tokens accepted
pub tokens_accepted: usize,
/// Total tokens rejected
pub tokens_rejected: usize,
}
impl SpecStats {
/// Overall acceptance rate
pub fn acceptance_rate(&self) -> f32 {
let total = self.tokens_accepted + self.tokens_rejected;
if total == 0 {
0.0
} else {
self.tokens_accepted as f32 / total as f32
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_speculative_config() {
let config = DraftVerifyConfig::for_five_chips();
assert_eq!(config.draft_chip, ChipId(0));
assert_eq!(config.verify_chips.len(), 4);
}
#[test]
fn test_verify_draft() {
let config = DraftVerifyConfig::default();
let mut decoder = SpeculativeDecoder::new(config, ChipId(1));
let mut draft = DraftResult {
tokens: HVec::new(),
probs: HVec::new(),
start_pos: 0,
};
let _ = draft.tokens.push(100);
let _ = draft.tokens.push(101);
let _ = draft.probs.push(200);
let _ = draft.probs.push(200);
let result = decoder.verify_draft(&draft, |_pos, _token| 190);
// Both should be accepted (190 >= 200 * 0.9 = 180)
assert_eq!(result.accepted_count, 2);
assert!(result.correction.is_none());
}
}

View File

@@ -0,0 +1,144 @@
//! Tensor Parallelism - Distributed Attention Heads
//!
//! Splits attention heads across chips for parallel computation.
//! Each chip handles a subset of heads, then results are combined.
use heapless::Vec as HVec;
use super::protocol::{ChipId, FederationMessage};
/// Maximum heads per chip
pub const MAX_HEADS_PER_CHIP: usize = 4;
/// Tensor parallel configuration
#[derive(Debug, Clone)]
pub struct TPConfig {
/// Number of chips
pub num_chips: usize,
/// This chip's ID
pub chip_id: ChipId,
/// Total attention heads
pub total_heads: usize,
/// Heads handled by this chip
pub my_heads: HVec<usize, MAX_HEADS_PER_CHIP>,
/// Embedding dimension per head
pub head_dim: usize,
}
impl TPConfig {
/// Create config distributing heads across chips
pub fn distribute_heads(
chip_id: usize,
num_chips: usize,
total_heads: usize,
head_dim: usize,
) -> Self {
let mut my_heads = HVec::new();
// Assign heads round-robin style
for h in 0..total_heads {
if h % num_chips == chip_id {
let _ = my_heads.push(h);
}
}
Self {
num_chips,
chip_id: ChipId(chip_id as u8),
total_heads,
my_heads,
head_dim,
}
}
}
/// Tensor parallel attention node
pub struct TensorParallelNode {
config: TPConfig,
/// Partial attention outputs from each head
partial_outputs: HVec<HVec<i32, 64>, MAX_HEADS_PER_CHIP>,
/// Combined output buffer
output_buffer: HVec<i32, 256>,
}
impl TensorParallelNode {
pub fn new(config: TPConfig) -> Self {
Self {
config,
partial_outputs: HVec::new(),
output_buffer: HVec::new(),
}
}
/// Get heads this chip handles
pub fn my_heads(&self) -> &[usize] {
&self.config.my_heads
}
/// Compute partial attention for assigned heads
pub fn compute_partial_attention(
&mut self,
query: &[i8],
keys: &[&[i8]],
values: &[&[i8]],
) -> crate::Result<()> {
self.partial_outputs.clear();
for &head_idx in &self.config.my_heads {
let mut head_output = HVec::new();
// Compute Q @ K^T for this head
let head_start = head_idx * self.config.head_dim;
let head_end = head_start + self.config.head_dim;
// Simplified attention: just dot product for now
for &val in &values[0][head_start..head_end.min(values[0].len())] {
head_output.push(val as i32).map_err(|_| crate::Error::BufferOverflow)?;
}
self.partial_outputs.push(head_output).map_err(|_| crate::Error::BufferOverflow)?;
}
Ok(())
}
/// Create message with partial results
pub fn create_partial_result_message(&self, dst: ChipId, seq: u16) -> crate::Result<FederationMessage> {
let mut data: Vec<i8> = Vec::new();
for partial in &self.partial_outputs {
for &val in partial {
data.push((val >> 8) as i8); // Scale down
}
}
FederationMessage::activation(
self.config.chip_id,
dst,
seq,
0, // Not layer-based
0,
&data,
)
}
/// Memory saved vs single-chip
pub fn memory_reduction(&self) -> f32 {
self.config.num_chips as f32
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_head_distribution() {
// 4 heads across 5 chips
let config0 = TPConfig::distribute_heads(0, 5, 4, 16);
let config1 = TPConfig::distribute_heads(1, 5, 4, 16);
// Chip 0 gets head 0, chip 1 gets head 1, etc.
assert_eq!(config0.my_heads.as_slice(), &[0]);
assert_eq!(config1.my_heads.as_slice(), &[1]);
}
}