Files
wifi-densepose/docs/research/latent-space/implementation-plans/agents/11-unit-tests.md
ruv d803bfe2b1 Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector
git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
2026-02-28 14:39:40 -05:00

69 KiB

Agent 11: Comprehensive Unit Testing Suite

Agent: QA Testing Specialist Status: Implementation Ready Dependencies: All attention mechanism implementations (Agents 1-6) Target Directory: tests/unit/

Overview

This document provides complete unit test specifications for all attention mechanisms in the ruvector-attention crate. Tests cover functionality, edge cases, numerical stability, gradient correctness, and serialization, with property-based testing for mathematical invariants.

Test Organization Structure

tests/
├── unit/
│   ├── mod.rs                          # Test module aggregator
│   ├── scaled_dot_product_tests.rs     # Agent 1 tests
│   ├── multi_head_tests.rs             # Agent 3 tests
│   ├── hyperbolic_tests.rs             # Agent 2 tests
│   ├── sparse_tests.rs                 # Agent 4 tests
│   ├── graph_tests.rs                  # Agent 6 tests
│   ├── moe_tests.rs                    # Agent 5 tests
│   └── test_utils.rs                   # Shared utilities
└── fixtures/
    ├── sample_data.rs
    └── numerical_gradients.rs

1. Test Utilities and Fixtures

1.1 Test Utilities (tests/unit/test_utils.rs)

//! Common test utilities for attention mechanisms

use ndarray::{Array1, Array2, ArrayView1};
use approx::AbsDiffEq;

pub const EPSILON: f32 = 1e-6;
pub const GRAD_EPSILON: f32 = 1e-4;

/// Generate random array with deterministic seed for reproducibility
pub fn random_array1(size: usize, seed: u64) -> Array1<f32> {
    use rand::{SeedableRng, Rng};
    use rand::rngs::StdRng;

    let mut rng = StdRng::seed_from_u64(seed);
    Array1::from_shape_fn(size, |_| rng.gen_range(-1.0..1.0))
}

pub fn random_array2(shape: (usize, usize), seed: u64) -> Array2<f32> {
    use rand::{SeedableRng, Rng};
    use rand::rngs::StdRng;

    let mut rng = StdRng::seed_from_u64(seed);
    Array2::from_shape_fn(shape, |_| rng.gen_range(-1.0..1.0))
}

/// Numerical gradient computation for gradient checking
pub fn numerical_gradient<F>(
    f: F,
    x: &Array1<f32>,
    idx: usize,
) -> f32
where
    F: Fn(&Array1<f32>) -> f32,
{
    let mut x_plus = x.clone();
    let mut x_minus = x.clone();

    x_plus[idx] += GRAD_EPSILON;
    x_minus[idx] -= GRAD_EPSILON;

    (f(&x_plus) - f(&x_minus)) / (2.0 * GRAD_EPSILON)
}

/// Check if two arrays are approximately equal
pub fn assert_arrays_close(a: &Array1<f32>, b: &Array1<f32>, tolerance: f32) {
    assert_eq!(a.len(), b.len(), "Array lengths must match");

    for (i, (a_val, b_val)) in a.iter().zip(b.iter()).enumerate() {
        assert!(
            (a_val - b_val).abs() < tolerance,
            "Arrays differ at index {}: {} vs {} (diff: {})",
            i, a_val, b_val, (a_val - b_val).abs()
        );
    }
}

/// Verify attention weights sum to 1.0
pub fn verify_attention_weights(weights: &Array1<f32>, tolerance: f32) {
    let sum: f32 = weights.iter().sum();
    assert!(
        (sum - 1.0).abs() < tolerance,
        "Attention weights should sum to 1.0, got: {}",
        sum
    );
}

/// Verify all weights are non-negative
pub fn verify_non_negative(weights: &Array1<f32>) {
    for (i, w) in weights.iter().enumerate() {
        assert!(
            *w >= 0.0,
            "Weight at index {} is negative: {}",
            i, w
        );
    }
}

/// Generate normalized random vector
pub fn random_normalized_vector(size: usize, seed: u64) -> Array1<f32> {
    let vec = random_array1(size, seed);
    let norm = vec.mapv(|x| x * x).sum().sqrt();
    vec / norm
}

/// Create one-hot encoded vector
pub fn one_hot(size: usize, idx: usize) -> Array1<f32> {
    let mut vec = Array1::zeros(size);
    vec[idx] = 1.0;
    vec
}

/// Relative error between two values
pub fn relative_error(a: f32, b: f32) -> f32 {
    if a.abs() < EPSILON && b.abs() < EPSILON {
        0.0
    } else {
        (a - b).abs() / (a.abs() + b.abs()).max(EPSILON)
    }
}

1.2 Test Fixtures (tests/fixtures/sample_data.rs)

//! Sample data fixtures for testing

use ndarray::{Array1, Array2};

pub struct AttentionTestData {
    pub query: Array1<f32>,
    pub keys: Array2<f32>,
    pub values: Array2<f32>,
    pub expected_weights: Option<Array1<f32>>,
}

impl AttentionTestData {
    /// Simple test case: query matches first key exactly
    pub fn exact_match() -> Self {
        let query = Array1::from_vec(vec![1.0, 0.0, 0.0, 0.0]);
        let keys = Array2::from_shape_vec(
            (3, 4),
            vec![
                1.0, 0.0, 0.0, 0.0,  // Exact match
                0.0, 1.0, 0.0, 0.0,  // Orthogonal
                0.0, 0.0, 1.0, 0.0,  // Orthogonal
            ],
        ).unwrap();
        let values = Array2::from_shape_vec(
            (3, 4),
            vec![
                1.0, 0.0, 0.0, 0.0,
                0.0, 1.0, 0.0, 0.0,
                0.0, 0.0, 1.0, 0.0,
            ],
        ).unwrap();

        Self { query, keys, values, expected_weights: None }
    }

    /// Uniform case: all keys equally similar to query
    pub fn uniform() -> Self {
        let dim = 4;
        let n_keys = 3;
        let query = Array1::from_elem(dim, 0.5);
        let keys = Array2::from_elem((n_keys, dim), 0.5);
        let values = Array2::from_elem((n_keys, dim), 1.0);

        Self { query, keys, values, expected_weights: None }
    }

    /// Empty case: no keys
    pub fn empty() -> Self {
        let query = Array1::from_vec(vec![1.0, 0.0, 0.0, 0.0]);
        let keys = Array2::zeros((0, 4));
        let values = Array2::zeros((0, 4));

        Self { query, keys, values, expected_weights: None }
    }

    /// Single key
    pub fn single_key() -> Self {
        let query = Array1::from_vec(vec![1.0, 0.5, 0.0, 0.0]);
        let keys = Array2::from_shape_vec((1, 4), vec![0.5, 1.0, 0.0, 0.0]).unwrap();
        let values = Array2::from_shape_vec((1, 4), vec![2.0, 3.0, 0.0, 0.0]).unwrap();

        Self { query, keys, values, expected_weights: Some(Array1::from_vec(vec![1.0])) }
    }
}

2. Scaled Dot-Product Attention Tests

File: tests/unit/scaled_dot_product_tests.rs

use ruvector_attention::scaled_dot_product::ScaledDotProduct;
use ndarray::{Array1, Array2};
use approx::assert_relative_eq;
use proptest::prelude::*;

mod test_utils;
use test_utils::*;

#[cfg(test)]
mod basic_functionality {
    use super::*;

    #[test]
    fn test_exact_match_attention() {
        let attention = ScaledDotProduct::new(4);
        let data = AttentionTestData::exact_match();

        let output = attention.forward(&data.query, &data.keys, &data.values)
            .expect("Forward pass failed");

        // Output should be dominated by first value (exact match)
        assert!(output[0] > 0.8, "First dimension should dominate: {}", output[0]);
    }

    #[test]
    fn test_attention_weights_sum_to_one() {
        let attention = ScaledDotProduct::new(4);
        let query = random_array1(4, 42);
        let keys = random_array2((5, 4), 123);

        let weights = attention.compute_weights(&query, &keys)
            .expect("Weight computation failed");

        verify_attention_weights(&weights, EPSILON);
        verify_non_negative(&weights);
    }

    #[test]
    fn test_dimension_preservation() {
        let dim = 8;
        let n_keys = 10;
        let attention = ScaledDotProduct::new(dim);

        let query = random_array1(dim, 1);
        let keys = random_array2((n_keys, dim), 2);
        let values = random_array2((n_keys, dim), 3);

        let output = attention.forward(&query, &keys, &values)
            .expect("Forward pass failed");

        assert_eq!(output.len(), dim, "Output dimension mismatch");
    }

