Files
wifi-densepose/examples/ruvLLM/esp32/src/federation/tensor_parallel.rs
ruv d803bfe2b1 Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector
git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
2026-02-28 14:39:40 -05:00

145 lines
3.8 KiB
Rust

//! Tensor Parallelism - Distributed Attention Heads
//!
//! Splits attention heads across chips for parallel computation.
//! Each chip handles a subset of heads, then results are combined.
use heapless::Vec as HVec;
use super::protocol::{ChipId, FederationMessage};
/// Maximum heads per chip
pub const MAX_HEADS_PER_CHIP: usize = 4;
/// Tensor parallel configuration
#[derive(Debug, Clone)]
pub struct TPConfig {
/// Number of chips
pub num_chips: usize,
/// This chip's ID
pub chip_id: ChipId,
/// Total attention heads
pub total_heads: usize,
/// Heads handled by this chip
pub my_heads: HVec<usize, MAX_HEADS_PER_CHIP>,
/// Embedding dimension per head
pub head_dim: usize,
}
impl TPConfig {
/// Create config distributing heads across chips
pub fn distribute_heads(
chip_id: usize,
num_chips: usize,
total_heads: usize,
head_dim: usize,
) -> Self {
let mut my_heads = HVec::new();
// Assign heads round-robin style
for h in 0..total_heads {
if h % num_chips == chip_id {
let _ = my_heads.push(h);
}
}
Self {
num_chips,
chip_id: ChipId(chip_id as u8),
total_heads,
my_heads,
head_dim,
}
}
}
/// Tensor parallel attention node
pub struct TensorParallelNode {
config: TPConfig,
/// Partial attention outputs from each head
partial_outputs: HVec<HVec<i32, 64>, MAX_HEADS_PER_CHIP>,
/// Combined output buffer
output_buffer: HVec<i32, 256>,
}
impl TensorParallelNode {
pub fn new(config: TPConfig) -> Self {
Self {
config,
partial_outputs: HVec::new(),
output_buffer: HVec::new(),
}
}
/// Get heads this chip handles
pub fn my_heads(&self) -> &[usize] {
&self.config.my_heads
}
/// Compute partial attention for assigned heads
pub fn compute_partial_attention(
&mut self,
query: &[i8],
keys: &[&[i8]],
values: &[&[i8]],
) -> crate::Result<()> {
self.partial_outputs.clear();
for &head_idx in &self.config.my_heads {
let mut head_output = HVec::new();
// Compute Q @ K^T for this head
let head_start = head_idx * self.config.head_dim;
let head_end = head_start + self.config.head_dim;
// Simplified attention: just dot product for now
for &val in &values[0][head_start..head_end.min(values[0].len())] {
head_output.push(val as i32).map_err(|_| crate::Error::BufferOverflow)?;
}
self.partial_outputs.push(head_output).map_err(|_| crate::Error::BufferOverflow)?;
}
Ok(())
}
/// Create message with partial results
pub fn create_partial_result_message(&self, dst: ChipId, seq: u16) -> crate::Result<FederationMessage> {
let mut data: Vec<i8> = Vec::new();
for partial in &self.partial_outputs {
for &val in partial {
data.push((val >> 8) as i8); // Scale down
}
}
FederationMessage::activation(
self.config.chip_id,
dst,
seq,
0, // Not layer-based
0,
&data,
)
}
/// Memory saved vs single-chip
pub fn memory_reduction(&self) -> f32 {
self.config.num_chips as f32
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_head_distribution() {
// 4 heads across 5 chips
let config0 = TPConfig::distribute_heads(0, 5, 4, 16);
let config1 = TPConfig::distribute_heads(1, 5, 4, 16);
// Chip 0 gets head 0, chip 1 gets head 1, etc.
assert_eq!(config0.my_heads.as_slice(), &[0]);
assert_eq!(config1.my_heads.as_slice(), &[1]);
}
}