Files
wifi-densepose/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/tests/test_losses.rs
Claude fce1271140 feat(rust): Complete training pipeline — losses, metrics, model, trainer, binaries
Losses (losses.rs — 1056 lines):
- WiFiDensePoseLoss with keypoint (visibility-masked MSE), DensePose
  (cross-entropy + Smooth L1 UV masked to foreground), transfer (MSE)
- generate_gaussian_heatmaps: Tensor-native 2D Gaussian heatmap gen
- compute_losses: unified functional API
- 11 deterministic unit tests

Metrics (metrics.rs — 984 lines):
- PCK@0.2 / PCK@0.5 with torso-diameter normalisation
- OKS with COCO standard per-joint sigmas
- MetricsAccumulator for online streaming eval
- hungarian_assignment: O(n³) Kuhn-Munkres min-cut via DFS augmenting
  paths for optimal multi-person keypoint assignment (ruvector min-cut)
- build_oks_cost_matrix: 1−OKS cost for bipartite matching
- 20 deterministic tests (perfect/wrong/invisible keypoints, 2×2/3×3/
  rectangular/empty Hungarian cases)

Model (model.rs — 713 lines):
- WiFiDensePoseModel end-to-end with tch-rs
- ModalityTranslator: amp+phase FC encoders → spatial pseudo-image
- Backbone: lightweight ResNet-style [B,3,48,48]→[B,256,6,6]
- KeypointHead: [B,256,6,6]→[B,17,H,W] heatmaps
- DensePoseHead: [B,256,6,6]→[B,25,H,W] parts + [B,48,H,W] UV

Trainer (trainer.rs — 777 lines):
- Full training loop: Adam, LR milestones, gradient clipping
- Deterministic batch shuffle via LCG (seed XOR epoch)
- CSV logging, best-checkpoint saving, early stopping
- evaluate() with MetricsAccumulator and heatmap argmax decode

Binaries:
- src/bin/train.rs: production MM-Fi training CLI (clap)
- src/bin/verify_training.rs: trust kill switch (EXIT 0/1/2)

Benches:
- benches/training_bench.rs: criterion benchmarks for key ops

Tests:
- tests/test_dataset.rs (459 lines)
- tests/test_metrics.rs (449 lines)
- tests/test_subcarrier.rs (389 lines)

proof.rs still stub — trainer agent completing it.

https://claude.ai/code/session_01BSBAQJ34SLkiJy4A8SoiL4
2026-02-28 15:22:54 +00:00

