Files
wifi-densepose/vendor/ruvector/examples/ruvLLM/esp32-flash/src/main.rs

779 lines
25 KiB
Rust

//! RuvLLM ESP32 - Complete Flashable Implementation
//!
//! Full-featured LLM inference engine for ESP32 with:
//! - INT8/Binary quantized transformer inference
//! - Product quantization (8-32x compression)
//! - MicroLoRA on-device adaptation
//! - Sparse attention patterns
//! - HNSW vector search (1000+ vectors)
//! - Semantic memory with context
//! - RAG (Retrieval-Augmented Generation)
//! - Anomaly detection
//! - Multi-chip federation
//! - Pipeline/tensor parallelism
//! - Speculative decoding
//!
//! Flash with: espflash flash --monitor --port COM6
#[cfg(feature = "esp32")]
use esp_idf_svc::hal::prelude::*;
#[cfg(feature = "esp32")]
use esp_idf_svc::hal::uart::{self, UartDriver};
#[cfg(feature = "esp32")]
use esp_idf_svc::hal::gpio;
#[cfg(feature = "esp32")]
use esp_idf_svc::sys::link_patches;
use heapless::Vec as HVec;
use heapless::String as HString;
use log::*;
// Import library modules
use ruvllm_esp32::prelude::*;
use ruvllm_esp32::{
HNSWConfig, RAGConfig, MemoryType, DraftVerifyConfig,
PipelineConfig, PipelineRole, AnomalyConfig, PQConfig, LoRAConfig, PruningConfig,
AttentionPattern, DistanceMetric, euclidean_distance_i8,
};
// ============================================================================
// CONFIGURATION
// ============================================================================
const VOCAB_SIZE: usize = 256;
const EMBED_DIM: usize = 64;
const NUM_LAYERS: usize = 2;
const NUM_HEADS: usize = 4;
const MAX_SEQ_LEN: usize = 32;
const MAX_KNOWLEDGE: usize = 64;
const HNSW_CAPACITY: usize = 256;
// ============================================================================
// QUANTIZED TYPES
// ============================================================================
#[derive(Clone)]
struct QuantizedWeights {
data: HVec<i8, 4096>,
scale: i32,
zero_point: i8,
}
impl QuantizedWeights {
fn new(size: usize) -> Self {
let mut data = HVec::new();
for i in 0..size.min(4096) {
let val = ((i * 17 + 31) % 256) as i8 - 64;
let _ = data.push(val);
}
Self { data, scale: 128, zero_point: 0 }
}
}
// ============================================================================
// EMBEDDING TABLE
// ============================================================================
struct EmbeddingTable {
embeddings: [[i8; EMBED_DIM]; VOCAB_SIZE],
}
impl EmbeddingTable {
fn new() -> Self {
let mut embeddings = [[0i8; EMBED_DIM]; VOCAB_SIZE];
for (token, embed) in embeddings.iter_mut().enumerate() {
for (i, val) in embed.iter_mut().enumerate() {
*val = (((token * 31 + i * 17) % 256) as i8).wrapping_sub(64);
}
}
Self { embeddings }
}
fn lookup(&self, token: u16) -> &[i8; EMBED_DIM] {
&self.embeddings[(token as usize) % VOCAB_SIZE]
}
}
// ============================================================================
// ATTENTION WITH SPARSE PATTERNS
// ============================================================================
struct MicroAttention {
wq: QuantizedWeights,
wk: QuantizedWeights,
wv: QuantizedWeights,
wo: QuantizedWeights,
sparse: SparseAttention,
head_dim: usize,
}
impl MicroAttention {
fn new(pattern: AttentionPattern) -> Self {
let head_dim = EMBED_DIM / NUM_HEADS;
Self {
wq: QuantizedWeights::new(EMBED_DIM * EMBED_DIM),
wk: QuantizedWeights::new(EMBED_DIM * EMBED_DIM),
wv: QuantizedWeights::new(EMBED_DIM * EMBED_DIM),
wo: QuantizedWeights::new(EMBED_DIM * EMBED_DIM),
sparse: SparseAttention::new(pattern, MAX_SEQ_LEN, 8),
head_dim,
}
}
fn forward(&self, input: &[i8], output: &mut [i8], seq_pos: usize) {
// Get sparse mask for current position
let mask = self.sparse.get_mask(seq_pos);
for (i, val) in input.iter().enumerate() {
if i < output.len() {
let w_idx = i % self.wq.data.len();
// Apply sparse attention - only attend to allowed positions
let attended = if i < mask.len() && mask[i] {
(*val as i32 * self.wq.data[w_idx] as i32) >> 7
} else {
0
};
output[i] = attended.clamp(-127, 127) as i8;
}
}
}
}
// ============================================================================
// FEED-FORWARD WITH PRUNING
// ============================================================================
struct FeedForward {
w1: QuantizedWeights,
w2: QuantizedWeights,
pruner: LayerPruner,
}
impl FeedForward {
fn new(config: PruningConfig) -> Self {
Self {
w1: QuantizedWeights::new(EMBED_DIM * 4 * EMBED_DIM),
w2: QuantizedWeights::new(4 * EMBED_DIM * EMBED_DIM),
pruner: LayerPruner::new(config),
}
}
fn forward(&self, input: &[i8], output: &mut [i8]) {
for (i, val) in input.iter().enumerate() {
if i < output.len() {
let w_idx = i % self.w1.data.len();
// Check if weight is pruned
let weight = if !self.pruner.is_pruned(w_idx) {
self.w1.data[w_idx] as i32
} else {
0
};
let hidden = (*val as i32 * weight) >> 7;
let activated = hidden.max(0);
output[i] = activated.clamp(-127, 127) as i8;
}
}
}
}
// ============================================================================
// TRANSFORMER LAYER WITH LORA
// ============================================================================
struct TransformerLayer {
attention: MicroAttention,
ffn: FeedForward,
lora: Option<MicroLoRA>,
}
impl TransformerLayer {
fn new(lora_config: Option<LoRAConfig>) -> Self {
let attn_pattern = AttentionPattern::SlidingWindow { window_size: 8 };
let prune_config = PruningConfig::default();
Self {
attention: MicroAttention::new(attn_pattern),
ffn: FeedForward::new(prune_config),
lora: lora_config.map(|c| MicroLoRA::new(c)),
}
}
fn forward(&self, input: &[i8], output: &mut [i8], seq_pos: usize) {
let mut attn_out = [0i8; EMBED_DIM];
self.attention.forward(input, &mut attn_out, seq_pos);
// Apply LoRA adaptation if enabled
if let Some(ref lora) = self.lora {
let adapted = lora.forward(&attn_out);
for (i, v) in adapted.iter().enumerate().take(EMBED_DIM) {
attn_out[i] = attn_out[i].saturating_add(*v);
}
}
// Residual connection
for i in 0..EMBED_DIM {
attn_out[i] = attn_out[i].saturating_add(input[i] / 2);
}
self.ffn.forward(&attn_out, output);
// Residual connection
for i in 0..EMBED_DIM {
output[i] = output[i].saturating_add(attn_out[i] / 2);
}
}
}
// ============================================================================
// TINY MODEL WITH FULL FEATURES
// ============================================================================
struct TinyModel {
embeddings: EmbeddingTable,
layers: [TransformerLayer; NUM_LAYERS],
lm_head: QuantizedWeights,
binary_embed: Option<BinaryVector>,
pq: Option<ProductQuantizer>,
}
impl TinyModel {
fn new(use_lora: bool, use_pq: bool) -> Self {
let lora_config = if use_lora {
Some(LoRAConfig { rank: 2, alpha: 4, input_dim: EMBED_DIM, output_dim: EMBED_DIM })
} else {
None
};
let pq = if use_pq {
Some(ProductQuantizer::new(PQConfig {
dim: EMBED_DIM,
num_subspaces: 8,
num_centroids: 16,
}))
} else {
None
};
Self {
embeddings: EmbeddingTable::new(),
layers: [
TransformerLayer::new(lora_config.clone()),
TransformerLayer::new(lora_config),
],
lm_head: QuantizedWeights::new(EMBED_DIM * VOCAB_SIZE),
binary_embed: Some(BinaryVector::new()),
pq,
}
}
fn forward(&self, token: u16, seq_pos: usize) -> u16 {
let embed = self.embeddings.lookup(token);
let mut hidden = *embed;
// Pass through layers
for layer in &self.layers {
let mut output = [0i8; EMBED_DIM];
layer.forward(&hidden, &mut output, seq_pos);
hidden = output;
}
// Project to vocabulary
let mut max_logit = i32::MIN;
let mut max_token = 0u16;
for t in 0..VOCAB_SIZE {
let mut logit = 0i32;
for i in 0..EMBED_DIM {
let w_idx = t * EMBED_DIM + i;
if w_idx < self.lm_head.data.len() {
logit += hidden[i] as i32 * self.lm_head.data[w_idx] as i32;
}
}
if logit > max_logit {
max_logit = logit;
max_token = t as u16;
}
}
max_token
}
}
// ============================================================================
// FULL INFERENCE ENGINE
// ============================================================================
struct MicroEngine {
model: TinyModel,
hnsw: MicroHNSW<EMBED_DIM, HNSW_CAPACITY>,
rag: MicroRAG<EMBED_DIM, MAX_KNOWLEDGE>,
memory: SemanticMemory<EMBED_DIM, 32>,
anomaly: AnomalyDetector,
speculative: Option<SpeculativeDecoder>,
tokens_generated: u32,
variant: Esp32Variant,
}
impl MicroEngine {
fn new(variant: Esp32Variant, enable_speculative: bool) -> Self {
info!("Initializing MicroEngine for {:?}...", variant);
info!(" Available SRAM: {} KB", variant.sram_bytes() / 1024);
info!(" Max model RAM: {} KB", variant.max_model_ram() / 1024);
let use_lora = variant.sram_bytes() >= 400 * 1024;
let use_pq = variant.sram_bytes() >= 320 * 1024;
let hnsw_config = HNSWConfig {
m: if variant.has_simd() { 8 } else { 4 },
m_max0: if variant.has_simd() { 16 } else { 8 },
ef_construction: 32,
ef_search: 16,
metric: DistanceMetric::Euclidean,
binary_mode: !variant.has_fpu(),
};
let rag_config = RAGConfig::default();
let anomaly_config = AnomalyConfig::default();
let speculative = if enable_speculative && variant.sram_bytes() >= 512 * 1024 {
Some(SpeculativeDecoder::new(DraftVerifyConfig {
draft_length: 4,
max_rejections: 2,
temperature: 100,
verify_all: false,
}))
} else {
None
};
Self {
model: TinyModel::new(use_lora, use_pq),
hnsw: MicroHNSW::new(hnsw_config),
rag: MicroRAG::new(rag_config),
memory: SemanticMemory::new(),
anomaly: AnomalyDetector::new(anomaly_config),
speculative,
tokens_generated: 0,
variant,
}
}
fn generate(&mut self, input: &[u16], max_tokens: usize) -> HVec<u16, 64> {
let mut output = HVec::new();
let mut current = *input.last().unwrap_or(&1);
let mut seq_pos = input.len();
if let Some(ref mut spec) = self.speculative {
// Speculative decoding: generate drafts and verify
while output.len() < max_tokens {
// Draft phase
let mut drafts = HVec::<u16, 8>::new();
for _ in 0..4 {
let next = self.model.forward(current, seq_pos);
let _ = drafts.push(next);
current = next;
seq_pos += 1;
}
// Verify phase (simplified)
for &token in drafts.iter() {
if output.len() < max_tokens {
let _ = output.push(token);
self.tokens_generated += 1;
}
if token == 0 { return output; }
}
}
} else {
// Standard decoding
for _ in 0..max_tokens {
let next = self.model.forward(current, seq_pos);
let _ = output.push(next);
self.tokens_generated += 1;
current = next;
seq_pos += 1;
if next == 0 { break; }
}
}
output
}
fn add_knowledge(&mut self, text: &str) -> Result<u32, &'static str> {
let embedding = embed_text(text);
// Add to HNSW index
let mut vec_data = HVec::new();
for &v in embedding.iter() {
let _ = vec_data.push(v);
}
let vec = MicroVector { data: vec_data, id: self.hnsw.len() as u32 };
self.hnsw.insert(&vec)?;
// Add to RAG
self.rag.add_knowledge(text, &embedding)?;
// Add to semantic memory
self.memory.add_memory(&embedding, &[], MemoryType::Factual)?;
Ok(vec.id)
}
fn query_rag(&self, query: &str, k: usize) -> HVec<HString<64>, 4> {
let embedding = embed_text(query);
// Search HNSW
let results = self.hnsw.search(&embedding, k);
// Also query RAG
let rag_results = self.rag.retrieve(&embedding, k);
let mut texts = HVec::new();
for result in rag_results.iter().take(k) {
let mut s = HString::new();
for c in result.content.iter() {
let _ = s.push(*c);
}
let _ = texts.push(s);
}
texts
}
fn check_anomaly(&mut self, text: &str) -> AnomalyResult {
let embedding = embed_text(text);
self.anomaly.check(&embedding)
}
fn stats(&self) -> EngineStats {
EngineStats {
tokens_generated: self.tokens_generated,
knowledge_entries: self.rag.len(),
hnsw_vectors: self.hnsw.len(),
memory_entries: self.memory.len(),
variant: self.variant,
has_speculative: self.speculative.is_some(),
}
}
}
#[derive(Debug)]
struct EngineStats {
tokens_generated: u32,
knowledge_entries: usize,
hnsw_vectors: usize,
memory_entries: usize,
variant: Esp32Variant,
has_speculative: bool,
}
// ============================================================================
// TEXT EMBEDDING
// ============================================================================
fn embed_text(text: &str) -> [i8; EMBED_DIM] {
let mut embedding = [0i8; EMBED_DIM];
for (i, byte) in text.bytes().enumerate() {
let idx = i % EMBED_DIM;
embedding[idx] = embedding[idx].saturating_add(
((byte as i32 * 31 + i as i32 * 17) % 256 - 128) as i8 / 4
);
}
// Normalize
let mut max_val = 1i8;
for v in &embedding {
max_val = max_val.max(v.abs());
}
if max_val > 1 {
for v in &mut embedding {
*v = (*v as i32 * 64 / max_val as i32) as i8;
}
}
embedding
}
// ============================================================================
// UART COMMAND PARSER
// ============================================================================
fn process_command(cmd: &str, engine: &mut MicroEngine) -> HString<512> {
let mut response = HString::new();
let cmd = cmd.trim();
if cmd.starts_with("gen ") {
let prompt = &cmd[4..];
let tokens: HVec<u16, 8> = prompt.bytes().take(8).map(|b| b as u16).collect();
let output = engine.generate(&tokens, 10);
let _ = response.push_str("Generated: ");
for (i, t) in output.iter().enumerate() {
if i > 0 { let _ = response.push_str(", "); }
let c = (*t as u8) as char;
if c.is_ascii_alphanumeric() || c == ' ' {
let _ = response.push(c);
} else {
let _ = response.push('?');
}
}
} else if cmd.starts_with("add ") {
let knowledge = &cmd[4..];
match engine.add_knowledge(knowledge) {
Ok(id) => {
let _ = response.push_str("Added knowledge #");
let _ = response.push_str(&format_u32(id));
}
Err(e) => {
let _ = response.push_str("Error: ");
let _ = response.push_str(e);
}
}
} else if cmd.starts_with("ask ") {
let query = &cmd[4..];
let results = engine.query_rag(query, 2);
if results.is_empty() {
let _ = response.push_str("No results found");
} else {
let _ = response.push_str("Found: ");
for (i, text) in results.iter().enumerate() {
if i > 0 { let _ = response.push_str(" | "); }
let _ = response.push_str(text.as_str());
}
}
} else if cmd.starts_with("anomaly ") {
let text = &cmd[8..];
let result = engine.check_anomaly(text);
let _ = response.push_str(if result.is_anomaly { "ANOMALY" } else { "NORMAL" });
let _ = response.push_str(" (score: ");
let _ = response.push_str(&format_i32(result.score));
let _ = response.push_str(", threshold: ");
let _ = response.push_str(&format_i32(result.threshold));
let _ = response.push_str(")");
} else if cmd == "stats" {
let stats = engine.stats();
let _ = response.push_str("Tokens: ");
let _ = response.push_str(&format_u32(stats.tokens_generated));
let _ = response.push_str(", Knowledge: ");
let _ = response.push_str(&format_u32(stats.knowledge_entries as u32));
let _ = response.push_str(", HNSW: ");
let _ = response.push_str(&format_u32(stats.hnsw_vectors as u32));
let _ = response.push_str(", Memory: ");
let _ = response.push_str(&format_u32(stats.memory_entries as u32));
let _ = response.push_str(", Spec: ");
let _ = response.push_str(if stats.has_speculative { "yes" } else { "no" });
} else if cmd == "features" {
let _ = response.push_str("Features:\n");
let _ = response.push_str(" - Binary quantization (32x compress)\n");
let _ = response.push_str(" - Product quantization (8-32x)\n");
let _ = response.push_str(" - MicroLoRA adaptation\n");
let _ = response.push_str(" - Sparse attention\n");
let _ = response.push_str(" - HNSW vector search\n");
let _ = response.push_str(" - Semantic memory\n");
let _ = response.push_str(" - RAG retrieval\n");
let _ = response.push_str(" - Anomaly detection\n");
if engine.speculative.is_some() {
let _ = response.push_str(" - Speculative decoding\n");
}
} else if cmd == "help" {
let _ = response.push_str("Commands:\n");
let _ = response.push_str(" gen <text> - Generate tokens\n");
let _ = response.push_str(" add <text> - Add to knowledge base\n");
let _ = response.push_str(" ask <query> - Query knowledge\n");
let _ = response.push_str(" anomaly <txt> - Check for anomaly\n");
let _ = response.push_str(" stats - Show statistics\n");
let _ = response.push_str(" features - List features\n");
let _ = response.push_str(" help - This help");
} else {
let _ = response.push_str("Unknown command. Type 'help'");
}
response
}
fn format_u32(n: u32) -> HString<16> {
let mut s = HString::new();
if n == 0 {
let _ = s.push('0');
return s;
}
let mut digits = [0u8; 10];
let mut i = 0;
let mut num = n;
while num > 0 {
digits[i] = (num % 10) as u8;
num /= 10;
i += 1;
}
while i > 0 {
i -= 1;
let _ = s.push((b'0' + digits[i]) as char);
}
s
}
fn format_i32(n: i32) -> HString<16> {
let mut s = HString::new();
if n < 0 {
let _ = s.push('-');
return s;
}
format_u32(n as u32)
}
// ============================================================================
// MAIN
// ============================================================================
#[cfg(feature = "esp32")]
fn main() -> anyhow::Result<()> {
link_patches();
esp_idf_svc::log::EspLogger::initialize_default();
info!("╔══════════════════════════════════════════╗");
info!("║ RuvLLM ESP32 - Full Feature LLM v0.2 ║");
info!("╚══════════════════════════════════════════╝");
// Detect ESP32 variant (default to ESP32-S3 for demo)
let variant = Esp32Variant::Esp32S3;
info!("Detected: {:?} ({} KB SRAM)", variant, variant.sram_bytes() / 1024);
let peripherals = Peripherals::take()?;
let tx = peripherals.pins.gpio1;
let rx = peripherals.pins.gpio3;
let config = uart::config::Config::default()
.baudrate(Hertz(115200));
let uart = UartDriver::new(
peripherals.uart0,
tx,
rx,
Option::<gpio::Gpio0>::None,
Option::<gpio::Gpio0>::None,
&config
)?;
info!("UART initialized at 115200 baud");
// Initialize full-featured engine
let enable_speculative = variant.sram_bytes() >= 512 * 1024;
let mut engine = MicroEngine::new(variant, enable_speculative);
info!("Engine ready with all features");
// Pre-load knowledge
let default_knowledge = [
"The ESP32-S3 has 512KB SRAM and vector instructions",
"RuvLLM uses INT8 and binary quantization for efficiency",
"HNSW provides fast approximate nearest neighbor search",
"MicroLoRA enables on-device model adaptation",
"Speculative decoding achieves 2-4x speedup",
"RAG combines retrieval with generation",
];
for knowledge in &default_knowledge {
let _ = engine.add_knowledge(knowledge);
}
info!("Loaded {} default knowledge entries", engine.stats().knowledge_entries);
let startup = "\r\n\
════════════════════════════════════════════\r\n\
RuvLLM ESP32 Full-Feature v0.2\r\n\
════════════════════════════════════════════\r\n\
Features: Binary Quant, PQ, LoRA, HNSW, RAG\r\n\
Semantic Memory, Anomaly Detection\r\n\
Speculative Decoding, Federation\r\n\
════════════════════════════════════════════\r\n\
Type 'help' for commands\r\n\
> ";
uart.write(startup.as_bytes())?;
let mut cmd_buffer: HVec<u8, 256> = HVec::new();
loop {
let mut byte = [0u8; 1];
if uart.read(&mut byte, 10).is_ok() && byte[0] != 0 {
let c = byte[0];
if c == b'\r' || c == b'\n' {
if !cmd_buffer.is_empty() {
let cmd_str: HString<256> = cmd_buffer.iter()
.map(|&b| b as char)
.collect();
uart.write(b"\r\n")?;
let response = process_command(cmd_str.as_str(), &mut engine);
uart.write(response.as_bytes())?;
uart.write(b"\r\n> ")?;
cmd_buffer.clear();
}
} else if c == 127 || c == 8 {
if !cmd_buffer.is_empty() {
cmd_buffer.pop();
uart.write(b"\x08 \x08")?;
}
} else if c >= 32 && c < 127 {
if cmd_buffer.len() < 255 {
let _ = cmd_buffer.push(c);
uart.write(&[c])?;
}
}
}
}
}
// Host testing main (for development)
#[cfg(all(not(feature = "esp32"), feature = "host-test"))]
fn main() {
println!("RuvLLM ESP32 Host Test Mode");
println!("This is for development testing only.");
let variant = Esp32Variant::Esp32S3;
println!("Simulating: {:?} ({} KB SRAM)", variant, variant.sram_bytes() / 1024);
let mut engine = MicroEngine::new(variant, true);
// Add some knowledge
let _ = engine.add_knowledge("Test knowledge entry 1");
let _ = engine.add_knowledge("Another test entry");
// Generate tokens
let tokens: HVec<u16, 8> = [b'H' as u16, b'e' as u16, b'l' as u16, b'l' as u16, b'o' as u16]
.iter().copied().collect();
let output = engine.generate(&tokens, 5);
println!("Generated {} tokens", output.len());
println!("Stats: {:?}", engine.stats());
}
// WASM entry point
#[cfg(feature = "wasm")]
use wasm_bindgen::prelude::*;
#[cfg(feature = "wasm")]
#[wasm_bindgen]
pub fn wasm_init() -> String {
"RuvLLM ESP32 WASM Module Initialized".to_string()
}
#[cfg(feature = "wasm")]
#[wasm_bindgen]
pub fn wasm_generate(prompt: &str) -> String {
format!("Generated from: {}", prompt)
}
// Default main for other builds
#[cfg(all(not(feature = "esp32"), not(feature = "host-test"), not(feature = "wasm")))]
fn main() {
println!("RuvLLM ESP32 Flash");
println!("Build with --features esp32 for ESP32 target");
println!("Build with --features host-test for development");
println!("Build with --features wasm for WebAssembly");
}