    #[test]
    fn test_orthogonal_queries() {
        let attention = ScaledDotProduct::new(4);

        // Query orthogonal to all keys should give uniform attention
        let query = Array1::from_vec(vec![1.0, 0.0, 0.0, 0.0]);
        let keys = Array2::from_shape_vec(
            (3, 4),
            vec![
                0.0, 1.0, 0.0, 0.0,
                0.0, 0.0, 1.0, 0.0,
                0.0, 0.0, 0.0, 1.0,
            ],
        ).unwrap();

        let weights = attention.compute_weights(&query, &keys)
            .expect("Weight computation failed");

        // All weights should be approximately equal (1/3)
        for w in weights.iter() {
            assert_relative_eq!(*w, 1.0/3.0, epsilon = 0.01);
        }
    }
}

#[cfg(test)]
mod edge_cases {
    use super::*;

    #[test]
    fn test_empty_keys() {
        let attention = ScaledDotProduct::new(4);
        let data = AttentionTestData::empty();

        let result = attention.forward(&data.query, &data.keys, &data.values);

        // Should return error or zero vector
        match result {
            Ok(output) => assert_eq!(output.len(), 4),
            Err(_) => (), // Error is acceptable
        }
    }

    #[test]
    fn test_single_key() {
        let attention = ScaledDotProduct::new(4);
        let data = AttentionTestData::single_key();

        let weights = attention.compute_weights(&data.query, &data.keys)
            .expect("Single key should work");

        assert_relative_eq!(weights[0], 1.0, epsilon = EPSILON);
    }

    #[test]
    fn test_zero_query() {
        let attention = ScaledDotProduct::new(4);
        let query = Array1::zeros(4);
        let keys = random_array2((3, 4), 42);
        let values = random_array2((3, 4), 43);

        let output = attention.forward(&query, &keys, &values)
            .expect("Zero query should work");

        // Should give uniform attention
        assert_eq!(output.len(), 4);
    }

    #[test]
    fn test_zero_keys() {
        let attention = ScaledDotProduct::new(4);
        let query = random_array1(4, 42);
        let keys = Array2::zeros((3, 4));
        let values = random_array2((3, 4), 43);

        let weights = attention.compute_weights(&query, &keys)
            .expect("Zero keys should work");

        // Uniform attention when all keys are identical
        verify_attention_weights(&weights, EPSILON);
    }

    #[test]
    fn test_very_large_values() {
        let attention = ScaledDotProduct::new(4);
        let query = Array1::from_elem(4, 1000.0);
        let keys = Array2::from_elem((3, 4), 1000.0);
        let values = Array2::from_elem((3, 4), 1.0);

        let output = attention.forward(&query, &keys, &values)
            .expect("Large values should not overflow");

        assert!(output.iter().all(|&x| x.is_finite()));
    }
}

#[cfg(test)]
mod numerical_stability {
    use super::*;

    #[test]
    fn test_softmax_numerical_stability() {
        let attention = ScaledDotProduct::new(4);

        // Very large scores that could cause overflow
        let query = Array1::from_elem(4, 100.0);
        let keys = Array2::from_elem((3, 4), 100.0);

        let weights = attention.compute_weights(&query, &keys)
            .expect("Should handle large scores");

        verify_attention_weights(&weights, EPSILON);
        assert!(weights.iter().all(|&x| x.is_finite()));
    }

    #[test]
    fn test_very_small_values() {
        let attention = ScaledDotProduct::new(4);
        let query = Array1::from_elem(4, 1e-10);
        let keys = Array2::from_elem((3, 4), 1e-10);
        let values = Array2::from_elem((3, 4), 1.0);

        let output = attention.forward(&query, &keys, &values)
            .expect("Small values should not underflow");

        assert!(output.iter().all(|&x| x.is_finite()));
    }

    #[test]
    fn test_mixed_magnitude_scores() {
        let attention = ScaledDotProduct::new(4);

        let query = Array1::from_vec(vec![1.0, 0.0, 0.0, 0.0]);
        let keys = Array2::from_shape_vec(
            (3, 4),
            vec![
                100.0, 0.0, 0.0, 0.0,  // Very high score
                0.001, 0.0, 0.0, 0.0,  // Very low score
                1.0, 0.0, 0.0, 0.0,    // Medium score
            ],
        ).unwrap();

        let weights = attention.compute_weights(&query, &keys)
            .expect("Mixed magnitudes should work");

        verify_attention_weights(&weights, EPSILON);
    }

    #[test]
    fn test_scaling_factor_effectiveness() {
        let dim = 64; // Larger dimension to test scaling
        let attention = ScaledDotProduct::new(dim);

        let query = random_normalized_vector(dim, 42);
        let keys = random_array2((10, dim), 123);

        let weights = attention.compute_weights(&query, &keys)
            .expect("Scaling should prevent saturation");

        // No weight should be extremely close to 1 or 0 with random data
        let max_weight = weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
        assert!(max_weight < 0.95, "Scaling should prevent saturation");
    }
}

#[cfg(test)]
mod gradient_correctness {
    use super::*;

    #[test]
    fn test_gradient_numerical_check() {
        let attention = ScaledDotProduct::new(4);

        let query = random_array1(4, 42);
        let keys = random_array2((3, 4), 123);
        let values = random_array2((3, 4), 456);

        // Define loss function (mean of output)
        let loss_fn = |q: &Array1<f32>| -> f32 {
            let output = attention.forward(q, &keys, &values).unwrap();
            output.mean().unwrap()
        };

        // Compute analytical gradient (via autograd or manual)
        let analytical_grad = attention.backward(&query, &keys, &values)
            .expect("Backward pass failed");

        // Compute numerical gradient
        let mut numerical_grad = Array1::zeros(4);
        for i in 0..4 {
            numerical_grad[i] = numerical_gradient(&loss_fn, &query, i);
        }

        // Compare
        for i in 0..4 {
            let rel_err = relative_error(analytical_grad[i], numerical_grad[i]);
            assert!(
                rel_err < 0.01,
                "Gradient mismatch at index {}: analytical={}, numerical={}, rel_err={}",
                i, analytical_grad[i], numerical_grad[i], rel_err
            );
        }
    }
}

#[cfg(test)]
mod serialization {
    use super::*;
    use serde_json;

    #[test]
    fn test_serialization_roundtrip() {
        let attention = ScaledDotProduct::new(4);

        // Serialize
        let serialized = serde_json::to_string(&attention)
            .expect("Serialization failed");

        // Deserialize
        let deserialized: ScaledDotProduct = serde_json::from_str(&serialized)
            .expect("Deserialization failed");

        // Verify behavior is identical
        let query = random_array1(4, 42);
        let keys = random_array2((3, 4), 123);
        let values = random_array2((3, 4), 456);

        let original_output = attention.forward(&query, &keys, &values).unwrap();
        let deserialized_output = deserialized.forward(&query, &keys, &values).unwrap();

        assert_arrays_close(&original_output, &deserialized_output, EPSILON);
    }
}

#[cfg(test)]
mod property_based_tests {
    use super::*;

    proptest! {
        #[test]
        fn prop_dimension_preservation(
            dim in 2usize..32,
            n_keys in 1usize..20,
            seed in 0u64..1000
        ) {
            let attention = ScaledDotProduct::new(dim);
            let query = random_array1(dim, seed);
            let keys = random_array2((n_keys, dim), seed + 1);
            let values = random_array2((n_keys, dim), seed + 2);

            let output = attention.forward(&query, &keys, &values).unwrap();
            prop_assert_eq!(output.len(), dim);
        }

        #[test]
        fn prop_weights_sum_to_one(
            dim in 2usize..16,
            n_keys in 1usize..10,
            seed in 0u64..100
        ) {
            let attention = ScaledDotProduct::new(dim);
            let query = random_array1(dim, seed);
            let keys = random_array2((n_keys, dim), seed + 1);

            let weights = attention.compute_weights(&query, &keys).unwrap();
            let sum: f32 = weights.iter().sum();

            prop_assert!((sum - 1.0).abs() < EPSILON);
        }

        #[test]
        fn prop_non_negative_weights(
            dim in 2usize..16,
            n_keys in 1usize..10,
            seed in 0u64..100
        ) {
            let attention = ScaledDotProduct::new(dim);
            let query = random_array1(dim, seed);
            let keys = random_array2((n_keys, dim), seed + 1);

            let weights = attention.compute_weights(&query, &keys).unwrap();

            for w in weights.iter() {
                prop_assert!(*w >= 0.0);
            }
        }

        #[test]
        fn prop_output_bounded(
            dim in 2usize..16,
            n_keys in 1usize..10,
            seed in 0u64..100
        ) {
            let attention = ScaledDotProduct::new(dim);
            let query = random_array1(dim, seed);
            let keys = random_array2((n_keys, dim), seed + 1);

            // Values bounded in [-1, 1]
            let values = random_array2((n_keys, dim), seed + 2);

            let output = attention.forward(&query, &keys, &values).unwrap();

            // Output should be convex combination, thus also bounded
            for &val in output.iter() {
                prop_assert!(val.abs() <= 2.0); // Allow some margin
            }
        }

        #[test]
        fn prop_permutation_invariance_values(
            dim in 2usize..8,
            seed in 0u64..50
        ) {
            // Swapping rows of values with same keys shouldn't change
            // which keys get attention (weights should be same)
            let attention = ScaledDotProduct::new(dim);
            let query = random_array1(dim, seed);
            let keys = random_array2((3, dim), seed + 1);

            let weights = attention.compute_weights(&query, &keys).unwrap();

            // Permute keys and check weights permute accordingly
            let mut keys_perm = Array2::zeros((3, dim));
            keys_perm.row_mut(0).assign(&keys.row(2));
            keys_perm.row_mut(1).assign(&keys.row(0));
            keys_perm.row_mut(2).assign(&keys.row(1));

            let weights_perm = attention.compute_weights(&query, &keys_perm).unwrap();

            // weights_perm should be permutation of weights
            prop_assert_relative_eq!(weights[0], weights_perm[1], epsilon = EPSILON);
            prop_assert_relative_eq!(weights[1], weights_perm[2], epsilon = EPSILON);
            prop_assert_relative_eq!(weights[2], weights_perm[0], epsilon = EPSILON);
        }
    }
}

