Files
wifi-densepose/examples/ruvLLM/esp32/examples/snn_gated_inference.rs
ruv d803bfe2b1 Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector
git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
2026-02-28 14:39:40 -05:00

506 lines
17 KiB
Rust

//! SNN-Gated Inference Example - Event-Driven LLM with Spiking Pre-Filter
//!
//! Demonstrates the optimal architecture where Spiking Neural Networks (SNN)
//! handle always-on event detection, while RuvLLM runs only when needed.
//!
//! # The Key Insight
//! ```text
//! ❌ Wrong: "SNN replaces the LLM"
//! ✅ Right: "SNN replaces expensive always-on gating, filtering, and routing"
//! ```
//!
//! # Architecture
//! ```text
//! ┌─────────────────────────────────────────────────────────────────────────┐
//! │ SNN-GATED INFERENCE PIPELINE │
//! ├─────────────────────────────────────────────────────────────────────────┤
//! │ │
//! │ Sensors ──▶ SNN Front-End ──▶ Event? ──▶ RuVector ──▶ RuvLLM │
//! │ (always on) (μW power) │ (query) (only on event) │
//! │ │ │
//! │ No event │
//! │ │ │
//! │ SLEEP │
//! │ (99% of time) │
//! │ │
//! └─────────────────────────────────────────────────────────────────────────┘
//! ```
//!
//! # Benefits
//! - 10-100x energy reduction (LLM sleeps 99% of the time)
//! - Microsecond response to events (SNN reacts in μs, LLM explains later)
//! - Higher throughput (compute only on events, not silence)
#![allow(unused)]
use heapless::Vec as HVec;
use heapless::String as HString;
const EMBED_DIM: usize = 16;
const SNN_NEURONS: usize = 32;
/// Spiking neuron state
#[derive(Debug, Clone, Copy)]
struct SpikingNeuron {
/// Membrane potential (mV scaled to i16)
membrane: i16,
/// Firing threshold
threshold: i16,
/// Refractory period remaining
refractory: u8,
/// Leak rate (how fast potential decays)
leak: i16,
/// Last spike time
last_spike: u32,
}
impl SpikingNeuron {
fn new(threshold: i16) -> Self {
Self {
membrane: 0,
threshold,
refractory: 0,
leak: 10, // Decay 10 units per tick
last_spike: 0,
}
}
/// Process input and return if neuron spiked
fn process(&mut self, input: i16, current_time: u32) -> bool {
// Check refractory period
if self.refractory > 0 {
self.refractory -= 1;
return false;
}
// Leak (decay toward resting potential)
if self.membrane > 0 {
self.membrane = (self.membrane - self.leak).max(0);
} else if self.membrane < 0 {
self.membrane = (self.membrane + self.leak).min(0);
}
// Integrate input
self.membrane = self.membrane.saturating_add(input);
// Check for spike
if self.membrane >= self.threshold {
self.membrane = -30; // Hyperpolarization after spike
self.refractory = 3; // Refractory period
self.last_spike = current_time;
return true;
}
false
}
/// Reset neuron state
fn reset(&mut self) {
self.membrane = 0;
self.refractory = 0;
}
}
/// SNN Event Types
#[derive(Debug, Clone, Copy, PartialEq)]
enum SNNEvent {
/// Wake word detected
WakeWord,
/// Anomaly onset detected
AnomalyOnset,
/// Novelty in sensor pattern
Novelty,
/// Threshold crossing
ThresholdCross,
/// Rhythm change detected
RhythmChange,
/// No event
None,
}
impl SNNEvent {
fn priority(&self) -> u8 {
match self {
Self::AnomalyOnset => 100,
Self::WakeWord => 90,
Self::ThresholdCross => 70,
Self::RhythmChange => 50,
Self::Novelty => 40,
Self::None => 0,
}
}
}
/// SNN Front-End for Event Detection
/// Runs continuously at μW power, gates LLM invocation
struct SNNEventDetector {
/// Neurons for different event types
neurons: [SpikingNeuron; SNN_NEURONS],
/// Current simulation time
current_time: u32,
/// Spike history (for pattern detection)
spike_history: HVec<(u8, u32), 64>, // (neuron_id, time)
/// Event counters
events_detected: u32,
/// False positives (estimated)
false_positives: u32,
/// Baseline adaptation
baseline: [i16; 8],
}
impl SNNEventDetector {
fn new() -> Self {
let mut neurons = [SpikingNeuron::new(100); SNN_NEURONS];
// Different thresholds for different event types
// Wake word neurons (sensitive)
for i in 0..4 {
neurons[i].threshold = 80;
}
// Anomaly neurons (balanced)
for i in 4..12 {
neurons[i].threshold = 100;
}
// Novelty neurons (less sensitive)
for i in 12..20 {
neurons[i].threshold = 120;
}
// Rhythm neurons (pattern-based)
for i in 20..SNN_NEURONS {
neurons[i].threshold = 90;
neurons[i].leak = 5; // Slower decay for temporal integration
}
Self {
neurons,
current_time: 0,
spike_history: HVec::new(),
events_detected: 0,
false_positives: 0,
baseline: [0; 8],
}
}
/// Process sensor input and detect events
fn process(&mut self, sensor_data: &[i16]) -> SNNEvent {
self.current_time += 1;
// Adapt baseline (slow moving average)
for (i, &val) in sensor_data.iter().take(8).enumerate() {
self.baseline[i] = ((self.baseline[i] as i32 * 95 + val as i32 * 5) / 100) as i16;
}
let mut spikes = 0u32;
let mut spike_pattern = [false; SNN_NEURONS];
// Process through SNN
for (neuron_idx, neuron) in self.neurons.iter_mut().enumerate() {
// Map sensor data to neurons
let input_idx = neuron_idx % sensor_data.len().max(1);
let raw_input = sensor_data.get(input_idx).copied().unwrap_or(0);
// Subtract baseline for adaptive threshold
let input = raw_input - self.baseline.get(input_idx).copied().unwrap_or(0);
if neuron.process(input, self.current_time) {
spikes |= 1 << neuron_idx;
spike_pattern[neuron_idx] = true;
// Record spike
if self.spike_history.len() >= 64 {
self.spike_history.remove(0);
}
let _ = self.spike_history.push((neuron_idx as u8, self.current_time));
}
}
// Decode events from spike patterns
let event = self.decode_spikes(&spike_pattern);
if event != SNNEvent::None {
self.events_detected += 1;
}
event
}
/// Decode spike pattern into event type
fn decode_spikes(&self, spikes: &[bool; SNN_NEURONS]) -> SNNEvent {
// Wake word: neurons 0-3 fire together
let wake_spikes: u8 = spikes[0..4].iter().filter(|&&s| s).count() as u8;
if wake_spikes >= 3 {
return SNNEvent::WakeWord;
}
// Anomaly: multiple neurons in 4-11 fire
let anomaly_spikes: u8 = spikes[4..12].iter().filter(|&&s| s).count() as u8;
if anomaly_spikes >= 4 {
return SNNEvent::AnomalyOnset;
}
// Threshold crossing: any single strong spike in 4-11
if spikes[4..12].iter().any(|&s| s) {
return SNNEvent::ThresholdCross;
}
// Novelty: neurons 12-19
let novelty_spikes: u8 = spikes[12..20].iter().filter(|&&s| s).count() as u8;
if novelty_spikes >= 2 {
return SNNEvent::Novelty;
}
// Rhythm change: check for pattern in 20-31
let rhythm_spikes: u8 = spikes[20..].iter().filter(|&&s| s).count() as u8;
if rhythm_spikes >= 2 {
// Check if this breaks expected rhythm
let recent_rhythm = self.spike_history.iter()
.rev()
.take(10)
.filter(|(id, _)| *id >= 20)
.count();
if recent_rhythm > 5 {
return SNNEvent::RhythmChange;
}
}
SNNEvent::None
}
/// Get spike rate (for monitoring)
fn spike_rate(&self) -> f32 {
let recent_spikes = self.spike_history.iter()
.filter(|(_, t)| self.current_time - *t < 100)
.count();
recent_spikes as f32 / 100.0 * SNN_NEURONS as f32
}
/// Reset all neurons
fn reset(&mut self) {
for neuron in self.neurons.iter_mut() {
neuron.reset();
}
self.spike_history.clear();
}
}
/// Routing decision based on SNN event
#[derive(Debug, Clone, Copy)]
enum RouteDecision {
/// Sleep, no action needed
Sleep,
/// Quick local response (no LLM)
LocalResponse,
/// Query RuVector memory
FetchMemory,
/// Run RuvLLM for generation
RunLLM,
/// Escalate to bigger model
Escalate,
/// Require human confirmation
RequireConfirmation,
}
/// SNN-based Router
struct SNNRouter {
/// Confidence threshold for local response
local_threshold: u8,
/// LLM invocation count
llm_invocations: u32,
/// Skipped invocations (energy saved)
skipped_invocations: u32,
}
impl SNNRouter {
fn new() -> Self {
Self {
local_threshold: 80,
llm_invocations: 0,
skipped_invocations: 0,
}
}
/// Route based on SNN event and confidence
fn route(&mut self, event: SNNEvent, confidence: u8) -> RouteDecision {
match event {
SNNEvent::None => {
self.skipped_invocations += 1;
RouteDecision::Sleep
}
SNNEvent::WakeWord => {
if confidence >= 90 {
self.llm_invocations += 1;
RouteDecision::RunLLM
} else {
RouteDecision::LocalResponse
}
}
SNNEvent::AnomalyOnset => {
if confidence >= 95 {
RouteDecision::RequireConfirmation
} else if confidence >= 70 {
self.llm_invocations += 1;
RouteDecision::RunLLM
} else {
RouteDecision::FetchMemory
}
}
SNNEvent::ThresholdCross => {
self.skipped_invocations += 1;
RouteDecision::LocalResponse
}
SNNEvent::Novelty => {
RouteDecision::FetchMemory
}
SNNEvent::RhythmChange => {
if confidence >= 80 {
self.llm_invocations += 1;
RouteDecision::RunLLM
} else {
RouteDecision::FetchMemory
}
}
}
}
/// Get energy savings ratio
fn energy_savings_ratio(&self) -> f32 {
let total = self.llm_invocations + self.skipped_invocations;
if total == 0 {
return 0.0;
}
self.skipped_invocations as f32 / total as f32
}
}
/// Simulated power model (μW)
fn estimate_power(route: RouteDecision) -> u32 {
match route {
RouteDecision::Sleep => 10, // Deep sleep: 10 μW
RouteDecision::LocalResponse => 500, // Quick compute: 500 μW
RouteDecision::FetchMemory => 2000, // Memory access: 2 mW
RouteDecision::RunLLM => 50000, // Full LLM: 50 mW
RouteDecision::Escalate => 100000, // External: 100 mW
RouteDecision::RequireConfirmation => 5000, // Alert: 5 mW
}
}
fn main() {
println!("⚡ SNN-Gated Inference Example");
println!("==============================\n");
println!("Key Insight:");
println!(" ❌ Wrong: SNN replaces the LLM");
println!(" ✅ Right: SNN replaces expensive always-on gating\n");
let mut snn = SNNEventDetector::new();
let mut router = SNNRouter::new();
// Simulate 1000 time steps of sensor data
println!("🔄 Running simulation (1000 time steps)...\n");
let mut total_power_uw = 0u64;
let mut events: HVec<(u32, SNNEvent, RouteDecision), 64> = HVec::new();
for t in 0..1000 {
// Generate sensor data
// 99% of the time: normal background noise
// 1% of the time: actual events
let sensor_data: [i16; 8] = if t % 100 == 42 {
// Anomaly spike
[200, 180, 150, 120, 100, 90, 80, 70]
} else if t % 200 == 150 {
// Wake word pattern
[150, 160, 155, 145, 30, 25, 20, 15]
} else if t % 300 == 250 {
// Novelty
[50, 100, 50, 100, 50, 100, 50, 100]
} else {
// Normal noise
let noise = ((t * 7) % 40) as i16 - 20;
[noise, noise + 5, noise - 3, noise + 2, noise - 1, noise + 4, noise - 2, noise + 1]
};
// SNN processes (always on, μW power)
let event = snn.process(&sensor_data);
// Calculate confidence from spike history
let confidence = if event != SNNEvent::None {
85 + (snn.spike_history.len() % 15) as u8
} else {
0
};
// Route decision
let route = router.route(event, confidence);
// Accumulate power
total_power_uw += estimate_power(route) as u64;
// Record interesting events
if event != SNNEvent::None {
if events.len() < 64 {
let _ = events.push((t, event, route));
}
}
}
// Results
println!("📊 Simulation Results:\n");
println!("Events Detected:");
for (time, event, route) in events.iter().take(10) {
println!(" t={:4}: {:?}{:?}", time, event, route);
}
if events.len() > 10 {
println!(" ... and {} more events", events.len() - 10);
}
println!("\n📈 Statistics:");
println!(" Total events detected: {}", snn.events_detected);
println!(" LLM invocations: {}", router.llm_invocations);
println!(" Skipped invocations: {}", router.skipped_invocations);
println!(" Energy savings ratio: {:.1}%", router.energy_savings_ratio() * 100.0);
println!("\n⚡ Power Analysis:");
let avg_power_uw = total_power_uw / 1000;
println!(" Total energy: {} μJ (1000 steps)", total_power_uw);
println!(" Average power: {} μW", avg_power_uw);
// Compare to always-on LLM
let always_on_power = 50000u64 * 1000; // 50mW * 1000 steps
let savings = (always_on_power - total_power_uw) as f64 / always_on_power as f64 * 100.0;
println!("\n vs Always-On LLM:");
println!(" Always-on: {} μJ", always_on_power);
println!(" SNN-gated: {} μJ", total_power_uw);
println!(" Savings: {:.1}%", savings);
println!(" Reduction: {:.0}x", always_on_power as f64 / total_power_uw.max(1) as f64);
// Three-stage benchmark comparison
println!("\n━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
println!("📊 Three-Stage Benchmark (as suggested):\n");
println!("Stage A - Baseline (LLM on every window):");
println!(" Power: 50,000 μW constant");
println!(" LLM calls: 1000");
println!(" Energy: 50,000,000 μJ\n");
println!("Stage B - SNN Gate (LLM only on spikes):");
println!(" Power: {} μW average", avg_power_uw);
println!(" LLM calls: {}", router.llm_invocations);
println!(" Energy: {} μJ", total_power_uw);
println!(" Improvement: {:.0}x\n", 50_000_000f64 / total_power_uw as f64);
println!("Stage C - SNN + Coherence (conservative on low coherence):");
println!(" [Would add min-cut gating for additional safety]");
println!(" Expected: Additional 20-30% reduction in false positives");
println!("\n✨ SNN-Gated Inference Demo Complete!");
println!("\n💡 Key Takeaways:");
println!(" - SNN runs at μW, LLM runs at mW");
println!(" - 99% of sensor data is silence → 99% sleep time");
println!(" - SNN detects in μs, LLM explains later");
println!(" - Perfect for: wearables, industrial, home hubs, swarm nodes");
}