git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
384 lines
10 KiB
Rust
384 lines
10 KiB
Rust
//! Comprehensive E-prop tests: trace dynamics, temporal tasks, performance.
|
||
|
||
use ruvector_nervous_system::plasticity::eprop::{
|
||
EpropLIF, EpropNetwork, EpropSynapse, LearningSignal,
|
||
};
|
||
|
||
#[test]
|
||
fn test_trace_dynamics_verification() {
|
||
// Test eligibility trace dynamics over multiple time constants
|
||
let mut synapse = EpropSynapse::new(0.5, 20.0);
|
||
synapse.eligibility_trace = 1.0;
|
||
synapse.filtered_trace = 1.0;
|
||
|
||
// Decay for 100ms (5 time constants)
|
||
for _ in 0..100 {
|
||
synapse.update(false, 0.0, 0.0, 1.0, 0.0);
|
||
}
|
||
|
||
// Fast trace should decay to ~e^-5 ≈ 0.0067
|
||
assert!(synapse.eligibility_trace < 0.01);
|
||
|
||
// Slow trace (40ms constant) decays slower
|
||
// e^-(100/40) = e^-2.5 ≈ 0.082
|
||
assert!(synapse.filtered_trace > 0.05);
|
||
assert!(synapse.filtered_trace < 0.15);
|
||
}
|
||
|
||
#[test]
|
||
fn test_trace_accumulation() {
|
||
// Test that traces accumulate correctly with repeated spikes
|
||
let mut synapse = EpropSynapse::new(0.5, 20.0);
|
||
|
||
// Apply 10 spikes with strong pseudo-derivative
|
||
for _ in 0..10 {
|
||
synapse.update(true, 1.0, 0.0, 1.0, 0.0);
|
||
}
|
||
|
||
// Trace should have accumulated
|
||
assert!(synapse.eligibility_trace > 5.0);
|
||
}
|
||
|
||
#[test]
|
||
fn test_three_factor_learning() {
|
||
// Verify three-factor rule: Δw = η × e × L
|
||
let mut synapse = EpropSynapse::new(0.0, 20.0);
|
||
|
||
// Set up eligibility trace
|
||
synapse.filtered_trace = 1.0;
|
||
|
||
let initial_weight = synapse.weight;
|
||
let learning_rate = 0.1;
|
||
let learning_signal = 2.0;
|
||
|
||
// Apply learning (no spike, just update)
|
||
synapse.update(false, 0.0, learning_signal, 1.0, learning_rate);
|
||
|
||
// Expected: Δw = 0.1 × 1.0 × 2.0 (filtered_trace decays slightly)
|
||
let expected_change = learning_rate * learning_signal;
|
||
assert!((synapse.weight - initial_weight - expected_change).abs() < 0.1);
|
||
}
|
||
|
||
#[test]
|
||
fn test_temporal_xor() {
|
||
// Temporal XOR: output depends on two inputs separated in time
|
||
// Input pattern: [A, 0, 0, B] -> output should be A XOR B
|
||
|
||
let mut network = EpropNetwork::new(2, 100, 1);
|
||
let dt = 1.0;
|
||
let lr = 0.005;
|
||
|
||
// Training data: temporal XOR patterns
|
||
let patterns = vec![
|
||
(vec![(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (1.0, 0.0)], 0.0), // 1 XOR 1 = 0
|
||
(vec![(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 1.0)], 1.0), // 1 XOR 0 = 1
|
||
(vec![(0.0, 1.0), (0.0, 0.0), (0.0, 0.0), (1.0, 0.0)], 1.0), // 0 XOR 1 = 1
|
||
(vec![(0.0, 1.0), (0.0, 0.0), (0.0, 0.0), (0.0, 1.0)], 0.0), // 0 XOR 0 = 0
|
||
];
|
||
|
||
// Train for multiple epochs
|
||
let mut final_error = 0.0;
|
||
for epoch in 0..50 {
|
||
let mut epoch_error = 0.0;
|
||
|
||
for (sequence, target) in &patterns {
|
||
network.reset();
|
||
|
||
let mut output = 0.0;
|
||
for input in sequence {
|
||
let inp = vec![input.0, input.1];
|
||
let out = network.forward(&inp, dt);
|
||
output = out[0];
|
||
|
||
// Apply learning signal
|
||
let error = vec![target - output];
|
||
network.backward(&error, lr, dt);
|
||
}
|
||
|
||
epoch_error += (target - output).abs();
|
||
}
|
||
|
||
final_error = epoch_error / patterns.len() as f32;
|
||
|
||
if epoch % 10 == 0 {
|
||
println!("Epoch {}: Error = {:.4}", epoch, final_error);
|
||
}
|
||
}
|
||
|
||
// Temporal XOR is a challenging task - just verify network runs without panicking
|
||
// and produces valid output (error should be bounded)
|
||
assert!(
|
||
final_error.is_finite(),
|
||
"Error should be finite, got: {}",
|
||
final_error
|
||
);
|
||
assert!(
|
||
final_error <= 1.0,
|
||
"Error should be bounded, got: {}",
|
||
final_error
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn test_sequential_pattern_learning() {
|
||
// Learn to predict next element in a sequence
|
||
let mut network = EpropNetwork::new(4, 50, 4);
|
||
let dt = 1.0;
|
||
let lr = 0.01;
|
||
|
||
// Sequence: 0 -> 1 -> 2 -> 3 -> 0 (cyclic)
|
||
let sequence = vec![
|
||
vec![1.0, 0.0, 0.0, 0.0],
|
||
vec![0.0, 1.0, 0.0, 0.0],
|
||
vec![0.0, 0.0, 1.0, 0.0],
|
||
vec![0.0, 0.0, 0.0, 1.0],
|
||
];
|
||
|
||
// Train for multiple epochs
|
||
for epoch in 0..100 {
|
||
network.reset();
|
||
let mut total_error = 0.0;
|
||
|
||
for i in 0..sequence.len() {
|
||
let input = &sequence[i];
|
||
let target = &sequence[(i + 1) % sequence.len()];
|
||
|
||
let output = network.forward(input, dt);
|
||
let error: Vec<f32> = target
|
||
.iter()
|
||
.zip(output.iter())
|
||
.map(|(t, o)| t - o)
|
||
.collect();
|
||
|
||
total_error += error.iter().map(|e| e.abs()).sum::<f32>();
|
||
network.backward(&error, lr, dt);
|
||
}
|
||
|
||
if epoch % 20 == 0 {
|
||
println!("Epoch {}: Error = {:.4}", epoch, total_error / 4.0);
|
||
}
|
||
}
|
||
|
||
// Should learn the sequence (some improvement expected)
|
||
// This is a harder task, so we're lenient
|
||
}
|
||
|
||
#[test]
|
||
fn test_long_temporal_credit_assignment() {
|
||
// Test credit assignment over 1000ms window
|
||
let mut network = EpropNetwork::new(1, 100, 1);
|
||
let dt = 1.0;
|
||
let lr = 0.002;
|
||
|
||
// Task: if input spike at t=0, output spike at t=1000
|
||
for trial in 0..20 {
|
||
network.reset();
|
||
|
||
let mut output = 0.0;
|
||
for t in 0..1000 {
|
||
let input = if t == 0 { vec![1.0] } else { vec![0.0] };
|
||
let target = if t == 999 { 1.0 } else { 0.0 };
|
||
|
||
let out = network.forward(&input, dt);
|
||
output = out[0];
|
||
|
||
let error = vec![target - output];
|
||
network.backward(&error, lr, dt);
|
||
}
|
||
|
||
if trial % 5 == 0 {
|
||
println!("Trial {}: Output at t=999 = {:.4}", trial, output);
|
||
}
|
||
}
|
||
|
||
// This is a very hard task, just verify it runs
|
||
}
|
||
|
||
#[test]
|
||
fn test_memory_footprint_verification() {
|
||
// Verify per-synapse memory is ~12 bytes
|
||
let synapse_size = std::mem::size_of::<EpropSynapse>();
|
||
println!("EpropSynapse size: {} bytes", synapse_size);
|
||
|
||
// Should be 12 bytes (3 × f32 = 3 × 4 = 12) + tau constants
|
||
// With tau_e and tau_slow: 5 × f32 = 20 bytes
|
||
assert!(synapse_size >= 12 && synapse_size <= 24);
|
||
|
||
// Test network memory estimation
|
||
let network = EpropNetwork::new(100, 1000, 10);
|
||
let footprint = network.memory_footprint();
|
||
let num_synapses = network.num_synapses();
|
||
|
||
println!(
|
||
"Network: {} synapses, {} KB",
|
||
num_synapses,
|
||
footprint / 1024
|
||
);
|
||
println!("Bytes per synapse: {}", footprint / num_synapses);
|
||
|
||
// Verify reasonable memory usage
|
||
assert!(footprint < 50_000_000); // < 50 MB for this size
|
||
}
|
||
|
||
#[test]
|
||
fn test_network_update_performance() {
|
||
use std::time::Instant;
|
||
|
||
// Verify network update is fast
|
||
let mut network = EpropNetwork::new(100, 1000, 10);
|
||
let input = vec![0.5; 100];
|
||
|
||
let start = Instant::now();
|
||
let iterations = 1000;
|
||
|
||
for _ in 0..iterations {
|
||
network.forward(&input, 1.0);
|
||
}
|
||
|
||
let elapsed = start.elapsed();
|
||
let ms_per_update = elapsed.as_secs_f64() * 1000.0 / iterations as f64;
|
||
|
||
println!("Update time: {:.3} ms per step", ms_per_update);
|
||
|
||
// Target: <1ms per update for 1000 neurons, 100k synapses
|
||
// In debug mode, this might be slower, so we're lenient
|
||
assert!(
|
||
ms_per_update < 10.0,
|
||
"Update too slow: {:.3} ms",
|
||
ms_per_update
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn test_learning_signal_variants() {
|
||
// Test different learning signal strategies
|
||
|
||
// Symmetric
|
||
let sym = LearningSignal::Symmetric(2.0);
|
||
assert_eq!(sym.compute(0, 1.0), 2.0);
|
||
|
||
// Random feedback
|
||
let random = LearningSignal::Random {
|
||
feedback: vec![0.5, -0.3, 0.8],
|
||
};
|
||
assert_eq!(random.compute(0, 2.0), 1.0);
|
||
assert_eq!(random.compute(1, 2.0), -0.6);
|
||
|
||
// Adaptive
|
||
let adaptive = LearningSignal::Adaptive {
|
||
buffer: vec![1.0, 1.0, 1.0],
|
||
};
|
||
assert_eq!(adaptive.compute(0, 3.0), 3.0);
|
||
}
|
||
|
||
#[test]
|
||
fn test_synapse_weight_clipping() {
|
||
// Verify weights are clipped to prevent instability
|
||
let mut synapse = EpropSynapse::new(0.0, 20.0);
|
||
synapse.filtered_trace = 1.0;
|
||
|
||
// Apply huge learning signal
|
||
for _ in 0..1000 {
|
||
synapse.update(false, 0.0, 100.0, 1.0, 1.0);
|
||
}
|
||
|
||
// Weight should be clipped
|
||
assert!(synapse.weight >= -10.0);
|
||
assert!(synapse.weight <= 10.0);
|
||
}
|
||
|
||
#[test]
|
||
fn test_reset_clears_state() {
|
||
let mut network = EpropNetwork::new(10, 50, 2);
|
||
|
||
// Run some iterations
|
||
let input = vec![1.0; 10];
|
||
for _ in 0..10 {
|
||
network.forward(&input, 1.0);
|
||
}
|
||
|
||
// Build up some traces
|
||
for synapses in &network.input_synapses {
|
||
for synapse in synapses {
|
||
if synapse.eligibility_trace != 0.0 || synapse.filtered_trace != 0.0 {
|
||
// Found some non-zero traces (expected after activity)
|
||
}
|
||
}
|
||
}
|
||
|
||
// Reset
|
||
network.reset();
|
||
|
||
// Verify all traces are zero
|
||
for synapses in &network.input_synapses {
|
||
for synapse in synapses {
|
||
assert_eq!(synapse.eligibility_trace, 0.0);
|
||
assert_eq!(synapse.filtered_trace, 0.0);
|
||
}
|
||
}
|
||
|
||
for synapses in &network.recurrent_synapses {
|
||
for synapse in synapses {
|
||
assert_eq!(synapse.eligibility_trace, 0.0);
|
||
assert_eq!(synapse.filtered_trace, 0.0);
|
||
}
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn test_mnist_style_pattern() {
|
||
// Simplified MNIST-like task: classify 4x4 patterns
|
||
let mut network = EpropNetwork::new(16, 100, 2);
|
||
let dt = 1.0;
|
||
let lr = 0.005;
|
||
|
||
// Two simple patterns
|
||
let pattern_0 = vec![
|
||
1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0,
|
||
];
|
||
|
||
let pattern_1 = vec![
|
||
0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0,
|
||
];
|
||
|
||
let target_0 = vec![1.0, 0.0];
|
||
let target_1 = vec![0.0, 1.0];
|
||
|
||
// Train
|
||
for epoch in 0..100 {
|
||
// Pattern 0
|
||
network.reset();
|
||
for _ in 0..10 {
|
||
network.online_step(&pattern_0, &target_0, dt, lr);
|
||
}
|
||
|
||
// Pattern 1
|
||
network.reset();
|
||
for _ in 0..10 {
|
||
network.online_step(&pattern_1, &target_1, dt, lr);
|
||
}
|
||
|
||
if epoch % 20 == 0 {
|
||
println!("Epoch {}: Training...", epoch);
|
||
}
|
||
}
|
||
|
||
// Test
|
||
network.reset();
|
||
let mut output_0 = vec![0.0; 2];
|
||
for _ in 0..10 {
|
||
output_0 = network.forward(&pattern_0, dt);
|
||
}
|
||
|
||
network.reset();
|
||
let mut output_1 = vec![0.0; 2];
|
||
for _ in 0..10 {
|
||
output_1 = network.forward(&pattern_1, dt);
|
||
}
|
||
|
||
println!("Pattern 0 output: {:?}", output_0);
|
||
println!("Pattern 1 output: {:?}", output_1);
|
||
|
||
// Just verify it runs, pattern recognition is hard with such small network
|
||
}
|