3. Multi-Head Attention Tests

File: tests/unit/multi_head_tests.rs

use ruvector_attention::multi_head::MultiHeadAttention;
use ndarray::{Array1, Array2, Array3};
use approx::assert_relative_eq;
use proptest::prelude::*;

mod test_utils;
use test_utils::*;

#[cfg(test)]
mod basic_functionality {
    use super::*;

    #[test]
    fn test_multi_head_initialization() {
        let mha = MultiHeadAttention::new(8, 4); // 8 dim, 4 heads

        assert_eq!(mha.num_heads(), 4);
        assert_eq!(mha.model_dim(), 8);
        assert_eq!(mha.head_dim(), 2); // 8/4 = 2
    }

    #[test]
    fn test_head_dimension_must_divide_model_dim() {
        // Should panic or return error if dimensions incompatible
        let result = std::panic::catch_unwind(|| {
            MultiHeadAttention::new(7, 4); // 7 not divisible by 4
        });

        assert!(result.is_err(), "Should fail when dimensions don't divide");
    }

    #[test]
    fn test_multi_head_forward() {
        let mha = MultiHeadAttention::new(8, 4);

        let query = random_array1(8, 42);
        let keys = random_array2((5, 8), 123);
        let values = random_array2((5, 8), 456);

        let output = mha.forward(&query, &keys, &values)
            .expect("Multi-head forward failed");

        assert_eq!(output.len(), 8, "Output dimension should match model_dim");
    }

    #[test]
    fn test_independent_heads() {
        // Each head should process different subspaces
        let mha = MultiHeadAttention::new(16, 4);

        let query = random_array1(16, 1);
        let keys = random_array2((3, 16), 2);
        let values = random_array2((3, 16), 3);

        let head_outputs = mha.get_head_outputs(&query, &keys, &values)
            .expect("Failed to get head outputs");

        assert_eq!(head_outputs.len(), 4, "Should have 4 head outputs");

        // Each head output should have dimension model_dim/num_heads
        for head_output in head_outputs.iter() {
            assert_eq!(head_output.len(), 4); // 16/4
        }
    }

    #[test]
    fn test_concat_projection() {
        let mha = MultiHeadAttention::new(12, 3);

        let query = random_array1(12, 42);
        let keys = random_array2((4, 12), 123);
        let values = random_array2((4, 12), 456);

        // Get individual head outputs
        let head_outputs = mha.get_head_outputs(&query, &keys, &values).unwrap();

        // Concatenate manually
        let manual_concat = concatenate_heads(&head_outputs);

        // Get model output
        let model_output = mha.forward(&query, &keys, &values).unwrap();

        // After projection, dimensions should match
        assert_eq!(model_output.len(), 12);
    }
}

#[cfg(test)]
mod edge_cases {
    use super::*;

    #[test]
    fn test_single_head() {
        // Single head should behave like standard attention
        let mha = MultiHeadAttention::new(8, 1);

        let query = random_array1(8, 42);
        let keys = random_array2((3, 8), 123);
        let values = random_array2((3, 8), 456);

        let output = mha.forward(&query, &keys, &values)
            .expect("Single head should work");

        assert_eq!(output.len(), 8);
    }

    #[test]
    fn test_maximum_heads() {
        // Each dimension is its own head
        let dim = 16;
        let mha = MultiHeadAttention::new(dim, dim);

        let query = random_array1(dim, 1);
        let keys = random_array2((3, dim), 2);
        let values = random_array2((3, dim), 3);

        let output = mha.forward(&query, &keys, &values)
            .expect("Max heads should work");

        assert_eq!(output.len(), dim);
    }

    #[test]
    fn test_empty_keys_all_heads() {
        let mha = MultiHeadAttention::new(8, 4);

        let query = random_array1(8, 42);
        let keys = Array2::zeros((0, 8));
        let values = Array2::zeros((0, 8));

        let result = mha.forward(&query, &keys, &values);

        // Should handle gracefully (error or zero output)
        assert!(result.is_ok() || result.is_err());
    }
}

#[cfg(test)]
mod numerical_stability {
    use super::*;

    #[test]
    fn test_large_number_of_heads() {
        let mha = MultiHeadAttention::new(64, 16);

        let query = random_array1(64, 42);
        let keys = random_array2((10, 64), 123);
        let values = random_array2((10, 64), 456);

        let output = mha.forward(&query, &keys, &values)
            .expect("Large number of heads should work");

        assert!(output.iter().all(|&x| x.is_finite()));
    }

    #[test]
    fn test_head_independence_no_interference() {
        // Changing input in one head's subspace shouldn't affect others
        let mha = MultiHeadAttention::new(8, 2);

        let mut query1 = Array1::zeros(8);
        query1[0] = 1.0; // First head's subspace

        let mut query2 = query1.clone();
        query2[4] = 1.0; // Second head's subspace

        let keys = random_array2((3, 8), 123);
        let values = random_array2((3, 8), 456);

        let output1 = mha.forward(&query1, &keys, &values).unwrap();
        let output2 = mha.forward(&query2, &keys, &values).unwrap();

        // Outputs should differ (heads process different subspaces)
        let diff = (&output1 - &output2).mapv(|x| x.abs()).sum();
        assert!(diff > EPSILON, "Outputs should differ when heads differ");
    }
}

#[cfg(test)]
mod gradient_correctness {
    use super::*;

    #[test]
    fn test_multi_head_gradient_check() {
        let mha = MultiHeadAttention::new(8, 2);

        let query = random_array1(8, 42);
        let keys = random_array2((3, 8), 123);
        let values = random_array2((3, 8), 456);

        let loss_fn = |q: &Array1<f32>| -> f32 {
            mha.forward(q, &keys, &values).unwrap().mean().unwrap()
        };

        let analytical_grad = mha.backward(&query, &keys, &values).unwrap();

        let mut numerical_grad = Array1::zeros(8);
        for i in 0..8 {
            numerical_grad[i] = numerical_gradient(&loss_fn, &query, i);
        }

        for i in 0..8 {
            let rel_err = relative_error(analytical_grad[i], numerical_grad[i]);
            assert!(rel_err < 0.01, "Gradient error at {}: {}", i, rel_err);
        }
    }
}

#[cfg(test)]
mod serialization {
    use super::*;
    use serde_json;

    #[test]
    fn test_multi_head_serialization() {
        let mha = MultiHeadAttention::new(12, 3);

        let serialized = serde_json::to_string(&mha).unwrap();
        let deserialized: MultiHeadAttention = serde_json::from_str(&serialized).unwrap();

        let query = random_array1(12, 42);
        let keys = random_array2((4, 12), 123);
        let values = random_array2((4, 12), 456);

        let original = mha.forward(&query, &keys, &values).unwrap();
        let restored = deserialized.forward(&query, &keys, &values).unwrap();

        assert_arrays_close(&original, &restored, EPSILON);
    }
}

#[cfg(test)]
mod property_based_tests {
    use super::*;

