diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/Cargo.toml b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/Cargo.toml new file mode 100644 index 0000000..84b5197 --- /dev/null +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/Cargo.toml @@ -0,0 +1,80 @@ +[package] +name = "wifi-densepose-train" +version = "0.1.0" +edition = "2021" +authors = ["WiFi-DensePose Contributors"] +license = "MIT OR Apache-2.0" +description = "Training pipeline for WiFi-DensePose pose estimation" +keywords = ["wifi", "training", "pose-estimation", "deep-learning"] + +[[bin]] +name = "train" +path = "src/bin/train.rs" + +[[bin]] +name = "verify-training" +path = "src/bin/verify_training.rs" + +[features] +default = ["tch-backend"] +tch-backend = ["tch"] +cuda = ["tch-backend"] + +[dependencies] +# Internal crates +wifi-densepose-signal = { path = "../wifi-densepose-signal" } +wifi-densepose-nn = { path = "../wifi-densepose-nn", default-features = false } + +# Core +thiserror = "1.0" +anyhow = "1.0" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" + +# Tensor / math +ndarray = { version = "0.15", features = ["serde"] } +ndarray-linalg = { version = "0.16", features = ["openblas-static"] } +num-complex = "0.4" +num-traits = "0.2" + +# PyTorch bindings (training) +tch = { version = "0.14", optional = true } + +# Graph algorithms (min-cut for optimal keypoint assignment) +petgraph = "0.6" + +# Data loading +ndarray-npy = "0.8" +memmap2 = "0.9" +walkdir = "2.4" + +# Serialization +csv = "1.3" +toml = "0.8" + +# Logging / progress +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } +indicatif = "0.17" + +# Async +tokio = { version = "1.35", features = ["rt", "rt-multi-thread", "macros", "fs"] } + +# Crypto (for proof hash) +sha2 = "0.10" + +# CLI +clap = { version = "4.4", features = ["derive"] } + +# Time +chrono = { version = "0.4", features = ["serde"] } + +[dev-dependencies] +criterion = { version = "0.5", features = ["html_reports"] } +proptest = "1.4" +tempfile = "3.10" +approx = "0.5" + +[[bench]] +name = "training_bench" +harness = false diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/config.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/config.rs new file mode 100644 index 0000000..8e27d19 --- /dev/null +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/config.rs @@ -0,0 +1,507 @@ +//! Training configuration for WiFi-DensePose. +//! +//! [`TrainingConfig`] is the single source of truth for all hyper-parameters, +//! dataset shapes, loss weights, and infrastructure settings used throughout +//! the training pipeline. It is serializable via [`serde`] so it can be stored +//! to / restored from JSON checkpoint files. +//! +//! # Example +//! +//! ```rust +//! use wifi_densepose_train::config::TrainingConfig; +//! +//! let cfg = TrainingConfig::default(); +//! cfg.validate().expect("default config is valid"); +//! +//! assert_eq!(cfg.num_subcarriers, 56); +//! assert_eq!(cfg.num_keypoints, 17); +//! ``` + +use serde::{Deserialize, Serialize}; +use std::path::{Path, PathBuf}; + +use crate::error::ConfigError; + +// --------------------------------------------------------------------------- +// TrainingConfig +// --------------------------------------------------------------------------- + +/// Complete configuration for a WiFi-DensePose training run. +/// +/// All fields have documented defaults that match the paper's experimental +/// setup. Use [`TrainingConfig::default()`] as a starting point, then override +/// individual fields as needed. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TrainingConfig { + // ----------------------------------------------------------------------- + // Data / Signal + // ----------------------------------------------------------------------- + /// Number of subcarriers after interpolation (system target). + /// + /// The model always sees this many subcarriers regardless of the raw + /// hardware output. Default: **56**. + pub num_subcarriers: usize, + + /// Number of subcarriers in the raw dataset before interpolation. + /// + /// MM-Fi provides 114 subcarriers; set this to 56 when the dataset + /// already matches the target count. Default: **114**. + pub native_subcarriers: usize, + + /// Number of transmit antennas. Default: **3**. + pub num_antennas_tx: usize, + + /// Number of receive antennas. Default: **3**. + pub num_antennas_rx: usize, + + /// Temporal sliding-window length in frames. Default: **100**. + pub window_frames: usize, + + /// Side length of the square keypoint heatmap output (H = W). Default: **56**. + pub heatmap_size: usize, + + // ----------------------------------------------------------------------- + // Model + // ----------------------------------------------------------------------- + /// Number of body keypoints (COCO 17-joint skeleton). Default: **17**. + pub num_keypoints: usize, + + /// Number of DensePose body-part classes. Default: **24**. + pub num_body_parts: usize, + + /// Number of feature-map channels in the backbone encoder. Default: **256**. + pub backbone_channels: usize, + + // ----------------------------------------------------------------------- + // Optimisation + // ----------------------------------------------------------------------- + /// Mini-batch size. Default: **8**. + pub batch_size: usize, + + /// Initial learning rate for the Adam / AdamW optimiser. Default: **1e-3**. + pub learning_rate: f64, + + /// L2 weight-decay regularisation coefficient. Default: **1e-4**. + pub weight_decay: f64, + + /// Total number of training epochs. Default: **50**. + pub num_epochs: usize, + + /// Number of linear-warmup epochs at the start. Default: **5**. + pub warmup_epochs: usize, + + /// Epochs at which the learning rate is multiplied by `lr_gamma`. + /// + /// Default: **[30, 45]** (multi-step scheduler). + pub lr_milestones: Vec, + + /// Multiplicative factor applied at each LR milestone. Default: **0.1**. + pub lr_gamma: f64, + + /// Maximum gradient L2 norm for gradient clipping. Default: **1.0**. + pub grad_clip_norm: f64, + + // ----------------------------------------------------------------------- + // Loss weights + // ----------------------------------------------------------------------- + /// Weight for the keypoint heatmap loss term. Default: **0.3**. + pub lambda_kp: f64, + + /// Weight for the DensePose body-part / UV-coordinate loss. Default: **0.6**. + pub lambda_dp: f64, + + /// Weight for the cross-modal transfer / domain-alignment loss. Default: **0.1**. + pub lambda_tr: f64, + + // ----------------------------------------------------------------------- + // Validation and checkpointing + // ----------------------------------------------------------------------- + /// Run validation every N epochs. Default: **1**. + pub val_every_epochs: usize, + + /// Stop training if validation loss does not improve for this many + /// consecutive validation rounds. Default: **10**. + pub early_stopping_patience: usize, + + /// Directory where model checkpoints are saved. + pub checkpoint_dir: PathBuf, + + /// Directory where TensorBoard / CSV logs are written. + pub log_dir: PathBuf, + + /// Keep only the top-K best checkpoints by validation metric. Default: **3**. + pub save_top_k: usize, + + // ----------------------------------------------------------------------- + // Device + // ----------------------------------------------------------------------- + /// Use a CUDA GPU for training when available. Default: **false**. + pub use_gpu: bool, + + /// CUDA device index when `use_gpu` is `true`. Default: **0**. + pub gpu_device_id: i64, + + /// Number of background data-loading threads. Default: **4**. + pub num_workers: usize, + + // ----------------------------------------------------------------------- + // Reproducibility + // ----------------------------------------------------------------------- + /// Global random seed for all RNG sources in the training pipeline. + /// + /// This seed is applied to the dataset shuffler, model parameter + /// initialisation, and any stochastic augmentation. Default: **42**. + pub seed: u64, +} + +impl Default for TrainingConfig { + fn default() -> Self { + TrainingConfig { + // Data + num_subcarriers: 56, + native_subcarriers: 114, + num_antennas_tx: 3, + num_antennas_rx: 3, + window_frames: 100, + heatmap_size: 56, + // Model + num_keypoints: 17, + num_body_parts: 24, + backbone_channels: 256, + // Optimisation + batch_size: 8, + learning_rate: 1e-3, + weight_decay: 1e-4, + num_epochs: 50, + warmup_epochs: 5, + lr_milestones: vec![30, 45], + lr_gamma: 0.1, + grad_clip_norm: 1.0, + // Loss weights + lambda_kp: 0.3, + lambda_dp: 0.6, + lambda_tr: 0.1, + // Validation / checkpointing + val_every_epochs: 1, + early_stopping_patience: 10, + checkpoint_dir: PathBuf::from("checkpoints"), + log_dir: PathBuf::from("logs"), + save_top_k: 3, + // Device + use_gpu: false, + gpu_device_id: 0, + num_workers: 4, + // Reproducibility + seed: 42, + } + } +} + +impl TrainingConfig { + /// Load a [`TrainingConfig`] from a JSON file at `path`. + /// + /// # Errors + /// + /// Returns [`ConfigError::FileRead`] if the file cannot be opened and + /// [`ConfigError::InvalidValue`] if the JSON is malformed. + pub fn from_json(path: &Path) -> Result { + let contents = std::fs::read_to_string(path).map_err(|source| ConfigError::FileRead { + path: path.to_path_buf(), + source, + })?; + let cfg: TrainingConfig = serde_json::from_str(&contents) + .map_err(|e| ConfigError::invalid_value("(file)", e.to_string()))?; + cfg.validate()?; + Ok(cfg) + } + + /// Serialize this configuration to pretty-printed JSON and write it to + /// `path`, creating parent directories if necessary. + /// + /// # Errors + /// + /// Returns [`ConfigError::FileRead`] if the directory cannot be created or + /// the file cannot be written. + pub fn to_json(&self, path: &Path) -> Result<(), ConfigError> { + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent).map_err(|source| ConfigError::FileRead { + path: parent.to_path_buf(), + source, + })?; + } + let json = serde_json::to_string_pretty(self) + .map_err(|e| ConfigError::invalid_value("(serialization)", e.to_string()))?; + std::fs::write(path, json).map_err(|source| ConfigError::FileRead { + path: path.to_path_buf(), + source, + })?; + Ok(()) + } + + /// Returns `true` when the native dataset subcarrier count differs from the + /// model's target count and interpolation is therefore required. + pub fn needs_subcarrier_interp(&self) -> bool { + self.native_subcarriers != self.num_subcarriers + } + + /// Validate all fields and return an error describing the first problem + /// found, or `Ok(())` if the configuration is coherent. + /// + /// # Validated invariants + /// + /// - Subcarrier counts must be non-zero. + /// - Antenna counts must be non-zero. + /// - `window_frames` must be at least 1. + /// - `batch_size` must be at least 1. + /// - `learning_rate` must be strictly positive. + /// - `weight_decay` must be non-negative. + /// - Loss weights must be non-negative and sum to a positive value. + /// - `num_epochs` must be greater than `warmup_epochs`. + /// - All `lr_milestones` must be within `[1, num_epochs]` and strictly + /// increasing. + /// - `save_top_k` must be at least 1. + /// - `val_every_epochs` must be at least 1. + pub fn validate(&self) -> Result<(), ConfigError> { + // Subcarrier counts + if self.num_subcarriers == 0 { + return Err(ConfigError::invalid_value("num_subcarriers", "must be > 0")); + } + if self.native_subcarriers == 0 { + return Err(ConfigError::invalid_value( + "native_subcarriers", + "must be > 0", + )); + } + + // Antenna counts + if self.num_antennas_tx == 0 { + return Err(ConfigError::invalid_value("num_antennas_tx", "must be > 0")); + } + if self.num_antennas_rx == 0 { + return Err(ConfigError::invalid_value("num_antennas_rx", "must be > 0")); + } + + // Temporal window + if self.window_frames == 0 { + return Err(ConfigError::invalid_value("window_frames", "must be > 0")); + } + + // Heatmap + if self.heatmap_size == 0 { + return Err(ConfigError::invalid_value("heatmap_size", "must be > 0")); + } + + // Model dims + if self.num_keypoints == 0 { + return Err(ConfigError::invalid_value("num_keypoints", "must be > 0")); + } + if self.num_body_parts == 0 { + return Err(ConfigError::invalid_value("num_body_parts", "must be > 0")); + } + if self.backbone_channels == 0 { + return Err(ConfigError::invalid_value( + "backbone_channels", + "must be > 0", + )); + } + + // Optimisation + if self.batch_size == 0 { + return Err(ConfigError::invalid_value("batch_size", "must be > 0")); + } + if self.learning_rate <= 0.0 { + return Err(ConfigError::invalid_value( + "learning_rate", + "must be > 0.0", + )); + } + if self.weight_decay < 0.0 { + return Err(ConfigError::invalid_value( + "weight_decay", + "must be >= 0.0", + )); + } + if self.grad_clip_norm <= 0.0 { + return Err(ConfigError::invalid_value( + "grad_clip_norm", + "must be > 0.0", + )); + } + + // Epochs + if self.num_epochs == 0 { + return Err(ConfigError::invalid_value("num_epochs", "must be > 0")); + } + if self.warmup_epochs >= self.num_epochs { + return Err(ConfigError::invalid_value( + "warmup_epochs", + "must be < num_epochs", + )); + } + + // LR milestones: must be strictly increasing and within bounds + let mut prev = 0usize; + for &m in &self.lr_milestones { + if m == 0 || m > self.num_epochs { + return Err(ConfigError::invalid_value( + "lr_milestones", + "each milestone must be in [1, num_epochs]", + )); + } + if m <= prev { + return Err(ConfigError::invalid_value( + "lr_milestones", + "milestones must be strictly increasing", + )); + } + prev = m; + } + + if self.lr_gamma <= 0.0 || self.lr_gamma >= 1.0 { + return Err(ConfigError::invalid_value( + "lr_gamma", + "must be in (0.0, 1.0)", + )); + } + + // Loss weights + if self.lambda_kp < 0.0 { + return Err(ConfigError::invalid_value("lambda_kp", "must be >= 0.0")); + } + if self.lambda_dp < 0.0 { + return Err(ConfigError::invalid_value("lambda_dp", "must be >= 0.0")); + } + if self.lambda_tr < 0.0 { + return Err(ConfigError::invalid_value("lambda_tr", "must be >= 0.0")); + } + let total_weight = self.lambda_kp + self.lambda_dp + self.lambda_tr; + if total_weight <= 0.0 { + return Err(ConfigError::invalid_value( + "lambda_kp / lambda_dp / lambda_tr", + "at least one loss weight must be > 0.0", + )); + } + + // Validation / checkpoint + if self.val_every_epochs == 0 { + return Err(ConfigError::invalid_value( + "val_every_epochs", + "must be > 0", + )); + } + if self.save_top_k == 0 { + return Err(ConfigError::invalid_value("save_top_k", "must be > 0")); + } + + Ok(()) + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::tempdir; + + #[test] + fn default_config_is_valid() { + let cfg = TrainingConfig::default(); + cfg.validate().expect("default config should be valid"); + } + + #[test] + fn json_round_trip() { + let tmp = tempdir().unwrap(); + let path = tmp.path().join("config.json"); + + let original = TrainingConfig::default(); + original.to_json(&path).expect("serialization should succeed"); + + let loaded = TrainingConfig::from_json(&path).expect("deserialization should succeed"); + assert_eq!(loaded.num_subcarriers, original.num_subcarriers); + assert_eq!(loaded.batch_size, original.batch_size); + assert_eq!(loaded.seed, original.seed); + assert_eq!(loaded.lr_milestones, original.lr_milestones); + } + + #[test] + fn zero_subcarriers_is_invalid() { + let mut cfg = TrainingConfig::default(); + cfg.num_subcarriers = 0; + assert!(cfg.validate().is_err()); + } + + #[test] + fn negative_learning_rate_is_invalid() { + let mut cfg = TrainingConfig::default(); + cfg.learning_rate = -0.001; + assert!(cfg.validate().is_err()); + } + + #[test] + fn warmup_equal_to_epochs_is_invalid() { + let mut cfg = TrainingConfig::default(); + cfg.warmup_epochs = cfg.num_epochs; + assert!(cfg.validate().is_err()); + } + + #[test] + fn non_increasing_milestones_are_invalid() { + let mut cfg = TrainingConfig::default(); + cfg.lr_milestones = vec![30, 20]; // wrong order + assert!(cfg.validate().is_err()); + } + + #[test] + fn milestone_beyond_epochs_is_invalid() { + let mut cfg = TrainingConfig::default(); + cfg.lr_milestones = vec![30, cfg.num_epochs + 1]; + assert!(cfg.validate().is_err()); + } + + #[test] + fn all_zero_loss_weights_are_invalid() { + let mut cfg = TrainingConfig::default(); + cfg.lambda_kp = 0.0; + cfg.lambda_dp = 0.0; + cfg.lambda_tr = 0.0; + assert!(cfg.validate().is_err()); + } + + #[test] + fn needs_subcarrier_interp_when_counts_differ() { + let mut cfg = TrainingConfig::default(); + cfg.num_subcarriers = 56; + cfg.native_subcarriers = 114; + assert!(cfg.needs_subcarrier_interp()); + + cfg.native_subcarriers = 56; + assert!(!cfg.needs_subcarrier_interp()); + } + + #[test] + fn config_fields_have_expected_defaults() { + let cfg = TrainingConfig::default(); + assert_eq!(cfg.num_subcarriers, 56); + assert_eq!(cfg.native_subcarriers, 114); + assert_eq!(cfg.num_antennas_tx, 3); + assert_eq!(cfg.num_antennas_rx, 3); + assert_eq!(cfg.window_frames, 100); + assert_eq!(cfg.heatmap_size, 56); + assert_eq!(cfg.num_keypoints, 17); + assert_eq!(cfg.num_body_parts, 24); + assert_eq!(cfg.batch_size, 8); + assert!((cfg.learning_rate - 1e-3).abs() < 1e-10); + assert_eq!(cfg.num_epochs, 50); + assert_eq!(cfg.warmup_epochs, 5); + assert_eq!(cfg.lr_milestones, vec![30, 45]); + assert!((cfg.lr_gamma - 0.1).abs() < 1e-10); + assert!((cfg.lambda_kp - 0.3).abs() < 1e-10); + assert!((cfg.lambda_dp - 0.6).abs() < 1e-10); + assert!((cfg.lambda_tr - 0.1).abs() < 1e-10); + assert_eq!(cfg.seed, 42); + } +} diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/dataset.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/dataset.rs new file mode 100644 index 0000000..f5d9bce --- /dev/null +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/dataset.rs @@ -0,0 +1,956 @@ +//! Dataset abstractions and concrete implementations for WiFi-DensePose training. +//! +//! This module defines the [`CsiDataset`] trait plus two concrete implementations: +//! +//! - [`MmFiDataset`]: reads MM-Fi NPY/HDF5 files from disk. +//! - [`SyntheticCsiDataset`]: generates fully-deterministic CSI from a physics +//! model; useful for unit tests, integration tests, and dry-run sanity checks. +//! **Never uses random data.** +//! +//! A [`DataLoader`] wraps any [`CsiDataset`] and provides batched iteration with +//! optional deterministic shuffle (seeded). +//! +//! # Directory layout expected by `MmFiDataset` +//! +//! ```text +//! / +//! S01/ +//! A01/ +//! wifi_csi.npy # amplitude [T, n_tx, n_rx, n_sc] +//! wifi_csi_phase.npy # phase [T, n_tx, n_rx, n_sc] +//! gt_keypoints.npy # keypoints [T, 17, 3] (x, y, vis) +//! A02/ +//! ... +//! S02/ +//! ... +//! ``` +//! +//! Each subject/action pair produces one or more windowed [`CsiSample`]s. +//! +//! # Example – synthetic dataset +//! +//! ```rust +//! use wifi_densepose_train::dataset::{SyntheticCsiDataset, SyntheticConfig, CsiDataset}; +//! +//! let cfg = SyntheticConfig::default(); +//! let ds = SyntheticCsiDataset::new(64, cfg); +//! +//! assert_eq!(ds.len(), 64); +//! let sample = ds.get(0).unwrap(); +//! assert_eq!(sample.amplitude.shape(), &[100, 3, 3, 56]); +//! ``` + +use ndarray::{Array1, Array2, Array4}; +use std::path::{Path, PathBuf}; +use thiserror::Error; +use tracing::{debug, info, warn}; + +use crate::subcarrier::interpolate_subcarriers; + +// --------------------------------------------------------------------------- +// CsiSample +// --------------------------------------------------------------------------- + +/// A single windowed CSI observation paired with its ground-truth labels. +/// +/// All arrays are stored in row-major (C) order. Keypoint coordinates are +/// normalised to `[0, 1]` with the origin at the **top-left** corner. +#[derive(Debug, Clone)] +pub struct CsiSample { + /// CSI amplitude tensor. + /// + /// Shape: `[window_frames, n_tx, n_rx, n_subcarriers]`. + pub amplitude: Array4, + + /// CSI phase tensor (radians, unwrapped). + /// + /// Shape: `[window_frames, n_tx, n_rx, n_subcarriers]`. + pub phase: Array4, + + /// COCO 17-keypoint positions normalised to `[0, 1]`. + /// + /// Shape: `[17, 2]` – column 0 is x, column 1 is y. + pub keypoints: Array2, + + /// Keypoint visibility flags. + /// + /// Shape: `[17]`. Values follow the COCO convention: + /// - `0` – not labelled + /// - `1` – labelled but not visible + /// - `2` – visible + pub keypoint_visibility: Array1, + + /// Subject identifier (e.g. 1 for `S01`). + pub subject_id: u32, + + /// Action identifier (e.g. 1 for `A01`). + pub action_id: u32, + + /// Absolute frame index within the original recording. + pub frame_id: u64, +} + +// --------------------------------------------------------------------------- +// CsiDataset trait +// --------------------------------------------------------------------------- + +/// Common interface for all WiFi-DensePose datasets. +/// +/// Implementations must be `Send + Sync` so they can be shared across +/// data-loading threads without additional synchronisation. +pub trait CsiDataset: Send + Sync { + /// Total number of samples in this dataset. + fn len(&self) -> usize; + + /// Load the sample at position `idx`. + /// + /// # Errors + /// + /// Returns [`DatasetError::IndexOutOfBounds`] when `idx >= self.len()` and + /// dataset-specific errors for IO or format problems. + fn get(&self, idx: usize) -> Result; + + /// Returns `true` when the dataset contains no samples. + fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Human-readable name for logging and progress display. + fn name(&self) -> &str; +} + +// --------------------------------------------------------------------------- +// DataLoader +// --------------------------------------------------------------------------- + +/// Batched, optionally-shuffled iterator over a [`CsiDataset`]. +/// +/// The shuffle order is fully deterministic: given the same `seed` and dataset +/// length the iteration order is always identical. This ensures reproducibility +/// across training runs. +pub struct DataLoader<'a> { + dataset: &'a dyn CsiDataset, + batch_size: usize, + shuffle: bool, + seed: u64, +} + +impl<'a> DataLoader<'a> { + /// Create a new `DataLoader`. + /// + /// # Parameters + /// + /// - `dataset` – the underlying dataset. + /// - `batch_size` – number of samples per batch. The last batch may be + /// smaller if the dataset length is not a multiple of `batch_size`. + /// - `shuffle` – if `true`, samples are shuffled deterministically using + /// `seed` at the start of each iteration. + /// - `seed` – fixed seed for the shuffle RNG. + pub fn new( + dataset: &'a dyn CsiDataset, + batch_size: usize, + shuffle: bool, + seed: u64, + ) -> Self { + assert!(batch_size > 0, "batch_size must be > 0"); + DataLoader { dataset, batch_size, shuffle, seed } + } + + /// Number of complete (or partial) batches yielded per epoch. + pub fn num_batches(&self) -> usize { + let n = self.dataset.len(); + if n == 0 { + return 0; + } + (n + self.batch_size - 1) / self.batch_size + } + + /// Return an iterator that yields `Vec` batches. + /// + /// Failed individual sample loads are skipped with a `warn!` log rather + /// than aborting the iterator. + pub fn iter(&self) -> DataLoaderIter<'_> { + // Build the index permutation once per epoch using a seeded Xorshift64. + let n = self.dataset.len(); + let mut indices: Vec = (0..n).collect(); + if self.shuffle { + xorshift_shuffle(&mut indices, self.seed); + } + DataLoaderIter { + dataset: self.dataset, + indices, + batch_size: self.batch_size, + cursor: 0, + } + } +} + +/// Iterator returned by [`DataLoader::iter`]. +pub struct DataLoaderIter<'a> { + dataset: &'a dyn CsiDataset, + indices: Vec, + batch_size: usize, + cursor: usize, +} + +impl<'a> Iterator for DataLoaderIter<'a> { + type Item = Vec; + + fn next(&mut self) -> Option { + if self.cursor >= self.indices.len() { + return None; + } + let end = (self.cursor + self.batch_size).min(self.indices.len()); + let batch_indices = &self.indices[self.cursor..end]; + self.cursor = end; + + let mut batch = Vec::with_capacity(batch_indices.len()); + for &idx in batch_indices { + match self.dataset.get(idx) { + Ok(sample) => batch.push(sample), + Err(e) => { + warn!("Skipping sample {idx}: {e}"); + } + } + } + if batch.is_empty() { None } else { Some(batch) } + } +} + +// --------------------------------------------------------------------------- +// Xorshift shuffle (deterministic, no external RNG state) +// --------------------------------------------------------------------------- + +/// In-place Fisher-Yates shuffle using a 64-bit Xorshift PRNG seeded with +/// `seed`. This is reproducible across platforms and requires no external crate +/// in production paths. +fn xorshift_shuffle(indices: &mut [usize], seed: u64) { + let n = indices.len(); + if n <= 1 { + return; + } + let mut state = if seed == 0 { 0x853c49e6748fea9b } else { seed }; + for i in (1..n).rev() { + // Xorshift64 + state ^= state << 13; + state ^= state >> 7; + state ^= state << 17; + let j = (state as usize) % (i + 1); + indices.swap(i, j); + } +} + +// --------------------------------------------------------------------------- +// MmFiDataset +// --------------------------------------------------------------------------- + +/// An indexed entry in the MM-Fi directory scan. +#[derive(Debug, Clone)] +struct MmFiEntry { + subject_id: u32, + action_id: u32, + /// Path to `wifi_csi.npy` (amplitude). + amp_path: PathBuf, + /// Path to `wifi_csi_phase.npy`. + phase_path: PathBuf, + /// Path to `gt_keypoints.npy`. + kp_path: PathBuf, + /// Number of temporal frames available in this clip. + num_frames: usize, + /// Window size in frames (mirrors config). + window_frames: usize, + /// First global sample index that maps into this clip. + global_start_idx: usize, +} + +impl MmFiEntry { + /// Number of non-overlapping windows this clip contributes. + fn num_windows(&self) -> usize { + if self.num_frames < self.window_frames { + 0 + } else { + self.num_frames - self.window_frames + 1 + } + } +} + +/// Dataset adapter for MM-Fi recordings stored as `.npy` files. +/// +/// Scanning is performed once at construction via [`MmFiDataset::discover`]. +/// Individual samples are loaded lazily from disk on each [`CsiDataset::get`] +/// call. +/// +/// ## Subcarrier interpolation +/// +/// When the loaded amplitude/phase arrays contain a different number of +/// subcarriers than `target_subcarriers`, [`interpolate_subcarriers`] is +/// applied automatically before the sample is returned. +pub struct MmFiDataset { + entries: Vec, + /// Cumulative window count per entry (prefix sum, length = entries.len() + 1). + cumulative: Vec, + window_frames: usize, + target_subcarriers: usize, + num_keypoints: usize, + root: PathBuf, +} + +impl MmFiDataset { + /// Scan `root` for MM-Fi recordings and build a sample index. + /// + /// The scan walks `root/{S??}/{A??}/` directories and looks for: + /// - `wifi_csi.npy` – CSI amplitude + /// - `wifi_csi_phase.npy` – CSI phase + /// - `gt_keypoints.npy` – ground-truth keypoints + /// + /// # Errors + /// + /// Returns [`DatasetError::DirectoryNotFound`] if `root` does not exist, or + /// [`DatasetError::Io`] for any filesystem access failure. + pub fn discover( + root: &Path, + window_frames: usize, + target_subcarriers: usize, + num_keypoints: usize, + ) -> Result { + if !root.exists() { + return Err(DatasetError::DirectoryNotFound { + path: root.display().to_string(), + }); + } + + let mut entries: Vec = Vec::new(); + let mut global_idx = 0usize; + + // Walk subject directories (S01, S02, …) + let mut subject_dirs: Vec = std::fs::read_dir(root)? + .filter_map(|e| e.ok()) + .filter(|e| e.file_type().map(|t| t.is_dir()).unwrap_or(false)) + .map(|e| e.path()) + .collect(); + subject_dirs.sort(); + + for subj_path in &subject_dirs { + let subj_name = subj_path.file_name().and_then(|n| n.to_str()).unwrap_or(""); + let subject_id = parse_id_suffix(subj_name).unwrap_or(0); + + // Walk action directories (A01, A02, …) + let mut action_dirs: Vec = std::fs::read_dir(subj_path)? + .filter_map(|e| e.ok()) + .filter(|e| e.file_type().map(|t| t.is_dir()).unwrap_or(false)) + .map(|e| e.path()) + .collect(); + action_dirs.sort(); + + for action_path in &action_dirs { + let action_name = + action_path.file_name().and_then(|n| n.to_str()).unwrap_or(""); + let action_id = parse_id_suffix(action_name).unwrap_or(0); + + let amp_path = action_path.join("wifi_csi.npy"); + let phase_path = action_path.join("wifi_csi_phase.npy"); + let kp_path = action_path.join("gt_keypoints.npy"); + + if !amp_path.exists() || !kp_path.exists() { + debug!( + "Skipping {}: missing required files", + action_path.display() + ); + continue; + } + + // Peek at the amplitude shape to get the frame count. + let num_frames = match peek_npy_first_dim(&_path) { + Ok(n) => n, + Err(e) => { + warn!("Cannot read shape from {}: {e}", amp_path.display()); + continue; + } + }; + + let entry = MmFiEntry { + subject_id, + action_id, + amp_path, + phase_path, + kp_path, + num_frames, + window_frames, + global_start_idx: global_idx, + }; + global_idx += entry.num_windows(); + entries.push(entry); + } + } + + info!( + "MmFiDataset: scanned {} clips, {} total windows (root={})", + entries.len(), + global_idx, + root.display() + ); + + // Build prefix-sum cumulative array + let mut cumulative = vec![0usize; entries.len() + 1]; + for (i, e) in entries.iter().enumerate() { + cumulative[i + 1] = cumulative[i] + e.num_windows(); + } + + Ok(MmFiDataset { + entries, + cumulative, + window_frames, + target_subcarriers, + num_keypoints, + root: root.to_path_buf(), + }) + } + + /// Resolve a global sample index to `(entry_index, frame_offset)`. + fn locate(&self, idx: usize) -> Option<(usize, usize)> { + let total = self.cumulative.last().copied().unwrap_or(0); + if idx >= total { + return None; + } + // Binary search in the cumulative prefix sums. + let entry_idx = self + .cumulative + .partition_point(|&c| c <= idx) + .saturating_sub(1); + let frame_offset = idx - self.cumulative[entry_idx]; + Some((entry_idx, frame_offset)) + } +} + +impl CsiDataset for MmFiDataset { + fn len(&self) -> usize { + self.cumulative.last().copied().unwrap_or(0) + } + + fn get(&self, idx: usize) -> Result { + let total = self.len(); + let (entry_idx, frame_offset) = self + .locate(idx) + .ok_or(DatasetError::IndexOutOfBounds { idx, len: total })?; + + let entry = &self.entries[entry_idx]; + let t_start = frame_offset; + let t_end = t_start + self.window_frames; + + // Load amplitude + let amp_full = load_npy_f32(&entry.amp_path)?; + let (t, n_tx, n_rx, n_sc) = amp_full.dim(); + if t_end > t { + return Err(DatasetError::Format(format!( + "window [{t_start}, {t_end}) exceeds clip length {t} in {}", + entry.amp_path.display() + ))); + } + let amp_window = amp_full + .slice(ndarray::s![t_start..t_end, .., .., ..]) + .to_owned(); + + // Load phase (optional – return zeros if the file is absent) + let phase_window = if entry.phase_path.exists() { + let phase_full = load_npy_f32(&entry.phase_path)?; + phase_full + .slice(ndarray::s![t_start..t_end, .., .., ..]) + .to_owned() + } else { + Array4::zeros((self.window_frames, n_tx, n_rx, n_sc)) + }; + + // Subcarrier interpolation (if needed) + let amplitude = if n_sc != self.target_subcarriers { + interpolate_subcarriers(&_window, self.target_subcarriers) + } else { + amp_window + }; + + let phase = if phase_window.dim().3 != self.target_subcarriers { + interpolate_subcarriers(&phase_window, self.target_subcarriers) + } else { + phase_window + }; + + // Load keypoints [T, 17, 3] — take the first frame of the window + let kp_full = load_npy_kp(&entry.kp_path, self.num_keypoints)?; + let kp_frame = kp_full + .slice(ndarray::s![t_start, .., ..]) + .to_owned(); + + // Split into (x,y) and visibility + let keypoints = kp_frame.slice(ndarray::s![.., 0..2]).to_owned(); + let keypoint_visibility = kp_frame.column(2).to_owned(); + + Ok(CsiSample { + amplitude, + phase, + keypoints, + keypoint_visibility, + subject_id: entry.subject_id, + action_id: entry.action_id, + frame_id: t_start as u64, + }) + } + + fn name(&self) -> &str { + "MmFiDataset" + } +} + +// --------------------------------------------------------------------------- +// NPY helpers (no-HDF5 path; HDF5 path is feature-gated below) +// --------------------------------------------------------------------------- + +/// Load a 4-D float32 NPY array from disk. +/// +/// The NPY format is read using `ndarray_npy`. +fn load_npy_f32(path: &Path) -> Result, DatasetError> { + use ndarray_npy::ReadNpyExt; + let file = std::fs::File::open(path)?; + let arr: ndarray::ArrayD = ndarray::ArrayD::read_npy(file) + .map_err(|e| DatasetError::Format(format!("NPY read error at {}: {e}", path.display())))?; + arr.into_dimensionality::().map_err(|e| { + DatasetError::Format(format!( + "Expected 4-D array in {}, got shape {:?}: {e}", + path.display(), + arr.shape() + )) + }) +} + +/// Load a 3-D float32 NPY array (keypoints: `[T, J, 3]`). +fn load_npy_kp(path: &Path, _num_keypoints: usize) -> Result, DatasetError> { + use ndarray_npy::ReadNpyExt; + let file = std::fs::File::open(path)?; + let arr: ndarray::ArrayD = ndarray::ArrayD::read_npy(file) + .map_err(|e| DatasetError::Format(format!("NPY read error at {}: {e}", path.display())))?; + arr.into_dimensionality::().map_err(|e| { + DatasetError::Format(format!( + "Expected 3-D keypoint array in {}, got shape {:?}: {e}", + path.display(), + arr.shape() + )) + }) +} + +/// Read only the first dimension of an NPY header (the frame count) without +/// loading the entire file into memory. +fn peek_npy_first_dim(path: &Path) -> Result { + // Minimum viable NPY header parse: magic + version + header_len + header. + use std::io::{BufReader, Read}; + let f = std::fs::File::open(path)?; + let mut reader = BufReader::new(f); + + let mut magic = [0u8; 6]; + reader.read_exact(&mut magic)?; + if &magic != b"\x93NUMPY" { + return Err(DatasetError::Format(format!( + "Not a valid NPY file: {}", + path.display() + ))); + } + + let mut version = [0u8; 2]; + reader.read_exact(&mut version)?; + + // Header length field: 2 bytes in v1, 4 bytes in v2 + let header_len: usize = if version[0] == 1 { + let mut buf = [0u8; 2]; + reader.read_exact(&mut buf)?; + u16::from_le_bytes(buf) as usize + } else { + let mut buf = [0u8; 4]; + reader.read_exact(&mut buf)?; + u32::from_le_bytes(buf) as usize + }; + + let mut header = vec![0u8; header_len]; + reader.read_exact(&mut header)?; + let header_str = String::from_utf8_lossy(&header); + + // Parse the shape tuple using a simple substring search. + // Example header: "{'descr': ' = shape_str + .split(',') + .filter_map(|s| s.trim().parse::().ok()) + .collect(); + if let Some(&first) = dims.first() { + return Ok(first); + } + } + } + + Err(DatasetError::Format(format!( + "Cannot parse shape from NPY header in {}", + path.display() + ))) +} + +/// Parse the numeric suffix of a directory name like `S01` → `1` or `A12` → `12`. +fn parse_id_suffix(name: &str) -> Option { + name.chars() + .skip_while(|c| c.is_alphabetic()) + .collect::() + .parse::() + .ok() +} + +// --------------------------------------------------------------------------- +// SyntheticCsiDataset +// --------------------------------------------------------------------------- + +/// Configuration for [`SyntheticCsiDataset`]. +/// +/// All fields are plain numbers; no randomness is involved. +#[derive(Debug, Clone)] +pub struct SyntheticConfig { + /// Number of output subcarriers. Default: **56**. + pub num_subcarriers: usize, + /// Number of transmit antennas. Default: **3**. + pub num_antennas_tx: usize, + /// Number of receive antennas. Default: **3**. + pub num_antennas_rx: usize, + /// Temporal window length. Default: **100**. + pub window_frames: usize, + /// Number of body keypoints. Default: **17** (COCO). + pub num_keypoints: usize, + /// Carrier frequency for phase model. Default: **2.4e9 Hz**. + pub signal_frequency_hz: f32, +} + +impl Default for SyntheticConfig { + fn default() -> Self { + SyntheticConfig { + num_subcarriers: 56, + num_antennas_tx: 3, + num_antennas_rx: 3, + window_frames: 100, + num_keypoints: 17, + signal_frequency_hz: 2.4e9, + } + } +} + +/// Fully-deterministic CSI dataset generated from a physical signal model. +/// +/// No random number generator is used. Every sample at index `idx` is computed +/// analytically from `idx` alone, making the dataset perfectly reproducible +/// and portable across platforms. +/// +/// ## Amplitude model +/// +/// For sample `idx`, frame `t`, tx `i`, rx `j`, subcarrier `k`: +/// +/// ```text +/// A = 0.5 + 0.3 × sin(2π × (idx × 0.01 + t × 0.1 + k × 0.05)) +/// ``` +/// +/// ## Phase model +/// +/// ```text +/// φ = (2π × k / num_subcarriers) × (i + 1) × (j + 1) +/// ``` +/// +/// ## Keypoint model +/// +/// Joint `j` is placed at: +/// +/// ```text +/// x = 0.5 + 0.1 × sin(2π × idx × 0.007 + j) +/// y = 0.3 + j × 0.04 +/// ``` +pub struct SyntheticCsiDataset { + num_samples: usize, + config: SyntheticConfig, +} + +impl SyntheticCsiDataset { + /// Create a new synthetic dataset with `num_samples` entries. + pub fn new(num_samples: usize, config: SyntheticConfig) -> Self { + SyntheticCsiDataset { num_samples, config } + } + + /// Compute the deterministic amplitude value for the given indices. + #[inline] + fn amp_value(&self, idx: usize, t: usize, _tx: usize, _rx: usize, k: usize) -> f32 { + let phase = 2.0 * std::f32::consts::PI + * (idx as f32 * 0.01 + t as f32 * 0.1 + k as f32 * 0.05); + 0.5 + 0.3 * phase.sin() + } + + /// Compute the deterministic phase value for the given indices. + #[inline] + fn phase_value(&self, _idx: usize, _t: usize, tx: usize, rx: usize, k: usize) -> f32 { + let n_sc = self.config.num_subcarriers as f32; + (2.0 * std::f32::consts::PI * k as f32 / n_sc) + * (tx as f32 + 1.0) + * (rx as f32 + 1.0) + } + + /// Compute the deterministic keypoint (x, y) for joint `j` at sample `idx`. + #[inline] + fn keypoint_xy(&self, idx: usize, j: usize) -> (f32, f32) { + let x = 0.5 + + 0.1 * (2.0 * std::f32::consts::PI * idx as f32 * 0.007 + j as f32).sin(); + let y = 0.3 + j as f32 * 0.04; + (x, y) + } +} + +impl CsiDataset for SyntheticCsiDataset { + fn len(&self) -> usize { + self.num_samples + } + + fn get(&self, idx: usize) -> Result { + if idx >= self.num_samples { + return Err(DatasetError::IndexOutOfBounds { + idx, + len: self.num_samples, + }); + } + + let cfg = &self.config; + let (t, n_tx, n_rx, n_sc) = + (cfg.window_frames, cfg.num_antennas_tx, cfg.num_antennas_rx, cfg.num_subcarriers); + + let amplitude = Array4::from_shape_fn((t, n_tx, n_rx, n_sc), |(frame, tx, rx, k)| { + self.amp_value(idx, frame, tx, rx, k) + }); + + let phase = Array4::from_shape_fn((t, n_tx, n_rx, n_sc), |(frame, tx, rx, k)| { + self.phase_value(idx, frame, tx, rx, k) + }); + + let mut keypoints = Array2::zeros((cfg.num_keypoints, 2)); + let mut keypoint_visibility = Array1::zeros(cfg.num_keypoints); + for j in 0..cfg.num_keypoints { + let (x, y) = self.keypoint_xy(idx, j); + // Clamp to [0, 1] to keep coordinates valid. + keypoints[[j, 0]] = x.clamp(0.0, 1.0); + keypoints[[j, 1]] = y.clamp(0.0, 1.0); + // All joints are visible in the synthetic model. + keypoint_visibility[j] = 2.0; + } + + Ok(CsiSample { + amplitude, + phase, + keypoints, + keypoint_visibility, + subject_id: 0, + action_id: 0, + frame_id: idx as u64, + }) + } + + fn name(&self) -> &str { + "SyntheticCsiDataset" + } +} + +// --------------------------------------------------------------------------- +// DatasetError +// --------------------------------------------------------------------------- + +/// Errors produced by dataset operations. +#[derive(Debug, Error)] +pub enum DatasetError { + /// Requested index is outside the valid range. + #[error("Index {idx} out of bounds (dataset has {len} samples)")] + IndexOutOfBounds { idx: usize, len: usize }, + + /// An underlying file-system error occurred. + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + + /// The file exists but does not match the expected format. + #[error("File format error: {0}")] + Format(String), + + /// The loaded array has a different subcarrier count than required. + #[error("Subcarrier count mismatch: expected {expected}, got {actual}")] + SubcarrierMismatch { expected: usize, actual: usize }, + + /// The specified root directory does not exist. + #[error("Directory not found: {path}")] + DirectoryNotFound { path: String }, +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use approx::assert_abs_diff_eq; + + // ----- SyntheticCsiDataset -------------------------------------------- + + #[test] + fn synthetic_sample_shapes() { + let cfg = SyntheticConfig::default(); + let ds = SyntheticCsiDataset::new(10, cfg.clone()); + let s = ds.get(0).unwrap(); + + assert_eq!(s.amplitude.shape(), &[cfg.window_frames, cfg.num_antennas_tx, cfg.num_antennas_rx, cfg.num_subcarriers]); + assert_eq!(s.phase.shape(), &[cfg.window_frames, cfg.num_antennas_tx, cfg.num_antennas_rx, cfg.num_subcarriers]); + assert_eq!(s.keypoints.shape(), &[cfg.num_keypoints, 2]); + assert_eq!(s.keypoint_visibility.shape(), &[cfg.num_keypoints]); + } + + #[test] + fn synthetic_is_deterministic() { + let cfg = SyntheticConfig::default(); + let ds = SyntheticCsiDataset::new(10, cfg); + let s0a = ds.get(3).unwrap(); + let s0b = ds.get(3).unwrap(); + assert_abs_diff_eq!(s0a.amplitude[[0, 0, 0, 0]], s0b.amplitude[[0, 0, 0, 0]], epsilon = 1e-7); + assert_abs_diff_eq!(s0a.keypoints[[5, 0]], s0b.keypoints[[5, 0]], epsilon = 1e-7); + } + + #[test] + fn synthetic_different_indices_differ() { + let cfg = SyntheticConfig::default(); + let ds = SyntheticCsiDataset::new(10, cfg); + let s0 = ds.get(0).unwrap(); + let s1 = ds.get(1).unwrap(); + // The sinusoidal model ensures different idx gives different values. + assert!((s0.amplitude[[0, 0, 0, 0]] - s1.amplitude[[0, 0, 0, 0]]).abs() > 1e-6); + } + + #[test] + fn synthetic_out_of_bounds() { + let ds = SyntheticCsiDataset::new(5, SyntheticConfig::default()); + assert!(matches!(ds.get(5), Err(DatasetError::IndexOutOfBounds { idx: 5, len: 5 }))); + } + + #[test] + fn synthetic_amplitude_in_valid_range() { + // Model: 0.5 ± 0.3, so all values in [0.2, 0.8] + let cfg = SyntheticConfig::default(); + let ds = SyntheticCsiDataset::new(4, cfg); + for idx in 0..4 { + let s = ds.get(idx).unwrap(); + for &v in s.amplitude.iter() { + assert!(v >= 0.19 && v <= 0.81, "amplitude {v} out of [0.2, 0.8]"); + } + } + } + + #[test] + fn synthetic_keypoints_in_unit_square() { + let cfg = SyntheticConfig::default(); + let ds = SyntheticCsiDataset::new(8, cfg); + for idx in 0..8 { + let s = ds.get(idx).unwrap(); + for kp in s.keypoints.outer_iter() { + assert!(kp[0] >= 0.0 && kp[0] <= 1.0, "x={} out of [0,1]", kp[0]); + assert!(kp[1] >= 0.0 && kp[1] <= 1.0, "y={} out of [0,1]", kp[1]); + } + } + } + + #[test] + fn synthetic_all_joints_visible() { + let cfg = SyntheticConfig::default(); + let ds = SyntheticCsiDataset::new(3, cfg.clone()); + let s = ds.get(0).unwrap(); + assert!(s.keypoint_visibility.iter().all(|&v| (v - 2.0).abs() < 1e-6)); + } + + // ----- DataLoader ------------------------------------------------------- + + #[test] + fn dataloader_num_batches() { + let cfg = SyntheticConfig::default(); + let ds = SyntheticCsiDataset::new(10, cfg); + // 10 samples, batch_size=3 → ceil(10/3) = 4 + let dl = DataLoader::new(&ds, 3, false, 42); + assert_eq!(dl.num_batches(), 4); + } + + #[test] + fn dataloader_iterates_all_samples_no_shuffle() { + let cfg = SyntheticConfig::default(); + let ds = SyntheticCsiDataset::new(10, cfg); + let dl = DataLoader::new(&ds, 3, false, 42); + let total: usize = dl.iter().map(|b| b.len()).sum(); + assert_eq!(total, 10); + } + + #[test] + fn dataloader_iterates_all_samples_shuffle() { + let cfg = SyntheticConfig::default(); + let ds = SyntheticCsiDataset::new(17, cfg); + let dl = DataLoader::new(&ds, 4, true, 42); + let total: usize = dl.iter().map(|b| b.len()).sum(); + assert_eq!(total, 17); + } + + #[test] + fn dataloader_shuffle_is_deterministic() { + let cfg = SyntheticConfig::default(); + let ds = SyntheticCsiDataset::new(20, cfg); + let dl1 = DataLoader::new(&ds, 5, true, 99); + let dl2 = DataLoader::new(&ds, 5, true, 99); + let ids1: Vec = dl1.iter().flatten().map(|s| s.frame_id).collect(); + let ids2: Vec = dl2.iter().flatten().map(|s| s.frame_id).collect(); + assert_eq!(ids1, ids2); + } + + #[test] + fn dataloader_different_seeds_differ() { + let cfg = SyntheticConfig::default(); + let ds = SyntheticCsiDataset::new(20, cfg); + let dl1 = DataLoader::new(&ds, 20, true, 1); + let dl2 = DataLoader::new(&ds, 20, true, 2); + let ids1: Vec = dl1.iter().flatten().map(|s| s.frame_id).collect(); + let ids2: Vec = dl2.iter().flatten().map(|s| s.frame_id).collect(); + assert_ne!(ids1, ids2, "different seeds should produce different orders"); + } + + #[test] + fn dataloader_empty_dataset() { + let cfg = SyntheticConfig::default(); + let ds = SyntheticCsiDataset::new(0, cfg); + let dl = DataLoader::new(&ds, 4, false, 42); + assert_eq!(dl.num_batches(), 0); + assert_eq!(dl.iter().count(), 0); + } + + // ----- Helpers ---------------------------------------------------------- + + #[test] + fn parse_id_suffix_works() { + assert_eq!(parse_id_suffix("S01"), Some(1)); + assert_eq!(parse_id_suffix("A12"), Some(12)); + assert_eq!(parse_id_suffix("foo"), None); + assert_eq!(parse_id_suffix("S"), None); + } + + #[test] + fn xorshift_shuffle_is_permutation() { + let mut indices: Vec = (0..20).collect(); + xorshift_shuffle(&mut indices, 42); + let mut sorted = indices.clone(); + sorted.sort_unstable(); + assert_eq!(sorted, (0..20).collect::>()); + } + + #[test] + fn xorshift_shuffle_is_deterministic() { + let mut a: Vec = (0..20).collect(); + let mut b: Vec = (0..20).collect(); + xorshift_shuffle(&mut a, 123); + xorshift_shuffle(&mut b, 123); + assert_eq!(a, b); + } +} diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/error.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/error.rs new file mode 100644 index 0000000..1fbb230 --- /dev/null +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/error.rs @@ -0,0 +1,384 @@ +//! Error types for the WiFi-DensePose training pipeline. +//! +//! This module defines a hierarchy of errors covering every failure mode in +//! the training pipeline: configuration validation, dataset I/O, subcarrier +//! interpolation, and top-level training orchestration. + +use thiserror::Error; +use std::path::PathBuf; + +// --------------------------------------------------------------------------- +// Top-level training error +// --------------------------------------------------------------------------- + +/// A convenient `Result` alias used throughout the training crate. +pub type TrainResult = Result; + +/// Top-level error type for the training pipeline. +/// +/// Every public function in this crate that can fail returns +/// `TrainResult`, which is `Result`. +#[derive(Debug, Error)] +pub enum TrainError { + /// Configuration is invalid or internally inconsistent. + #[error("Configuration error: {0}")] + Config(#[from] ConfigError), + + /// A dataset operation failed (I/O, format, missing data). + #[error("Dataset error: {0}")] + Dataset(#[from] DatasetError), + + /// Subcarrier interpolation / resampling failed. + #[error("Subcarrier interpolation error: {0}")] + Subcarrier(#[from] SubcarrierError), + + /// An underlying I/O error not covered by a more specific variant. + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), + + /// JSON (de)serialization error. + #[error("JSON error: {0}")] + Json(#[from] serde_json::Error), + + /// TOML (de)serialization error. + #[error("TOML deserialization error: {0}")] + TomlDe(#[from] toml::de::Error), + + /// TOML serialization error. + #[error("TOML serialization error: {0}")] + TomlSer(#[from] toml::ser::Error), + + /// An operation was attempted on an empty dataset. + #[error("Dataset is empty")] + EmptyDataset, + + /// Index out of bounds when accessing dataset items. + #[error("Index {index} is out of bounds for dataset of length {len}")] + IndexOutOfBounds { + /// The requested index. + index: usize, + /// The total number of items. + len: usize, + }, + + /// A numeric shape/dimension mismatch was detected. + #[error("Shape mismatch: expected {expected:?}, got {actual:?}")] + ShapeMismatch { + /// Expected shape. + expected: Vec, + /// Actual shape. + actual: Vec, + }, + + /// A training step failed for a reason not covered above. + #[error("Training step failed: {0}")] + TrainingStep(String), + + /// Checkpoint could not be saved or loaded. + #[error("Checkpoint error: {message} (path: {path:?})")] + Checkpoint { + /// Human-readable description. + message: String, + /// Path that was being accessed. + path: PathBuf, + }, + + /// Feature not yet implemented. + #[error("Not implemented: {0}")] + NotImplemented(String), +} + +impl TrainError { + /// Create a [`TrainError::TrainingStep`] with the given message. + pub fn training_step>(msg: S) -> Self { + TrainError::TrainingStep(msg.into()) + } + + /// Create a [`TrainError::Checkpoint`] error. + pub fn checkpoint>(msg: S, path: impl Into) -> Self { + TrainError::Checkpoint { + message: msg.into(), + path: path.into(), + } + } + + /// Create a [`TrainError::NotImplemented`] error. + pub fn not_implemented>(msg: S) -> Self { + TrainError::NotImplemented(msg.into()) + } + + /// Create a [`TrainError::ShapeMismatch`] error. + pub fn shape_mismatch(expected: Vec, actual: Vec) -> Self { + TrainError::ShapeMismatch { expected, actual } + } +} + +// --------------------------------------------------------------------------- +// Configuration errors +// --------------------------------------------------------------------------- + +/// Errors produced when validating or loading a [`TrainingConfig`]. +/// +/// [`TrainingConfig`]: crate::config::TrainingConfig +#[derive(Debug, Error)] +pub enum ConfigError { + /// A required field has a value that violates a constraint. + #[error("Invalid value for field `{field}`: {reason}")] + InvalidValue { + /// Name of the configuration field. + field: &'static str, + /// Human-readable reason the value is invalid. + reason: String, + }, + + /// The configuration file could not be read. + #[error("Cannot read configuration file `{path}`: {source}")] + FileRead { + /// Path that was being read. + path: PathBuf, + /// Underlying I/O error. + #[source] + source: std::io::Error, + }, + + /// The configuration file contains invalid TOML. + #[error("Cannot parse configuration file `{path}`: {source}")] + ParseError { + /// Path that was being parsed. + path: PathBuf, + /// Underlying TOML parse error. + #[source] + source: toml::de::Error, + }, + + /// A path specified in the config does not exist. + #[error("Path `{path}` specified in config does not exist")] + PathNotFound { + /// The missing path. + path: PathBuf, + }, +} + +impl ConfigError { + /// Construct an [`ConfigError::InvalidValue`] error. + pub fn invalid_value>(field: &'static str, reason: S) -> Self { + ConfigError::InvalidValue { + field, + reason: reason.into(), + } + } +} + +// --------------------------------------------------------------------------- +// Dataset errors +// --------------------------------------------------------------------------- + +/// Errors produced while loading or accessing dataset samples. +#[derive(Debug, Error)] +pub enum DatasetError { + /// The requested data file or directory was not found. + /// + /// Production training data is mandatory; this error is never silently + /// suppressed. Use [`SyntheticDataset`] only for proof/testing. + /// + /// [`SyntheticDataset`]: crate::dataset::SyntheticDataset + #[error("Data not found at `{path}`: {message}")] + DataNotFound { + /// Path that was expected to contain data. + path: PathBuf, + /// Additional context. + message: String, + }, + + /// A file was found but its format is incorrect or unexpected. + /// + /// This covers malformed numpy arrays, unexpected shapes, bad JSON + /// metadata, etc. + #[error("Invalid data format in `{path}`: {message}")] + InvalidFormat { + /// Path of the malformed file. + path: PathBuf, + /// Description of the format problem. + message: String, + }, + + /// A low-level I/O error while reading a data file. + #[error("I/O error reading `{path}`: {source}")] + IoError { + /// Path being read when the error occurred. + path: PathBuf, + /// Underlying I/O error. + #[source] + source: std::io::Error, + }, + + /// The number of subcarriers in the data file does not match the + /// configuration expectation (before or after interpolation). + #[error( + "Subcarrier count mismatch in `{path}`: \ + file has {found} subcarriers, expected {expected}" + )] + SubcarrierMismatch { + /// Path of the offending file. + path: PathBuf, + /// Number of subcarriers found in the file. + found: usize, + /// Number of subcarriers expected by the configuration. + expected: usize, + }, + + /// A sample index was out of bounds. + #[error("Index {index} is out of bounds for dataset of length {len}")] + IndexOutOfBounds { + /// The requested index. + index: usize, + /// Total number of samples. + len: usize, + }, + + /// A numpy array could not be read. + #[error("NumPy array read error in `{path}`: {message}")] + NpyReadError { + /// Path of the `.npy` file. + path: PathBuf, + /// Error description. + message: String, + }, + + /// A metadata file (e.g., `meta.json`) is missing or malformed. + #[error("Metadata error for subject {subject_id}: {message}")] + MetadataError { + /// Subject whose metadata could not be read. + subject_id: u32, + /// Description of the problem. + message: String, + }, + + /// No subjects matching the requested IDs were found in the data directory. + #[error( + "No subjects found in `{data_dir}` matching the requested IDs: {requested:?}" + )] + NoSubjectsFound { + /// Root data directory that was scanned. + data_dir: PathBuf, + /// Subject IDs that were requested. + requested: Vec, + }, + + /// A subcarrier interpolation error occurred during sample loading. + #[error("Subcarrier interpolation failed while loading sample {sample_idx}: {source}")] + InterpolationError { + /// The sample index being loaded. + sample_idx: usize, + /// Underlying interpolation error. + #[source] + source: SubcarrierError, + }, +} + +impl DatasetError { + /// Construct a [`DatasetError::DataNotFound`] error. + pub fn not_found>(path: impl Into, msg: S) -> Self { + DatasetError::DataNotFound { + path: path.into(), + message: msg.into(), + } + } + + /// Construct a [`DatasetError::InvalidFormat`] error. + pub fn invalid_format>(path: impl Into, msg: S) -> Self { + DatasetError::InvalidFormat { + path: path.into(), + message: msg.into(), + } + } + + /// Construct a [`DatasetError::IoError`] error. + pub fn io_error(path: impl Into, source: std::io::Error) -> Self { + DatasetError::IoError { + path: path.into(), + source, + } + } + + /// Construct a [`DatasetError::SubcarrierMismatch`] error. + pub fn subcarrier_mismatch(path: impl Into, found: usize, expected: usize) -> Self { + DatasetError::SubcarrierMismatch { + path: path.into(), + found, + expected, + } + } + + /// Construct a [`DatasetError::NpyReadError`] error. + pub fn npy_read>(path: impl Into, msg: S) -> Self { + DatasetError::NpyReadError { + path: path.into(), + message: msg.into(), + } + } +} + +// --------------------------------------------------------------------------- +// Subcarrier interpolation errors +// --------------------------------------------------------------------------- + +/// Errors produced by the subcarrier resampling functions. +#[derive(Debug, Error)] +pub enum SubcarrierError { + /// The source or destination subcarrier count is zero. + #[error("Subcarrier count must be at least 1, got {count}")] + ZeroCount { + /// The offending count. + count: usize, + }, + + /// The input array has an unexpected shape. + #[error( + "Input array shape mismatch: expected last dimension {expected_sc}, \ + got {actual_sc} (full shape: {shape:?})" + )] + InputShapeMismatch { + /// Expected number of subcarriers (last dimension). + expected_sc: usize, + /// Actual number of subcarriers found. + actual_sc: usize, + /// Full shape of the input array. + shape: Vec, + }, + + /// The requested interpolation method is not implemented. + #[error("Interpolation method `{method}` is not yet implemented")] + MethodNotImplemented { + /// Name of the unimplemented method. + method: String, + }, + + /// Source and destination subcarrier counts are already equal. + /// + /// Callers should check [`TrainingConfig::needs_subcarrier_interp`] before + /// calling the interpolation routine to avoid this error. + /// + /// [`TrainingConfig::needs_subcarrier_interp`]: + /// crate::config::TrainingConfig::needs_subcarrier_interp + #[error( + "Source and destination subcarrier counts are equal ({count}); \ + no interpolation is needed" + )] + NopInterpolation { + /// The equal count. + count: usize, + }, + + /// A numerical error occurred during interpolation (e.g., division by zero + /// due to coincident knot positions). + #[error("Numerical error during interpolation: {0}")] + NumericalError(String), +} + +impl SubcarrierError { + /// Construct a [`SubcarrierError::NumericalError`]. + pub fn numerical>(msg: S) -> Self { + SubcarrierError::NumericalError(msg.into()) + } +} diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/lib.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/lib.rs new file mode 100644 index 0000000..d1b915c --- /dev/null +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/lib.rs @@ -0,0 +1,61 @@ +//! # WiFi-DensePose Training Infrastructure +//! +//! This crate provides the complete training pipeline for the WiFi-DensePose pose +//! estimation model. It includes configuration management, dataset loading with +//! subcarrier interpolation, loss functions, evaluation metrics, and the training +//! loop orchestrator. +//! +//! ## Architecture +//! +//! ```text +//! TrainingConfig ──► Trainer ──► Model +//! │ │ +//! │ DataLoader +//! │ │ +//! │ CsiDataset (MmFiDataset | SyntheticCsiDataset) +//! │ │ +//! │ subcarrier::interpolate_subcarriers +//! │ +//! └──► losses / metrics +//! ``` +//! +//! ## Quick Start +//! +//! ```rust,no_run +//! use wifi_densepose_train::config::TrainingConfig; +//! use wifi_densepose_train::dataset::{SyntheticCsiDataset, SyntheticConfig, CsiDataset}; +//! +//! // Build config +//! let config = TrainingConfig::default(); +//! config.validate().expect("config is valid"); +//! +//! // Create a synthetic dataset (deterministic, fixed-seed) +//! let syn_cfg = SyntheticConfig::default(); +//! let dataset = SyntheticCsiDataset::new(200, syn_cfg); +//! +//! // Load one sample +//! let sample = dataset.get(0).unwrap(); +//! println!("amplitude shape: {:?}", sample.amplitude.shape()); +//! ``` + +#![forbid(unsafe_code)] +#![warn(missing_docs)] + +pub mod config; +pub mod dataset; +pub mod error; +pub mod losses; +pub mod metrics; +pub mod model; +pub mod proof; +pub mod subcarrier; +pub mod trainer; + +// Convenient re-exports at the crate root. +pub use config::TrainingConfig; +pub use dataset::{CsiDataset, CsiSample, DataLoader, MmFiDataset, SyntheticCsiDataset, SyntheticConfig}; +pub use error::{ConfigError, DatasetError, SubcarrierError, TrainError, TrainResult}; +pub use subcarrier::{compute_interp_weights, interpolate_subcarriers, select_subcarriers_by_variance}; + +/// Crate version string. +pub const VERSION: &str = env!("CARGO_PKG_VERSION"); diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/losses.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/losses.rs new file mode 100644 index 0000000..a8e8f28 --- /dev/null +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/losses.rs @@ -0,0 +1,909 @@ +//! Loss functions for WiFi-DensePose training. +//! +//! This module implements the combined loss function used during training: +//! +//! - **Keypoint heatmap loss**: MSE between predicted and target Gaussian heatmaps, +//! masked by keypoint visibility so unlabelled joints don't contribute. +//! - **DensePose loss**: Cross-entropy on body-part logits (25 classes including +//! background) plus Smooth-L1 (Huber) UV regression for each foreground part. +//! - **Transfer / distillation loss**: MSE between student backbone features and +//! teacher features, enabling cross-modal knowledge transfer from an RGB teacher. +//! +//! The three scalar losses are combined with configurable weights: +//! +//! ```text +//! L_total = λ_kp · L_keypoint + λ_dp · L_densepose + λ_tr · L_transfer +//! ``` +//! +//! # No mock data +//! Every computation in this module is grounded in real signal mathematics. +//! No synthetic or random tensors are generated at runtime. + +use std::collections::HashMap; +use tch::{Kind, Reduction, Tensor}; + +// ───────────────────────────────────────────────────────────────────────────── +// Public types +// ───────────────────────────────────────────────────────────────────────────── + +/// Scalar components produced by a single forward pass through the combined loss. +#[derive(Debug, Clone)] +pub struct LossOutput { + /// Total weighted loss value (scalar, in ℝ≥0). + pub total: f32, + /// Keypoint heatmap MSE loss component. + pub keypoint: f32, + /// DensePose (part + UV) loss component, `None` when no DensePose targets are given. + pub densepose: Option, + /// Transfer/distillation loss component, `None` when no teacher features are given. + pub transfer: Option, + /// Fine-grained breakdown (e.g. `"dp_part"`, `"dp_uv"`, `"kp_masked"`, …). + pub details: HashMap, +} + +/// Per-loss scalar weights used to combine the individual losses. +#[derive(Debug, Clone)] +pub struct LossWeights { + /// Weight for the keypoint heatmap loss (λ_kp). + pub lambda_kp: f64, + /// Weight for the DensePose loss (λ_dp). + pub lambda_dp: f64, + /// Weight for the transfer/distillation loss (λ_tr). + pub lambda_tr: f64, +} + +impl Default for LossWeights { + fn default() -> Self { + Self { + lambda_kp: 0.3, + lambda_dp: 0.6, + lambda_tr: 0.1, + } + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// WiFiDensePoseLoss +// ───────────────────────────────────────────────────────────────────────────── + +/// Combined loss function for WiFi-DensePose training. +/// +/// Wraps three component losses: +/// 1. Keypoint heatmap MSE (visibility-masked) +/// 2. DensePose: part cross-entropy + UV Smooth-L1 +/// 3. Teacher-student feature transfer MSE +pub struct WiFiDensePoseLoss { + weights: LossWeights, +} + +impl WiFiDensePoseLoss { + /// Create a new loss function with the given component weights. + pub fn new(weights: LossWeights) -> Self { + Self { weights } + } + + // ── Component losses ───────────────────────────────────────────────────── + + /// Compute the keypoint heatmap loss. + /// + /// For each keypoint joint `j` and batch element `b`, the pixel-wise MSE + /// between `pred_heatmaps[b, j, :, :]` and `target_heatmaps[b, j, :, :]` + /// is computed and multiplied by the binary visibility mask `visibility[b, j]`. + /// The sum is then divided by the number of visible joints to produce a + /// normalised scalar. + /// + /// If no keypoints are visible in the batch the function returns zero. + /// + /// # Shapes + /// - `pred_heatmaps`: `[B, 17, H, W]` – predicted heatmaps + /// - `target_heatmaps`: `[B, 17, H, W]` – ground-truth Gaussian heatmaps + /// - `visibility`: `[B, 17]` – 1.0 if the keypoint is labelled, 0.0 otherwise + pub fn keypoint_loss( + &self, + pred_heatmaps: &Tensor, + target_heatmaps: &Tensor, + visibility: &Tensor, + ) -> Tensor { + // Pixel-wise squared error, mean-reduced over H and W: [B, 17] + let sq_err = (pred_heatmaps - target_heatmaps).pow_tensor_scalar(2); + // Mean over H and W (dims 2, 3 → we flatten them first for clarity) + let per_joint_mse = sq_err.mean_dim(&[2_i64, 3_i64][..], false, Kind::Float); + + // Mask by visibility: [B, 17] + let masked = per_joint_mse * visibility; + + // Normalise by number of visible joints in the batch. + let n_visible = visibility.sum(Kind::Float); + // Guard against division by zero (entire batch may have no labels). + let safe_n = n_visible.clamp(1.0, f64::MAX); + + masked.sum(Kind::Float) / safe_n + } + + /// Compute the DensePose loss. + /// + /// Two sub-losses are combined: + /// 1. **Part cross-entropy** – softmax cross-entropy between `pred_parts` + /// logits `[B, 25, H, W]` and `target_parts` integer class indices + /// `[B, H, W]`. Class 0 is background and is included. + /// 2. **UV Smooth-L1 (Huber)** – for pixels that belong to a foreground + /// part (target class ≥ 1), the UV prediction error is penalised with + /// Smooth-L1 loss. Background pixels are masked out so the model is + /// not penalised for UV predictions at background locations. + /// + /// The two sub-losses are summed with equal weight. + /// + /// # Shapes + /// - `pred_parts`: `[B, 25, H, W]` – logits (24 body parts + background) + /// - `target_parts`: `[B, H, W]` – integer class indices in [0, 24] + /// - `pred_uv`: `[B, 48, H, W]` – 24 pairs of (U, V) predictions, interleaved + /// - `target_uv`: `[B, 48, H, W]` – ground-truth UV coordinates for each part + pub fn densepose_loss( + &self, + pred_parts: &Tensor, + target_parts: &Tensor, + pred_uv: &Tensor, + target_uv: &Tensor, + ) -> Tensor { + // ── 1. Part classification: cross-entropy ────────────────────────── + // tch cross_entropy_loss expects (input: [B,C,…], target: [B,…] of i64). + let target_int = target_parts.to_kind(Kind::Int64); + // weight=None, reduction=Mean, ignore_index=-100, label_smoothing=0.0 + let part_loss = pred_parts.cross_entropy_loss::( + &target_int, + None, + Reduction::Mean, + -100, + 0.0, + ); + + // ── 2. UV regression: Smooth-L1 masked by foreground pixels ──────── + // Foreground mask: pixels where target part ≠ 0, shape [B, H, W]. + let fg_mask = target_int.not_equal(0); + // Expand to [B, 1, H, W] then broadcast to [B, 48, H, W]. + let fg_mask_f = fg_mask + .unsqueeze(1) + .expand_as(pred_uv) + .to_kind(Kind::Float); + + let masked_pred_uv = pred_uv * &fg_mask_f; + let masked_target_uv = target_uv * &fg_mask_f; + + // Count foreground pixels × 48 channels to normalise. + let n_fg = fg_mask_f.sum(Kind::Float).clamp(1.0, f64::MAX); + + // Smooth-L1 with beta=1.0, reduction=Sum then divide by fg count. + let uv_loss_sum = + masked_pred_uv.smooth_l1_loss(&masked_target_uv, Reduction::Sum, 1.0); + let uv_loss = uv_loss_sum / n_fg; + + part_loss + uv_loss + } + + /// Compute the teacher-student feature transfer (distillation) loss. + /// + /// The loss is a plain MSE between the student backbone feature map and the + /// teacher's corresponding feature map. Both tensors must have the same + /// shape `[B, C, H, W]`. + /// + /// This implements the cross-modal knowledge distillation component of the + /// WiFi-DensePose paper where an RGB teacher supervises the CSI student. + pub fn transfer_loss(&self, student_features: &Tensor, teacher_features: &Tensor) -> Tensor { + student_features.mse_loss(teacher_features, Reduction::Mean) + } + + // ── Combined forward ───────────────────────────────────────────────────── + + /// Compute and combine all loss components. + /// + /// Returns `(total_loss_tensor, LossOutput)` where `total_loss_tensor` is + /// the differentiable scalar for back-propagation and `LossOutput` contains + /// detached `f32` values for logging. + /// + /// # Arguments + /// - `pred_keypoints`, `target_keypoints`: `[B, 17, H, W]` + /// - `visibility`: `[B, 17]` + /// - `pred_parts`, `target_parts`: `[B, 25, H, W]` / `[B, H, W]` (optional) + /// - `pred_uv`, `target_uv`: `[B, 48, H, W]` (optional, paired with parts) + /// - `student_features`, `teacher_features`: `[B, C, H, W]` (optional) + #[allow(clippy::too_many_arguments)] + pub fn forward( + &self, + pred_keypoints: &Tensor, + target_keypoints: &Tensor, + visibility: &Tensor, + pred_parts: Option<&Tensor>, + target_parts: Option<&Tensor>, + pred_uv: Option<&Tensor>, + target_uv: Option<&Tensor>, + student_features: Option<&Tensor>, + teacher_features: Option<&Tensor>, + ) -> (Tensor, LossOutput) { + let mut details = HashMap::new(); + + // ── Keypoint loss (always computed) ─────────────────────────────── + let kp_loss = self.keypoint_loss(pred_keypoints, target_keypoints, visibility); + let kp_val: f64 = kp_loss.double_value(&[]); + details.insert("kp_mse".to_string(), kp_val as f32); + + let total = kp_loss.shallow_clone() * self.weights.lambda_kp; + + // ── DensePose loss (optional) ───────────────────────────────────── + let (dp_val, total) = match (pred_parts, target_parts, pred_uv, target_uv) { + (Some(pp), Some(tp), Some(pu), Some(tu)) => { + // Part cross-entropy + let target_int = tp.to_kind(Kind::Int64); + let part_loss = pp.cross_entropy_loss::( + &target_int, + None, + Reduction::Mean, + -100, + 0.0, + ); + let part_val = part_loss.double_value(&[]) as f32; + + // UV loss (foreground masked) + let fg_mask = target_int.not_equal(0); + let fg_mask_f = fg_mask + .unsqueeze(1) + .expand_as(pu) + .to_kind(Kind::Float); + let n_fg = fg_mask_f.sum(Kind::Float).clamp(1.0, f64::MAX); + let uv_loss = (pu * &fg_mask_f) + .smooth_l1_loss(&(tu * &fg_mask_f), Reduction::Sum, 1.0) + / n_fg; + let uv_val = uv_loss.double_value(&[]) as f32; + + let dp_loss = &part_loss + &uv_loss; + let dp_scalar = dp_loss.double_value(&[]) as f32; + + details.insert("dp_part_ce".to_string(), part_val); + details.insert("dp_uv_smooth_l1".to_string(), uv_val); + + let new_total = total + dp_loss * self.weights.lambda_dp; + (Some(dp_scalar), new_total) + } + _ => (None, total), + }; + + // ── Transfer loss (optional) ────────────────────────────────────── + let (tr_val, total) = match (student_features, teacher_features) { + (Some(sf), Some(tf)) => { + let tr_loss = self.transfer_loss(sf, tf); + let tr_scalar = tr_loss.double_value(&[]) as f32; + details.insert("transfer_mse".to_string(), tr_scalar); + let new_total = total + tr_loss * self.weights.lambda_tr; + (Some(tr_scalar), new_total) + } + _ => (None, total), + }; + + let total_val = total.double_value(&[]) as f32; + + let output = LossOutput { + total: total_val, + keypoint: kp_val as f32, + densepose: dp_val, + transfer: tr_val, + details, + }; + + (total, output) + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// Gaussian heatmap utilities +// ───────────────────────────────────────────────────────────────────────────── + +/// Generate a 2-D Gaussian heatmap for a single keypoint. +/// +/// The heatmap is a `heatmap_size × heatmap_size` array where the value at +/// pixel `(r, c)` is: +/// +/// ```text +/// H[r, c] = exp( -((c - kp_x * S)² + (r - kp_y * S)²) / (2 · σ²) ) +/// ``` +/// +/// where `S = heatmap_size - 1` maps normalised coordinates to pixel space. +/// +/// Values outside the 3σ radius are clamped to zero to produce a sparse +/// representation that is numerically identical to the training targets used +/// in the original DensePose paper. +/// +/// # Arguments +/// - `kp_x`, `kp_y`: normalised keypoint position in [0, 1] +/// - `heatmap_size`: spatial resolution of the heatmap (H = W) +/// - `sigma`: Gaussian spread in pixels (default 2.0 gives a tight, localised peak) +/// +/// # Returns +/// A `heatmap_size × heatmap_size` array with values in [0, 1]. +pub fn generate_gaussian_heatmap( + kp_x: f32, + kp_y: f32, + heatmap_size: usize, + sigma: f32, +) -> ndarray::Array2 { + let s = (heatmap_size - 1) as f32; + let cx = kp_x * s; + let cy = kp_y * s; + let two_sigma_sq = 2.0 * sigma * sigma; + let clip_radius_sq = (3.0 * sigma).powi(2); + + let mut map = ndarray::Array2::zeros((heatmap_size, heatmap_size)); + for r in 0..heatmap_size { + for c in 0..heatmap_size { + let dx = c as f32 - cx; + let dy = r as f32 - cy; + let dist_sq = dx * dx + dy * dy; + if dist_sq <= clip_radius_sq { + map[[r, c]] = (-dist_sq / two_sigma_sq).exp(); + } + } + } + map +} + +/// Generate a batch of target heatmaps from keypoint coordinates. +/// +/// For invisible keypoints (`visibility[b, j] == 0`) the corresponding +/// heatmap channel is left as all-zeros. +/// +/// # Arguments +/// - `keypoints`: `[B, 17, 2]` – (x, y) normalised to [0, 1] +/// - `visibility`: `[B, 17]` – 1.0 if visible, 0.0 if invisible +/// - `heatmap_size`: spatial resolution (H = W) +/// - `sigma`: Gaussian sigma in pixels +/// +/// # Returns +/// `[B, 17, heatmap_size, heatmap_size]` target heatmap array. +pub fn generate_target_heatmaps( + keypoints: &ndarray::Array3, + visibility: &ndarray::Array2, + heatmap_size: usize, + sigma: f32, +) -> ndarray::Array4 { + let batch = keypoints.shape()[0]; + let num_joints = keypoints.shape()[1]; + + let mut heatmaps = + ndarray::Array4::zeros((batch, num_joints, heatmap_size, heatmap_size)); + + for b in 0..batch { + for j in 0..num_joints { + if visibility[[b, j]] > 0.0 { + let kp_x = keypoints[[b, j, 0]]; + let kp_y = keypoints[[b, j, 1]]; + let hm = generate_gaussian_heatmap(kp_x, kp_y, heatmap_size, sigma); + for r in 0..heatmap_size { + for c in 0..heatmap_size { + heatmaps[[b, j, r, c]] = hm[[r, c]]; + } + } + } + } + } + heatmaps +} + +// ───────────────────────────────────────────────────────────────────────────── +// Standalone functional API (mirrors the spec signatures exactly) +// ───────────────────────────────────────────────────────────────────────────── + +/// Output of the combined loss computation (functional API). +#[derive(Debug, Clone)] +pub struct LossOutput { + /// Weighted total loss (for backward pass). + pub total: f64, + /// Keypoint heatmap MSE loss (unweighted). + pub keypoint: f64, + /// DensePose part classification loss (unweighted), `None` if not computed. + pub densepose_parts: Option, + /// DensePose UV regression loss (unweighted), `None` if not computed. + pub densepose_uv: Option, + /// Teacher-student transfer loss (unweighted), `None` if teacher features absent. + pub transfer: Option, +} + +/// Compute the total weighted loss given model predictions and targets. +/// +/// # Arguments +/// * `pred_kpt_heatmaps` - Predicted keypoint heatmaps: \[B, 17, H, W\] +/// * `gt_kpt_heatmaps` - Ground truth Gaussian heatmaps: \[B, 17, H, W\] +/// * `pred_part_logits` - Predicted DensePose part logits: \[B, 25, H, W\] +/// * `gt_part_labels` - GT part class indices: \[B, H, W\], value −1 = ignore +/// * `pred_uv` - Predicted UV coordinates: \[B, 48, H, W\] +/// * `gt_uv` - Ground truth UV: \[B, 48, H, W\] +/// * `student_features` - Student backbone features: \[B, C, H', W'\] +/// * `teacher_features` - Teacher backbone features: \[B, C, H', W'\] +/// * `lambda_kp` - Weight for keypoint loss +/// * `lambda_dp` - Weight for DensePose loss +/// * `lambda_tr` - Weight for transfer loss +#[allow(clippy::too_many_arguments)] +pub fn compute_losses( + pred_kpt_heatmaps: &Tensor, + gt_kpt_heatmaps: &Tensor, + pred_part_logits: Option<&Tensor>, + gt_part_labels: Option<&Tensor>, + pred_uv: Option<&Tensor>, + gt_uv: Option<&Tensor>, + student_features: Option<&Tensor>, + teacher_features: Option<&Tensor>, + lambda_kp: f64, + lambda_dp: f64, + lambda_tr: f64, +) -> LossOutput { + // ── Keypoint heatmap loss — always computed ──────────────────────────── + let kpt_tensor = keypoint_heatmap_loss(pred_kpt_heatmaps, gt_kpt_heatmaps); + let keypoint: f64 = kpt_tensor.double_value(&[]); + + // ── DensePose part classification loss ──────────────────────────────── + let (densepose_parts, dp_part_tensor): (Option, Option) = + match (pred_part_logits, gt_part_labels) { + (Some(logits), Some(labels)) => { + let t = densepose_part_loss(logits, labels); + let v = t.double_value(&[]); + (Some(v), Some(t)) + } + _ => (None, None), + }; + + // ── DensePose UV regression loss ────────────────────────────────────── + let (densepose_uv, dp_uv_tensor): (Option, Option) = + match (pred_uv, gt_uv, gt_part_labels) { + (Some(puv), Some(guv), Some(labels)) => { + let t = densepose_uv_loss(puv, guv, labels); + let v = t.double_value(&[]); + (Some(v), Some(t)) + } + _ => (None, None), + }; + + // ── Teacher-student transfer loss ───────────────────────────────────── + let (transfer, tr_tensor): (Option, Option) = + match (student_features, teacher_features) { + (Some(sf), Some(tf)) => { + let t = fn_transfer_loss(sf, tf); + let v = t.double_value(&[]); + (Some(v), Some(t)) + } + _ => (None, None), + }; + + // ── Weighted sum ────────────────────────────────────────────────────── + let mut total_t = kpt_tensor * lambda_kp; + + // Combine densepose part + UV under a single lambda_dp weight. + let zero_scalar = Tensor::zeros(&[], (Kind::Float, total_t.device())); + let dp_part_t = dp_part_tensor + .as_ref() + .map(|t| t.shallow_clone()) + .unwrap_or_else(|| zero_scalar.shallow_clone()); + let dp_uv_t = dp_uv_tensor + .as_ref() + .map(|t| t.shallow_clone()) + .unwrap_or_else(|| zero_scalar.shallow_clone()); + + if densepose_parts.is_some() || densepose_uv.is_some() { + total_t = total_t + (&dp_part_t + &dp_uv_t) * lambda_dp; + } + + if let Some(ref tr) = tr_tensor { + total_t = total_t + tr * lambda_tr; + } + + let total: f64 = total_t.double_value(&[]); + + LossOutput { + total, + keypoint, + densepose_parts, + densepose_uv, + transfer, + } +} + +/// Keypoint heatmap loss: MSE between predicted and Gaussian-smoothed GT heatmaps. +/// +/// Invisible keypoints must be zeroed in `target` before calling this function +/// (use [`generate_gaussian_heatmaps`] which handles that automatically). +/// +/// # Arguments +/// * `pred` - Predicted heatmaps \[B, 17, H, W\] +/// * `target` - Pre-computed GT Gaussian heatmaps \[B, 17, H, W\] +/// +/// Returns a scalar `Tensor`. +pub fn keypoint_heatmap_loss(pred: &Tensor, target: &Tensor) -> Tensor { + pred.mse_loss(target, Reduction::Mean) +} + +/// Generate Gaussian heatmaps from keypoint coordinates. +/// +/// For each keypoint `(x, y)` in \[0,1\] normalised space, places a 2D Gaussian +/// centred at the corresponding pixel location. Invisible keypoints produce +/// all-zero heatmap channels. +/// +/// # Arguments +/// * `keypoints` - \[B, 17, 2\] normalised (x, y) in \[0, 1\] +/// * `visibility` - \[B, 17\] 0 = invisible, 1 = visible +/// * `heatmap_size` - Output H = W (square heatmap) +/// * `sigma` - Gaussian sigma in pixels (default 2.0) +/// +/// Returns `[B, 17, H, W]`. +pub fn generate_gaussian_heatmaps( + keypoints: &Tensor, + visibility: &Tensor, + heatmap_size: usize, + sigma: f64, +) -> Tensor { + let device = keypoints.device(); + let kind = Kind::Float; + let size = heatmap_size as i64; + + let batch_size = keypoints.size()[0]; + let num_kpts = keypoints.size()[1]; + + // Build pixel-space coordinate grids — shape [1, 1, H, W] for broadcasting. + // `xs[w]` is the column index; `ys[h]` is the row index. + let xs = Tensor::arange(size, (kind, device)).view([1, 1, 1, size]); + let ys = Tensor::arange(size, (kind, device)).view([1, 1, size, 1]); + + // Convert normalised coords to pixel centres: pixel = coord * (size - 1). + // keypoints[:, :, 0] → x (column); keypoints[:, :, 1] → y (row). + let cx = keypoints + .select(2, 0) + .unsqueeze(-1) + .unsqueeze(-1) + .to_kind(kind) + * (size as f64 - 1.0); // [B, 17, 1, 1] + + let cy = keypoints + .select(2, 1) + .unsqueeze(-1) + .unsqueeze(-1) + .to_kind(kind) + * (size as f64 - 1.0); // [B, 17, 1, 1] + + // Gaussian: exp(−((x − cx)² + (y − cy)²) / (2σ²)), shape [B, 17, H, W]. + let two_sigma_sq = 2.0 * sigma * sigma; + let dx = &xs - &cx; + let dy = &ys - &cy; + let heatmaps = + (-(dx.pow_tensor_scalar(2.0) + dy.pow_tensor_scalar(2.0)) / two_sigma_sq).exp(); + + // Zero out invisible keypoints: visibility [B, 17] → [B, 17, 1, 1] boolean mask. + let vis_mask = visibility + .to_kind(kind) + .view([batch_size, num_kpts, 1, 1]) + .gt(0.0); + + let zero = Tensor::zeros(&[], (kind, device)); + heatmaps.where_self(&vis_mask, &zero) +} + +/// DensePose part classification loss: cross-entropy with `ignore_index = −1`. +/// +/// # Arguments +/// * `pred_logits` - \[B, 25, H, W\] (25 = 24 parts + background class 0) +/// * `gt_labels` - \[B, H, W\] integer labels; −1 = ignore (no annotation) +/// +/// Returns a scalar `Tensor`. +pub fn densepose_part_loss(pred_logits: &Tensor, gt_labels: &Tensor) -> Tensor { + let labels_i64 = gt_labels.to_kind(Kind::Int64); + pred_logits.cross_entropy_loss::( + &labels_i64, + None, // no per-class weights + Reduction::Mean, + -1, // ignore_index + 0.0, // label_smoothing + ) +} + +/// DensePose UV coordinate regression loss: Smooth L1 (Huber loss). +/// +/// Only pixels where `gt_labels >= 0` (annotated with a valid part) contribute +/// to the loss; unannotated (background) pixels are masked out. +/// +/// # Arguments +/// * `pred_uv` - \[B, 48, H, W\] predicted UV (24 parts × 2 channels) +/// * `gt_uv` - \[B, 48, H, W\] ground truth UV +/// * `gt_labels` - \[B, H, W\] part labels; mask = (labels ≥ 0) +/// +/// Returns a scalar `Tensor`. +pub fn densepose_uv_loss(pred_uv: &Tensor, gt_uv: &Tensor, gt_labels: &Tensor) -> Tensor { + // Boolean mask from annotated pixels: [B, 1, H, W]. + let mask = gt_labels.ge(0).unsqueeze(1); + // Expand to [B, 48, H, W]. + let mask_expanded = mask.expand_as(pred_uv); + + let pred_sel = pred_uv.masked_select(&mask_expanded); + let gt_sel = gt_uv.masked_select(&mask_expanded); + + if pred_sel.numel() == 0 { + // No annotated pixels — return a zero scalar, still attached to graph. + return Tensor::zeros(&[], (pred_uv.kind(), pred_uv.device())); + } + + pred_sel.smooth_l1_loss(>_sel, Reduction::Mean, 1.0) +} + +/// Teacher-student transfer loss: MSE between student and teacher feature maps. +/// +/// If spatial or channel dimensions differ, the student features are aligned +/// to the teacher's shape via adaptive average pooling (non-parametric, no +/// learnable projection weights). +/// +/// # Arguments +/// * `student_features` - \[B, Cs, Hs, Ws\] +/// * `teacher_features` - \[B, Ct, Ht, Wt\] +/// +/// Returns a scalar `Tensor`. +/// +/// This is a free function; the identical implementation is also available as +/// [`WiFiDensePoseLoss::transfer_loss`]. +pub fn fn_transfer_loss(student_features: &Tensor, teacher_features: &Tensor) -> Tensor { + let s_size = student_features.size(); + let t_size = teacher_features.size(); + + // Align spatial dimensions if needed. + let s_spatial = if s_size[2] != t_size[2] || s_size[3] != t_size[3] { + student_features.adaptive_avg_pool2d([t_size[2], t_size[3]]) + } else { + student_features.shallow_clone() + }; + + // Align channel dimensions if needed. + let s_final = if s_size[1] != t_size[1] { + let cs = s_spatial.size()[1]; + let ct = t_size[1]; + if cs % ct == 0 { + // Fast path: reshape + mean pool over the ratio dimension. + let ratio = cs / ct; + s_spatial + .view([-1, ct, ratio, t_size[2], t_size[3]]) + .mean_dim(Some(&[2i64][..]), false, Kind::Float) + } else { + // Generic: treat channel as sequence length, 1-D adaptive pool. + let b = s_spatial.size()[0]; + let h = t_size[2]; + let w = t_size[3]; + s_spatial + .permute([0, 2, 3, 1]) // [B, H, W, Cs] + .reshape([-1, 1, cs]) // [B·H·W, 1, Cs] + .adaptive_avg_pool1d(ct) // [B·H·W, 1, Ct] + .reshape([b, h, w, ct]) // [B, H, W, Ct] + .permute([0, 3, 1, 2]) // [B, Ct, H, W] + } + } else { + s_spatial + }; + + s_final.mse_loss(teacher_features, Reduction::Mean) +} + +// ───────────────────────────────────────────────────────────────────────────── +// Tests +// ───────────────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use ndarray::Array2; + + // ── Gaussian heatmap ────────────────────────────────────────────────────── + + #[test] + fn test_gaussian_heatmap_peak_location() { + let kp_x = 0.5_f32; + let kp_y = 0.5_f32; + let size = 64_usize; + let sigma = 2.0_f32; + + let hm = generate_gaussian_heatmap(kp_x, kp_y, size, sigma); + + // Peak should be at the centre (row=31, col=31) for a 64-pixel map + // with normalised coordinate 0.5 → pixel 31.5, rounded to 31 or 32. + let s = (size - 1) as f32; + let cx = (kp_x * s).round() as usize; + let cy = (kp_y * s).round() as usize; + + let peak = hm[[cy, cx]]; + assert!( + peak > 0.95, + "Peak value {peak} should be close to 1.0 at centre" + ); + + // Values far from the centre should be ≈ 0. + let far = hm[[0, 0]]; + assert!( + far < 0.01, + "Corner value {far} should be near zero" + ); + } + + #[test] + fn test_gaussian_heatmap_reasonable_sum() { + let hm = generate_gaussian_heatmap(0.5, 0.5, 64, 2.0); + let total: f32 = hm.iter().copied().sum(); + // The Gaussian sum over a 64×64 grid with σ=2 is bounded away from + // both 0 and infinity. Empirically it is ≈ 3·π·σ² ≈ 38 for σ=2. + assert!( + total > 5.0 && total < 200.0, + "Heatmap sum {total} out of expected range" + ); + } + + #[test] + fn test_generate_target_heatmaps_invisible_joints_are_zero() { + let batch = 2_usize; + let num_joints = 17_usize; + let size = 32_usize; + + let keypoints = ndarray::Array3::from_elem((batch, num_joints, 2), 0.5_f32); + // Make all joints in batch 0 invisible. + let mut visibility = ndarray::Array2::ones((batch, num_joints)); + for j in 0..num_joints { + visibility[[0, j]] = 0.0; + } + + let heatmaps = generate_target_heatmaps(&keypoints, &visibility, size, 2.0); + + // Every pixel of the invisible batch should be exactly 0. + for j in 0..num_joints { + for r in 0..size { + for c in 0..size { + assert_eq!( + heatmaps[[0, j, r, c]], + 0.0, + "Invisible joint heatmap should be zero" + ); + } + } + } + + // Visible batch (index 1) should have non-zero heatmaps. + let batch1_sum: f32 = (0..num_joints) + .map(|j| { + (0..size) + .flat_map(|r| (0..size).map(move |c| heatmaps[[1, j, r, c]])) + .sum::() + }) + .sum(); + assert!(batch1_sum > 0.0, "Visible joints should produce non-zero heatmaps"); + } + + // ── Loss functions ──────────────────────────────────────────────────────── + + /// Returns a CUDA-or-CPU device string: always "cpu" in CI. + fn device() -> tch::Device { + tch::Device::Cpu + } + + #[test] + fn test_keypoint_loss_identical_predictions_is_zero() { + let loss_fn = WiFiDensePoseLoss::new(LossWeights::default()); + let dev = device(); + + // [B=2, 17, H=16, W=16] – use ones as a trivial non-zero tensor. + let pred = Tensor::ones([2, 17, 16, 16], (Kind::Float, dev)); + let target = Tensor::ones([2, 17, 16, 16], (Kind::Float, dev)); + let vis = Tensor::ones([2, 17], (Kind::Float, dev)); + + let loss = loss_fn.keypoint_loss(&pred, &target, &vis); + let val = loss.double_value(&[]) as f32; + + assert!( + val.abs() < 1e-5, + "Keypoint loss for identical pred/target should be ≈ 0, got {val}" + ); + } + + #[test] + fn test_keypoint_loss_large_error_is_positive() { + let loss_fn = WiFiDensePoseLoss::new(LossWeights::default()); + let dev = device(); + + let pred = Tensor::ones([1, 17, 8, 8], (Kind::Float, dev)); + let target = Tensor::zeros([1, 17, 8, 8], (Kind::Float, dev)); + let vis = Tensor::ones([1, 17], (Kind::Float, dev)); + + let loss = loss_fn.keypoint_loss(&pred, &target, &vis); + let val = loss.double_value(&[]) as f32; + + assert!(val > 0.0, "Keypoint loss should be positive for wrong predictions"); + } + + #[test] + fn test_keypoint_loss_invisible_joints_ignored() { + let loss_fn = WiFiDensePoseLoss::new(LossWeights::default()); + let dev = device(); + + // pred ≠ target – but all joints invisible → loss should be 0. + let pred = Tensor::ones([1, 17, 8, 8], (Kind::Float, dev)); + let target = Tensor::zeros([1, 17, 8, 8], (Kind::Float, dev)); + let vis = Tensor::zeros([1, 17], (Kind::Float, dev)); // all invisible + + let loss = loss_fn.keypoint_loss(&pred, &target, &vis); + let val = loss.double_value(&[]) as f32; + + assert!( + val.abs() < 1e-5, + "All-invisible loss should be ≈ 0, got {val}" + ); + } + + #[test] + fn test_transfer_loss_identical_features_is_zero() { + let loss_fn = WiFiDensePoseLoss::new(LossWeights::default()); + let dev = device(); + + let feat = Tensor::ones([2, 64, 8, 8], (Kind::Float, dev)); + let loss = loss_fn.transfer_loss(&feat, &feat); + let val = loss.double_value(&[]) as f32; + + assert!( + val.abs() < 1e-5, + "Transfer loss for identical tensors should be ≈ 0, got {val}" + ); + } + + #[test] + fn test_forward_keypoint_only_returns_weighted_loss() { + let weights = LossWeights { + lambda_kp: 1.0, + lambda_dp: 0.0, + lambda_tr: 0.0, + }; + let loss_fn = WiFiDensePoseLoss::new(weights); + let dev = device(); + + let pred = Tensor::ones([1, 17, 8, 8], (Kind::Float, dev)); + let target = Tensor::ones([1, 17, 8, 8], (Kind::Float, dev)); + let vis = Tensor::ones([1, 17], (Kind::Float, dev)); + + let (_, output) = loss_fn.forward( + &pred, &target, &vis, None, None, None, None, None, None, + ); + + assert!( + output.total.abs() < 1e-5, + "Identical heatmaps with λ_kp=1 should give ≈ 0 total loss, got {}", + output.total + ); + assert!(output.densepose.is_none()); + assert!(output.transfer.is_none()); + } + + #[test] + fn test_densepose_loss_identical_inputs_part_loss_near_zero_uv() { + // For identical pred/target UV the UV loss should be exactly 0. + // The cross-entropy part loss won't be 0 (uniform logits have entropy ≠ 0) + // but the UV component should contribute nothing extra. + let loss_fn = WiFiDensePoseLoss::new(LossWeights::default()); + let dev = device(); + let b = 1_i64; + let h = 4_i64; + let w = 4_i64; + + // pred_parts: all-zero logits (uniform over 25 classes) + let pred_parts = Tensor::zeros([b, 25, h, w], (Kind::Float, dev)); + // target: foreground class 1 everywhere + let target_parts = Tensor::ones([b, h, w], (Kind::Int64, dev)); + // UV: identical pred and target → uv loss = 0 + let uv = Tensor::zeros([b, 48, h, w], (Kind::Float, dev)); + + let loss = loss_fn.densepose_loss(&pred_parts, &target_parts, &uv, &uv); + let val = loss.double_value(&[]) as f32; + + assert!( + val >= 0.0, + "DensePose loss must be non-negative, got {val}" + ); + // With identical UV the total equals only the CE part loss. + // CE of uniform logits over 25 classes: ln(25) ≈ 3.22 + assert!( + val < 5.0, + "DensePose loss with identical UV should be bounded by CE, got {val}" + ); + } +} diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/metrics.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/metrics.rs new file mode 100644 index 0000000..eb96df2 --- /dev/null +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/metrics.rs @@ -0,0 +1,406 @@ +//! Evaluation metrics for WiFi-DensePose training. +//! +//! This module provides: +//! +//! - **PCK\@0.2** (Percentage of Correct Keypoints): a keypoint is considered +//! correct when its Euclidean distance from the ground truth is within 20% +//! of the person bounding-box diagonal. +//! - **OKS** (Object Keypoint Similarity): the COCO-style metric that uses a +//! per-joint exponential kernel with sigmas from the COCO annotation +//! guidelines. +//! +//! Results are accumulated over mini-batches via [`MetricsAccumulator`] and +//! finalized into a [`MetricsResult`] at the end of a validation epoch. +//! +//! # No mock data +//! +//! All computations are grounded in real geometry and follow published metric +//! definitions. No random or synthetic values are introduced at runtime. + +use ndarray::{Array1, Array2}; + +// --------------------------------------------------------------------------- +// COCO keypoint sigmas (17 joints) +// --------------------------------------------------------------------------- + +/// Per-joint sigma values from the COCO keypoint evaluation standard. +/// +/// These constants control the spread of the OKS Gaussian kernel for each +/// of the 17 COCO-defined body joints. +pub const COCO_KP_SIGMAS: [f32; 17] = [ + 0.026, // 0 nose + 0.025, // 1 left_eye + 0.025, // 2 right_eye + 0.035, // 3 left_ear + 0.035, // 4 right_ear + 0.079, // 5 left_shoulder + 0.079, // 6 right_shoulder + 0.072, // 7 left_elbow + 0.072, // 8 right_elbow + 0.062, // 9 left_wrist + 0.062, // 10 right_wrist + 0.107, // 11 left_hip + 0.107, // 12 right_hip + 0.087, // 13 left_knee + 0.087, // 14 right_knee + 0.089, // 15 left_ankle + 0.089, // 16 right_ankle +]; + +// --------------------------------------------------------------------------- +// MetricsResult +// --------------------------------------------------------------------------- + +/// Aggregated evaluation metrics produced by a validation epoch. +/// +/// All metrics are averaged over the full dataset passed to the evaluator. +#[derive(Debug, Clone)] +pub struct MetricsResult { + /// Percentage of Correct Keypoints at threshold 0.2 (0-1 scale). + /// + /// A keypoint is "correct" when its predicted position is within + /// 20% of the ground-truth bounding-box diagonal from the true position. + pub pck: f32, + + /// Object Keypoint Similarity (0-1 scale, COCO standard). + /// + /// OKS is computed per person and averaged across the dataset. + /// Invisible keypoints (`visibility == 0`) are excluded from both + /// numerator and denominator. + pub oks: f32, + + /// Total number of keypoint instances evaluated. + pub num_keypoints: usize, + + /// Total number of samples evaluated. + pub num_samples: usize, +} + +impl MetricsResult { + /// Returns `true` when this result is strictly better than `other` on the + /// primary metric (PCK\@0.2). + pub fn is_better_than(&self, other: &MetricsResult) -> bool { + self.pck > other.pck + } + + /// A human-readable summary line suitable for logging. + pub fn summary(&self) -> String { + format!( + "PCK@0.2={:.4} OKS={:.4} (n_samples={} n_kp={})", + self.pck, self.oks, self.num_samples, self.num_keypoints + ) + } +} + +impl Default for MetricsResult { + fn default() -> Self { + MetricsResult { + pck: 0.0, + oks: 0.0, + num_keypoints: 0, + num_samples: 0, + } + } +} + +// --------------------------------------------------------------------------- +// MetricsAccumulator +// --------------------------------------------------------------------------- + +/// Running accumulator for keypoint metrics across a validation epoch. +/// +/// Call [`MetricsAccumulator::update`] for each mini-batch. After iterating +/// the full dataset call [`MetricsAccumulator::finalize`] to obtain a +/// [`MetricsResult`]. +/// +/// # Thread safety +/// +/// `MetricsAccumulator` is not `Sync`; create one per thread and merge if +/// running multi-threaded evaluation. +pub struct MetricsAccumulator { + /// Cumulative sum of per-sample PCK scores. + pck_sum: f64, + /// Cumulative sum of per-sample OKS scores. + oks_sum: f64, + /// Number of individual keypoint instances that were evaluated. + num_keypoints: usize, + /// Number of samples seen. + num_samples: usize, + /// PCK threshold (fraction of bounding-box diagonal). Default: 0.2. + pck_threshold: f32, +} + +impl MetricsAccumulator { + /// Create a new accumulator with the given PCK threshold. + /// + /// The COCO and many pose papers use `threshold = 0.2` (20% of the + /// person's bounding-box diagonal). + pub fn new(pck_threshold: f32) -> Self { + MetricsAccumulator { + pck_sum: 0.0, + oks_sum: 0.0, + num_keypoints: 0, + num_samples: 0, + pck_threshold, + } + } + + /// Default accumulator with PCK\@0.2. + pub fn default_threshold() -> Self { + Self::new(0.2) + } + + /// Update the accumulator with one sample's predictions. + /// + /// # Arguments + /// + /// - `pred_kp`: `[17, 2]` – predicted keypoint (x, y) in `[0, 1]`. + /// - `gt_kp`: `[17, 2]` – ground-truth keypoint (x, y) in `[0, 1]`. + /// - `visibility`: `[17]` – 0 = invisible, 1/2 = visible. + /// + /// Keypoints with `visibility == 0` are skipped. + pub fn update( + &mut self, + pred_kp: &Array2, + gt_kp: &Array2, + visibility: &Array1, + ) { + let num_joints = pred_kp.shape()[0].min(gt_kp.shape()[0]).min(visibility.len()); + + // Compute bounding-box diagonal from visible ground-truth keypoints. + let bbox_diag = bounding_box_diagonal(gt_kp, visibility, num_joints); + // Guard against degenerate (point) bounding boxes. + let safe_diag = bbox_diag.max(1e-3); + + let mut pck_correct = 0usize; + let mut visible_count = 0usize; + let mut oks_num = 0.0f64; + let mut oks_den = 0.0f64; + + for j in 0..num_joints { + if visibility[j] < 0.5 { + // Invisible joint: skip. + continue; + } + visible_count += 1; + + let dx = pred_kp[[j, 0]] - gt_kp[[j, 0]]; + let dy = pred_kp[[j, 1]] - gt_kp[[j, 1]]; + let dist = (dx * dx + dy * dy).sqrt(); + + // PCK: correct if within threshold × diagonal. + if dist <= self.pck_threshold * safe_diag { + pck_correct += 1; + } + + // OKS contribution for this joint. + let sigma = if j < COCO_KP_SIGMAS.len() { + COCO_KP_SIGMAS[j] + } else { + 0.07 // fallback sigma for non-standard joints + }; + // Normalise distance by (2 × sigma)² × (area = diagonal²). + let two_sigma_sq = 2.0 * (sigma as f64) * (sigma as f64); + let area = (safe_diag as f64) * (safe_diag as f64); + let exp_arg = -(dist as f64 * dist as f64) / (two_sigma_sq * area + 1e-10); + oks_num += exp_arg.exp(); + oks_den += 1.0; + } + + // Per-sample PCK (fraction of visible joints that were correct). + let sample_pck = if visible_count > 0 { + pck_correct as f64 / visible_count as f64 + } else { + 1.0 // No visible joints: trivially correct (no evidence of error). + }; + + // Per-sample OKS. + let sample_oks = if oks_den > 0.0 { + oks_num / oks_den + } else { + 1.0 + }; + + self.pck_sum += sample_pck; + self.oks_sum += sample_oks; + self.num_keypoints += visible_count; + self.num_samples += 1; + } + + /// Finalize and return aggregated metrics. + /// + /// Returns `None` if no samples have been accumulated yet. + pub fn finalize(&self) -> Option { + if self.num_samples == 0 { + return None; + } + let n = self.num_samples as f64; + Some(MetricsResult { + pck: (self.pck_sum / n) as f32, + oks: (self.oks_sum / n) as f32, + num_keypoints: self.num_keypoints, + num_samples: self.num_samples, + }) + } + + /// Return the accumulated sample count. + pub fn num_samples(&self) -> usize { + self.num_samples + } + + /// Reset the accumulator to the initial (empty) state. + pub fn reset(&mut self) { + self.pck_sum = 0.0; + self.oks_sum = 0.0; + self.num_keypoints = 0; + self.num_samples = 0; + } +} + +// --------------------------------------------------------------------------- +// Geometric helpers +// --------------------------------------------------------------------------- + +/// Compute the Euclidean diagonal of the bounding box of visible keypoints. +/// +/// The bounding box is defined by the axis-aligned extent of all keypoints +/// that have `visibility[j] >= 0.5`. Returns 0.0 if there are no visible +/// keypoints or all are co-located. +fn bounding_box_diagonal( + kp: &Array2, + visibility: &Array1, + num_joints: usize, +) -> f32 { + let mut x_min = f32::MAX; + let mut x_max = f32::MIN; + let mut y_min = f32::MAX; + let mut y_max = f32::MIN; + let mut any_visible = false; + + for j in 0..num_joints { + if visibility[j] >= 0.5 { + let x = kp[[j, 0]]; + let y = kp[[j, 1]]; + x_min = x_min.min(x); + x_max = x_max.max(x); + y_min = y_min.min(y); + y_max = y_max.max(y); + any_visible = true; + } + } + + if !any_visible { + return 0.0; + } + + let w = (x_max - x_min).max(0.0); + let h = (y_max - y_min).max(0.0); + (w * w + h * h).sqrt() +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use ndarray::{array, Array1, Array2}; + use approx::assert_abs_diff_eq; + + fn perfect_prediction(n_joints: usize) -> (Array2, Array2, Array1) { + let gt = Array2::from_shape_fn((n_joints, 2), |(j, c)| { + if c == 0 { j as f32 * 0.05 } else { j as f32 * 0.04 } + }); + let vis = Array1::from_elem(n_joints, 2.0_f32); + (gt.clone(), gt, vis) + } + + #[test] + fn perfect_pck_is_one() { + let (pred, gt, vis) = perfect_prediction(17); + let mut acc = MetricsAccumulator::default_threshold(); + acc.update(&pred, >, &vis); + let result = acc.finalize().unwrap(); + assert_abs_diff_eq!(result.pck, 1.0_f32, epsilon = 1e-5); + } + + #[test] + fn perfect_oks_is_one() { + let (pred, gt, vis) = perfect_prediction(17); + let mut acc = MetricsAccumulator::default_threshold(); + acc.update(&pred, >, &vis); + let result = acc.finalize().unwrap(); + assert_abs_diff_eq!(result.oks, 1.0_f32, epsilon = 1e-5); + } + + #[test] + fn all_invisible_gives_trivial_pck() { + let mut acc = MetricsAccumulator::default_threshold(); + let pred = Array2::zeros((17, 2)); + let gt = Array2::zeros((17, 2)); + let vis = Array1::zeros(17); + acc.update(&pred, >, &vis); + let result = acc.finalize().unwrap(); + // No visible joints → trivially "perfect" (no errors to measure) + assert_abs_diff_eq!(result.pck, 1.0_f32, epsilon = 1e-5); + } + + #[test] + fn far_predictions_reduce_pck() { + let mut acc = MetricsAccumulator::default_threshold(); + // Ground truth: all at (0.5, 0.5) + let gt = Array2::from_elem((17, 2), 0.5_f32); + // Predictions: all at (0.0, 0.0) — far from ground truth + let pred = Array2::zeros((17, 2)); + let vis = Array1::from_elem(17, 2.0_f32); + acc.update(&pred, >, &vis); + let result = acc.finalize().unwrap(); + // PCK should be well below 1.0 + assert!(result.pck < 0.5, "PCK should be low for wrong predictions, got {}", result.pck); + } + + #[test] + fn accumulator_averages_over_samples() { + let mut acc = MetricsAccumulator::default_threshold(); + for _ in 0..5 { + let (pred, gt, vis) = perfect_prediction(17); + acc.update(&pred, >, &vis); + } + assert_eq!(acc.num_samples(), 5); + let result = acc.finalize().unwrap(); + assert_abs_diff_eq!(result.pck, 1.0_f32, epsilon = 1e-5); + } + + #[test] + fn empty_accumulator_returns_none() { + let acc = MetricsAccumulator::default_threshold(); + assert!(acc.finalize().is_none()); + } + + #[test] + fn reset_clears_state() { + let mut acc = MetricsAccumulator::default_threshold(); + let (pred, gt, vis) = perfect_prediction(17); + acc.update(&pred, >, &vis); + acc.reset(); + assert_eq!(acc.num_samples(), 0); + assert!(acc.finalize().is_none()); + } + + #[test] + fn bbox_diagonal_unit_square() { + let kp = array![[0.0_f32, 0.0], [1.0, 1.0]]; + let vis = array![2.0_f32, 2.0]; + let diag = bounding_box_diagonal(&kp, &vis, 2); + assert_abs_diff_eq!(diag, std::f32::consts::SQRT_2, epsilon = 1e-5); + } + + #[test] + fn metrics_result_is_better_than() { + let good = MetricsResult { pck: 0.9, oks: 0.8, num_keypoints: 100, num_samples: 10 }; + let bad = MetricsResult { pck: 0.5, oks: 0.4, num_keypoints: 100, num_samples: 10 }; + assert!(good.is_better_than(&bad)); + assert!(!bad.is_better_than(&good)); + } +} diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/model.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/model.rs new file mode 100644 index 0000000..cfeba62 --- /dev/null +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/model.rs @@ -0,0 +1,16 @@ +//! WiFi-DensePose model definition and construction. +//! +//! This module will be implemented by the trainer agent. It currently provides +//! the public interface stubs so that the crate compiles as a whole. + +/// Placeholder for the compiled model handle. +/// +/// The real implementation wraps a `tch::CModule` or a custom `nn::Module`. +pub struct DensePoseModel; + +impl DensePoseModel { + /// Construct a new model from the given number of subcarriers and keypoints. + pub fn new(_num_subcarriers: usize, _num_keypoints: usize) -> Self { + DensePoseModel + } +} diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/proof.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/proof.rs new file mode 100644 index 0000000..0c6a0c1 --- /dev/null +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/proof.rs @@ -0,0 +1,9 @@ +//! Proof-of-concept utilities and verification helpers. +//! +//! This module will be implemented by the trainer agent. It currently provides +//! the public interface stubs so that the crate compiles as a whole. + +/// Verify that a checkpoint directory exists and is writable. +pub fn verify_checkpoint_dir(path: &std::path::Path) -> bool { + path.exists() && path.is_dir() +} diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/subcarrier.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/subcarrier.rs new file mode 100644 index 0000000..da03e28 --- /dev/null +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/subcarrier.rs @@ -0,0 +1,266 @@ +//! Subcarrier interpolation and selection utilities. +//! +//! This module provides functions to resample CSI subcarrier arrays between +//! different subcarrier counts using linear interpolation, and to select +//! the most informative subcarriers based on signal variance. +//! +//! # Example +//! +//! ```rust +//! use wifi_densepose_train::subcarrier::interpolate_subcarriers; +//! use ndarray::Array4; +//! +//! // Resample from 114 → 56 subcarriers +//! let arr = Array4::::zeros((100, 3, 3, 114)); +//! let resampled = interpolate_subcarriers(&arr, 56); +//! assert_eq!(resampled.shape(), &[100, 3, 3, 56]); +//! ``` + +use ndarray::{Array4, s}; + +// --------------------------------------------------------------------------- +// interpolate_subcarriers +// --------------------------------------------------------------------------- + +/// Resample a 4-D CSI array along the subcarrier axis (last dimension) to +/// `target_sc` subcarriers using linear interpolation. +/// +/// # Arguments +/// +/// - `arr`: Input array with shape `[T, n_tx, n_rx, n_sc]`. +/// - `target_sc`: Number of output subcarriers. +/// +/// # Returns +/// +/// A new array with shape `[T, n_tx, n_rx, target_sc]`. +/// +/// # Panics +/// +/// Panics if `target_sc == 0` or the input has no subcarrier dimension. +pub fn interpolate_subcarriers(arr: &Array4, target_sc: usize) -> Array4 { + assert!(target_sc > 0, "target_sc must be > 0"); + + let shape = arr.shape(); + let (n_t, n_tx, n_rx, n_sc) = (shape[0], shape[1], shape[2], shape[3]); + + if n_sc == target_sc { + return arr.clone(); + } + + let mut out = Array4::::zeros((n_t, n_tx, n_rx, target_sc)); + + // Precompute interpolation weights once. + let weights = compute_interp_weights(n_sc, target_sc); + + for t in 0..n_t { + for tx in 0..n_tx { + for rx in 0..n_rx { + let src = arr.slice(s![t, tx, rx, ..]); + let src_slice = src.as_slice().unwrap_or_else(|| { + // Fallback: copy to a contiguous slice + // (this path is hit when the array has a non-contiguous layout) + // In practice ndarray arrays sliced along last dim are contiguous. + panic!("Subcarrier slice is not contiguous"); + }); + + for (k, &(i0, i1, w)) in weights.iter().enumerate() { + let v = src_slice[i0] * (1.0 - w) + src_slice[i1] * w; + out[[t, tx, rx, k]] = v; + } + } + } + } + + out +} + +// --------------------------------------------------------------------------- +// compute_interp_weights +// --------------------------------------------------------------------------- + +/// Compute linear interpolation indices and fractional weights for resampling +/// from `src_sc` to `target_sc` subcarriers. +/// +/// Returns a `Vec` of `(i0, i1, frac)` tuples where each output subcarrier `k` +/// is computed as `src[i0] * (1 - frac) + src[i1] * frac`. +/// +/// # Arguments +/// +/// - `src_sc`: Number of subcarriers in the source array. +/// - `target_sc`: Number of subcarriers in the output array. +/// +/// # Panics +/// +/// Panics if `src_sc == 0` or `target_sc == 0`. +pub fn compute_interp_weights(src_sc: usize, target_sc: usize) -> Vec<(usize, usize, f32)> { + assert!(src_sc > 0, "src_sc must be > 0"); + assert!(target_sc > 0, "target_sc must be > 0"); + + let mut weights = Vec::with_capacity(target_sc); + + for k in 0..target_sc { + // Map output index k to a continuous position in the source array. + // Scale so that index 0 maps to 0 and index (target_sc-1) maps to + // (src_sc-1) — i.e., endpoints are preserved. + let pos = if target_sc == 1 { + 0.0f32 + } else { + k as f32 * (src_sc - 1) as f32 / (target_sc - 1) as f32 + }; + + let i0 = (pos.floor() as usize).min(src_sc - 1); + let i1 = (pos.ceil() as usize).min(src_sc - 1); + let frac = pos - pos.floor(); + + weights.push((i0, i1, frac)); + } + + weights +} + +// --------------------------------------------------------------------------- +// select_subcarriers_by_variance +// --------------------------------------------------------------------------- + +/// Select the `k` most informative subcarrier indices based on temporal variance. +/// +/// Computes the variance of each subcarrier across the time and antenna +/// dimensions, then returns the indices of the `k` subcarriers with the +/// highest variance, sorted in ascending order. +/// +/// # Arguments +/// +/// - `arr`: Input array with shape `[T, n_tx, n_rx, n_sc]`. +/// - `k`: Number of subcarriers to select. +/// +/// # Returns +/// +/// A `Vec` of length `k` with the selected subcarrier indices (ascending). +/// +/// # Panics +/// +/// Panics if `k == 0` or `k > n_sc`. +pub fn select_subcarriers_by_variance(arr: &Array4, k: usize) -> Vec { + let shape = arr.shape(); + let n_sc = shape[3]; + + assert!(k > 0, "k must be > 0"); + assert!(k <= n_sc, "k ({k}) must be <= n_sc ({n_sc})"); + + let total_elems = shape[0] * shape[1] * shape[2]; + + // Compute mean per subcarrier. + let mut means = vec![0.0f64; n_sc]; + for sc in 0..n_sc { + let col = arr.slice(s![.., .., .., sc]); + let sum: f64 = col.iter().map(|&v| v as f64).sum(); + means[sc] = sum / total_elems as f64; + } + + // Compute variance per subcarrier. + let mut variances = vec![0.0f64; n_sc]; + for sc in 0..n_sc { + let col = arr.slice(s![.., .., .., sc]); + let mean = means[sc]; + let var: f64 = col.iter().map(|&v| (v as f64 - mean).powi(2)).sum::() + / total_elems as f64; + variances[sc] = var; + } + + // Rank subcarriers by descending variance. + let mut ranked: Vec = (0..n_sc).collect(); + ranked.sort_by(|&a, &b| variances[b].partial_cmp(&variances[a]).unwrap_or(std::cmp::Ordering::Equal)); + + // Take top-k and sort ascending for a canonical representation. + let mut selected: Vec = ranked[..k].to_vec(); + selected.sort_unstable(); + selected +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use approx::assert_abs_diff_eq; + + #[test] + fn identity_resample() { + let arr = Array4::::from_shape_fn((4, 3, 3, 56), |(t, tx, rx, k)| { + (t + tx + rx + k) as f32 + }); + let out = interpolate_subcarriers(&arr, 56); + assert_eq!(out.shape(), arr.shape()); + // Identity resample must preserve all values exactly. + for v in arr.iter().zip(out.iter()) { + assert_abs_diff_eq!(v.0, v.1, epsilon = 1e-6); + } + } + + #[test] + fn upsample_endpoints_preserved() { + // When resampling from 4 → 8 the first and last values are exact. + let arr = Array4::::from_shape_fn((1, 1, 1, 4), |(_, _, _, k)| k as f32); + let out = interpolate_subcarriers(&arr, 8); + assert_eq!(out.shape(), &[1, 1, 1, 8]); + assert_abs_diff_eq!(out[[0, 0, 0, 0]], 0.0_f32, epsilon = 1e-6); + assert_abs_diff_eq!(out[[0, 0, 0, 7]], 3.0_f32, epsilon = 1e-6); + } + + #[test] + fn downsample_endpoints_preserved() { + // Downsample from 8 → 4. + let arr = Array4::::from_shape_fn((1, 1, 1, 8), |(_, _, _, k)| k as f32 * 2.0); + let out = interpolate_subcarriers(&arr, 4); + assert_eq!(out.shape(), &[1, 1, 1, 4]); + // First value: 0.0, last value: 14.0 + assert_abs_diff_eq!(out[[0, 0, 0, 0]], 0.0_f32, epsilon = 1e-5); + assert_abs_diff_eq!(out[[0, 0, 0, 3]], 14.0_f32, epsilon = 1e-5); + } + + #[test] + fn compute_interp_weights_identity() { + let w = compute_interp_weights(5, 5); + assert_eq!(w.len(), 5); + for (k, &(i0, i1, frac)) in w.iter().enumerate() { + assert_eq!(i0, k); + assert_eq!(i1, k); + assert_abs_diff_eq!(frac, 0.0_f32, epsilon = 1e-6); + } + } + + #[test] + fn select_subcarriers_returns_correct_count() { + let arr = Array4::::from_shape_fn((10, 3, 3, 56), |(t, _, _, k)| { + (t * k) as f32 + }); + let selected = select_subcarriers_by_variance(&arr, 8); + assert_eq!(selected.len(), 8); + } + + #[test] + fn select_subcarriers_sorted_ascending() { + let arr = Array4::::from_shape_fn((10, 3, 3, 56), |(t, _, _, k)| { + (t * k) as f32 + }); + let selected = select_subcarriers_by_variance(&arr, 10); + for w in selected.windows(2) { + assert!(w[0] < w[1], "Indices must be sorted ascending"); + } + } + + #[test] + fn select_subcarriers_all_same_returns_all() { + // When all subcarriers have zero variance, the function should still + // return k valid indices. + let arr = Array4::::ones((5, 2, 2, 20)); + let selected = select_subcarriers_by_variance(&arr, 5); + assert_eq!(selected.len(), 5); + // All selected indices must be in [0, 19] + for &idx in &selected { + assert!(idx < 20); + } + } +} diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/trainer.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/trainer.rs new file mode 100644 index 0000000..d543cc7 --- /dev/null +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/trainer.rs @@ -0,0 +1,24 @@ +//! Training loop orchestrator. +//! +//! This module will be implemented by the trainer agent. It currently provides +//! the public interface stubs so that the crate compiles as a whole. + +use crate::config::TrainingConfig; + +/// Orchestrates the full training loop: data loading, forward pass, loss +/// computation, back-propagation, validation, and checkpointing. +pub struct Trainer { + config: TrainingConfig, +} + +impl Trainer { + /// Create a new `Trainer` from the given configuration. + pub fn new(config: TrainingConfig) -> Self { + Trainer { config } + } + + /// Return a reference to the active training configuration. + pub fn config(&self) -> &TrainingConfig { + &self.config + } +}