Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
310
vendor/ruvector/crates/ruvector-nervous-system-wasm/src/btsp.rs
vendored
Normal file
310
vendor/ruvector/crates/ruvector-nervous-system-wasm/src/btsp.rs
vendored
Normal file
@@ -0,0 +1,310 @@
|
||||
//! BTSP (Behavioral Timescale Synaptic Plasticity) WASM bindings
|
||||
//!
|
||||
//! One-shot learning for immediate pattern-target associations.
|
||||
//! Based on Bittner et al. 2017 hippocampal place field formation.
|
||||
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
/// BTSP synapse with eligibility trace and bidirectional plasticity
|
||||
#[wasm_bindgen]
|
||||
#[derive(Clone)]
|
||||
pub struct BTSPSynapse {
|
||||
weight: f32,
|
||||
eligibility_trace: f32,
|
||||
tau_btsp: f32,
|
||||
min_weight: f32,
|
||||
max_weight: f32,
|
||||
ltp_rate: f32,
|
||||
ltd_rate: f32,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl BTSPSynapse {
|
||||
/// Create a new BTSP synapse
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `initial_weight` - Starting weight (0.0 to 1.0)
|
||||
/// * `tau_btsp` - Time constant in milliseconds (1000-3000ms recommended)
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(initial_weight: f32, tau_btsp: f32) -> Result<BTSPSynapse, JsValue> {
|
||||
if !(0.0..=1.0).contains(&initial_weight) {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Invalid weight: {} (must be 0.0-1.0)",
|
||||
initial_weight
|
||||
)));
|
||||
}
|
||||
if tau_btsp <= 0.0 {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Invalid time constant: {} (must be > 0)",
|
||||
tau_btsp
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
weight: initial_weight,
|
||||
eligibility_trace: 0.0,
|
||||
tau_btsp,
|
||||
min_weight: 0.0,
|
||||
max_weight: 1.0,
|
||||
ltp_rate: 0.1,
|
||||
ltd_rate: 0.05,
|
||||
})
|
||||
}
|
||||
|
||||
/// Update synapse based on activity and plateau signal
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `presynaptic_active` - Is presynaptic neuron firing?
|
||||
/// * `plateau_signal` - Dendritic plateau potential detected?
|
||||
/// * `dt` - Time step in milliseconds
|
||||
#[wasm_bindgen]
|
||||
pub fn update(&mut self, presynaptic_active: bool, plateau_signal: bool, dt: f32) {
|
||||
// Decay eligibility trace exponentially
|
||||
self.eligibility_trace *= (-dt / self.tau_btsp).exp();
|
||||
|
||||
// Accumulate trace when presynaptic neuron fires
|
||||
if presynaptic_active {
|
||||
self.eligibility_trace += 1.0;
|
||||
}
|
||||
|
||||
// Bidirectional plasticity gated by plateau potential
|
||||
if plateau_signal && self.eligibility_trace > 0.01 {
|
||||
let delta = if self.weight < 0.5 {
|
||||
self.ltp_rate // Potentiation
|
||||
} else {
|
||||
-self.ltd_rate // Depression
|
||||
};
|
||||
|
||||
self.weight += delta * self.eligibility_trace;
|
||||
self.weight = self.weight.clamp(self.min_weight, self.max_weight);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get current weight
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn weight(&self) -> f32 {
|
||||
self.weight
|
||||
}
|
||||
|
||||
/// Get eligibility trace
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn eligibility_trace(&self) -> f32 {
|
||||
self.eligibility_trace
|
||||
}
|
||||
|
||||
/// Compute synaptic output
|
||||
#[wasm_bindgen]
|
||||
pub fn forward(&self, input: f32) -> f32 {
|
||||
self.weight * input
|
||||
}
|
||||
}
|
||||
|
||||
/// BTSP Layer for one-shot learning
|
||||
///
|
||||
/// # Performance
|
||||
/// - One-shot learning: immediate, no iteration
|
||||
/// - Forward pass: <10us for 10K synapses
|
||||
#[wasm_bindgen]
|
||||
pub struct BTSPLayer {
|
||||
weights: Vec<f32>,
|
||||
eligibility_traces: Vec<f32>,
|
||||
#[allow(dead_code)]
|
||||
tau_btsp: f32,
|
||||
#[allow(dead_code)]
|
||||
plateau_threshold: f32,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl BTSPLayer {
|
||||
/// Create a new BTSP layer
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `size` - Number of synapses (input dimension)
|
||||
/// * `tau` - Time constant in milliseconds (2000ms default)
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(size: usize, tau: f32) -> BTSPLayer {
|
||||
use rand::Rng;
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let weights: Vec<f32> = (0..size).map(|_| rng.gen_range(0.0..0.1)).collect();
|
||||
let eligibility_traces = vec![0.0; size];
|
||||
|
||||
Self {
|
||||
weights,
|
||||
eligibility_traces,
|
||||
tau_btsp: tau,
|
||||
plateau_threshold: 0.7,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass: compute layer output
|
||||
#[wasm_bindgen]
|
||||
pub fn forward(&self, input: &[f32]) -> Result<f32, JsValue> {
|
||||
if input.len() != self.weights.len() {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Input size mismatch: expected {}, got {}",
|
||||
self.weights.len(),
|
||||
input.len()
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(self
|
||||
.weights
|
||||
.iter()
|
||||
.zip(input.iter())
|
||||
.map(|(&w, &x)| w * x)
|
||||
.sum())
|
||||
}
|
||||
|
||||
/// One-shot association: learn pattern -> target in single step
|
||||
///
|
||||
/// This is the key BTSP capability: immediate learning without iteration.
|
||||
/// Uses gradient normalization for single-step convergence.
|
||||
#[wasm_bindgen]
|
||||
pub fn one_shot_associate(&mut self, pattern: &[f32], target: f32) -> Result<(), JsValue> {
|
||||
if pattern.len() != self.weights.len() {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Pattern size mismatch: expected {}, got {}",
|
||||
self.weights.len(),
|
||||
pattern.len()
|
||||
)));
|
||||
}
|
||||
|
||||
// Current output
|
||||
let current: f32 = self
|
||||
.weights
|
||||
.iter()
|
||||
.zip(pattern.iter())
|
||||
.map(|(&w, &x)| w * x)
|
||||
.sum();
|
||||
|
||||
// Compute required weight change
|
||||
let error = target - current;
|
||||
|
||||
// Compute sum of squared inputs for gradient normalization
|
||||
let sum_squared: f32 = pattern.iter().map(|&x| x * x).sum();
|
||||
if sum_squared < 1e-8 {
|
||||
return Ok(()); // No active inputs
|
||||
}
|
||||
|
||||
// Set eligibility traces and update weights
|
||||
for (i, &input_val) in pattern.iter().enumerate() {
|
||||
if input_val.abs() > 0.01 {
|
||||
// Set trace proportional to input
|
||||
self.eligibility_traces[i] = input_val;
|
||||
|
||||
// Direct weight update: delta = error * x / sum(x^2)
|
||||
let delta = error * input_val / sum_squared;
|
||||
self.weights[i] += delta;
|
||||
self.weights[i] = self.weights[i].clamp(0.0, 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get number of synapses
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn size(&self) -> usize {
|
||||
self.weights.len()
|
||||
}
|
||||
|
||||
/// Get weights as Float32Array
|
||||
#[wasm_bindgen]
|
||||
pub fn get_weights(&self) -> js_sys::Float32Array {
|
||||
js_sys::Float32Array::from(self.weights.as_slice())
|
||||
}
|
||||
|
||||
/// Reset layer to initial state
|
||||
#[wasm_bindgen]
|
||||
pub fn reset(&mut self) {
|
||||
use rand::Rng;
|
||||
let mut rng = rand::thread_rng();
|
||||
for w in &mut self.weights {
|
||||
*w = rng.gen_range(0.0..0.1);
|
||||
}
|
||||
self.eligibility_traces.fill(0.0);
|
||||
}
|
||||
}
|
||||
|
||||
/// Associative memory using BTSP for key-value storage
|
||||
#[wasm_bindgen]
|
||||
pub struct BTSPAssociativeMemory {
|
||||
layers: Vec<BTSPLayer>,
|
||||
input_size: usize,
|
||||
output_size: usize,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl BTSPAssociativeMemory {
|
||||
/// Create new associative memory
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `input_size` - Dimension of key vectors
|
||||
/// * `output_size` - Dimension of value vectors
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(input_size: usize, output_size: usize) -> BTSPAssociativeMemory {
|
||||
let tau = 2000.0;
|
||||
let layers = (0..output_size)
|
||||
.map(|_| BTSPLayer::new(input_size, tau))
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
layers,
|
||||
input_size,
|
||||
output_size,
|
||||
}
|
||||
}
|
||||
|
||||
/// Store key-value association in one shot
|
||||
#[wasm_bindgen]
|
||||
pub fn store_one_shot(&mut self, key: &[f32], value: &[f32]) -> Result<(), JsValue> {
|
||||
if key.len() != self.input_size {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Key size mismatch: expected {}, got {}",
|
||||
self.input_size,
|
||||
key.len()
|
||||
)));
|
||||
}
|
||||
if value.len() != self.output_size {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Value size mismatch: expected {}, got {}",
|
||||
self.output_size,
|
||||
value.len()
|
||||
)));
|
||||
}
|
||||
|
||||
for (layer, &target) in self.layers.iter_mut().zip(value.iter()) {
|
||||
layer.one_shot_associate(key, target)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Retrieve value from key
|
||||
#[wasm_bindgen]
|
||||
pub fn retrieve(&self, query: &[f32]) -> Result<js_sys::Float32Array, JsValue> {
|
||||
if query.len() != self.input_size {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Query size mismatch: expected {}, got {}",
|
||||
self.input_size,
|
||||
query.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let output: Vec<f32> = self
|
||||
.layers
|
||||
.iter()
|
||||
.map(|layer| layer.forward(query).unwrap_or(0.0))
|
||||
.collect();
|
||||
|
||||
Ok(js_sys::Float32Array::from(output.as_slice()))
|
||||
}
|
||||
|
||||
/// Get memory dimensions
|
||||
#[wasm_bindgen]
|
||||
pub fn dimensions(&self) -> JsValue {
|
||||
let dims = serde_wasm_bindgen::to_value(&(self.input_size, self.output_size));
|
||||
dims.unwrap_or(JsValue::NULL)
|
||||
}
|
||||
}
|
||||
272
vendor/ruvector/crates/ruvector-nervous-system-wasm/src/hdc.rs
vendored
Normal file
272
vendor/ruvector/crates/ruvector-nervous-system-wasm/src/hdc.rs
vendored
Normal file
@@ -0,0 +1,272 @@
|
||||
//! Hyperdimensional Computing (HDC) WASM bindings
|
||||
//!
|
||||
//! 10,000-bit binary hypervectors with ultra-fast operations:
|
||||
//! - XOR binding: <50ns
|
||||
//! - Hamming similarity: <100ns via SIMD
|
||||
//! - 10^40 representational capacity
|
||||
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
/// Number of bits in a hypervector
|
||||
const HYPERVECTOR_BITS: usize = 10_000;
|
||||
|
||||
/// Number of u64 words needed (ceil(10000/64) = 157)
|
||||
const HYPERVECTOR_U64_LEN: usize = 157;
|
||||
|
||||
/// A binary hypervector with 10,000 bits
|
||||
///
|
||||
/// # Performance
|
||||
/// - Memory: 1,248 bytes per vector
|
||||
/// - XOR binding: <50ns
|
||||
/// - Similarity: <100ns with SIMD popcount
|
||||
#[wasm_bindgen]
|
||||
pub struct Hypervector {
|
||||
bits: Vec<u64>,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl Hypervector {
|
||||
/// Create a zero hypervector
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new() -> Hypervector {
|
||||
Self {
|
||||
bits: vec![0u64; HYPERVECTOR_U64_LEN],
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a random hypervector with ~50% bits set
|
||||
#[wasm_bindgen]
|
||||
pub fn random() -> Hypervector {
|
||||
use rand::Rng;
|
||||
let mut rng = rand::thread_rng();
|
||||
let bits: Vec<u64> = (0..HYPERVECTOR_U64_LEN).map(|_| rng.gen()).collect();
|
||||
Self { bits }
|
||||
}
|
||||
|
||||
/// Create a hypervector from a seed for reproducibility
|
||||
#[wasm_bindgen]
|
||||
pub fn from_seed(seed: u64) -> Hypervector {
|
||||
use rand::{Rng, SeedableRng};
|
||||
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
|
||||
let bits: Vec<u64> = (0..HYPERVECTOR_U64_LEN).map(|_| rng.gen()).collect();
|
||||
Self { bits }
|
||||
}
|
||||
|
||||
/// Bind two hypervectors using XOR
|
||||
///
|
||||
/// Binding is associative, commutative, and self-inverse:
|
||||
/// - a.bind(b) == b.bind(a)
|
||||
/// - a.bind(b).bind(b) == a
|
||||
#[wasm_bindgen]
|
||||
pub fn bind(&self, other: &Hypervector) -> Hypervector {
|
||||
let bits: Vec<u64> = self
|
||||
.bits
|
||||
.iter()
|
||||
.zip(other.bits.iter())
|
||||
.map(|(&a, &b)| a ^ b)
|
||||
.collect();
|
||||
Self { bits }
|
||||
}
|
||||
|
||||
/// Compute similarity between two hypervectors
|
||||
///
|
||||
/// Returns a value in [-1.0, 1.0] where:
|
||||
/// - 1.0 = identical vectors
|
||||
/// - 0.0 = random/orthogonal vectors
|
||||
/// - -1.0 = completely opposite vectors
|
||||
#[wasm_bindgen]
|
||||
pub fn similarity(&self, other: &Hypervector) -> f32 {
|
||||
let hamming = self.hamming_distance(other);
|
||||
1.0 - (2.0 * hamming as f32 / HYPERVECTOR_BITS as f32)
|
||||
}
|
||||
|
||||
/// Compute Hamming distance (number of differing bits)
|
||||
#[wasm_bindgen]
|
||||
pub fn hamming_distance(&self, other: &Hypervector) -> u32 {
|
||||
// Unrolled loop for better instruction-level parallelism
|
||||
let mut d0 = 0u32;
|
||||
let mut d1 = 0u32;
|
||||
let mut d2 = 0u32;
|
||||
let mut d3 = 0u32;
|
||||
|
||||
let chunks = HYPERVECTOR_U64_LEN / 4;
|
||||
let remainder = HYPERVECTOR_U64_LEN % 4;
|
||||
|
||||
for i in 0..chunks {
|
||||
let base = i * 4;
|
||||
d0 += (self.bits[base] ^ other.bits[base]).count_ones();
|
||||
d1 += (self.bits[base + 1] ^ other.bits[base + 1]).count_ones();
|
||||
d2 += (self.bits[base + 2] ^ other.bits[base + 2]).count_ones();
|
||||
d3 += (self.bits[base + 3] ^ other.bits[base + 3]).count_ones();
|
||||
}
|
||||
|
||||
let base = chunks * 4;
|
||||
for i in 0..remainder {
|
||||
d0 += (self.bits[base + i] ^ other.bits[base + i]).count_ones();
|
||||
}
|
||||
|
||||
d0 + d1 + d2 + d3
|
||||
}
|
||||
|
||||
/// Count the number of set bits (population count)
|
||||
#[wasm_bindgen]
|
||||
pub fn popcount(&self) -> u32 {
|
||||
self.bits.iter().map(|&w| w.count_ones()).sum()
|
||||
}
|
||||
|
||||
/// Bundle multiple vectors by majority voting on each bit
|
||||
#[wasm_bindgen]
|
||||
pub fn bundle_3(a: &Hypervector, b: &Hypervector, c: &Hypervector) -> Hypervector {
|
||||
// Majority of 3 bits: (a & b) | (b & c) | (a & c)
|
||||
let bits: Vec<u64> = (0..HYPERVECTOR_U64_LEN)
|
||||
.map(|i| {
|
||||
let wa = a.bits[i];
|
||||
let wb = b.bits[i];
|
||||
let wc = c.bits[i];
|
||||
(wa & wb) | (wb & wc) | (wa & wc)
|
||||
})
|
||||
.collect();
|
||||
Self { bits }
|
||||
}
|
||||
|
||||
/// Get the raw bits as Uint8Array (for serialization)
|
||||
#[wasm_bindgen]
|
||||
pub fn to_bytes(&self) -> js_sys::Uint8Array {
|
||||
let bytes: Vec<u8> = self.bits.iter().flat_map(|&w| w.to_le_bytes()).collect();
|
||||
js_sys::Uint8Array::from(bytes.as_slice())
|
||||
}
|
||||
|
||||
/// Create from raw bytes
|
||||
#[wasm_bindgen]
|
||||
pub fn from_bytes(bytes: &[u8]) -> Result<Hypervector, JsValue> {
|
||||
if bytes.len() != HYPERVECTOR_U64_LEN * 8 {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Invalid byte length: expected {}, got {}",
|
||||
HYPERVECTOR_U64_LEN * 8,
|
||||
bytes.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let bits: Vec<u64> = bytes
|
||||
.chunks_exact(8)
|
||||
.map(|chunk| u64::from_le_bytes(chunk.try_into().unwrap()))
|
||||
.collect();
|
||||
|
||||
Ok(Self { bits })
|
||||
}
|
||||
|
||||
/// Get number of bits
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn dimension(&self) -> usize {
|
||||
HYPERVECTOR_BITS
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Hypervector {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// HDC Memory for storing and retrieving hypervectors by label
|
||||
#[wasm_bindgen]
|
||||
pub struct HdcMemory {
|
||||
labels: Vec<String>,
|
||||
vectors: Vec<Hypervector>,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl HdcMemory {
|
||||
/// Create a new empty HDC memory
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new() -> HdcMemory {
|
||||
Self {
|
||||
labels: Vec::new(),
|
||||
vectors: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Store a hypervector with a label
|
||||
#[wasm_bindgen]
|
||||
pub fn store(&mut self, label: &str, vector: Hypervector) {
|
||||
// Check if label exists
|
||||
if let Some(idx) = self.labels.iter().position(|l| l == label) {
|
||||
self.vectors[idx] = vector;
|
||||
} else {
|
||||
self.labels.push(label.to_string());
|
||||
self.vectors.push(vector);
|
||||
}
|
||||
}
|
||||
|
||||
/// Retrieve vectors similar to query above threshold
|
||||
///
|
||||
/// Returns array of [label, similarity] pairs
|
||||
#[wasm_bindgen]
|
||||
pub fn retrieve(&self, query: &Hypervector, threshold: f32) -> JsValue {
|
||||
let mut results: Vec<(String, f32)> = Vec::new();
|
||||
|
||||
for (label, vector) in self.labels.iter().zip(self.vectors.iter()) {
|
||||
let sim = query.similarity(vector);
|
||||
if sim >= threshold {
|
||||
results.push((label.clone(), sim));
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by similarity descending
|
||||
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
|
||||
serde_wasm_bindgen::to_value(&results).unwrap_or(JsValue::NULL)
|
||||
}
|
||||
|
||||
/// Find the k most similar vectors to query
|
||||
#[wasm_bindgen]
|
||||
pub fn top_k(&self, query: &Hypervector, k: usize) -> JsValue {
|
||||
let mut similarities: Vec<(String, f32)> = self
|
||||
.labels
|
||||
.iter()
|
||||
.zip(self.vectors.iter())
|
||||
.map(|(label, vector)| (label.clone(), query.similarity(vector)))
|
||||
.collect();
|
||||
|
||||
similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
similarities.truncate(k);
|
||||
|
||||
serde_wasm_bindgen::to_value(&similarities).unwrap_or(JsValue::NULL)
|
||||
}
|
||||
|
||||
/// Get number of stored vectors
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn size(&self) -> usize {
|
||||
self.vectors.len()
|
||||
}
|
||||
|
||||
/// Clear all stored vectors
|
||||
#[wasm_bindgen]
|
||||
pub fn clear(&mut self) {
|
||||
self.labels.clear();
|
||||
self.vectors.clear();
|
||||
}
|
||||
|
||||
/// Check if a label exists
|
||||
#[wasm_bindgen]
|
||||
pub fn has(&self, label: &str) -> bool {
|
||||
self.labels.iter().any(|l| l == label)
|
||||
}
|
||||
|
||||
/// Get a vector by label
|
||||
#[wasm_bindgen]
|
||||
pub fn get(&self, label: &str) -> Option<Hypervector> {
|
||||
self.labels
|
||||
.iter()
|
||||
.position(|l| l == label)
|
||||
.map(|idx| Hypervector {
|
||||
bits: self.vectors[idx].bits.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for HdcMemory {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
156
vendor/ruvector/crates/ruvector-nervous-system-wasm/src/lib.rs
vendored
Normal file
156
vendor/ruvector/crates/ruvector-nervous-system-wasm/src/lib.rs
vendored
Normal file
@@ -0,0 +1,156 @@
|
||||
//! # RuVector Nervous System WASM
|
||||
//!
|
||||
//! Bio-inspired neural system components for browser execution.
|
||||
//!
|
||||
//! ## Components
|
||||
//!
|
||||
//! - **BTSP** (Behavioral Timescale Synaptic Plasticity) - One-shot learning
|
||||
//! - **HDC** (Hyperdimensional Computing) - 10,000-bit binary hypervectors
|
||||
//! - **WTA** (Winner-Take-All) - <1us instant decisions
|
||||
//! - **Global Workspace** - 4-7 item attention bottleneck
|
||||
//!
|
||||
//! ## Performance Targets
|
||||
//!
|
||||
//! | Component | Target | Method |
|
||||
//! |-----------|--------|--------|
|
||||
//! | BTSP one_shot_associate | Immediate | Gradient normalization |
|
||||
//! | HDC bind | <50ns | XOR operation |
|
||||
//! | HDC similarity | <100ns | Hamming distance + SIMD |
|
||||
//! | WTA compete | <1us | Single-pass argmax |
|
||||
//! | K-WTA select | <10us | Partial sort |
|
||||
//! | Workspace broadcast | <10us | Competition |
|
||||
//!
|
||||
//! ## Bundle Size
|
||||
//!
|
||||
//! Target: <100KB with all bio-inspired mechanisms.
|
||||
//!
|
||||
//! ## Example Usage (JavaScript)
|
||||
//!
|
||||
//! ```javascript
|
||||
//! import init, {
|
||||
//! BTSPLayer,
|
||||
//! Hypervector,
|
||||
//! HdcMemory,
|
||||
//! WTALayer,
|
||||
//! KWTALayer,
|
||||
//! GlobalWorkspace,
|
||||
//! WorkspaceItem,
|
||||
//! } from 'ruvector-nervous-system-wasm';
|
||||
//!
|
||||
//! await init();
|
||||
//!
|
||||
//! // One-shot learning with BTSP
|
||||
//! const btsp = new BTSPLayer(100, 2000.0);
|
||||
//! const pattern = new Float32Array(100).fill(0.1);
|
||||
//! btsp.one_shot_associate(pattern, 1.0);
|
||||
//! const output = btsp.forward(pattern);
|
||||
//!
|
||||
//! // Hyperdimensional computing
|
||||
//! const apple = Hypervector.random();
|
||||
//! const orange = Hypervector.random();
|
||||
//! const fruit = apple.bind(orange);
|
||||
//! const similarity = apple.similarity(orange);
|
||||
//!
|
||||
//! const memory = new HdcMemory();
|
||||
//! memory.store("apple", apple);
|
||||
//! const results = memory.retrieve(apple, 0.9);
|
||||
//!
|
||||
//! // Instant decisions with WTA
|
||||
//! const wta = new WTALayer(1000, 0.5, 0.8);
|
||||
//! const activations = new Float32Array(1000);
|
||||
//! const winner = wta.compete(activations);
|
||||
//!
|
||||
//! // Sparse coding with K-WTA
|
||||
//! const kwta = new KWTALayer(1000, 50);
|
||||
//! const winners = kwta.select(activations);
|
||||
//!
|
||||
//! // Attention bottleneck with Global Workspace
|
||||
//! const workspace = new GlobalWorkspace(7); // Miller's Law: 7 +/- 2
|
||||
//! const item = new WorkspaceItem(new Float32Array([1, 2, 3]), 0.9, 1, Date.now());
|
||||
//! workspace.broadcast(item);
|
||||
//! ```
|
||||
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
pub mod btsp;
|
||||
pub mod hdc;
|
||||
pub mod workspace;
|
||||
pub mod wta;
|
||||
|
||||
// Re-export all public types
|
||||
pub use btsp::{BTSPAssociativeMemory, BTSPLayer, BTSPSynapse};
|
||||
pub use hdc::{HdcMemory, Hypervector};
|
||||
pub use workspace::{GlobalWorkspace, WorkspaceItem};
|
||||
pub use wta::{KWTALayer, WTALayer};
|
||||
|
||||
/// Initialize the WASM module with panic hook
|
||||
#[wasm_bindgen(start)]
|
||||
pub fn init() {
|
||||
#[cfg(feature = "console_error_panic_hook")]
|
||||
console_error_panic_hook::set_once();
|
||||
}
|
||||
|
||||
/// Get the version of the crate
|
||||
#[wasm_bindgen]
|
||||
pub fn version() -> String {
|
||||
env!("CARGO_PKG_VERSION").to_string()
|
||||
}
|
||||
|
||||
/// Get information about available bio-inspired mechanisms
|
||||
#[wasm_bindgen]
|
||||
pub fn available_mechanisms() -> JsValue {
|
||||
let mechanisms = vec![
|
||||
(
|
||||
"btsp",
|
||||
"Behavioral Timescale Synaptic Plasticity - One-shot learning",
|
||||
),
|
||||
("hdc", "Hyperdimensional Computing - 10,000-bit vectors"),
|
||||
("wta", "Winner-Take-All - <1us decisions"),
|
||||
("kwta", "K-Winner-Take-All - Sparse distributed coding"),
|
||||
("workspace", "Global Workspace - 4-7 item attention"),
|
||||
];
|
||||
serde_wasm_bindgen::to_value(&mechanisms).unwrap_or(JsValue::NULL)
|
||||
}
|
||||
|
||||
/// Get performance targets for each mechanism
|
||||
#[wasm_bindgen]
|
||||
pub fn performance_targets() -> JsValue {
|
||||
let targets = vec![
|
||||
("btsp_one_shot", "Immediate (no iteration)"),
|
||||
("hdc_bind", "<50ns"),
|
||||
("hdc_similarity", "<100ns"),
|
||||
("wta_compete", "<1us"),
|
||||
("kwta_select", "<10us (k=50, n=1000)"),
|
||||
("workspace_broadcast", "<10us"),
|
||||
];
|
||||
serde_wasm_bindgen::to_value(&targets).unwrap_or(JsValue::NULL)
|
||||
}
|
||||
|
||||
/// Get biological references for the mechanisms
|
||||
#[wasm_bindgen]
|
||||
pub fn biological_references() -> JsValue {
|
||||
let refs = vec![
|
||||
("BTSP", "Bittner et al. 2017 - Hippocampal place fields"),
|
||||
(
|
||||
"HDC",
|
||||
"Kanerva 1988, Plate 2003 - Hyperdimensional computing",
|
||||
),
|
||||
("WTA", "Cortical microcircuits - Lateral inhibition"),
|
||||
(
|
||||
"Global Workspace",
|
||||
"Baars 1988, Dehaene 2014 - Consciousness",
|
||||
),
|
||||
];
|
||||
serde_wasm_bindgen::to_value(&refs).unwrap_or(JsValue::NULL)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_version() {
|
||||
let v = version();
|
||||
assert!(!v.is_empty());
|
||||
}
|
||||
}
|
||||
343
vendor/ruvector/crates/ruvector-nervous-system-wasm/src/workspace.rs
vendored
Normal file
343
vendor/ruvector/crates/ruvector-nervous-system-wasm/src/workspace.rs
vendored
Normal file
@@ -0,0 +1,343 @@
|
||||
//! Global Workspace WASM bindings
|
||||
//!
|
||||
//! Based on Global Workspace Theory (Baars, Dehaene):
|
||||
//! - 4-7 item capacity (Miller's law)
|
||||
//! - Broadcast/compete architecture
|
||||
//! - Relevance-based ignition
|
||||
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
/// Item in the global workspace
|
||||
#[wasm_bindgen]
|
||||
#[derive(Clone)]
|
||||
pub struct WorkspaceItem {
|
||||
content: Vec<f32>,
|
||||
salience: f32,
|
||||
source_module: u16,
|
||||
timestamp: u64,
|
||||
decay_rate: f32,
|
||||
lifetime: u64,
|
||||
id: u64,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WorkspaceItem {
|
||||
/// Create a new workspace item
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(
|
||||
content: &[f32],
|
||||
salience: f32,
|
||||
source_module: u16,
|
||||
timestamp: u64,
|
||||
) -> WorkspaceItem {
|
||||
Self {
|
||||
content: content.to_vec(),
|
||||
salience,
|
||||
source_module,
|
||||
timestamp,
|
||||
decay_rate: 0.95,
|
||||
lifetime: 1000,
|
||||
id: timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with custom decay and lifetime
|
||||
#[wasm_bindgen]
|
||||
pub fn with_decay(
|
||||
content: &[f32],
|
||||
salience: f32,
|
||||
source_module: u16,
|
||||
timestamp: u64,
|
||||
decay_rate: f32,
|
||||
lifetime: u64,
|
||||
) -> WorkspaceItem {
|
||||
Self {
|
||||
content: content.to_vec(),
|
||||
salience,
|
||||
source_module,
|
||||
timestamp,
|
||||
decay_rate,
|
||||
lifetime,
|
||||
id: timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get content as Float32Array
|
||||
#[wasm_bindgen]
|
||||
pub fn get_content(&self) -> js_sys::Float32Array {
|
||||
js_sys::Float32Array::from(self.content.as_slice())
|
||||
}
|
||||
|
||||
/// Get salience
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn salience(&self) -> f32 {
|
||||
self.salience
|
||||
}
|
||||
|
||||
/// Get source module
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn source_module(&self) -> u16 {
|
||||
self.source_module
|
||||
}
|
||||
|
||||
/// Get timestamp
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn timestamp(&self) -> u64 {
|
||||
self.timestamp
|
||||
}
|
||||
|
||||
/// Get ID
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn id(&self) -> u64 {
|
||||
self.id
|
||||
}
|
||||
|
||||
/// Compute content magnitude (L2 norm)
|
||||
#[wasm_bindgen]
|
||||
pub fn magnitude(&self) -> f32 {
|
||||
self.content.iter().map(|x| x * x).sum::<f32>().sqrt()
|
||||
}
|
||||
|
||||
/// Update salience
|
||||
#[wasm_bindgen]
|
||||
pub fn update_salience(&mut self, new_salience: f32) {
|
||||
self.salience = new_salience.max(0.0);
|
||||
}
|
||||
|
||||
/// Apply temporal decay
|
||||
#[wasm_bindgen]
|
||||
pub fn apply_decay(&mut self, dt: f32) {
|
||||
self.salience *= self.decay_rate.powf(dt);
|
||||
}
|
||||
|
||||
/// Check if expired
|
||||
#[wasm_bindgen]
|
||||
pub fn is_expired(&self, current_time: u64) -> bool {
|
||||
current_time.saturating_sub(self.timestamp) > self.lifetime
|
||||
}
|
||||
}
|
||||
|
||||
/// Global workspace with limited capacity and competitive dynamics
|
||||
///
|
||||
/// Implements attention and conscious access mechanisms based on
|
||||
/// Global Workspace Theory.
|
||||
#[wasm_bindgen]
|
||||
pub struct GlobalWorkspace {
|
||||
buffer: Vec<WorkspaceItem>,
|
||||
capacity: usize,
|
||||
salience_threshold: f32,
|
||||
timestamp: u64,
|
||||
salience_decay: f32,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl GlobalWorkspace {
|
||||
/// Create a new global workspace
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `capacity` - Maximum number of representations (typically 4-7)
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(capacity: usize) -> GlobalWorkspace {
|
||||
Self {
|
||||
buffer: Vec::with_capacity(capacity),
|
||||
capacity,
|
||||
salience_threshold: 0.1,
|
||||
timestamp: 0,
|
||||
salience_decay: 0.95,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with custom threshold
|
||||
#[wasm_bindgen]
|
||||
pub fn with_threshold(capacity: usize, threshold: f32) -> GlobalWorkspace {
|
||||
Self {
|
||||
buffer: Vec::with_capacity(capacity),
|
||||
capacity,
|
||||
salience_threshold: threshold,
|
||||
timestamp: 0,
|
||||
salience_decay: 0.95,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set salience decay rate
|
||||
#[wasm_bindgen]
|
||||
pub fn set_decay_rate(&mut self, decay: f32) {
|
||||
self.salience_decay = decay.clamp(0.0, 1.0);
|
||||
}
|
||||
|
||||
/// Broadcast a representation to the workspace
|
||||
///
|
||||
/// Returns true if accepted, false if rejected.
|
||||
#[wasm_bindgen]
|
||||
pub fn broadcast(&mut self, item: WorkspaceItem) -> bool {
|
||||
self.timestamp += 1;
|
||||
let mut item = item;
|
||||
item.timestamp = self.timestamp;
|
||||
|
||||
// Reject if below threshold
|
||||
if item.salience < self.salience_threshold {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If workspace not full, add directly
|
||||
if self.buffer.len() < self.capacity {
|
||||
self.buffer.push(item);
|
||||
return true;
|
||||
}
|
||||
|
||||
// If full, compete with weakest item
|
||||
if let Some(min_idx) = self.find_weakest() {
|
||||
if self.buffer[min_idx].salience < item.salience {
|
||||
self.buffer.swap_remove(min_idx);
|
||||
self.buffer.push(item);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
/// Run competitive dynamics (salience decay and pruning)
|
||||
#[wasm_bindgen]
|
||||
pub fn compete(&mut self) {
|
||||
// Apply salience decay
|
||||
for item in self.buffer.iter_mut() {
|
||||
item.salience *= self.salience_decay;
|
||||
}
|
||||
|
||||
// Remove items below threshold
|
||||
self.buffer
|
||||
.retain(|item| item.salience >= self.salience_threshold);
|
||||
}
|
||||
|
||||
/// Retrieve all current representations as JSON
|
||||
#[wasm_bindgen]
|
||||
pub fn retrieve(&self) -> JsValue {
|
||||
let items: Vec<_> = self
|
||||
.buffer
|
||||
.iter()
|
||||
.map(|item| {
|
||||
serde_json::json!({
|
||||
"content": item.content,
|
||||
"salience": item.salience,
|
||||
"source_module": item.source_module,
|
||||
"timestamp": item.timestamp,
|
||||
"id": item.id
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
serde_wasm_bindgen::to_value(&items).unwrap_or(JsValue::NULL)
|
||||
}
|
||||
|
||||
/// Retrieve top-k most salient representations
|
||||
#[wasm_bindgen]
|
||||
pub fn retrieve_top_k(&self, k: usize) -> JsValue {
|
||||
let mut items: Vec<_> = self.buffer.iter().collect();
|
||||
items.sort_by(|a, b| {
|
||||
b.salience
|
||||
.partial_cmp(&a.salience)
|
||||
.unwrap_or(std::cmp::Ordering::Less)
|
||||
});
|
||||
items.truncate(k);
|
||||
|
||||
let result: Vec<_> = items
|
||||
.iter()
|
||||
.map(|item| {
|
||||
serde_json::json!({
|
||||
"content": item.content,
|
||||
"salience": item.salience,
|
||||
"source_module": item.source_module,
|
||||
"timestamp": item.timestamp,
|
||||
"id": item.id
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
serde_wasm_bindgen::to_value(&result).unwrap_or(JsValue::NULL)
|
||||
}
|
||||
|
||||
/// Get most salient item
|
||||
#[wasm_bindgen]
|
||||
pub fn most_salient(&self) -> Option<WorkspaceItem> {
|
||||
self.buffer
|
||||
.iter()
|
||||
.max_by(|a, b| {
|
||||
a.salience
|
||||
.partial_cmp(&b.salience)
|
||||
.unwrap_or(std::cmp::Ordering::Less)
|
||||
})
|
||||
.cloned()
|
||||
}
|
||||
|
||||
/// Check if workspace is at capacity
|
||||
#[wasm_bindgen]
|
||||
pub fn is_full(&self) -> bool {
|
||||
self.buffer.len() >= self.capacity
|
||||
}
|
||||
|
||||
/// Check if workspace is empty
|
||||
#[wasm_bindgen]
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.buffer.is_empty()
|
||||
}
|
||||
|
||||
/// Get current number of representations
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn len(&self) -> usize {
|
||||
self.buffer.len()
|
||||
}
|
||||
|
||||
/// Get workspace capacity
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn capacity(&self) -> usize {
|
||||
self.capacity
|
||||
}
|
||||
|
||||
/// Clear all representations
|
||||
#[wasm_bindgen]
|
||||
pub fn clear(&mut self) {
|
||||
self.buffer.clear();
|
||||
}
|
||||
|
||||
/// Get average salience
|
||||
#[wasm_bindgen]
|
||||
pub fn average_salience(&self) -> f32 {
|
||||
if self.buffer.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
let sum: f32 = self.buffer.iter().map(|r| r.salience).sum();
|
||||
sum / self.buffer.len() as f32
|
||||
}
|
||||
|
||||
/// Get available slots
|
||||
#[wasm_bindgen]
|
||||
pub fn available_slots(&self) -> usize {
|
||||
self.capacity.saturating_sub(self.buffer.len())
|
||||
}
|
||||
|
||||
/// Get current load (0.0 to 1.0)
|
||||
#[wasm_bindgen]
|
||||
pub fn current_load(&self) -> f32 {
|
||||
self.buffer.len() as f32 / self.capacity as f32
|
||||
}
|
||||
|
||||
/// Find index of weakest representation
|
||||
fn find_weakest(&self) -> Option<usize> {
|
||||
if self.buffer.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut min_idx = 0;
|
||||
let mut min_salience = self.buffer[0].salience;
|
||||
|
||||
for (i, item) in self.buffer.iter().enumerate().skip(1) {
|
||||
if item.salience < min_salience {
|
||||
min_salience = item.salience;
|
||||
min_idx = i;
|
||||
}
|
||||
}
|
||||
|
||||
Some(min_idx)
|
||||
}
|
||||
}
|
||||
334
vendor/ruvector/crates/ruvector-nervous-system-wasm/src/wta.rs
vendored
Normal file
334
vendor/ruvector/crates/ruvector-nervous-system-wasm/src/wta.rs
vendored
Normal file
@@ -0,0 +1,334 @@
|
||||
//! Winner-Take-All (WTA) WASM bindings
|
||||
//!
|
||||
//! Instant decisions via neural competition:
|
||||
//! - Single winner: <1us for 1000 neurons
|
||||
//! - K-WTA: <10us for k=50
|
||||
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
/// Winner-Take-All competition layer
|
||||
///
|
||||
/// Implements neural competition where the highest-activation neuron
|
||||
/// wins and suppresses others through lateral inhibition.
|
||||
///
|
||||
/// # Performance
|
||||
/// - <1us winner selection for 1000 neurons
|
||||
#[wasm_bindgen]
|
||||
pub struct WTALayer {
|
||||
membranes: Vec<f32>,
|
||||
threshold: f32,
|
||||
inhibition_strength: f32,
|
||||
refractory_period: u32,
|
||||
refractory_counters: Vec<u32>,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl WTALayer {
|
||||
/// Create a new WTA layer
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `size` - Number of competing neurons
|
||||
/// * `threshold` - Activation threshold for firing
|
||||
/// * `inhibition` - Lateral inhibition strength (0.0-1.0)
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(size: usize, threshold: f32, inhibition: f32) -> Result<WTALayer, JsValue> {
|
||||
if size == 0 {
|
||||
return Err(JsValue::from_str("Size must be > 0"));
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
membranes: vec![0.0; size],
|
||||
threshold,
|
||||
inhibition_strength: inhibition.clamp(0.0, 1.0),
|
||||
refractory_period: 10,
|
||||
refractory_counters: vec![0; size],
|
||||
})
|
||||
}
|
||||
|
||||
/// Run winner-take-all competition
|
||||
///
|
||||
/// Returns the index of the winning neuron, or -1 if no neuron exceeds threshold.
|
||||
#[wasm_bindgen]
|
||||
pub fn compete(&mut self, inputs: &[f32]) -> Result<i32, JsValue> {
|
||||
if inputs.len() != self.membranes.len() {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Input size mismatch: expected {}, got {}",
|
||||
self.membranes.len(),
|
||||
inputs.len()
|
||||
)));
|
||||
}
|
||||
|
||||
// Single-pass: update membrane potentials and find max
|
||||
let mut best_idx: Option<usize> = None;
|
||||
let mut best_val = f32::NEG_INFINITY;
|
||||
|
||||
for (i, &input) in inputs.iter().enumerate() {
|
||||
if self.refractory_counters[i] == 0 {
|
||||
self.membranes[i] = input;
|
||||
if input > best_val {
|
||||
best_val = input;
|
||||
best_idx = Some(i);
|
||||
}
|
||||
} else {
|
||||
self.refractory_counters[i] = self.refractory_counters[i].saturating_sub(1);
|
||||
}
|
||||
}
|
||||
|
||||
let winner_idx = match best_idx {
|
||||
Some(idx) => idx,
|
||||
None => return Ok(-1),
|
||||
};
|
||||
|
||||
// Check if winner exceeds threshold
|
||||
if best_val < self.threshold {
|
||||
return Ok(-1);
|
||||
}
|
||||
|
||||
// Apply lateral inhibition
|
||||
for (i, membrane) in self.membranes.iter_mut().enumerate() {
|
||||
if i != winner_idx {
|
||||
*membrane *= 1.0 - self.inhibition_strength;
|
||||
}
|
||||
}
|
||||
|
||||
// Set refractory period for winner
|
||||
self.refractory_counters[winner_idx] = self.refractory_period;
|
||||
|
||||
Ok(winner_idx as i32)
|
||||
}
|
||||
|
||||
/// Soft competition with normalized activations
|
||||
///
|
||||
/// Returns activation levels for all neurons after softmax-like normalization.
|
||||
#[wasm_bindgen]
|
||||
pub fn compete_soft(&mut self, inputs: &[f32]) -> Result<js_sys::Float32Array, JsValue> {
|
||||
if inputs.len() != self.membranes.len() {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Input size mismatch: expected {}, got {}",
|
||||
self.membranes.len(),
|
||||
inputs.len()
|
||||
)));
|
||||
}
|
||||
|
||||
// Update membrane potentials
|
||||
self.membranes.copy_from_slice(inputs);
|
||||
|
||||
// Find max for numerical stability
|
||||
let max_val = self
|
||||
.membranes
|
||||
.iter()
|
||||
.copied()
|
||||
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
||||
.unwrap_or(0.0);
|
||||
|
||||
// Softmax with temperature
|
||||
let temperature = 1.0 / (1.0 + self.inhibition_strength);
|
||||
let mut activations: Vec<f32> = self
|
||||
.membranes
|
||||
.iter()
|
||||
.map(|&x| ((x - max_val) / temperature).exp())
|
||||
.collect();
|
||||
|
||||
// Normalize
|
||||
let sum: f32 = activations.iter().sum();
|
||||
if sum > 0.0 {
|
||||
for a in &mut activations {
|
||||
*a /= sum;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(js_sys::Float32Array::from(activations.as_slice()))
|
||||
}
|
||||
|
||||
/// Reset layer state
|
||||
#[wasm_bindgen]
|
||||
pub fn reset(&mut self) {
|
||||
self.membranes.fill(0.0);
|
||||
self.refractory_counters.fill(0);
|
||||
}
|
||||
|
||||
/// Get current membrane potentials
|
||||
#[wasm_bindgen]
|
||||
pub fn get_membranes(&self) -> js_sys::Float32Array {
|
||||
js_sys::Float32Array::from(self.membranes.as_slice())
|
||||
}
|
||||
|
||||
/// Set refractory period
|
||||
#[wasm_bindgen]
|
||||
pub fn set_refractory_period(&mut self, period: u32) {
|
||||
self.refractory_period = period;
|
||||
}
|
||||
|
||||
/// Get layer size
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn size(&self) -> usize {
|
||||
self.membranes.len()
|
||||
}
|
||||
}
|
||||
|
||||
/// K-Winner-Take-All layer for sparse distributed coding
|
||||
///
|
||||
/// Selects top-k neurons with highest activations.
|
||||
///
|
||||
/// # Performance
|
||||
/// - O(n + k log k) using partial sorting
|
||||
/// - <10us for 1000 neurons, k=50
|
||||
#[wasm_bindgen]
|
||||
pub struct KWTALayer {
|
||||
size: usize,
|
||||
k: usize,
|
||||
threshold: Option<f32>,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl KWTALayer {
|
||||
/// Create a new K-WTA layer
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `size` - Total number of neurons
|
||||
/// * `k` - Number of winners to select
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(size: usize, k: usize) -> Result<KWTALayer, JsValue> {
|
||||
if k == 0 {
|
||||
return Err(JsValue::from_str("k must be > 0"));
|
||||
}
|
||||
if k > size {
|
||||
return Err(JsValue::from_str("k cannot exceed layer size"));
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
size,
|
||||
k,
|
||||
threshold: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Set activation threshold
|
||||
#[wasm_bindgen]
|
||||
pub fn with_threshold(&mut self, threshold: f32) {
|
||||
self.threshold = Some(threshold);
|
||||
}
|
||||
|
||||
/// Select top-k neurons
|
||||
///
|
||||
/// Returns indices of k neurons with highest activations, sorted descending.
|
||||
#[wasm_bindgen]
|
||||
pub fn select(&self, inputs: &[f32]) -> Result<js_sys::Uint32Array, JsValue> {
|
||||
if inputs.len() != self.size {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Input size mismatch: expected {}, got {}",
|
||||
self.size,
|
||||
inputs.len()
|
||||
)));
|
||||
}
|
||||
|
||||
// Create (index, value) pairs
|
||||
let mut indexed: Vec<(usize, f32)> =
|
||||
inputs.iter().enumerate().map(|(i, &v)| (i, v)).collect();
|
||||
|
||||
// Filter by threshold if set
|
||||
if let Some(threshold) = self.threshold {
|
||||
indexed.retain(|(_, v)| *v >= threshold);
|
||||
}
|
||||
|
||||
if indexed.is_empty() {
|
||||
return Ok(js_sys::Uint32Array::new_with_length(0));
|
||||
}
|
||||
|
||||
// Partial sort to get top-k
|
||||
let k_actual = self.k.min(indexed.len());
|
||||
indexed.select_nth_unstable_by(k_actual - 1, |a, b| {
|
||||
b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
// Take top k and sort descending
|
||||
let mut winners: Vec<(usize, f32)> = indexed[..k_actual].to_vec();
|
||||
winners.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
|
||||
// Return only indices as u32
|
||||
let indices: Vec<u32> = winners.into_iter().map(|(i, _)| i as u32).collect();
|
||||
Ok(js_sys::Uint32Array::from(indices.as_slice()))
|
||||
}
|
||||
|
||||
/// Select top-k neurons with their activation values
|
||||
///
|
||||
/// Returns array of [index, value] pairs.
|
||||
#[wasm_bindgen]
|
||||
pub fn select_with_values(&self, inputs: &[f32]) -> Result<JsValue, JsValue> {
|
||||
if inputs.len() != self.size {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Input size mismatch: expected {}, got {}",
|
||||
self.size,
|
||||
inputs.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let mut indexed: Vec<(usize, f32)> =
|
||||
inputs.iter().enumerate().map(|(i, &v)| (i, v)).collect();
|
||||
|
||||
if let Some(threshold) = self.threshold {
|
||||
indexed.retain(|(_, v)| *v >= threshold);
|
||||
}
|
||||
|
||||
if indexed.is_empty() {
|
||||
return serde_wasm_bindgen::to_value(&Vec::<(usize, f32)>::new())
|
||||
.map_err(|e| JsValue::from_str(&e.to_string()));
|
||||
}
|
||||
|
||||
let k_actual = self.k.min(indexed.len());
|
||||
indexed.select_nth_unstable_by(k_actual - 1, |a, b| {
|
||||
b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
let mut winners: Vec<(usize, f32)> = indexed[..k_actual].to_vec();
|
||||
winners.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
|
||||
serde_wasm_bindgen::to_value(&winners).map_err(|e| JsValue::from_str(&e.to_string()))
|
||||
}
|
||||
|
||||
/// Create sparse activation vector (only top-k preserved)
|
||||
#[wasm_bindgen]
|
||||
pub fn sparse_activations(&self, inputs: &[f32]) -> Result<js_sys::Float32Array, JsValue> {
|
||||
if inputs.len() != self.size {
|
||||
return Err(JsValue::from_str(&format!(
|
||||
"Input size mismatch: expected {}, got {}",
|
||||
self.size,
|
||||
inputs.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let mut indexed: Vec<(usize, f32)> =
|
||||
inputs.iter().enumerate().map(|(i, &v)| (i, v)).collect();
|
||||
|
||||
if let Some(threshold) = self.threshold {
|
||||
indexed.retain(|(_, v)| *v >= threshold);
|
||||
}
|
||||
|
||||
let mut sparse = vec![0.0; self.size];
|
||||
|
||||
if !indexed.is_empty() {
|
||||
let k_actual = self.k.min(indexed.len());
|
||||
indexed.select_nth_unstable_by(k_actual - 1, |a, b| {
|
||||
b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
for (idx, value) in &indexed[..k_actual] {
|
||||
sparse[*idx] = *value;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(js_sys::Float32Array::from(sparse.as_slice()))
|
||||
}
|
||||
|
||||
/// Get number of winners
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn k(&self) -> usize {
|
||||
self.k
|
||||
}
|
||||
|
||||
/// Get layer size
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user