    proptest! {
        #[test]
        fn prop_valid_head_dimensions(
            model_dim in vec![8usize, 12, 16, 24, 32, 64],
            num_heads in 1usize..8
        ) {
            if model_dim % num_heads == 0 {
                let mha = MultiHeadAttention::new(model_dim, num_heads);
                prop_assert_eq!(mha.head_dim(), model_dim / num_heads);
            }
        }

        #[test]
        fn prop_output_dimension_preserved(
            heads in 1usize..5,
            head_dim in 2usize..8,
            n_keys in 1usize..10,
            seed in 0u64..100
        ) {
            let model_dim = heads * head_dim;
            let mha = MultiHeadAttention::new(model_dim, heads);

            let query = random_array1(model_dim, seed);
            let keys = random_array2((n_keys, model_dim), seed + 1);
            let values = random_array2((n_keys, model_dim), seed + 2);

            let output = mha.forward(&query, &keys, &values).unwrap();
            prop_assert_eq!(output.len(), model_dim);
        }

        #[test]
        fn prop_finite_outputs(
            heads in 1usize..4,
            head_dim in 2usize..6,
            n_keys in 1usize..8,
            seed in 0u64..50
        ) {
            let model_dim = heads * head_dim;
            let mha = MultiHeadAttention::new(model_dim, heads);

            let query = random_array1(model_dim, seed);
            let keys = random_array2((n_keys, model_dim), seed + 1);
            let values = random_array2((n_keys, model_dim), seed + 2);

            let output = mha.forward(&query, &keys, &values).unwrap();

            for &val in output.iter() {
                prop_assert!(val.is_finite());
            }
        }
    }
}

// Helper function
fn concatenate_heads(heads: &[Array1<f32>]) -> Array1<f32> {
    let total_dim: usize = heads.iter().map(|h| h.len()).sum();
    let mut result = Array1::zeros(total_dim);
    let mut offset = 0;

    for head in heads {
        result.slice_mut(s![offset..offset + head.len()]).assign(head);
        offset += head.len();
    }

    result
}

4. Hyperbolic Attention Tests

File: tests/unit/hyperbolic_tests.rs

use ruvector_attention::hyperbolic::HyperbolicAttention;
use ndarray::{Array1, Array2};
use approx::assert_relative_eq;
use proptest::prelude::*;

mod test_utils;
use test_utils::*;

#[cfg(test)]
mod poincare_operations {
    use super::*;

    #[test]
    fn test_poincare_distance_identity() {
        let x = random_normalized_vector(4, 42) * 0.5; // Stay in ball
        let curvature = 1.0;

        let dist = poincare_distance(&x, &x, curvature);

        assert_relative_eq!(dist, 0.0, epsilon = EPSILON);
    }

    #[test]
    fn test_poincare_distance_symmetry() {
        let x = random_normalized_vector(4, 42) * 0.3;
        let y = random_normalized_vector(4, 123) * 0.3;
        let curvature = 1.0;

        let dist_xy = poincare_distance(&x, &y, curvature);
        let dist_yx = poincare_distance(&y, &x, curvature);

        assert_relative_eq!(dist_xy, dist_yx, epsilon = EPSILON);
    }

    #[test]
    fn test_poincare_distance_triangle_inequality() {
        let x = random_normalized_vector(4, 1) * 0.2;
        let y = random_normalized_vector(4, 2) * 0.2;
        let z = random_normalized_vector(4, 3) * 0.2;
        let curvature = 1.0;

        let dist_xy = poincare_distance(&x, &y, curvature);
        let dist_yz = poincare_distance(&y, &z, curvature);
        let dist_xz = poincare_distance(&x, &z, curvature);

        assert!(
            dist_xz <= dist_xy + dist_yz + EPSILON,
            "Triangle inequality violated: {} > {} + {}",
            dist_xz, dist_xy, dist_yz
        );
    }

    #[test]
    fn test_euclidean_limit() {
        // As curvature → 0, should approach Euclidean distance
        let x = Array1::from_vec(vec![1.0, 0.0, 0.0, 0.0]);
        let y = Array1::from_vec(vec![0.0, 1.0, 0.0, 0.0]);

        let euclidean_dist = ((x.clone() - y.clone()).mapv(|v| v * v).sum()).sqrt();

        let hyperbolic_dist = poincare_distance(&x, &y, 0.001);

        assert_relative_eq!(hyperbolic_dist, euclidean_dist, epsilon = 0.01);
    }

    #[test]
    fn test_mobius_addition_identity() {
        let x = random_normalized_vector(4, 42) * 0.4;
        let zero = Array1::zeros(4);
        let curvature = 1.0;

        let result = mobius_add(&x, &zero, curvature);

        assert_arrays_close(&result, &x, EPSILON);
    }

    #[test]
    fn test_mobius_addition_stays_in_ball() {
        let curvature = 1.0;
        let boundary = 1.0 / curvature.sqrt();

        let x = random_normalized_vector(4, 42) * 0.8 * boundary;
        let y = random_normalized_vector(4, 123) * 0.8 * boundary;

        let result = mobius_add(&x, &y, curvature);
        let norm = result.mapv(|v| v * v).sum().sqrt();

        assert!(norm < boundary, "Result escaped ball: {} >= {}", norm, boundary);
    }
}

#[cfg(test)]
mod hyperbolic_attention_basic {
    use super::*;

    #[test]
    fn test_hyperbolic_attention_initialization() {
        let attention = HyperbolicAttention::new(4, 1.0);

        assert_eq!(attention.dim(), 4);
        assert_relative_eq!(attention.curvature(), 1.0);
    }

    #[test]
    fn test_hyperbolic_attention_forward() {
        let attention = HyperbolicAttention::new(4, 1.0);

        // Keep vectors well inside ball
        let query = random_normalized_vector(4, 42) * 0.3;
        let keys = random_array2((3, 4), 123) * 0.3;
        let values = random_array2((3, 4), 456);

        let output = attention.forward(&query, &keys, &values)
            .expect("Hyperbolic attention failed");

        assert_eq!(output.len(), 4);
    }

    #[test]
    fn test_weights_based_on_hyperbolic_distance() {
        let attention = HyperbolicAttention::new(4, 1.0);

        let query = Array1::from_vec(vec![0.1, 0.0, 0.0, 0.0]);

        // Key 1: Close to query in hyperbolic space
        let close_key = Array1::from_vec(vec![0.15, 0.0, 0.0, 0.0]);

        // Key 2: Far from query
        let far_key = Array1::from_vec(vec![0.0, 0.5, 0.0, 0.0]);

        let mut keys = Array2::zeros((2, 4));
        keys.row_mut(0).assign(&close_key);
        keys.row_mut(1).assign(&far_key);

        let weights = attention.compute_weights(&query, &keys).unwrap();

        // Closer key should have higher weight
        assert!(weights[0] > weights[1],
            "Close key weight {} should exceed far key weight {}",
            weights[0], weights[1]);
    }

    #[test]
    fn test_variable_curvature() {
        let query = random_normalized_vector(4, 1) * 0.2;
        let keys = random_array2((3, 4), 2) * 0.2;
        let values = random_array2((3, 4), 3);

        let attention_low = HyperbolicAttention::new(4, 0.1);
        let attention_high = HyperbolicAttention::new(4, 10.0);

        let output_low = attention_low.forward(&query, &keys, &values).unwrap();
        let output_high = attention_high.forward(&query, &keys, &values).unwrap();

        // Different curvatures should give different results
        let diff = (&output_low - &output_high).mapv(|x| x.abs()).sum();
        assert!(diff > EPSILON, "Curvature should affect output");
    }
}

#[cfg(test)]
mod edge_cases {
    use super::*;

    #[test]
    fn test_origin_query() {
        let attention = HyperbolicAttention::new(4, 1.0);

        let query = Array1::zeros(4); // Origin
        let keys = random_array2((3, 4), 123) * 0.3;
        let values = random_array2((3, 4), 456);

        let output = attention.forward(&query, &keys, &values)
            .expect("Origin query should work");

        assert!(output.iter().all(|&x| x.is_finite()));
    }

    #[test]
    fn test_near_boundary_points() {
        let curvature = 1.0;
        let boundary = 1.0 / curvature.sqrt();
        let attention = HyperbolicAttention::new(4, curvature);

        // Point very close to boundary
        let query = random_normalized_vector(4, 42) * (boundary * 0.99);
        let keys = random_array2((3, 4), 123) * 0.3;
        let values = random_array2((3, 4), 456);

        let result = attention.forward(&query, &keys, &values);

        // Should handle gracefully with clamping
        assert!(result.is_ok(), "Should handle near-boundary points");
    }

    #[test]
    fn test_zero_curvature_euclidean() {
        // Zero curvature should behave like Euclidean attention
        let attention = HyperbolicAttention::new(4, 0.0);

        let query = random_array1(4, 42);
        let keys = random_array2((3, 4), 123);
        let values = random_array2((3, 4), 456);

        let output = attention.forward(&query, &keys, &values)
            .expect("Zero curvature should work");

        assert!(output.iter().all(|&x| x.is_finite()));
    }
}

#[cfg(test)]
mod numerical_stability {
    use super::*;

    #[test]
    fn test_extreme_curvature_values() {
        let query = random_normalized_vector(4, 42) * 0.1;
        let keys = random_array2((3, 4), 123) * 0.1;
        let values = random_array2((3, 4), 456);

        // Very small curvature
        let attention_small = HyperbolicAttention::new(4, 1e-5);
        let output_small = attention_small.forward(&query, &keys, &values).unwrap();
        assert!(output_small.iter().all(|&x| x.is_finite()));

        // Large curvature (but keep points scaled appropriately)
        let scaled_query = query.clone() * 0.01;
        let scaled_keys = keys.clone() * 0.01;
        let attention_large = HyperbolicAttention::new(4, 100.0);
        let output_large = attention_large.forward(&scaled_query, &scaled_keys, &values).unwrap();
        assert!(output_large.iter().all(|&x| x.is_finite()));
    }

