Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
This commit is contained in:
146
examples/ruvLLM/esp32-flash/src/federation/speculative.rs
Normal file
146
examples/ruvLLM/esp32-flash/src/federation/speculative.rs
Normal file
@@ -0,0 +1,146 @@
|
||||
//! Speculative Decoding - Draft and Verify
|
||||
|
||||
use heapless::Vec as HVec;
|
||||
use super::protocol::{ChipId, FederationMessage};
|
||||
|
||||
pub const MAX_DRAFT_TOKENS: usize = 8;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DraftVerifyConfig {
|
||||
pub draft_length: usize,
|
||||
pub acceptance_threshold: f32,
|
||||
pub draft_chip: ChipId,
|
||||
pub verify_chips: HVec<ChipId, 4>,
|
||||
pub adaptive: bool,
|
||||
}
|
||||
|
||||
impl Default for DraftVerifyConfig {
|
||||
fn default() -> Self {
|
||||
Self { draft_length: 4, acceptance_threshold: 0.9, draft_chip: ChipId(0), verify_chips: HVec::new(), adaptive: true }
|
||||
}
|
||||
}
|
||||
|
||||
impl DraftVerifyConfig {
|
||||
pub fn for_five_chips() -> Self {
|
||||
let mut verify_chips = HVec::new();
|
||||
for i in 1..5 { let _ = verify_chips.push(ChipId(i)); }
|
||||
Self { draft_length: 4, acceptance_threshold: 0.9, draft_chip: ChipId(0), verify_chips, adaptive: true }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DraftResult {
|
||||
pub tokens: HVec<u16, MAX_DRAFT_TOKENS>,
|
||||
pub probs: HVec<u8, MAX_DRAFT_TOKENS>,
|
||||
pub start_pos: u16,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VerifyResult {
|
||||
pub accepted_count: usize,
|
||||
pub correction: Option<u16>,
|
||||
pub verify_probs: HVec<u8, MAX_DRAFT_TOKENS>,
|
||||
}
|
||||
|
||||
pub struct SpeculativeDecoder {
|
||||
config: DraftVerifyConfig,
|
||||
is_draft_chip: bool,
|
||||
acceptance_rate: f32,
|
||||
pending_draft: Option<DraftResult>,
|
||||
stats: SpecStats,
|
||||
}
|
||||
|
||||
impl SpeculativeDecoder {
|
||||
pub fn new(config: DraftVerifyConfig, chip_id: ChipId) -> Self {
|
||||
let is_draft = chip_id == config.draft_chip;
|
||||
Self { config, is_draft_chip: is_draft, acceptance_rate: 0.9, pending_draft: None, stats: SpecStats::default() }
|
||||
}
|
||||
|
||||
pub fn is_drafter(&self) -> bool { self.is_draft_chip }
|
||||
|
||||
pub fn submit_draft(&mut self, draft: DraftResult) -> crate::Result<FederationMessage> {
|
||||
if !self.is_draft_chip { return Err(crate::Error::UnsupportedFeature("Not draft chip")); }
|
||||
let tokens: heapless::Vec<u16, MAX_DRAFT_TOKENS> = draft.tokens.iter().cloned().collect();
|
||||
let msg = FederationMessage::draft_tokens(self.config.draft_chip, ChipId::BROADCAST, draft.start_pos, &tokens)?;
|
||||
self.pending_draft = Some(draft);
|
||||
self.stats.drafts_sent += 1;
|
||||
Ok(msg)
|
||||
}
|
||||
|
||||
pub fn verify_draft<F>(&mut self, draft: &DraftResult, mut get_prob: F) -> VerifyResult
|
||||
where F: FnMut(u16, u16) -> u8
|
||||
{
|
||||
let mut accepted = 0;
|
||||
let mut correction = None;
|
||||
let mut verify_probs = HVec::new();
|
||||
|
||||
for (i, &token) in draft.tokens.iter().enumerate() {
|
||||
let pos = draft.start_pos + i as u16;
|
||||
let verify_prob = get_prob(pos, token);
|
||||
let _ = verify_probs.push(verify_prob);
|
||||
let draft_prob = draft.probs.get(i).copied().unwrap_or(128);
|
||||
let threshold = (draft_prob as f32 * self.config.acceptance_threshold) as u8;
|
||||
|
||||
if verify_prob >= threshold {
|
||||
accepted += 1;
|
||||
} else {
|
||||
correction = Some(token.wrapping_add(1));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
VerifyResult { accepted_count: accepted, correction, verify_probs }
|
||||
}
|
||||
|
||||
pub fn process_verification(&mut self, result: &VerifyResult) -> HVec<u16, MAX_DRAFT_TOKENS> {
|
||||
let mut accepted_tokens = HVec::new();
|
||||
|
||||
if let Some(ref draft) = self.pending_draft {
|
||||
for i in 0..result.accepted_count {
|
||||
if let Some(&token) = draft.tokens.get(i) {
|
||||
let _ = accepted_tokens.push(token);
|
||||
}
|
||||
}
|
||||
if let Some(correct) = result.correction {
|
||||
let _ = accepted_tokens.push(correct);
|
||||
}
|
||||
|
||||
self.stats.tokens_accepted += result.accepted_count;
|
||||
self.stats.tokens_rejected += draft.tokens.len() - result.accepted_count;
|
||||
let rate = result.accepted_count as f32 / draft.tokens.len() as f32;
|
||||
self.acceptance_rate = 0.9 * self.acceptance_rate + 0.1 * rate;
|
||||
}
|
||||
|
||||
self.pending_draft = None;
|
||||
accepted_tokens
|
||||
}
|
||||
|
||||
pub fn adaptive_draft_length(&self) -> usize {
|
||||
if !self.config.adaptive { return self.config.draft_length; }
|
||||
if self.acceptance_rate > 0.95 { (self.config.draft_length + 2).min(MAX_DRAFT_TOKENS) }
|
||||
else if self.acceptance_rate > 0.8 { self.config.draft_length }
|
||||
else if self.acceptance_rate > 0.5 { (self.config.draft_length - 1).max(1) }
|
||||
else { 1 }
|
||||
}
|
||||
|
||||
pub fn estimated_speedup(&self) -> f32 {
|
||||
let avg = self.acceptance_rate * self.adaptive_draft_length() as f32;
|
||||
avg / 1.2
|
||||
}
|
||||
|
||||
pub fn stats(&self) -> &SpecStats { &self.stats }
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct SpecStats {
|
||||
pub drafts_sent: usize,
|
||||
pub tokens_accepted: usize,
|
||||
pub tokens_rejected: usize,
|
||||
}
|
||||
|
||||
impl SpecStats {
|
||||
pub fn acceptance_rate(&self) -> f32 {
|
||||
let total = self.tokens_accepted + self.tokens_rejected;
|
||||
if total == 0 { 0.0 } else { self.tokens_accepted as f32 / total as f32 }
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user