Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
388
vendor/ruvector/examples/spiking-network/src/encoding/mod.rs
vendored
Normal file
388
vendor/ruvector/examples/spiking-network/src/encoding/mod.rs
vendored
Normal file
@@ -0,0 +1,388 @@
|
||||
//! Spike encoding and decoding utilities.
|
||||
//!
|
||||
//! This module provides methods to convert between analog signals and sparse spike trains.
|
||||
//!
|
||||
//! ## Encoding Schemes
|
||||
//!
|
||||
//! - **Rate coding**: Spike frequency encodes magnitude
|
||||
//! - **Temporal coding**: Spike timing encodes information
|
||||
//! - **Population coding**: Distributed representation across neurons
|
||||
//! - **Delta modulation**: Spikes encode changes only
|
||||
//!
|
||||
//! ## ASIC Benefits
|
||||
//!
|
||||
//! Sparse spike representations dramatically reduce:
|
||||
//! - Memory bandwidth (only store/transmit active spikes)
|
||||
//! - Computation (skip silent neurons entirely)
|
||||
//! - Power consumption (event-driven processing)
|
||||
|
||||
use bitvec::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use smallvec::SmallVec;
|
||||
|
||||
/// A single spike event with source and timing.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
|
||||
pub struct SpikeEvent {
|
||||
/// Source neuron index
|
||||
pub source: u32,
|
||||
/// Timestamp in simulation time units
|
||||
pub time: f32,
|
||||
/// Optional payload (for routing or plasticity)
|
||||
pub payload: u8,
|
||||
}
|
||||
|
||||
impl SpikeEvent {
|
||||
/// Create a new spike event.
|
||||
pub fn new(source: u32, time: f32) -> Self {
|
||||
Self {
|
||||
source,
|
||||
time,
|
||||
payload: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create spike with payload.
|
||||
pub fn with_payload(source: u32, time: f32, payload: u8) -> Self {
|
||||
Self { source, time, payload }
|
||||
}
|
||||
}
|
||||
|
||||
/// A train of spikes from a single neuron over time.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SpikeTrain {
|
||||
/// Neuron ID
|
||||
pub neuron_id: u32,
|
||||
/// Spike times (sorted ascending)
|
||||
pub times: Vec<f32>,
|
||||
}
|
||||
|
||||
impl SpikeTrain {
|
||||
/// Create empty spike train for a neuron.
|
||||
pub fn new(neuron_id: u32) -> Self {
|
||||
Self {
|
||||
neuron_id,
|
||||
times: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a spike at given time.
|
||||
pub fn add_spike(&mut self, time: f32) {
|
||||
self.times.push(time);
|
||||
}
|
||||
|
||||
/// Get spike count.
|
||||
pub fn spike_count(&self) -> usize {
|
||||
self.times.len()
|
||||
}
|
||||
|
||||
/// Calculate firing rate over duration.
|
||||
pub fn firing_rate(&self, duration: f32) -> f32 {
|
||||
if duration <= 0.0 {
|
||||
return 0.0;
|
||||
}
|
||||
self.times.len() as f32 / duration * 1000.0 // Hz
|
||||
}
|
||||
|
||||
/// Get inter-spike intervals.
|
||||
pub fn isis(&self) -> Vec<f32> {
|
||||
if self.times.len() < 2 {
|
||||
return Vec::new();
|
||||
}
|
||||
self.times.windows(2).map(|w| w[1] - w[0]).collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Sparse spike matrix for population activity.
|
||||
///
|
||||
/// Uses compressed representation - only stores active spikes.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SparseSpikes {
|
||||
/// Number of neurons in population
|
||||
pub num_neurons: u32,
|
||||
/// Number of time bins
|
||||
pub num_timesteps: u32,
|
||||
/// Spike events (sorted by time)
|
||||
pub events: Vec<SpikeEvent>,
|
||||
}
|
||||
|
||||
impl SparseSpikes {
|
||||
/// Create empty spike matrix.
|
||||
pub fn new(num_neurons: u32, num_timesteps: u32) -> Self {
|
||||
Self {
|
||||
num_neurons,
|
||||
num_timesteps,
|
||||
events: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a spike event.
|
||||
pub fn add_spike(&mut self, neuron: u32, timestep: u32) {
|
||||
if neuron < self.num_neurons && timestep < self.num_timesteps {
|
||||
self.events.push(SpikeEvent::new(neuron, timestep as f32));
|
||||
}
|
||||
}
|
||||
|
||||
/// Get sparsity (fraction of silent entries).
|
||||
pub fn sparsity(&self) -> f32 {
|
||||
let total = self.num_neurons as f64 * self.num_timesteps as f64;
|
||||
if total == 0.0 {
|
||||
return 1.0;
|
||||
}
|
||||
1.0 - (self.events.len() as f64 / total) as f32
|
||||
}
|
||||
|
||||
/// Get neurons that spiked at given timestep.
|
||||
pub fn spikes_at(&self, timestep: u32) -> SmallVec<[u32; 8]> {
|
||||
self.events
|
||||
.iter()
|
||||
.filter(|e| e.time as u32 == timestep)
|
||||
.map(|e| e.source)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get spike count.
|
||||
pub fn spike_count(&self) -> usize {
|
||||
self.events.len()
|
||||
}
|
||||
}
|
||||
|
||||
/// Encoder for converting analog values to spike trains.
|
||||
pub struct SpikeEncoder;
|
||||
|
||||
impl SpikeEncoder {
|
||||
/// Rate coding: Convert value to spike probability per timestep.
|
||||
///
|
||||
/// Higher values produce more frequent spikes.
|
||||
/// Returns sparse bit vector of spike times.
|
||||
pub fn rate_encode(value: f32, duration_ms: f32, dt: f32, max_rate_hz: f32) -> BitVec {
|
||||
let num_steps = (duration_ms / dt) as usize;
|
||||
let mut spikes = bitvec![0; num_steps];
|
||||
|
||||
// Clamp value to [0, 1]
|
||||
let normalized = value.clamp(0.0, 1.0);
|
||||
|
||||
// Convert to spike probability per timestep
|
||||
let prob_per_step = normalized * max_rate_hz * dt / 1000.0;
|
||||
|
||||
// Generate spikes stochastically
|
||||
use rand::Rng;
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
for i in 0..num_steps {
|
||||
if rng.gen::<f32>() < prob_per_step {
|
||||
spikes.set(i, true);
|
||||
}
|
||||
}
|
||||
|
||||
spikes
|
||||
}
|
||||
|
||||
/// Temporal coding: First spike time encodes value.
|
||||
///
|
||||
/// Lower values spike earlier (inverse temporal coding).
|
||||
pub fn temporal_encode(value: f32, max_latency_ms: f32) -> f32 {
|
||||
let normalized = value.clamp(0.0, 1.0);
|
||||
// Invert: high value = early spike
|
||||
(1.0 - normalized) * max_latency_ms
|
||||
}
|
||||
|
||||
/// Delta modulation: Spike on significant change.
|
||||
///
|
||||
/// Returns (+1, 0, -1) for increase, no change, decrease.
|
||||
pub fn delta_encode(current: f32, previous: f32, threshold: f32) -> i8 {
|
||||
let delta = current - previous;
|
||||
if delta > threshold {
|
||||
1 // Positive spike
|
||||
} else if delta < -threshold {
|
||||
-1 // Negative spike
|
||||
} else {
|
||||
0 // No spike
|
||||
}
|
||||
}
|
||||
|
||||
/// Population coding: Distribute value across multiple neurons.
|
||||
///
|
||||
/// Returns spike pattern across `num_neurons` with Gaussian tuning curves.
|
||||
pub fn population_encode(value: f32, num_neurons: usize, sigma: f32) -> Vec<f32> {
|
||||
let mut activities = vec![0.0; num_neurons];
|
||||
let centers: Vec<f32> = (0..num_neurons)
|
||||
.map(|i| i as f32 / (num_neurons - 1).max(1) as f32)
|
||||
.collect();
|
||||
|
||||
for (i, ¢er) in centers.iter().enumerate() {
|
||||
let diff = value - center;
|
||||
activities[i] = (-diff * diff / (2.0 * sigma * sigma)).exp();
|
||||
}
|
||||
|
||||
activities
|
||||
}
|
||||
|
||||
/// Convert image patch to spike-based representation.
|
||||
///
|
||||
/// Uses difference-of-Gaussians for edge detection,
|
||||
/// then temporal coding for spike generation.
|
||||
pub fn encode_image_patch(
|
||||
patch: &[f32],
|
||||
width: usize,
|
||||
height: usize,
|
||||
) -> SparseSpikes {
|
||||
let mut spikes = SparseSpikes::new((width * height) as u32, 100);
|
||||
|
||||
// Simple intensity-based encoding
|
||||
for (i, &pixel) in patch.iter().enumerate() {
|
||||
if i >= width * height {
|
||||
break;
|
||||
}
|
||||
// Higher intensity = earlier spike
|
||||
let spike_time = ((1.0 - pixel.clamp(0.0, 1.0)) * 99.0) as u32;
|
||||
if pixel > 0.1 {
|
||||
// Threshold
|
||||
spikes.add_spike(i as u32, spike_time);
|
||||
}
|
||||
}
|
||||
|
||||
spikes
|
||||
}
|
||||
}
|
||||
|
||||
/// Decoder for converting spike trains back to analog values.
|
||||
pub struct SpikeDecoder;
|
||||
|
||||
impl SpikeDecoder {
|
||||
/// Decode rate-coded spikes to value.
|
||||
pub fn rate_decode(spikes: &BitVec, dt: f32, max_rate_hz: f32) -> f32 {
|
||||
let spike_count = spikes.count_ones();
|
||||
let duration_ms = spikes.len() as f32 * dt;
|
||||
let rate_hz = spike_count as f32 / duration_ms * 1000.0;
|
||||
(rate_hz / max_rate_hz).clamp(0.0, 1.0)
|
||||
}
|
||||
|
||||
/// Decode temporally-coded spike time to value.
|
||||
pub fn temporal_decode(spike_time: f32, max_latency_ms: f32) -> f32 {
|
||||
1.0 - (spike_time / max_latency_ms).clamp(0.0, 1.0)
|
||||
}
|
||||
|
||||
/// Decode population activity to value using center-of-mass.
|
||||
pub fn population_decode(activities: &[f32]) -> f32 {
|
||||
let num_neurons = activities.len();
|
||||
if num_neurons == 0 {
|
||||
return 0.5;
|
||||
}
|
||||
|
||||
let mut weighted_sum = 0.0;
|
||||
let mut total_weight = 0.0;
|
||||
|
||||
for (i, &activity) in activities.iter().enumerate() {
|
||||
let center = i as f32 / (num_neurons - 1).max(1) as f32;
|
||||
weighted_sum += center * activity;
|
||||
total_weight += activity;
|
||||
}
|
||||
|
||||
if total_weight > 0.0 {
|
||||
weighted_sum / total_weight
|
||||
} else {
|
||||
0.5
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_spike_event_creation() {
|
||||
let event = SpikeEvent::new(42, 10.5);
|
||||
assert_eq!(event.source, 42);
|
||||
assert_eq!(event.time, 10.5);
|
||||
assert_eq!(event.payload, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_spike_train() {
|
||||
let mut train = SpikeTrain::new(0);
|
||||
train.add_spike(10.0);
|
||||
train.add_spike(30.0);
|
||||
train.add_spike(50.0);
|
||||
|
||||
assert_eq!(train.spike_count(), 3);
|
||||
assert!((train.firing_rate(100.0) - 30.0).abs() < 0.1);
|
||||
|
||||
let isis = train.isis();
|
||||
assert_eq!(isis, vec![20.0, 20.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sparse_spikes() {
|
||||
let mut spikes = SparseSpikes::new(100, 1000);
|
||||
spikes.add_spike(0, 10);
|
||||
spikes.add_spike(50, 10);
|
||||
spikes.add_spike(99, 500);
|
||||
|
||||
assert_eq!(spikes.spike_count(), 3);
|
||||
assert!(spikes.sparsity() > 0.99); // Very sparse
|
||||
|
||||
let at_10 = spikes.spikes_at(10);
|
||||
assert_eq!(at_10.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rate_encoding() {
|
||||
// High value should produce more spikes
|
||||
let high_spikes = SpikeEncoder::rate_encode(0.9, 100.0, 1.0, 100.0);
|
||||
let low_spikes = SpikeEncoder::rate_encode(0.1, 100.0, 1.0, 100.0);
|
||||
|
||||
// Statistical test - not deterministic
|
||||
assert!(high_spikes.count_ones() > low_spikes.count_ones() / 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_temporal_encoding() {
|
||||
let early = SpikeEncoder::temporal_encode(0.9, 100.0);
|
||||
let late = SpikeEncoder::temporal_encode(0.1, 100.0);
|
||||
|
||||
assert!(early < late); // High value = early spike
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delta_encoding() {
|
||||
assert_eq!(SpikeEncoder::delta_encode(1.0, 0.0, 0.5), 1);
|
||||
assert_eq!(SpikeEncoder::delta_encode(0.0, 1.0, 0.5), -1);
|
||||
assert_eq!(SpikeEncoder::delta_encode(0.5, 0.5, 0.5), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_population_encoding() {
|
||||
let activities = SpikeEncoder::population_encode(0.5, 10, 0.2);
|
||||
assert_eq!(activities.len(), 10);
|
||||
|
||||
// Middle neuron should have highest activity
|
||||
let max_idx = activities
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
|
||||
.map(|(i, _)| i)
|
||||
.unwrap();
|
||||
assert_eq!(max_idx, 4); // Close to middle
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rate_decode() {
|
||||
let mut spikes = bitvec![0; 100];
|
||||
// 10 spikes in 100 timesteps at dt=1ms = 100 Hz
|
||||
for i in (0..100).step_by(10) {
|
||||
spikes.set(i, true);
|
||||
}
|
||||
|
||||
let decoded = SpikeDecoder::rate_decode(&spikes, 1.0, 100.0);
|
||||
assert!((decoded - 1.0).abs() < 0.1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_population_decode() {
|
||||
// Peak at middle
|
||||
let activities = vec![0.0, 0.1, 0.5, 1.0, 0.5, 0.1, 0.0];
|
||||
let decoded = SpikeDecoder::population_decode(&activities);
|
||||
assert!((decoded - 0.5).abs() < 0.1);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user