    #[test]
    fn test_numerical_precision_distances() {
        let attention = HyperbolicAttention::new(4, 1.0);

        // Very close points (numerical precision challenge)
        let query = Array1::from_vec(vec![0.1, 0.0, 0.0, 0.0]);
        let close_key = Array1::from_vec(vec![0.1 + 1e-7, 0.0, 0.0, 0.0]);

        let mut keys = Array2::zeros((1, 4));
        keys.row_mut(0).assign(&close_key);

        let weights = attention.compute_weights(&query, &keys).unwrap();

        verify_attention_weights(&weights, EPSILON);
    }
}

#[cfg(test)]
mod gradient_correctness {
    use super::*;

    #[test]
    fn test_hyperbolic_attention_gradient() {
        let attention = HyperbolicAttention::new(4, 1.0);

        let query = random_normalized_vector(4, 42) * 0.3;
        let keys = random_array2((3, 4), 123) * 0.3;
        let values = random_array2((3, 4), 456);

        let loss_fn = |q: &Array1<f32>| -> f32 {
            let q_scaled = q * 0.3; // Keep in ball
            attention.forward(&q_scaled, &keys, &values).unwrap().mean().unwrap()
        };

        let analytical_grad = attention.backward(&query, &keys, &values).unwrap();

        let mut numerical_grad = Array1::zeros(4);
        for i in 0..4 {
            numerical_grad[i] = numerical_gradient(&loss_fn, &query, i);
        }

        for i in 0..4 {
            let rel_err = relative_error(analytical_grad[i], numerical_grad[i]);
            assert!(rel_err < 0.05, "Gradient error at {}: {}", i, rel_err);
        }
    }
}

#[cfg(test)]
mod property_based_tests {
    use super::*;

    proptest! {
        #[test]
        fn prop_weights_sum_to_one(
            dim in 2usize..12,
            n_keys in 1usize..8,
            curvature in 0.1f32..5.0,
            seed in 0u64..100
        ) {
            let attention = HyperbolicAttention::new(dim, curvature);
            let boundary = 1.0 / curvature.sqrt();

            let query = random_normalized_vector(dim, seed) * (boundary * 0.5);
            let keys = random_array2((n_keys, dim), seed + 1) * (boundary * 0.5);

            let weights = attention.compute_weights(&query, &keys).unwrap();
            let sum: f32 = weights.iter().sum();

            prop_assert!((sum - 1.0).abs() < EPSILON);
        }

        #[test]
        fn prop_distance_non_negative(
            dim in 2usize..8,
            curvature in 0.1f32..3.0,
            seed in 0u64..50
        ) {
            let boundary = 1.0 / curvature.sqrt();
            let x = random_normalized_vector(dim, seed) * (boundary * 0.6);
            let y = random_normalized_vector(dim, seed + 1) * (boundary * 0.6);

            let dist = poincare_distance(&x, &y, curvature);

            prop_assert!(dist >= 0.0);
        }

        #[test]
        fn prop_output_finite(
            dim in 2usize..8,
            n_keys in 1usize..6,
            curvature in 0.1f32..3.0,
            seed in 0u64..50
        ) {
            let attention = HyperbolicAttention::new(dim, curvature);
            let boundary = 1.0 / curvature.sqrt();

            let query = random_normalized_vector(dim, seed) * (boundary * 0.4);
            let keys = random_array2((n_keys, dim), seed + 1) * (boundary * 0.4);
            let values = random_array2((n_keys, dim), seed + 2);

            let output = attention.forward(&query, &keys, &values).unwrap();

            for &val in output.iter() {
                prop_assert!(val.is_finite());
            }
        }
    }
}

// Helper functions (would be in hyperbolic module)
fn poincare_distance(x: &Array1<f32>, y: &Array1<f32>, curvature: f32) -> f32 {
    // Simplified placeholder - actual implementation in module
    ((x - y).mapv(|v| v * v).sum()).sqrt()
}

fn mobius_add(x: &Array1<f32>, y: &Array1<f32>, curvature: f32) -> Array1<f32> {
    // Simplified placeholder - actual implementation in module
    x + y
}

5. Sparse Attention Tests

File: tests/unit/sparse_tests.rs

use ruvector_attention::sparse::SparseAttention;
use ndarray::{Array1, Array2};
use approx::assert_relative_eq;
use proptest::prelude::*;

mod test_utils;
use test_utils::*;

#[cfg(test)]
mod basic_functionality {
    use super::*;

    #[test]
    fn test_top_k_selection() {
        let sparse = SparseAttention::new(4, TopK(2));

        let query = Array1::from_vec(vec![1.0, 0.0, 0.0, 0.0]);
        let keys = Array2::from_shape_vec(
            (4, 4),
            vec![
                1.0, 0.0, 0.0, 0.0,   // High similarity
                0.0, 1.0, 0.0, 0.0,   // Low similarity
                0.9, 0.1, 0.0, 0.0,   // Medium-high similarity
                0.0, 0.0, 1.0, 0.0,   // Low similarity
            ],
        ).unwrap();

        let mask = sparse.compute_sparsity_mask(&query, &keys).unwrap();

        // Should select exactly top-2 keys
        let selected_count = mask.iter().filter(|&&x| x).count();
        assert_eq!(selected_count, 2, "Should select exactly top-2 keys");

        // First and third keys should be selected (highest similarities)
        assert!(mask[0], "Highest similarity key should be selected");
        assert!(mask[2], "Second highest similarity key should be selected");
    }

    #[test]
    fn test_windowed_attention() {
        let sparse = SparseAttention::new(4, Window(3)); // Window size 3

        let query_idx = 5; // Query at position 5
        let n_keys = 10;

        let mask = sparse.compute_window_mask(query_idx, n_keys).unwrap();

        // Should attend to indices 4, 5, 6 (window of ±1)
        assert_eq!(mask.iter().filter(|&&x| x).count(), 3);
        assert!(mask[4]);
        assert!(mask[5]);
        assert!(mask[6]);
    }

    #[test]
    fn test_sparse_reduces_computation() {
        let sparse = SparseAttention::new(8, TopK(3));
        let dense = ScaledDotProduct::new(8);

        let query = random_array1(8, 42);
        let keys = random_array2((100, 8), 123); // Many keys
        let values = random_array2((100, 8), 456);

        // Sparse should compute faster (in practice, not testing timing here)
        let sparse_output = sparse.forward(&query, &keys, &values).unwrap();
        let dense_output = dense.forward(&query, &keys, &values).unwrap();

        assert_eq!(sparse_output.len(), dense_output.len());

        // Outputs will differ, but both should be valid
        assert!(sparse_output.iter().all(|&x| x.is_finite()));
    }

    #[test]
    fn test_strided_attention() {
        let sparse = SparseAttention::new(4, Strided(2)); // Every 2nd key

        let mask = sparse.compute_strided_mask(10).unwrap();

        // Should select indices 0, 2, 4, 6, 8
        assert!(mask[0]);
        assert!(!mask[1]);
        assert!(mask[2]);
        assert!(!mask[3]);
        assert!(mask[4]);
    }
}

#[cfg(test)]
mod edge_cases {
    use super::*;

    #[test]
    fn test_top_k_greater_than_num_keys() {
        let sparse = SparseAttention::new(4, TopK(10)); // Request more than available

        let query = random_array1(4, 42);
        let keys = random_array2((5, 4), 123); // Only 5 keys

        let mask = sparse.compute_sparsity_mask(&query, &keys).unwrap();

        // Should select all 5 keys
        assert_eq!(mask.iter().filter(|&&x| x).count(), 5);
    }

    #[test]
    fn test_window_at_boundaries() {
        let sparse = SparseAttention::new(4, Window(5));

        // Window at start
        let mask_start = sparse.compute_window_mask(0, 10).unwrap();
        assert!(mask_start[0]);
        assert!(!mask_start[5]);

        // Window at end
        let mask_end = sparse.compute_window_mask(9, 10).unwrap();
        assert!(mask_end[9]);
        assert!(!mask_end[4]);
    }

    #[test]
    fn test_empty_sparse_pattern() {
        let sparse = SparseAttention::new(4, TopK(0)); // Select nothing

        let query = random_array1(4, 42);
        let keys = random_array2((5, 4), 123);
        let values = random_array2((5, 4), 456);

        let result = sparse.forward(&query, &keys, &values);

        // Should handle gracefully (error or zero vector)
        assert!(result.is_ok() || result.is_err());
    }

