Files
wifi-densepose/crates/ruvector-mincut-gated-transformer/tests/sparse_attention.rs
ruv d803bfe2b1 Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector
git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
2026-02-28 14:39:40 -05:00

115 lines
3.0 KiB
Rust

//! Comprehensive tests for mincut-aware sparse attention.
#![cfg(feature = "sparse_attention")]
use ruvector_mincut_gated_transformer::{
GatePacket, LambdaDensitySchedule, MincutSparseAttention, SparseMask, SparsityConfig,
};
#[test]
fn test_sparse_mask_creation() {
let empty = SparseMask::empty();
assert_eq!(empty.num_positions(), 0);
assert_eq!(empty.density, 0.0);
assert_eq!(empty.sparsity(), 1.0);
let full = SparseMask::full(8);
assert_eq!(full.num_positions(), 36); // 8*9/2 = 36 causal positions
assert_eq!(full.density, 1.0);
assert_eq!(full.sparsity(), 0.0);
}
#[test]
fn test_sparse_mask_can_attend() {
let mut mask = SparseMask::empty();
mask.positions.push((2, 0));
mask.positions.push((2, 1));
mask.positions.push((2, 2));
assert!(mask.can_attend(2, 0));
assert!(mask.can_attend(2, 1));
assert!(mask.can_attend(2, 2));
assert!(!mask.can_attend(2, 3));
assert!(!mask.can_attend(1, 0));
}
#[test]
fn test_density_calculation_adaptive() {
let config = SparsityConfig {
lambda_based_density: Some(LambdaDensitySchedule::Adaptive),
..Default::default()
};
let sparse_attn = MincutSparseAttention::new(config);
// High lambda, low boundaries = dense
let gate_stable = GatePacket {
lambda: 200,
boundary_edges: 5,
boundary_concentration_q15: 4096,
partition_count: 2,
..Default::default()
};
let density_stable = sparse_attn.calculate_density(&gate_stable);
// Low lambda, high boundaries = sparse
let gate_unstable = GatePacket {
lambda: 50,
boundary_edges: 40,
boundary_concentration_q15: 24576,
partition_count: 8,
..Default::default()
};
let density_unstable = sparse_attn.calculate_density(&gate_unstable);
assert!(density_stable > density_unstable);
}
#[test]
fn test_mask_building_with_partitions() {
let config = SparsityConfig::default();
let sparse_attn = MincutSparseAttention::new(config);
let gate = GatePacket {
lambda: 100,
partition_count: 3,
boundary_edges: 5,
..Default::default()
};
let mask = sparse_attn.build_mask(&gate, 32);
// Should have some positions
assert!(mask.num_positions() > 0);
// Should be sparse (density < 1.0)
assert!(mask.density < 1.0);
// Should have 3 partitions
assert_eq!(mask.partition_boundaries.len(), 3);
// Positions should be causal
for &(q, k) in &mask.positions {
assert!(k <= q, "Non-causal position: ({}, {})", q, k);
}
}
#[test]
fn test_flops_estimation() {
let config = SparsityConfig::default();
let sparse_attn = MincutSparseAttention::new(config);
let gate = GatePacket {
lambda: 100,
partition_count: 4,
boundary_edges: 10,
..Default::default()
};
let mask = sparse_attn.build_mask(&gate, 64);
let ratio = sparse_attn.estimated_flops_ratio(&mask, 64);
// Should have speedup
assert!(ratio < 1.0);
assert!(ratio > 0.0);
}