Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
This commit is contained in:
581
crates/ruvector-mincut/src/snn/synapse.rs
Normal file
581
crates/ruvector-mincut/src/snn/synapse.rs
Normal file
@@ -0,0 +1,581 @@
|
||||
//! # Synaptic Connections with STDP Learning
|
||||
//!
|
||||
//! Implements spike-timing dependent plasticity (STDP) for synaptic weight updates.
|
||||
//!
|
||||
//! ## STDP Learning Rule
|
||||
//!
|
||||
//! ```text
|
||||
//! ΔW = A+ * exp(-Δt/τ+) if Δt > 0 (pre before post → LTP)
|
||||
//! ΔW = A- * exp(Δt/τ-) if Δt < 0 (post before pre → LTD)
|
||||
//! ```
|
||||
//!
|
||||
//! Where Δt = t_post - t_pre
|
||||
//!
|
||||
//! ## Integration with MinCut
|
||||
//!
|
||||
//! Synaptic weights directly map to graph edge weights:
|
||||
//! - Strong synapse → strong edge → less likely in mincut
|
||||
//! - STDP learning → edge weight evolution → dynamic mincut
|
||||
|
||||
use super::{SimTime, Spike};
|
||||
use crate::graph::{DynamicGraph, VertexId, Weight};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Configuration for STDP learning
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct STDPConfig {
|
||||
/// LTP amplitude (potentiation)
|
||||
pub a_plus: f64,
|
||||
/// LTD amplitude (depression)
|
||||
pub a_minus: f64,
|
||||
/// LTP time constant (ms)
|
||||
pub tau_plus: f64,
|
||||
/// LTD time constant (ms)
|
||||
pub tau_minus: f64,
|
||||
/// Minimum weight
|
||||
pub w_min: f64,
|
||||
/// Maximum weight
|
||||
pub w_max: f64,
|
||||
/// Learning rate
|
||||
pub learning_rate: f64,
|
||||
/// Eligibility trace time constant
|
||||
pub tau_eligibility: f64,
|
||||
}
|
||||
|
||||
impl Default for STDPConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
a_plus: 0.01,
|
||||
a_minus: 0.012,
|
||||
tau_plus: 20.0,
|
||||
tau_minus: 20.0,
|
||||
w_min: 0.0,
|
||||
w_max: 1.0,
|
||||
learning_rate: 1.0,
|
||||
tau_eligibility: 1000.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A single synapse between two neurons
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Synapse {
|
||||
/// Pre-synaptic neuron ID
|
||||
pub pre: usize,
|
||||
/// Post-synaptic neuron ID
|
||||
pub post: usize,
|
||||
/// Synaptic weight
|
||||
pub weight: f64,
|
||||
/// Transmission delay (ms)
|
||||
pub delay: f64,
|
||||
/// Eligibility trace for reward-modulated STDP
|
||||
pub eligibility: f64,
|
||||
/// Last update time
|
||||
pub last_update: SimTime,
|
||||
}
|
||||
|
||||
impl Synapse {
|
||||
/// Create a new synapse
|
||||
pub fn new(pre: usize, post: usize, weight: f64) -> Self {
|
||||
Self {
|
||||
pre,
|
||||
post,
|
||||
weight,
|
||||
delay: 1.0,
|
||||
eligibility: 0.0,
|
||||
last_update: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create synapse with delay
|
||||
pub fn with_delay(pre: usize, post: usize, weight: f64, delay: f64) -> Self {
|
||||
Self {
|
||||
pre,
|
||||
post,
|
||||
weight,
|
||||
delay,
|
||||
eligibility: 0.0,
|
||||
last_update: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute STDP weight change
|
||||
pub fn stdp_update(&mut self, t_pre: SimTime, t_post: SimTime, config: &STDPConfig) -> f64 {
|
||||
let dt = t_post - t_pre;
|
||||
|
||||
let dw = if dt > 0.0 {
|
||||
// Pre before post → LTP
|
||||
config.a_plus * (-dt / config.tau_plus).exp()
|
||||
} else {
|
||||
// Post before pre → LTD
|
||||
-config.a_minus * (dt / config.tau_minus).exp()
|
||||
};
|
||||
|
||||
// Apply learning rate and clip
|
||||
let delta = config.learning_rate * dw;
|
||||
self.weight = (self.weight + delta).clamp(config.w_min, config.w_max);
|
||||
|
||||
// Update eligibility trace
|
||||
self.eligibility += dw;
|
||||
|
||||
delta
|
||||
}
|
||||
|
||||
/// Decay eligibility trace
|
||||
pub fn decay_eligibility(&mut self, dt: f64, tau: f64) {
|
||||
self.eligibility *= (-dt / tau).exp();
|
||||
}
|
||||
|
||||
/// Apply reward-modulated update (R-STDP)
|
||||
pub fn reward_modulated_update(&mut self, reward: f64, config: &STDPConfig) {
|
||||
let delta = reward * self.eligibility * config.learning_rate;
|
||||
self.weight = (self.weight + delta).clamp(config.w_min, config.w_max);
|
||||
// Reset eligibility after reward
|
||||
self.eligibility *= 0.5;
|
||||
}
|
||||
}
|
||||
|
||||
/// Matrix of synaptic connections
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SynapseMatrix {
|
||||
/// Number of pre-synaptic neurons
|
||||
pub n_pre: usize,
|
||||
/// Number of post-synaptic neurons
|
||||
pub n_post: usize,
|
||||
/// Synapses indexed by (pre, post)
|
||||
synapses: HashMap<(usize, usize), Synapse>,
|
||||
/// STDP configuration
|
||||
pub config: STDPConfig,
|
||||
/// Track last spike times for pre-synaptic neurons
|
||||
pre_spike_times: Vec<SimTime>,
|
||||
/// Track last spike times for post-synaptic neurons
|
||||
post_spike_times: Vec<SimTime>,
|
||||
}
|
||||
|
||||
impl SynapseMatrix {
|
||||
/// Create a new synapse matrix
|
||||
pub fn new(n_pre: usize, n_post: usize) -> Self {
|
||||
Self {
|
||||
n_pre,
|
||||
n_post,
|
||||
synapses: HashMap::new(),
|
||||
config: STDPConfig::default(),
|
||||
pre_spike_times: vec![f64::NEG_INFINITY; n_pre],
|
||||
post_spike_times: vec![f64::NEG_INFINITY; n_post],
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with custom STDP config
|
||||
pub fn with_config(n_pre: usize, n_post: usize, config: STDPConfig) -> Self {
|
||||
Self {
|
||||
n_pre,
|
||||
n_post,
|
||||
synapses: HashMap::new(),
|
||||
config,
|
||||
pre_spike_times: vec![f64::NEG_INFINITY; n_pre],
|
||||
post_spike_times: vec![f64::NEG_INFINITY; n_post],
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a synapse
|
||||
pub fn add_synapse(&mut self, pre: usize, post: usize, weight: f64) {
|
||||
if pre < self.n_pre && post < self.n_post {
|
||||
self.synapses
|
||||
.insert((pre, post), Synapse::new(pre, post, weight));
|
||||
}
|
||||
}
|
||||
|
||||
/// Get synapse if it exists
|
||||
pub fn get_synapse(&self, pre: usize, post: usize) -> Option<&Synapse> {
|
||||
self.synapses.get(&(pre, post))
|
||||
}
|
||||
|
||||
/// Get mutable synapse if it exists
|
||||
pub fn get_synapse_mut(&mut self, pre: usize, post: usize) -> Option<&mut Synapse> {
|
||||
self.synapses.get_mut(&(pre, post))
|
||||
}
|
||||
|
||||
/// Get weight of a synapse (0 if doesn't exist)
|
||||
pub fn weight(&self, pre: usize, post: usize) -> f64 {
|
||||
self.get_synapse(pre, post).map(|s| s.weight).unwrap_or(0.0)
|
||||
}
|
||||
|
||||
/// Compute weighted sum for all post-synaptic neurons given pre-synaptic activations
|
||||
///
|
||||
/// This is optimized to iterate only over existing synapses, avoiding O(n²) lookups.
|
||||
/// pre_activations[i] is the activation of pre-synaptic neuron i.
|
||||
/// Returns vector of weighted sums for each post-synaptic neuron.
|
||||
#[inline]
|
||||
pub fn compute_weighted_sums(&self, pre_activations: &[f64]) -> Vec<f64> {
|
||||
let mut sums = vec![0.0; self.n_post];
|
||||
|
||||
// Iterate only over existing synapses (sparse operation)
|
||||
for (&(pre, post), synapse) in &self.synapses {
|
||||
if pre < pre_activations.len() {
|
||||
sums[post] += synapse.weight * pre_activations[pre];
|
||||
}
|
||||
}
|
||||
|
||||
sums
|
||||
}
|
||||
|
||||
/// Compute weighted sum for a single post-synaptic neuron
|
||||
#[inline]
|
||||
pub fn weighted_sum_for_post(&self, post: usize, pre_activations: &[f64]) -> f64 {
|
||||
let mut sum = 0.0;
|
||||
for pre in 0..self.n_pre.min(pre_activations.len()) {
|
||||
if let Some(synapse) = self.synapses.get(&(pre, post)) {
|
||||
sum += synapse.weight * pre_activations[pre];
|
||||
}
|
||||
}
|
||||
sum
|
||||
}
|
||||
|
||||
/// Set weight of a synapse (creates if doesn't exist)
|
||||
pub fn set_weight(&mut self, pre: usize, post: usize, weight: f64) {
|
||||
if let Some(synapse) = self.get_synapse_mut(pre, post) {
|
||||
synapse.weight = weight;
|
||||
} else {
|
||||
self.add_synapse(pre, post, weight);
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a pre-synaptic spike and perform STDP updates
|
||||
pub fn on_pre_spike(&mut self, pre: usize, time: SimTime) {
|
||||
if pre >= self.n_pre {
|
||||
return;
|
||||
}
|
||||
|
||||
self.pre_spike_times[pre] = time;
|
||||
|
||||
// LTD: pre spike after recent post spikes
|
||||
for post in 0..self.n_post {
|
||||
if let Some(synapse) = self.synapses.get_mut(&(pre, post)) {
|
||||
let t_post = self.post_spike_times[post];
|
||||
if t_post > f64::NEG_INFINITY {
|
||||
synapse.stdp_update(time, t_post, &self.config);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a post-synaptic spike and perform STDP updates
|
||||
pub fn on_post_spike(&mut self, post: usize, time: SimTime) {
|
||||
if post >= self.n_post {
|
||||
return;
|
||||
}
|
||||
|
||||
self.post_spike_times[post] = time;
|
||||
|
||||
// LTP: post spike after recent pre spikes
|
||||
for pre in 0..self.n_pre {
|
||||
if let Some(synapse) = self.synapses.get_mut(&(pre, post)) {
|
||||
let t_pre = self.pre_spike_times[pre];
|
||||
if t_pre > f64::NEG_INFINITY {
|
||||
synapse.stdp_update(t_pre, time, &self.config);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Process multiple spikes with STDP
|
||||
pub fn process_spikes(&mut self, spikes: &[Spike]) {
|
||||
for spike in spikes {
|
||||
// Assume neuron IDs map directly
|
||||
// Pre-synaptic: lower half, Post-synaptic: upper half (example mapping)
|
||||
if spike.neuron_id < self.n_pre {
|
||||
self.on_pre_spike(spike.neuron_id, spike.time);
|
||||
}
|
||||
if spike.neuron_id < self.n_post {
|
||||
self.on_post_spike(spike.neuron_id, spike.time);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Decay all eligibility traces
|
||||
pub fn decay_eligibility(&mut self, dt: f64) {
|
||||
for synapse in self.synapses.values_mut() {
|
||||
synapse.decay_eligibility(dt, self.config.tau_eligibility);
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply reward-modulated learning to all synapses
|
||||
pub fn apply_reward(&mut self, reward: f64) {
|
||||
for synapse in self.synapses.values_mut() {
|
||||
synapse.reward_modulated_update(reward, &self.config);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get all synapses as an iterator
|
||||
pub fn iter(&self) -> impl Iterator<Item = (&(usize, usize), &Synapse)> {
|
||||
self.synapses.iter()
|
||||
}
|
||||
|
||||
/// Get number of synapses
|
||||
pub fn num_synapses(&self) -> usize {
|
||||
self.synapses.len()
|
||||
}
|
||||
|
||||
/// Compute total synaptic input to a post-synaptic neuron
|
||||
pub fn input_to(&self, post: usize, pre_activities: &[f64]) -> f64 {
|
||||
let mut total = 0.0;
|
||||
for pre in 0..self.n_pre.min(pre_activities.len()) {
|
||||
total += self.weight(pre, post) * pre_activities[pre];
|
||||
}
|
||||
total
|
||||
}
|
||||
|
||||
/// Create dense weight matrix
|
||||
pub fn to_dense(&self) -> Vec<Vec<f64>> {
|
||||
let mut matrix = vec![vec![0.0; self.n_post]; self.n_pre];
|
||||
for ((pre, post), synapse) in &self.synapses {
|
||||
matrix[*pre][*post] = synapse.weight;
|
||||
}
|
||||
matrix
|
||||
}
|
||||
|
||||
/// Initialize from dense matrix
|
||||
pub fn from_dense(matrix: &[Vec<f64>]) -> Self {
|
||||
let n_pre = matrix.len();
|
||||
let n_post = matrix.first().map(|r| r.len()).unwrap_or(0);
|
||||
|
||||
let mut sm = Self::new(n_pre, n_post);
|
||||
|
||||
for (pre, row) in matrix.iter().enumerate() {
|
||||
for (post, &weight) in row.iter().enumerate() {
|
||||
if weight != 0.0 {
|
||||
sm.add_synapse(pre, post, weight);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sm
|
||||
}
|
||||
|
||||
/// Synchronize weights with a DynamicGraph
|
||||
/// Maps neurons to vertices via a mapping function
|
||||
pub fn sync_to_graph<F>(&self, graph: &mut DynamicGraph, neuron_to_vertex: F)
|
||||
where
|
||||
F: Fn(usize) -> VertexId,
|
||||
{
|
||||
for ((pre, post), synapse) in &self.synapses {
|
||||
let u = neuron_to_vertex(*pre);
|
||||
let v = neuron_to_vertex(*post);
|
||||
|
||||
if graph.has_edge(u, v) {
|
||||
let _ = graph.update_edge_weight(u, v, synapse.weight);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Load weights from a DynamicGraph
|
||||
pub fn sync_from_graph<F>(&mut self, graph: &DynamicGraph, vertex_to_neuron: F)
|
||||
where
|
||||
F: Fn(VertexId) -> usize,
|
||||
{
|
||||
for edge in graph.edges() {
|
||||
let pre = vertex_to_neuron(edge.source);
|
||||
let post = vertex_to_neuron(edge.target);
|
||||
|
||||
if pre < self.n_pre && post < self.n_post {
|
||||
self.set_weight(pre, post, edge.weight);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get high-correlation pairs (synapses with weight above threshold)
|
||||
pub fn high_correlation_pairs(&self, threshold: f64) -> Vec<(usize, usize)> {
|
||||
self.synapses
|
||||
.iter()
|
||||
.filter(|(_, s)| s.weight >= threshold)
|
||||
.map(|((pre, post), _)| (*pre, *post))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Asymmetric STDP for causal relationship encoding
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AsymmetricSTDP {
|
||||
/// Forward (causal) time constant
|
||||
pub tau_forward: f64,
|
||||
/// Backward time constant
|
||||
pub tau_backward: f64,
|
||||
/// Forward amplitude (typically larger for causality)
|
||||
pub a_forward: f64,
|
||||
/// Backward amplitude
|
||||
pub a_backward: f64,
|
||||
}
|
||||
|
||||
impl Default for AsymmetricSTDP {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
tau_forward: 15.0,
|
||||
tau_backward: 30.0, // Longer backward window
|
||||
a_forward: 0.015, // Stronger forward (causal)
|
||||
a_backward: 0.008, // Weaker backward
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsymmetricSTDP {
|
||||
/// Compute weight change for causal relationship encoding
|
||||
/// Positive Δt (pre→post) is weighted more heavily
|
||||
pub fn compute_dw(&self, dt: f64) -> f64 {
|
||||
if dt > 0.0 {
|
||||
// Pre before post → causal relationship
|
||||
self.a_forward * (-dt / self.tau_forward).exp()
|
||||
} else {
|
||||
// Post before pre → anti-causal
|
||||
-self.a_backward * (dt / self.tau_backward).exp()
|
||||
}
|
||||
}
|
||||
|
||||
/// Update weight matrix for causal discovery
|
||||
pub fn update_weights(&self, matrix: &mut SynapseMatrix, neuron_id: usize, time: SimTime) {
|
||||
let w_min = matrix.config.w_min;
|
||||
let w_max = matrix.config.w_max;
|
||||
let n_pre = matrix.n_pre;
|
||||
let n_post = matrix.n_post;
|
||||
|
||||
// Collect pre-spike times first to avoid borrow conflicts
|
||||
let pre_times: Vec<_> = (0..n_pre)
|
||||
.map(|pre| {
|
||||
matrix
|
||||
.pre_spike_times
|
||||
.get(pre)
|
||||
.copied()
|
||||
.unwrap_or(f64::NEG_INFINITY)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// This neuron just spiked - update all synapses involving it (incoming)
|
||||
for pre in 0..n_pre {
|
||||
let t_pre = pre_times[pre];
|
||||
if t_pre > f64::NEG_INFINITY {
|
||||
let dt = time - t_pre;
|
||||
let dw = self.compute_dw(dt);
|
||||
if let Some(synapse) = matrix.get_synapse_mut(pre, neuron_id) {
|
||||
synapse.weight = (synapse.weight + dw).clamp(w_min, w_max);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Collect post-spike times
|
||||
let post_times: Vec<_> = (0..n_post)
|
||||
.map(|post| {
|
||||
matrix
|
||||
.post_spike_times
|
||||
.get(post)
|
||||
.copied()
|
||||
.unwrap_or(f64::NEG_INFINITY)
|
||||
})
|
||||
.collect();
|
||||
|
||||
for post in 0..n_post {
|
||||
let t_post = post_times[post];
|
||||
if t_post > f64::NEG_INFINITY {
|
||||
let dt = t_post - time; // Reversed for outgoing
|
||||
let dw = self.compute_dw(dt);
|
||||
if let Some(synapse) = matrix.get_synapse_mut(neuron_id, post) {
|
||||
synapse.weight = (synapse.weight + dw).clamp(w_min, w_max);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_synapse_creation() {
|
||||
let synapse = Synapse::new(0, 1, 0.5);
|
||||
assert_eq!(synapse.pre, 0);
|
||||
assert_eq!(synapse.post, 1);
|
||||
assert_eq!(synapse.weight, 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stdp_ltp() {
|
||||
let mut synapse = Synapse::new(0, 1, 0.5);
|
||||
let config = STDPConfig::default();
|
||||
|
||||
// Pre before post → LTP
|
||||
let dw = synapse.stdp_update(10.0, 15.0, &config);
|
||||
assert!(dw > 0.0);
|
||||
assert!(synapse.weight > 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stdp_ltd() {
|
||||
let mut synapse = Synapse::new(0, 1, 0.5);
|
||||
let config = STDPConfig::default();
|
||||
|
||||
// Post before pre → LTD
|
||||
let dw = synapse.stdp_update(15.0, 10.0, &config);
|
||||
assert!(dw < 0.0);
|
||||
assert!(synapse.weight < 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_synapse_matrix() {
|
||||
let mut matrix = SynapseMatrix::new(10, 10);
|
||||
matrix.add_synapse(0, 1, 0.5);
|
||||
matrix.add_synapse(1, 2, 0.3);
|
||||
|
||||
assert_eq!(matrix.num_synapses(), 2);
|
||||
assert!((matrix.weight(0, 1) - 0.5).abs() < 0.001);
|
||||
assert!((matrix.weight(1, 2) - 0.3).abs() < 0.001);
|
||||
assert_eq!(matrix.weight(2, 3), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_spike_processing() {
|
||||
let mut matrix = SynapseMatrix::new(5, 5);
|
||||
|
||||
// Fully connected
|
||||
for i in 0..5 {
|
||||
for j in 0..5 {
|
||||
if i != j {
|
||||
matrix.add_synapse(i, j, 0.5);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pre spike then post spike → LTP
|
||||
matrix.on_pre_spike(0, 10.0);
|
||||
matrix.on_post_spike(1, 15.0);
|
||||
|
||||
// Should have strengthened 0→1 connection
|
||||
assert!(matrix.weight(0, 1) > 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_asymmetric_stdp() {
|
||||
let stdp = AsymmetricSTDP::default();
|
||||
|
||||
// Causal (dt > 0) should have larger effect
|
||||
let dw_causal = stdp.compute_dw(5.0);
|
||||
let dw_anticausal = stdp.compute_dw(-5.0);
|
||||
|
||||
assert!(dw_causal > 0.0);
|
||||
assert!(dw_anticausal < 0.0);
|
||||
assert!(dw_causal.abs() > dw_anticausal.abs());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dense_conversion() {
|
||||
let mut matrix = SynapseMatrix::new(3, 3);
|
||||
matrix.add_synapse(0, 1, 0.5);
|
||||
matrix.add_synapse(1, 2, 0.7);
|
||||
|
||||
let dense = matrix.to_dense();
|
||||
assert_eq!(dense.len(), 3);
|
||||
assert!((dense[0][1] - 0.5).abs() < 0.001);
|
||||
assert!((dense[1][2] - 0.7).abs() < 0.001);
|
||||
|
||||
let recovered = SynapseMatrix::from_dense(&dense);
|
||||
assert_eq!(recovered.num_synapses(), 2);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user