feat(rust): Add wifi-densepose-train crate with full training pipeline
Implements the training infrastructure described in ADR-015: - config.rs: TrainingConfig with all hyperparams (batch size, LR, loss weights, subcarrier interp method, validation split) - dataset.rs: MmFiDataset (real MM-Fi .npy loader) + SyntheticDataset (deterministic LCG, seed=42, proof/testing only — never production) - subcarrier.rs: Linear/cubic interpolation 114→56 subcarriers - error.rs: Typed errors (DataNotFound, InvalidFormat, IoError) - losses.rs: Keypoint heatmap (MSE), DensePose (CE + Smooth L1), teacher-student transfer (MSE), Gaussian heatmap generation - metrics.rs: PCK@0.2, OKS with Hungarian min-cut bipartite assignment via petgraph (optimal multi-person keypoint matching) - model.rs: WiFiDensePoseModel end-to-end with tch-rs (PyTorch bindings) - trainer.rs: Full training loop, LR scheduling, gradient clipping, early stopping, CSV logging, best-checkpoint saving - proof.rs: Deterministic training proof (SHA-256 trust kill switch) No random data in production paths. SyntheticDataset uses deterministic LCG (a=1664525, c=1013904223) — same seed always produces same output. https://claude.ai/code/session_01BSBAQJ34SLkiJy4A8SoiL4
This commit is contained in:
@@ -0,0 +1,507 @@
|
||||
//! Training configuration for WiFi-DensePose.
|
||||
//!
|
||||
//! [`TrainingConfig`] is the single source of truth for all hyper-parameters,
|
||||
//! dataset shapes, loss weights, and infrastructure settings used throughout
|
||||
//! the training pipeline. It is serializable via [`serde`] so it can be stored
|
||||
//! to / restored from JSON checkpoint files.
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! ```rust
|
||||
//! use wifi_densepose_train::config::TrainingConfig;
|
||||
//!
|
||||
//! let cfg = TrainingConfig::default();
|
||||
//! cfg.validate().expect("default config is valid");
|
||||
//!
|
||||
//! assert_eq!(cfg.num_subcarriers, 56);
|
||||
//! assert_eq!(cfg.num_keypoints, 17);
|
||||
//! ```
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use crate::error::ConfigError;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TrainingConfig
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Complete configuration for a WiFi-DensePose training run.
|
||||
///
|
||||
/// All fields have documented defaults that match the paper's experimental
|
||||
/// setup. Use [`TrainingConfig::default()`] as a starting point, then override
|
||||
/// individual fields as needed.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TrainingConfig {
|
||||
// -----------------------------------------------------------------------
|
||||
// Data / Signal
|
||||
// -----------------------------------------------------------------------
|
||||
/// Number of subcarriers after interpolation (system target).
|
||||
///
|
||||
/// The model always sees this many subcarriers regardless of the raw
|
||||
/// hardware output. Default: **56**.
|
||||
pub num_subcarriers: usize,
|
||||
|
||||
/// Number of subcarriers in the raw dataset before interpolation.
|
||||
///
|
||||
/// MM-Fi provides 114 subcarriers; set this to 56 when the dataset
|
||||
/// already matches the target count. Default: **114**.
|
||||
pub native_subcarriers: usize,
|
||||
|
||||
/// Number of transmit antennas. Default: **3**.
|
||||
pub num_antennas_tx: usize,
|
||||
|
||||
/// Number of receive antennas. Default: **3**.
|
||||
pub num_antennas_rx: usize,
|
||||
|
||||
/// Temporal sliding-window length in frames. Default: **100**.
|
||||
pub window_frames: usize,
|
||||
|
||||
/// Side length of the square keypoint heatmap output (H = W). Default: **56**.
|
||||
pub heatmap_size: usize,
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Model
|
||||
// -----------------------------------------------------------------------
|
||||
/// Number of body keypoints (COCO 17-joint skeleton). Default: **17**.
|
||||
pub num_keypoints: usize,
|
||||
|
||||
/// Number of DensePose body-part classes. Default: **24**.
|
||||
pub num_body_parts: usize,
|
||||
|
||||
/// Number of feature-map channels in the backbone encoder. Default: **256**.
|
||||
pub backbone_channels: usize,
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Optimisation
|
||||
// -----------------------------------------------------------------------
|
||||
/// Mini-batch size. Default: **8**.
|
||||
pub batch_size: usize,
|
||||
|
||||
/// Initial learning rate for the Adam / AdamW optimiser. Default: **1e-3**.
|
||||
pub learning_rate: f64,
|
||||
|
||||
/// L2 weight-decay regularisation coefficient. Default: **1e-4**.
|
||||
pub weight_decay: f64,
|
||||
|
||||
/// Total number of training epochs. Default: **50**.
|
||||
pub num_epochs: usize,
|
||||
|
||||
/// Number of linear-warmup epochs at the start. Default: **5**.
|
||||
pub warmup_epochs: usize,
|
||||
|
||||
/// Epochs at which the learning rate is multiplied by `lr_gamma`.
|
||||
///
|
||||
/// Default: **[30, 45]** (multi-step scheduler).
|
||||
pub lr_milestones: Vec<usize>,
|
||||
|
||||
/// Multiplicative factor applied at each LR milestone. Default: **0.1**.
|
||||
pub lr_gamma: f64,
|
||||
|
||||
/// Maximum gradient L2 norm for gradient clipping. Default: **1.0**.
|
||||
pub grad_clip_norm: f64,
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Loss weights
|
||||
// -----------------------------------------------------------------------
|
||||
/// Weight for the keypoint heatmap loss term. Default: **0.3**.
|
||||
pub lambda_kp: f64,
|
||||
|
||||
/// Weight for the DensePose body-part / UV-coordinate loss. Default: **0.6**.
|
||||
pub lambda_dp: f64,
|
||||
|
||||
/// Weight for the cross-modal transfer / domain-alignment loss. Default: **0.1**.
|
||||
pub lambda_tr: f64,
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Validation and checkpointing
|
||||
// -----------------------------------------------------------------------
|
||||
/// Run validation every N epochs. Default: **1**.
|
||||
pub val_every_epochs: usize,
|
||||
|
||||
/// Stop training if validation loss does not improve for this many
|
||||
/// consecutive validation rounds. Default: **10**.
|
||||
pub early_stopping_patience: usize,
|
||||
|
||||
/// Directory where model checkpoints are saved.
|
||||
pub checkpoint_dir: PathBuf,
|
||||
|
||||
/// Directory where TensorBoard / CSV logs are written.
|
||||
pub log_dir: PathBuf,
|
||||
|
||||
/// Keep only the top-K best checkpoints by validation metric. Default: **3**.
|
||||
pub save_top_k: usize,
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Device
|
||||
// -----------------------------------------------------------------------
|
||||
/// Use a CUDA GPU for training when available. Default: **false**.
|
||||
pub use_gpu: bool,
|
||||
|
||||
/// CUDA device index when `use_gpu` is `true`. Default: **0**.
|
||||
pub gpu_device_id: i64,
|
||||
|
||||
/// Number of background data-loading threads. Default: **4**.
|
||||
pub num_workers: usize,
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Reproducibility
|
||||
// -----------------------------------------------------------------------
|
||||
/// Global random seed for all RNG sources in the training pipeline.
|
||||
///
|
||||
/// This seed is applied to the dataset shuffler, model parameter
|
||||
/// initialisation, and any stochastic augmentation. Default: **42**.
|
||||
pub seed: u64,
|
||||
}
|
||||
|
||||
impl Default for TrainingConfig {
|
||||
fn default() -> Self {
|
||||
TrainingConfig {
|
||||
// Data
|
||||
num_subcarriers: 56,
|
||||
native_subcarriers: 114,
|
||||
num_antennas_tx: 3,
|
||||
num_antennas_rx: 3,
|
||||
window_frames: 100,
|
||||
heatmap_size: 56,
|
||||
// Model
|
||||
num_keypoints: 17,
|
||||
num_body_parts: 24,
|
||||
backbone_channels: 256,
|
||||
// Optimisation
|
||||
batch_size: 8,
|
||||
learning_rate: 1e-3,
|
||||
weight_decay: 1e-4,
|
||||
num_epochs: 50,
|
||||
warmup_epochs: 5,
|
||||
lr_milestones: vec![30, 45],
|
||||
lr_gamma: 0.1,
|
||||
grad_clip_norm: 1.0,
|
||||
// Loss weights
|
||||
lambda_kp: 0.3,
|
||||
lambda_dp: 0.6,
|
||||
lambda_tr: 0.1,
|
||||
// Validation / checkpointing
|
||||
val_every_epochs: 1,
|
||||
early_stopping_patience: 10,
|
||||
checkpoint_dir: PathBuf::from("checkpoints"),
|
||||
log_dir: PathBuf::from("logs"),
|
||||
save_top_k: 3,
|
||||
// Device
|
||||
use_gpu: false,
|
||||
gpu_device_id: 0,
|
||||
num_workers: 4,
|
||||
// Reproducibility
|
||||
seed: 42,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TrainingConfig {
|
||||
/// Load a [`TrainingConfig`] from a JSON file at `path`.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`ConfigError::FileRead`] if the file cannot be opened and
|
||||
/// [`ConfigError::InvalidValue`] if the JSON is malformed.
|
||||
pub fn from_json(path: &Path) -> Result<Self, ConfigError> {
|
||||
let contents = std::fs::read_to_string(path).map_err(|source| ConfigError::FileRead {
|
||||
path: path.to_path_buf(),
|
||||
source,
|
||||
})?;
|
||||
let cfg: TrainingConfig = serde_json::from_str(&contents)
|
||||
.map_err(|e| ConfigError::invalid_value("(file)", e.to_string()))?;
|
||||
cfg.validate()?;
|
||||
Ok(cfg)
|
||||
}
|
||||
|
||||
/// Serialize this configuration to pretty-printed JSON and write it to
|
||||
/// `path`, creating parent directories if necessary.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`ConfigError::FileRead`] if the directory cannot be created or
|
||||
/// the file cannot be written.
|
||||
pub fn to_json(&self, path: &Path) -> Result<(), ConfigError> {
|
||||
if let Some(parent) = path.parent() {
|
||||
std::fs::create_dir_all(parent).map_err(|source| ConfigError::FileRead {
|
||||
path: parent.to_path_buf(),
|
||||
source,
|
||||
})?;
|
||||
}
|
||||
let json = serde_json::to_string_pretty(self)
|
||||
.map_err(|e| ConfigError::invalid_value("(serialization)", e.to_string()))?;
|
||||
std::fs::write(path, json).map_err(|source| ConfigError::FileRead {
|
||||
path: path.to_path_buf(),
|
||||
source,
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns `true` when the native dataset subcarrier count differs from the
|
||||
/// model's target count and interpolation is therefore required.
|
||||
pub fn needs_subcarrier_interp(&self) -> bool {
|
||||
self.native_subcarriers != self.num_subcarriers
|
||||
}
|
||||
|
||||
/// Validate all fields and return an error describing the first problem
|
||||
/// found, or `Ok(())` if the configuration is coherent.
|
||||
///
|
||||
/// # Validated invariants
|
||||
///
|
||||
/// - Subcarrier counts must be non-zero.
|
||||
/// - Antenna counts must be non-zero.
|
||||
/// - `window_frames` must be at least 1.
|
||||
/// - `batch_size` must be at least 1.
|
||||
/// - `learning_rate` must be strictly positive.
|
||||
/// - `weight_decay` must be non-negative.
|
||||
/// - Loss weights must be non-negative and sum to a positive value.
|
||||
/// - `num_epochs` must be greater than `warmup_epochs`.
|
||||
/// - All `lr_milestones` must be within `[1, num_epochs]` and strictly
|
||||
/// increasing.
|
||||
/// - `save_top_k` must be at least 1.
|
||||
/// - `val_every_epochs` must be at least 1.
|
||||
pub fn validate(&self) -> Result<(), ConfigError> {
|
||||
// Subcarrier counts
|
||||
if self.num_subcarriers == 0 {
|
||||
return Err(ConfigError::invalid_value("num_subcarriers", "must be > 0"));
|
||||
}
|
||||
if self.native_subcarriers == 0 {
|
||||
return Err(ConfigError::invalid_value(
|
||||
"native_subcarriers",
|
||||
"must be > 0",
|
||||
));
|
||||
}
|
||||
|
||||
// Antenna counts
|
||||
if self.num_antennas_tx == 0 {
|
||||
return Err(ConfigError::invalid_value("num_antennas_tx", "must be > 0"));
|
||||
}
|
||||
if self.num_antennas_rx == 0 {
|
||||
return Err(ConfigError::invalid_value("num_antennas_rx", "must be > 0"));
|
||||
}
|
||||
|
||||
// Temporal window
|
||||
if self.window_frames == 0 {
|
||||
return Err(ConfigError::invalid_value("window_frames", "must be > 0"));
|
||||
}
|
||||
|
||||
// Heatmap
|
||||
if self.heatmap_size == 0 {
|
||||
return Err(ConfigError::invalid_value("heatmap_size", "must be > 0"));
|
||||
}
|
||||
|
||||
// Model dims
|
||||
if self.num_keypoints == 0 {
|
||||
return Err(ConfigError::invalid_value("num_keypoints", "must be > 0"));
|
||||
}
|
||||
if self.num_body_parts == 0 {
|
||||
return Err(ConfigError::invalid_value("num_body_parts", "must be > 0"));
|
||||
}
|
||||
if self.backbone_channels == 0 {
|
||||
return Err(ConfigError::invalid_value(
|
||||
"backbone_channels",
|
||||
"must be > 0",
|
||||
));
|
||||
}
|
||||
|
||||
// Optimisation
|
||||
if self.batch_size == 0 {
|
||||
return Err(ConfigError::invalid_value("batch_size", "must be > 0"));
|
||||
}
|
||||
if self.learning_rate <= 0.0 {
|
||||
return Err(ConfigError::invalid_value(
|
||||
"learning_rate",
|
||||
"must be > 0.0",
|
||||
));
|
||||
}
|
||||
if self.weight_decay < 0.0 {
|
||||
return Err(ConfigError::invalid_value(
|
||||
"weight_decay",
|
||||
"must be >= 0.0",
|
||||
));
|
||||
}
|
||||
if self.grad_clip_norm <= 0.0 {
|
||||
return Err(ConfigError::invalid_value(
|
||||
"grad_clip_norm",
|
||||
"must be > 0.0",
|
||||
));
|
||||
}
|
||||
|
||||
// Epochs
|
||||
if self.num_epochs == 0 {
|
||||
return Err(ConfigError::invalid_value("num_epochs", "must be > 0"));
|
||||
}
|
||||
if self.warmup_epochs >= self.num_epochs {
|
||||
return Err(ConfigError::invalid_value(
|
||||
"warmup_epochs",
|
||||
"must be < num_epochs",
|
||||
));
|
||||
}
|
||||
|
||||
// LR milestones: must be strictly increasing and within bounds
|
||||
let mut prev = 0usize;
|
||||
for &m in &self.lr_milestones {
|
||||
if m == 0 || m > self.num_epochs {
|
||||
return Err(ConfigError::invalid_value(
|
||||
"lr_milestones",
|
||||
"each milestone must be in [1, num_epochs]",
|
||||
));
|
||||
}
|
||||
if m <= prev {
|
||||
return Err(ConfigError::invalid_value(
|
||||
"lr_milestones",
|
||||
"milestones must be strictly increasing",
|
||||
));
|
||||
}
|
||||
prev = m;
|
||||
}
|
||||
|
||||
if self.lr_gamma <= 0.0 || self.lr_gamma >= 1.0 {
|
||||
return Err(ConfigError::invalid_value(
|
||||
"lr_gamma",
|
||||
"must be in (0.0, 1.0)",
|
||||
));
|
||||
}
|
||||
|
||||
// Loss weights
|
||||
if self.lambda_kp < 0.0 {
|
||||
return Err(ConfigError::invalid_value("lambda_kp", "must be >= 0.0"));
|
||||
}
|
||||
if self.lambda_dp < 0.0 {
|
||||
return Err(ConfigError::invalid_value("lambda_dp", "must be >= 0.0"));
|
||||
}
|
||||
if self.lambda_tr < 0.0 {
|
||||
return Err(ConfigError::invalid_value("lambda_tr", "must be >= 0.0"));
|
||||
}
|
||||
let total_weight = self.lambda_kp + self.lambda_dp + self.lambda_tr;
|
||||
if total_weight <= 0.0 {
|
||||
return Err(ConfigError::invalid_value(
|
||||
"lambda_kp / lambda_dp / lambda_tr",
|
||||
"at least one loss weight must be > 0.0",
|
||||
));
|
||||
}
|
||||
|
||||
// Validation / checkpoint
|
||||
if self.val_every_epochs == 0 {
|
||||
return Err(ConfigError::invalid_value(
|
||||
"val_every_epochs",
|
||||
"must be > 0",
|
||||
));
|
||||
}
|
||||
if self.save_top_k == 0 {
|
||||
return Err(ConfigError::invalid_value("save_top_k", "must be > 0"));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn default_config_is_valid() {
|
||||
let cfg = TrainingConfig::default();
|
||||
cfg.validate().expect("default config should be valid");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn json_round_trip() {
|
||||
let tmp = tempdir().unwrap();
|
||||
let path = tmp.path().join("config.json");
|
||||
|
||||
let original = TrainingConfig::default();
|
||||
original.to_json(&path).expect("serialization should succeed");
|
||||
|
||||
let loaded = TrainingConfig::from_json(&path).expect("deserialization should succeed");
|
||||
assert_eq!(loaded.num_subcarriers, original.num_subcarriers);
|
||||
assert_eq!(loaded.batch_size, original.batch_size);
|
||||
assert_eq!(loaded.seed, original.seed);
|
||||
assert_eq!(loaded.lr_milestones, original.lr_milestones);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zero_subcarriers_is_invalid() {
|
||||
let mut cfg = TrainingConfig::default();
|
||||
cfg.num_subcarriers = 0;
|
||||
assert!(cfg.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn negative_learning_rate_is_invalid() {
|
||||
let mut cfg = TrainingConfig::default();
|
||||
cfg.learning_rate = -0.001;
|
||||
assert!(cfg.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn warmup_equal_to_epochs_is_invalid() {
|
||||
let mut cfg = TrainingConfig::default();
|
||||
cfg.warmup_epochs = cfg.num_epochs;
|
||||
assert!(cfg.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_increasing_milestones_are_invalid() {
|
||||
let mut cfg = TrainingConfig::default();
|
||||
cfg.lr_milestones = vec![30, 20]; // wrong order
|
||||
assert!(cfg.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn milestone_beyond_epochs_is_invalid() {
|
||||
let mut cfg = TrainingConfig::default();
|
||||
cfg.lr_milestones = vec![30, cfg.num_epochs + 1];
|
||||
assert!(cfg.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn all_zero_loss_weights_are_invalid() {
|
||||
let mut cfg = TrainingConfig::default();
|
||||
cfg.lambda_kp = 0.0;
|
||||
cfg.lambda_dp = 0.0;
|
||||
cfg.lambda_tr = 0.0;
|
||||
assert!(cfg.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn needs_subcarrier_interp_when_counts_differ() {
|
||||
let mut cfg = TrainingConfig::default();
|
||||
cfg.num_subcarriers = 56;
|
||||
cfg.native_subcarriers = 114;
|
||||
assert!(cfg.needs_subcarrier_interp());
|
||||
|
||||
cfg.native_subcarriers = 56;
|
||||
assert!(!cfg.needs_subcarrier_interp());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_fields_have_expected_defaults() {
|
||||
let cfg = TrainingConfig::default();
|
||||
assert_eq!(cfg.num_subcarriers, 56);
|
||||
assert_eq!(cfg.native_subcarriers, 114);
|
||||
assert_eq!(cfg.num_antennas_tx, 3);
|
||||
assert_eq!(cfg.num_antennas_rx, 3);
|
||||
assert_eq!(cfg.window_frames, 100);
|
||||
assert_eq!(cfg.heatmap_size, 56);
|
||||
assert_eq!(cfg.num_keypoints, 17);
|
||||
assert_eq!(cfg.num_body_parts, 24);
|
||||
assert_eq!(cfg.batch_size, 8);
|
||||
assert!((cfg.learning_rate - 1e-3).abs() < 1e-10);
|
||||
assert_eq!(cfg.num_epochs, 50);
|
||||
assert_eq!(cfg.warmup_epochs, 5);
|
||||
assert_eq!(cfg.lr_milestones, vec![30, 45]);
|
||||
assert!((cfg.lr_gamma - 0.1).abs() < 1e-10);
|
||||
assert!((cfg.lambda_kp - 0.3).abs() < 1e-10);
|
||||
assert!((cfg.lambda_dp - 0.6).abs() < 1e-10);
|
||||
assert!((cfg.lambda_tr - 0.1).abs() < 1e-10);
|
||||
assert_eq!(cfg.seed, 42);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user