Files
wifi-densepose/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/config.rs
Claude ec98e40fff feat(rust): Add wifi-densepose-train crate with full training pipeline
Implements the training infrastructure described in ADR-015:

- config.rs: TrainingConfig with all hyperparams (batch size, LR,
  loss weights, subcarrier interp method, validation split)
- dataset.rs: MmFiDataset (real MM-Fi .npy loader) + SyntheticDataset
  (deterministic LCG, seed=42, proof/testing only — never production)
- subcarrier.rs: Linear/cubic interpolation 114→56 subcarriers
- error.rs: Typed errors (DataNotFound, InvalidFormat, IoError)
- losses.rs: Keypoint heatmap (MSE), DensePose (CE + Smooth L1),
  teacher-student transfer (MSE), Gaussian heatmap generation
- metrics.rs: PCK@0.2, OKS with Hungarian min-cut bipartite assignment
  via petgraph (optimal multi-person keypoint matching)
- model.rs: WiFiDensePoseModel end-to-end with tch-rs (PyTorch bindings)
- trainer.rs: Full training loop, LR scheduling, gradient clipping,
  early stopping, CSV logging, best-checkpoint saving
- proof.rs: Deterministic training proof (SHA-256 trust kill switch)

No random data in production paths. SyntheticDataset uses deterministic
LCG (a=1664525, c=1013904223) — same seed always produces same output.

https://claude.ai/code/session_01BSBAQJ34SLkiJy4A8SoiL4
2026-02-28 15:15:31 +00:00

508 lines
17 KiB
Rust

