Files
wifi-densepose/vendor/ruvector/crates/ruvector-nervous-system-wasm/src/btsp.rs

311 lines
8.9 KiB
Rust

//! BTSP (Behavioral Timescale Synaptic Plasticity) WASM bindings
//!
//! One-shot learning for immediate pattern-target associations.
//! Based on Bittner et al. 2017 hippocampal place field formation.
use wasm_bindgen::prelude::*;
/// BTSP synapse with eligibility trace and bidirectional plasticity
#[wasm_bindgen]
#[derive(Clone)]
pub struct BTSPSynapse {
weight: f32,
eligibility_trace: f32,
tau_btsp: f32,
min_weight: f32,
max_weight: f32,
ltp_rate: f32,
ltd_rate: f32,
}
#[wasm_bindgen]
impl BTSPSynapse {
/// Create a new BTSP synapse
///
/// # Arguments
/// * `initial_weight` - Starting weight (0.0 to 1.0)
/// * `tau_btsp` - Time constant in milliseconds (1000-3000ms recommended)
#[wasm_bindgen(constructor)]
pub fn new(initial_weight: f32, tau_btsp: f32) -> Result<BTSPSynapse, JsValue> {
if !(0.0..=1.0).contains(&initial_weight) {
return Err(JsValue::from_str(&format!(
"Invalid weight: {} (must be 0.0-1.0)",
initial_weight
)));
}
if tau_btsp <= 0.0 {
return Err(JsValue::from_str(&format!(
"Invalid time constant: {} (must be > 0)",
tau_btsp
)));
}
Ok(Self {
weight: initial_weight,
eligibility_trace: 0.0,
tau_btsp,
min_weight: 0.0,
max_weight: 1.0,
ltp_rate: 0.1,
ltd_rate: 0.05,
})
}
/// Update synapse based on activity and plateau signal
///
/// # Arguments
/// * `presynaptic_active` - Is presynaptic neuron firing?
/// * `plateau_signal` - Dendritic plateau potential detected?
/// * `dt` - Time step in milliseconds
#[wasm_bindgen]
pub fn update(&mut self, presynaptic_active: bool, plateau_signal: bool, dt: f32) {
// Decay eligibility trace exponentially
self.eligibility_trace *= (-dt / self.tau_btsp).exp();
// Accumulate trace when presynaptic neuron fires
if presynaptic_active {
self.eligibility_trace += 1.0;
}
// Bidirectional plasticity gated by plateau potential
if plateau_signal && self.eligibility_trace > 0.01 {
let delta = if self.weight < 0.5 {
self.ltp_rate // Potentiation
} else {
-self.ltd_rate // Depression
};
self.weight += delta * self.eligibility_trace;
self.weight = self.weight.clamp(self.min_weight, self.max_weight);
}
}
/// Get current weight
#[wasm_bindgen(getter)]
pub fn weight(&self) -> f32 {
self.weight
}
/// Get eligibility trace
#[wasm_bindgen(getter)]
pub fn eligibility_trace(&self) -> f32 {
self.eligibility_trace
}
/// Compute synaptic output
#[wasm_bindgen]
pub fn forward(&self, input: f32) -> f32 {
self.weight * input
}
}
/// BTSP Layer for one-shot learning
///
/// # Performance
/// - One-shot learning: immediate, no iteration
/// - Forward pass: <10us for 10K synapses
#[wasm_bindgen]
pub struct BTSPLayer {
weights: Vec<f32>,
eligibility_traces: Vec<f32>,
#[allow(dead_code)]
tau_btsp: f32,
#[allow(dead_code)]
plateau_threshold: f32,
}
#[wasm_bindgen]
impl BTSPLayer {
/// Create a new BTSP layer
///
/// # Arguments
/// * `size` - Number of synapses (input dimension)
/// * `tau` - Time constant in milliseconds (2000ms default)
#[wasm_bindgen(constructor)]
pub fn new(size: usize, tau: f32) -> BTSPLayer {
use rand::Rng;
let mut rng = rand::thread_rng();
let weights: Vec<f32> = (0..size).map(|_| rng.gen_range(0.0..0.1)).collect();
let eligibility_traces = vec![0.0; size];
Self {
weights,
eligibility_traces,
tau_btsp: tau,
plateau_threshold: 0.7,
}
}
/// Forward pass: compute layer output
#[wasm_bindgen]
pub fn forward(&self, input: &[f32]) -> Result<f32, JsValue> {
if input.len() != self.weights.len() {
return Err(JsValue::from_str(&format!(
"Input size mismatch: expected {}, got {}",
self.weights.len(),
input.len()
)));
}
Ok(self
.weights
.iter()
.zip(input.iter())
.map(|(&w, &x)| w * x)
.sum())
}
/// One-shot association: learn pattern -> target in single step
///
/// This is the key BTSP capability: immediate learning without iteration.
/// Uses gradient normalization for single-step convergence.
#[wasm_bindgen]
pub fn one_shot_associate(&mut self, pattern: &[f32], target: f32) -> Result<(), JsValue> {
if pattern.len() != self.weights.len() {
return Err(JsValue::from_str(&format!(
"Pattern size mismatch: expected {}, got {}",
self.weights.len(),
pattern.len()
)));
}
// Current output
let current: f32 = self
.weights
.iter()
.zip(pattern.iter())
.map(|(&w, &x)| w * x)
.sum();
// Compute required weight change
let error = target - current;
// Compute sum of squared inputs for gradient normalization
let sum_squared: f32 = pattern.iter().map(|&x| x * x).sum();
if sum_squared < 1e-8 {
return Ok(()); // No active inputs
}
// Set eligibility traces and update weights
for (i, &input_val) in pattern.iter().enumerate() {
if input_val.abs() > 0.01 {
// Set trace proportional to input
self.eligibility_traces[i] = input_val;
// Direct weight update: delta = error * x / sum(x^2)
let delta = error * input_val / sum_squared;
self.weights[i] += delta;
self.weights[i] = self.weights[i].clamp(0.0, 1.0);
}
}
Ok(())
}
/// Get number of synapses
#[wasm_bindgen(getter)]
pub fn size(&self) -> usize {
self.weights.len()
}
/// Get weights as Float32Array
#[wasm_bindgen]
pub fn get_weights(&self) -> js_sys::Float32Array {
js_sys::Float32Array::from(self.weights.as_slice())
}
/// Reset layer to initial state
#[wasm_bindgen]
pub fn reset(&mut self) {
use rand::Rng;
let mut rng = rand::thread_rng();
for w in &mut self.weights {
*w = rng.gen_range(0.0..0.1);
}
self.eligibility_traces.fill(0.0);
}
}
/// Associative memory using BTSP for key-value storage
#[wasm_bindgen]
pub struct BTSPAssociativeMemory {
layers: Vec<BTSPLayer>,
input_size: usize,
output_size: usize,
}
#[wasm_bindgen]
impl BTSPAssociativeMemory {
/// Create new associative memory
///
/// # Arguments
/// * `input_size` - Dimension of key vectors
/// * `output_size` - Dimension of value vectors
#[wasm_bindgen(constructor)]
pub fn new(input_size: usize, output_size: usize) -> BTSPAssociativeMemory {
let tau = 2000.0;
let layers = (0..output_size)
.map(|_| BTSPLayer::new(input_size, tau))
.collect();
Self {
layers,
input_size,
output_size,
}
}
/// Store key-value association in one shot
#[wasm_bindgen]
pub fn store_one_shot(&mut self, key: &[f32], value: &[f32]) -> Result<(), JsValue> {
if key.len() != self.input_size {
return Err(JsValue::from_str(&format!(
"Key size mismatch: expected {}, got {}",
self.input_size,
key.len()
)));
}
if value.len() != self.output_size {
return Err(JsValue::from_str(&format!(
"Value size mismatch: expected {}, got {}",
self.output_size,
value.len()
)));
}
for (layer, &target) in self.layers.iter_mut().zip(value.iter()) {
layer.one_shot_associate(key, target)?;
}
Ok(())
}
/// Retrieve value from key
#[wasm_bindgen]
pub fn retrieve(&self, query: &[f32]) -> Result<js_sys::Float32Array, JsValue> {
if query.len() != self.input_size {
return Err(JsValue::from_str(&format!(
"Query size mismatch: expected {}, got {}",
self.input_size,
query.len()
)));
}
let output: Vec<f32> = self
.layers
.iter()
.map(|layer| layer.forward(query).unwrap_or(0.0))
.collect();
Ok(js_sys::Float32Array::from(output.as_slice()))
}
/// Get memory dimensions
#[wasm_bindgen]
pub fn dimensions(&self) -> JsValue {
let dims = serde_wasm_bindgen::to_value(&(self.input_size, self.output_size));
dims.unwrap_or(JsValue::NULL)
}
}