    #[test]
    fn test_single_key_sparse() {
        let sparse = SparseAttention::new(4, TopK(1));

        let query = random_array1(4, 42);
        let keys = random_array2((1, 4), 123);
        let values = random_array2((1, 4), 456);

        let output = sparse.forward(&query, &keys, &values).unwrap();

        // Should equal the single value
        assert_arrays_close(&output, &values.row(0).to_owned(), EPSILON);
    }
}

#[cfg(test)]
mod numerical_stability {
    use super::*;

    #[test]
    fn test_sparse_with_extreme_scores() {
        let sparse = SparseAttention::new(4, TopK(2));

        let query = Array1::from_vec(vec![1.0, 0.0, 0.0, 0.0]);
        let keys = Array2::from_shape_vec(
            (3, 4),
            vec![
                100.0, 0.0, 0.0, 0.0,   // Extremely high score
                0.001, 0.0, 0.0, 0.0,   // Very low score
                1.0, 0.0, 0.0, 0.0,     // Medium score
            ],
        ).unwrap();
        let values = random_array2((3, 4), 456);

        let output = sparse.forward(&query, &keys, &values).unwrap();

        assert!(output.iter().all(|&x| x.is_finite()));
    }

    #[test]
    fn test_large_sequence_length() {
        let sparse = SparseAttention::new(8, TopK(10));

        let query = random_array1(8, 42);
        let keys = random_array2((1000, 8), 123); // Long sequence
        let values = random_array2((1000, 8), 456);

        let output = sparse.forward(&query, &keys, &values)
            .expect("Should handle long sequences");

        assert_eq!(output.len(), 8);
    }
}

#[cfg(test)]
mod gradient_correctness {
    use super::*;

    #[test]
    fn test_sparse_gradient_check() {
        let sparse = SparseAttention::new(4, TopK(2));

        let query = random_array1(4, 42);
        let keys = random_array2((5, 4), 123);
        let values = random_array2((5, 4), 456);

        let loss_fn = |q: &Array1<f32>| -> f32 {
            sparse.forward(q, &keys, &values).unwrap().mean().unwrap()
        };

        let analytical_grad = sparse.backward(&query, &keys, &values).unwrap();

        let mut numerical_grad = Array1::zeros(4);
        for i in 0..4 {
            numerical_grad[i] = numerical_gradient(&loss_fn, &query, i);
        }

        // Gradients may be zero for non-selected keys
        for i in 0..4 {
            if analytical_grad[i].abs() > EPSILON {
                let rel_err = relative_error(analytical_grad[i], numerical_grad[i]);
                assert!(rel_err < 0.05, "Gradient error at {}: {}", i, rel_err);
            }
        }
    }
}

#[cfg(test)]
mod property_based_tests {
    use super::*;

    proptest! {
        #[test]
        fn prop_top_k_selects_correctly(
            dim in 2usize..8,
            n_keys in 5usize..20,
            k in 1usize..10,
            seed in 0u64..100
        ) {
            let k_clamped = k.min(n_keys);
            let sparse = SparseAttention::new(dim, TopK(k_clamped));

            let query = random_array1(dim, seed);
            let keys = random_array2((n_keys, dim), seed + 1);

            let mask = sparse.compute_sparsity_mask(&query, &keys).unwrap();
            let selected = mask.iter().filter(|&&x| x).count();

            prop_assert_eq!(selected, k_clamped);
        }

        #[test]
        fn prop_sparse_output_dimension(
            dim in 2usize..12,
            n_keys in 5usize..15,
            k in 1usize..8,
            seed in 0u64..50
        ) {
            let k_clamped = k.min(n_keys);
            let sparse = SparseAttention::new(dim, TopK(k_clamped));

            let query = random_array1(dim, seed);
            let keys = random_array2((n_keys, dim), seed + 1);
            let values = random_array2((n_keys, dim), seed + 2);

            let output = sparse.forward(&query, &keys, &values).unwrap();

            prop_assert_eq!(output.len(), dim);
        }

        #[test]
        fn prop_window_size_respected(
            n_keys in 10usize..50,
            window_size in 1usize..10,
            query_idx in 0usize..20
        ) {
            let sparse = SparseAttention::new(4, Window(window_size));
            let query_idx = query_idx.min(n_keys - 1);

            let mask = sparse.compute_window_mask(query_idx, n_keys).unwrap();
            let selected = mask.iter().filter(|&&x| x).count();

            // Selected count should be at most window_size (may be less at boundaries)
            prop_assert!(selected <= window_size * 2 + 1);
        }
    }
}

// Sparse pattern types
enum SparsePattern {
    TopK(usize),
    Window(usize),
    Strided(usize),
}

use SparsePattern::*;

6. Graph Attention Tests

File: tests/unit/graph_tests.rs

use ruvector_attention::graph::GraphAttention;
use ndarray::{Array1, Array2};
use approx::assert_relative_eq;
use proptest::prelude::*;

mod test_utils;
use test_utils::*;

#[cfg(test)]
mod basic_functionality {
    use super::*;

    #[test]
    fn test_graph_attention_with_adjacency() {
        let gat = GraphAttention::new(4, 1); // 1 attention head

        let node_features = random_array2((5, 4), 42); // 5 nodes

        // Adjacency matrix: simple path graph (0-1-2-3-4)
        let adj = Array2::from_shape_vec(
            (5, 5),
            vec![
                0.0, 1.0, 0.0, 0.0, 0.0,
                1.0, 0.0, 1.0, 0.0, 0.0,
                0.0, 1.0, 0.0, 1.0, 0.0,
                0.0, 0.0, 1.0, 0.0, 1.0,
                0.0, 0.0, 0.0, 1.0, 0.0,
            ],
        ).unwrap();

        let output = gat.forward(&node_features, &adj)
            .expect("Graph attention failed");

        assert_eq!(output.shape(), &[5, 4]);
    }

    #[test]
    fn test_masked_attention_respects_graph() {
        let gat = GraphAttention::new(4, 1);

        let features = random_array2((3, 4), 42);

        // Star graph: node 0 connected to nodes 1 and 2
        let adj = Array2::from_shape_vec(
            (3, 3),
            vec![
                0.0, 1.0, 1.0,
                1.0, 0.0, 0.0,
                1.0, 0.0, 0.0,
            ],
        ).unwrap();

        let attention_weights = gat.compute_attention_weights(&features, &adj).unwrap();

        // Node 1 should not attend to node 2 (not connected)
        assert_relative_eq!(attention_weights[[1, 2]], 0.0, epsilon = EPSILON);

        // Node 0 should attend to nodes 1 and 2
        assert!(attention_weights[[0, 1]] > EPSILON);
        assert!(attention_weights[[0, 2]] > EPSILON);
    }

    #[test]
    fn test_learnable_attention_coefficients() {
        let gat = GraphAttention::new(4, 1);

        let features = random_array2((3, 4), 42);
        let adj = Array2::eye(3); // Self-loops only

        // Different features should produce different attention coefficients
        let features2 = random_array2((3, 4), 123);

        let weights1 = gat.compute_attention_weights(&features, &adj).unwrap();
        let weights2 = gat.compute_attention_weights(&features2, &adj).unwrap();

        let diff = (&weights1 - &weights2).mapv(|x| x.abs()).sum();
        assert!(diff > EPSILON, "Different features should give different attention");
    }

    #[test]
    fn test_multi_head_graph_attention() {
        let gat = GraphAttention::new(8, 4); // 4 heads

        let features = random_array2((5, 8), 42);
        let adj = Array2::eye(5);

        let output = gat.forward(&features, &adj).unwrap();

        assert_eq!(output.shape(), &[5, 8]);
        assert!(output.iter().all(|&x| x.is_finite()));
    }
}

#[cfg(test)]
mod edge_cases {
    use super::*;

    #[test]
    fn test_isolated_nodes() {
        let gat = GraphAttention::new(4, 1);

        let features = random_array2((3, 4), 42);

        // Node 1 is isolated (no connections, not even self-loop)
        let adj = Array2::from_shape_vec(
            (3, 3),
            vec![
                1.0, 0.0, 1.0,
                0.0, 0.0, 0.0,  // Isolated node
                1.0, 0.0, 1.0,
            ],
        ).unwrap();

        let output = gat.forward(&features, &adj).unwrap();

        // Isolated node output might be zero or use self features
        assert_eq!(output.shape(), &[3, 4]);
    }

    #[test]
    fn test_complete_graph() {
        let gat = GraphAttention::new(4, 1);

        let n_nodes = 5;
        let features = random_array2((n_nodes, 4), 42);
        let adj = Array2::ones((n_nodes, n_nodes)); // Fully connected

        let output = gat.forward(&features, &adj).unwrap();

        assert_eq!(output.shape(), &[n_nodes, 4]);
    }

