feat(rust): Complete training pipeline — losses, metrics, model, trainer, binaries
Losses (losses.rs — 1056 lines): - WiFiDensePoseLoss with keypoint (visibility-masked MSE), DensePose (cross-entropy + Smooth L1 UV masked to foreground), transfer (MSE) - generate_gaussian_heatmaps: Tensor-native 2D Gaussian heatmap gen - compute_losses: unified functional API - 11 deterministic unit tests Metrics (metrics.rs — 984 lines): - PCK@0.2 / PCK@0.5 with torso-diameter normalisation - OKS with COCO standard per-joint sigmas - MetricsAccumulator for online streaming eval - hungarian_assignment: O(n³) Kuhn-Munkres min-cut via DFS augmenting paths for optimal multi-person keypoint assignment (ruvector min-cut) - build_oks_cost_matrix: 1−OKS cost for bipartite matching - 20 deterministic tests (perfect/wrong/invisible keypoints, 2×2/3×3/ rectangular/empty Hungarian cases) Model (model.rs — 713 lines): - WiFiDensePoseModel end-to-end with tch-rs - ModalityTranslator: amp+phase FC encoders → spatial pseudo-image - Backbone: lightweight ResNet-style [B,3,48,48]→[B,256,6,6] - KeypointHead: [B,256,6,6]→[B,17,H,W] heatmaps - DensePoseHead: [B,256,6,6]→[B,25,H,W] parts + [B,48,H,W] UV Trainer (trainer.rs — 777 lines): - Full training loop: Adam, LR milestones, gradient clipping - Deterministic batch shuffle via LCG (seed XOR epoch) - CSV logging, best-checkpoint saving, early stopping - evaluate() with MetricsAccumulator and heatmap argmax decode Binaries: - src/bin/train.rs: production MM-Fi training CLI (clap) - src/bin/verify_training.rs: trust kill switch (EXIT 0/1/2) Benches: - benches/training_bench.rs: criterion benchmarks for key ops Tests: - tests/test_dataset.rs (459 lines) - tests/test_metrics.rs (449 lines) - tests/test_subcarrier.rs (389 lines) proof.rs still stub — trainer agent completing it. https://claude.ai/code/session_01BSBAQJ34SLkiJy4A8SoiL4
This commit is contained in:
@@ -26,9 +26,12 @@ use tch::{Kind, Reduction, Tensor};
|
||||
// Public types
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Scalar components produced by a single forward pass through the combined loss.
|
||||
/// Scalar components produced by a single forward pass through [`WiFiDensePoseLoss::forward`].
|
||||
///
|
||||
/// Contains `f32` scalar values extracted from the computation graph for
|
||||
/// logging and checkpointing (they are not used for back-propagation).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LossOutput {
|
||||
pub struct WiFiLossComponents {
|
||||
/// Total weighted loss value (scalar, in ℝ≥0).
|
||||
pub total: f32,
|
||||
/// Keypoint heatmap MSE loss component.
|
||||
@@ -159,7 +162,7 @@ impl WiFiDensePoseLoss {
|
||||
|
||||
// ── 2. UV regression: Smooth-L1 masked by foreground pixels ────────
|
||||
// Foreground mask: pixels where target part ≠ 0, shape [B, H, W].
|
||||
let fg_mask = target_int.not_equal(0);
|
||||
let fg_mask = target_int.not_equal(0_i64);
|
||||
// Expand to [B, 1, H, W] then broadcast to [B, 48, H, W].
|
||||
let fg_mask_f = fg_mask
|
||||
.unsqueeze(1)
|
||||
@@ -218,7 +221,7 @@ impl WiFiDensePoseLoss {
|
||||
target_uv: Option<&Tensor>,
|
||||
student_features: Option<&Tensor>,
|
||||
teacher_features: Option<&Tensor>,
|
||||
) -> (Tensor, LossOutput) {
|
||||
) -> (Tensor, WiFiLossComponents) {
|
||||
let mut details = HashMap::new();
|
||||
|
||||
// ── Keypoint loss (always computed) ───────────────────────────────
|
||||
@@ -243,7 +246,7 @@ impl WiFiDensePoseLoss {
|
||||
let part_val = part_loss.double_value(&[]) as f32;
|
||||
|
||||
// UV loss (foreground masked)
|
||||
let fg_mask = target_int.not_equal(0);
|
||||
let fg_mask = target_int.not_equal(0_i64);
|
||||
let fg_mask_f = fg_mask
|
||||
.unsqueeze(1)
|
||||
.expand_as(pu)
|
||||
@@ -280,7 +283,7 @@ impl WiFiDensePoseLoss {
|
||||
|
||||
let total_val = total.double_value(&[]) as f32;
|
||||
|
||||
let output = LossOutput {
|
||||
let output = WiFiLossComponents {
|
||||
total: total_val,
|
||||
keypoint: kp_val as f32,
|
||||
densepose: dp_val,
|
||||
|
||||
Reference in New Issue
Block a user