//! Training configuration for WiFi-DensePose.
//!
//! [`TrainingConfig`] is the single source of truth for all hyper-parameters,
//! dataset shapes, loss weights, and infrastructure settings used throughout
//! the training pipeline. It is serializable via [`serde`] so it can be stored
//! to / restored from JSON checkpoint files.
//!
//! # Example
//!
//! ```rust
//! use wifi_densepose_train::config::TrainingConfig;
//!
//! let cfg = TrainingConfig::default();
//! cfg.validate().expect("default config is valid");
//!
//! assert_eq!(cfg.num_subcarriers, 56);
//! assert_eq!(cfg.num_keypoints, 17);
//! ```
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use crate::error::ConfigError;
// ---------------------------------------------------------------------------
// TrainingConfig
// ---------------------------------------------------------------------------
/// Complete configuration for a WiFi-DensePose training run.
///
/// All fields have documented defaults that match the paper's experimental
/// setup. Use [`TrainingConfig::default()`] as a starting point, then override
/// individual fields as needed.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingConfig {
// -----------------------------------------------------------------------
// Data / Signal
// -----------------------------------------------------------------------
/// Number of subcarriers after interpolation (system target).
///
/// The model always sees this many subcarriers regardless of the raw
/// hardware output. Default: **56**.
pub num_subcarriers: usize,
/// Number of subcarriers in the raw dataset before interpolation.
///
/// MM-Fi provides 114 subcarriers; set this to 56 when the dataset
/// already matches the target count. Default: **114**.
pub native_subcarriers: usize,
/// Number of transmit antennas. Default: **3**.
pub num_antennas_tx: usize,
/// Number of receive antennas. Default: **3**.
pub num_antennas_rx: usize,
/// Temporal sliding-window length in frames. Default: **100**.
pub window_frames: usize,
/// Side length of the square keypoint heatmap output (H = W). Default: **56**.
pub heatmap_size: usize,
// -----------------------------------------------------------------------
// Model
// -----------------------------------------------------------------------
/// Number of body keypoints (COCO 17-joint skeleton). Default: **17**.
pub num_keypoints: usize,
/// Number of DensePose body-part classes. Default: **24**.
pub num_body_parts: usize,
/// Number of feature-map channels in the backbone encoder. Default: **256**.
pub backbone_channels: usize,
// -----------------------------------------------------------------------
// Optimisation
// -----------------------------------------------------------------------
/// Mini-batch size. Default: **8**.
pub batch_size: usize,
/// Initial learning rate for the Adam / AdamW optimiser. Default: **1e-3**.
pub learning_rate: f64,
/// L2 weight-decay regularisation coefficient. Default: **1e-4**.
pub weight_decay: f64,
/// Total number of training epochs. Default: **50**.
pub num_epochs: usize,
/// Number of linear-warmup epochs at the start. Default: **5**.
pub warmup_epochs: usize,
/// Epochs at which the learning rate is multiplied by `lr_gamma`.
///
/// Default: **[30, 45]** (multi-step scheduler).
pub lr_milestones: Vec<usize>,
/// Multiplicative factor applied at each LR milestone. Default: **0.1**.
pub lr_gamma: f64,
/// Maximum gradient L2 norm for gradient clipping. Default: **1.0**.
pub grad_clip_norm: f64,
// -----------------------------------------------------------------------
// Loss weights
// -----------------------------------------------------------------------
/// Weight for the keypoint heatmap loss term. Default: **0.3**.
pub lambda_kp: f64,
/// Weight for the DensePose body-part / UV-coordinate loss. Default: **0.6**.
pub lambda_dp: f64,
/// Weight for the cross-modal transfer / domain-alignment loss. Default: **0.1**.
pub lambda_tr: f64,
// -----------------------------------------------------------------------
// Validation and checkpointing
// -----------------------------------------------------------------------
/// Run validation every N epochs. Default: **1**.
pub val_every_epochs: usize,
/// Stop training if validation loss does not improve for this many
/// consecutive validation rounds. Default: **10**.
pub early_stopping_patience: usize,
/// Directory where model checkpoints are saved.
pub checkpoint_dir: PathBuf,
/// Directory where TensorBoard / CSV logs are written.
pub log_dir: PathBuf,
/// Keep only the top-K best checkpoints by validation metric. Default: **3**.
pub save_top_k: usize,
// -----------------------------------------------------------------------
// Device
// -----------------------------------------------------------------------
/// Use a CUDA GPU for training when available. Default: **false**.
pub use_gpu: bool,
/// CUDA device index when `use_gpu` is `true`. Default: **0**.
pub gpu_device_id: i64,
/// Number of background data-loading threads. Default: **4**.
pub num_workers: usize,
// -----------------------------------------------------------------------
// Reproducibility
// -----------------------------------------------------------------------
/// Global random seed for all RNG sources in the training pipeline.
///
/// This seed is applied to the dataset shuffler, model parameter
/// initialisation, and any stochastic augmentation. Default: **42**.
pub seed: u64,
}
impl Default for TrainingConfig {
fn default() -> Self {
TrainingConfig {
// Data
num_subcarriers: 56,
native_subcarriers: 114,
num_antennas_tx: 3,
num_antennas_rx: 3,
window_frames: 100,
heatmap_size: 56,
// Model
num_keypoints: 17,
num_body_parts: 24,
backbone_channels: 256,
// Optimisation
batch_size: 8,
learning_rate: 1e-3,
weight_decay: 1e-4,
num_epochs: 50,
warmup_epochs: 5,
lr_milestones: vec![30, 45],
lr_gamma: 0.1,
grad_clip_norm: 1.0,
// Loss weights
lambda_kp: 0.3,
lambda_dp: 0.6,
lambda_tr: 0.1,
// Validation / checkpointing
val_every_epochs: 1,
early_stopping_patience: 10,
checkpoint_dir: PathBuf::from("checkpoints"),
log_dir: PathBuf::from("logs"),
save_top_k: 3,
// Device
use_gpu: false,
gpu_device_id: 0,
num_workers: 4,
// Reproducibility
seed: 42,
}
}
}
impl TrainingConfig {
/// Load a [`TrainingConfig`] from a JSON file at `path`.
///
/// # Errors
///
/// Returns [`ConfigError::FileRead`] if the file cannot be opened and
/// [`ConfigError::InvalidValue`] if the JSON is malformed.
pub fn from_json(path: &Path) -> Result<Self, ConfigError> {
let contents = std::fs::read_to_string(path).map_err(|source| ConfigError::FileRead {
path: path.to_path_buf(),
source,
})?;
let cfg: TrainingConfig = serde_json::from_str(&contents)
.map_err(|e| ConfigError::invalid_value("(file)", e.to_string()))?;
cfg.validate()?;
Ok(cfg)
}
/// Serialize this configuration to pretty-printed JSON and write it to
/// `path`, creating parent directories if necessary.
///
/// # Errors
///
/// Returns [`ConfigError::FileRead`] if the directory cannot be created or
/// the file cannot be written.
pub fn to_json(&self, path: &Path) -> Result<(), ConfigError> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent).map_err(|source| ConfigError::FileRead {
path: parent.to_path_buf(),
source,
})?;
}
let json = serde_json::to_string_pretty(self)
.map_err(|e| ConfigError::invalid_value("(serialization)", e.to_string()))?;
std::fs::write(path, json).map_err(|source| ConfigError::FileRead {
path: path.to_path_buf(),
source,
})?;
Ok(())
}
/// Returns `true` when the native dataset subcarrier count differs from the
/// model's target count and interpolation is therefore required.
pub fn needs_subcarrier_interp(&self) -> bool {
self.native_subcarriers != self.num_subcarriers
}
/// Validate all fields and return an error describing the first problem
/// found, or `Ok(())` if the configuration is coherent.
///
/// # Validated invariants
///
/// - Subcarrier counts must be non-zero.
/// - Antenna counts must be non-zero.
/// - `window_frames` must be at least 1.
/// - `batch_size` must be at least 1.
/// - `learning_rate` must be strictly positive.
/// - `weight_decay` must be non-negative.
/// - Loss weights must be non-negative and sum to a positive value.
/// - `num_epochs` must be greater than `warmup_epochs`.
/// - All `lr_milestones` must be within `[1, num_epochs]` and strictly
/// increasing.
/// - `save_top_k` must be at least 1.
/// - `val_every_epochs` must be at least 1.
pub fn validate(&self) -> Result<(), ConfigError> {
// Subcarrier counts
if self.num_subcarriers == 0 {
return Err(ConfigError::invalid_value("num_subcarriers", "must be > 0"));
}
if self.native_subcarriers == 0 {
return Err(ConfigError::invalid_value(
"native_subcarriers",
"must be > 0",
));
}
// Antenna counts
if self.num_antennas_tx == 0 {
return Err(ConfigError::invalid_value("num_antennas_tx", "must be > 0"));
}
if self.num_antennas_rx == 0 {
return Err(ConfigError::invalid_value("num_antennas_rx", "must be > 0"));
}
// Temporal window
if self.window_frames == 0 {
return Err(ConfigError::invalid_value("window_frames", "must be > 0"));
}
// Heatmap
if self.heatmap_size == 0 {
return Err(ConfigError::invalid_value("heatmap_size", "must be > 0"));
}
// Model dims
if self.num_keypoints == 0 {
return Err(ConfigError::invalid_value("num_keypoints", "must be > 0"));
}
if self.num_body_parts == 0 {
return Err(ConfigError::invalid_value("num_body_parts", "must be > 0"));
}
if self.backbone_channels == 0 {
return Err(ConfigError::invalid_value(
"backbone_channels",
"must be > 0",
));
}
// Optimisation
if self.batch_size == 0 {
return Err(ConfigError::invalid_value("batch_size", "must be > 0"));
}
if self.learning_rate <= 0.0 {
return Err(ConfigError::invalid_value(
"learning_rate",
"must be > 0.0",
));
}
if self.weight_decay < 0.0 {
return Err(ConfigError::invalid_value(
"weight_decay",
"must be >= 0.0",
));
}
if self.grad_clip_norm <= 0.0 {
return Err(ConfigError::invalid_value(
"grad_clip_norm",
"must be > 0.0",
));
}
// Epochs
if self.num_epochs == 0 {
return Err(ConfigError::invalid_value("num_epochs", "must be > 0"));
}
if self.warmup_epochs >= self.num_epochs {
return Err(ConfigError::invalid_value(
"warmup_epochs",
"must be < num_epochs",
));
}
// LR milestones: must be strictly increasing and within bounds
let mut prev = 0usize;
for &m in &self.lr_milestones {
if m == 0 || m > self.num_epochs {
return Err(ConfigError::invalid_value(
"lr_milestones",
"each milestone must be in [1, num_epochs]",
));
}
if m <= prev {
return Err(ConfigError::invalid_value(
"lr_milestones",
"milestones must be strictly increasing",
));
}
prev = m;
}
if self.lr_gamma <= 0.0 || self.lr_gamma >= 1.0 {
return Err(ConfigError::invalid_value(
"lr_gamma",
"must be in (0.0, 1.0)",
));
}
// Loss weights
if self.lambda_kp < 0.0 {
return Err(ConfigError::invalid_value("lambda_kp", "must be >= 0.0"));
}
if self.lambda_dp < 0.0 {
return Err(ConfigError::invalid_value("lambda_dp", "must be >= 0.0"));
}
if self.lambda_tr < 0.0 {
return Err(ConfigError::invalid_value("lambda_tr", "must be >= 0.0"));
}
let total_weight = self.lambda_kp + self.lambda_dp + self.lambda_tr;
if total_weight <= 0.0 {
return Err(ConfigError::invalid_value(
"lambda_kp / lambda_dp / lambda_tr",
"at least one loss weight must be > 0.0",
));
}
// Validation / checkpoint
if self.val_every_epochs == 0 {
return Err(ConfigError::invalid_value(
"val_every_epochs",
"must be > 0",
));
}
if self.save_top_k == 0 {
return Err(ConfigError::invalid_value("save_top_k", "must be > 0"));
}
Ok(())
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn default_config_is_valid() {
let cfg = TrainingConfig::default();
cfg.validate().expect("default config should be valid");
}
#[test]
fn json_round_trip() {
let tmp = tempdir().unwrap();
let path = tmp.path().join("config.json");
let original = TrainingConfig::default();
original.to_json(&path).expect("serialization should succeed");
let loaded = TrainingConfig::from_json(&path).expect("deserialization should succeed");
assert_eq!(loaded.num_subcarriers, original.num_subcarriers);
assert_eq!(loaded.batch_size, original.batch_size);
assert_eq!(loaded.seed, original.seed);
assert_eq!(loaded.lr_milestones, original.lr_milestones);
}
#[test]
fn zero_subcarriers_is_invalid() {
let mut cfg = TrainingConfig::default();
cfg.num_subcarriers = 0;
assert!(cfg.validate().is_err());
}
#[test]
fn negative_learning_rate_is_invalid() {
let mut cfg = TrainingConfig::default();
cfg.learning_rate = -0.001;
assert!(cfg.validate().is_err());
}
#[test]
fn warmup_equal_to_epochs_is_invalid() {
let mut cfg = TrainingConfig::default();
cfg.warmup_epochs = cfg.num_epochs;
assert!(cfg.validate().is_err());
}
#[test]
fn non_increasing_milestones_are_invalid() {
let mut cfg = TrainingConfig::default();
cfg.lr_milestones = vec![30, 20]; // wrong order
assert!(cfg.validate().is_err());
}
#[test]
fn milestone_beyond_epochs_is_invalid() {
let mut cfg = TrainingConfig::default();
cfg.lr_milestones = vec![30, cfg.num_epochs + 1];
assert!(cfg.validate().is_err());
}
#[test]
fn all_zero_loss_weights_are_invalid() {
let mut cfg = TrainingConfig::default();
cfg.lambda_kp = 0.0;
cfg.lambda_dp = 0.0;
cfg.lambda_tr = 0.0;
assert!(cfg.validate().is_err());
}
#[test]
fn needs_subcarrier_interp_when_counts_differ() {
let mut cfg = TrainingConfig::default();
cfg.num_subcarriers = 56;
cfg.native_subcarriers = 114;
assert!(cfg.needs_subcarrier_interp());
cfg.native_subcarriers = 56;
assert!(!cfg.needs_subcarrier_interp());
}
#[test]
fn config_fields_have_expected_defaults() {
let cfg = TrainingConfig::default();
assert_eq!(cfg.num_subcarriers, 56);
assert_eq!(cfg.native_subcarriers, 114);
assert_eq!(cfg.num_antennas_tx, 3);
assert_eq!(cfg.num_antennas_rx, 3);
assert_eq!(cfg.window_frames, 100);
assert_eq!(cfg.heatmap_size, 56);
assert_eq!(cfg.num_keypoints, 17);
assert_eq!(cfg.num_body_parts, 24);
assert_eq!(cfg.batch_size, 8);
assert!((cfg.learning_rate - 1e-3).abs() < 1e-10);
assert_eq!(cfg.num_epochs, 50);
assert_eq!(cfg.warmup_epochs, 5);
assert_eq!(cfg.lr_milestones, vec![30, 45]);
assert!((cfg.lr_gamma - 0.1).abs() < 1e-10);
assert!((cfg.lambda_kp - 0.3).abs() < 1e-10);
assert!((cfg.lambda_dp - 0.6).abs() < 1e-10);
assert!((cfg.lambda_tr - 0.1).abs() < 1e-10);
assert_eq!(cfg.seed, 42);
}
}