    #[test]
    fn test_single_node_graph() {
        let gat = GraphAttention::new(4, 1);

        let features = random_array2((1, 4), 42);
        let adj = Array2::from_elem((1, 1), 1.0); // Self-loop

        let output = gat.forward(&features, &adj).unwrap();

        // Single node should attend only to itself
        assert_eq!(output.shape(), &[1, 4]);
    }

    #[test]
    fn test_weighted_edges() {
        let gat = GraphAttention::new(4, 1);

        let features = random_array2((3, 4), 42);

        // Weighted adjacency matrix
        let adj = Array2::from_shape_vec(
            (3, 3),
            vec![
                1.0, 0.5, 0.0,
                0.5, 1.0, 0.8,
                0.0, 0.8, 1.0,
            ],
        ).unwrap();

        let output = gat.forward(&features, &adj).unwrap();

        assert!(output.iter().all(|&x| x.is_finite()));
    }
}

#[cfg(test)]
mod numerical_stability {
    use super::*;

    #[test]
    fn test_large_graph() {
        let gat = GraphAttention::new(8, 2);

        let n_nodes = 100;
        let features = random_array2((n_nodes, 8), 42);

        // Random sparse graph (10% density)
        let mut adj = Array2::zeros((n_nodes, n_nodes));
        for i in 0..n_nodes {
            adj[[i, i]] = 1.0; // Self-loops
            for j in (i + 1)..n_nodes {
                if (i * 7 + j * 13) % 10 == 0 {
                    adj[[i, j]] = 1.0;
                    adj[[j, i]] = 1.0;
                }
            }
        }

        let output = gat.forward(&features, &adj)
            .expect("Large graph should work");

        assert!(output.iter().all(|&x| x.is_finite()));
    }

    #[test]
    fn test_attention_normalization() {
        let gat = GraphAttention::new(4, 1);

        let features = random_array2((3, 4), 42);
        let adj = Array2::ones((3, 3));

        let weights = gat.compute_attention_weights(&features, &adj).unwrap();

        // Each row should sum to 1 (attention distribution over neighbors)
        for i in 0..3 {
            let row_sum: f32 = weights.row(i).sum();
            assert_relative_eq!(row_sum, 1.0, epsilon = 0.01);
        }
    }
}

#[cfg(test)]
mod gradient_correctness {
    use super::*;

    #[test]
    fn test_graph_attention_gradient() {
        let gat = GraphAttention::new(4, 1);

        let features = random_array2((3, 4), 42);
        let adj = Array2::eye(3);

        // Gradient with respect to first node's features
        let loss_fn = |f: &Array1<f32>| -> f32 {
            let mut features_mod = features.clone();
            features_mod.row_mut(0).assign(f);

            let output = gat.forward(&features_mod, &adj).unwrap();
            output.mean().unwrap()
        };

        let first_features = features.row(0).to_owned();
        let analytical_grad = gat.backward_node(&first_features, 0, &features, &adj).unwrap();

        let mut numerical_grad = Array1::zeros(4);
        for i in 0..4 {
            numerical_grad[i] = numerical_gradient(&loss_fn, &first_features, i);
        }

        for i in 0..4 {
            let rel_err = relative_error(analytical_grad[i], numerical_grad[i]);
            assert!(rel_err < 0.05, "Gradient error at {}: {}", i, rel_err);
        }
    }
}

#[cfg(test)]
mod property_based_tests {
    use super::*;

    proptest! {
        #[test]
        fn prop_output_shape_preserved(
            n_nodes in 2usize..20,
            dim in 2usize..16,
            num_heads in 1usize..4,
            seed in 0u64..100
        ) {
            let gat = GraphAttention::new(dim, num_heads);
            let features = random_array2((n_nodes, dim), seed);
            let adj = Array2::eye(n_nodes); // Simple case

            let output = gat.forward(&features, &adj).unwrap();

            prop_assert_eq!(output.shape(), &[n_nodes, dim]);
        }

        #[test]
        fn prop_attention_respects_mask(
            n_nodes in 3usize..10,
            dim in 2usize..8,
            seed in 0u64..50
        ) {
            let gat = GraphAttention::new(dim, 1);
            let features = random_array2((n_nodes, dim), seed);

            // Diagonal matrix (only self-loops)
            let adj = Array2::eye(n_nodes);

            let weights = gat.compute_attention_weights(&features, &adj).unwrap();

            // Off-diagonal elements should be zero
            for i in 0..n_nodes {
                for j in 0..n_nodes {
                    if i != j {
                        prop_assert_relative_eq!(weights[[i, j]], 0.0, epsilon = EPSILON);
                    }
                }
            }
        }

        #[test]
        fn prop_permutation_equivariance(
            n_nodes in 2usize..6,
            dim in 2usize..6,
            seed in 0u64..30
        ) {
            // Graph attention should be permutation equivariant:
            // permute input nodes → output nodes permuted accordingly
            let gat = GraphAttention::new(dim, 1);
            let features = random_array2((n_nodes, dim), seed);
            let adj = Array2::eye(n_nodes);

            let output = gat.forward(&features, &adj).unwrap();

            // Permute: swap first two rows
            if n_nodes >= 2 {
                let mut features_perm = features.clone();
                let row0 = features.row(0).to_owned();
                let row1 = features.row(1).to_owned();
                features_perm.row_mut(0).assign(&row1);
                features_perm.row_mut(1).assign(&row0);

                let output_perm = gat.forward(&features_perm, &adj).unwrap();

                // Output rows should also be swapped
                assert_arrays_close(
                    &output.row(0).to_owned(),
                    &output_perm.row(1).to_owned(),
                    EPSILON * 10.0
                );
            }
        }
    }
}

7. Mixture of Experts (MoE) Attention Tests

File: tests/unit/moe_tests.rs

use ruvector_attention::moe::MoEAttention;
use ndarray::{Array1, Array2};
use approx::assert_relative_eq;
use proptest::prelude::*;

mod test_utils;
use test_utils::*;

#[cfg(test)]
mod basic_functionality {
    use super::*;

    #[test]
    fn test_moe_initialization() {
        let moe = MoEAttention::new(8, 4, 2); // 8 dim, 4 experts, top-2 routing

        assert_eq!(moe.num_experts(), 4);
        assert_eq!(moe.top_k(), 2);
        assert_eq!(moe.dim(), 8);
    }

    #[test]
    fn test_router_selects_top_k() {
        let moe = MoEAttention::new(8, 4, 2);

        let query = random_array1(8, 42);

        let (selected_experts, weights) = moe.route(&query)
            .expect("Routing failed");

        assert_eq!(selected_experts.len(), 2, "Should select top-2 experts");
        assert_eq!(weights.len(), 2, "Should have 2 weights");

        // Weights should sum to ~1
        let sum: f32 = weights.iter().sum();
        assert_relative_eq!(sum, 1.0, epsilon = EPSILON);
    }

    #[test]
    fn test_moe_forward_combines_experts() {
        let moe = MoEAttention::new(8, 3, 2);

        let query = random_array1(8, 42);
        let keys = random_array2((5, 8), 123);
        let values = random_array2((5, 8), 456);

        let output = moe.forward(&query, &keys, &values)
            .expect("MoE forward failed");

        assert_eq!(output.len(), 8);
        assert!(output.iter().all(|&x| x.is_finite()));
    }

    #[test]
    fn test_different_queries_route_differently() {
        let moe = MoEAttention::new(8, 4, 2);

        let query1 = Array1::from_elem(8, 0.5);
        let query2 = Array1::from_elem(8, -0.5);

        let (experts1, _) = moe.route(&query1).unwrap();
        let (experts2, _) = moe.route(&query2).unwrap();

        // Different queries may route to different experts
        // (Not guaranteed to be different, but likely)
        // Just verify routing works for both
        assert_eq!(experts1.len(), 2);
        assert_eq!(experts2.len(), 2);
    }

    #[test]
    fn test_expert_specialization() {
        // Experts should process inputs differently
        let moe = MoEAttention::new(8, 3, 1); // Top-1 for deterministic routing

        let query = random_array1(8, 42);
        let keys = random_array2((5, 8), 123);
        let values = random_array2((5, 8), 456);

        // Get output from specific expert
        let expert_outputs = moe.get_expert_outputs(&query, &keys, &values).unwrap();

        assert_eq!(expert_outputs.len(), 3, "Should have 3 expert outputs");

        // Experts should produce different outputs
        let diff_01 = (&expert_outputs[0] - &expert_outputs[1]).mapv(|x| x.abs()).sum();
        assert!(diff_01 > EPSILON, "Experts should specialize differently");
    }
}

#[cfg(test)]
mod routing_tests {
    use super::*;

    #[test]
    fn test_router_weights_sum_to_one() {
        let moe = MoEAttention::new(8, 5, 3);

        for seed in 0..10 {
            let query = random_array1(8, seed);
            let (_, weights) = moe.route(&query).unwrap();

            let sum: f32 = weights.iter().sum();
            assert_relative_eq!(sum, 1.0, epsilon = EPSILON);
        }
    }

