Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
960
vendor/ruvector/examples/edge-net/src/learning/mod.rs
vendored
Normal file
960
vendor/ruvector/examples/edge-net/src/learning/mod.rs
vendored
Normal file
@@ -0,0 +1,960 @@
|
||||
//! Learning and Attention Module for Edge-Net
|
||||
//!
|
||||
//! Integrates RuVector's self-learning intelligence and attention mechanisms
|
||||
//! for distributed compute optimization. This module enables edge nodes to:
|
||||
//!
|
||||
//! - **Learn patterns** from task execution trajectories
|
||||
//! - **Store knowledge** in a ReasoningBank for retrieval
|
||||
//! - **Route tasks** using multi-head attention
|
||||
//! - **Optimize energy** with spike-driven attention (87x more efficient)
|
||||
//!
|
||||
//! ## Architecture
|
||||
//!
|
||||
//! ```text
|
||||
//! ┌─────────────────────────────────────────────────────┐
|
||||
//! │ Learning Intelligence │
|
||||
//! ├─────────────────────────────────────────────────────┤
|
||||
//! │ ┌──────────────┐ ┌──────────────┐ ┌───────────┐ │
|
||||
//! │ │ ReasoningBank│ │ Trajectory │ │ Pattern │ │
|
||||
//! │ │ Storage │◄─┤ Tracker │──┤ Extractor │ │
|
||||
//! │ └──────────────┘ └──────────────┘ └───────────┘ │
|
||||
//! ├─────────────────────────────────────────────────────┤
|
||||
//! │ ┌──────────────┐ ┌──────────────┐ │
|
||||
//! │ │ Multi-Head │ │ Spike-Driven │ │
|
||||
//! │ │ Attention │ │ Attention │ │
|
||||
//! │ │ (Task Route) │ │ (87x Energy) │ │
|
||||
//! │ └──────────────┘ └──────────────┘ │
|
||||
//! └─────────────────────────────────────────────────────┘
|
||||
//! ```
|
||||
|
||||
use wasm_bindgen::prelude::*;
|
||||
use serde::{Serialize, Deserialize};
|
||||
use rustc_hash::FxHashMap;
|
||||
use std::sync::RwLock;
|
||||
|
||||
// ============================================================================
|
||||
// Learned Patterns
|
||||
// ============================================================================
|
||||
|
||||
/// A learned pattern from task execution
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct LearnedPattern {
|
||||
/// Centroid vector representing the pattern
|
||||
pub centroid: Vec<f32>,
|
||||
/// Optimal task allocation score
|
||||
pub optimal_allocation: f32,
|
||||
/// Optimal energy budget for this pattern
|
||||
pub optimal_energy: u64,
|
||||
/// Confidence score (0.0 - 1.0)
|
||||
pub confidence: f64,
|
||||
/// Number of samples in this pattern
|
||||
pub sample_count: usize,
|
||||
/// Average latency in milliseconds
|
||||
pub avg_latency_ms: f64,
|
||||
/// Average success rate
|
||||
pub avg_success_rate: Option<f64>,
|
||||
}
|
||||
|
||||
impl LearnedPattern {
|
||||
/// Create a new learned pattern
|
||||
pub fn new(
|
||||
centroid: Vec<f32>,
|
||||
optimal_allocation: f32,
|
||||
optimal_energy: u64,
|
||||
confidence: f64,
|
||||
sample_count: usize,
|
||||
avg_latency_ms: f64,
|
||||
avg_success_rate: Option<f64>,
|
||||
) -> Self {
|
||||
Self {
|
||||
centroid,
|
||||
optimal_allocation,
|
||||
optimal_energy,
|
||||
confidence,
|
||||
sample_count,
|
||||
avg_latency_ms,
|
||||
avg_success_rate,
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate cosine similarity to a query vector
|
||||
pub fn similarity(&self, query: &[f32]) -> f64 {
|
||||
if query.len() != self.centroid.len() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let dot: f32 = query.iter().zip(&self.centroid).map(|(a, b)| a * b).sum();
|
||||
let norm_q: f32 = query.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_c: f32 = self.centroid.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
|
||||
if norm_q == 0.0 || norm_c == 0.0 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
(dot / (norm_q * norm_c)) as f64
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Task Trajectory
|
||||
// ============================================================================
|
||||
|
||||
/// A single task execution trajectory
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct TaskTrajectory {
|
||||
/// Task feature vector
|
||||
pub task_vector: Vec<f32>,
|
||||
/// Execution latency in milliseconds
|
||||
pub latency_ms: u64,
|
||||
/// Energy consumed (rUv)
|
||||
pub energy_spent: u64,
|
||||
/// Energy earned (rUv)
|
||||
pub energy_earned: u64,
|
||||
/// Task success flag
|
||||
pub success: bool,
|
||||
/// Node that executed the task
|
||||
pub executor_id: String,
|
||||
/// Timestamp (ms since epoch)
|
||||
pub timestamp: u64,
|
||||
}
|
||||
|
||||
impl TaskTrajectory {
|
||||
/// Create a new task trajectory
|
||||
pub fn new(
|
||||
task_vector: Vec<f32>,
|
||||
latency_ms: u64,
|
||||
energy_spent: u64,
|
||||
energy_earned: u64,
|
||||
success: bool,
|
||||
executor_id: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
task_vector,
|
||||
latency_ms,
|
||||
energy_spent,
|
||||
energy_earned,
|
||||
success,
|
||||
executor_id,
|
||||
timestamp: js_sys::Date::now() as u64,
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate efficiency ratio (earned/spent)
|
||||
pub fn efficiency(&self) -> f64 {
|
||||
if self.energy_spent == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
self.energy_earned as f64 / self.energy_spent as f64
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Trajectory Tracker
|
||||
// ============================================================================
|
||||
|
||||
/// Ring buffer tracker for task trajectories
|
||||
#[wasm_bindgen]
|
||||
pub struct TrajectoryTracker {
|
||||
/// Ring buffer of trajectories
|
||||
trajectories: RwLock<Vec<TaskTrajectory>>,
|
||||
/// Maximum size
|
||||
max_size: usize,
|
||||
/// Current write position
|
||||
write_pos: RwLock<usize>,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl TrajectoryTracker {
|
||||
/// Create a new trajectory tracker
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(max_size: usize) -> Self {
|
||||
Self {
|
||||
trajectories: RwLock::new(Vec::with_capacity(max_size)),
|
||||
max_size,
|
||||
write_pos: RwLock::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a new trajectory
|
||||
#[wasm_bindgen]
|
||||
pub fn record(&self, trajectory_json: &str) -> bool {
|
||||
let trajectory: TaskTrajectory = match serde_json::from_str(trajectory_json) {
|
||||
Ok(t) => t,
|
||||
Err(_) => return false,
|
||||
};
|
||||
|
||||
let mut trajectories = self.trajectories.write().unwrap();
|
||||
let mut pos = self.write_pos.write().unwrap();
|
||||
|
||||
if trajectories.len() < self.max_size {
|
||||
trajectories.push(trajectory);
|
||||
} else {
|
||||
trajectories[*pos] = trajectory;
|
||||
}
|
||||
|
||||
*pos = (*pos + 1) % self.max_size;
|
||||
true
|
||||
}
|
||||
|
||||
/// Get statistics as JSON
|
||||
#[wasm_bindgen(js_name = getStats)]
|
||||
pub fn get_stats(&self) -> String {
|
||||
let trajectories = self.trajectories.read().unwrap();
|
||||
|
||||
if trajectories.is_empty() {
|
||||
return r#"{"total":0}"#.to_string();
|
||||
}
|
||||
|
||||
let total = trajectories.len();
|
||||
let successful = trajectories.iter().filter(|t| t.success).count();
|
||||
let avg_latency = trajectories.iter().map(|t| t.latency_ms).sum::<u64>() as f64 / total as f64;
|
||||
let avg_efficiency = trajectories.iter().map(|t| t.efficiency()).sum::<f64>() / total as f64;
|
||||
|
||||
format!(
|
||||
r#"{{"total":{},"successful":{},"success_rate":{:.4},"avg_latency_ms":{:.2},"avg_efficiency":{:.4}}}"#,
|
||||
total,
|
||||
successful,
|
||||
successful as f64 / total as f64,
|
||||
avg_latency,
|
||||
avg_efficiency
|
||||
)
|
||||
}
|
||||
|
||||
/// Get count of trajectories
|
||||
#[wasm_bindgen]
|
||||
pub fn count(&self) -> usize {
|
||||
self.trajectories.read().unwrap().len()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Reasoning Bank
|
||||
// ============================================================================
|
||||
|
||||
/// Pattern entry with usage tracking
|
||||
#[derive(Clone)]
|
||||
struct PatternEntry {
|
||||
pattern: LearnedPattern,
|
||||
usage_count: usize,
|
||||
last_used: u64,
|
||||
}
|
||||
|
||||
/// Spatial bucket for fast approximate nearest neighbor search
|
||||
struct SpatialBucket {
|
||||
pattern_ids: Vec<usize>,
|
||||
}
|
||||
|
||||
/// ReasoningBank for storing and retrieving learned patterns
|
||||
/// Optimized with spatial indexing for O(1) approximate lookups
|
||||
#[wasm_bindgen]
|
||||
pub struct ReasoningBank {
|
||||
/// Stored patterns indexed by ID
|
||||
patterns: RwLock<FxHashMap<usize, PatternEntry>>,
|
||||
/// Next pattern ID
|
||||
next_id: RwLock<usize>,
|
||||
/// Spatial index for fast approximate nearest neighbor
|
||||
/// Maps quantized vector hash to pattern IDs
|
||||
spatial_index: RwLock<FxHashMap<u64, SpatialBucket>>,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl ReasoningBank {
|
||||
/// Create a new ReasoningBank
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new() -> ReasoningBank {
|
||||
ReasoningBank {
|
||||
patterns: RwLock::new(FxHashMap::default()),
|
||||
next_id: RwLock::new(0),
|
||||
spatial_index: RwLock::new(FxHashMap::default()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Hash a vector into a spatial bucket (locality-sensitive hashing)
|
||||
fn spatial_hash(vector: &[f32]) -> u64 {
|
||||
// Simple grid-based quantization for fast approximate matching
|
||||
// Quantize each dimension to 8 levels (3 bits)
|
||||
let mut hash = 0u64;
|
||||
for (i, &val) in vector.iter().take(20).enumerate() {
|
||||
// Normalize to [0, 7] range
|
||||
let quantized = ((val + 1.0) * 3.5).clamp(0.0, 7.0) as u64;
|
||||
hash |= quantized << (i * 3);
|
||||
}
|
||||
hash
|
||||
}
|
||||
|
||||
/// Store a new pattern (JSON format)
|
||||
#[wasm_bindgen]
|
||||
pub fn store(&self, pattern_json: &str) -> i32 {
|
||||
let pattern: LearnedPattern = match serde_json::from_str(pattern_json) {
|
||||
Ok(p) => p,
|
||||
Err(_) => return -1,
|
||||
};
|
||||
|
||||
// Compute spatial hash for indexing
|
||||
let hash = Self::spatial_hash(&pattern.centroid);
|
||||
|
||||
let mut next_id = self.next_id.write().unwrap();
|
||||
let id = *next_id;
|
||||
*next_id += 1;
|
||||
|
||||
let entry = PatternEntry {
|
||||
pattern,
|
||||
usage_count: 0,
|
||||
last_used: js_sys::Date::now() as u64,
|
||||
};
|
||||
|
||||
self.patterns.write().unwrap().insert(id, entry);
|
||||
|
||||
// Add to spatial index
|
||||
let mut index = self.spatial_index.write().unwrap();
|
||||
index.entry(hash)
|
||||
.or_insert_with(|| SpatialBucket { pattern_ids: Vec::with_capacity(10) })
|
||||
.pattern_ids.push(id);
|
||||
|
||||
id as i32
|
||||
}
|
||||
|
||||
/// Lookup most similar patterns (OPTIMIZED with spatial indexing)
|
||||
#[wasm_bindgen]
|
||||
pub fn lookup(&self, query_json: &str, k: usize) -> String {
|
||||
let query: Vec<f32> = match serde_json::from_str(query_json) {
|
||||
Ok(q) => q,
|
||||
Err(_) => return "[]".to_string(),
|
||||
};
|
||||
|
||||
let query_hash = Self::spatial_hash(&query);
|
||||
let now = js_sys::Date::now() as u64;
|
||||
|
||||
// Step 1: Fast approximate search using spatial index
|
||||
let index = self.spatial_index.read().unwrap();
|
||||
let mut candidate_ids = Vec::with_capacity(k * 3); // Pre-allocate
|
||||
|
||||
// Get patterns from same bucket
|
||||
if let Some(bucket) = index.get(&query_hash) {
|
||||
candidate_ids.extend_from_slice(&bucket.pattern_ids);
|
||||
}
|
||||
|
||||
// Check neighboring buckets (increase recall)
|
||||
// Flip 1-2 bits in hash to find nearby buckets
|
||||
for bit_flip in 0..6 {
|
||||
let neighbor_hash = query_hash ^ (1u64 << (bit_flip * 3));
|
||||
if let Some(bucket) = index.get(&neighbor_hash) {
|
||||
candidate_ids.extend_from_slice(&bucket.pattern_ids);
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: if too few candidates, scan more buckets
|
||||
if candidate_ids.len() < k * 2 {
|
||||
for bucket in index.values().take(10) {
|
||||
candidate_ids.extend_from_slice(&bucket.pattern_ids);
|
||||
if candidate_ids.len() >= k * 3 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: Exact similarity computation only for candidates
|
||||
let mut patterns = self.patterns.write().unwrap();
|
||||
let mut similarities = Vec::with_capacity(candidate_ids.len());
|
||||
|
||||
for &id in &candidate_ids {
|
||||
if let Some(entry) = patterns.get_mut(&id) {
|
||||
let similarity = entry.pattern.similarity(&query);
|
||||
entry.usage_count += 1;
|
||||
entry.last_used = now;
|
||||
similarities.push((id, entry.pattern.clone(), similarity));
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by weighted score (similarity * confidence)
|
||||
similarities.sort_unstable_by(|a, b| {
|
||||
let score_a = a.2 * a.1.confidence;
|
||||
let score_b = b.2 * b.1.confidence;
|
||||
score_b.partial_cmp(&score_a).unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
similarities.truncate(k);
|
||||
|
||||
// Pre-allocate string with estimated capacity
|
||||
let mut result = String::with_capacity(k * 120);
|
||||
result.push('[');
|
||||
|
||||
for (i, (id, pattern, sim)) in similarities.iter().enumerate() {
|
||||
if i > 0 {
|
||||
result.push(',');
|
||||
}
|
||||
use std::fmt::Write;
|
||||
let _ = write!(
|
||||
result,
|
||||
r#"{{"id":{},"similarity":{:.4},"confidence":{:.4},"optimal_allocation":{:.4},"optimal_energy":{}}}"#,
|
||||
id, sim, pattern.confidence, pattern.optimal_allocation, pattern.optimal_energy
|
||||
);
|
||||
}
|
||||
|
||||
result.push(']');
|
||||
result
|
||||
}
|
||||
|
||||
/// Prune low-quality patterns
|
||||
#[wasm_bindgen]
|
||||
pub fn prune(&self, min_usage: usize, min_confidence: f64) -> usize {
|
||||
let mut patterns = self.patterns.write().unwrap();
|
||||
let before = patterns.len();
|
||||
|
||||
patterns.retain(|_, entry| {
|
||||
entry.usage_count >= min_usage && entry.pattern.confidence >= min_confidence
|
||||
});
|
||||
|
||||
before - patterns.len()
|
||||
}
|
||||
|
||||
/// Get total pattern count
|
||||
#[wasm_bindgen]
|
||||
pub fn count(&self) -> usize {
|
||||
self.patterns.read().unwrap().len()
|
||||
}
|
||||
|
||||
/// Get bank statistics
|
||||
#[wasm_bindgen(js_name = getStats)]
|
||||
pub fn get_stats(&self) -> String {
|
||||
let patterns = self.patterns.read().unwrap();
|
||||
|
||||
if patterns.is_empty() {
|
||||
return r#"{"total":0}"#.to_string();
|
||||
}
|
||||
|
||||
let total = patterns.len();
|
||||
let total_samples: usize = patterns.values().map(|e| e.pattern.sample_count).sum();
|
||||
let avg_confidence: f64 = patterns.values().map(|e| e.pattern.confidence).sum::<f64>() / total as f64;
|
||||
let total_usage: usize = patterns.values().map(|e| e.usage_count).sum();
|
||||
|
||||
format!(
|
||||
r#"{{"total_patterns":{},"total_samples":{},"avg_confidence":{:.4},"total_usage":{}}}"#,
|
||||
total, total_samples, avg_confidence, total_usage
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ReasoningBank {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Spike Train for Energy-Efficient Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Spike train representation for temporal coding
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct SpikeTrain {
|
||||
/// Spike times within temporal window
|
||||
pub times: Vec<u8>,
|
||||
/// Spike polarities: +1 for positive, -1 for negative
|
||||
pub polarities: Vec<i8>,
|
||||
}
|
||||
|
||||
impl SpikeTrain {
|
||||
/// Create empty spike train
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
times: Vec::new(),
|
||||
polarities: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create spike train with pre-allocated capacity
|
||||
pub fn with_capacity(capacity: usize) -> Self {
|
||||
Self {
|
||||
times: Vec::with_capacity(capacity),
|
||||
polarities: Vec::with_capacity(capacity),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a spike at given time with polarity
|
||||
pub fn add_spike(&mut self, time: u8, polarity: i8) {
|
||||
self.times.push(time);
|
||||
self.polarities.push(polarity);
|
||||
}
|
||||
|
||||
/// Number of spikes
|
||||
pub fn len(&self) -> usize {
|
||||
self.times.len()
|
||||
}
|
||||
|
||||
/// Check if empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.times.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Spike-Driven Attention
|
||||
// ============================================================================
|
||||
|
||||
/// Configuration for spike-driven attention
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct SpikeDrivenConfig {
|
||||
/// Spike threshold in Q15 fixed-point
|
||||
pub spike_threshold_q15: u16,
|
||||
/// Number of temporal coding steps
|
||||
pub temporal_coding_steps: u8,
|
||||
/// Use binary quantization
|
||||
pub binary_qkv: bool,
|
||||
/// Refractory period after spike
|
||||
pub refractory_period: u8,
|
||||
}
|
||||
|
||||
impl Default for SpikeDrivenConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
spike_threshold_q15: 16384, // 0.5 in Q15
|
||||
temporal_coding_steps: 8,
|
||||
binary_qkv: true,
|
||||
refractory_period: 2,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Spike-driven attention for energy-efficient compute (87x savings)
|
||||
#[wasm_bindgen]
|
||||
pub struct SpikeDrivenAttention {
|
||||
config: SpikeDrivenConfig,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl SpikeDrivenAttention {
|
||||
/// Create new spike-driven attention with default config
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
config: SpikeDrivenConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with custom parameters
|
||||
#[wasm_bindgen(js_name = withConfig)]
|
||||
pub fn with_config(threshold: u16, steps: u8, refractory: u8) -> Self {
|
||||
Self {
|
||||
config: SpikeDrivenConfig {
|
||||
spike_threshold_q15: threshold,
|
||||
temporal_coding_steps: steps,
|
||||
binary_qkv: true,
|
||||
refractory_period: refractory,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Estimate energy savings ratio compared to standard attention
|
||||
#[wasm_bindgen(js_name = energyRatio)]
|
||||
pub fn energy_ratio(&self, seq_len: usize, hidden_dim: usize) -> f32 {
|
||||
if seq_len == 0 || hidden_dim == 0 {
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
// Standard attention operations (multiplications)
|
||||
let standard_mults = 2 * seq_len * seq_len * hidden_dim;
|
||||
|
||||
// Spike-driven operations (additions only)
|
||||
let avg_spikes_per_neuron = (self.config.temporal_coding_steps as f32) * 0.3;
|
||||
let spike_adds = (seq_len as f32) * avg_spikes_per_neuron * (hidden_dim as f32);
|
||||
|
||||
// Energy ratio (multiplication ~3.7x more expensive than addition)
|
||||
let mult_energy_factor = 3.7;
|
||||
|
||||
let standard_energy = (standard_mults as f32) * mult_energy_factor;
|
||||
let spike_energy = spike_adds;
|
||||
|
||||
if spike_energy == 0.0 {
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
standard_energy / spike_energy
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SpikeDrivenAttention {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl SpikeDrivenAttention {
|
||||
/// Encode values to spike trains using rate coding (OPTIMIZED with pre-allocation)
|
||||
pub fn encode_spikes(&self, values: &[i8]) -> Vec<SpikeTrain> {
|
||||
let steps = self.config.temporal_coding_steps as usize;
|
||||
let mut trains = Vec::with_capacity(values.len());
|
||||
|
||||
for &value in values {
|
||||
// Pre-allocate spike train capacity (max possible spikes)
|
||||
let mut train = SpikeTrain::with_capacity(steps);
|
||||
|
||||
let abs_val = if value == i8::MIN { 128u16 } else { value.abs() as u16 };
|
||||
let polarity = value.signum();
|
||||
|
||||
if abs_val == 0 {
|
||||
trains.push(train);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Rate coding: spike frequency proportional to magnitude
|
||||
let rate_q15 = ((abs_val as u32) * 32768 / 128) as u16;
|
||||
|
||||
let mut refractory_counter = 0u8;
|
||||
let mut membrane_potential = 0u32;
|
||||
|
||||
for step in 0..steps {
|
||||
if refractory_counter > 0 {
|
||||
refractory_counter -= 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
membrane_potential = membrane_potential.saturating_add(rate_q15 as u32);
|
||||
|
||||
if membrane_potential >= self.config.spike_threshold_q15 as u32 {
|
||||
train.add_spike(step as u8, polarity);
|
||||
membrane_potential = 0;
|
||||
refractory_counter = self.config.refractory_period;
|
||||
}
|
||||
}
|
||||
|
||||
trains.push(train);
|
||||
}
|
||||
|
||||
trains
|
||||
}
|
||||
|
||||
/// Compute spike-driven attention (no multiplications)
|
||||
pub fn attention(
|
||||
&self,
|
||||
q_spikes: &[SpikeTrain],
|
||||
k_spikes: &[SpikeTrain],
|
||||
v_spikes: &[SpikeTrain],
|
||||
) -> Vec<i32> {
|
||||
let seq_len = q_spikes.len().min(k_spikes.len());
|
||||
let hidden_dim = v_spikes.len();
|
||||
let mut output = vec![0i32; hidden_dim];
|
||||
|
||||
if seq_len == 0 || hidden_dim == 0 {
|
||||
return output;
|
||||
}
|
||||
|
||||
for q_idx in 0..seq_len {
|
||||
let q_train = &q_spikes[q_idx];
|
||||
|
||||
// Compute attention weights via spike coincidence
|
||||
for k_idx in 0..=q_idx.min(seq_len - 1) {
|
||||
let k_train = &k_spikes[k_idx];
|
||||
|
||||
let mut coincidence_score = 0i32;
|
||||
for (&q_time, &q_pol) in q_train.times.iter().zip(q_train.polarities.iter()) {
|
||||
for (&k_time, &k_pol) in k_train.times.iter().zip(k_train.polarities.iter()) {
|
||||
if q_time == k_time {
|
||||
coincidence_score += (q_pol as i32) * (k_pol as i32);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if coincidence_score != 0 {
|
||||
for (d, v_train) in v_spikes.iter().enumerate().take(hidden_dim) {
|
||||
let value_contrib: i32 = v_train.polarities.iter()
|
||||
.map(|&p| (p as i32).saturating_mul(coincidence_score))
|
||||
.sum();
|
||||
output[d] += value_contrib;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Multi-Head Attention for Task Routing
|
||||
// ============================================================================
|
||||
|
||||
/// Multi-head attention for distributed task routing
|
||||
#[wasm_bindgen]
|
||||
pub struct MultiHeadAttention {
|
||||
dim: usize,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl MultiHeadAttention {
|
||||
/// Create new multi-head attention
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(dim: usize, num_heads: usize) -> Self {
|
||||
let head_dim = dim / num_heads;
|
||||
Self { dim, num_heads, head_dim }
|
||||
}
|
||||
|
||||
/// Get embedding dimension
|
||||
#[wasm_bindgen]
|
||||
pub fn dim(&self) -> usize {
|
||||
self.dim
|
||||
}
|
||||
|
||||
/// Get number of heads
|
||||
#[wasm_bindgen(js_name = numHeads)]
|
||||
pub fn num_heads(&self) -> usize {
|
||||
self.num_heads
|
||||
}
|
||||
}
|
||||
|
||||
impl MultiHeadAttention {
|
||||
/// Split input into multiple heads
|
||||
fn split_heads(&self, input: &[f32]) -> Vec<Vec<f32>> {
|
||||
(0..self.num_heads)
|
||||
.map(|h| {
|
||||
let start = h * self.head_dim;
|
||||
let end = start + self.head_dim;
|
||||
input[start..end].to_vec()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Compute scaled dot-product attention for a single head
|
||||
fn scaled_dot_product(&self, query: &[f32], keys: &[&[f32]], values: &[&[f32]]) -> Vec<f32> {
|
||||
let scale = (self.head_dim as f32).sqrt();
|
||||
|
||||
// Compute attention scores
|
||||
let scores: Vec<f32> = keys.iter()
|
||||
.map(|k| {
|
||||
let dot: f32 = query.iter().zip(*k).map(|(q, k)| q * k).sum();
|
||||
dot / scale
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Softmax
|
||||
let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let exp_scores: Vec<f32> = scores.iter().map(|s| (s - max_score).exp()).collect();
|
||||
let sum_exp: f32 = exp_scores.iter().sum();
|
||||
let attention_weights: Vec<f32> = exp_scores.iter().map(|e| e / sum_exp).collect();
|
||||
|
||||
// Weighted sum of values
|
||||
let mut output = vec![0.0f32; self.head_dim];
|
||||
for (weight, value) in attention_weights.iter().zip(values.iter()) {
|
||||
for (o, v) in output.iter_mut().zip(value.iter()) {
|
||||
*o += weight * v;
|
||||
}
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Compute multi-head attention
|
||||
pub fn compute(&self, query: &[f32], keys: &[&[f32]], values: &[&[f32]]) -> Vec<f32> {
|
||||
if query.len() != self.dim {
|
||||
return vec![0.0; self.dim];
|
||||
}
|
||||
|
||||
// Split query into heads
|
||||
let query_heads = self.split_heads(query);
|
||||
|
||||
// Split keys and values
|
||||
let key_heads: Vec<Vec<Vec<f32>>> = keys.iter().map(|k| self.split_heads(k)).collect();
|
||||
let value_heads: Vec<Vec<Vec<f32>>> = values.iter().map(|v| self.split_heads(v)).collect();
|
||||
|
||||
// Compute attention for each head
|
||||
let mut head_outputs = Vec::new();
|
||||
for h in 0..self.num_heads {
|
||||
let head_keys: Vec<&[f32]> = key_heads.iter().map(|kh| kh[h].as_slice()).collect();
|
||||
let head_values: Vec<&[f32]> = value_heads.iter().map(|vh| vh[h].as_slice()).collect();
|
||||
let head_out = self.scaled_dot_product(&query_heads[h], &head_keys, &head_values);
|
||||
head_outputs.push(head_out);
|
||||
}
|
||||
|
||||
// Concatenate head outputs
|
||||
head_outputs.into_iter().flatten().collect()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Network Learning Intelligence
|
||||
// ============================================================================
|
||||
|
||||
/// Unified learning intelligence for edge-net nodes
|
||||
#[wasm_bindgen]
|
||||
pub struct NetworkLearning {
|
||||
/// Pattern storage
|
||||
reasoning_bank: ReasoningBank,
|
||||
/// Trajectory tracking
|
||||
trajectory_tracker: TrajectoryTracker,
|
||||
/// Spike-driven attention for energy efficiency
|
||||
spike_attention: SpikeDrivenAttention,
|
||||
/// Multi-head attention for task routing
|
||||
multi_head: MultiHeadAttention,
|
||||
/// Learning rate for online updates
|
||||
learning_rate: f32,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl NetworkLearning {
|
||||
/// Create new network learning intelligence
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
reasoning_bank: ReasoningBank::new(),
|
||||
trajectory_tracker: TrajectoryTracker::new(1000),
|
||||
spike_attention: SpikeDrivenAttention::new(),
|
||||
multi_head: MultiHeadAttention::new(64, 4), // 64-dim, 4 heads
|
||||
learning_rate: 0.01,
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a task execution trajectory
|
||||
#[wasm_bindgen(js_name = recordTrajectory)]
|
||||
pub fn record_trajectory(&self, trajectory_json: &str) -> bool {
|
||||
self.trajectory_tracker.record(trajectory_json)
|
||||
}
|
||||
|
||||
/// Store a learned pattern
|
||||
#[wasm_bindgen(js_name = storePattern)]
|
||||
pub fn store_pattern(&self, pattern_json: &str) -> i32 {
|
||||
self.reasoning_bank.store(pattern_json)
|
||||
}
|
||||
|
||||
/// Look up similar patterns
|
||||
#[wasm_bindgen(js_name = lookupPatterns)]
|
||||
pub fn lookup_patterns(&self, query_json: &str, k: usize) -> String {
|
||||
self.reasoning_bank.lookup(query_json, k)
|
||||
}
|
||||
|
||||
/// Get energy savings ratio for spike-driven attention
|
||||
#[wasm_bindgen(js_name = getEnergyRatio)]
|
||||
pub fn get_energy_ratio(&self, seq_len: usize, hidden_dim: usize) -> f32 {
|
||||
self.spike_attention.energy_ratio(seq_len, hidden_dim)
|
||||
}
|
||||
|
||||
/// Get combined statistics
|
||||
#[wasm_bindgen(js_name = getStats)]
|
||||
pub fn get_stats(&self) -> String {
|
||||
let bank_stats = self.reasoning_bank.get_stats();
|
||||
let traj_stats = self.trajectory_tracker.get_stats();
|
||||
let energy_ratio = self.spike_attention.energy_ratio(64, 256);
|
||||
|
||||
format!(
|
||||
r#"{{"reasoning_bank":{},"trajectories":{},"spike_energy_ratio":{:.2},"learning_rate":{}}}"#,
|
||||
bank_stats, traj_stats, energy_ratio, self.learning_rate
|
||||
)
|
||||
}
|
||||
|
||||
/// Prune low-quality patterns
|
||||
#[wasm_bindgen]
|
||||
pub fn prune(&self, min_usage: usize, min_confidence: f64) -> usize {
|
||||
self.reasoning_bank.prune(min_usage, min_confidence)
|
||||
}
|
||||
|
||||
/// Get trajectory count
|
||||
#[wasm_bindgen(js_name = trajectoryCount)]
|
||||
pub fn trajectory_count(&self) -> usize {
|
||||
self.trajectory_tracker.count()
|
||||
}
|
||||
|
||||
/// Get pattern count
|
||||
#[wasm_bindgen(js_name = patternCount)]
|
||||
pub fn pattern_count(&self) -> usize {
|
||||
self.reasoning_bank.count()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for NetworkLearning {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_learned_pattern_similarity() {
|
||||
let pattern = LearnedPattern::new(
|
||||
vec![1.0, 0.0, 0.0],
|
||||
0.8,
|
||||
100,
|
||||
0.9,
|
||||
10,
|
||||
50.0,
|
||||
Some(0.95),
|
||||
);
|
||||
|
||||
let query_same = vec![1.0, 0.0, 0.0];
|
||||
let query_perp = vec![0.0, 1.0, 0.0];
|
||||
|
||||
assert!((pattern.similarity(&query_same) - 1.0).abs() < 0.001);
|
||||
assert!((pattern.similarity(&query_perp) - 0.0).abs() < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_task_trajectory_efficiency() {
|
||||
let traj = TaskTrajectory {
|
||||
task_vector: vec![1.0, 2.0],
|
||||
latency_ms: 100,
|
||||
energy_spent: 50,
|
||||
energy_earned: 100,
|
||||
success: true,
|
||||
executor_id: "node-1".to_string(),
|
||||
timestamp: 0,
|
||||
};
|
||||
|
||||
assert!((traj.efficiency() - 2.0).abs() < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_spike_train() {
|
||||
let mut train = SpikeTrain::new();
|
||||
assert!(train.is_empty());
|
||||
|
||||
train.add_spike(0, 1);
|
||||
train.add_spike(3, -1);
|
||||
|
||||
assert_eq!(train.len(), 2);
|
||||
assert_eq!(train.times, vec![0, 3]);
|
||||
assert_eq!(train.polarities, vec![1, -1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_spike_encoding() {
|
||||
let attn = SpikeDrivenAttention::new();
|
||||
let values = vec![64i8, 0, -64];
|
||||
let trains = attn.encode_spikes(&values);
|
||||
|
||||
assert_eq!(trains.len(), 3);
|
||||
assert!(trains[0].len() > 0); // High positive
|
||||
assert!(trains[1].is_empty()); // Zero
|
||||
assert!(trains[2].len() > 0); // High negative
|
||||
assert!(trains[2].polarities.iter().all(|&p| p == -1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multi_head_attention() {
|
||||
let attn = MultiHeadAttention::new(8, 2);
|
||||
let query = vec![1.0_f32; 8];
|
||||
let key1 = vec![0.5_f32; 8];
|
||||
let val1 = vec![1.0_f32; 8];
|
||||
let keys: Vec<&[f32]> = vec![key1.as_slice()];
|
||||
let values: Vec<&[f32]> = vec![val1.as_slice()];
|
||||
|
||||
let result = attn.compute(&query, &keys, &values);
|
||||
assert_eq!(result.len(), 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_energy_ratio() {
|
||||
let attn = SpikeDrivenAttention::new();
|
||||
let ratio = attn.energy_ratio(64, 256);
|
||||
|
||||
// Should show significant energy savings
|
||||
assert!(ratio > 10.0);
|
||||
assert!(ratio < 200.0);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user