Files
wifi-densepose/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/Cargo.toml
Claude fce1271140 feat(rust): Complete training pipeline — losses, metrics, model, trainer, binaries
Losses (losses.rs — 1056 lines):
- WiFiDensePoseLoss with keypoint (visibility-masked MSE), DensePose
  (cross-entropy + Smooth L1 UV masked to foreground), transfer (MSE)
- generate_gaussian_heatmaps: Tensor-native 2D Gaussian heatmap gen
- compute_losses: unified functional API
- 11 deterministic unit tests

Metrics (metrics.rs — 984 lines):
- PCK@0.2 / PCK@0.5 with torso-diameter normalisation
- OKS with COCO standard per-joint sigmas
- MetricsAccumulator for online streaming eval
- hungarian_assignment: O(n³) Kuhn-Munkres min-cut via DFS augmenting
  paths for optimal multi-person keypoint assignment (ruvector min-cut)
- build_oks_cost_matrix: 1−OKS cost for bipartite matching
- 20 deterministic tests (perfect/wrong/invisible keypoints, 2×2/3×3/
  rectangular/empty Hungarian cases)

Model (model.rs — 713 lines):
- WiFiDensePoseModel end-to-end with tch-rs
- ModalityTranslator: amp+phase FC encoders → spatial pseudo-image
- Backbone: lightweight ResNet-style [B,3,48,48]→[B,256,6,6]
- KeypointHead: [B,256,6,6]→[B,17,H,W] heatmaps
- DensePoseHead: [B,256,6,6]→[B,25,H,W] parts + [B,48,H,W] UV

Trainer (trainer.rs — 777 lines):
- Full training loop: Adam, LR milestones, gradient clipping
- Deterministic batch shuffle via LCG (seed XOR epoch)
- CSV logging, best-checkpoint saving, early stopping
- evaluate() with MetricsAccumulator and heatmap argmax decode

Binaries:
- src/bin/train.rs: production MM-Fi training CLI (clap)
- src/bin/verify_training.rs: trust kill switch (EXIT 0/1/2)

Benches:
- benches/training_bench.rs: criterion benchmarks for key ops

Tests:
- tests/test_dataset.rs (459 lines)
- tests/test_metrics.rs (449 lines)
- tests/test_subcarrier.rs (389 lines)

proof.rs still stub — trainer agent completing it.

https://claude.ai/code/session_01BSBAQJ34SLkiJy4A8SoiL4
2026-02-28 15:22:54 +00:00

80 lines
1.7 KiB
TOML

[package]
name = "wifi-densepose-train"
version = "0.1.0"
edition = "2021"
authors = ["WiFi-DensePose Contributors"]
license = "MIT OR Apache-2.0"
description = "Training pipeline for WiFi-DensePose pose estimation"
keywords = ["wifi", "training", "pose-estimation", "deep-learning"]
[[bin]]
name = "train"
path = "src/bin/train.rs"
[[bin]]
name = "verify-training"
path = "src/bin/verify_training.rs"
[features]
default = []
tch-backend = ["tch"]
cuda = ["tch-backend"]
[dependencies]
# Internal crates
wifi-densepose-signal = { path = "../wifi-densepose-signal" }
wifi-densepose-nn = { path = "../wifi-densepose-nn" }
# Core
thiserror.workspace = true
anyhow.workspace = true
serde = { workspace = true, features = ["derive"] }
serde_json.workspace = true
# Tensor / math
ndarray.workspace = true
num-complex.workspace = true
num-traits.workspace = true
# PyTorch bindings (optional — only enabled by `tch-backend` feature)
tch = { workspace = true, optional = true }
# Graph algorithms (min-cut for optimal keypoint assignment)
petgraph.workspace = true
# Data loading
ndarray-npy.workspace = true
memmap2 = "0.9"
walkdir.workspace = true
# Serialization
csv.workspace = true
toml = "0.8"
# Logging / progress
tracing.workspace = true
tracing-subscriber.workspace = true
indicatif.workspace = true
# Async (subset of features needed by training pipeline)
tokio = { workspace = true, features = ["rt", "rt-multi-thread", "macros", "fs"] }
# Crypto (for proof hash)
sha2.workspace = true
# CLI
clap.workspace = true
# Time
chrono = { version = "0.4", features = ["serde"] }
[dev-dependencies]
criterion.workspace = true
proptest.workspace = true
tempfile = "3.10"
approx = "0.5"
[[bench]]
name = "training_bench"
harness = false