git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
474 lines
13 KiB
Rust
474 lines
13 KiB
Rust
//! Integration tests for spike-driven attention.
|
|
//!
|
|
//! Tests the complete spike-driven attention pipeline including:
|
|
//! - Spike encoding from quantized values
|
|
//! - Spike-based attention computation
|
|
//! - Energy efficiency metrics
|
|
//! - Integration with existing transformer components
|
|
|
|
#![cfg(feature = "spike_attention")]
|
|
|
|
use ruvector_mincut_gated_transformer::{SpikeDrivenAttention, SpikeDrivenConfig, SpikeTrain};
|
|
|
|
#[test]
|
|
fn test_spike_train_basic_operations() {
|
|
let mut train = SpikeTrain::new();
|
|
|
|
assert!(train.is_empty());
|
|
assert_eq!(train.len(), 0);
|
|
|
|
train.add_spike(0, 1);
|
|
train.add_spike(2, 1);
|
|
train.add_spike(5, -1);
|
|
|
|
assert_eq!(train.len(), 3);
|
|
assert_eq!(train.times.len(), 3);
|
|
assert_eq!(train.polarities.len(), 3);
|
|
|
|
train.clear();
|
|
assert!(train.is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn test_spike_encoding_range() {
|
|
let config = SpikeDrivenConfig {
|
|
spike_threshold_q15: 16384,
|
|
temporal_coding_steps: 10,
|
|
binary_qkv: true,
|
|
refractory_period: 2,
|
|
};
|
|
let attn = SpikeDrivenAttention::new(config);
|
|
|
|
// Test full i8 range
|
|
let values: Vec<i8> = vec![-128, -64, -32, 0, 32, 64, 127];
|
|
let trains = attn.encode_spikes(&values);
|
|
|
|
assert_eq!(trains.len(), 7);
|
|
|
|
// Zero produces no spikes
|
|
assert_eq!(trains[3].len(), 0);
|
|
|
|
// Negative values should have negative polarity
|
|
for train in &trains[0..3] {
|
|
if !train.is_empty() {
|
|
assert!(train.polarities.iter().all(|&p| p == -1));
|
|
}
|
|
}
|
|
|
|
// Positive values should have positive polarity
|
|
for train in &trains[4..7] {
|
|
if !train.is_empty() {
|
|
assert!(train.polarities.iter().all(|&p| p == 1));
|
|
}
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_spike_encoding_proportionality() {
|
|
let config = SpikeDrivenConfig::default();
|
|
let attn = SpikeDrivenAttention::new(config);
|
|
|
|
// Higher magnitude should produce more spikes
|
|
let values = vec![127i8, 64, 32, 16, 8];
|
|
let trains = attn.encode_spikes(&values);
|
|
|
|
// Verify descending spike counts
|
|
for i in 0..trains.len() - 1 {
|
|
assert!(
|
|
trains[i].len() >= trains[i + 1].len(),
|
|
"Higher values should produce more spikes: {} vs {}",
|
|
trains[i].len(),
|
|
trains[i + 1].len()
|
|
);
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_refractory_period_enforcement() {
|
|
let config = SpikeDrivenConfig {
|
|
spike_threshold_q15: 4096, // Low threshold for more spikes
|
|
temporal_coding_steps: 20,
|
|
binary_qkv: true,
|
|
refractory_period: 5,
|
|
};
|
|
let refractory_period = config.refractory_period;
|
|
let attn = SpikeDrivenAttention::new(config);
|
|
|
|
let values = vec![127i8]; // Maximum value
|
|
let trains = attn.encode_spikes(&values);
|
|
|
|
if trains[0].len() > 1 {
|
|
// Verify refractory period between consecutive spikes
|
|
for i in 1..trains[0].times.len() {
|
|
let time_diff = trains[0].times[i] - trains[0].times[i - 1];
|
|
assert!(
|
|
time_diff > refractory_period,
|
|
"Spikes should respect refractory period: diff={}, period={}",
|
|
time_diff,
|
|
refractory_period
|
|
);
|
|
}
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_attention_output_shape() {
|
|
let attn = SpikeDrivenAttention::default();
|
|
|
|
// Create simple spike trains
|
|
let seq_len = 4;
|
|
let hidden_dim = 8;
|
|
|
|
let mut q_spikes = Vec::with_capacity(seq_len);
|
|
let mut k_spikes = Vec::with_capacity(seq_len);
|
|
let mut v_spikes = Vec::with_capacity(hidden_dim);
|
|
|
|
for _ in 0..seq_len {
|
|
let mut train = SpikeTrain::new();
|
|
train.add_spike(0, 1);
|
|
q_spikes.push(train.clone());
|
|
k_spikes.push(train);
|
|
}
|
|
|
|
for _ in 0..hidden_dim {
|
|
let mut train = SpikeTrain::new();
|
|
train.add_spike(1, 1);
|
|
v_spikes.push(train);
|
|
}
|
|
|
|
let output = attn.attention(&q_spikes, &k_spikes, &v_spikes);
|
|
|
|
assert_eq!(output.len(), hidden_dim);
|
|
}
|
|
|
|
#[test]
|
|
fn test_attention_causal_masking() {
|
|
let attn = SpikeDrivenAttention::default();
|
|
|
|
// Create spike trains where future positions have different patterns
|
|
let mut q_spikes = vec![];
|
|
let mut k_spikes = vec![];
|
|
let mut v_spikes = vec![];
|
|
|
|
// Position 0 query
|
|
let mut q0 = SpikeTrain::new();
|
|
q0.add_spike(0, 1);
|
|
q_spikes.push(q0);
|
|
|
|
// Position 0 key (should match)
|
|
let mut k0 = SpikeTrain::new();
|
|
k0.add_spike(0, 1);
|
|
k_spikes.push(k0);
|
|
|
|
// Position 1 key (should not affect position 0's attention)
|
|
let mut k1 = SpikeTrain::new();
|
|
k1.add_spike(0, 1);
|
|
k1.add_spike(1, 1);
|
|
k1.add_spike(2, 1); // Much stronger signal
|
|
k_spikes.push(k1);
|
|
|
|
let mut v0 = SpikeTrain::new();
|
|
v0.add_spike(0, 1);
|
|
v_spikes.push(v0);
|
|
|
|
// Compute attention for position 0
|
|
// It should only see k0, not k1 (causal masking)
|
|
let output = attn.attention(&q_spikes, &k_spikes, &v_spikes);
|
|
|
|
assert_eq!(output.len(), 1);
|
|
// Output should be non-zero due to coincidence at position 0
|
|
assert_ne!(output[0], 0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_coincidence_detection() {
|
|
let attn = SpikeDrivenAttention::default();
|
|
|
|
// Test spike coincidence scoring
|
|
let mut q_train = SpikeTrain::new();
|
|
q_train.add_spike(0, 1);
|
|
q_train.add_spike(5, 1);
|
|
|
|
let mut k_coincident = SpikeTrain::new();
|
|
k_coincident.add_spike(0, 1); // Matches q at time 0
|
|
k_coincident.add_spike(5, 1); // Matches q at time 5
|
|
|
|
let mut k_no_match = SpikeTrain::new();
|
|
k_no_match.add_spike(1, 1); // No match
|
|
k_no_match.add_spike(3, 1); // No match
|
|
|
|
let mut v_train = SpikeTrain::new();
|
|
v_train.add_spike(0, 1);
|
|
|
|
let q_spikes = vec![q_train];
|
|
let k_spikes = vec![k_coincident, k_no_match];
|
|
let v_spikes = vec![v_train];
|
|
|
|
let output = attn.attention(&q_spikes, &k_spikes, &v_spikes);
|
|
|
|
// Should have stronger output due to coincident k0
|
|
assert_ne!(output[0], 0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_polarity_interaction() {
|
|
let attn = SpikeDrivenAttention::default();
|
|
|
|
// Test opposite polarities
|
|
let mut q_pos = SpikeTrain::new();
|
|
q_pos.add_spike(0, 1);
|
|
|
|
let mut k_pos = SpikeTrain::new();
|
|
k_pos.add_spike(0, 1); // Same polarity
|
|
|
|
let mut k_neg = SpikeTrain::new();
|
|
k_neg.add_spike(0, -1); // Opposite polarity
|
|
|
|
let mut v_train = SpikeTrain::new();
|
|
v_train.add_spike(0, 1);
|
|
|
|
// Test with same polarity
|
|
let q_spikes = vec![q_pos.clone()];
|
|
let k_spikes = vec![k_pos];
|
|
let v_spikes = vec![v_train.clone()];
|
|
|
|
let output_same = attn.attention(&q_spikes, &k_spikes, &v_spikes);
|
|
|
|
// Test with opposite polarity
|
|
let k_spikes_opp = vec![k_neg];
|
|
let output_opp = attn.attention(&q_spikes, &k_spikes_opp, &v_spikes);
|
|
|
|
// Opposite polarities should produce negative contribution
|
|
assert!(output_same[0] > 0);
|
|
assert!(output_opp[0] < 0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_sparse_attention_top_k() {
|
|
let attn = SpikeDrivenAttention::default();
|
|
|
|
// Create scenario with clear top-k positions
|
|
let mut q_spikes = vec![];
|
|
let mut k_spikes = vec![];
|
|
let mut v_spikes = vec![];
|
|
|
|
let mut q = SpikeTrain::new();
|
|
q.add_spike(0, 1);
|
|
q.add_spike(1, 1);
|
|
q_spikes.push(q);
|
|
|
|
// k0: strong match (2 coincidences)
|
|
let mut k0 = SpikeTrain::new();
|
|
k0.add_spike(0, 1);
|
|
k0.add_spike(1, 1);
|
|
k_spikes.push(k0);
|
|
|
|
// k1: weak match (1 coincidence)
|
|
let mut k1 = SpikeTrain::new();
|
|
k1.add_spike(0, 1);
|
|
k_spikes.push(k1);
|
|
|
|
// k2: no match
|
|
let mut k2 = SpikeTrain::new();
|
|
k2.add_spike(5, 1);
|
|
k_spikes.push(k2);
|
|
|
|
let mut v = SpikeTrain::new();
|
|
v.add_spike(0, 1);
|
|
v_spikes.push(v);
|
|
|
|
// Top-1 should only use strongest match (k0)
|
|
let output_top1 = attn.sparse_attention(&q_spikes, &k_spikes, &v_spikes, 1);
|
|
|
|
// Top-2 should use both k0 and k1
|
|
let output_top2 = attn.sparse_attention(&q_spikes, &k_spikes, &v_spikes, 2);
|
|
|
|
assert_eq!(output_top1.len(), 1);
|
|
assert_eq!(output_top2.len(), 1);
|
|
|
|
// Top-2 should have higher magnitude (more contributions)
|
|
assert!(output_top2[0].abs() >= output_top1[0].abs());
|
|
}
|
|
|
|
#[test]
|
|
fn test_energy_ratio_estimation() {
|
|
let attn = SpikeDrivenAttention::default();
|
|
|
|
// Test various sequence lengths
|
|
let test_cases = vec![
|
|
(16, 64), // Small
|
|
(64, 256), // Medium
|
|
(128, 512), // Large
|
|
];
|
|
|
|
for (seq_len, hidden_dim) in test_cases {
|
|
let ratio = attn.energy_ratio(seq_len, hidden_dim);
|
|
|
|
// Should show significant energy savings
|
|
assert!(
|
|
ratio > 5.0,
|
|
"Energy ratio should be > 5x for ({}, {})",
|
|
seq_len,
|
|
hidden_dim
|
|
);
|
|
|
|
// Should be finite and positive
|
|
assert!(ratio.is_finite());
|
|
assert!(ratio > 0.0);
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_energy_ratio_scaling() {
|
|
let attn = SpikeDrivenAttention::default();
|
|
|
|
// Energy ratio should improve with larger sequences
|
|
// (more multiplications avoided)
|
|
let ratio_small = attn.energy_ratio(16, 64);
|
|
let ratio_large = attn.energy_ratio(128, 512);
|
|
|
|
assert!(
|
|
ratio_large > ratio_small,
|
|
"Energy savings should increase with size: small={}, large={}",
|
|
ratio_small,
|
|
ratio_large
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_binarization() {
|
|
let attn = SpikeDrivenAttention::default();
|
|
|
|
let values = vec![-127, -50, -1, 0, 1, 50, 127];
|
|
let binary = attn.binarize(&values);
|
|
|
|
assert_eq!(binary.len(), values.len());
|
|
|
|
// All values should be in {-1, 0, 1}
|
|
for &b in &binary {
|
|
assert!(b >= -1 && b <= 1);
|
|
}
|
|
|
|
// Check specific mappings
|
|
assert_eq!(binary[0], -1); // negative -> -1
|
|
assert_eq!(binary[3], 0); // zero -> 0
|
|
assert_eq!(binary[6], 1); // positive -> 1
|
|
}
|
|
|
|
#[test]
|
|
fn test_end_to_end_encoding_and_attention() {
|
|
let config = SpikeDrivenConfig {
|
|
spike_threshold_q15: 16384,
|
|
temporal_coding_steps: 8,
|
|
binary_qkv: true,
|
|
refractory_period: 2,
|
|
};
|
|
let attn = SpikeDrivenAttention::new(config);
|
|
|
|
// Simulate simple sequence: [high, medium, low, zero]
|
|
let q_values = vec![100i8, 50, 25, 0];
|
|
let k_values = vec![100i8, 50, 25, 0];
|
|
let v_values = vec![64i8, 32, 16, 8];
|
|
|
|
// Encode to spike trains
|
|
let q_spikes = attn.encode_spikes(&q_values);
|
|
let k_spikes = attn.encode_spikes(&k_values);
|
|
let v_spikes = attn.encode_spikes(&v_values);
|
|
|
|
// Compute attention
|
|
let output = attn.attention(&q_spikes, &k_spikes, &v_spikes);
|
|
|
|
assert_eq!(output.len(), v_values.len());
|
|
|
|
// Output should be non-zero for positions with spike activity
|
|
assert!(output.iter().any(|&x| x != 0));
|
|
}
|
|
|
|
#[test]
|
|
fn test_zero_length_sequences() {
|
|
let attn = SpikeDrivenAttention::default();
|
|
|
|
let empty_spikes: Vec<SpikeTrain> = vec![];
|
|
let output = attn.attention(&empty_spikes, &empty_spikes, &empty_spikes);
|
|
|
|
assert_eq!(output.len(), 0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_mismatched_dimensions() {
|
|
let attn = SpikeDrivenAttention::default();
|
|
|
|
let mut q_spikes = vec![SpikeTrain::new()];
|
|
let k_spikes = vec![SpikeTrain::new(), SpikeTrain::new()];
|
|
let v_spikes = vec![SpikeTrain::new()];
|
|
|
|
q_spikes[0].add_spike(0, 1);
|
|
|
|
// Should handle mismatched dimensions gracefully
|
|
let output = attn.attention(&q_spikes, &k_spikes, &v_spikes);
|
|
|
|
assert_eq!(output.len(), 1);
|
|
}
|
|
|
|
#[test]
|
|
fn test_high_temporal_resolution() {
|
|
let config = SpikeDrivenConfig {
|
|
spike_threshold_q15: 8192,
|
|
temporal_coding_steps: 32, // High temporal resolution
|
|
binary_qkv: false,
|
|
refractory_period: 1,
|
|
};
|
|
let temporal_coding_steps = config.temporal_coding_steps;
|
|
let attn = SpikeDrivenAttention::new(config);
|
|
|
|
let values = vec![127i8];
|
|
let trains = attn.encode_spikes(&values);
|
|
|
|
// Should produce more spikes with higher temporal resolution
|
|
assert!(trains[0].len() > 0);
|
|
|
|
// All spike times should be within temporal window
|
|
for &time in &trains[0].times {
|
|
assert!(time < temporal_coding_steps);
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_config_variations() {
|
|
// Test different configurations
|
|
let configs = vec![
|
|
SpikeDrivenConfig {
|
|
spike_threshold_q15: 8192,
|
|
temporal_coding_steps: 4,
|
|
binary_qkv: true,
|
|
refractory_period: 1,
|
|
},
|
|
SpikeDrivenConfig {
|
|
spike_threshold_q15: 24576,
|
|
temporal_coding_steps: 16,
|
|
binary_qkv: false,
|
|
refractory_period: 3,
|
|
},
|
|
];
|
|
|
|
for config in configs {
|
|
let temporal_coding_steps = config.temporal_coding_steps;
|
|
let attn = SpikeDrivenAttention::new(config);
|
|
|
|
let values = vec![64i8, -64];
|
|
let trains = attn.encode_spikes(&values);
|
|
|
|
assert_eq!(trains.len(), 2);
|
|
|
|
// Basic sanity checks
|
|
for train in &trains {
|
|
for &time in &train.times {
|
|
assert!(time < temporal_coding_steps);
|
|
}
|
|
for &polarity in &train.polarities {
|
|
assert!(polarity == 1 || polarity == -1);
|
|
}
|
|
}
|
|
}
|
|
}
|