Losses (losses.rs — 1056 lines): - WiFiDensePoseLoss with keypoint (visibility-masked MSE), DensePose (cross-entropy + Smooth L1 UV masked to foreground), transfer (MSE) - generate_gaussian_heatmaps: Tensor-native 2D Gaussian heatmap gen - compute_losses: unified functional API - 11 deterministic unit tests Metrics (metrics.rs — 984 lines): - PCK@0.2 / PCK@0.5 with torso-diameter normalisation - OKS with COCO standard per-joint sigmas - MetricsAccumulator for online streaming eval - hungarian_assignment: O(n³) Kuhn-Munkres min-cut via DFS augmenting paths for optimal multi-person keypoint assignment (ruvector min-cut) - build_oks_cost_matrix: 1−OKS cost for bipartite matching - 20 deterministic tests (perfect/wrong/invisible keypoints, 2×2/3×3/ rectangular/empty Hungarian cases) Model (model.rs — 713 lines): - WiFiDensePoseModel end-to-end with tch-rs - ModalityTranslator: amp+phase FC encoders → spatial pseudo-image - Backbone: lightweight ResNet-style [B,3,48,48]→[B,256,6,6] - KeypointHead: [B,256,6,6]→[B,17,H,W] heatmaps - DensePoseHead: [B,256,6,6]→[B,25,H,W] parts + [B,48,H,W] UV Trainer (trainer.rs — 777 lines): - Full training loop: Adam, LR milestones, gradient clipping - Deterministic batch shuffle via LCG (seed XOR epoch) - CSV logging, best-checkpoint saving, early stopping - evaluate() with MetricsAccumulator and heatmap argmax decode Binaries: - src/bin/train.rs: production MM-Fi training CLI (clap) - src/bin/verify_training.rs: trust kill switch (EXIT 0/1/2) Benches: - benches/training_bench.rs: criterion benchmarks for key ops Tests: - tests/test_dataset.rs (459 lines) - tests/test_metrics.rs (449 lines) - tests/test_subcarrier.rs (389 lines) proof.rs still stub — trainer agent completing it. https://claude.ai/code/session_01BSBAQJ34SLkiJy4A8SoiL4
80 lines
1.7 KiB
TOML
80 lines
1.7 KiB
TOML
[package]
|
|
name = "wifi-densepose-train"
|
|
version = "0.1.0"
|
|
edition = "2021"
|
|
authors = ["WiFi-DensePose Contributors"]
|
|
license = "MIT OR Apache-2.0"
|
|
description = "Training pipeline for WiFi-DensePose pose estimation"
|
|
keywords = ["wifi", "training", "pose-estimation", "deep-learning"]
|
|
|
|
[[bin]]
|
|
name = "train"
|
|
path = "src/bin/train.rs"
|
|
|
|
[[bin]]
|
|
name = "verify-training"
|
|
path = "src/bin/verify_training.rs"
|
|
|
|
[features]
|
|
default = []
|
|
tch-backend = ["tch"]
|
|
cuda = ["tch-backend"]
|
|
|
|
[dependencies]
|
|
# Internal crates
|
|
wifi-densepose-signal = { path = "../wifi-densepose-signal" }
|
|
wifi-densepose-nn = { path = "../wifi-densepose-nn" }
|
|
|
|
# Core
|
|
thiserror.workspace = true
|
|
anyhow.workspace = true
|
|
serde = { workspace = true, features = ["derive"] }
|
|
serde_json.workspace = true
|
|
|
|
# Tensor / math
|
|
ndarray.workspace = true
|
|
num-complex.workspace = true
|
|
num-traits.workspace = true
|
|
|
|
# PyTorch bindings (optional — only enabled by `tch-backend` feature)
|
|
tch = { workspace = true, optional = true }
|
|
|
|
# Graph algorithms (min-cut for optimal keypoint assignment)
|
|
petgraph.workspace = true
|
|
|
|
# Data loading
|
|
ndarray-npy.workspace = true
|
|
memmap2 = "0.9"
|
|
walkdir.workspace = true
|
|
|
|
# Serialization
|
|
csv.workspace = true
|
|
toml = "0.8"
|
|
|
|
# Logging / progress
|
|
tracing.workspace = true
|
|
tracing-subscriber.workspace = true
|
|
indicatif.workspace = true
|
|
|
|
# Async (subset of features needed by training pipeline)
|
|
tokio = { workspace = true, features = ["rt", "rt-multi-thread", "macros", "fs"] }
|
|
|
|
# Crypto (for proof hash)
|
|
sha2.workspace = true
|
|
|
|
# CLI
|
|
clap.workspace = true
|
|
|
|
# Time
|
|
chrono = { version = "0.4", features = ["serde"] }
|
|
|
|
[dev-dependencies]
|
|
criterion.workspace = true
|
|
proptest.workspace = true
|
|
tempfile = "3.10"
|
|
approx = "0.5"
|
|
|
|
[[bench]]
|
|
name = "training_bench"
|
|
harness = false
|