Files
wifi-densepose/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/tests/test_subcarrier.rs
Claude fce1271140 feat(rust): Complete training pipeline — losses, metrics, model, trainer, binaries
Losses (losses.rs — 1056 lines):
- WiFiDensePoseLoss with keypoint (visibility-masked MSE), DensePose
  (cross-entropy + Smooth L1 UV masked to foreground), transfer (MSE)
- generate_gaussian_heatmaps: Tensor-native 2D Gaussian heatmap gen
- compute_losses: unified functional API
- 11 deterministic unit tests

Metrics (metrics.rs — 984 lines):
- PCK@0.2 / PCK@0.5 with torso-diameter normalisation
- OKS with COCO standard per-joint sigmas
- MetricsAccumulator for online streaming eval
- hungarian_assignment: O(n³) Kuhn-Munkres min-cut via DFS augmenting
  paths for optimal multi-person keypoint assignment (ruvector min-cut)
- build_oks_cost_matrix: 1−OKS cost for bipartite matching
- 20 deterministic tests (perfect/wrong/invisible keypoints, 2×2/3×3/
  rectangular/empty Hungarian cases)

Model (model.rs — 713 lines):
- WiFiDensePoseModel end-to-end with tch-rs
- ModalityTranslator: amp+phase FC encoders → spatial pseudo-image
- Backbone: lightweight ResNet-style [B,3,48,48]→[B,256,6,6]
- KeypointHead: [B,256,6,6]→[B,17,H,W] heatmaps
- DensePoseHead: [B,256,6,6]→[B,25,H,W] parts + [B,48,H,W] UV

Trainer (trainer.rs — 777 lines):
- Full training loop: Adam, LR milestones, gradient clipping
- Deterministic batch shuffle via LCG (seed XOR epoch)
- CSV logging, best-checkpoint saving, early stopping
- evaluate() with MetricsAccumulator and heatmap argmax decode

Binaries:
- src/bin/train.rs: production MM-Fi training CLI (clap)
- src/bin/verify_training.rs: trust kill switch (EXIT 0/1/2)

Benches:
- benches/training_bench.rs: criterion benchmarks for key ops

Tests:
- tests/test_dataset.rs (459 lines)
- tests/test_metrics.rs (449 lines)
- tests/test_subcarrier.rs (389 lines)

proof.rs still stub — trainer agent completing it.

https://claude.ai/code/session_01BSBAQJ34SLkiJy4A8SoiL4
2026-02-28 15:22:54 +00:00

390 lines
13 KiB
Rust

