Losses (losses.rs — 1056 lines): - WiFiDensePoseLoss with keypoint (visibility-masked MSE), DensePose (cross-entropy + Smooth L1 UV masked to foreground), transfer (MSE) - generate_gaussian_heatmaps: Tensor-native 2D Gaussian heatmap gen - compute_losses: unified functional API - 11 deterministic unit tests Metrics (metrics.rs — 984 lines): - PCK@0.2 / PCK@0.5 with torso-diameter normalisation - OKS with COCO standard per-joint sigmas - MetricsAccumulator for online streaming eval - hungarian_assignment: O(n³) Kuhn-Munkres min-cut via DFS augmenting paths for optimal multi-person keypoint assignment (ruvector min-cut) - build_oks_cost_matrix: 1−OKS cost for bipartite matching - 20 deterministic tests (perfect/wrong/invisible keypoints, 2×2/3×3/ rectangular/empty Hungarian cases) Model (model.rs — 713 lines): - WiFiDensePoseModel end-to-end with tch-rs - ModalityTranslator: amp+phase FC encoders → spatial pseudo-image - Backbone: lightweight ResNet-style [B,3,48,48]→[B,256,6,6] - KeypointHead: [B,256,6,6]→[B,17,H,W] heatmaps - DensePoseHead: [B,256,6,6]→[B,25,H,W] parts + [B,48,H,W] UV Trainer (trainer.rs — 777 lines): - Full training loop: Adam, LR milestones, gradient clipping - Deterministic batch shuffle via LCG (seed XOR epoch) - CSV logging, best-checkpoint saving, early stopping - evaluate() with MetricsAccumulator and heatmap argmax decode Binaries: - src/bin/train.rs: production MM-Fi training CLI (clap) - src/bin/verify_training.rs: trust kill switch (EXIT 0/1/2) Benches: - benches/training_bench.rs: criterion benchmarks for key ops Tests: - tests/test_dataset.rs (459 lines) - tests/test_metrics.rs (449 lines) - tests/test_subcarrier.rs (389 lines) proof.rs still stub — trainer agent completing it. https://claude.ai/code/session_01BSBAQJ34SLkiJy4A8SoiL4
452 lines
16 KiB
Rust
452 lines
16 KiB
Rust
//! Integration tests for [`wifi_densepose_train::losses`].
|
||
//!
|
||
//! All tests are gated behind `#[cfg(feature = "tch-backend")]` because the
|
||
//! loss functions require PyTorch via `tch`. When running without that
|
||
//! feature the entire module is compiled but skipped at test-registration
|
||
//! time.
|
||
//!
|
||
//! All input tensors are constructed from fixed, deterministic data — no
|
||
//! `rand` crate, no OS entropy.
|
||
|
||
#[cfg(feature = "tch-backend")]
|
||
mod tch_tests {
|
||
use wifi_densepose_train::losses::{
|
||
generate_gaussian_heatmap, generate_target_heatmaps, LossWeights, WiFiDensePoseLoss,
|
||
};
|
||
|
||
// -----------------------------------------------------------------------
|
||
// Helper: CPU device
|
||
// -----------------------------------------------------------------------
|
||
|
||
fn cpu() -> tch::Device {
|
||
tch::Device::Cpu
|
||
}
|
||
|
||
// -----------------------------------------------------------------------
|
||
// generate_gaussian_heatmap
|
||
// -----------------------------------------------------------------------
|
||
|
||
/// The heatmap must have shape [heatmap_size, heatmap_size].
|
||
#[test]
|
||
fn gaussian_heatmap_has_correct_shape() {
|
||
let hm = generate_gaussian_heatmap(0.5, 0.5, 56, 2.0);
|
||
assert_eq!(
|
||
hm.shape(),
|
||
&[56, 56],
|
||
"heatmap shape must be [56, 56], got {:?}",
|
||
hm.shape()
|
||
);
|
||
}
|
||
|
||
/// All values in the heatmap must lie in [0, 1].
|
||
#[test]
|
||
fn gaussian_heatmap_values_in_unit_interval() {
|
||
let hm = generate_gaussian_heatmap(0.3, 0.7, 56, 2.0);
|
||
for &v in hm.iter() {
|
||
assert!(
|
||
v >= 0.0 && v <= 1.0 + 1e-6,
|
||
"heatmap value {v} is outside [0, 1]"
|
||
);
|
||
}
|
||
}
|
||
|
||
/// The peak must be at (or very close to) the keypoint pixel location.
|
||
#[test]
|
||
fn gaussian_heatmap_peak_at_keypoint_location() {
|
||
let kp_x = 0.5_f32;
|
||
let kp_y = 0.5_f32;
|
||
let size = 56_usize;
|
||
let sigma = 2.0_f32;
|
||
|
||
let hm = generate_gaussian_heatmap(kp_x, kp_y, size, sigma);
|
||
|
||
// Map normalised coordinates to pixel space.
|
||
let s = (size - 1) as f32;
|
||
let cx = (kp_x * s).round() as usize;
|
||
let cy = (kp_y * s).round() as usize;
|
||
|
||
let peak_val = hm[[cy, cx]];
|
||
assert!(
|
||
peak_val > 0.9,
|
||
"peak value {peak_val} at ({cx},{cy}) must be > 0.9 for σ=2.0"
|
||
);
|
||
|
||
// Verify it really is the maximum.
|
||
let global_max = hm.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||
assert!(
|
||
(global_max - peak_val).abs() < 1e-4,
|
||
"peak at keypoint location {peak_val} must equal the global max {global_max}"
|
||
);
|
||
}
|
||
|
||
/// Values outside the 3σ radius must be zero (clamped).
|
||
#[test]
|
||
fn gaussian_heatmap_zero_outside_3sigma_radius() {
|
||
let size = 56_usize;
|
||
let sigma = 2.0_f32;
|
||
let kp_x = 0.5_f32;
|
||
let kp_y = 0.5_f32;
|
||
|
||
let hm = generate_gaussian_heatmap(kp_x, kp_y, size, sigma);
|
||
|
||
let s = (size - 1) as f32;
|
||
let cx = kp_x * s;
|
||
let cy = kp_y * s;
|
||
let clip_radius = 3.0 * sigma;
|
||
|
||
for r in 0..size {
|
||
for c in 0..size {
|
||
let dx = c as f32 - cx;
|
||
let dy = r as f32 - cy;
|
||
let dist = (dx * dx + dy * dy).sqrt();
|
||
if dist > clip_radius + 0.5 {
|
||
assert_eq!(
|
||
hm[[r, c]],
|
||
0.0,
|
||
"pixel at ({r},{c}) with dist={dist:.2} from kp must be 0 (outside 3σ)"
|
||
);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// -----------------------------------------------------------------------
|
||
// generate_target_heatmaps (batch)
|
||
// -----------------------------------------------------------------------
|
||
|
||
/// Output shape must be [B, 17, H, W].
|
||
#[test]
|
||
fn target_heatmaps_output_shape() {
|
||
let batch = 4_usize;
|
||
let joints = 17_usize;
|
||
let size = 56_usize;
|
||
|
||
let keypoints = ndarray::Array3::from_elem((batch, joints, 2), 0.5_f32);
|
||
let visibility = ndarray::Array2::ones((batch, joints));
|
||
|
||
let heatmaps = generate_target_heatmaps(&keypoints, &visibility, size, 2.0);
|
||
|
||
assert_eq!(
|
||
heatmaps.shape(),
|
||
&[batch, joints, size, size],
|
||
"target heatmaps shape must be [{batch}, {joints}, {size}, {size}], \
|
||
got {:?}",
|
||
heatmaps.shape()
|
||
);
|
||
}
|
||
|
||
/// Invisible keypoints (visibility = 0) must produce all-zero heatmap channels.
|
||
#[test]
|
||
fn target_heatmaps_invisible_joints_are_zero() {
|
||
let batch = 2_usize;
|
||
let joints = 17_usize;
|
||
let size = 32_usize;
|
||
|
||
let keypoints = ndarray::Array3::from_elem((batch, joints, 2), 0.5_f32);
|
||
// Make all joints in batch 0 invisible.
|
||
let mut visibility = ndarray::Array2::ones((batch, joints));
|
||
for j in 0..joints {
|
||
visibility[[0, j]] = 0.0;
|
||
}
|
||
|
||
let heatmaps = generate_target_heatmaps(&keypoints, &visibility, size, 2.0);
|
||
|
||
for j in 0..joints {
|
||
for r in 0..size {
|
||
for c in 0..size {
|
||
assert_eq!(
|
||
heatmaps[[0, j, r, c]],
|
||
0.0,
|
||
"invisible joint heatmap at [0,{j},{r},{c}] must be zero"
|
||
);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
/// Visible keypoints must produce non-zero heatmaps.
|
||
#[test]
|
||
fn target_heatmaps_visible_joints_are_nonzero() {
|
||
let batch = 1_usize;
|
||
let joints = 17_usize;
|
||
let size = 56_usize;
|
||
|
||
let keypoints = ndarray::Array3::from_elem((batch, joints, 2), 0.5_f32);
|
||
let visibility = ndarray::Array2::ones((batch, joints));
|
||
|
||
let heatmaps = generate_target_heatmaps(&keypoints, &visibility, size, 2.0);
|
||
|
||
let total_sum: f32 = heatmaps.iter().copied().sum();
|
||
assert!(
|
||
total_sum > 0.0,
|
||
"visible joints must produce non-zero heatmaps, sum={total_sum}"
|
||
);
|
||
}
|
||
|
||
// -----------------------------------------------------------------------
|
||
// keypoint_heatmap_loss
|
||
// -----------------------------------------------------------------------
|
||
|
||
/// Loss of identical pred and target heatmaps must be ≈ 0.0.
|
||
#[test]
|
||
fn keypoint_heatmap_loss_identical_tensors_is_zero() {
|
||
let loss_fn = WiFiDensePoseLoss::new(LossWeights::default());
|
||
let dev = cpu();
|
||
|
||
let pred = tch::Tensor::ones([2, 17, 16, 16], (tch::Kind::Float, dev));
|
||
let target = tch::Tensor::ones([2, 17, 16, 16], (tch::Kind::Float, dev));
|
||
let vis = tch::Tensor::ones([2, 17], (tch::Kind::Float, dev));
|
||
|
||
let loss = loss_fn.keypoint_loss(&pred, &target, &vis);
|
||
let val = loss.double_value(&[]) as f32;
|
||
|
||
assert!(
|
||
val.abs() < 1e-5,
|
||
"keypoint loss for identical pred/target must be ≈ 0.0, got {val}"
|
||
);
|
||
}
|
||
|
||
/// Loss of all-zeros pred vs all-ones target must be > 0.0.
|
||
#[test]
|
||
fn keypoint_heatmap_loss_zero_pred_vs_ones_target_is_positive() {
|
||
let loss_fn = WiFiDensePoseLoss::new(LossWeights::default());
|
||
let dev = cpu();
|
||
|
||
let pred = tch::Tensor::zeros([1, 17, 8, 8], (tch::Kind::Float, dev));
|
||
let target = tch::Tensor::ones([1, 17, 8, 8], (tch::Kind::Float, dev));
|
||
let vis = tch::Tensor::ones([1, 17], (tch::Kind::Float, dev));
|
||
|
||
let loss = loss_fn.keypoint_loss(&pred, &target, &vis);
|
||
let val = loss.double_value(&[]) as f32;
|
||
|
||
assert!(
|
||
val > 0.0,
|
||
"keypoint loss for zero vs ones must be > 0.0, got {val}"
|
||
);
|
||
}
|
||
|
||
/// Invisible joints must not contribute to the loss.
|
||
#[test]
|
||
fn keypoint_heatmap_loss_invisible_joints_contribute_nothing() {
|
||
let loss_fn = WiFiDensePoseLoss::new(LossWeights::default());
|
||
let dev = cpu();
|
||
|
||
// Large error but all visibility = 0 → loss must be ≈ 0.
|
||
let pred = tch::Tensor::ones([1, 17, 8, 8], (tch::Kind::Float, dev));
|
||
let target = tch::Tensor::zeros([1, 17, 8, 8], (tch::Kind::Float, dev));
|
||
let vis = tch::Tensor::zeros([1, 17], (tch::Kind::Float, dev));
|
||
|
||
let loss = loss_fn.keypoint_loss(&pred, &target, &vis);
|
||
let val = loss.double_value(&[]) as f32;
|
||
|
||
assert!(
|
||
val.abs() < 1e-5,
|
||
"all-invisible loss must be ≈ 0.0 (no joints contribute), got {val}"
|
||
);
|
||
}
|
||
|
||
// -----------------------------------------------------------------------
|
||
// densepose_part_loss
|
||
// -----------------------------------------------------------------------
|
||
|
||
/// densepose_loss must return a non-NaN, non-negative value.
|
||
#[test]
|
||
fn densepose_part_loss_no_nan() {
|
||
let loss_fn = WiFiDensePoseLoss::new(LossWeights::default());
|
||
let dev = cpu();
|
||
|
||
let b = 1_i64;
|
||
let h = 8_i64;
|
||
let w = 8_i64;
|
||
|
||
let pred_parts = tch::Tensor::zeros([b, 25, h, w], (tch::Kind::Float, dev));
|
||
let target_parts = tch::Tensor::ones([b, h, w], (tch::Kind::Int64, dev));
|
||
let uv = tch::Tensor::zeros([b, 48, h, w], (tch::Kind::Float, dev));
|
||
|
||
let loss = loss_fn.densepose_loss(&pred_parts, &target_parts, &uv, &uv);
|
||
let val = loss.double_value(&[]) as f32;
|
||
|
||
assert!(
|
||
!val.is_nan(),
|
||
"densepose_loss must not produce NaN, got {val}"
|
||
);
|
||
assert!(
|
||
val >= 0.0,
|
||
"densepose_loss must be non-negative, got {val}"
|
||
);
|
||
}
|
||
|
||
// -----------------------------------------------------------------------
|
||
// compute_losses (forward)
|
||
// -----------------------------------------------------------------------
|
||
|
||
/// The combined forward pass must produce a total loss > 0 for non-trivial
|
||
/// (non-identical) inputs.
|
||
#[test]
|
||
fn compute_losses_total_positive_for_nonzero_error() {
|
||
let loss_fn = WiFiDensePoseLoss::new(LossWeights::default());
|
||
let dev = cpu();
|
||
|
||
// pred = zeros, target = ones → non-zero keypoint error.
|
||
let pred_kp = tch::Tensor::zeros([2, 17, 8, 8], (tch::Kind::Float, dev));
|
||
let target_kp = tch::Tensor::ones([2, 17, 8, 8], (tch::Kind::Float, dev));
|
||
let vis = tch::Tensor::ones([2, 17], (tch::Kind::Float, dev));
|
||
|
||
let (_, output) = loss_fn.forward(
|
||
&pred_kp, &target_kp, &vis,
|
||
None, None, None, None,
|
||
None, None,
|
||
);
|
||
|
||
assert!(
|
||
output.total > 0.0,
|
||
"total loss must be > 0 for non-trivial predictions, got {}",
|
||
output.total
|
||
);
|
||
}
|
||
|
||
/// The combined forward pass with identical tensors must produce total ≈ 0.
|
||
#[test]
|
||
fn compute_losses_total_zero_for_perfect_prediction() {
|
||
let weights = LossWeights {
|
||
lambda_kp: 1.0,
|
||
lambda_dp: 0.0,
|
||
lambda_tr: 0.0,
|
||
};
|
||
let loss_fn = WiFiDensePoseLoss::new(weights);
|
||
let dev = cpu();
|
||
|
||
let perfect = tch::Tensor::ones([1, 17, 8, 8], (tch::Kind::Float, dev));
|
||
let vis = tch::Tensor::ones([1, 17], (tch::Kind::Float, dev));
|
||
|
||
let (_, output) = loss_fn.forward(
|
||
&perfect, &perfect, &vis,
|
||
None, None, None, None,
|
||
None, None,
|
||
);
|
||
|
||
assert!(
|
||
output.total.abs() < 1e-5,
|
||
"perfect prediction must yield total ≈ 0.0, got {}",
|
||
output.total
|
||
);
|
||
}
|
||
|
||
/// Optional densepose and transfer outputs must be None when not supplied.
|
||
#[test]
|
||
fn compute_losses_optional_components_are_none() {
|
||
let loss_fn = WiFiDensePoseLoss::new(LossWeights::default());
|
||
let dev = cpu();
|
||
|
||
let t = tch::Tensor::ones([1, 17, 8, 8], (tch::Kind::Float, dev));
|
||
let vis = tch::Tensor::ones([1, 17], (tch::Kind::Float, dev));
|
||
|
||
let (_, output) = loss_fn.forward(
|
||
&t, &t, &vis,
|
||
None, None, None, None,
|
||
None, None,
|
||
);
|
||
|
||
assert!(
|
||
output.densepose.is_none(),
|
||
"densepose component must be None when not supplied"
|
||
);
|
||
assert!(
|
||
output.transfer.is_none(),
|
||
"transfer component must be None when not supplied"
|
||
);
|
||
}
|
||
|
||
/// Full forward pass with all optional components must populate all fields.
|
||
#[test]
|
||
fn compute_losses_with_all_components_populates_all_fields() {
|
||
let loss_fn = WiFiDensePoseLoss::new(LossWeights::default());
|
||
let dev = cpu();
|
||
|
||
let pred_kp = tch::Tensor::zeros([1, 17, 8, 8], (tch::Kind::Float, dev));
|
||
let target_kp = tch::Tensor::ones([1, 17, 8, 8], (tch::Kind::Float, dev));
|
||
let vis = tch::Tensor::ones([1, 17], (tch::Kind::Float, dev));
|
||
|
||
let pred_parts = tch::Tensor::zeros([1, 25, 8, 8], (tch::Kind::Float, dev));
|
||
let target_parts = tch::Tensor::ones([1, 8, 8], (tch::Kind::Int64, dev));
|
||
let uv = tch::Tensor::zeros([1, 48, 8, 8], (tch::Kind::Float, dev));
|
||
|
||
let student = tch::Tensor::zeros([1, 64, 4, 4], (tch::Kind::Float, dev));
|
||
let teacher = tch::Tensor::ones([1, 64, 4, 4], (tch::Kind::Float, dev));
|
||
|
||
let (_, output) = loss_fn.forward(
|
||
&pred_kp, &target_kp, &vis,
|
||
Some(&pred_parts), Some(&target_parts), Some(&uv), Some(&uv),
|
||
Some(&student), Some(&teacher),
|
||
);
|
||
|
||
assert!(
|
||
output.densepose.is_some(),
|
||
"densepose component must be Some when all inputs provided"
|
||
);
|
||
assert!(
|
||
output.transfer.is_some(),
|
||
"transfer component must be Some when student/teacher provided"
|
||
);
|
||
assert!(
|
||
output.total > 0.0,
|
||
"total loss must be > 0 when pred ≠ target, got {}",
|
||
output.total
|
||
);
|
||
|
||
// Neither component may be NaN.
|
||
if let Some(dp) = output.densepose {
|
||
assert!(!dp.is_nan(), "densepose component must not be NaN");
|
||
}
|
||
if let Some(tr) = output.transfer {
|
||
assert!(!tr.is_nan(), "transfer component must not be NaN");
|
||
}
|
||
}
|
||
|
||
// -----------------------------------------------------------------------
|
||
// transfer_loss
|
||
// -----------------------------------------------------------------------
|
||
|
||
/// Transfer loss for identical tensors must be ≈ 0.0.
|
||
#[test]
|
||
fn transfer_loss_identical_features_is_zero() {
|
||
let loss_fn = WiFiDensePoseLoss::new(LossWeights::default());
|
||
let dev = cpu();
|
||
|
||
let feat = tch::Tensor::ones([2, 64, 8, 8], (tch::Kind::Float, dev));
|
||
let loss = loss_fn.transfer_loss(&feat, &feat);
|
||
let val = loss.double_value(&[]) as f32;
|
||
|
||
assert!(
|
||
val.abs() < 1e-5,
|
||
"transfer loss for identical tensors must be ≈ 0.0, got {val}"
|
||
);
|
||
}
|
||
|
||
/// Transfer loss for different tensors must be > 0.0.
|
||
#[test]
|
||
fn transfer_loss_different_features_is_positive() {
|
||
let loss_fn = WiFiDensePoseLoss::new(LossWeights::default());
|
||
let dev = cpu();
|
||
|
||
let student = tch::Tensor::zeros([2, 64, 8, 8], (tch::Kind::Float, dev));
|
||
let teacher = tch::Tensor::ones([2, 64, 8, 8], (tch::Kind::Float, dev));
|
||
|
||
let loss = loss_fn.transfer_loss(&student, &teacher);
|
||
let val = loss.double_value(&[]) as f32;
|
||
|
||
assert!(
|
||
val > 0.0,
|
||
"transfer loss for different tensors must be > 0.0, got {val}"
|
||
);
|
||
}
|
||
}
|
||
|
||
// When tch-backend is disabled, ensure the file still compiles cleanly.
|
||
#[cfg(not(feature = "tch-backend"))]
|
||
#[test]
|
||
fn tch_backend_not_enabled() {
|
||
// This test passes trivially when the tch-backend feature is absent.
|
||
// The tch_tests module above is fully skipped.
|
||
}
|