feat(rust): Complete training pipeline — losses, metrics, model, trainer, binaries

Losses (losses.rs — 1056 lines):
- WiFiDensePoseLoss with keypoint (visibility-masked MSE), DensePose
  (cross-entropy + Smooth L1 UV masked to foreground), transfer (MSE)
- generate_gaussian_heatmaps: Tensor-native 2D Gaussian heatmap gen
- compute_losses: unified functional API
- 11 deterministic unit tests

Metrics (metrics.rs — 984 lines):
- PCK@0.2 / PCK@0.5 with torso-diameter normalisation
- OKS with COCO standard per-joint sigmas
- MetricsAccumulator for online streaming eval
- hungarian_assignment: O(n³) Kuhn-Munkres min-cut via DFS augmenting
  paths for optimal multi-person keypoint assignment (ruvector min-cut)
- build_oks_cost_matrix: 1−OKS cost for bipartite matching
- 20 deterministic tests (perfect/wrong/invisible keypoints, 2×2/3×3/
  rectangular/empty Hungarian cases)

Model (model.rs — 713 lines):
- WiFiDensePoseModel end-to-end with tch-rs
- ModalityTranslator: amp+phase FC encoders → spatial pseudo-image
- Backbone: lightweight ResNet-style [B,3,48,48]→[B,256,6,6]
- KeypointHead: [B,256,6,6]→[B,17,H,W] heatmaps
- DensePoseHead: [B,256,6,6]→[B,25,H,W] parts + [B,48,H,W] UV

Trainer (trainer.rs — 777 lines):
- Full training loop: Adam, LR milestones, gradient clipping
- Deterministic batch shuffle via LCG (seed XOR epoch)
- CSV logging, best-checkpoint saving, early stopping
- evaluate() with MetricsAccumulator and heatmap argmax decode

Binaries:
- src/bin/train.rs: production MM-Fi training CLI (clap)
- src/bin/verify_training.rs: trust kill switch (EXIT 0/1/2)

Benches:
- benches/training_bench.rs: criterion benchmarks for key ops

Tests:
- tests/test_dataset.rs (459 lines)
- tests/test_metrics.rs (449 lines)
- tests/test_subcarrier.rs (389 lines)

proof.rs still stub — trainer agent completing it.

https://claude.ai/code/session_01BSBAQJ34SLkiJy4A8SoiL4
This commit is contained in:
Claude
2026-02-28 15:22:54 +00:00
parent 2c5ca308a4
commit fce1271140
16 changed files with 4828 additions and 159 deletions

View File

@@ -15,9 +15,10 @@
use thiserror::Error;
use std::path::PathBuf;
// Import module-local error types so TrainError can wrap them via #[from].
use crate::config::ConfigError;
use crate::dataset::DatasetError;
// Import module-local error types so TrainError can wrap them via #[from],
// and re-export them so `lib.rs` can forward them from `error::*`.
pub use crate::config::ConfigError;
pub use crate::dataset::DatasetError;
// ---------------------------------------------------------------------------
// Top-level training error
@@ -41,14 +42,18 @@ pub enum TrainError {
#[error("Dataset error: {0}")]
Dataset(#[from] DatasetError),
/// An underlying I/O error not covered by a more specific variant.
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
/// JSON (de)serialization error.
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
/// An underlying I/O error not wrapped by Config or Dataset.
///
/// Note: [`std::io::Error`] cannot be wrapped via `#[from]` here because
/// both [`ConfigError`] and [`DatasetError`] already implement
/// `From<std::io::Error>`. Callers should convert via those types instead.
#[error("I/O error: {0}")]
Io(String),
/// An operation was attempted on an empty dataset.
#[error("Dataset is empty")]
EmptyDataset,
@@ -113,3 +118,67 @@ impl TrainError {
TrainError::ShapeMismatch { expected, actual }
}
}
// ---------------------------------------------------------------------------
// SubcarrierError
// ---------------------------------------------------------------------------
/// Errors produced by the subcarrier resampling / interpolation functions.
///
/// These are separate from [`DatasetError`] because subcarrier operations are
/// also usable outside the dataset loading pipeline (e.g. in real-time
/// inference preprocessing).
#[derive(Debug, Error)]
pub enum SubcarrierError {
/// The source or destination subcarrier count is zero.
#[error("Subcarrier count must be >= 1, got {count}")]
ZeroCount {
/// The offending count.
count: usize,
},
/// The input array's last dimension does not match the declared source count.
#[error(
"Subcarrier shape mismatch: last dimension is {actual_sc} \
but `src_n` was declared as {expected_sc} (full shape: {shape:?})"
)]
InputShapeMismatch {
/// Expected subcarrier count (as declared by the caller).
expected_sc: usize,
/// Actual last-dimension size of the input array.
actual_sc: usize,
/// Full shape of the input array.
shape: Vec<usize>,
},
/// The requested interpolation method is not yet implemented.
#[error("Interpolation method `{method}` is not implemented")]
MethodNotImplemented {
/// Human-readable name of the unsupported method.
method: String,
},
/// `src_n == dst_n` — no resampling is needed.
///
/// Callers should check [`TrainingConfig::needs_subcarrier_interp`] before
/// calling the interpolation routine.
///
/// [`TrainingConfig::needs_subcarrier_interp`]:
/// crate::config::TrainingConfig::needs_subcarrier_interp
#[error("src_n == dst_n == {count}; no interpolation needed")]
NopInterpolation {
/// The equal count.
count: usize,
},
/// A numerical error during interpolation (e.g. division by zero).
#[error("Numerical error: {0}")]
NumericalError(String),
}
impl SubcarrierError {
/// Construct a [`SubcarrierError::NumericalError`].
pub fn numerical<S: Into<String>>(msg: S) -> Self {
SubcarrierError::NumericalError(msg.into())
}
}