diff --git a/rust-port/wifi-densepose-rs/Cargo.toml b/rust-port/wifi-densepose-rs/Cargo.toml index 0641447..6eee3f1 100644 --- a/rust-port/wifi-densepose-rs/Cargo.toml +++ b/rust-port/wifi-densepose-rs/Cargo.toml @@ -11,6 +11,7 @@ members = [ "crates/wifi-densepose-wasm", "crates/wifi-densepose-cli", "crates/wifi-densepose-mat", + "crates/wifi-densepose-train", ] [workspace.package] @@ -73,6 +74,25 @@ getrandom = { version = "0.2", features = ["js"] } serialport = "4.3" pcap = "1.1" +# Graph algorithms (for min-cut assignment in metrics) +petgraph = "0.6" + +# Data loading +ndarray-npy = "0.8" +walkdir = "2.4" + +# Hashing (for proof) +sha2 = "0.10" + +# CSV logging +csv = "1.3" + +# Progress bars +indicatif = "0.17" + +# CLI +clap = { version = "4.4", features = ["derive"] } + # Testing criterion = { version = "0.5", features = ["html_reports"] } proptest = "1.4" diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/error.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/error.rs index 1fbb230..d7f3fcd 100644 --- a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/error.rs +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/error.rs @@ -1,12 +1,24 @@ //! Error types for the WiFi-DensePose training pipeline. //! -//! This module defines a hierarchy of errors covering every failure mode in -//! the training pipeline: configuration validation, dataset I/O, subcarrier -//! interpolation, and top-level training orchestration. +//! This module provides: +//! +//! - [`TrainError`]: top-level error aggregating all training failure modes. +//! - [`TrainResult`]: convenient `Result` alias using `TrainError`. +//! +//! Module-local error types live in their respective modules: +//! +//! - [`crate::config::ConfigError`]: configuration validation errors. +//! - [`crate::dataset::DatasetError`]: dataset loading/access errors. +//! +//! All are re-exported at the crate root for ergonomic use. use thiserror::Error; use std::path::PathBuf; +// Import module-local error types so TrainError can wrap them via #[from]. +use crate::config::ConfigError; +use crate::dataset::DatasetError; + // --------------------------------------------------------------------------- // Top-level training error // --------------------------------------------------------------------------- @@ -16,8 +28,9 @@ pub type TrainResult = Result; /// Top-level error type for the training pipeline. /// -/// Every public function in this crate that can fail returns -/// `TrainResult`, which is `Result`. +/// Every orchestration-level function returns `TrainResult`. Lower-level +/// functions in [`crate::config`] and [`crate::dataset`] return their own +/// module-specific error types which are automatically coerced via `#[from]`. #[derive(Debug, Error)] pub enum TrainError { /// Configuration is invalid or internally inconsistent. @@ -28,10 +41,6 @@ pub enum TrainError { #[error("Dataset error: {0}")] Dataset(#[from] DatasetError), - /// Subcarrier interpolation / resampling failed. - #[error("Subcarrier interpolation error: {0}")] - Subcarrier(#[from] SubcarrierError), - /// An underlying I/O error not covered by a more specific variant. #[error("I/O error: {0}")] Io(#[from] std::io::Error), @@ -40,14 +49,6 @@ pub enum TrainError { #[error("JSON error: {0}")] Json(#[from] serde_json::Error), - /// TOML (de)serialization error. - #[error("TOML deserialization error: {0}")] - TomlDe(#[from] toml::de::Error), - - /// TOML serialization error. - #[error("TOML serialization error: {0}")] - TomlSer(#[from] toml::ser::Error), - /// An operation was attempted on an empty dataset. #[error("Dataset is empty")] EmptyDataset, @@ -112,273 +113,3 @@ impl TrainError { TrainError::ShapeMismatch { expected, actual } } } - -// --------------------------------------------------------------------------- -// Configuration errors -// --------------------------------------------------------------------------- - -/// Errors produced when validating or loading a [`TrainingConfig`]. -/// -/// [`TrainingConfig`]: crate::config::TrainingConfig -#[derive(Debug, Error)] -pub enum ConfigError { - /// A required field has a value that violates a constraint. - #[error("Invalid value for field `{field}`: {reason}")] - InvalidValue { - /// Name of the configuration field. - field: &'static str, - /// Human-readable reason the value is invalid. - reason: String, - }, - - /// The configuration file could not be read. - #[error("Cannot read configuration file `{path}`: {source}")] - FileRead { - /// Path that was being read. - path: PathBuf, - /// Underlying I/O error. - #[source] - source: std::io::Error, - }, - - /// The configuration file contains invalid TOML. - #[error("Cannot parse configuration file `{path}`: {source}")] - ParseError { - /// Path that was being parsed. - path: PathBuf, - /// Underlying TOML parse error. - #[source] - source: toml::de::Error, - }, - - /// A path specified in the config does not exist. - #[error("Path `{path}` specified in config does not exist")] - PathNotFound { - /// The missing path. - path: PathBuf, - }, -} - -impl ConfigError { - /// Construct an [`ConfigError::InvalidValue`] error. - pub fn invalid_value>(field: &'static str, reason: S) -> Self { - ConfigError::InvalidValue { - field, - reason: reason.into(), - } - } -} - -// --------------------------------------------------------------------------- -// Dataset errors -// --------------------------------------------------------------------------- - -/// Errors produced while loading or accessing dataset samples. -#[derive(Debug, Error)] -pub enum DatasetError { - /// The requested data file or directory was not found. - /// - /// Production training data is mandatory; this error is never silently - /// suppressed. Use [`SyntheticDataset`] only for proof/testing. - /// - /// [`SyntheticDataset`]: crate::dataset::SyntheticDataset - #[error("Data not found at `{path}`: {message}")] - DataNotFound { - /// Path that was expected to contain data. - path: PathBuf, - /// Additional context. - message: String, - }, - - /// A file was found but its format is incorrect or unexpected. - /// - /// This covers malformed numpy arrays, unexpected shapes, bad JSON - /// metadata, etc. - #[error("Invalid data format in `{path}`: {message}")] - InvalidFormat { - /// Path of the malformed file. - path: PathBuf, - /// Description of the format problem. - message: String, - }, - - /// A low-level I/O error while reading a data file. - #[error("I/O error reading `{path}`: {source}")] - IoError { - /// Path being read when the error occurred. - path: PathBuf, - /// Underlying I/O error. - #[source] - source: std::io::Error, - }, - - /// The number of subcarriers in the data file does not match the - /// configuration expectation (before or after interpolation). - #[error( - "Subcarrier count mismatch in `{path}`: \ - file has {found} subcarriers, expected {expected}" - )] - SubcarrierMismatch { - /// Path of the offending file. - path: PathBuf, - /// Number of subcarriers found in the file. - found: usize, - /// Number of subcarriers expected by the configuration. - expected: usize, - }, - - /// A sample index was out of bounds. - #[error("Index {index} is out of bounds for dataset of length {len}")] - IndexOutOfBounds { - /// The requested index. - index: usize, - /// Total number of samples. - len: usize, - }, - - /// A numpy array could not be read. - #[error("NumPy array read error in `{path}`: {message}")] - NpyReadError { - /// Path of the `.npy` file. - path: PathBuf, - /// Error description. - message: String, - }, - - /// A metadata file (e.g., `meta.json`) is missing or malformed. - #[error("Metadata error for subject {subject_id}: {message}")] - MetadataError { - /// Subject whose metadata could not be read. - subject_id: u32, - /// Description of the problem. - message: String, - }, - - /// No subjects matching the requested IDs were found in the data directory. - #[error( - "No subjects found in `{data_dir}` matching the requested IDs: {requested:?}" - )] - NoSubjectsFound { - /// Root data directory that was scanned. - data_dir: PathBuf, - /// Subject IDs that were requested. - requested: Vec, - }, - - /// A subcarrier interpolation error occurred during sample loading. - #[error("Subcarrier interpolation failed while loading sample {sample_idx}: {source}")] - InterpolationError { - /// The sample index being loaded. - sample_idx: usize, - /// Underlying interpolation error. - #[source] - source: SubcarrierError, - }, -} - -impl DatasetError { - /// Construct a [`DatasetError::DataNotFound`] error. - pub fn not_found>(path: impl Into, msg: S) -> Self { - DatasetError::DataNotFound { - path: path.into(), - message: msg.into(), - } - } - - /// Construct a [`DatasetError::InvalidFormat`] error. - pub fn invalid_format>(path: impl Into, msg: S) -> Self { - DatasetError::InvalidFormat { - path: path.into(), - message: msg.into(), - } - } - - /// Construct a [`DatasetError::IoError`] error. - pub fn io_error(path: impl Into, source: std::io::Error) -> Self { - DatasetError::IoError { - path: path.into(), - source, - } - } - - /// Construct a [`DatasetError::SubcarrierMismatch`] error. - pub fn subcarrier_mismatch(path: impl Into, found: usize, expected: usize) -> Self { - DatasetError::SubcarrierMismatch { - path: path.into(), - found, - expected, - } - } - - /// Construct a [`DatasetError::NpyReadError`] error. - pub fn npy_read>(path: impl Into, msg: S) -> Self { - DatasetError::NpyReadError { - path: path.into(), - message: msg.into(), - } - } -} - -// --------------------------------------------------------------------------- -// Subcarrier interpolation errors -// --------------------------------------------------------------------------- - -/// Errors produced by the subcarrier resampling functions. -#[derive(Debug, Error)] -pub enum SubcarrierError { - /// The source or destination subcarrier count is zero. - #[error("Subcarrier count must be at least 1, got {count}")] - ZeroCount { - /// The offending count. - count: usize, - }, - - /// The input array has an unexpected shape. - #[error( - "Input array shape mismatch: expected last dimension {expected_sc}, \ - got {actual_sc} (full shape: {shape:?})" - )] - InputShapeMismatch { - /// Expected number of subcarriers (last dimension). - expected_sc: usize, - /// Actual number of subcarriers found. - actual_sc: usize, - /// Full shape of the input array. - shape: Vec, - }, - - /// The requested interpolation method is not implemented. - #[error("Interpolation method `{method}` is not yet implemented")] - MethodNotImplemented { - /// Name of the unimplemented method. - method: String, - }, - - /// Source and destination subcarrier counts are already equal. - /// - /// Callers should check [`TrainingConfig::needs_subcarrier_interp`] before - /// calling the interpolation routine to avoid this error. - /// - /// [`TrainingConfig::needs_subcarrier_interp`]: - /// crate::config::TrainingConfig::needs_subcarrier_interp - #[error( - "Source and destination subcarrier counts are equal ({count}); \ - no interpolation is needed" - )] - NopInterpolation { - /// The equal count. - count: usize, - }, - - /// A numerical error occurred during interpolation (e.g., division by zero - /// due to coincident knot positions). - #[error("Numerical error during interpolation: {0}")] - NumericalError(String), -} - -impl SubcarrierError { - /// Construct a [`SubcarrierError::NumericalError`]. - pub fn numerical>(msg: S) -> Self { - SubcarrierError::NumericalError(msg.into()) - } -} diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/lib.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/lib.rs index d1b915c..b55d787 100644 --- a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/lib.rs +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/lib.rs @@ -52,9 +52,9 @@ pub mod subcarrier; pub mod trainer; // Convenient re-exports at the crate root. -pub use config::TrainingConfig; -pub use dataset::{CsiDataset, CsiSample, DataLoader, MmFiDataset, SyntheticCsiDataset, SyntheticConfig}; -pub use error::{ConfigError, DatasetError, SubcarrierError, TrainError, TrainResult}; +pub use config::{ConfigError, TrainingConfig}; +pub use dataset::{CsiDataset, CsiSample, DataLoader, DatasetError, MmFiDataset, SyntheticCsiDataset, SyntheticConfig}; +pub use error::{TrainError, TrainResult}; pub use subcarrier::{compute_interp_weights, interpolate_subcarriers, select_subcarriers_by_variance}; /// Crate version string. diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/losses.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/losses.rs index a8e8f28..0fe343c 100644 --- a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/losses.rs +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/losses.rs @@ -906,4 +906,148 @@ mod tests { "DensePose loss with identical UV should be bounded by CE, got {val}" ); } + + // ── Standalone functional API tests ────────────────────────────────────── + + #[test] + fn test_fn_keypoint_heatmap_loss_identical_zero() { + let dev = device(); + let t = Tensor::ones([2, 17, 8, 8], (Kind::Float, dev)); + let loss = keypoint_heatmap_loss(&t, &t); + let v = loss.double_value(&[]) as f32; + assert!(v.abs() < 1e-6, "Identical heatmaps → loss must be ≈0, got {v}"); + } + + #[test] + fn test_fn_generate_gaussian_heatmaps_shape() { + let dev = device(); + let kpts = Tensor::full(&[2i64, 17, 2], 0.5, (Kind::Float, dev)); + let vis = Tensor::ones(&[2i64, 17], (Kind::Float, dev)); + let hm = generate_gaussian_heatmaps(&kpts, &vis, 16, 2.0); + assert_eq!(hm.size(), [2, 17, 16, 16]); + } + + #[test] + fn test_fn_generate_gaussian_heatmaps_invisible_zero() { + let dev = device(); + let kpts = Tensor::full(&[1i64, 17, 2], 0.5, (Kind::Float, dev)); + let vis = Tensor::zeros(&[1i64, 17], (Kind::Float, dev)); // all invisible + let hm = generate_gaussian_heatmaps(&kpts, &vis, 8, 2.0); + let total: f64 = hm.sum(Kind::Float).double_value(&[]); + assert_eq!(total, 0.0, "All-invisible heatmaps must be zero"); + } + + #[test] + fn test_fn_generate_gaussian_heatmaps_peak_near_one() { + let dev = device(); + // Keypoint at (0.5, 0.5) on an 8×8 map. + let kpts = Tensor::full(&[1i64, 1, 2], 0.5, (Kind::Float, dev)); + let vis = Tensor::ones(&[1i64, 1], (Kind::Float, dev)); + let hm = generate_gaussian_heatmaps(&kpts, &vis, 8, 1.5); + let max_val: f64 = hm.max().double_value(&[]); + assert!(max_val > 0.9, "Peak value {max_val} should be > 0.9"); + } + + #[test] + fn test_fn_densepose_part_loss_returns_finite() { + let dev = device(); + let logits = Tensor::zeros(&[1i64, 25, 4, 4], (Kind::Float, dev)); + let labels = Tensor::zeros(&[1i64, 4, 4], (Kind::Int64, dev)); + let loss = densepose_part_loss(&logits, &labels); + let v = loss.double_value(&[]); + assert!(v.is_finite() && v >= 0.0); + } + + #[test] + fn test_fn_densepose_uv_loss_no_annotated_pixels_zero() { + let dev = device(); + let pred = Tensor::ones(&[1i64, 48, 4, 4], (Kind::Float, dev)); + let gt = Tensor::zeros(&[1i64, 48, 4, 4], (Kind::Float, dev)); + let labels = Tensor::full(&[1i64, 4, 4], -1i64, (Kind::Int64, dev)); + let loss = densepose_uv_loss(&pred, >, &labels); + let v = loss.double_value(&[]); + assert_eq!(v, 0.0, "No annotated pixels → UV loss must be 0"); + } + + #[test] + fn test_fn_densepose_uv_loss_identical_zero() { + let dev = device(); + let t = Tensor::ones(&[1i64, 48, 4, 4], (Kind::Float, dev)); + let labels = Tensor::zeros(&[1i64, 4, 4], (Kind::Int64, dev)); + let loss = densepose_uv_loss(&t, &t, &labels); + let v = loss.double_value(&[]); + assert!(v.abs() < 1e-6, "Identical UV → loss ≈ 0, got {v}"); + } + + #[test] + fn test_fn_transfer_loss_identical_zero() { + let dev = device(); + let t = Tensor::ones(&[2i64, 64, 8, 8], (Kind::Float, dev)); + let loss = fn_transfer_loss(&t, &t); + let v = loss.double_value(&[]); + assert!(v.abs() < 1e-6, "Identical features → transfer loss ≈ 0, got {v}"); + } + + #[test] + fn test_fn_transfer_loss_spatial_mismatch() { + let dev = device(); + let student = Tensor::ones(&[1i64, 64, 16, 16], (Kind::Float, dev)); + let teacher = Tensor::ones(&[1i64, 64, 8, 8], (Kind::Float, dev)); + let loss = fn_transfer_loss(&student, &teacher); + let v = loss.double_value(&[]); + assert!(v.is_finite() && v >= 0.0, "Spatial-mismatch transfer loss must be finite"); + } + + #[test] + fn test_fn_transfer_loss_channel_mismatch_divisible() { + let dev = device(); + let student = Tensor::ones(&[1i64, 128, 8, 8], (Kind::Float, dev)); + let teacher = Tensor::ones(&[1i64, 64, 8, 8], (Kind::Float, dev)); + let loss = fn_transfer_loss(&student, &teacher); + let v = loss.double_value(&[]); + assert!(v.is_finite() && v >= 0.0); + } + + #[test] + fn test_compute_losses_keypoint_only() { + let dev = device(); + let pred = Tensor::ones(&[1i64, 17, 8, 8], (Kind::Float, dev)); + let gt = Tensor::ones(&[1i64, 17, 8, 8], (Kind::Float, dev)); + let out = compute_losses(&pred, >, None, None, None, None, None, None, + 1.0, 1.0, 1.0); + assert!(out.total.is_finite()); + assert!(out.keypoint >= 0.0); + assert!(out.densepose_parts.is_none()); + assert!(out.densepose_uv.is_none()); + assert!(out.transfer.is_none()); + } + + #[test] + fn test_compute_losses_all_components_finite() { + let dev = device(); + let b = 1i64; + let h = 4i64; + let w = 4i64; + let pred_kpt = Tensor::ones(&[b, 17, h, w], (Kind::Float, dev)); + let gt_kpt = Tensor::ones(&[b, 17, h, w], (Kind::Float, dev)); + let logits = Tensor::zeros(&[b, 25, h, w], (Kind::Float, dev)); + let labels = Tensor::zeros(&[b, h, w], (Kind::Int64, dev)); + let pred_uv = Tensor::ones(&[b, 48, h, w], (Kind::Float, dev)); + let gt_uv = Tensor::ones(&[b, 48, h, w], (Kind::Float, dev)); + let sf = Tensor::ones(&[b, 64, 2, 2], (Kind::Float, dev)); + let tf = Tensor::ones(&[b, 64, 2, 2], (Kind::Float, dev)); + + let out = compute_losses( + &pred_kpt, >_kpt, + Some(&logits), Some(&labels), + Some(&pred_uv), Some(>_uv), + Some(&sf), Some(&tf), + 1.0, 0.5, 0.1, + ); + + assert!(out.total.is_finite() && out.total >= 0.0); + assert!(out.densepose_parts.is_some()); + assert!(out.densepose_uv.is_some()); + assert!(out.transfer.is_some()); + } } diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/tests/test_config.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/tests/test_config.rs new file mode 100644 index 0000000..e9928f0 --- /dev/null +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/tests/test_config.rs @@ -0,0 +1,458 @@ +//! Integration tests for [`wifi_densepose_train::config`]. +//! +//! All tests are deterministic: they use only fixed values and the +//! `TrainingConfig::default()` constructor. No OS entropy or `rand` crate +//! is used. + +use wifi_densepose_train::config::TrainingConfig; + +// --------------------------------------------------------------------------- +// Default config invariants +// --------------------------------------------------------------------------- + +/// The default configuration must pass its own validation. +#[test] +fn default_config_is_valid() { + let cfg = TrainingConfig::default(); + cfg.validate() + .expect("default TrainingConfig must be valid"); +} + +/// Every numeric field in the default config must be strictly positive where +/// the domain requires it. +#[test] +fn default_config_all_positive_fields() { + let cfg = TrainingConfig::default(); + + assert!(cfg.num_subcarriers > 0, "num_subcarriers must be > 0"); + assert!(cfg.native_subcarriers > 0, "native_subcarriers must be > 0"); + assert!(cfg.num_antennas_tx > 0, "num_antennas_tx must be > 0"); + assert!(cfg.num_antennas_rx > 0, "num_antennas_rx must be > 0"); + assert!(cfg.window_frames > 0, "window_frames must be > 0"); + assert!(cfg.heatmap_size > 0, "heatmap_size must be > 0"); + assert!(cfg.num_keypoints > 0, "num_keypoints must be > 0"); + assert!(cfg.num_body_parts > 0, "num_body_parts must be > 0"); + assert!(cfg.backbone_channels > 0, "backbone_channels must be > 0"); + assert!(cfg.batch_size > 0, "batch_size must be > 0"); + assert!(cfg.learning_rate > 0.0, "learning_rate must be > 0.0"); + assert!(cfg.weight_decay >= 0.0, "weight_decay must be >= 0.0"); + assert!(cfg.num_epochs > 0, "num_epochs must be > 0"); + assert!(cfg.grad_clip_norm > 0.0, "grad_clip_norm must be > 0.0"); +} + +/// The three loss weights in the default config must all be non-negative and +/// their sum must be positive (not all zero). +#[test] +fn default_config_loss_weights_sum_positive() { + let cfg = TrainingConfig::default(); + + assert!(cfg.lambda_kp >= 0.0, "lambda_kp must be >= 0.0"); + assert!(cfg.lambda_dp >= 0.0, "lambda_dp must be >= 0.0"); + assert!(cfg.lambda_tr >= 0.0, "lambda_tr must be >= 0.0"); + + let total = cfg.lambda_kp + cfg.lambda_dp + cfg.lambda_tr; + assert!( + total > 0.0, + "sum of loss weights must be > 0.0, got {total}" + ); +} + +/// The default loss weights should sum to exactly 1.0 (within floating-point +/// tolerance). +#[test] +fn default_config_loss_weights_sum_to_one() { + let cfg = TrainingConfig::default(); + let total = cfg.lambda_kp + cfg.lambda_dp + cfg.lambda_tr; + let diff = (total - 1.0_f64).abs(); + assert!( + diff < 1e-9, + "expected loss weights to sum to 1.0, got {total} (diff={diff})" + ); +} + +// --------------------------------------------------------------------------- +// Specific default values +// --------------------------------------------------------------------------- + +/// The default number of subcarriers is 56 (MM-Fi target). +#[test] +fn default_num_subcarriers_is_56() { + let cfg = TrainingConfig::default(); + assert_eq!( + cfg.num_subcarriers, 56, + "expected default num_subcarriers = 56, got {}", + cfg.num_subcarriers + ); +} + +/// The default number of native subcarriers is 114 (raw MM-Fi hardware output). +#[test] +fn default_native_subcarriers_is_114() { + let cfg = TrainingConfig::default(); + assert_eq!( + cfg.native_subcarriers, 114, + "expected default native_subcarriers = 114, got {}", + cfg.native_subcarriers + ); +} + +/// The default number of keypoints is 17 (COCO skeleton). +#[test] +fn default_num_keypoints_is_17() { + let cfg = TrainingConfig::default(); + assert_eq!( + cfg.num_keypoints, 17, + "expected default num_keypoints = 17, got {}", + cfg.num_keypoints + ); +} + +/// The default antenna counts are 3×3. +#[test] +fn default_antenna_counts_are_3x3() { + let cfg = TrainingConfig::default(); + assert_eq!(cfg.num_antennas_tx, 3, "expected num_antennas_tx = 3"); + assert_eq!(cfg.num_antennas_rx, 3, "expected num_antennas_rx = 3"); +} + +/// The default window length is 100 frames. +#[test] +fn default_window_frames_is_100() { + let cfg = TrainingConfig::default(); + assert_eq!( + cfg.window_frames, 100, + "expected window_frames = 100, got {}", + cfg.window_frames + ); +} + +/// The default seed is 42. +#[test] +fn default_seed_is_42() { + let cfg = TrainingConfig::default(); + assert_eq!(cfg.seed, 42, "expected seed = 42, got {}", cfg.seed); +} + +// --------------------------------------------------------------------------- +// needs_subcarrier_interp equivalent property +// --------------------------------------------------------------------------- + +/// When native_subcarriers differs from num_subcarriers, interpolation is +/// needed. The default config has 114 != 56, so this property must hold. +#[test] +fn default_config_needs_interpolation() { + let cfg = TrainingConfig::default(); + // 114 native → 56 target: interpolation is required. + assert_ne!( + cfg.native_subcarriers, cfg.num_subcarriers, + "default config must require subcarrier interpolation (native={} != target={})", + cfg.native_subcarriers, cfg.num_subcarriers + ); +} + +/// When native_subcarriers equals num_subcarriers no interpolation is needed. +#[test] +fn equal_subcarrier_counts_means_no_interpolation_needed() { + let mut cfg = TrainingConfig::default(); + cfg.native_subcarriers = cfg.num_subcarriers; // e.g., both = 56 + cfg.validate().expect("config with equal subcarrier counts must be valid"); + assert_eq!( + cfg.native_subcarriers, cfg.num_subcarriers, + "after setting equal counts, native ({}) must equal target ({})", + cfg.native_subcarriers, cfg.num_subcarriers + ); +} + +// --------------------------------------------------------------------------- +// csi_flat_size equivalent property +// --------------------------------------------------------------------------- + +/// The flat input size of a single CSI window is +/// `window_frames × num_antennas_tx × num_antennas_rx × num_subcarriers`. +/// Verify the arithmetic matches the default config. +#[test] +fn csi_flat_size_matches_expected() { + let cfg = TrainingConfig::default(); + let expected = cfg.window_frames + * cfg.num_antennas_tx + * cfg.num_antennas_rx + * cfg.num_subcarriers; + // Default: 100 * 3 * 3 * 56 = 50400 + assert_eq!( + expected, 50_400, + "CSI flat size must be 50400 for default config, got {expected}" + ); +} + +/// The CSI flat size must be > 0 for any valid config. +#[test] +fn csi_flat_size_positive_for_valid_config() { + let cfg = TrainingConfig::default(); + let flat_size = cfg.window_frames + * cfg.num_antennas_tx + * cfg.num_antennas_rx + * cfg.num_subcarriers; + assert!( + flat_size > 0, + "CSI flat size must be > 0, got {flat_size}" + ); +} + +// --------------------------------------------------------------------------- +// JSON serialization round-trip +// --------------------------------------------------------------------------- + +/// Serializing a config to JSON and deserializing it must yield an identical +/// config (all fields must match). +#[test] +fn config_json_roundtrip_identical() { + use std::path::PathBuf; + use tempfile::tempdir; + + let tmp = tempdir().expect("tempdir must be created"); + let path = tmp.path().join("config.json"); + + let original = TrainingConfig::default(); + original + .to_json(&path) + .expect("to_json must succeed for default config"); + + let loaded = TrainingConfig::from_json(&path) + .expect("from_json must succeed for previously serialized config"); + + // Verify all fields are equal. + assert_eq!( + loaded.num_subcarriers, original.num_subcarriers, + "num_subcarriers must survive round-trip" + ); + assert_eq!( + loaded.native_subcarriers, original.native_subcarriers, + "native_subcarriers must survive round-trip" + ); + assert_eq!( + loaded.num_antennas_tx, original.num_antennas_tx, + "num_antennas_tx must survive round-trip" + ); + assert_eq!( + loaded.num_antennas_rx, original.num_antennas_rx, + "num_antennas_rx must survive round-trip" + ); + assert_eq!( + loaded.window_frames, original.window_frames, + "window_frames must survive round-trip" + ); + assert_eq!( + loaded.heatmap_size, original.heatmap_size, + "heatmap_size must survive round-trip" + ); + assert_eq!( + loaded.num_keypoints, original.num_keypoints, + "num_keypoints must survive round-trip" + ); + assert_eq!( + loaded.num_body_parts, original.num_body_parts, + "num_body_parts must survive round-trip" + ); + assert_eq!( + loaded.backbone_channels, original.backbone_channels, + "backbone_channels must survive round-trip" + ); + assert_eq!( + loaded.batch_size, original.batch_size, + "batch_size must survive round-trip" + ); + assert!( + (loaded.learning_rate - original.learning_rate).abs() < 1e-12, + "learning_rate must survive round-trip: got {}", + loaded.learning_rate + ); + assert!( + (loaded.weight_decay - original.weight_decay).abs() < 1e-12, + "weight_decay must survive round-trip" + ); + assert_eq!( + loaded.num_epochs, original.num_epochs, + "num_epochs must survive round-trip" + ); + assert_eq!( + loaded.warmup_epochs, original.warmup_epochs, + "warmup_epochs must survive round-trip" + ); + assert_eq!( + loaded.lr_milestones, original.lr_milestones, + "lr_milestones must survive round-trip" + ); + assert!( + (loaded.lr_gamma - original.lr_gamma).abs() < 1e-12, + "lr_gamma must survive round-trip" + ); + assert!( + (loaded.grad_clip_norm - original.grad_clip_norm).abs() < 1e-12, + "grad_clip_norm must survive round-trip" + ); + assert!( + (loaded.lambda_kp - original.lambda_kp).abs() < 1e-12, + "lambda_kp must survive round-trip" + ); + assert!( + (loaded.lambda_dp - original.lambda_dp).abs() < 1e-12, + "lambda_dp must survive round-trip" + ); + assert!( + (loaded.lambda_tr - original.lambda_tr).abs() < 1e-12, + "lambda_tr must survive round-trip" + ); + assert_eq!( + loaded.val_every_epochs, original.val_every_epochs, + "val_every_epochs must survive round-trip" + ); + assert_eq!( + loaded.early_stopping_patience, original.early_stopping_patience, + "early_stopping_patience must survive round-trip" + ); + assert_eq!( + loaded.save_top_k, original.save_top_k, + "save_top_k must survive round-trip" + ); + assert_eq!(loaded.use_gpu, original.use_gpu, "use_gpu must survive round-trip"); + assert_eq!( + loaded.gpu_device_id, original.gpu_device_id, + "gpu_device_id must survive round-trip" + ); + assert_eq!( + loaded.num_workers, original.num_workers, + "num_workers must survive round-trip" + ); + assert_eq!(loaded.seed, original.seed, "seed must survive round-trip"); +} + +/// A modified config with non-default values must also survive a JSON +/// round-trip. +#[test] +fn config_json_roundtrip_modified_values() { + use tempfile::tempdir; + + let tmp = tempdir().expect("tempdir must be created"); + let path = tmp.path().join("modified.json"); + + let mut cfg = TrainingConfig::default(); + cfg.batch_size = 16; + cfg.learning_rate = 5e-4; + cfg.num_epochs = 100; + cfg.warmup_epochs = 10; + cfg.lr_milestones = vec![50, 80]; + cfg.seed = 99; + + cfg.validate().expect("modified config must be valid before serialization"); + cfg.to_json(&path).expect("to_json must succeed"); + + let loaded = TrainingConfig::from_json(&path).expect("from_json must succeed"); + + assert_eq!(loaded.batch_size, 16, "batch_size must match after round-trip"); + assert!( + (loaded.learning_rate - 5e-4_f64).abs() < 1e-12, + "learning_rate must match after round-trip" + ); + assert_eq!(loaded.num_epochs, 100, "num_epochs must match after round-trip"); + assert_eq!(loaded.warmup_epochs, 10, "warmup_epochs must match after round-trip"); + assert_eq!( + loaded.lr_milestones, + vec![50, 80], + "lr_milestones must match after round-trip" + ); + assert_eq!(loaded.seed, 99, "seed must match after round-trip"); +} + +// --------------------------------------------------------------------------- +// Validation: invalid configurations are rejected +// --------------------------------------------------------------------------- + +/// Setting num_subcarriers to 0 must produce a validation error. +#[test] +fn zero_num_subcarriers_is_invalid() { + let mut cfg = TrainingConfig::default(); + cfg.num_subcarriers = 0; + assert!( + cfg.validate().is_err(), + "num_subcarriers = 0 must be rejected by validate()" + ); +} + +/// Setting native_subcarriers to 0 must produce a validation error. +#[test] +fn zero_native_subcarriers_is_invalid() { + let mut cfg = TrainingConfig::default(); + cfg.native_subcarriers = 0; + assert!( + cfg.validate().is_err(), + "native_subcarriers = 0 must be rejected by validate()" + ); +} + +/// Setting batch_size to 0 must produce a validation error. +#[test] +fn zero_batch_size_is_invalid() { + let mut cfg = TrainingConfig::default(); + cfg.batch_size = 0; + assert!( + cfg.validate().is_err(), + "batch_size = 0 must be rejected by validate()" + ); +} + +/// A negative learning rate must produce a validation error. +#[test] +fn negative_learning_rate_is_invalid() { + let mut cfg = TrainingConfig::default(); + cfg.learning_rate = -0.001; + assert!( + cfg.validate().is_err(), + "learning_rate < 0 must be rejected by validate()" + ); +} + +/// warmup_epochs >= num_epochs must produce a validation error. +#[test] +fn warmup_exceeding_epochs_is_invalid() { + let mut cfg = TrainingConfig::default(); + cfg.warmup_epochs = cfg.num_epochs; // equal, which is still invalid + assert!( + cfg.validate().is_err(), + "warmup_epochs >= num_epochs must be rejected by validate()" + ); +} + +/// All loss weights set to 0.0 must produce a validation error. +#[test] +fn all_zero_loss_weights_are_invalid() { + let mut cfg = TrainingConfig::default(); + cfg.lambda_kp = 0.0; + cfg.lambda_dp = 0.0; + cfg.lambda_tr = 0.0; + assert!( + cfg.validate().is_err(), + "all-zero loss weights must be rejected by validate()" + ); +} + +/// Non-increasing lr_milestones must produce a validation error. +#[test] +fn non_increasing_milestones_are_invalid() { + let mut cfg = TrainingConfig::default(); + cfg.lr_milestones = vec![40, 30]; // wrong order + assert!( + cfg.validate().is_err(), + "non-increasing lr_milestones must be rejected by validate()" + ); +} + +/// An lr_milestone beyond num_epochs must produce a validation error. +#[test] +fn milestone_beyond_num_epochs_is_invalid() { + let mut cfg = TrainingConfig::default(); + cfg.lr_milestones = vec![30, cfg.num_epochs + 1]; + assert!( + cfg.validate().is_err(), + "lr_milestone > num_epochs must be rejected by validate()" + ); +}