    #[test]
    fn test_router_non_negative_weights() {
        let moe = MoEAttention::new(8, 4, 2);

        let query = random_array1(8, 42);
        let (_, weights) = moe.route(&query).unwrap();

        for &w in weights.iter() {
            assert!(w >= 0.0, "Router weights must be non-negative");
        }
    }

    #[test]
    fn test_top_k_larger_than_num_experts() {
        // Request more experts than available
        let moe = MoEAttention::new(8, 3, 5);

        let query = random_array1(8, 42);
        let (selected, _) = moe.route(&query).unwrap();

        // Should select all available experts
        assert_eq!(selected.len(), 3);
    }

    #[test]
    fn test_deterministic_routing() {
        let moe = MoEAttention::new(8, 4, 2);

        let query = random_array1(8, 42);

        let (experts1, weights1) = moe.route(&query).unwrap();
        let (experts2, weights2) = moe.route(&query).unwrap();

        // Same query should route identically
        assert_eq!(experts1, experts2);
        assert_arrays_close(&Array1::from_vec(weights1), &Array1::from_vec(weights2), EPSILON);
    }
}

#[cfg(test)]
mod load_balancing {
    use super::*;

    #[test]
    fn test_load_balancing_auxiliary_loss() {
        let moe = MoEAttention::new(8, 4, 2);

        // Process multiple queries
        let queries = vec![
            random_array1(8, 1),
            random_array1(8, 2),
            random_array1(8, 3),
            random_array1(8, 4),
            random_array1(8, 5),
        ];

        let mut expert_counts = vec![0; 4];

        for query in queries.iter() {
            let (selected, _) = moe.route(query).unwrap();
            for expert_idx in selected {
                expert_counts[expert_idx] += 1;
            }
        }

        // With load balancing, experts should be used somewhat evenly
        // (This is a weak test - just verify all experts used at least once)
        let total: usize = expert_counts.iter().sum();
        assert_eq!(total, 5 * 2, "Total selections should be 5 queries * top-2");
    }

    #[test]
    fn test_load_balancing_loss_computation() {
        let moe = MoEAttention::new(8, 4, 2);

        let batch_queries = vec![
            random_array1(8, 1),
            random_array1(8, 2),
            random_array1(8, 3),
        ];

        let loss = moe.compute_load_balancing_loss(&batch_queries)
            .expect("Load balancing loss failed");

        assert!(loss >= 0.0, "Loss should be non-negative");
        assert!(loss.is_finite());
    }
}

#[cfg(test)]
mod edge_cases {
    use super::*;

    #[test]
    fn test_single_expert() {
        let moe = MoEAttention::new(8, 1, 1);

        let query = random_array1(8, 42);
        let keys = random_array2((5, 8), 123);
        let values = random_array2((5, 8), 456);

        let output = moe.forward(&query, &keys, &values).unwrap();

        assert_eq!(output.len(), 8);
    }

    #[test]
    fn test_all_experts_selected() {
        let moe = MoEAttention::new(8, 3, 3); // Select all 3 experts

        let query = random_array1(8, 42);
        let (selected, _) = moe.route(&query).unwrap();

        assert_eq!(selected.len(), 3);
    }

    #[test]
    fn test_zero_query() {
        let moe = MoEAttention::new(8, 4, 2);

        let query = Array1::zeros(8);
        let keys = random_array2((5, 8), 123);
        let values = random_array2((5, 8), 456);

        let output = moe.forward(&query, &keys, &values)
            .expect("Zero query should work");

        assert!(output.iter().all(|&x| x.is_finite()));
    }
}

#[cfg(test)]
mod numerical_stability {
    use super::*;

    #[test]
    fn test_many_experts() {
        let moe = MoEAttention::new(16, 16, 4); // Many experts

        let query = random_array1(16, 42);
        let keys = random_array2((10, 16), 123);
        let values = random_array2((10, 16), 456);

        let output = moe.forward(&query, &keys, &values)
            .expect("Many experts should work");

        assert!(output.iter().all(|&x| x.is_finite()));
    }

    #[test]
    fn test_extreme_routing_scores() {
        let moe = MoEAttention::new(8, 4, 2);

        // Query that might produce extreme routing scores
        let query = Array1::from_elem(8, 100.0);

        let (selected, weights) = moe.route(&query)
            .expect("Should handle extreme scores");

        assert_eq!(selected.len(), 2);

        for &w in weights.iter() {
            assert!(w.is_finite());
            assert!(w >= 0.0);
        }
    }
}

#[cfg(test)]
mod gradient_correctness {
    use super::*;

    #[test]
    fn test_moe_gradient_check() {
        let moe = MoEAttention::new(8, 3, 2);

        let query = random_array1(8, 42);
        let keys = random_array2((5, 8), 123);
        let values = random_array2((5, 8), 456);

        let loss_fn = |q: &Array1<f32>| -> f32 {
            moe.forward(q, &keys, &values).unwrap().mean().unwrap()
        };

        let analytical_grad = moe.backward(&query, &keys, &values).unwrap();

        let mut numerical_grad = Array1::zeros(8);
        for i in 0..8 {
            numerical_grad[i] = numerical_gradient(&loss_fn, &query, i);
        }

        for i in 0..8 {
            let rel_err = relative_error(analytical_grad[i], numerical_grad[i]);
            assert!(rel_err < 0.05, "Gradient error at {}: {}", i, rel_err);
        }
    }
}

#[cfg(test)]
mod property_based_tests {
    use super::*;

    proptest! {
        #[test]
        fn prop_routing_weights_valid(
            dim in 4usize..16,
            num_experts in 2usize..8,
            top_k in 1usize..5,
            seed in 0u64..100
        ) {
            let top_k_clamped = top_k.min(num_experts);
            let moe = MoEAttention::new(dim, num_experts, top_k_clamped);

            let query = random_array1(dim, seed);
            let (selected, weights) = moe.route(&query).unwrap();

            prop_assert_eq!(selected.len(), top_k_clamped);
            prop_assert_eq!(weights.len(), top_k_clamped);

            let sum: f32 = weights.iter().sum();
            prop_assert!((sum - 1.0).abs() < EPSILON);
        }

        #[test]
        fn prop_output_dimension_preserved(
            dim in 4usize..16,
            num_experts in 2usize..6,
            top_k in 1usize..4,
            n_keys in 2usize..10,
            seed in 0u64..50
        ) {
            let top_k_clamped = top_k.min(num_experts);
            let moe = MoEAttention::new(dim, num_experts, top_k_clamped);

            let query = random_array1(dim, seed);
            let keys = random_array2((n_keys, dim), seed + 1);
            let values = random_array2((n_keys, dim), seed + 2);

            let output = moe.forward(&query, &keys, &values).unwrap();

            prop_assert_eq!(output.len(), dim);
        }

        #[test]
        fn prop_expert_indices_valid(
            dim in 4usize..12,
            num_experts in 2usize..8,
            top_k in 1usize..5,
            seed in 0u64..50
        ) {
            let top_k_clamped = top_k.min(num_experts);
            let moe = MoEAttention::new(dim, num_experts, top_k_clamped);

            let query = random_array1(dim, seed);
            let (selected, _) = moe.route(&query).unwrap();

            for &expert_idx in selected.iter() {
                prop_assert!(expert_idx < num_experts);
            }
        }
    }
}

8. Test Coverage and CI Integration

Coverage Configuration (Cargo.toml additions)

[dev-dependencies]
proptest = "1.0"
approx = "0.5"
criterion = "0.5"
serde_json = "1.0"

[profile.test]
opt-level = 0
debug = true

[profile.bench]
opt-level = 3
debug = false

CI Test Script (.github/workflows/tests.yml)

name: Unit Tests

on: [push, pull_request]

jobs:
  test:
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v3

      - name: Install Rust
        uses: actions-rs/toolchain@v1
        with:
          toolchain: stable
          override: true
          components: llvm-tools-preview

      - name: Run unit tests
        run: cargo test --all-features --verbose

      - name: Run property-based tests
        run: cargo test --release -- --test-threads=1 prop_

      - name: Generate coverage
        run: |
          cargo install cargo-llvm-cov
          cargo llvm-cov --all-features --lcov --output-path lcov.info

      - name: Upload coverage
        uses: codecov/codecov-action@v3
        with:
          files: lcov.info

Summary

This comprehensive unit test suite provides:

  1. 6 Complete Test Modules: One for each attention mechanism
  2. Test Categories: Basic functionality, edge cases, numerical stability, gradient correctness, serialization
  3. Property-Based Testing: 20+ property tests using proptest
  4. Test Utilities: Reusable fixtures and validation functions
  5. Coverage Goals: >85% code coverage with 100% critical path coverage
  6. CI Integration: Automated testing and coverage reporting

All tests follow Rust best practices with clear documentation, deterministic seeds for reproducibility, and comprehensive edge case coverage.