397 lines
14 KiB
Rust
397 lines
14 KiB
Rust
//! Integration tests for ruvector-attention-unified-wasm
|
|
//!
|
|
//! Tests for unified attention mechanisms including:
|
|
//! - Multi-head self-attention
|
|
//! - Mamba SSM (Selective State Space Model)
|
|
//! - RWKV attention
|
|
//! - Flash attention approximation
|
|
//! - Hyperbolic attention
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use wasm_bindgen_test::*;
|
|
use super::super::common::*;
|
|
|
|
wasm_bindgen_test_configure!(run_in_browser);
|
|
|
|
// ========================================================================
|
|
// Multi-Head Attention Tests
|
|
// ========================================================================
|
|
|
|
#[wasm_bindgen_test]
|
|
fn test_multi_head_attention_basic() {
|
|
// Setup query, keys, values
|
|
let dim = 64;
|
|
let num_heads = 8;
|
|
let head_dim = dim / num_heads;
|
|
let seq_len = 16;
|
|
|
|
let query = random_vector(dim);
|
|
let keys: Vec<Vec<f32>> = (0..seq_len).map(|_| random_vector(dim)).collect();
|
|
let values: Vec<Vec<f32>> = (0..seq_len).map(|_| random_vector(dim)).collect();
|
|
|
|
// TODO: When ruvector-attention-unified-wasm is implemented:
|
|
// let attention = MultiHeadAttention::new(dim, num_heads);
|
|
// let output = attention.forward(&query, &keys, &values);
|
|
//
|
|
// Assert output shape
|
|
// assert_eq!(output.len(), dim);
|
|
// assert_finite(&output);
|
|
|
|
// Placeholder assertion
|
|
assert_eq!(query.len(), dim);
|
|
assert_eq!(keys.len(), seq_len);
|
|
}
|
|
|
|
#[wasm_bindgen_test]
|
|
fn test_multi_head_attention_output_shape() {
|
|
let dim = 128;
|
|
let num_heads = 16;
|
|
let seq_len = 32;
|
|
|
|
let queries: Vec<Vec<f32>> = (0..seq_len).map(|_| random_vector(dim)).collect();
|
|
let keys: Vec<Vec<f32>> = (0..seq_len).map(|_| random_vector(dim)).collect();
|
|
let values: Vec<Vec<f32>> = (0..seq_len).map(|_| random_vector(dim)).collect();
|
|
|
|
// TODO: Verify output shape matches (seq_len, dim)
|
|
// let attention = MultiHeadAttention::new(dim, num_heads);
|
|
// let outputs = attention.forward_batch(&queries, &keys, &values);
|
|
// assert_eq!(outputs.len(), seq_len);
|
|
// for output in &outputs {
|
|
// assert_eq!(output.len(), dim);
|
|
// assert_finite(output);
|
|
// }
|
|
|
|
assert_eq!(queries.len(), seq_len);
|
|
}
|
|
|
|
#[wasm_bindgen_test]
|
|
fn test_multi_head_attention_causality() {
|
|
// Test that causal masking works correctly
|
|
let dim = 32;
|
|
let seq_len = 8;
|
|
|
|
// TODO: Verify causal attention doesn't attend to future tokens
|
|
// let attention = MultiHeadAttention::new_causal(dim, 4);
|
|
// let weights = attention.get_attention_weights(&queries, &keys);
|
|
//
|
|
// For each position i, weights[i][j] should be 0 for j > i
|
|
// for i in 0..seq_len {
|
|
// for j in (i+1)..seq_len {
|
|
// assert_eq!(weights[i][j], 0.0, "Causal violation at ({}, {})", i, j);
|
|
// }
|
|
// }
|
|
|
|
assert!(dim > 0);
|
|
}
|
|
|
|
// ========================================================================
|
|
// Mamba SSM Tests
|
|
// ========================================================================
|
|
|
|
#[wasm_bindgen_test]
|
|
fn test_mamba_ssm_basic() {
|
|
// Test O(n) selective scan complexity
|
|
let dim = 64;
|
|
let seq_len = 100;
|
|
|
|
let input: Vec<Vec<f32>> = (0..seq_len).map(|_| random_vector(dim)).collect();
|
|
|
|
// TODO: When Mamba SSM is implemented:
|
|
// let mamba = MambaSSM::new(dim);
|
|
// let output = mamba.forward(&input);
|
|
//
|
|
// Assert O(n) complexity by timing
|
|
// let start = performance.now();
|
|
// mamba.forward(&input);
|
|
// let duration = performance.now() - start;
|
|
//
|
|
// Double input size should roughly double time (O(n))
|
|
// let input_2x = (0..seq_len*2).map(|_| random_vector(dim)).collect();
|
|
// let start_2x = performance.now();
|
|
// mamba.forward(&input_2x);
|
|
// let duration_2x = performance.now() - start_2x;
|
|
//
|
|
// assert!(duration_2x < duration * 2.5, "Should be O(n) not O(n^2)");
|
|
|
|
assert_eq!(input.len(), seq_len);
|
|
}
|
|
|
|
#[wasm_bindgen_test]
|
|
fn test_mamba_ssm_selective_scan() {
|
|
// Test the selective scan mechanism
|
|
let dim = 32;
|
|
let seq_len = 50;
|
|
|
|
let input: Vec<Vec<f32>> = (0..seq_len).map(|_| random_vector(dim)).collect();
|
|
|
|
// TODO: Verify selective scan produces valid outputs
|
|
// let mamba = MambaSSM::new(dim);
|
|
// let (output, hidden_states) = mamba.forward_with_states(&input);
|
|
//
|
|
// Hidden states should evolve based on input
|
|
// for state in &hidden_states {
|
|
// assert_finite(state);
|
|
// }
|
|
|
|
assert!(dim > 0);
|
|
}
|
|
|
|
#[wasm_bindgen_test]
|
|
fn test_mamba_ssm_state_propagation() {
|
|
// Test that state is properly propagated across sequence
|
|
let dim = 16;
|
|
|
|
// TODO: Create a simple pattern and verify state carries information
|
|
// let mamba = MambaSSM::new(dim);
|
|
//
|
|
// Input with a spike at position 0
|
|
// let mut input = vec![vec![0.0; dim]; 20];
|
|
// input[0] = vec![1.0; dim];
|
|
//
|
|
// let output = mamba.forward(&input);
|
|
//
|
|
// Later positions should still have some response to the spike
|
|
// let response_at_5: f32 = output[5].iter().map(|x| x.abs()).sum();
|
|
// assert!(response_at_5 > 0.01, "State should propagate forward");
|
|
|
|
assert!(dim > 0);
|
|
}
|
|
|
|
// ========================================================================
|
|
// RWKV Attention Tests
|
|
// ========================================================================
|
|
|
|
#[wasm_bindgen_test]
|
|
fn test_rwkv_attention_basic() {
|
|
let dim = 64;
|
|
let seq_len = 100;
|
|
|
|
let input: Vec<Vec<f32>> = (0..seq_len).map(|_| random_vector(dim)).collect();
|
|
|
|
// TODO: Test RWKV linear attention
|
|
// let rwkv = RWKVAttention::new(dim);
|
|
// let output = rwkv.forward(&input);
|
|
// assert_eq!(output.len(), seq_len);
|
|
|
|
assert!(input.len() == seq_len);
|
|
}
|
|
|
|
#[wasm_bindgen_test]
|
|
fn test_rwkv_linear_complexity() {
|
|
// RWKV should be O(n) in sequence length
|
|
let dim = 32;
|
|
|
|
// TODO: Verify linear complexity
|
|
// let rwkv = RWKVAttention::new(dim);
|
|
//
|
|
// Time with 100 tokens
|
|
// let input_100 = (0..100).map(|_| random_vector(dim)).collect();
|
|
// let t1 = time_execution(|| rwkv.forward(&input_100));
|
|
//
|
|
// Time with 1000 tokens
|
|
// let input_1000 = (0..1000).map(|_| random_vector(dim)).collect();
|
|
// let t2 = time_execution(|| rwkv.forward(&input_1000));
|
|
//
|
|
// Should be roughly 10x, not 100x (O(n) vs O(n^2))
|
|
// assert!(t2 < t1 * 20.0, "RWKV should be O(n)");
|
|
|
|
assert!(dim > 0);
|
|
}
|
|
|
|
// ========================================================================
|
|
// Flash Attention Approximation Tests
|
|
// ========================================================================
|
|
|
|
#[wasm_bindgen_test]
|
|
fn test_flash_attention_approximation() {
|
|
let dim = 64;
|
|
let seq_len = 128;
|
|
|
|
let queries: Vec<Vec<f32>> = (0..seq_len).map(|_| random_vector(dim)).collect();
|
|
let keys: Vec<Vec<f32>> = (0..seq_len).map(|_| random_vector(dim)).collect();
|
|
let values: Vec<Vec<f32>> = (0..seq_len).map(|_| random_vector(dim)).collect();
|
|
|
|
// TODO: Compare flash attention to standard attention
|
|
// let standard = StandardAttention::new(dim);
|
|
// let flash = FlashAttention::new(dim);
|
|
//
|
|
// let output_standard = standard.forward(&queries, &keys, &values);
|
|
// let output_flash = flash.forward(&queries, &keys, &values);
|
|
//
|
|
// Should be numerically close
|
|
// for (std_out, flash_out) in output_standard.iter().zip(output_flash.iter()) {
|
|
// assert_vectors_approx_eq(std_out, flash_out, 1e-4);
|
|
// }
|
|
|
|
assert!(queries.len() == seq_len);
|
|
}
|
|
|
|
#[wasm_bindgen_test]
|
|
fn test_flash_attention_memory_efficiency() {
|
|
// Flash attention should use less memory for long sequences
|
|
let dim = 64;
|
|
let seq_len = 512;
|
|
|
|
// TODO: Verify memory usage is O(n) not O(n^2)
|
|
// This is harder to test in WASM, but we can verify it doesn't OOM
|
|
|
|
assert!(seq_len > 0);
|
|
}
|
|
|
|
// ========================================================================
|
|
// Hyperbolic Attention Tests
|
|
// ========================================================================
|
|
|
|
#[wasm_bindgen_test]
|
|
fn test_hyperbolic_attention_basic() {
|
|
let dim = 32;
|
|
let curvature = -1.0;
|
|
|
|
let query = random_vector(dim);
|
|
let keys: Vec<Vec<f32>> = (0..10).map(|_| random_vector(dim)).collect();
|
|
let values: Vec<Vec<f32>> = (0..10).map(|_| random_vector(dim)).collect();
|
|
|
|
// TODO: Test hyperbolic attention
|
|
// let hyp_attn = HyperbolicAttention::new(dim, curvature);
|
|
// let output = hyp_attn.forward(&query, &keys, &values);
|
|
//
|
|
// assert_eq!(output.len(), dim);
|
|
// assert_finite(&output);
|
|
|
|
assert!(curvature < 0.0);
|
|
}
|
|
|
|
#[wasm_bindgen_test]
|
|
fn test_hyperbolic_distance_properties() {
|
|
// Test Poincare distance metric properties
|
|
let dim = 8;
|
|
|
|
let u = random_vector(dim);
|
|
let v = random_vector(dim);
|
|
|
|
// TODO: Verify metric properties
|
|
// let d_uv = poincare_distance(&u, &v, 1.0);
|
|
// let d_vu = poincare_distance(&v, &u, 1.0);
|
|
//
|
|
// Symmetry
|
|
// assert!((d_uv - d_vu).abs() < 1e-6);
|
|
//
|
|
// Non-negativity
|
|
// assert!(d_uv >= 0.0);
|
|
//
|
|
// Identity
|
|
// let d_uu = poincare_distance(&u, &u, 1.0);
|
|
// assert!(d_uu.abs() < 1e-6);
|
|
|
|
assert!(dim > 0);
|
|
}
|
|
|
|
// ========================================================================
|
|
// Unified Interface Tests
|
|
// ========================================================================
|
|
|
|
#[wasm_bindgen_test]
|
|
fn test_attention_mechanism_registry() {
|
|
// Test that all mechanisms can be accessed through unified interface
|
|
|
|
// TODO: Test mechanism registry
|
|
// let registry = AttentionRegistry::new();
|
|
//
|
|
// assert!(registry.has_mechanism("multi_head"));
|
|
// assert!(registry.has_mechanism("mamba_ssm"));
|
|
// assert!(registry.has_mechanism("rwkv"));
|
|
// assert!(registry.has_mechanism("flash"));
|
|
// assert!(registry.has_mechanism("hyperbolic"));
|
|
|
|
assert!(true);
|
|
}
|
|
|
|
#[wasm_bindgen_test]
|
|
fn test_attention_factory() {
|
|
// Test creating different attention types through factory
|
|
|
|
// TODO: Test factory pattern
|
|
// let factory = AttentionFactory::new();
|
|
//
|
|
// let config = AttentionConfig {
|
|
// dim: 64,
|
|
// num_heads: 8,
|
|
// mechanism: "multi_head",
|
|
// };
|
|
//
|
|
// let attention = factory.create(&config);
|
|
// assert!(attention.is_some());
|
|
|
|
assert!(true);
|
|
}
|
|
|
|
// ========================================================================
|
|
// Numerical Stability Tests
|
|
// ========================================================================
|
|
|
|
#[wasm_bindgen_test]
|
|
fn test_attention_numerical_stability_large_values() {
|
|
let dim = 32;
|
|
|
|
// Test with large input values
|
|
let query: Vec<f32> = (0..dim).map(|i| (i as f32) * 100.0).collect();
|
|
let keys: Vec<Vec<f32>> = (0..10).map(|i| vec![(i as f32) * 100.0; dim]).collect();
|
|
|
|
// TODO: Should not overflow or produce NaN
|
|
// let attention = MultiHeadAttention::new(dim, 4);
|
|
// let output = attention.forward(&query, &keys, &values);
|
|
// assert_finite(&output);
|
|
|
|
assert!(query[0].is_finite());
|
|
}
|
|
|
|
#[wasm_bindgen_test]
|
|
fn test_attention_numerical_stability_small_values() {
|
|
let dim = 32;
|
|
|
|
// Test with very small input values
|
|
let query: Vec<f32> = vec![1e-10; dim];
|
|
let keys: Vec<Vec<f32>> = (0..10).map(|_| vec![1e-10; dim]).collect();
|
|
|
|
// TODO: Should not underflow or produce NaN
|
|
// let attention = MultiHeadAttention::new(dim, 4);
|
|
// let output = attention.forward(&query, &keys, &values);
|
|
// assert_finite(&output);
|
|
|
|
assert!(query[0].is_finite());
|
|
}
|
|
|
|
// ========================================================================
|
|
// Performance Constraint Tests
|
|
// ========================================================================
|
|
|
|
#[wasm_bindgen_test]
|
|
fn test_attention_latency_target() {
|
|
// Target: <100 microseconds per mechanism at 100 tokens
|
|
let dim = 64;
|
|
let seq_len = 100;
|
|
|
|
let queries: Vec<Vec<f32>> = (0..seq_len).map(|_| random_vector(dim)).collect();
|
|
let keys: Vec<Vec<f32>> = (0..seq_len).map(|_| random_vector(dim)).collect();
|
|
let values: Vec<Vec<f32>> = (0..seq_len).map(|_| random_vector(dim)).collect();
|
|
|
|
// TODO: Measure latency when implemented
|
|
// let attention = MultiHeadAttention::new(dim, 8);
|
|
//
|
|
// Warm up
|
|
// attention.forward(&queries[0], &keys, &values);
|
|
//
|
|
// Measure
|
|
// let start = performance.now();
|
|
// for _ in 0..100 {
|
|
// attention.forward(&queries[0], &keys, &values);
|
|
// }
|
|
// let avg_latency_us = (performance.now() - start) * 10.0; // 100 runs -> us
|
|
//
|
|
// assert!(avg_latency_us < 100.0, "Latency {} us exceeds 100 us target", avg_latency_us);
|
|
|
|
assert!(queries.len() == seq_len);
|
|
}
|
|
}
|