//! Integration tests for [`wifi_densepose_train::subcarrier`].
//!
//! All test data is constructed from fixed, deterministic arrays — no `rand`
//! crate or OS entropy is used. The same input always produces the same
//! output regardless of the platform or execution order.
use ndarray::Array4;
use wifi_densepose_train::subcarrier::{
compute_interp_weights, interpolate_subcarriers, select_subcarriers_by_variance,
};
// ---------------------------------------------------------------------------
// Output shape tests
// ---------------------------------------------------------------------------
/// Resampling 114 → 56 subcarriers must produce shape [T, n_tx, n_rx, 56].
#[test]
fn resample_114_to_56_output_shape() {
let t = 10_usize;
let n_tx = 3_usize;
let n_rx = 3_usize;
let src_sc = 114_usize;
let tgt_sc = 56_usize;
// Deterministic data: value = t_idx + tx + rx + k (no randomness).
let arr = Array4::<f32>::from_shape_fn((t, n_tx, n_rx, src_sc), |(ti, tx, rx, k)| {
(ti + tx + rx + k) as f32
});
let out = interpolate_subcarriers(&arr, tgt_sc);
assert_eq!(
out.shape(),
&[t, n_tx, n_rx, tgt_sc],
"resampled shape must be [{t}, {n_tx}, {n_rx}, {tgt_sc}], got {:?}",
out.shape()
);
}
/// Resampling 56 → 114 (upsampling) must produce shape [T, n_tx, n_rx, 114].
#[test]
fn resample_56_to_114_output_shape() {
let arr = Array4::<f32>::from_shape_fn((8, 2, 2, 56), |(ti, tx, rx, k)| {
(ti + tx + rx + k) as f32 * 0.1
});
let out = interpolate_subcarriers(&arr, 114);
assert_eq!(
out.shape(),
&[8, 2, 2, 114],
"upsampled shape must be [8, 2, 2, 114], got {:?}",
out.shape()
);
}
// ---------------------------------------------------------------------------
// Identity case: 56 → 56
// ---------------------------------------------------------------------------
/// Resampling from 56 → 56 subcarriers must return a tensor identical to the
/// input (element-wise equality within floating-point precision).
#[test]
fn identity_resample_56_to_56_preserves_values() {
let arr = Array4::<f32>::from_shape_fn((5, 3, 3, 56), |(ti, tx, rx, k)| {
// Deterministic: use a simple arithmetic formula.
(ti as f32 * 1000.0 + tx as f32 * 100.0 + rx as f32 * 10.0 + k as f32).sin()
});
let out = interpolate_subcarriers(&arr, 56);
assert_eq!(
out.shape(),
arr.shape(),
"identity resample must preserve shape"
);
for ((ti, tx, rx, k), orig) in arr.indexed_iter() {
let resampled = out[[ti, tx, rx, k]];
assert!(
(resampled - orig).abs() < 1e-5,
"identity resample mismatch at [{ti},{tx},{rx},{k}]: \
orig={orig}, resampled={resampled}"
);
}
}
// ---------------------------------------------------------------------------
// Monotone (linearly-increasing) input interpolates correctly
// ---------------------------------------------------------------------------
/// For a linearly-increasing input across the subcarrier axis, the resampled
/// output must also be linearly increasing (all values lie on the same line).
#[test]
fn monotone_input_interpolates_linearly() {
// src[k] = k as f32 for k in 0..8 — a straight line through the origin.
let arr = Array4::<f32>::from_shape_fn((1, 1, 1, 8), |(_, _, _, k)| k as f32);
let out = interpolate_subcarriers(&arr, 16);
// The output must be a linearly-spaced sequence from 0.0 to 7.0.
// out[i] = i * 7.0 / 15.0 (endpoints preserved by the mapping).
for i in 0..16_usize {
let expected = i as f32 * 7.0 / 15.0;
let actual = out[[0, 0, 0, i]];
assert!(
(actual - expected).abs() < 1e-5,
"linear interpolation wrong at index {i}: expected {expected}, got {actual}"
);
}
}
/// Downsampling a linearly-increasing input must also produce a linear output.
#[test]
fn monotone_downsample_interpolates_linearly() {
// src[k] = k * 2.0 for k in 0..16 (values 0, 2, 4, …, 30).
let arr = Array4::<f32>::from_shape_fn((1, 1, 1, 16), |(_, _, _, k)| k as f32 * 2.0);
let out = interpolate_subcarriers(&arr, 8);
// out[i] = i * 30.0 / 7.0 (endpoints at 0.0 and 30.0).
for i in 0..8_usize {
let expected = i as f32 * 30.0 / 7.0;
let actual = out[[0, 0, 0, i]];
assert!(
(actual - expected).abs() < 1e-4,
"linear downsampling wrong at index {i}: expected {expected}, got {actual}"
);
}
}
// ---------------------------------------------------------------------------
// Boundary value preservation
// ---------------------------------------------------------------------------
/// The first output subcarrier must equal the first input subcarrier exactly.
#[test]
fn boundary_first_subcarrier_preserved_on_downsample() {
// Fixed non-trivial values so we can verify the exact first element.
let arr = Array4::<f32>::from_shape_fn((1, 1, 1, 114), |(_, _, _, k)| {
(k as f32 * 0.1 + 1.0).ln() // deterministic, non-trivial
});
let first_value = arr[[0, 0, 0, 0]];
let out = interpolate_subcarriers(&arr, 56);
let first_out = out[[0, 0, 0, 0]];
assert!(
(first_out - first_value).abs() < 1e-5,
"first output subcarrier must equal first input subcarrier: \
expected {first_value}, got {first_out}"
);
}
/// The last output subcarrier must equal the last input subcarrier exactly.
#[test]
fn boundary_last_subcarrier_preserved_on_downsample() {
let arr = Array4::<f32>::from_shape_fn((1, 1, 1, 114), |(_, _, _, k)| {
(k as f32 * 0.1 + 1.0).ln()
});
let last_input = arr[[0, 0, 0, 113]];
let out = interpolate_subcarriers(&arr, 56);
let last_output = out[[0, 0, 0, 55]];
assert!(
(last_output - last_input).abs() < 1e-5,
"last output subcarrier must equal last input subcarrier: \
expected {last_input}, got {last_output}"
);
}
/// The same boundary preservation holds when upsampling.
#[test]
fn boundary_endpoints_preserved_on_upsample() {
let arr = Array4::<f32>::from_shape_fn((1, 1, 1, 56), |(_, _, _, k)| {
(k as f32 * 0.05 + 0.5).powi(2)
});
let first_input = arr[[0, 0, 0, 0]];
let last_input = arr[[0, 0, 0, 55]];
let out = interpolate_subcarriers(&arr, 114);
let first_output = out[[0, 0, 0, 0]];
let last_output = out[[0, 0, 0, 113]];
assert!(
(first_output - first_input).abs() < 1e-5,
"first output must equal first input on upsample: \
expected {first_input}, got {first_output}"
);
assert!(
(last_output - last_input).abs() < 1e-5,
"last output must equal last input on upsample: \
expected {last_input}, got {last_output}"
);
}
// ---------------------------------------------------------------------------
// Determinism
// ---------------------------------------------------------------------------
/// Calling `interpolate_subcarriers` twice with the same input must yield
/// bit-identical results — no non-deterministic behavior allowed.
#[test]
fn resample_is_deterministic() {
// Use a fixed deterministic array (seed=42 LCG-style arithmetic).
let arr = Array4::<f32>::from_shape_fn((10, 3, 3, 114), |(ti, tx, rx, k)| {
// Simple deterministic formula mimicking SyntheticDataset's LCG pattern.
let idx = ti * 3 * 3 * 114 + tx * 3 * 114 + rx * 114 + k;
// LCG: state = (a * state + c) mod m with seed = 42
let state_u64 = (6364136223846793005_u64)
.wrapping_mul(idx as u64 + 42)
.wrapping_add(1442695040888963407);
((state_u64 >> 33) as f32) / (u32::MAX as f32) // in [0, 1)
});
let out1 = interpolate_subcarriers(&arr, 56);
let out2 = interpolate_subcarriers(&arr, 56);
for ((ti, tx, rx, k), v1) in out1.indexed_iter() {
let v2 = out2[[ti, tx, rx, k]];
assert_eq!(
v1.to_bits(),
v2.to_bits(),
"bit-identical result required at [{ti},{tx},{rx},{k}]: \
first={v1}, second={v2}"
);
}
}
/// Same input parameters → same `compute_interp_weights` output every time.
#[test]
fn compute_interp_weights_is_deterministic() {
let w1 = compute_interp_weights(114, 56);
let w2 = compute_interp_weights(114, 56);
assert_eq!(w1.len(), w2.len(), "weight vector lengths must match");
for (i, (a, b)) in w1.iter().zip(w2.iter()).enumerate() {
assert_eq!(
a, b,
"weight at index {i} must be bit-identical across calls"
);
}
}
// ---------------------------------------------------------------------------
// compute_interp_weights properties
// ---------------------------------------------------------------------------
/// `compute_interp_weights(n, n)` must produce identity weights (i0==i1==k,
/// frac==0).
#[test]
fn compute_interp_weights_identity_case() {
let n = 56_usize;
let weights = compute_interp_weights(n, n);
assert_eq!(weights.len(), n, "identity weights length must equal n");
for (k, &(i0, i1, frac)) in weights.iter().enumerate() {
assert_eq!(i0, k, "i0 must equal k for identity weights at {k}");
assert_eq!(i1, k, "i1 must equal k for identity weights at {k}");
assert!(
frac.abs() < 1e-6,
"frac must be 0 for identity weights at {k}, got {frac}"
);
}
}
/// `compute_interp_weights` must produce exactly `target_sc` entries.
#[test]
fn compute_interp_weights_correct_length() {
let weights = compute_interp_weights(114, 56);
assert_eq!(
weights.len(),
56,
"114→56 weights must have 56 entries, got {}",
weights.len()
);
}
/// All weights must have fractions in [0, 1].
#[test]
fn compute_interp_weights_frac_in_unit_interval() {
let weights = compute_interp_weights(114, 56);
for (i, &(_, _, frac)) in weights.iter().enumerate() {
assert!(
frac >= 0.0 && frac <= 1.0 + 1e-6,
"fractional weight at index {i} must be in [0, 1], got {frac}"
);
}
}
/// All i0 and i1 indices must be within bounds of the source array.
#[test]
fn compute_interp_weights_indices_in_bounds() {
let src_sc = 114_usize;
let weights = compute_interp_weights(src_sc, 56);
for (k, &(i0, i1, _)) in weights.iter().enumerate() {
assert!(
i0 < src_sc,
"i0={i0} at output {k} is out of bounds for src_sc={src_sc}"
);
assert!(
i1 < src_sc,
"i1={i1} at output {k} is out of bounds for src_sc={src_sc}"
);
}
}
// ---------------------------------------------------------------------------
// select_subcarriers_by_variance
// ---------------------------------------------------------------------------
/// `select_subcarriers_by_variance` must return exactly k indices.
#[test]
fn select_subcarriers_returns_k_indices() {
let arr = Array4::<f32>::from_shape_fn((20, 3, 3, 56), |(ti, _, _, k)| {
(ti * k) as f32
});
let selected = select_subcarriers_by_variance(&arr, 8);
assert_eq!(
selected.len(),
8,
"must select exactly 8 subcarriers, got {}",
selected.len()
);
}
/// The returned indices must be sorted in ascending order.
#[test]
fn select_subcarriers_indices_are_sorted_ascending() {
let arr = Array4::<f32>::from_shape_fn((10, 2, 2, 56), |(ti, tx, rx, k)| {
(ti + tx * 3 + rx * 7 + k * 11) as f32
});
let selected = select_subcarriers_by_variance(&arr, 10);
for window in selected.windows(2) {
assert!(
window[0] < window[1],
"selected indices must be strictly ascending: {:?}",
selected
);
}
}
/// All returned indices must be within [0, n_sc).
#[test]
fn select_subcarriers_indices_are_valid() {
let n_sc = 56_usize;
let arr = Array4::<f32>::from_shape_fn((8, 3, 3, n_sc), |(ti, _, _, k)| {
(ti as f32 * 0.7 + k as f32 * 1.3).cos()
});
let selected = select_subcarriers_by_variance(&arr, 5);
for &idx in &selected {
assert!(
idx < n_sc,
"selected index {idx} is out of bounds for n_sc={n_sc}"
);
}
}
/// High-variance subcarriers should be preferred over low-variance ones.
/// Create an array where subcarriers 0..4 have zero variance and
/// subcarriers 4..8 have high variance — the top-4 selection must exclude 0..4.
#[test]
fn select_subcarriers_prefers_high_variance() {
// Subcarriers 0..4: constant value 0.5 (zero variance).
// Subcarriers 4..8: vary wildly across time (high variance).
let arr = Array4::<f32>::from_shape_fn((20, 1, 1, 8), |(ti, _, _, k)| {
if k < 4 {
0.5_f32 // constant across time → zero variance
} else {
// High variance: alternating +100 / -100 depending on time.
if ti % 2 == 0 { 100.0 } else { -100.0 }
}
});
let selected = select_subcarriers_by_variance(&arr, 4);
// All selected indices should be in {4, 5, 6, 7}.
for &idx in &selected {
assert!(
idx >= 4,
"expected only high-variance subcarriers (4..8) to be selected, \
but got index {idx}: selected = {:?}",
selected
);
}
}