//! Learning and Attention Module for Edge-Net //! //! Integrates RuVector's self-learning intelligence and attention mechanisms //! for distributed compute optimization. This module enables edge nodes to: //! //! - **Learn patterns** from task execution trajectories //! - **Store knowledge** in a ReasoningBank for retrieval //! - **Route tasks** using multi-head attention //! - **Optimize energy** with spike-driven attention (87x more efficient) //! //! ## Architecture //! //! ```text //! ┌─────────────────────────────────────────────────────┐ //! │ Learning Intelligence │ //! ├─────────────────────────────────────────────────────┤ //! │ ┌──────────────┐ ┌──────────────┐ ┌───────────┐ │ //! │ │ ReasoningBank│ │ Trajectory │ │ Pattern │ │ //! │ │ Storage │◄─┤ Tracker │──┤ Extractor │ │ //! │ └──────────────┘ └──────────────┘ └───────────┘ │ //! ├─────────────────────────────────────────────────────┤ //! │ ┌──────────────┐ ┌──────────────┐ │ //! │ │ Multi-Head │ │ Spike-Driven │ │ //! │ │ Attention │ │ Attention │ │ //! │ │ (Task Route) │ │ (87x Energy) │ │ //! │ └──────────────┘ └──────────────┘ │ //! └─────────────────────────────────────────────────────┘ //! ``` use wasm_bindgen::prelude::*; use serde::{Serialize, Deserialize}; use rustc_hash::FxHashMap; use std::sync::RwLock; // ============================================================================ // Learned Patterns // ============================================================================ /// A learned pattern from task execution #[derive(Clone, Debug, Serialize, Deserialize)] pub struct LearnedPattern { /// Centroid vector representing the pattern pub centroid: Vec, /// Optimal task allocation score pub optimal_allocation: f32, /// Optimal energy budget for this pattern pub optimal_energy: u64, /// Confidence score (0.0 - 1.0) pub confidence: f64, /// Number of samples in this pattern pub sample_count: usize, /// Average latency in milliseconds pub avg_latency_ms: f64, /// Average success rate pub avg_success_rate: Option, } impl LearnedPattern { /// Create a new learned pattern pub fn new( centroid: Vec, optimal_allocation: f32, optimal_energy: u64, confidence: f64, sample_count: usize, avg_latency_ms: f64, avg_success_rate: Option, ) -> Self { Self { centroid, optimal_allocation, optimal_energy, confidence, sample_count, avg_latency_ms, avg_success_rate, } } /// Calculate cosine similarity to a query vector pub fn similarity(&self, query: &[f32]) -> f64 { if query.len() != self.centroid.len() { return 0.0; } let dot: f32 = query.iter().zip(&self.centroid).map(|(a, b)| a * b).sum(); let norm_q: f32 = query.iter().map(|x| x * x).sum::().sqrt(); let norm_c: f32 = self.centroid.iter().map(|x| x * x).sum::().sqrt(); if norm_q == 0.0 || norm_c == 0.0 { return 0.0; } (dot / (norm_q * norm_c)) as f64 } } // ============================================================================ // Task Trajectory // ============================================================================ /// A single task execution trajectory #[derive(Clone, Debug, Serialize, Deserialize)] pub struct TaskTrajectory { /// Task feature vector pub task_vector: Vec, /// Execution latency in milliseconds pub latency_ms: u64, /// Energy consumed (rUv) pub energy_spent: u64, /// Energy earned (rUv) pub energy_earned: u64, /// Task success flag pub success: bool, /// Node that executed the task pub executor_id: String, /// Timestamp (ms since epoch) pub timestamp: u64, } impl TaskTrajectory { /// Create a new task trajectory pub fn new( task_vector: Vec, latency_ms: u64, energy_spent: u64, energy_earned: u64, success: bool, executor_id: String, ) -> Self { Self { task_vector, latency_ms, energy_spent, energy_earned, success, executor_id, timestamp: js_sys::Date::now() as u64, } } /// Calculate efficiency ratio (earned/spent) pub fn efficiency(&self) -> f64 { if self.energy_spent == 0 { return 0.0; } self.energy_earned as f64 / self.energy_spent as f64 } } // ============================================================================ // Trajectory Tracker // ============================================================================ /// Ring buffer tracker for task trajectories #[wasm_bindgen] pub struct TrajectoryTracker { /// Ring buffer of trajectories trajectories: RwLock>, /// Maximum size max_size: usize, /// Current write position write_pos: RwLock, } #[wasm_bindgen] impl TrajectoryTracker { /// Create a new trajectory tracker #[wasm_bindgen(constructor)] pub fn new(max_size: usize) -> Self { Self { trajectories: RwLock::new(Vec::with_capacity(max_size)), max_size, write_pos: RwLock::new(0), } } /// Record a new trajectory #[wasm_bindgen] pub fn record(&self, trajectory_json: &str) -> bool { let trajectory: TaskTrajectory = match serde_json::from_str(trajectory_json) { Ok(t) => t, Err(_) => return false, }; let mut trajectories = self.trajectories.write().unwrap(); let mut pos = self.write_pos.write().unwrap(); if trajectories.len() < self.max_size { trajectories.push(trajectory); } else { trajectories[*pos] = trajectory; } *pos = (*pos + 1) % self.max_size; true } /// Get statistics as JSON #[wasm_bindgen(js_name = getStats)] pub fn get_stats(&self) -> String { let trajectories = self.trajectories.read().unwrap(); if trajectories.is_empty() { return r#"{"total":0}"#.to_string(); } let total = trajectories.len(); let successful = trajectories.iter().filter(|t| t.success).count(); let avg_latency = trajectories.iter().map(|t| t.latency_ms).sum::() as f64 / total as f64; let avg_efficiency = trajectories.iter().map(|t| t.efficiency()).sum::() / total as f64; format!( r#"{{"total":{},"successful":{},"success_rate":{:.4},"avg_latency_ms":{:.2},"avg_efficiency":{:.4}}}"#, total, successful, successful as f64 / total as f64, avg_latency, avg_efficiency ) } /// Get count of trajectories #[wasm_bindgen] pub fn count(&self) -> usize { self.trajectories.read().unwrap().len() } } // ============================================================================ // Reasoning Bank // ============================================================================ /// Pattern entry with usage tracking #[derive(Clone)] struct PatternEntry { pattern: LearnedPattern, usage_count: usize, last_used: u64, } /// Spatial bucket for fast approximate nearest neighbor search struct SpatialBucket { pattern_ids: Vec, } /// ReasoningBank for storing and retrieving learned patterns /// Optimized with spatial indexing for O(1) approximate lookups #[wasm_bindgen] pub struct ReasoningBank { /// Stored patterns indexed by ID patterns: RwLock>, /// Next pattern ID next_id: RwLock, /// Spatial index for fast approximate nearest neighbor /// Maps quantized vector hash to pattern IDs spatial_index: RwLock>, } #[wasm_bindgen] impl ReasoningBank { /// Create a new ReasoningBank #[wasm_bindgen(constructor)] pub fn new() -> ReasoningBank { ReasoningBank { patterns: RwLock::new(FxHashMap::default()), next_id: RwLock::new(0), spatial_index: RwLock::new(FxHashMap::default()), } } /// Hash a vector into a spatial bucket (locality-sensitive hashing) fn spatial_hash(vector: &[f32]) -> u64 { // Simple grid-based quantization for fast approximate matching // Quantize each dimension to 8 levels (3 bits) let mut hash = 0u64; for (i, &val) in vector.iter().take(20).enumerate() { // Normalize to [0, 7] range let quantized = ((val + 1.0) * 3.5).clamp(0.0, 7.0) as u64; hash |= quantized << (i * 3); } hash } /// Store a new pattern (JSON format) #[wasm_bindgen] pub fn store(&self, pattern_json: &str) -> i32 { let pattern: LearnedPattern = match serde_json::from_str(pattern_json) { Ok(p) => p, Err(_) => return -1, }; // Compute spatial hash for indexing let hash = Self::spatial_hash(&pattern.centroid); let mut next_id = self.next_id.write().unwrap(); let id = *next_id; *next_id += 1; let entry = PatternEntry { pattern, usage_count: 0, last_used: js_sys::Date::now() as u64, }; self.patterns.write().unwrap().insert(id, entry); // Add to spatial index let mut index = self.spatial_index.write().unwrap(); index.entry(hash) .or_insert_with(|| SpatialBucket { pattern_ids: Vec::with_capacity(10) }) .pattern_ids.push(id); id as i32 } /// Lookup most similar patterns (OPTIMIZED with spatial indexing) #[wasm_bindgen] pub fn lookup(&self, query_json: &str, k: usize) -> String { let query: Vec = match serde_json::from_str(query_json) { Ok(q) => q, Err(_) => return "[]".to_string(), }; let query_hash = Self::spatial_hash(&query); let now = js_sys::Date::now() as u64; // Step 1: Fast approximate search using spatial index let index = self.spatial_index.read().unwrap(); let mut candidate_ids = Vec::with_capacity(k * 3); // Pre-allocate // Get patterns from same bucket if let Some(bucket) = index.get(&query_hash) { candidate_ids.extend_from_slice(&bucket.pattern_ids); } // Check neighboring buckets (increase recall) // Flip 1-2 bits in hash to find nearby buckets for bit_flip in 0..6 { let neighbor_hash = query_hash ^ (1u64 << (bit_flip * 3)); if let Some(bucket) = index.get(&neighbor_hash) { candidate_ids.extend_from_slice(&bucket.pattern_ids); } } // Fallback: if too few candidates, scan more buckets if candidate_ids.len() < k * 2 { for bucket in index.values().take(10) { candidate_ids.extend_from_slice(&bucket.pattern_ids); if candidate_ids.len() >= k * 3 { break; } } } // Step 2: Exact similarity computation only for candidates let mut patterns = self.patterns.write().unwrap(); let mut similarities = Vec::with_capacity(candidate_ids.len()); for &id in &candidate_ids { if let Some(entry) = patterns.get_mut(&id) { let similarity = entry.pattern.similarity(&query); entry.usage_count += 1; entry.last_used = now; similarities.push((id, entry.pattern.clone(), similarity)); } } // Sort by weighted score (similarity * confidence) similarities.sort_unstable_by(|a, b| { let score_a = a.2 * a.1.confidence; let score_b = b.2 * b.1.confidence; score_b.partial_cmp(&score_a).unwrap_or(std::cmp::Ordering::Equal) }); similarities.truncate(k); // Pre-allocate string with estimated capacity let mut result = String::with_capacity(k * 120); result.push('['); for (i, (id, pattern, sim)) in similarities.iter().enumerate() { if i > 0 { result.push(','); } use std::fmt::Write; let _ = write!( result, r#"{{"id":{},"similarity":{:.4},"confidence":{:.4},"optimal_allocation":{:.4},"optimal_energy":{}}}"#, id, sim, pattern.confidence, pattern.optimal_allocation, pattern.optimal_energy ); } result.push(']'); result } /// Prune low-quality patterns #[wasm_bindgen] pub fn prune(&self, min_usage: usize, min_confidence: f64) -> usize { let mut patterns = self.patterns.write().unwrap(); let before = patterns.len(); patterns.retain(|_, entry| { entry.usage_count >= min_usage && entry.pattern.confidence >= min_confidence }); before - patterns.len() } /// Get total pattern count #[wasm_bindgen] pub fn count(&self) -> usize { self.patterns.read().unwrap().len() } /// Get bank statistics #[wasm_bindgen(js_name = getStats)] pub fn get_stats(&self) -> String { let patterns = self.patterns.read().unwrap(); if patterns.is_empty() { return r#"{"total":0}"#.to_string(); } let total = patterns.len(); let total_samples: usize = patterns.values().map(|e| e.pattern.sample_count).sum(); let avg_confidence: f64 = patterns.values().map(|e| e.pattern.confidence).sum::() / total as f64; let total_usage: usize = patterns.values().map(|e| e.usage_count).sum(); format!( r#"{{"total_patterns":{},"total_samples":{},"avg_confidence":{:.4},"total_usage":{}}}"#, total, total_samples, avg_confidence, total_usage ) } } impl Default for ReasoningBank { fn default() -> Self { Self::new() } } // ============================================================================ // Spike Train for Energy-Efficient Attention // ============================================================================ /// Spike train representation for temporal coding #[derive(Clone, Debug, Default)] pub struct SpikeTrain { /// Spike times within temporal window pub times: Vec, /// Spike polarities: +1 for positive, -1 for negative pub polarities: Vec, } impl SpikeTrain { /// Create empty spike train pub fn new() -> Self { Self { times: Vec::new(), polarities: Vec::new(), } } /// Create spike train with pre-allocated capacity pub fn with_capacity(capacity: usize) -> Self { Self { times: Vec::with_capacity(capacity), polarities: Vec::with_capacity(capacity), } } /// Add a spike at given time with polarity pub fn add_spike(&mut self, time: u8, polarity: i8) { self.times.push(time); self.polarities.push(polarity); } /// Number of spikes pub fn len(&self) -> usize { self.times.len() } /// Check if empty pub fn is_empty(&self) -> bool { self.times.is_empty() } } // ============================================================================ // Spike-Driven Attention // ============================================================================ /// Configuration for spike-driven attention #[derive(Clone, Debug)] pub struct SpikeDrivenConfig { /// Spike threshold in Q15 fixed-point pub spike_threshold_q15: u16, /// Number of temporal coding steps pub temporal_coding_steps: u8, /// Use binary quantization pub binary_qkv: bool, /// Refractory period after spike pub refractory_period: u8, } impl Default for SpikeDrivenConfig { fn default() -> Self { Self { spike_threshold_q15: 16384, // 0.5 in Q15 temporal_coding_steps: 8, binary_qkv: true, refractory_period: 2, } } } /// Spike-driven attention for energy-efficient compute (87x savings) #[wasm_bindgen] pub struct SpikeDrivenAttention { config: SpikeDrivenConfig, } #[wasm_bindgen] impl SpikeDrivenAttention { /// Create new spike-driven attention with default config #[wasm_bindgen(constructor)] pub fn new() -> Self { Self { config: SpikeDrivenConfig::default(), } } /// Create with custom parameters #[wasm_bindgen(js_name = withConfig)] pub fn with_config(threshold: u16, steps: u8, refractory: u8) -> Self { Self { config: SpikeDrivenConfig { spike_threshold_q15: threshold, temporal_coding_steps: steps, binary_qkv: true, refractory_period: refractory, }, } } /// Estimate energy savings ratio compared to standard attention #[wasm_bindgen(js_name = energyRatio)] pub fn energy_ratio(&self, seq_len: usize, hidden_dim: usize) -> f32 { if seq_len == 0 || hidden_dim == 0 { return 1.0; } // Standard attention operations (multiplications) let standard_mults = 2 * seq_len * seq_len * hidden_dim; // Spike-driven operations (additions only) let avg_spikes_per_neuron = (self.config.temporal_coding_steps as f32) * 0.3; let spike_adds = (seq_len as f32) * avg_spikes_per_neuron * (hidden_dim as f32); // Energy ratio (multiplication ~3.7x more expensive than addition) let mult_energy_factor = 3.7; let standard_energy = (standard_mults as f32) * mult_energy_factor; let spike_energy = spike_adds; if spike_energy == 0.0 { return 1.0; } standard_energy / spike_energy } } impl Default for SpikeDrivenAttention { fn default() -> Self { Self::new() } } impl SpikeDrivenAttention { /// Encode values to spike trains using rate coding (OPTIMIZED with pre-allocation) pub fn encode_spikes(&self, values: &[i8]) -> Vec { let steps = self.config.temporal_coding_steps as usize; let mut trains = Vec::with_capacity(values.len()); for &value in values { // Pre-allocate spike train capacity (max possible spikes) let mut train = SpikeTrain::with_capacity(steps); let abs_val = if value == i8::MIN { 128u16 } else { value.abs() as u16 }; let polarity = value.signum(); if abs_val == 0 { trains.push(train); continue; } // Rate coding: spike frequency proportional to magnitude let rate_q15 = ((abs_val as u32) * 32768 / 128) as u16; let mut refractory_counter = 0u8; let mut membrane_potential = 0u32; for step in 0..steps { if refractory_counter > 0 { refractory_counter -= 1; continue; } membrane_potential = membrane_potential.saturating_add(rate_q15 as u32); if membrane_potential >= self.config.spike_threshold_q15 as u32 { train.add_spike(step as u8, polarity); membrane_potential = 0; refractory_counter = self.config.refractory_period; } } trains.push(train); } trains } /// Compute spike-driven attention (no multiplications) pub fn attention( &self, q_spikes: &[SpikeTrain], k_spikes: &[SpikeTrain], v_spikes: &[SpikeTrain], ) -> Vec { let seq_len = q_spikes.len().min(k_spikes.len()); let hidden_dim = v_spikes.len(); let mut output = vec![0i32; hidden_dim]; if seq_len == 0 || hidden_dim == 0 { return output; } for q_idx in 0..seq_len { let q_train = &q_spikes[q_idx]; // Compute attention weights via spike coincidence for k_idx in 0..=q_idx.min(seq_len - 1) { let k_train = &k_spikes[k_idx]; let mut coincidence_score = 0i32; for (&q_time, &q_pol) in q_train.times.iter().zip(q_train.polarities.iter()) { for (&k_time, &k_pol) in k_train.times.iter().zip(k_train.polarities.iter()) { if q_time == k_time { coincidence_score += (q_pol as i32) * (k_pol as i32); } } } if coincidence_score != 0 { for (d, v_train) in v_spikes.iter().enumerate().take(hidden_dim) { let value_contrib: i32 = v_train.polarities.iter() .map(|&p| (p as i32).saturating_mul(coincidence_score)) .sum(); output[d] += value_contrib; } } } } output } } // ============================================================================ // Multi-Head Attention for Task Routing // ============================================================================ /// Multi-head attention for distributed task routing #[wasm_bindgen] pub struct MultiHeadAttention { dim: usize, num_heads: usize, head_dim: usize, } #[wasm_bindgen] impl MultiHeadAttention { /// Create new multi-head attention #[wasm_bindgen(constructor)] pub fn new(dim: usize, num_heads: usize) -> Self { let head_dim = dim / num_heads; Self { dim, num_heads, head_dim } } /// Get embedding dimension #[wasm_bindgen] pub fn dim(&self) -> usize { self.dim } /// Get number of heads #[wasm_bindgen(js_name = numHeads)] pub fn num_heads(&self) -> usize { self.num_heads } } impl MultiHeadAttention { /// Split input into multiple heads fn split_heads(&self, input: &[f32]) -> Vec> { (0..self.num_heads) .map(|h| { let start = h * self.head_dim; let end = start + self.head_dim; input[start..end].to_vec() }) .collect() } /// Compute scaled dot-product attention for a single head fn scaled_dot_product(&self, query: &[f32], keys: &[&[f32]], values: &[&[f32]]) -> Vec { let scale = (self.head_dim as f32).sqrt(); // Compute attention scores let scores: Vec = keys.iter() .map(|k| { let dot: f32 = query.iter().zip(*k).map(|(q, k)| q * k).sum(); dot / scale }) .collect(); // Softmax let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max); let exp_scores: Vec = scores.iter().map(|s| (s - max_score).exp()).collect(); let sum_exp: f32 = exp_scores.iter().sum(); let attention_weights: Vec = exp_scores.iter().map(|e| e / sum_exp).collect(); // Weighted sum of values let mut output = vec![0.0f32; self.head_dim]; for (weight, value) in attention_weights.iter().zip(values.iter()) { for (o, v) in output.iter_mut().zip(value.iter()) { *o += weight * v; } } output } /// Compute multi-head attention pub fn compute(&self, query: &[f32], keys: &[&[f32]], values: &[&[f32]]) -> Vec { if query.len() != self.dim { return vec![0.0; self.dim]; } // Split query into heads let query_heads = self.split_heads(query); // Split keys and values let key_heads: Vec>> = keys.iter().map(|k| self.split_heads(k)).collect(); let value_heads: Vec>> = values.iter().map(|v| self.split_heads(v)).collect(); // Compute attention for each head let mut head_outputs = Vec::new(); for h in 0..self.num_heads { let head_keys: Vec<&[f32]> = key_heads.iter().map(|kh| kh[h].as_slice()).collect(); let head_values: Vec<&[f32]> = value_heads.iter().map(|vh| vh[h].as_slice()).collect(); let head_out = self.scaled_dot_product(&query_heads[h], &head_keys, &head_values); head_outputs.push(head_out); } // Concatenate head outputs head_outputs.into_iter().flatten().collect() } } // ============================================================================ // Network Learning Intelligence // ============================================================================ /// Unified learning intelligence for edge-net nodes #[wasm_bindgen] pub struct NetworkLearning { /// Pattern storage reasoning_bank: ReasoningBank, /// Trajectory tracking trajectory_tracker: TrajectoryTracker, /// Spike-driven attention for energy efficiency spike_attention: SpikeDrivenAttention, /// Multi-head attention for task routing multi_head: MultiHeadAttention, /// Learning rate for online updates learning_rate: f32, } #[wasm_bindgen] impl NetworkLearning { /// Create new network learning intelligence #[wasm_bindgen(constructor)] pub fn new() -> Self { Self { reasoning_bank: ReasoningBank::new(), trajectory_tracker: TrajectoryTracker::new(1000), spike_attention: SpikeDrivenAttention::new(), multi_head: MultiHeadAttention::new(64, 4), // 64-dim, 4 heads learning_rate: 0.01, } } /// Record a task execution trajectory #[wasm_bindgen(js_name = recordTrajectory)] pub fn record_trajectory(&self, trajectory_json: &str) -> bool { self.trajectory_tracker.record(trajectory_json) } /// Store a learned pattern #[wasm_bindgen(js_name = storePattern)] pub fn store_pattern(&self, pattern_json: &str) -> i32 { self.reasoning_bank.store(pattern_json) } /// Look up similar patterns #[wasm_bindgen(js_name = lookupPatterns)] pub fn lookup_patterns(&self, query_json: &str, k: usize) -> String { self.reasoning_bank.lookup(query_json, k) } /// Get energy savings ratio for spike-driven attention #[wasm_bindgen(js_name = getEnergyRatio)] pub fn get_energy_ratio(&self, seq_len: usize, hidden_dim: usize) -> f32 { self.spike_attention.energy_ratio(seq_len, hidden_dim) } /// Get combined statistics #[wasm_bindgen(js_name = getStats)] pub fn get_stats(&self) -> String { let bank_stats = self.reasoning_bank.get_stats(); let traj_stats = self.trajectory_tracker.get_stats(); let energy_ratio = self.spike_attention.energy_ratio(64, 256); format!( r#"{{"reasoning_bank":{},"trajectories":{},"spike_energy_ratio":{:.2},"learning_rate":{}}}"#, bank_stats, traj_stats, energy_ratio, self.learning_rate ) } /// Prune low-quality patterns #[wasm_bindgen] pub fn prune(&self, min_usage: usize, min_confidence: f64) -> usize { self.reasoning_bank.prune(min_usage, min_confidence) } /// Get trajectory count #[wasm_bindgen(js_name = trajectoryCount)] pub fn trajectory_count(&self) -> usize { self.trajectory_tracker.count() } /// Get pattern count #[wasm_bindgen(js_name = patternCount)] pub fn pattern_count(&self) -> usize { self.reasoning_bank.count() } } impl Default for NetworkLearning { fn default() -> Self { Self::new() } } // ============================================================================ // Tests // ============================================================================ #[cfg(test)] mod tests { use super::*; #[test] fn test_learned_pattern_similarity() { let pattern = LearnedPattern::new( vec![1.0, 0.0, 0.0], 0.8, 100, 0.9, 10, 50.0, Some(0.95), ); let query_same = vec![1.0, 0.0, 0.0]; let query_perp = vec![0.0, 1.0, 0.0]; assert!((pattern.similarity(&query_same) - 1.0).abs() < 0.001); assert!((pattern.similarity(&query_perp) - 0.0).abs() < 0.001); } #[test] fn test_task_trajectory_efficiency() { let traj = TaskTrajectory { task_vector: vec![1.0, 2.0], latency_ms: 100, energy_spent: 50, energy_earned: 100, success: true, executor_id: "node-1".to_string(), timestamp: 0, }; assert!((traj.efficiency() - 2.0).abs() < 0.001); } #[test] fn test_spike_train() { let mut train = SpikeTrain::new(); assert!(train.is_empty()); train.add_spike(0, 1); train.add_spike(3, -1); assert_eq!(train.len(), 2); assert_eq!(train.times, vec![0, 3]); assert_eq!(train.polarities, vec![1, -1]); } #[test] fn test_spike_encoding() { let attn = SpikeDrivenAttention::new(); let values = vec![64i8, 0, -64]; let trains = attn.encode_spikes(&values); assert_eq!(trains.len(), 3); assert!(trains[0].len() > 0); // High positive assert!(trains[1].is_empty()); // Zero assert!(trains[2].len() > 0); // High negative assert!(trains[2].polarities.iter().all(|&p| p == -1)); } #[test] fn test_multi_head_attention() { let attn = MultiHeadAttention::new(8, 2); let query = vec![1.0_f32; 8]; let key1 = vec![0.5_f32; 8]; let val1 = vec![1.0_f32; 8]; let keys: Vec<&[f32]> = vec![key1.as_slice()]; let values: Vec<&[f32]> = vec![val1.as_slice()]; let result = attn.compute(&query, &keys, &values); assert_eq!(result.len(), 8); } #[test] fn test_energy_ratio() { let attn = SpikeDrivenAttention::new(); let ratio = attn.energy_ratio(64, 256); // Should show significant energy savings assert!(ratio > 10.0); assert!(ratio < 200.0); } }