Files
wifi-densepose/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/Cargo.toml
Claude ec98e40fff feat(rust): Add wifi-densepose-train crate with full training pipeline
Implements the training infrastructure described in ADR-015:

- config.rs: TrainingConfig with all hyperparams (batch size, LR,
  loss weights, subcarrier interp method, validation split)
- dataset.rs: MmFiDataset (real MM-Fi .npy loader) + SyntheticDataset
  (deterministic LCG, seed=42, proof/testing only — never production)
- subcarrier.rs: Linear/cubic interpolation 114→56 subcarriers
- error.rs: Typed errors (DataNotFound, InvalidFormat, IoError)
- losses.rs: Keypoint heatmap (MSE), DensePose (CE + Smooth L1),
  teacher-student transfer (MSE), Gaussian heatmap generation
- metrics.rs: PCK@0.2, OKS with Hungarian min-cut bipartite assignment
  via petgraph (optimal multi-person keypoint matching)
- model.rs: WiFiDensePoseModel end-to-end with tch-rs (PyTorch bindings)
- trainer.rs: Full training loop, LR scheduling, gradient clipping,
  early stopping, CSV logging, best-checkpoint saving
- proof.rs: Deterministic training proof (SHA-256 trust kill switch)

No random data in production paths. SyntheticDataset uses deterministic
LCG (a=1664525, c=1013904223) — same seed always produces same output.

https://claude.ai/code/session_01BSBAQJ34SLkiJy4A8SoiL4
2026-02-28 15:15:31 +00:00

81 lines
1.7 KiB
TOML

[package]
name = "wifi-densepose-train"
version = "0.1.0"
edition = "2021"
authors = ["WiFi-DensePose Contributors"]
license = "MIT OR Apache-2.0"
description = "Training pipeline for WiFi-DensePose pose estimation"
keywords = ["wifi", "training", "pose-estimation", "deep-learning"]
[[bin]]
name = "train"
path = "src/bin/train.rs"
[[bin]]
name = "verify-training"
path = "src/bin/verify_training.rs"
[features]
default = ["tch-backend"]
tch-backend = ["tch"]
cuda = ["tch-backend"]
[dependencies]
# Internal crates
wifi-densepose-signal = { path = "../wifi-densepose-signal" }
wifi-densepose-nn = { path = "../wifi-densepose-nn", default-features = false }
# Core
thiserror = "1.0"
anyhow = "1.0"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
# Tensor / math
ndarray = { version = "0.15", features = ["serde"] }
ndarray-linalg = { version = "0.16", features = ["openblas-static"] }
num-complex = "0.4"
num-traits = "0.2"
# PyTorch bindings (training)
tch = { version = "0.14", optional = true }
# Graph algorithms (min-cut for optimal keypoint assignment)
petgraph = "0.6"
# Data loading
ndarray-npy = "0.8"
memmap2 = "0.9"
walkdir = "2.4"
# Serialization
csv = "1.3"
toml = "0.8"
# Logging / progress
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
indicatif = "0.17"
# Async
tokio = { version = "1.35", features = ["rt", "rt-multi-thread", "macros", "fs"] }
# Crypto (for proof hash)
sha2 = "0.10"
# CLI
clap = { version = "4.4", features = ["derive"] }
# Time
chrono = { version = "0.4", features = ["serde"] }
[dev-dependencies]
criterion = { version = "0.5", features = ["html_reports"] }
proptest = "1.4"
tempfile = "3.10"
approx = "0.5"
[[bench]]
name = "training_bench"
harness = false