feat(rust): Complete training pipeline — losses, metrics, model, trainer, binaries

Losses (losses.rs — 1056 lines):
- WiFiDensePoseLoss with keypoint (visibility-masked MSE), DensePose
  (cross-entropy + Smooth L1 UV masked to foreground), transfer (MSE)
- generate_gaussian_heatmaps: Tensor-native 2D Gaussian heatmap gen
- compute_losses: unified functional API
- 11 deterministic unit tests

Metrics (metrics.rs — 984 lines):
- PCK@0.2 / PCK@0.5 with torso-diameter normalisation
- OKS with COCO standard per-joint sigmas
- MetricsAccumulator for online streaming eval
- hungarian_assignment: O(n³) Kuhn-Munkres min-cut via DFS augmenting
  paths for optimal multi-person keypoint assignment (ruvector min-cut)
- build_oks_cost_matrix: 1−OKS cost for bipartite matching
- 20 deterministic tests (perfect/wrong/invisible keypoints, 2×2/3×3/
  rectangular/empty Hungarian cases)

Model (model.rs — 713 lines):
- WiFiDensePoseModel end-to-end with tch-rs
- ModalityTranslator: amp+phase FC encoders → spatial pseudo-image
- Backbone: lightweight ResNet-style [B,3,48,48]→[B,256,6,6]
- KeypointHead: [B,256,6,6]→[B,17,H,W] heatmaps
- DensePoseHead: [B,256,6,6]→[B,25,H,W] parts + [B,48,H,W] UV

Trainer (trainer.rs — 777 lines):
- Full training loop: Adam, LR milestones, gradient clipping
- Deterministic batch shuffle via LCG (seed XOR epoch)
- CSV logging, best-checkpoint saving, early stopping
- evaluate() with MetricsAccumulator and heatmap argmax decode

Binaries:
- src/bin/train.rs: production MM-Fi training CLI (clap)
- src/bin/verify_training.rs: trust kill switch (EXIT 0/1/2)

Benches:
- benches/training_bench.rs: criterion benchmarks for key ops

Tests:
- tests/test_dataset.rs (459 lines)
- tests/test_metrics.rs (449 lines)
- tests/test_subcarrier.rs (389 lines)

proof.rs still stub — trainer agent completing it.

https://claude.ai/code/session_01BSBAQJ34SLkiJy4A8SoiL4
This commit is contained in:
Claude
2026-02-28 15:22:54 +00:00
parent 2c5ca308a4
commit fce1271140
16 changed files with 4828 additions and 159 deletions

View File

@@ -0,0 +1,149 @@
//! Benchmarks for the WiFi-DensePose training pipeline.
//!
//! Run with:
//! ```bash
//! cargo bench -p wifi-densepose-train
//! ```
//!
//! Criterion HTML reports are written to `target/criterion/`.
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
use ndarray::Array4;
use wifi_densepose_train::{
config::TrainingConfig,
dataset::{CsiDataset, SyntheticCsiDataset, SyntheticConfig},
subcarrier::{compute_interp_weights, interpolate_subcarriers},
};
// ---------------------------------------------------------------------------
// Dataset benchmarks
// ---------------------------------------------------------------------------
/// Benchmark synthetic sample generation for a single index.
fn bench_synthetic_get(c: &mut Criterion) {
let syn_cfg = SyntheticConfig::default();
let dataset = SyntheticCsiDataset::new(1000, syn_cfg);
c.bench_function("synthetic_dataset_get", |b| {
b.iter(|| {
let _ = dataset.get(black_box(42)).expect("sample 42 must exist");
});
});
}
/// Benchmark full epoch iteration (no I/O — all in-process).
fn bench_synthetic_epoch(c: &mut Criterion) {
let mut group = c.benchmark_group("synthetic_epoch");
for n_samples in [64usize, 256, 1024] {
let syn_cfg = SyntheticConfig::default();
let dataset = SyntheticCsiDataset::new(n_samples, syn_cfg);
group.bench_with_input(
BenchmarkId::new("samples", n_samples),
&n_samples,
|b, &n| {
b.iter(|| {
for i in 0..n {
let _ = dataset.get(black_box(i)).expect("sample exists");
}
});
},
);
}
group.finish();
}
// ---------------------------------------------------------------------------
// Subcarrier interpolation benchmarks
// ---------------------------------------------------------------------------
/// Benchmark `interpolate_subcarriers` for the standard 114 → 56 use-case.
fn bench_interp_114_to_56(c: &mut Criterion) {
// Simulate a single sample worth of raw CSI from MM-Fi.
let cfg = TrainingConfig::default();
let arr: Array4<f32> = Array4::from_shape_fn(
(cfg.window_frames, cfg.num_antennas_tx, cfg.num_antennas_rx, 114),
|(t, tx, rx, k)| (t + tx + rx + k) as f32 * 0.001,
);
c.bench_function("interp_114_to_56", |b| {
b.iter(|| {
let _ = interpolate_subcarriers(black_box(&arr), black_box(56));
});
});
}
/// Benchmark `compute_interp_weights` to ensure it is fast enough to
/// precompute at dataset construction time.
fn bench_compute_interp_weights(c: &mut Criterion) {
c.bench_function("compute_interp_weights_114_56", |b| {
b.iter(|| {
let _ = compute_interp_weights(black_box(114), black_box(56));
});
});
}
/// Benchmark interpolation for varying source subcarrier counts.
fn bench_interp_scaling(c: &mut Criterion) {
let mut group = c.benchmark_group("interp_scaling");
let cfg = TrainingConfig::default();
for src_sc in [56usize, 114, 256, 512] {
let arr: Array4<f32> = Array4::zeros((
cfg.window_frames,
cfg.num_antennas_tx,
cfg.num_antennas_rx,
src_sc,
));
group.bench_with_input(
BenchmarkId::new("src_sc", src_sc),
&src_sc,
|b, &sc| {
if sc == 56 {
// Identity case — skip; interpolate_subcarriers clones.
b.iter(|| {
let _ = arr.clone();
});
} else {
b.iter(|| {
let _ = interpolate_subcarriers(black_box(&arr), black_box(56));
});
}
},
);
}
group.finish();
}
// ---------------------------------------------------------------------------
// Config benchmarks
// ---------------------------------------------------------------------------
/// Benchmark TrainingConfig::validate() to ensure it stays O(1).
fn bench_config_validate(c: &mut Criterion) {
let config = TrainingConfig::default();
c.bench_function("config_validate", |b| {
b.iter(|| {
let _ = black_box(&config).validate();
});
});
}
// ---------------------------------------------------------------------------
// Criterion main
// ---------------------------------------------------------------------------
criterion_group!(
benches,
bench_synthetic_get,
bench_synthetic_epoch,
bench_interp_114_to_56,
bench_compute_interp_weights,
bench_interp_scaling,
bench_config_validate,
);
criterion_main!(benches);