//! Integration tests for ruvector-attention-unified-wasm //! //! Tests for unified attention mechanisms including: //! - Multi-head self-attention //! - Mamba SSM (Selective State Space Model) //! - RWKV attention //! - Flash attention approximation //! - Hyperbolic attention #[cfg(test)] mod tests { use wasm_bindgen_test::*; use super::super::common::*; wasm_bindgen_test_configure!(run_in_browser); // ======================================================================== // Multi-Head Attention Tests // ======================================================================== #[wasm_bindgen_test] fn test_multi_head_attention_basic() { // Setup query, keys, values let dim = 64; let num_heads = 8; let head_dim = dim / num_heads; let seq_len = 16; let query = random_vector(dim); let keys: Vec> = (0..seq_len).map(|_| random_vector(dim)).collect(); let values: Vec> = (0..seq_len).map(|_| random_vector(dim)).collect(); // TODO: When ruvector-attention-unified-wasm is implemented: // let attention = MultiHeadAttention::new(dim, num_heads); // let output = attention.forward(&query, &keys, &values); // // Assert output shape // assert_eq!(output.len(), dim); // assert_finite(&output); // Placeholder assertion assert_eq!(query.len(), dim); assert_eq!(keys.len(), seq_len); } #[wasm_bindgen_test] fn test_multi_head_attention_output_shape() { let dim = 128; let num_heads = 16; let seq_len = 32; let queries: Vec> = (0..seq_len).map(|_| random_vector(dim)).collect(); let keys: Vec> = (0..seq_len).map(|_| random_vector(dim)).collect(); let values: Vec> = (0..seq_len).map(|_| random_vector(dim)).collect(); // TODO: Verify output shape matches (seq_len, dim) // let attention = MultiHeadAttention::new(dim, num_heads); // let outputs = attention.forward_batch(&queries, &keys, &values); // assert_eq!(outputs.len(), seq_len); // for output in &outputs { // assert_eq!(output.len(), dim); // assert_finite(output); // } assert_eq!(queries.len(), seq_len); } #[wasm_bindgen_test] fn test_multi_head_attention_causality() { // Test that causal masking works correctly let dim = 32; let seq_len = 8; // TODO: Verify causal attention doesn't attend to future tokens // let attention = MultiHeadAttention::new_causal(dim, 4); // let weights = attention.get_attention_weights(&queries, &keys); // // For each position i, weights[i][j] should be 0 for j > i // for i in 0..seq_len { // for j in (i+1)..seq_len { // assert_eq!(weights[i][j], 0.0, "Causal violation at ({}, {})", i, j); // } // } assert!(dim > 0); } // ======================================================================== // Mamba SSM Tests // ======================================================================== #[wasm_bindgen_test] fn test_mamba_ssm_basic() { // Test O(n) selective scan complexity let dim = 64; let seq_len = 100; let input: Vec> = (0..seq_len).map(|_| random_vector(dim)).collect(); // TODO: When Mamba SSM is implemented: // let mamba = MambaSSM::new(dim); // let output = mamba.forward(&input); // // Assert O(n) complexity by timing // let start = performance.now(); // mamba.forward(&input); // let duration = performance.now() - start; // // Double input size should roughly double time (O(n)) // let input_2x = (0..seq_len*2).map(|_| random_vector(dim)).collect(); // let start_2x = performance.now(); // mamba.forward(&input_2x); // let duration_2x = performance.now() - start_2x; // // assert!(duration_2x < duration * 2.5, "Should be O(n) not O(n^2)"); assert_eq!(input.len(), seq_len); } #[wasm_bindgen_test] fn test_mamba_ssm_selective_scan() { // Test the selective scan mechanism let dim = 32; let seq_len = 50; let input: Vec> = (0..seq_len).map(|_| random_vector(dim)).collect(); // TODO: Verify selective scan produces valid outputs // let mamba = MambaSSM::new(dim); // let (output, hidden_states) = mamba.forward_with_states(&input); // // Hidden states should evolve based on input // for state in &hidden_states { // assert_finite(state); // } assert!(dim > 0); } #[wasm_bindgen_test] fn test_mamba_ssm_state_propagation() { // Test that state is properly propagated across sequence let dim = 16; // TODO: Create a simple pattern and verify state carries information // let mamba = MambaSSM::new(dim); // // Input with a spike at position 0 // let mut input = vec![vec![0.0; dim]; 20]; // input[0] = vec![1.0; dim]; // // let output = mamba.forward(&input); // // Later positions should still have some response to the spike // let response_at_5: f32 = output[5].iter().map(|x| x.abs()).sum(); // assert!(response_at_5 > 0.01, "State should propagate forward"); assert!(dim > 0); } // ======================================================================== // RWKV Attention Tests // ======================================================================== #[wasm_bindgen_test] fn test_rwkv_attention_basic() { let dim = 64; let seq_len = 100; let input: Vec> = (0..seq_len).map(|_| random_vector(dim)).collect(); // TODO: Test RWKV linear attention // let rwkv = RWKVAttention::new(dim); // let output = rwkv.forward(&input); // assert_eq!(output.len(), seq_len); assert!(input.len() == seq_len); } #[wasm_bindgen_test] fn test_rwkv_linear_complexity() { // RWKV should be O(n) in sequence length let dim = 32; // TODO: Verify linear complexity // let rwkv = RWKVAttention::new(dim); // // Time with 100 tokens // let input_100 = (0..100).map(|_| random_vector(dim)).collect(); // let t1 = time_execution(|| rwkv.forward(&input_100)); // // Time with 1000 tokens // let input_1000 = (0..1000).map(|_| random_vector(dim)).collect(); // let t2 = time_execution(|| rwkv.forward(&input_1000)); // // Should be roughly 10x, not 100x (O(n) vs O(n^2)) // assert!(t2 < t1 * 20.0, "RWKV should be O(n)"); assert!(dim > 0); } // ======================================================================== // Flash Attention Approximation Tests // ======================================================================== #[wasm_bindgen_test] fn test_flash_attention_approximation() { let dim = 64; let seq_len = 128; let queries: Vec> = (0..seq_len).map(|_| random_vector(dim)).collect(); let keys: Vec> = (0..seq_len).map(|_| random_vector(dim)).collect(); let values: Vec> = (0..seq_len).map(|_| random_vector(dim)).collect(); // TODO: Compare flash attention to standard attention // let standard = StandardAttention::new(dim); // let flash = FlashAttention::new(dim); // // let output_standard = standard.forward(&queries, &keys, &values); // let output_flash = flash.forward(&queries, &keys, &values); // // Should be numerically close // for (std_out, flash_out) in output_standard.iter().zip(output_flash.iter()) { // assert_vectors_approx_eq(std_out, flash_out, 1e-4); // } assert!(queries.len() == seq_len); } #[wasm_bindgen_test] fn test_flash_attention_memory_efficiency() { // Flash attention should use less memory for long sequences let dim = 64; let seq_len = 512; // TODO: Verify memory usage is O(n) not O(n^2) // This is harder to test in WASM, but we can verify it doesn't OOM assert!(seq_len > 0); } // ======================================================================== // Hyperbolic Attention Tests // ======================================================================== #[wasm_bindgen_test] fn test_hyperbolic_attention_basic() { let dim = 32; let curvature = -1.0; let query = random_vector(dim); let keys: Vec> = (0..10).map(|_| random_vector(dim)).collect(); let values: Vec> = (0..10).map(|_| random_vector(dim)).collect(); // TODO: Test hyperbolic attention // let hyp_attn = HyperbolicAttention::new(dim, curvature); // let output = hyp_attn.forward(&query, &keys, &values); // // assert_eq!(output.len(), dim); // assert_finite(&output); assert!(curvature < 0.0); } #[wasm_bindgen_test] fn test_hyperbolic_distance_properties() { // Test Poincare distance metric properties let dim = 8; let u = random_vector(dim); let v = random_vector(dim); // TODO: Verify metric properties // let d_uv = poincare_distance(&u, &v, 1.0); // let d_vu = poincare_distance(&v, &u, 1.0); // // Symmetry // assert!((d_uv - d_vu).abs() < 1e-6); // // Non-negativity // assert!(d_uv >= 0.0); // // Identity // let d_uu = poincare_distance(&u, &u, 1.0); // assert!(d_uu.abs() < 1e-6); assert!(dim > 0); } // ======================================================================== // Unified Interface Tests // ======================================================================== #[wasm_bindgen_test] fn test_attention_mechanism_registry() { // Test that all mechanisms can be accessed through unified interface // TODO: Test mechanism registry // let registry = AttentionRegistry::new(); // // assert!(registry.has_mechanism("multi_head")); // assert!(registry.has_mechanism("mamba_ssm")); // assert!(registry.has_mechanism("rwkv")); // assert!(registry.has_mechanism("flash")); // assert!(registry.has_mechanism("hyperbolic")); assert!(true); } #[wasm_bindgen_test] fn test_attention_factory() { // Test creating different attention types through factory // TODO: Test factory pattern // let factory = AttentionFactory::new(); // // let config = AttentionConfig { // dim: 64, // num_heads: 8, // mechanism: "multi_head", // }; // // let attention = factory.create(&config); // assert!(attention.is_some()); assert!(true); } // ======================================================================== // Numerical Stability Tests // ======================================================================== #[wasm_bindgen_test] fn test_attention_numerical_stability_large_values() { let dim = 32; // Test with large input values let query: Vec = (0..dim).map(|i| (i as f32) * 100.0).collect(); let keys: Vec> = (0..10).map(|i| vec![(i as f32) * 100.0; dim]).collect(); // TODO: Should not overflow or produce NaN // let attention = MultiHeadAttention::new(dim, 4); // let output = attention.forward(&query, &keys, &values); // assert_finite(&output); assert!(query[0].is_finite()); } #[wasm_bindgen_test] fn test_attention_numerical_stability_small_values() { let dim = 32; // Test with very small input values let query: Vec = vec![1e-10; dim]; let keys: Vec> = (0..10).map(|_| vec![1e-10; dim]).collect(); // TODO: Should not underflow or produce NaN // let attention = MultiHeadAttention::new(dim, 4); // let output = attention.forward(&query, &keys, &values); // assert_finite(&output); assert!(query[0].is_finite()); } // ======================================================================== // Performance Constraint Tests // ======================================================================== #[wasm_bindgen_test] fn test_attention_latency_target() { // Target: <100 microseconds per mechanism at 100 tokens let dim = 64; let seq_len = 100; let queries: Vec> = (0..seq_len).map(|_| random_vector(dim)).collect(); let keys: Vec> = (0..seq_len).map(|_| random_vector(dim)).collect(); let values: Vec> = (0..seq_len).map(|_| random_vector(dim)).collect(); // TODO: Measure latency when implemented // let attention = MultiHeadAttention::new(dim, 8); // // Warm up // attention.forward(&queries[0], &keys, &values); // // Measure // let start = performance.now(); // for _ in 0..100 { // attention.forward(&queries[0], &keys, &values); // } // let avg_latency_us = (performance.now() - start) * 10.0; // 100 runs -> us // // assert!(avg_latency_us < 100.0, "Latency {} us exceeds 100 us target", avg_latency_us); assert!(queries.len() == seq_len); } }