Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
310
vendor/ruvector/crates/ruvector-nervous-system-wasm/src/btsp.rs
vendored
Normal file
310
vendor/ruvector/crates/ruvector-nervous-system-wasm/src/btsp.rs
vendored
Normal file
@@ -0,0 +1,310 @@
|
||||
//! BTSP (Behavioral Timescale Synaptic Plasticity) WASM bindings
|
||||
//!
|
||||
//! One-shot learning for immediate pattern-target associations.
|
||||
//! Based on Bittner et al. 2017 hippocampal place field formation.
|
||||
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
/// BTSP synapse with eligibility trace and bidirectional plasticity
|
||||
#[wasm_bindgen]
|
||||
#[derive(Clone)]
|
||||
pub struct BTSPSynapse {
|
||||
weight: f32,
|
||||
eligibility_trace: f32,
|
||||
tau_btsp: f32,
|
||||
min_weight: f32,
|
||||
max_weight: f32,
|
||||
ltp_rate: f32,
|
||||
ltd_rate: f32,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl BTSPSynapse {
|
||||
/// Create a new BTSP synapse
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `initial_weight` - Starting weight (0.0 to 1.0)
|
||||
/// * `tau_btsp` - Time constant in milliseconds (1000-3000ms recommended)
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(initial_weight: f32, tau_btsp: f32) -> Result<BTSPSynapse, JsValue> {
|
||||
if !(0.0..=1.0).contains(&initial_weight) {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Invalid weight: {} (must be 0.0-1.0)",
|
||||
initial_weight
|
||||
)));
|
||||
}
|
||||
if tau_btsp <= 0.0 {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Invalid time constant: {} (must be > 0)",
|
||||
tau_btsp
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
weight: initial_weight,
|
||||
eligibility_trace: 0.0,
|
||||
tau_btsp,
|
||||
min_weight: 0.0,
|
||||
max_weight: 1.0,
|
||||
ltp_rate: 0.1,
|
||||
ltd_rate: 0.05,
|
||||
})
|
||||
}
|
||||
|
||||
/// Update synapse based on activity and plateau signal
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `presynaptic_active` - Is presynaptic neuron firing?
|
||||
/// * `plateau_signal` - Dendritic plateau potential detected?
|
||||
/// * `dt` - Time step in milliseconds
|
||||
#[wasm_bindgen]
|
||||
pub fn update(&mut self, presynaptic_active: bool, plateau_signal: bool, dt: f32) {
|
||||
// Decay eligibility trace exponentially
|
||||
self.eligibility_trace *= (-dt / self.tau_btsp).exp();
|
||||
|
||||
// Accumulate trace when presynaptic neuron fires
|
||||
if presynaptic_active {
|
||||
self.eligibility_trace += 1.0;
|
||||
}
|
||||
|
||||
// Bidirectional plasticity gated by plateau potential
|
||||
if plateau_signal && self.eligibility_trace > 0.01 {
|
||||
let delta = if self.weight < 0.5 {
|
||||
self.ltp_rate // Potentiation
|
||||
} else {
|
||||
-self.ltd_rate // Depression
|
||||
};
|
||||
|
||||
self.weight += delta * self.eligibility_trace;
|
||||
self.weight = self.weight.clamp(self.min_weight, self.max_weight);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get current weight
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn weight(&self) -> f32 {
|
||||
self.weight
|
||||
}
|
||||
|
||||
/// Get eligibility trace
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn eligibility_trace(&self) -> f32 {
|
||||
self.eligibility_trace
|
||||
}
|
||||
|
||||
/// Compute synaptic output
|
||||
#[wasm_bindgen]
|
||||
pub fn forward(&self, input: f32) -> f32 {
|
||||
self.weight * input
|
||||
}
|
||||
}
|
||||
|
||||
/// BTSP Layer for one-shot learning
|
||||
///
|
||||
/// # Performance
|
||||
/// - One-shot learning: immediate, no iteration
|
||||
/// - Forward pass: <10us for 10K synapses
|
||||
#[wasm_bindgen]
|
||||
pub struct BTSPLayer {
|
||||
weights: Vec<f32>,
|
||||
eligibility_traces: Vec<f32>,
|
||||
#[allow(dead_code)]
|
||||
tau_btsp: f32,
|
||||
#[allow(dead_code)]
|
||||
plateau_threshold: f32,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl BTSPLayer {
|
||||
/// Create a new BTSP layer
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `size` - Number of synapses (input dimension)
|
||||
/// * `tau` - Time constant in milliseconds (2000ms default)
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(size: usize, tau: f32) -> BTSPLayer {
|
||||
use rand::Rng;
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let weights: Vec<f32> = (0..size).map(|_| rng.gen_range(0.0..0.1)).collect();
|
||||
let eligibility_traces = vec![0.0; size];
|
||||
|
||||
Self {
|
||||
weights,
|
||||
eligibility_traces,
|
||||
tau_btsp: tau,
|
||||
plateau_threshold: 0.7,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass: compute layer output
|
||||
#[wasm_bindgen]
|
||||
pub fn forward(&self, input: &[f32]) -> Result<f32, JsValue> {
|
||||
if input.len() != self.weights.len() {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Input size mismatch: expected {}, got {}",
|
||||
self.weights.len(),
|
||||
input.len()
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(self
|
||||
.weights
|
||||
.iter()
|
||||
.zip(input.iter())
|
||||
.map(|(&w, &x)| w * x)
|
||||
.sum())
|
||||
}
|
||||
|
||||
/// One-shot association: learn pattern -> target in single step
|
||||
///
|
||||
/// This is the key BTSP capability: immediate learning without iteration.
|
||||
/// Uses gradient normalization for single-step convergence.
|
||||
#[wasm_bindgen]
|
||||
pub fn one_shot_associate(&mut self, pattern: &[f32], target: f32) -> Result<(), JsValue> {
|
||||
if pattern.len() != self.weights.len() {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Pattern size mismatch: expected {}, got {}",
|
||||
self.weights.len(),
|
||||
pattern.len()
|
||||
)));
|
||||
}
|
||||
|
||||
// Current output
|
||||
let current: f32 = self
|
||||
.weights
|
||||
.iter()
|
||||
.zip(pattern.iter())
|
||||
.map(|(&w, &x)| w * x)
|
||||
.sum();
|
||||
|
||||
// Compute required weight change
|
||||
let error = target - current;
|
||||
|
||||
// Compute sum of squared inputs for gradient normalization
|
||||
let sum_squared: f32 = pattern.iter().map(|&x| x * x).sum();
|
||||
if sum_squared < 1e-8 {
|
||||
return Ok(()); // No active inputs
|
||||
}
|
||||
|
||||
// Set eligibility traces and update weights
|
||||
for (i, &input_val) in pattern.iter().enumerate() {
|
||||
if input_val.abs() > 0.01 {
|
||||
// Set trace proportional to input
|
||||
self.eligibility_traces[i] = input_val;
|
||||
|
||||
// Direct weight update: delta = error * x / sum(x^2)
|
||||
let delta = error * input_val / sum_squared;
|
||||
self.weights[i] += delta;
|
||||
self.weights[i] = self.weights[i].clamp(0.0, 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get number of synapses
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn size(&self) -> usize {
|
||||
self.weights.len()
|
||||
}
|
||||
|
||||
/// Get weights as Float32Array
|
||||
#[wasm_bindgen]
|
||||
pub fn get_weights(&self) -> js_sys::Float32Array {
|
||||
js_sys::Float32Array::from(self.weights.as_slice())
|
||||
}
|
||||
|
||||
/// Reset layer to initial state
|
||||
#[wasm_bindgen]
|
||||
pub fn reset(&mut self) {
|
||||
use rand::Rng;
|
||||
let mut rng = rand::thread_rng();
|
||||
for w in &mut self.weights {
|
||||
*w = rng.gen_range(0.0..0.1);
|
||||
}
|
||||
self.eligibility_traces.fill(0.0);
|
||||
}
|
||||
}
|
||||
|
||||
/// Associative memory using BTSP for key-value storage
|
||||
#[wasm_bindgen]
|
||||
pub struct BTSPAssociativeMemory {
|
||||
layers: Vec<BTSPLayer>,
|
||||
input_size: usize,
|
||||
output_size: usize,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl BTSPAssociativeMemory {
|
||||
/// Create new associative memory
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `input_size` - Dimension of key vectors
|
||||
/// * `output_size` - Dimension of value vectors
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(input_size: usize, output_size: usize) -> BTSPAssociativeMemory {
|
||||
let tau = 2000.0;
|
||||
let layers = (0..output_size)
|
||||
.map(|_| BTSPLayer::new(input_size, tau))
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
layers,
|
||||
input_size,
|
||||
output_size,
|
||||
}
|
||||
}
|
||||
|
||||
/// Store key-value association in one shot
|
||||
#[wasm_bindgen]
|
||||
pub fn store_one_shot(&mut self, key: &[f32], value: &[f32]) -> Result<(), JsValue> {
|
||||
if key.len() != self.input_size {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Key size mismatch: expected {}, got {}",
|
||||
self.input_size,
|
||||
key.len()
|
||||
)));
|
||||
}
|
||||
if value.len() != self.output_size {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Value size mismatch: expected {}, got {}",
|
||||
self.output_size,
|
||||
value.len()
|
||||
)));
|
||||
}
|
||||
|
||||
for (layer, &target) in self.layers.iter_mut().zip(value.iter()) {
|
||||
layer.one_shot_associate(key, target)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Retrieve value from key
|
||||
#[wasm_bindgen]
|
||||
pub fn retrieve(&self, query: &[f32]) -> Result<js_sys::Float32Array, JsValue> {
|
||||
if query.len() != self.input_size {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Query size mismatch: expected {}, got {}",
|
||||
self.input_size,
|
||||
query.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let output: Vec<f32> = self
|
||||
.layers
|
||||
.iter()
|
||||
.map(|layer| layer.forward(query).unwrap_or(0.0))
|
||||
.collect();
|
||||
|
||||
Ok(js_sys::Float32Array::from(output.as_slice()))
|
||||
}
|
||||
|
||||
/// Get memory dimensions
|
||||
#[wasm_bindgen]
|
||||
pub fn dimensions(&self) -> JsValue {
|
||||
let dims = serde_wasm_bindgen::to_value(&(self.input_size, self.output_size));
|
||||
dims.unwrap_or(JsValue::NULL)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user