452 lines
16 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//! Integration tests for [`wifi_densepose_train::losses`].
//!
//! All tests are gated behind `#[cfg(feature = "tch-backend")]` because the
//! loss functions require PyTorch via `tch`. When running without that
//! feature the entire module is compiled but skipped at test-registration
//! time.
//!
//! All input tensors are constructed from fixed, deterministic data — no
//! `rand` crate, no OS entropy.
#[cfg(feature = "tch-backend")]
mod tch_tests {
use wifi_densepose_train::losses::{
generate_gaussian_heatmap, generate_target_heatmaps, LossWeights, WiFiDensePoseLoss,
};
// -----------------------------------------------------------------------
// Helper: CPU device
// -----------------------------------------------------------------------
fn cpu() -> tch::Device {
tch::Device::Cpu
}
// -----------------------------------------------------------------------
// generate_gaussian_heatmap
// -----------------------------------------------------------------------
/// The heatmap must have shape [heatmap_size, heatmap_size].
#[test]
fn gaussian_heatmap_has_correct_shape() {
let hm = generate_gaussian_heatmap(0.5, 0.5, 56, 2.0);
assert_eq!(
hm.shape(),
&[56, 56],
"heatmap shape must be [56, 56], got {:?}",
hm.shape()
);
}
/// All values in the heatmap must lie in [0, 1].
#[test]
fn gaussian_heatmap_values_in_unit_interval() {
let hm = generate_gaussian_heatmap(0.3, 0.7, 56, 2.0);
for &v in hm.iter() {
assert!(
v >= 0.0 && v <= 1.0 + 1e-6,
"heatmap value {v} is outside [0, 1]"
);
}
}
/// The peak must be at (or very close to) the keypoint pixel location.
#[test]
fn gaussian_heatmap_peak_at_keypoint_location() {
let kp_x = 0.5_f32;
let kp_y = 0.5_f32;
let size = 56_usize;
let sigma = 2.0_f32;
let hm = generate_gaussian_heatmap(kp_x, kp_y, size, sigma);
// Map normalised coordinates to pixel space.
let s = (size - 1) as f32;
let cx = (kp_x * s).round() as usize;
let cy = (kp_y * s).round() as usize;
let peak_val = hm[[cy, cx]];
assert!(
peak_val > 0.9,
"peak value {peak_val} at ({cx},{cy}) must be > 0.9 for σ=2.0"
);
// Verify it really is the maximum.
let global_max = hm.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
assert!(
(global_max - peak_val).abs() < 1e-4,
"peak at keypoint location {peak_val} must equal the global max {global_max}"
);
}
/// Values outside the 3σ radius must be zero (clamped).
#[test]
fn gaussian_heatmap_zero_outside_3sigma_radius() {
let size = 56_usize;
let sigma = 2.0_f32;
let kp_x = 0.5_f32;
let kp_y = 0.5_f32;
let hm = generate_gaussian_heatmap(kp_x, kp_y, size, sigma);
let s = (size - 1) as f32;
let cx = kp_x * s;
let cy = kp_y * s;
let clip_radius = 3.0 * sigma;
for r in 0..size {
for c in 0..size {
let dx = c as f32 - cx;
let dy = r as f32 - cy;
let dist = (dx * dx + dy * dy).sqrt();
if dist > clip_radius + 0.5 {
assert_eq!(
hm[[r, c]],
0.0,
"pixel at ({r},{c}) with dist={dist:.2} from kp must be 0 (outside 3σ)"
);
}
}
}
}
// -----------------------------------------------------------------------
// generate_target_heatmaps (batch)
// -----------------------------------------------------------------------
/// Output shape must be [B, 17, H, W].
#[test]
fn target_heatmaps_output_shape() {
let batch = 4_usize;
let joints = 17_usize;
let size = 56_usize;
let keypoints = ndarray::Array3::from_elem((batch, joints, 2), 0.5_f32);
let visibility = ndarray::Array2::ones((batch, joints));
let heatmaps = generate_target_heatmaps(&keypoints, &visibility, size, 2.0);
assert_eq!(
heatmaps.shape(),
&[batch, joints, size, size],
"target heatmaps shape must be [{batch}, {joints}, {size}, {size}], \
got {:?}",
heatmaps.shape()
);
}
/// Invisible keypoints (visibility = 0) must produce all-zero heatmap channels.
#[test]
fn target_heatmaps_invisible_joints_are_zero() {
let batch = 2_usize;
let joints = 17_usize;
let size = 32_usize;
let keypoints = ndarray::Array3::from_elem((batch, joints, 2), 0.5_f32);
// Make all joints in batch 0 invisible.
let mut visibility = ndarray::Array2::ones((batch, joints));
for j in 0..joints {
visibility[[0, j]] = 0.0;
}
let heatmaps = generate_target_heatmaps(&keypoints, &visibility, size, 2.0);
for j in 0..joints {
for r in 0..size {
for c in 0..size {
assert_eq!(
heatmaps[[0, j, r, c]],
0.0,
"invisible joint heatmap at [0,{j},{r},{c}] must be zero"
);
}
}
}
}
/// Visible keypoints must produce non-zero heatmaps.
#[test]
fn target_heatmaps_visible_joints_are_nonzero() {
let batch = 1_usize;
let joints = 17_usize;
let size = 56_usize;
let keypoints = ndarray::Array3::from_elem((batch, joints, 2), 0.5_f32);
let visibility = ndarray::Array2::ones((batch, joints));
let heatmaps = generate_target_heatmaps(&keypoints, &visibility, size, 2.0);
let total_sum: f32 = heatmaps.iter().copied().sum();
assert!(
total_sum > 0.0,
"visible joints must produce non-zero heatmaps, sum={total_sum}"
);
}
// -----------------------------------------------------------------------
// keypoint_heatmap_loss
// -----------------------------------------------------------------------
/// Loss of identical pred and target heatmaps must be ≈ 0.0.
#[test]
fn keypoint_heatmap_loss_identical_tensors_is_zero() {
let loss_fn = WiFiDensePoseLoss::new(LossWeights::default());
let dev = cpu();
let pred = tch::Tensor::ones([2, 17, 16, 16], (tch::Kind::Float, dev));
let target = tch::Tensor::ones([2, 17, 16, 16], (tch::Kind::Float, dev));
let vis = tch::Tensor::ones([2, 17], (tch::Kind::Float, dev));
let loss = loss_fn.keypoint_loss(&pred, &target, &vis);
let val = loss.double_value(&[]) as f32;
assert!(
val.abs() < 1e-5,
"keypoint loss for identical pred/target must be ≈ 0.0, got {val}"
);
}
/// Loss of all-zeros pred vs all-ones target must be > 0.0.
#[test]
fn keypoint_heatmap_loss_zero_pred_vs_ones_target_is_positive() {
let loss_fn = WiFiDensePoseLoss::new(LossWeights::default());
let dev = cpu();
let pred = tch::Tensor::zeros([1, 17, 8, 8], (tch::Kind::Float, dev));
let target = tch::Tensor::ones([1, 17, 8, 8], (tch::Kind::Float, dev));
let vis = tch::Tensor::ones([1, 17], (tch::Kind::Float, dev));
let loss = loss_fn.keypoint_loss(&pred, &target, &vis);
let val = loss.double_value(&[]) as f32;
assert!(
val > 0.0,
"keypoint loss for zero vs ones must be > 0.0, got {val}"
);
}
/// Invisible joints must not contribute to the loss.
#[test]
fn keypoint_heatmap_loss_invisible_joints_contribute_nothing() {
let loss_fn = WiFiDensePoseLoss::new(LossWeights::default());
let dev = cpu();
// Large error but all visibility = 0 → loss must be ≈ 0.
let pred = tch::Tensor::ones([1, 17, 8, 8], (tch::Kind::Float, dev));
let target = tch::Tensor::zeros([1, 17, 8, 8], (tch::Kind::Float, dev));
let vis = tch::Tensor::zeros([1, 17], (tch::Kind::Float, dev));
let loss = loss_fn.keypoint_loss(&pred, &target, &vis);
let val = loss.double_value(&[]) as f32;
assert!(
val.abs() < 1e-5,
"all-invisible loss must be ≈ 0.0 (no joints contribute), got {val}"
);
}
// -----------------------------------------------------------------------
// densepose_part_loss
// -----------------------------------------------------------------------
/// densepose_loss must return a non-NaN, non-negative value.
#[test]
fn densepose_part_loss_no_nan() {
let loss_fn = WiFiDensePoseLoss::new(LossWeights::default());
let dev = cpu();
let b = 1_i64;
let h = 8_i64;
let w = 8_i64;
let pred_parts = tch::Tensor::zeros([b, 25, h, w], (tch::Kind::Float, dev));
let target_parts = tch::Tensor::ones([b, h, w], (tch::Kind::Int64, dev));
let uv = tch::Tensor::zeros([b, 48, h, w], (tch::Kind::Float, dev));
let loss = loss_fn.densepose_loss(&pred_parts, &target_parts, &uv, &uv);
let val = loss.double_value(&[]) as f32;
assert!(
!val.is_nan(),
"densepose_loss must not produce NaN, got {val}"
);
assert!(
val >= 0.0,
"densepose_loss must be non-negative, got {val}"
);
}
// -----------------------------------------------------------------------
// compute_losses (forward)
// -----------------------------------------------------------------------
/// The combined forward pass must produce a total loss > 0 for non-trivial
/// (non-identical) inputs.
#[test]
fn compute_losses_total_positive_for_nonzero_error() {
let loss_fn = WiFiDensePoseLoss::new(LossWeights::default());
let dev = cpu();
// pred = zeros, target = ones → non-zero keypoint error.
let pred_kp = tch::Tensor::zeros([2, 17, 8, 8], (tch::Kind::Float, dev));
let target_kp = tch::Tensor::ones([2, 17, 8, 8], (tch::Kind::Float, dev));
let vis = tch::Tensor::ones([2, 17], (tch::Kind::Float, dev));
let (_, output) = loss_fn.forward(
&pred_kp, &target_kp, &vis,
None, None, None, None,
None, None,
);
assert!(
output.total > 0.0,
"total loss must be > 0 for non-trivial predictions, got {}",
output.total
);
}
/// The combined forward pass with identical tensors must produce total ≈ 0.
#[test]
fn compute_losses_total_zero_for_perfect_prediction() {
let weights = LossWeights {
lambda_kp: 1.0,
lambda_dp: 0.0,
lambda_tr: 0.0,
};
let loss_fn = WiFiDensePoseLoss::new(weights);
let dev = cpu();
let perfect = tch::Tensor::ones([1, 17, 8, 8], (tch::Kind::Float, dev));
let vis = tch::Tensor::ones([1, 17], (tch::Kind::Float, dev));
let (_, output) = loss_fn.forward(
&perfect, &perfect, &vis,
None, None, None, None,
None, None,
);
assert!(
output.total.abs() < 1e-5,
"perfect prediction must yield total ≈ 0.0, got {}",
output.total
);
}
/// Optional densepose and transfer outputs must be None when not supplied.
#[test]
fn compute_losses_optional_components_are_none() {
let loss_fn = WiFiDensePoseLoss::new(LossWeights::default());
let dev = cpu();
let t = tch::Tensor::ones([1, 17, 8, 8], (tch::Kind::Float, dev));
let vis = tch::Tensor::ones([1, 17], (tch::Kind::Float, dev));
let (_, output) = loss_fn.forward(
&t, &t, &vis,
None, None, None, None,
None, None,
);
assert!(
output.densepose.is_none(),
"densepose component must be None when not supplied"
);
assert!(
output.transfer.is_none(),
"transfer component must be None when not supplied"
);
}
/// Full forward pass with all optional components must populate all fields.
#[test]
fn compute_losses_with_all_components_populates_all_fields() {
let loss_fn = WiFiDensePoseLoss::new(LossWeights::default());
let dev = cpu();
let pred_kp = tch::Tensor::zeros([1, 17, 8, 8], (tch::Kind::Float, dev));
let target_kp = tch::Tensor::ones([1, 17, 8, 8], (tch::Kind::Float, dev));
let vis = tch::Tensor::ones([1, 17], (tch::Kind::Float, dev));
let pred_parts = tch::Tensor::zeros([1, 25, 8, 8], (tch::Kind::Float, dev));
let target_parts = tch::Tensor::ones([1, 8, 8], (tch::Kind::Int64, dev));
let uv = tch::Tensor::zeros([1, 48, 8, 8], (tch::Kind::Float, dev));
let student = tch::Tensor::zeros([1, 64, 4, 4], (tch::Kind::Float, dev));
let teacher = tch::Tensor::ones([1, 64, 4, 4], (tch::Kind::Float, dev));
let (_, output) = loss_fn.forward(
&pred_kp, &target_kp, &vis,
Some(&pred_parts), Some(&target_parts), Some(&uv), Some(&uv),
Some(&student), Some(&teacher),
);
assert!(
output.densepose.is_some(),
"densepose component must be Some when all inputs provided"
);
assert!(
output.transfer.is_some(),
"transfer component must be Some when student/teacher provided"
);
assert!(
output.total > 0.0,
"total loss must be > 0 when pred ≠ target, got {}",
output.total
);
// Neither component may be NaN.
if let Some(dp) = output.densepose {
assert!(!dp.is_nan(), "densepose component must not be NaN");
}
if let Some(tr) = output.transfer {
assert!(!tr.is_nan(), "transfer component must not be NaN");
}
}
// -----------------------------------------------------------------------
// transfer_loss
// -----------------------------------------------------------------------
/// Transfer loss for identical tensors must be ≈ 0.0.
#[test]
fn transfer_loss_identical_features_is_zero() {
let loss_fn = WiFiDensePoseLoss::new(LossWeights::default());
let dev = cpu();
let feat = tch::Tensor::ones([2, 64, 8, 8], (tch::Kind::Float, dev));
let loss = loss_fn.transfer_loss(&feat, &feat);
let val = loss.double_value(&[]) as f32;
assert!(
val.abs() < 1e-5,
"transfer loss for identical tensors must be ≈ 0.0, got {val}"
);
}
/// Transfer loss for different tensors must be > 0.0.
#[test]
fn transfer_loss_different_features_is_positive() {
let loss_fn = WiFiDensePoseLoss::new(LossWeights::default());
let dev = cpu();
let student = tch::Tensor::zeros([2, 64, 8, 8], (tch::Kind::Float, dev));
let teacher = tch::Tensor::ones([2, 64, 8, 8], (tch::Kind::Float, dev));
let loss = loss_fn.transfer_loss(&student, &teacher);
let val = loss.double_value(&[]) as f32;
assert!(
val > 0.0,
"transfer loss for different tensors must be > 0.0, got {val}"
);
}
}
// When tch-backend is disabled, ensure the file still compiles cleanly.
#[cfg(not(feature = "tch-backend"))]
#[test]
fn tch_backend_not_enabled() {
// This test passes trivially when the tch-backend feature is absent.
// The tch_tests module above is fully skipped.
}