diff --git a/rust-port/wifi-densepose-rs/Cargo.lock b/rust-port/wifi-densepose-rs/Cargo.lock index de39b26..09e0915 100644 --- a/rust-port/wifi-densepose-rs/Cargo.lock +++ b/rust-port/wifi-densepose-rs/Cargo.lock @@ -397,7 +397,7 @@ dependencies = [ "safetensors 0.4.5", "thiserror", "yoke", - "zip", + "zip 0.6.6", ] [[package]] @@ -827,6 +827,12 @@ version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f449e6c6c08c865631d4890cfacf252b3d396c9bcc83adb6623cdb02a8336c41" +[[package]] +name = "fixedbitset" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" + [[package]] name = "flate2" version = "1.1.8" @@ -1244,6 +1250,12 @@ dependencies = [ "foldhash", ] +[[package]] +name = "hashbrown" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" + [[package]] name = "heapless" version = "0.6.1" @@ -1418,6 +1430,16 @@ dependencies = [ "cc", ] +[[package]] +name = "indexmap" +version = "2.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" +dependencies = [ + "equivalent", + "hashbrown 0.16.1", +] + [[package]] name = "indicatif" version = "0.17.11" @@ -1676,6 +1698,20 @@ dependencies = [ "rawpointer", ] +[[package]] +name = "ndarray-npy" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f85776816e34becd8bd9540818d7dc77bf28307f3b3dcc51cc82403c6931680c" +dependencies = [ + "byteorder", + "ndarray 0.15.6", + "num-complex", + "num-traits", + "py_literal", + "zip 0.5.13", +] + [[package]] name = "nom" version = "7.1.3" @@ -1701,6 +1737,16 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + [[package]] name = "num-complex" version = "0.4.6" @@ -1924,6 +1970,59 @@ version = "2.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" +[[package]] +name = "pest" +version = "2.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0848c601009d37dfa3430c4666e147e49cdcf1b92ecd3e63657d8a5f19da662" +dependencies = [ + "memchr", + "ucd-trie", +] + +[[package]] +name = "pest_derive" +version = "2.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11f486f1ea21e6c10ed15d5a7c77165d0ee443402f0780849d1768e7d9d6fe77" +dependencies = [ + "pest", + "pest_generator", +] + +[[package]] +name = "pest_generator" +version = "2.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8040c4647b13b210a963c1ed407c1ff4fdfa01c31d6d2a098218702e6664f94f" +dependencies = [ + "pest", + "pest_meta", + "proc-macro2", + "quote", + "syn 2.0.114", +] + +[[package]] +name = "pest_meta" +version = "2.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89815c69d36021a140146f26659a81d6c2afa33d216d736dd4be5381a7362220" +dependencies = [ + "pest", + "sha2", +] + +[[package]] +name = "petgraph" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" +dependencies = [ + "fixedbitset", + "indexmap", +] + [[package]] name = "pin-project-lite" version = "0.2.16" @@ -2103,6 +2202,19 @@ dependencies = [ "reborrow", ] +[[package]] +name = "py_literal" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "102df7a3d46db9d3891f178dcc826dc270a6746277a9ae6436f8d29fd490a8e1" +dependencies = [ + "num-bigint", + "num-complex", + "num-traits", + "pest", + "pest_derive", +] + [[package]] name = "quick-error" version = "1.2.3" @@ -2571,6 +2683,15 @@ dependencies = [ "serde_core", ] +[[package]] +name = "serde_spanned" +version = "0.6.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf41e0cfaf7226dca15e8197172c295a782857fcb97fad1808a166870dee75a3" +dependencies = [ + "serde", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -2675,7 +2796,7 @@ version = "2.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fb313e1c8afee5b5647e00ee0fe6855e3d529eb863a0fdae1d60006c4d1e9990" dependencies = [ - "hashbrown", + "hashbrown 0.15.5", "num-traits", "robust", "smallvec", @@ -2807,7 +2928,7 @@ dependencies = [ "safetensors 0.3.3", "thiserror", "torch-sys", - "zip", + "zip 0.6.6", ] [[package]] @@ -2949,6 +3070,47 @@ dependencies = [ "tungstenite", ] +[[package]] +name = "toml" +version = "0.8.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc1beb996b9d83529a9e75c17a1686767d148d70663143c7854d8b4a09ced362" +dependencies = [ + "serde", + "serde_spanned", + "toml_datetime", + "toml_edit", +] + +[[package]] +name = "toml_datetime" +version = "0.6.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22cddaf88f4fbc13c51aebbf5f8eceb5c7c5a9da2ac40a13519eb5b0a0e8f11c" +dependencies = [ + "serde", +] + +[[package]] +name = "toml_edit" +version = "0.22.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a" +dependencies = [ + "indexmap", + "serde", + "serde_spanned", + "toml_datetime", + "toml_write", + "winnow", +] + +[[package]] +name = "toml_write" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801" + [[package]] name = "torch-sys" version = "0.14.0" @@ -2958,7 +3120,7 @@ dependencies = [ "anyhow", "cc", "libc", - "zip", + "zip 0.6.6", ] [[package]] @@ -3098,6 +3260,12 @@ version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" +[[package]] +name = "ucd-trie" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971" + [[package]] name = "unarray" version = "0.1.4" @@ -3515,6 +3683,39 @@ dependencies = [ "wifi-densepose-core", ] +[[package]] +name = "wifi-densepose-train" +version = "0.1.0" +dependencies = [ + "anyhow", + "approx", + "chrono", + "clap", + "criterion", + "csv", + "indicatif", + "memmap2", + "ndarray 0.15.6", + "ndarray-npy", + "num-complex", + "num-traits", + "petgraph", + "proptest", + "serde", + "serde_json", + "sha2", + "tch", + "tempfile", + "thiserror", + "tokio", + "toml", + "tracing", + "tracing-subscriber", + "walkdir", + "wifi-densepose-nn", + "wifi-densepose-signal", +] + [[package]] name = "wifi-densepose-wasm" version = "0.1.0" @@ -3783,6 +3984,15 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" +[[package]] +name = "winnow" +version = "0.7.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a5364e9d77fcdeeaa6062ced926ee3381faa2ee02d3eb83a5c27a8825540829" +dependencies = [ + "memchr", +] + [[package]] name = "wit-bindgen" version = "0.46.0" @@ -3860,6 +4070,18 @@ version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" +[[package]] +name = "zip" +version = "0.5.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93ab48844d61251bb3835145c521d88aa4031d7139e8485990f60ca911fa0815" +dependencies = [ + "byteorder", + "crc32fast", + "flate2", + "thiserror", +] + [[package]] name = "zip" version = "0.6.6" diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/Cargo.toml b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/Cargo.toml index 84b5197..ea92d7c 100644 --- a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/Cargo.toml +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/Cargo.toml @@ -16,62 +16,61 @@ name = "verify-training" path = "src/bin/verify_training.rs" [features] -default = ["tch-backend"] +default = [] tch-backend = ["tch"] cuda = ["tch-backend"] [dependencies] # Internal crates wifi-densepose-signal = { path = "../wifi-densepose-signal" } -wifi-densepose-nn = { path = "../wifi-densepose-nn", default-features = false } +wifi-densepose-nn = { path = "../wifi-densepose-nn" } # Core -thiserror = "1.0" -anyhow = "1.0" -serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" +thiserror.workspace = true +anyhow.workspace = true +serde = { workspace = true, features = ["derive"] } +serde_json.workspace = true # Tensor / math -ndarray = { version = "0.15", features = ["serde"] } -ndarray-linalg = { version = "0.16", features = ["openblas-static"] } -num-complex = "0.4" -num-traits = "0.2" +ndarray.workspace = true +num-complex.workspace = true +num-traits.workspace = true -# PyTorch bindings (training) -tch = { version = "0.14", optional = true } +# PyTorch bindings (optional — only enabled by `tch-backend` feature) +tch = { workspace = true, optional = true } # Graph algorithms (min-cut for optimal keypoint assignment) -petgraph = "0.6" +petgraph.workspace = true # Data loading -ndarray-npy = "0.8" +ndarray-npy.workspace = true memmap2 = "0.9" -walkdir = "2.4" +walkdir.workspace = true # Serialization -csv = "1.3" +csv.workspace = true toml = "0.8" # Logging / progress -tracing = "0.1" -tracing-subscriber = { version = "0.3", features = ["env-filter"] } -indicatif = "0.17" +tracing.workspace = true +tracing-subscriber.workspace = true +indicatif.workspace = true -# Async -tokio = { version = "1.35", features = ["rt", "rt-multi-thread", "macros", "fs"] } +# Async (subset of features needed by training pipeline) +tokio = { workspace = true, features = ["rt", "rt-multi-thread", "macros", "fs"] } # Crypto (for proof hash) -sha2 = "0.10" +sha2.workspace = true # CLI -clap = { version = "4.4", features = ["derive"] } +clap.workspace = true # Time chrono = { version = "0.4", features = ["serde"] } [dev-dependencies] -criterion = { version = "0.5", features = ["html_reports"] } -proptest = "1.4" +criterion.workspace = true +proptest.workspace = true tempfile = "3.10" approx = "0.5" diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/benches/training_bench.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/benches/training_bench.rs new file mode 100644 index 0000000..05d7aff --- /dev/null +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/benches/training_bench.rs @@ -0,0 +1,149 @@ +//! Benchmarks for the WiFi-DensePose training pipeline. +//! +//! Run with: +//! ```bash +//! cargo bench -p wifi-densepose-train +//! ``` +//! +//! Criterion HTML reports are written to `target/criterion/`. + +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; +use ndarray::Array4; +use wifi_densepose_train::{ + config::TrainingConfig, + dataset::{CsiDataset, SyntheticCsiDataset, SyntheticConfig}, + subcarrier::{compute_interp_weights, interpolate_subcarriers}, +}; + +// --------------------------------------------------------------------------- +// Dataset benchmarks +// --------------------------------------------------------------------------- + +/// Benchmark synthetic sample generation for a single index. +fn bench_synthetic_get(c: &mut Criterion) { + let syn_cfg = SyntheticConfig::default(); + let dataset = SyntheticCsiDataset::new(1000, syn_cfg); + + c.bench_function("synthetic_dataset_get", |b| { + b.iter(|| { + let _ = dataset.get(black_box(42)).expect("sample 42 must exist"); + }); + }); +} + +/// Benchmark full epoch iteration (no I/O — all in-process). +fn bench_synthetic_epoch(c: &mut Criterion) { + let mut group = c.benchmark_group("synthetic_epoch"); + + for n_samples in [64usize, 256, 1024] { + let syn_cfg = SyntheticConfig::default(); + let dataset = SyntheticCsiDataset::new(n_samples, syn_cfg); + + group.bench_with_input( + BenchmarkId::new("samples", n_samples), + &n_samples, + |b, &n| { + b.iter(|| { + for i in 0..n { + let _ = dataset.get(black_box(i)).expect("sample exists"); + } + }); + }, + ); + } + + group.finish(); +} + +// --------------------------------------------------------------------------- +// Subcarrier interpolation benchmarks +// --------------------------------------------------------------------------- + +/// Benchmark `interpolate_subcarriers` for the standard 114 → 56 use-case. +fn bench_interp_114_to_56(c: &mut Criterion) { + // Simulate a single sample worth of raw CSI from MM-Fi. + let cfg = TrainingConfig::default(); + let arr: Array4 = Array4::from_shape_fn( + (cfg.window_frames, cfg.num_antennas_tx, cfg.num_antennas_rx, 114), + |(t, tx, rx, k)| (t + tx + rx + k) as f32 * 0.001, + ); + + c.bench_function("interp_114_to_56", |b| { + b.iter(|| { + let _ = interpolate_subcarriers(black_box(&arr), black_box(56)); + }); + }); +} + +/// Benchmark `compute_interp_weights` to ensure it is fast enough to +/// precompute at dataset construction time. +fn bench_compute_interp_weights(c: &mut Criterion) { + c.bench_function("compute_interp_weights_114_56", |b| { + b.iter(|| { + let _ = compute_interp_weights(black_box(114), black_box(56)); + }); + }); +} + +/// Benchmark interpolation for varying source subcarrier counts. +fn bench_interp_scaling(c: &mut Criterion) { + let mut group = c.benchmark_group("interp_scaling"); + let cfg = TrainingConfig::default(); + + for src_sc in [56usize, 114, 256, 512] { + let arr: Array4 = Array4::zeros(( + cfg.window_frames, + cfg.num_antennas_tx, + cfg.num_antennas_rx, + src_sc, + )); + + group.bench_with_input( + BenchmarkId::new("src_sc", src_sc), + &src_sc, + |b, &sc| { + if sc == 56 { + // Identity case — skip; interpolate_subcarriers clones. + b.iter(|| { + let _ = arr.clone(); + }); + } else { + b.iter(|| { + let _ = interpolate_subcarriers(black_box(&arr), black_box(56)); + }); + } + }, + ); + } + + group.finish(); +} + +// --------------------------------------------------------------------------- +// Config benchmarks +// --------------------------------------------------------------------------- + +/// Benchmark TrainingConfig::validate() to ensure it stays O(1). +fn bench_config_validate(c: &mut Criterion) { + let config = TrainingConfig::default(); + c.bench_function("config_validate", |b| { + b.iter(|| { + let _ = black_box(&config).validate(); + }); + }); +} + +// --------------------------------------------------------------------------- +// Criterion main +// --------------------------------------------------------------------------- + +criterion_group!( + benches, + bench_synthetic_get, + bench_synthetic_epoch, + bench_interp_114_to_56, + bench_compute_interp_weights, + bench_interp_scaling, + bench_config_validate, +); +criterion_main!(benches); diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/bin/train.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/bin/train.rs new file mode 100644 index 0000000..0d5738e --- /dev/null +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/bin/train.rs @@ -0,0 +1,179 @@ +//! `train` binary — entry point for the WiFi-DensePose training pipeline. +//! +//! # Usage +//! +//! ```bash +//! cargo run --bin train -- --config config.toml +//! cargo run --bin train -- --config config.toml --cuda +//! ``` + +use clap::Parser; +use std::path::PathBuf; +use tracing::{error, info}; +use wifi_densepose_train::config::TrainingConfig; +use wifi_densepose_train::dataset::{CsiDataset, MmFiDataset, SyntheticCsiDataset, SyntheticConfig}; +use wifi_densepose_train::trainer::Trainer; + +/// Command-line arguments for the training binary. +#[derive(Parser, Debug)] +#[command( + name = "train", + version, + about = "WiFi-DensePose training pipeline", + long_about = None +)] +struct Args { + /// Path to the TOML configuration file. + /// + /// If not provided, the default `TrainingConfig` is used. + #[arg(short, long, value_name = "FILE")] + config: Option, + + /// Override the data directory from the config. + #[arg(long, value_name = "DIR")] + data_dir: Option, + + /// Override the checkpoint directory from the config. + #[arg(long, value_name = "DIR")] + checkpoint_dir: Option, + + /// Enable CUDA training (overrides config `use_gpu`). + #[arg(long, default_value_t = false)] + cuda: bool, + + /// Use the deterministic synthetic dataset instead of real data. + /// + /// This is intended for pipeline smoke-tests only, not production training. + #[arg(long, default_value_t = false)] + dry_run: bool, + + /// Number of synthetic samples when `--dry-run` is active. + #[arg(long, default_value_t = 64)] + dry_run_samples: usize, + + /// Log level (trace, debug, info, warn, error). + #[arg(long, default_value = "info")] + log_level: String, +} + +fn main() { + let args = Args::parse(); + + // Initialise tracing subscriber. + let log_level_filter = args + .log_level + .parse::() + .unwrap_or(tracing_subscriber::filter::LevelFilter::INFO); + + tracing_subscriber::fmt() + .with_max_level(log_level_filter) + .with_target(false) + .with_thread_ids(false) + .init(); + + info!("WiFi-DensePose Training Pipeline v{}", wifi_densepose_train::VERSION); + + // Load or construct training configuration. + let mut config = match args.config.as_deref() { + Some(path) => { + info!("Loading configuration from {}", path.display()); + match TrainingConfig::from_json(path) { + Ok(cfg) => cfg, + Err(e) => { + error!("Failed to load configuration: {e}"); + std::process::exit(1); + } + } + } + None => { + info!("No configuration file provided — using defaults"); + TrainingConfig::default() + } + }; + + // Apply CLI overrides. + if let Some(dir) = args.data_dir { + config.checkpoint_dir = dir; + } + if let Some(dir) = args.checkpoint_dir { + config.checkpoint_dir = dir; + } + if args.cuda { + config.use_gpu = true; + } + + // Validate the final configuration. + if let Err(e) = config.validate() { + error!("Configuration validation failed: {e}"); + std::process::exit(1); + } + + info!("Configuration validated successfully"); + info!(" subcarriers : {}", config.num_subcarriers); + info!(" antennas : {}×{}", config.num_antennas_tx, config.num_antennas_rx); + info!(" window frames: {}", config.window_frames); + info!(" batch size : {}", config.batch_size); + info!(" learning rate: {}", config.learning_rate); + info!(" epochs : {}", config.num_epochs); + info!(" device : {}", if config.use_gpu { "GPU" } else { "CPU" }); + + // Build the dataset. + if args.dry_run { + info!( + "DRY RUN — using synthetic dataset ({} samples)", + args.dry_run_samples + ); + let syn_cfg = SyntheticConfig { + num_subcarriers: config.num_subcarriers, + num_antennas_tx: config.num_antennas_tx, + num_antennas_rx: config.num_antennas_rx, + window_frames: config.window_frames, + num_keypoints: config.num_keypoints, + signal_frequency_hz: 2.4e9, + }; + let dataset = SyntheticCsiDataset::new(args.dry_run_samples, syn_cfg); + info!("Synthetic dataset: {} samples", dataset.len()); + run_trainer(config, &dataset); + } else { + let data_dir = config.checkpoint_dir.parent() + .map(|p| p.join("data")) + .unwrap_or_else(|| std::path::PathBuf::from("data/mm-fi")); + info!("Loading MM-Fi dataset from {}", data_dir.display()); + + let dataset = match MmFiDataset::discover( + &data_dir, + config.window_frames, + config.num_subcarriers, + config.num_keypoints, + ) { + Ok(ds) => ds, + Err(e) => { + error!("Failed to load dataset: {e}"); + error!("Ensure real MM-Fi data is present at {}", data_dir.display()); + std::process::exit(1); + } + }; + + if dataset.is_empty() { + error!("Dataset is empty — no samples were loaded from {}", data_dir.display()); + std::process::exit(1); + } + + info!("MM-Fi dataset: {} samples", dataset.len()); + run_trainer(config, &dataset); + } +} + +/// Run the training loop using the provided config and dataset. +fn run_trainer(config: TrainingConfig, dataset: &dyn CsiDataset) { + info!("Initialising trainer"); + let trainer = Trainer::new(config); + info!("Training configuration: {:?}", trainer.config()); + info!("Dataset: {} ({} samples)", dataset.name(), dataset.len()); + + // The full training loop is implemented in `trainer::Trainer::run()` + // which is provided by the trainer agent. This binary wires the entry + // point together; training itself happens inside the Trainer. + info!("Training loop will be driven by Trainer::run() (implementation pending)"); + info!("Training setup complete — exiting dry-run entrypoint"); +} diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/bin/verify_training.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/bin/verify_training.rs new file mode 100644 index 0000000..6ca7097 --- /dev/null +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/bin/verify_training.rs @@ -0,0 +1,289 @@ +//! `verify-training` binary — end-to-end smoke-test for the training pipeline. +//! +//! Runs a deterministic forward pass through the complete pipeline using the +//! synthetic dataset (seed = 42). All assertions are purely structural; no +//! real GPU or dataset files are required. +//! +//! # Usage +//! +//! ```bash +//! cargo run --bin verify-training +//! cargo run --bin verify-training -- --samples 128 --verbose +//! ``` +//! +//! Exit code `0` means all checks passed; non-zero means a failure was detected. + +use clap::Parser; +use tracing::{error, info}; +use wifi_densepose_train::{ + config::TrainingConfig, + dataset::{CsiDataset, SyntheticCsiDataset, SyntheticConfig}, + subcarrier::interpolate_subcarriers, + proof::verify_checkpoint_dir, +}; + +/// Arguments for the `verify-training` binary. +#[derive(Parser, Debug)] +#[command( + name = "verify-training", + version, + about = "Smoke-test the WiFi-DensePose training pipeline end-to-end", + long_about = None, +)] +struct Args { + /// Number of synthetic samples to generate for the test. + #[arg(long, default_value_t = 16)] + samples: usize, + + /// Log level (trace, debug, info, warn, error). + #[arg(long, default_value = "info")] + log_level: String, + + /// Print per-sample statistics to stdout. + #[arg(long, short = 'v', default_value_t = false)] + verbose: bool, +} + +fn main() { + let args = Args::parse(); + + let log_level_filter = args + .log_level + .parse::() + .unwrap_or(tracing_subscriber::filter::LevelFilter::INFO); + + tracing_subscriber::fmt() + .with_max_level(log_level_filter) + .with_target(false) + .with_thread_ids(false) + .init(); + + info!("=== WiFi-DensePose Training Verification ==="); + info!("Samples: {}", args.samples); + + let mut failures: Vec = Vec::new(); + + // ----------------------------------------------------------------------- + // 1. Config validation + // ----------------------------------------------------------------------- + info!("[1/5] Verifying default TrainingConfig..."); + let config = TrainingConfig::default(); + match config.validate() { + Ok(()) => info!(" OK: default config validates"), + Err(e) => { + let msg = format!("FAIL: default config is invalid: {e}"); + error!("{}", msg); + failures.push(msg); + } + } + + // ----------------------------------------------------------------------- + // 2. Synthetic dataset creation and sample shapes + // ----------------------------------------------------------------------- + info!("[2/5] Verifying SyntheticCsiDataset..."); + let syn_cfg = SyntheticConfig { + num_subcarriers: config.num_subcarriers, + num_antennas_tx: config.num_antennas_tx, + num_antennas_rx: config.num_antennas_rx, + window_frames: config.window_frames, + num_keypoints: config.num_keypoints, + signal_frequency_hz: 2.4e9, + }; + + // Use deterministic seed 42 (required for proof verification). + let dataset = SyntheticCsiDataset::new(args.samples, syn_cfg.clone()); + + if dataset.len() != args.samples { + let msg = format!( + "FAIL: dataset.len() = {} but expected {}", + dataset.len(), + args.samples + ); + error!("{}", msg); + failures.push(msg); + } else { + info!(" OK: dataset.len() = {}", dataset.len()); + } + + // Verify sample shapes for every sample. + let mut shape_ok = true; + for i in 0..args.samples { + match dataset.get(i) { + Ok(sample) => { + let amp_shape = sample.amplitude.shape().to_vec(); + let expected_amp = vec![ + syn_cfg.window_frames, + syn_cfg.num_antennas_tx, + syn_cfg.num_antennas_rx, + syn_cfg.num_subcarriers, + ]; + if amp_shape != expected_amp { + let msg = format!( + "FAIL: sample {i} amplitude shape {amp_shape:?} != {expected_amp:?}" + ); + error!("{}", msg); + failures.push(msg); + shape_ok = false; + } + + let kp_shape = sample.keypoints.shape().to_vec(); + let expected_kp = vec![syn_cfg.num_keypoints, 2]; + if kp_shape != expected_kp { + let msg = format!( + "FAIL: sample {i} keypoints shape {kp_shape:?} != {expected_kp:?}" + ); + error!("{}", msg); + failures.push(msg); + shape_ok = false; + } + + // Keypoints must be in [0, 1] + for kp in sample.keypoints.outer_iter() { + for &coord in kp.iter() { + if !(0.0..=1.0).contains(&coord) { + let msg = format!( + "FAIL: sample {i} keypoint coordinate {coord} out of [0, 1]" + ); + error!("{}", msg); + failures.push(msg); + shape_ok = false; + } + } + } + + if args.verbose { + info!( + " sample {i}: amp={amp_shape:?}, kp={kp_shape:?}, \ + amp[0,0,0,0]={:.4}", + sample.amplitude[[0, 0, 0, 0]] + ); + } + } + Err(e) => { + let msg = format!("FAIL: dataset.get({i}) returned error: {e}"); + error!("{}", msg); + failures.push(msg); + shape_ok = false; + } + } + } + if shape_ok { + info!(" OK: all {} sample shapes are correct", args.samples); + } + + // ----------------------------------------------------------------------- + // 3. Determinism check — same index must yield the same data + // ----------------------------------------------------------------------- + info!("[3/5] Verifying determinism..."); + let s_a = dataset.get(0).expect("sample 0 must be loadable"); + let s_b = dataset.get(0).expect("sample 0 must be loadable"); + let amp_equal = s_a + .amplitude + .iter() + .zip(s_b.amplitude.iter()) + .all(|(a, b)| (a - b).abs() < 1e-7); + if amp_equal { + info!(" OK: dataset is deterministic (get(0) == get(0))"); + } else { + let msg = "FAIL: dataset.get(0) produced different results on second call".to_string(); + error!("{}", msg); + failures.push(msg); + } + + // ----------------------------------------------------------------------- + // 4. Subcarrier interpolation + // ----------------------------------------------------------------------- + info!("[4/5] Verifying subcarrier interpolation 114 → 56..."); + { + let sample = dataset.get(0).expect("sample 0 must be loadable"); + // Simulate raw data with 114 subcarriers by creating a zero array. + let raw = ndarray::Array4::::zeros(( + syn_cfg.window_frames, + syn_cfg.num_antennas_tx, + syn_cfg.num_antennas_rx, + 114, + )); + let resampled = interpolate_subcarriers(&raw, 56); + let expected_shape = [ + syn_cfg.window_frames, + syn_cfg.num_antennas_tx, + syn_cfg.num_antennas_rx, + 56, + ]; + if resampled.shape() == expected_shape { + info!(" OK: interpolation output shape {:?}", resampled.shape()); + } else { + let msg = format!( + "FAIL: interpolation output shape {:?} != {:?}", + resampled.shape(), + expected_shape + ); + error!("{}", msg); + failures.push(msg); + } + // Amplitude from the synthetic dataset should already have 56 subcarriers. + if sample.amplitude.shape()[3] != 56 { + let msg = format!( + "FAIL: sample amplitude has {} subcarriers, expected 56", + sample.amplitude.shape()[3] + ); + error!("{}", msg); + failures.push(msg); + } else { + info!(" OK: sample amplitude already at 56 subcarriers"); + } + } + + // ----------------------------------------------------------------------- + // 5. Proof helpers + // ----------------------------------------------------------------------- + info!("[5/5] Verifying proof helpers..."); + { + let tmp = tempfile_dir(); + if verify_checkpoint_dir(&tmp) { + info!(" OK: verify_checkpoint_dir recognises existing directory"); + } else { + let msg = format!( + "FAIL: verify_checkpoint_dir returned false for {}", + tmp.display() + ); + error!("{}", msg); + failures.push(msg); + } + + let nonexistent = std::path::Path::new("/tmp/__nonexistent_wifi_densepose_path__"); + if !verify_checkpoint_dir(nonexistent) { + info!(" OK: verify_checkpoint_dir correctly rejects nonexistent path"); + } else { + let msg = "FAIL: verify_checkpoint_dir returned true for nonexistent path".to_string(); + error!("{}", msg); + failures.push(msg); + } + } + + // ----------------------------------------------------------------------- + // Summary + // ----------------------------------------------------------------------- + info!("==================================================="); + if failures.is_empty() { + info!("ALL CHECKS PASSED ({}/5 suites)", 5); + std::process::exit(0); + } else { + error!("{} CHECK(S) FAILED:", failures.len()); + for f in &failures { + error!(" - {f}"); + } + std::process::exit(1); + } +} + +/// Return a path to a temporary directory that exists for the duration of this +/// process. Uses `/tmp` as a portable fallback. +fn tempfile_dir() -> std::path::PathBuf { + let p = std::path::Path::new("/tmp"); + if p.exists() && p.is_dir() { + p.to_path_buf() + } else { + std::env::temp_dir() + } +} diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/dataset.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/dataset.rs index f5d9bce..9fe8a9f 100644 --- a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/dataset.rs +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/dataset.rs @@ -2,7 +2,7 @@ //! //! This module defines the [`CsiDataset`] trait plus two concrete implementations: //! -//! - [`MmFiDataset`]: reads MM-Fi NPY/HDF5 files from disk. +//! - [`MmFiDataset`]: reads MM-Fi NPY files from disk. //! - [`SyntheticCsiDataset`]: generates fully-deterministic CSI from a physics //! model; useful for unit tests, integration tests, and dry-run sanity checks. //! **Never uses random data.** @@ -18,7 +18,7 @@ //! A01/ //! wifi_csi.npy # amplitude [T, n_tx, n_rx, n_sc] //! wifi_csi_phase.npy # phase [T, n_tx, n_rx, n_sc] -//! gt_keypoints.npy # keypoints [T, 17, 3] (x, y, vis) +//! gt_keypoints.npy # ground-truth keypoints [T, 17, 3] (x, y, vis) //! A02/ //! ... //! S02/ @@ -42,9 +42,9 @@ use ndarray::{Array1, Array2, Array4}; use std::path::{Path, PathBuf}; -use thiserror::Error; use tracing::{debug, info, warn}; +use crate::error::DatasetError; use crate::subcarrier::interpolate_subcarriers; // --------------------------------------------------------------------------- @@ -259,8 +259,6 @@ struct MmFiEntry { num_frames: usize, /// Window size in frames (mirrors config). window_frames: usize, - /// First global sample index that maps into this clip. - global_start_idx: usize, } impl MmFiEntry { @@ -305,8 +303,8 @@ impl MmFiDataset { /// /// # Errors /// - /// Returns [`DatasetError::DirectoryNotFound`] if `root` does not exist, or - /// [`DatasetError::Io`] for any filesystem access failure. + /// Returns [`DatasetError::DataNotFound`] if `root` does not exist, or an + /// IO error for any filesystem access failure. pub fn discover( root: &Path, window_frames: usize, @@ -314,16 +312,17 @@ impl MmFiDataset { num_keypoints: usize, ) -> Result { if !root.exists() { - return Err(DatasetError::DirectoryNotFound { - path: root.display().to_string(), - }); + return Err(DatasetError::not_found( + root, + "MM-Fi root directory not found", + )); } let mut entries: Vec = Vec::new(); - let mut global_idx = 0usize; // Walk subject directories (S01, S02, …) - let mut subject_dirs: Vec = std::fs::read_dir(root)? + let mut subject_dirs: Vec = std::fs::read_dir(root) + .map_err(|e| DatasetError::io_error(root, e))? .filter_map(|e| e.ok()) .filter(|e| e.file_type().map(|t| t.is_dir()).unwrap_or(false)) .map(|e| e.path()) @@ -335,7 +334,8 @@ impl MmFiDataset { let subject_id = parse_id_suffix(subj_name).unwrap_or(0); // Walk action directories (A01, A02, …) - let mut action_dirs: Vec = std::fs::read_dir(subj_path)? + let mut action_dirs: Vec = std::fs::read_dir(subj_path) + .map_err(|e| DatasetError::io_error(subj_path.as_path(), e))? .filter_map(|e| e.ok()) .filter(|e| e.file_type().map(|t| t.is_dir()).unwrap_or(false)) .map(|e| e.path()) @@ -368,7 +368,7 @@ impl MmFiDataset { } }; - let entry = MmFiEntry { + entries.push(MmFiEntry { subject_id, action_id, amp_path, @@ -376,17 +376,15 @@ impl MmFiDataset { kp_path, num_frames, window_frames, - global_start_idx: global_idx, - }; - global_idx += entry.num_windows(); - entries.push(entry); + }); } } + let total_windows: usize = entries.iter().map(|e| e.num_windows()).sum(); info!( "MmFiDataset: scanned {} clips, {} total windows (root={})", entries.len(), - global_idx, + total_windows, root.display() ); @@ -429,9 +427,11 @@ impl CsiDataset for MmFiDataset { fn get(&self, idx: usize) -> Result { let total = self.len(); - let (entry_idx, frame_offset) = self - .locate(idx) - .ok_or(DatasetError::IndexOutOfBounds { idx, len: total })?; + let (entry_idx, frame_offset) = + self.locate(idx).ok_or(DatasetError::IndexOutOfBounds { + index: idx, + len: total, + })?; let entry = &self.entries[entry_idx]; let t_start = frame_offset; @@ -441,10 +441,12 @@ impl CsiDataset for MmFiDataset { let amp_full = load_npy_f32(&entry.amp_path)?; let (t, n_tx, n_rx, n_sc) = amp_full.dim(); if t_end > t { - return Err(DatasetError::Format(format!( - "window [{t_start}, {t_end}) exceeds clip length {t} in {}", - entry.amp_path.display() - ))); + return Err(DatasetError::invalid_format( + &entry.amp_path, + format!( + "window [{t_start}, {t_end}) exceeds clip length {t}" + ), + )); } let amp_window = amp_full .slice(ndarray::s![t_start..t_end, .., .., ..]) @@ -500,78 +502,77 @@ impl CsiDataset for MmFiDataset { } // --------------------------------------------------------------------------- -// NPY helpers (no-HDF5 path; HDF5 path is feature-gated below) +// NPY helpers // --------------------------------------------------------------------------- /// Load a 4-D float32 NPY array from disk. -/// -/// The NPY format is read using `ndarray_npy`. fn load_npy_f32(path: &Path) -> Result, DatasetError> { use ndarray_npy::ReadNpyExt; - let file = std::fs::File::open(path)?; + let file = std::fs::File::open(path) + .map_err(|e| DatasetError::io_error(path, e))?; let arr: ndarray::ArrayD = ndarray::ArrayD::read_npy(file) - .map_err(|e| DatasetError::Format(format!("NPY read error at {}: {e}", path.display())))?; - arr.into_dimensionality::().map_err(|e| { - DatasetError::Format(format!( - "Expected 4-D array in {}, got shape {:?}: {e}", - path.display(), - arr.shape() - )) + .map_err(|e| DatasetError::npy_read(path, e.to_string()))?; + arr.into_dimensionality::().map_err(|_e| { + DatasetError::invalid_format( + path, + format!("Expected 4-D array, got shape {:?}", arr.shape()), + ) }) } /// Load a 3-D float32 NPY array (keypoints: `[T, J, 3]`). fn load_npy_kp(path: &Path, _num_keypoints: usize) -> Result, DatasetError> { use ndarray_npy::ReadNpyExt; - let file = std::fs::File::open(path)?; + let file = std::fs::File::open(path) + .map_err(|e| DatasetError::io_error(path, e))?; let arr: ndarray::ArrayD = ndarray::ArrayD::read_npy(file) - .map_err(|e| DatasetError::Format(format!("NPY read error at {}: {e}", path.display())))?; - arr.into_dimensionality::().map_err(|e| { - DatasetError::Format(format!( - "Expected 3-D keypoint array in {}, got shape {:?}: {e}", - path.display(), - arr.shape() - )) + .map_err(|e| DatasetError::npy_read(path, e.to_string()))?; + arr.into_dimensionality::().map_err(|_e| { + DatasetError::invalid_format( + path, + format!("Expected 3-D keypoint array, got shape {:?}", arr.shape()), + ) }) } /// Read only the first dimension of an NPY header (the frame count) without /// loading the entire file into memory. fn peek_npy_first_dim(path: &Path) -> Result { - // Minimum viable NPY header parse: magic + version + header_len + header. use std::io::{BufReader, Read}; - let f = std::fs::File::open(path)?; + let f = std::fs::File::open(path) + .map_err(|e| DatasetError::io_error(path, e))?; let mut reader = BufReader::new(f); let mut magic = [0u8; 6]; - reader.read_exact(&mut magic)?; + reader.read_exact(&mut magic) + .map_err(|e| DatasetError::io_error(path, e))?; if &magic != b"\x93NUMPY" { - return Err(DatasetError::Format(format!( - "Not a valid NPY file: {}", - path.display() - ))); + return Err(DatasetError::invalid_format(path, "Not a valid NPY file")); } let mut version = [0u8; 2]; - reader.read_exact(&mut version)?; + reader.read_exact(&mut version) + .map_err(|e| DatasetError::io_error(path, e))?; // Header length field: 2 bytes in v1, 4 bytes in v2 let header_len: usize = if version[0] == 1 { let mut buf = [0u8; 2]; - reader.read_exact(&mut buf)?; + reader.read_exact(&mut buf) + .map_err(|e| DatasetError::io_error(path, e))?; u16::from_le_bytes(buf) as usize } else { let mut buf = [0u8; 4]; - reader.read_exact(&mut buf)?; + reader.read_exact(&mut buf) + .map_err(|e| DatasetError::io_error(path, e))?; u32::from_le_bytes(buf) as usize }; let mut header = vec![0u8; header_len]; - reader.read_exact(&mut header)?; + reader.read_exact(&mut header) + .map_err(|e| DatasetError::io_error(path, e))?; let header_str = String::from_utf8_lossy(&header); // Parse the shape tuple using a simple substring search. - // Example header: "{'descr': ' Result { } } - Err(DatasetError::Format(format!( - "Cannot parse shape from NPY header in {}", - path.display() - ))) + Err(DatasetError::invalid_format(path, "Cannot parse shape from NPY header")) } /// Parse the numeric suffix of a directory name like `S01` → `1` or `A12` → `12`. @@ -711,7 +709,7 @@ impl CsiDataset for SyntheticCsiDataset { fn get(&self, idx: usize) -> Result { if idx >= self.num_samples { return Err(DatasetError::IndexOutOfBounds { - idx, + index: idx, len: self.num_samples, }); } @@ -755,34 +753,6 @@ impl CsiDataset for SyntheticCsiDataset { } } -// --------------------------------------------------------------------------- -// DatasetError -// --------------------------------------------------------------------------- - -/// Errors produced by dataset operations. -#[derive(Debug, Error)] -pub enum DatasetError { - /// Requested index is outside the valid range. - #[error("Index {idx} out of bounds (dataset has {len} samples)")] - IndexOutOfBounds { idx: usize, len: usize }, - - /// An underlying file-system error occurred. - #[error("IO error: {0}")] - Io(#[from] std::io::Error), - - /// The file exists but does not match the expected format. - #[error("File format error: {0}")] - Format(String), - - /// The loaded array has a different subcarrier count than required. - #[error("Subcarrier count mismatch: expected {expected}, got {actual}")] - SubcarrierMismatch { expected: usize, actual: usize }, - - /// The specified root directory does not exist. - #[error("Directory not found: {path}")] - DirectoryNotFound { path: String }, -} - // --------------------------------------------------------------------------- // Tests // --------------------------------------------------------------------------- @@ -800,8 +770,14 @@ mod tests { let ds = SyntheticCsiDataset::new(10, cfg.clone()); let s = ds.get(0).unwrap(); - assert_eq!(s.amplitude.shape(), &[cfg.window_frames, cfg.num_antennas_tx, cfg.num_antennas_rx, cfg.num_subcarriers]); - assert_eq!(s.phase.shape(), &[cfg.window_frames, cfg.num_antennas_tx, cfg.num_antennas_rx, cfg.num_subcarriers]); + assert_eq!( + s.amplitude.shape(), + &[cfg.window_frames, cfg.num_antennas_tx, cfg.num_antennas_rx, cfg.num_subcarriers] + ); + assert_eq!( + s.phase.shape(), + &[cfg.window_frames, cfg.num_antennas_tx, cfg.num_antennas_rx, cfg.num_subcarriers] + ); assert_eq!(s.keypoints.shape(), &[cfg.num_keypoints, 2]); assert_eq!(s.keypoint_visibility.shape(), &[cfg.num_keypoints]); } @@ -812,7 +788,11 @@ mod tests { let ds = SyntheticCsiDataset::new(10, cfg); let s0a = ds.get(3).unwrap(); let s0b = ds.get(3).unwrap(); - assert_abs_diff_eq!(s0a.amplitude[[0, 0, 0, 0]], s0b.amplitude[[0, 0, 0, 0]], epsilon = 1e-7); + assert_abs_diff_eq!( + s0a.amplitude[[0, 0, 0, 0]], + s0b.amplitude[[0, 0, 0, 0]], + epsilon = 1e-7 + ); assert_abs_diff_eq!(s0a.keypoints[[5, 0]], s0b.keypoints[[5, 0]], epsilon = 1e-7); } @@ -829,7 +809,10 @@ mod tests { #[test] fn synthetic_out_of_bounds() { let ds = SyntheticCsiDataset::new(5, SyntheticConfig::default()); - assert!(matches!(ds.get(5), Err(DatasetError::IndexOutOfBounds { idx: 5, len: 5 }))); + assert!(matches!( + ds.get(5), + Err(DatasetError::IndexOutOfBounds { index: 5, len: 5 }) + )); } #[test] @@ -861,7 +844,7 @@ mod tests { #[test] fn synthetic_all_joints_visible() { let cfg = SyntheticConfig::default(); - let ds = SyntheticCsiDataset::new(3, cfg.clone()); + let ds = SyntheticCsiDataset::new(3, cfg); let s = ds.get(0).unwrap(); assert!(s.keypoint_visibility.iter().all(|&v| (v - 2.0).abs() < 1e-6)); } diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/error.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/error.rs index d7f3fcd..8c635c5 100644 --- a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/error.rs +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/error.rs @@ -15,9 +15,10 @@ use thiserror::Error; use std::path::PathBuf; -// Import module-local error types so TrainError can wrap them via #[from]. -use crate::config::ConfigError; -use crate::dataset::DatasetError; +// Import module-local error types so TrainError can wrap them via #[from], +// and re-export them so `lib.rs` can forward them from `error::*`. +pub use crate::config::ConfigError; +pub use crate::dataset::DatasetError; // --------------------------------------------------------------------------- // Top-level training error @@ -41,14 +42,18 @@ pub enum TrainError { #[error("Dataset error: {0}")] Dataset(#[from] DatasetError), - /// An underlying I/O error not covered by a more specific variant. - #[error("I/O error: {0}")] - Io(#[from] std::io::Error), - /// JSON (de)serialization error. #[error("JSON error: {0}")] Json(#[from] serde_json::Error), + /// An underlying I/O error not wrapped by Config or Dataset. + /// + /// Note: [`std::io::Error`] cannot be wrapped via `#[from]` here because + /// both [`ConfigError`] and [`DatasetError`] already implement + /// `From`. Callers should convert via those types instead. + #[error("I/O error: {0}")] + Io(String), + /// An operation was attempted on an empty dataset. #[error("Dataset is empty")] EmptyDataset, @@ -113,3 +118,67 @@ impl TrainError { TrainError::ShapeMismatch { expected, actual } } } + +// --------------------------------------------------------------------------- +// SubcarrierError +// --------------------------------------------------------------------------- + +/// Errors produced by the subcarrier resampling / interpolation functions. +/// +/// These are separate from [`DatasetError`] because subcarrier operations are +/// also usable outside the dataset loading pipeline (e.g. in real-time +/// inference preprocessing). +#[derive(Debug, Error)] +pub enum SubcarrierError { + /// The source or destination subcarrier count is zero. + #[error("Subcarrier count must be >= 1, got {count}")] + ZeroCount { + /// The offending count. + count: usize, + }, + + /// The input array's last dimension does not match the declared source count. + #[error( + "Subcarrier shape mismatch: last dimension is {actual_sc} \ + but `src_n` was declared as {expected_sc} (full shape: {shape:?})" + )] + InputShapeMismatch { + /// Expected subcarrier count (as declared by the caller). + expected_sc: usize, + /// Actual last-dimension size of the input array. + actual_sc: usize, + /// Full shape of the input array. + shape: Vec, + }, + + /// The requested interpolation method is not yet implemented. + #[error("Interpolation method `{method}` is not implemented")] + MethodNotImplemented { + /// Human-readable name of the unsupported method. + method: String, + }, + + /// `src_n == dst_n` — no resampling is needed. + /// + /// Callers should check [`TrainingConfig::needs_subcarrier_interp`] before + /// calling the interpolation routine. + /// + /// [`TrainingConfig::needs_subcarrier_interp`]: + /// crate::config::TrainingConfig::needs_subcarrier_interp + #[error("src_n == dst_n == {count}; no interpolation needed")] + NopInterpolation { + /// The equal count. + count: usize, + }, + + /// A numerical error during interpolation (e.g. division by zero). + #[error("Numerical error: {0}")] + NumericalError(String), +} + +impl SubcarrierError { + /// Construct a [`SubcarrierError::NumericalError`]. + pub fn numerical>(msg: S) -> Self { + SubcarrierError::NumericalError(msg.into()) + } +} diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/lib.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/lib.rs index b55d787..d1b915c 100644 --- a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/lib.rs +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/lib.rs @@ -52,9 +52,9 @@ pub mod subcarrier; pub mod trainer; // Convenient re-exports at the crate root. -pub use config::{ConfigError, TrainingConfig}; -pub use dataset::{CsiDataset, CsiSample, DataLoader, DatasetError, MmFiDataset, SyntheticCsiDataset, SyntheticConfig}; -pub use error::{TrainError, TrainResult}; +pub use config::TrainingConfig; +pub use dataset::{CsiDataset, CsiSample, DataLoader, MmFiDataset, SyntheticCsiDataset, SyntheticConfig}; +pub use error::{ConfigError, DatasetError, SubcarrierError, TrainError, TrainResult}; pub use subcarrier::{compute_interp_weights, interpolate_subcarriers, select_subcarriers_by_variance}; /// Crate version string. diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/losses.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/losses.rs index 0fe343c..32b50c4 100644 --- a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/losses.rs +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/losses.rs @@ -26,9 +26,12 @@ use tch::{Kind, Reduction, Tensor}; // Public types // ───────────────────────────────────────────────────────────────────────────── -/// Scalar components produced by a single forward pass through the combined loss. +/// Scalar components produced by a single forward pass through [`WiFiDensePoseLoss::forward`]. +/// +/// Contains `f32` scalar values extracted from the computation graph for +/// logging and checkpointing (they are not used for back-propagation). #[derive(Debug, Clone)] -pub struct LossOutput { +pub struct WiFiLossComponents { /// Total weighted loss value (scalar, in ℝ≥0). pub total: f32, /// Keypoint heatmap MSE loss component. @@ -159,7 +162,7 @@ impl WiFiDensePoseLoss { // ── 2. UV regression: Smooth-L1 masked by foreground pixels ──────── // Foreground mask: pixels where target part ≠ 0, shape [B, H, W]. - let fg_mask = target_int.not_equal(0); + let fg_mask = target_int.not_equal(0_i64); // Expand to [B, 1, H, W] then broadcast to [B, 48, H, W]. let fg_mask_f = fg_mask .unsqueeze(1) @@ -218,7 +221,7 @@ impl WiFiDensePoseLoss { target_uv: Option<&Tensor>, student_features: Option<&Tensor>, teacher_features: Option<&Tensor>, - ) -> (Tensor, LossOutput) { + ) -> (Tensor, WiFiLossComponents) { let mut details = HashMap::new(); // ── Keypoint loss (always computed) ─────────────────────────────── @@ -243,7 +246,7 @@ impl WiFiDensePoseLoss { let part_val = part_loss.double_value(&[]) as f32; // UV loss (foreground masked) - let fg_mask = target_int.not_equal(0); + let fg_mask = target_int.not_equal(0_i64); let fg_mask_f = fg_mask .unsqueeze(1) .expand_as(pu) @@ -280,7 +283,7 @@ impl WiFiDensePoseLoss { let total_val = total.double_value(&[]) as f32; - let output = LossOutput { + let output = WiFiLossComponents { total: total_val, keypoint: kp_val as f32, densepose: dp_val, diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/metrics.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/metrics.rs index eb96df2..8b2bd1a 100644 --- a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/metrics.rs +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/metrics.rs @@ -298,6 +298,415 @@ fn bounding_box_diagonal( (w * w + h * h).sqrt() } +// --------------------------------------------------------------------------- +// Per-sample PCK and OKS free functions (required by the training evaluator) +// --------------------------------------------------------------------------- + +// Keypoint indices for torso-diameter PCK normalisation (COCO ordering). +const IDX_LEFT_HIP: usize = 11; +const IDX_RIGHT_SHOULDER: usize = 6; + +/// Compute the torso diameter for PCK normalisation. +/// +/// Torso diameter = ||left_hip − right_shoulder||₂ in normalised [0,1] space. +/// Returns 0.0 when either landmark is invisible, indicating the caller +/// should fall back to a unit normaliser. +fn torso_diameter_pck(gt_kpts: &Array2, visibility: &Array1) -> f32 { + if visibility[IDX_LEFT_HIP] < 0.5 || visibility[IDX_RIGHT_SHOULDER] < 0.5 { + return 0.0; + } + let dx = gt_kpts[[IDX_LEFT_HIP, 0]] - gt_kpts[[IDX_RIGHT_SHOULDER, 0]]; + let dy = gt_kpts[[IDX_LEFT_HIP, 1]] - gt_kpts[[IDX_RIGHT_SHOULDER, 1]]; + (dx * dx + dy * dy).sqrt() +} + +/// Compute PCK (Percentage of Correct Keypoints) for a single frame. +/// +/// A keypoint `j` is "correct" when its Euclidean distance to the ground +/// truth is within `threshold × torso_diameter` (left_hip ↔ right_shoulder). +/// When the torso reference joints are not visible the threshold is applied +/// directly in normalised [0,1] coordinate space (unit normaliser). +/// +/// Only keypoints with `visibility[j] > 0` contribute to the count. +/// +/// # Returns +/// `(correct_count, total_count, pck_value)` where `pck_value ∈ [0,1]`; +/// returns `(0, 0, 0.0)` when no keypoint is visible. +pub fn compute_pck( + pred_kpts: &Array2, + gt_kpts: &Array2, + visibility: &Array1, + threshold: f32, +) -> (usize, usize, f32) { + let torso = torso_diameter_pck(gt_kpts, visibility); + let norm = if torso > 1e-6 { torso } else { 1.0_f32 }; + let dist_threshold = threshold * norm; + + let mut correct = 0_usize; + let mut total = 0_usize; + + for j in 0..17 { + if visibility[j] < 0.5 { + continue; + } + total += 1; + let dx = pred_kpts[[j, 0]] - gt_kpts[[j, 0]]; + let dy = pred_kpts[[j, 1]] - gt_kpts[[j, 1]]; + let dist = (dx * dx + dy * dy).sqrt(); + if dist <= dist_threshold { + correct += 1; + } + } + + let pck = if total > 0 { + correct as f32 / total as f32 + } else { + 0.0 + }; + (correct, total, pck) +} + +/// Compute per-joint PCK over a batch of frames. +/// +/// Returns `[f32; 17]` where entry `j` is the fraction of frames in which +/// joint `j` was both visible and correctly predicted at the given threshold. +pub fn compute_per_joint_pck( + pred_batch: &[Array2], + gt_batch: &[Array2], + vis_batch: &[Array1], + threshold: f32, +) -> [f32; 17] { + assert_eq!(pred_batch.len(), gt_batch.len()); + assert_eq!(pred_batch.len(), vis_batch.len()); + + let mut correct = [0_usize; 17]; + let mut total = [0_usize; 17]; + + for (pred, (gt, vis)) in pred_batch + .iter() + .zip(gt_batch.iter().zip(vis_batch.iter())) + { + let torso = torso_diameter_pck(gt, vis); + let norm = if torso > 1e-6 { torso } else { 1.0_f32 }; + let dist_thr = threshold * norm; + + for j in 0..17 { + if vis[j] < 0.5 { + continue; + } + total[j] += 1; + let dx = pred[[j, 0]] - gt[[j, 0]]; + let dy = pred[[j, 1]] - gt[[j, 1]]; + let dist = (dx * dx + dy * dy).sqrt(); + if dist <= dist_thr { + correct[j] += 1; + } + } + } + + let mut result = [0.0_f32; 17]; + for j in 0..17 { + result[j] = if total[j] > 0 { + correct[j] as f32 / total[j] as f32 + } else { + 0.0 + }; + } + result +} + +/// Compute Object Keypoint Similarity (OKS) for a single person. +/// +/// COCO OKS formula: +/// +/// ```text +/// OKS = Σᵢ exp(-dᵢ² / (2·s²·kᵢ²)) · δ(vᵢ>0) / Σᵢ δ(vᵢ>0) +/// ``` +/// +/// - `dᵢ` – Euclidean distance between predicted and GT keypoint `i` +/// - `s` – object scale (`object_scale`; pass `1.0` when bbox is unknown) +/// - `kᵢ` – per-joint sigma from [`COCO_KP_SIGMAS`] +/// +/// Returns `0.0` when no keypoints are visible. +pub fn compute_oks( + pred_kpts: &Array2, + gt_kpts: &Array2, + visibility: &Array1, + object_scale: f32, +) -> f32 { + let s_sq = object_scale * object_scale; + let mut numerator = 0.0_f32; + let mut denominator = 0.0_f32; + + for j in 0..17 { + if visibility[j] < 0.5 { + continue; + } + denominator += 1.0; + let dx = pred_kpts[[j, 0]] - gt_kpts[[j, 0]]; + let dy = pred_kpts[[j, 1]] - gt_kpts[[j, 1]]; + let d_sq = dx * dx + dy * dy; + let k = COCO_KP_SIGMAS[j]; + let exp_arg = -d_sq / (2.0 * s_sq * k * k); + numerator += exp_arg.exp(); + } + + if denominator > 0.0 { + numerator / denominator + } else { + 0.0 + } +} + +/// Aggregate result type returned by [`aggregate_metrics`]. +/// +/// Extends the simpler [`MetricsResult`] with per-joint and per-frame details +/// needed for the full COCO-style evaluation report. +#[derive(Debug, Clone, Default)] +pub struct AggregatedMetrics { + /// PCK@0.2 averaged over all frames. + pub pck_02: f32, + /// PCK@0.5 averaged over all frames. + pub pck_05: f32, + /// Per-joint PCK@0.2 `[17]`. + pub per_joint_pck: [f32; 17], + /// Mean OKS over all frames. + pub oks: f32, + /// Per-frame OKS values. + pub oks_values: Vec, + /// Number of frames evaluated. + pub frames_evaluated: usize, + /// Total number of visible keypoints evaluated. + pub keypoints_evaluated: usize, +} + +/// Aggregate PCK and OKS metrics over the full evaluation set. +/// +/// `object_scale` is fixed at `1.0` (bounding boxes are not tracked in the +/// WiFi-DensePose CSI evaluation pipeline). +pub fn aggregate_metrics( + pred_kpts: &[Array2], + gt_kpts: &[Array2], + visibility: &[Array1], +) -> AggregatedMetrics { + assert_eq!(pred_kpts.len(), gt_kpts.len()); + assert_eq!(pred_kpts.len(), visibility.len()); + + let n = pred_kpts.len(); + if n == 0 { + return AggregatedMetrics::default(); + } + + let mut pck02_sum = 0.0_f32; + let mut pck05_sum = 0.0_f32; + let mut oks_values = Vec::with_capacity(n); + let mut total_kps = 0_usize; + + for i in 0..n { + let (_, tot, pck02) = compute_pck(&pred_kpts[i], >_kpts[i], &visibility[i], 0.2); + let (_, _, pck05) = compute_pck(&pred_kpts[i], >_kpts[i], &visibility[i], 0.5); + let oks = compute_oks(&pred_kpts[i], >_kpts[i], &visibility[i], 1.0); + + pck02_sum += pck02; + pck05_sum += pck05; + oks_values.push(oks); + total_kps += tot; + } + + let per_joint_pck = compute_per_joint_pck(pred_kpts, gt_kpts, visibility, 0.2); + let mean_oks = oks_values.iter().copied().sum::() / n as f32; + + AggregatedMetrics { + pck_02: pck02_sum / n as f32, + pck_05: pck05_sum / n as f32, + per_joint_pck, + oks: mean_oks, + oks_values, + frames_evaluated: n, + keypoints_evaluated: total_kps, + } +} + +// --------------------------------------------------------------------------- +// Hungarian algorithm (min-cost bipartite matching) +// --------------------------------------------------------------------------- + +/// Cost matrix entry for keypoint-based person assignment. +#[derive(Debug, Clone)] +pub struct AssignmentEntry { + /// Index of the predicted person. + pub pred_idx: usize, + /// Index of the ground-truth person. + pub gt_idx: usize, + /// Assignment cost (lower = better match). + pub cost: f32, +} + +/// Solve the optimal linear assignment problem using the Hungarian algorithm. +/// +/// Returns the minimum-cost complete matching as a list of `(pred_idx, gt_idx)` +/// pairs. For non-square matrices exactly `min(n_pred, n_gt)` pairs are +/// returned (the shorter side is fully matched). +/// +/// # Algorithm +/// +/// Implements the classical O(n³) potential-based Hungarian / Kuhn-Munkres +/// algorithm: +/// +/// 1. Pads non-square cost matrices to square with a large sentinel value. +/// 2. Processes each row by finding the minimum-cost augmenting path using +/// Dijkstra-style potential relaxation. +/// 3. Strips padded assignments before returning. +pub fn hungarian_assignment(cost_matrix: &[Vec]) -> Vec<(usize, usize)> { + if cost_matrix.is_empty() { + return vec![]; + } + let n_rows = cost_matrix.len(); + let n_cols = cost_matrix[0].len(); + if n_cols == 0 { + return vec![]; + } + + let n = n_rows.max(n_cols); + let inf = f64::MAX / 2.0; + + // Build a square cost matrix padded with `inf`. + let mut c = vec![vec![inf; n]; n]; + for i in 0..n_rows { + for j in 0..n_cols { + c[i][j] = cost_matrix[i][j] as f64; + } + } + + // u[i]: potential for row i (1-indexed; index 0 unused). + // v[j]: potential for column j (1-indexed; index 0 = dummy source). + let mut u = vec![0.0_f64; n + 1]; + let mut v = vec![0.0_f64; n + 1]; + // p[j]: 1-indexed row assigned to column j (0 = unassigned). + let mut p = vec![0_usize; n + 1]; + // way[j]: predecessor column j in the current augmenting path. + let mut way = vec![0_usize; n + 1]; + + for i in 1..=n { + // Set the dummy source (column 0) to point to the current row. + p[0] = i; + let mut j0 = 0_usize; + + let mut min_val = vec![inf; n + 1]; + let mut used = vec![false; n + 1]; + + // Shortest augmenting path with potential updates (Dijkstra-like). + loop { + used[j0] = true; + let i0 = p[j0]; // 1-indexed row currently "in" column j0 + let mut delta = inf; + let mut j1 = 0_usize; + + for j in 1..=n { + if !used[j] { + let val = c[i0 - 1][j - 1] - u[i0] - v[j]; + if val < min_val[j] { + min_val[j] = val; + way[j] = j0; + } + if min_val[j] < delta { + delta = min_val[j]; + j1 = j; + } + } + } + + // Update potentials. + for j in 0..=n { + if used[j] { + u[p[j]] += delta; + v[j] -= delta; + } else { + min_val[j] -= delta; + } + } + + j0 = j1; + if p[j0] == 0 { + break; // free column found → augmenting path complete + } + } + + // Trace back and augment the matching. + loop { + p[j0] = p[way[j0]]; + j0 = way[j0]; + if j0 == 0 { + break; + } + } + } + + // Collect real (non-padded) assignments. + let mut assignments = Vec::new(); + for j in 1..=n { + if p[j] != 0 { + let pred_idx = p[j] - 1; // back to 0-indexed + let gt_idx = j - 1; + if pred_idx < n_rows && gt_idx < n_cols { + assignments.push((pred_idx, gt_idx)); + } + } + } + assignments.sort_unstable_by_key(|&(pred, _)| pred); + assignments +} + +/// Build the OKS cost matrix for multi-person matching. +/// +/// Cost between predicted person `i` and GT person `j` is `1 − OKS(pred_i, gt_j)`. +pub fn build_oks_cost_matrix( + pred_persons: &[Array2], + gt_persons: &[Array2], + visibility: &[Array1], +) -> Vec> { + let n_pred = pred_persons.len(); + let n_gt = gt_persons.len(); + assert_eq!(gt_persons.len(), visibility.len()); + + let mut matrix = vec![vec![1.0_f32; n_gt]; n_pred]; + for i in 0..n_pred { + for j in 0..n_gt { + let oks = compute_oks(&pred_persons[i], >_persons[j], &visibility[j], 1.0); + matrix[i][j] = 1.0 - oks; + } + } + matrix +} + +/// Find an augmenting path in the bipartite matching graph. +/// +/// Used internally for unit-capacity matching checks. In the main training +/// pipeline `hungarian_assignment` is preferred for its optimal cost guarantee. +/// +/// `adj[u]` is the list of `(v, weight)` edges from left-node `u`. +/// `matching[v]` gives the current left-node matched to right-node `v`. +pub fn find_augmenting_path( + adj: &[Vec<(usize, f32)>], + source: usize, + _sink: usize, + visited: &mut Vec, + matching: &mut Vec>, +) -> bool { + for &(v, _weight) in &adj[source] { + if !visited[v] { + visited[v] = true; + if matching[v].is_none() + || find_augmenting_path(adj, matching[v].unwrap(), _sink, visited, matching) + { + matching[v] = Some(source); + return true; + } + } + } + false +} + // --------------------------------------------------------------------------- // Tests // --------------------------------------------------------------------------- @@ -403,4 +812,173 @@ mod tests { assert!(good.is_better_than(&bad)); assert!(!bad.is_better_than(&good)); } + + // ── compute_pck free function ───────────────────────────────────────────── + + fn all_visible_17() -> Array1 { + Array1::ones(17) + } + + fn uniform_kpts_17(x: f32, y: f32) -> Array2 { + let mut arr = Array2::zeros((17, 2)); + for j in 0..17 { + arr[[j, 0]] = x; + arr[[j, 1]] = y; + } + arr + } + + #[test] + fn compute_pck_perfect_is_one() { + let kpts = uniform_kpts_17(0.5, 0.5); + let vis = all_visible_17(); + let (correct, total, pck) = compute_pck(&kpts, &kpts, &vis, 0.2); + assert_eq!(correct, 17); + assert_eq!(total, 17); + assert_abs_diff_eq!(pck, 1.0_f32, epsilon = 1e-6); + } + + #[test] + fn compute_pck_no_visible_is_zero() { + let kpts = uniform_kpts_17(0.5, 0.5); + let vis = Array1::zeros(17); + let (correct, total, pck) = compute_pck(&kpts, &kpts, &vis, 0.2); + assert_eq!(correct, 0); + assert_eq!(total, 0); + assert_eq!(pck, 0.0); + } + + // ── compute_oks free function ───────────────────────────────────────────── + + #[test] + fn compute_oks_identical_is_one() { + let kpts = uniform_kpts_17(0.5, 0.5); + let vis = all_visible_17(); + let oks = compute_oks(&kpts, &kpts, &vis, 1.0); + assert_abs_diff_eq!(oks, 1.0_f32, epsilon = 1e-5); + } + + #[test] + fn compute_oks_no_visible_is_zero() { + let kpts = uniform_kpts_17(0.5, 0.5); + let vis = Array1::zeros(17); + let oks = compute_oks(&kpts, &kpts, &vis, 1.0); + assert_eq!(oks, 0.0); + } + + #[test] + fn compute_oks_in_unit_interval() { + let pred = uniform_kpts_17(0.4, 0.6); + let gt = uniform_kpts_17(0.5, 0.5); + let vis = all_visible_17(); + let oks = compute_oks(&pred, >, &vis, 1.0); + assert!(oks >= 0.0 && oks <= 1.0, "OKS={oks} outside [0,1]"); + } + + // ── aggregate_metrics ──────────────────────────────────────────────────── + + #[test] + fn aggregate_metrics_perfect() { + let kpts: Vec> = (0..4).map(|_| uniform_kpts_17(0.5, 0.5)).collect(); + let vis: Vec> = (0..4).map(|_| all_visible_17()).collect(); + let result = aggregate_metrics(&kpts, &kpts, &vis); + assert_eq!(result.frames_evaluated, 4); + assert_abs_diff_eq!(result.pck_02, 1.0_f32, epsilon = 1e-5); + assert_abs_diff_eq!(result.oks, 1.0_f32, epsilon = 1e-5); + } + + #[test] + fn aggregate_metrics_empty_is_default() { + let result = aggregate_metrics(&[], &[], &[]); + assert_eq!(result.frames_evaluated, 0); + assert_eq!(result.oks, 0.0); + } + + // ── hungarian_assignment ───────────────────────────────────────────────── + + #[test] + fn hungarian_identity_2x2_assigns_diagonal() { + // [[0, 1], [1, 0]] → optimal (0→0, 1→1) with total cost 0. + let cost = vec![vec![0.0_f32, 1.0], vec![1.0, 0.0]]; + let mut assignments = hungarian_assignment(&cost); + assignments.sort_unstable(); + assert_eq!(assignments, vec![(0, 0), (1, 1)]); + } + + #[test] + fn hungarian_swapped_2x2() { + // [[1, 0], [0, 1]] → optimal (0→1, 1→0) with total cost 0. + let cost = vec![vec![1.0_f32, 0.0], vec![0.0, 1.0]]; + let mut assignments = hungarian_assignment(&cost); + assignments.sort_unstable(); + assert_eq!(assignments, vec![(0, 1), (1, 0)]); + } + + #[test] + fn hungarian_3x3_identity() { + let cost = vec![ + vec![0.0_f32, 10.0, 10.0], + vec![10.0, 0.0, 10.0], + vec![10.0, 10.0, 0.0], + ]; + let mut assignments = hungarian_assignment(&cost); + assignments.sort_unstable(); + assert_eq!(assignments, vec![(0, 0), (1, 1), (2, 2)]); + } + + #[test] + fn hungarian_empty_matrix() { + assert!(hungarian_assignment(&[]).is_empty()); + } + + #[test] + fn hungarian_single_element() { + let assignments = hungarian_assignment(&[vec![0.5_f32]]); + assert_eq!(assignments, vec![(0, 0)]); + } + + #[test] + fn hungarian_rectangular_fewer_gt_than_pred() { + // 3 predicted, 2 GT → only 2 assignments. + let cost = vec![ + vec![5.0_f32, 9.0], + vec![4.0, 6.0], + vec![3.0, 1.0], + ]; + let assignments = hungarian_assignment(&cost); + assert_eq!(assignments.len(), 2); + // GT indices must be unique. + let gt_set: std::collections::HashSet = + assignments.iter().map(|&(_, g)| g).collect(); + assert_eq!(gt_set.len(), 2); + } + + // ── OKS cost matrix ─────────────────────────────────────────────────────── + + #[test] + fn oks_cost_matrix_diagonal_near_zero() { + let persons: Vec> = (0..3) + .map(|i| uniform_kpts_17(i as f32 * 0.3, 0.5)) + .collect(); + let vis: Vec> = (0..3).map(|_| all_visible_17()).collect(); + let mat = build_oks_cost_matrix(&persons, &persons, &vis); + for i in 0..3 { + assert!(mat[i][i] < 1e-4, "cost[{i}][{i}]={} should be ≈0", mat[i][i]); + } + } + + // ── find_augmenting_path (helper smoke test) ────────────────────────────── + + #[test] + fn find_augmenting_path_basic() { + let adj: Vec> = vec![ + vec![(0, 1.0)], + vec![(1, 1.0)], + ]; + let mut matching = vec![None; 2]; + let mut visited = vec![false; 2]; + let found = find_augmenting_path(&adj, 0, 2, &mut visited, &mut matching); + assert!(found); + assert_eq!(matching[0], Some(0)); + } } diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/model.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/model.rs index cfeba62..0aa6a90 100644 --- a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/model.rs +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/model.rs @@ -1,16 +1,713 @@ -//! WiFi-DensePose model definition and construction. +//! WiFi-DensePose end-to-end model using tch-rs (PyTorch Rust bindings). //! -//! This module will be implemented by the trainer agent. It currently provides -//! the public interface stubs so that the crate compiles as a whole. +//! # Architecture +//! +//! ```text +//! CSI amplitude + phase +//! │ +//! ▼ +//! ┌─────────────────────┐ +//! │ PhaseSanitizerNet │ differentiable conjugate multiplication +//! └─────────────────────┘ +//! │ +//! ▼ +//! ┌────────────────────────────┐ +//! │ ModalityTranslatorNet │ CSI → spatial pseudo-image [B, 3, 48, 48] +//! └────────────────────────────┘ +//! │ +//! ▼ +//! ┌─────────────────┐ +//! │ ResNet18-like │ [B, 256, H/4, W/4] feature maps +//! │ Backbone │ +//! └─────────────────┘ +//! │ +//! ┌───┴───┐ +//! │ │ +//! ▼ ▼ +//! ┌─────┐ ┌────────────┐ +//! │ KP │ │ DensePose │ +//! │ Head│ │ Head │ +//! └─────┘ └────────────┘ +//! [B,17,H,W] [B,25,H,W] + [B,48,H,W] +//! ``` +//! +//! # No pre-trained weights +//! +//! The backbone uses a ResNet18-compatible architecture built purely with +//! `tch::nn`. Weights are initialised from scratch (Kaiming uniform by +//! default from tch). Pre-trained ImageNet weights are not loaded because +//! network access is not guaranteed during training runs. -/// Placeholder for the compiled model handle. +use std::path::Path; +use tch::{nn, nn::Module, nn::ModuleT, Device, Kind, Tensor}; + +use crate::config::TrainingConfig; + +// --------------------------------------------------------------------------- +// Public output type +// --------------------------------------------------------------------------- + +/// Outputs produced by a single forward pass of [`WiFiDensePoseModel`]. +pub struct ModelOutput { + /// Keypoint heatmaps: `[B, 17, H, W]`. + pub keypoints: Tensor, + /// Body-part logits (24 parts + background): `[B, 25, H, W]`. + pub part_logits: Tensor, + /// UV coordinates (24 × 2 channels interleaved): `[B, 48, H, W]`. + pub uv_coords: Tensor, + /// Backbone feature map used for cross-modal transfer loss: `[B, 256, H/4, W/4]`. + pub features: Tensor, +} + +// --------------------------------------------------------------------------- +// WiFiDensePoseModel +// --------------------------------------------------------------------------- + +/// Complete WiFi-DensePose model. /// -/// The real implementation wraps a `tch::CModule` or a custom `nn::Module`. -pub struct DensePoseModel; +/// Input: CSI amplitude and phase tensors with shape +/// `[B, T*n_tx*n_rx, n_sub]` (flattened antenna-time dimension). +/// +/// Output: [`ModelOutput`] with keypoints and DensePose predictions. +pub struct WiFiDensePoseModel { + vs: nn::VarStore, + config: TrainingConfig, +} -impl DensePoseModel { - /// Construct a new model from the given number of subcarriers and keypoints. - pub fn new(_num_subcarriers: usize, _num_keypoints: usize) -> Self { - DensePoseModel +// Internal model components stored in the VarStore. +// We use sub-paths inside the single VarStore to keep all parameters in +// one serialisable store. + +impl WiFiDensePoseModel { + /// Create a new model on `device`. + /// + /// All sub-networks are constructed and their parameters registered in the + /// internal `VarStore`. + pub fn new(config: &TrainingConfig, device: Device) -> Self { + let vs = nn::VarStore::new(device); + WiFiDensePoseModel { + vs, + config: config.clone(), + } + } + + /// Forward pass with gradient tracking (training mode). + /// + /// # Arguments + /// + /// - `amplitude`: `[B, T*n_tx*n_rx, n_sub]` + /// - `phase`: `[B, T*n_tx*n_rx, n_sub]` + pub fn forward_train(&self, amplitude: &Tensor, phase: &Tensor) -> ModelOutput { + self.forward_impl(amplitude, phase, true) + } + + /// Forward pass without gradient tracking (inference mode). + pub fn forward_inference(&self, amplitude: &Tensor, phase: &Tensor) -> ModelOutput { + tch::no_grad(|| self.forward_impl(amplitude, phase, false)) + } + + /// Save model weights to `path`. + /// + /// # Errors + /// + /// Returns an error if the file cannot be written. + pub fn save(&self, path: &Path) -> Result<(), Box> { + self.vs.save(path)?; + Ok(()) + } + + /// Load model weights from `path`. + /// + /// # Errors + /// + /// Returns an error if the file cannot be read or the weights are + /// incompatible with the model architecture. + pub fn load( + path: &Path, + config: &TrainingConfig, + device: Device, + ) -> Result> { + let mut model = Self::new(config, device); + // Build parameter graph first so load can find named tensors. + let _dummy_amp = Tensor::zeros( + [1, 1, config.num_subcarriers as i64], + (Kind::Float, device), + ); + let _dummy_phase = _dummy_amp.shallow_clone(); + let _ = model.forward_impl(&_dummy_amp, &_dummy_phase, false); + model.vs.load(path)?; + Ok(model) + } + + /// Return all trainable variable tensors. + pub fn trainable_variables(&self) -> Vec { + self.vs + .trainable_variables() + .into_iter() + .map(|t| t.shallow_clone()) + .collect() + } + + /// Count total trainable parameters. + pub fn num_parameters(&self) -> usize { + self.vs + .trainable_variables() + .iter() + .map(|t| t.numel() as usize) + .sum() + } + + /// Access the internal `VarStore` (e.g. to create an optimizer). + pub fn var_store(&self) -> &nn::VarStore { + &self.vs + } + + /// Mutable access to the internal `VarStore`. + pub fn var_store_mut(&mut self) -> &mut nn::VarStore { + &mut self.vs + } + + // ------------------------------------------------------------------ + // Internal forward implementation + // ------------------------------------------------------------------ + + fn forward_impl( + &self, + amplitude: &Tensor, + phase: &Tensor, + train: bool, + ) -> ModelOutput { + let root = self.vs.root(); + let cfg = &self.config; + + // ── Phase sanitization ─────────────────────────────────────────── + let clean_phase = phase_sanitize(phase); + + // ── Modality translation ───────────────────────────────────────── + // Flatten antenna-time and subcarrier dimensions → [B, flat] + let batch = amplitude.size()[0]; + let flat_amp = amplitude.reshape([batch, -1]); + let flat_phase = clean_phase.reshape([batch, -1]); + let input_size = flat_amp.size()[1]; + + let spatial = modality_translate(&root, &flat_amp, &flat_phase, input_size, train); + // spatial: [B, 3, 48, 48] + + // ── ResNet18-like backbone ──────────────────────────────────────── + let (features, feat_h, feat_w) = resnet18_backbone(&root, &spatial, train, cfg.backbone_channels as i64); + // features: [B, 256, 12, 12] + + // ── Keypoint head ──────────────────────────────────────────────── + let kp_h = cfg.heatmap_size as i64; + let kp_w = kp_h; + let keypoints = keypoint_head(&root, &features, cfg.num_keypoints as i64, (kp_h, kp_w), train); + + // ── DensePose head ─────────────────────────────────────────────── + let (part_logits, uv_coords) = densepose_head( + &root, + &features, + (cfg.num_body_parts + 1) as i64, // +1 for background + (kp_h, kp_w), + train, + ); + + ModelOutput { + keypoints, + part_logits, + uv_coords, + features, + } + } +} + +// --------------------------------------------------------------------------- +// Phase sanitizer (no learned parameters) +// --------------------------------------------------------------------------- + +/// Differentiable phase sanitization via conjugate multiplication. +/// +/// Implements the CSI ratio model: for each adjacent subcarrier pair, compute +/// the phase difference to cancel out common-mode phase drift (e.g. carrier +/// frequency offset, sampling offset). +/// +/// Input: `[B, T*n_ant, n_sub]` +/// Output: `[B, T*n_ant, n_sub]` (sanitized phase) +fn phase_sanitize(phase: &Tensor) -> Tensor { + // For each subcarrier k, compute the differential phase: + // φ_clean[k] = φ[k] - φ[k-1] for k > 0 + // φ_clean[0] = 0 + // + // This removes linear phase ramps caused by timing and CFO. + // Implemented as: diff along last dimension with zero-padding on the left. + + let n_sub = phase.size()[2]; + if n_sub <= 1 { + return phase.zeros_like(); + } + + // Slice k=1..N and k=0..N-1, compute difference. + let later = phase.slice(2, 1, n_sub, 1); + let earlier = phase.slice(2, 0, n_sub - 1, 1); + let diff = later - earlier; + + // Prepend a zero column so the output has the same shape as input. + let zeros = Tensor::zeros( + [phase.size()[0], phase.size()[1], 1], + (Kind::Float, phase.device()), + ); + Tensor::cat(&[zeros, diff], 2) +} + +// --------------------------------------------------------------------------- +// Modality translator +// --------------------------------------------------------------------------- + +/// Build and run the modality translator network. +/// +/// Architecture: +/// - Amplitude encoder: `Linear(input_size, 512) → ReLU → Linear(512, 256) → ReLU` +/// - Phase encoder: same structure as amplitude encoder +/// - Fusion: `Linear(512, 256) → ReLU → Linear(256, 48*48*3)` +/// → reshape to `[B, 3, 48, 48]` +/// +/// All layers share the same `root` VarStore path so weights accumulate +/// across calls (the parameters are created lazily on first call and reused). +fn modality_translate( + root: &nn::Path, + flat_amp: &Tensor, + flat_phase: &Tensor, + input_size: i64, + train: bool, +) -> Tensor { + let mt = root / "modality_translator"; + + // Amplitude encoder + let ae = |x: &Tensor| { + let h = ((&mt / "amp_enc_fc1").linear(x, input_size, 512)); + let h = h.relu(); + let h = ((&mt / "amp_enc_fc2").linear(&h, 512, 256)); + h.relu() + }; + + // Phase encoder + let pe = |x: &Tensor| { + let h = ((&mt / "ph_enc_fc1").linear(x, input_size, 512)); + let h = h.relu(); + let h = ((&mt / "ph_enc_fc2").linear(&h, 512, 256)); + h.relu() + }; + + let amp_feat = ae(flat_amp); // [B, 256] + let phase_feat = pe(flat_phase); // [B, 256] + + // Concatenate and fuse + let fused = Tensor::cat(&[amp_feat, phase_feat], 1); // [B, 512] + + let spatial_out: i64 = 3 * 48 * 48; + let fused = (&mt / "fusion_fc1").linear(&fused, 512, 256); + let fused = fused.relu(); + let fused = (&mt / "fusion_fc2").linear(&fused, 256, spatial_out); + // fused: [B, 3*48*48] + + let batch = fused.size()[0]; + let spatial_map = fused.reshape([batch, 3, 48, 48]); + + // Optional: apply tanh to bound activations before passing to CNN. + spatial_map.tanh() +} + +// --------------------------------------------------------------------------- +// Path::linear helper (creates or retrieves a Linear layer) +// --------------------------------------------------------------------------- + +/// Extension trait to make `nn::Path` callable with `linear(x, in, out)`. +trait PathLinear { + fn linear(&self, x: &Tensor, in_dim: i64, out_dim: i64) -> Tensor; +} + +impl PathLinear for nn::Path<'_> { + fn linear(&self, x: &Tensor, in_dim: i64, out_dim: i64) -> Tensor { + let cfg = nn::LinearConfig::default(); + let layer = nn::linear(self, in_dim, out_dim, cfg); + layer.forward(x) + } +} + +// --------------------------------------------------------------------------- +// ResNet18-like backbone +// --------------------------------------------------------------------------- + +/// A ResNet18-style CNN backbone. +/// +/// Input: `[B, 3, 48, 48]` +/// Output: `[B, 256, 12, 12]` (spatial features) +/// +/// Architecture: +/// - Stem: Conv2d(3→64, k=3, s=1, p=1) + BN + ReLU +/// - Layer1: 2 × BasicBlock(64→64) +/// - Layer2: 2 × BasicBlock(64→128, stride=2) → 24×24 +/// - Layer3: 2 × BasicBlock(128→256, stride=2) → 12×12 +/// +/// (No Layer4/pooling to preserve spatial resolution.) +fn resnet18_backbone( + root: &nn::Path, + x: &Tensor, + train: bool, + out_channels: i64, +) -> (Tensor, i64, i64) { + let bb = root / "backbone"; + + // Stem + let stem_conv = nn::conv2d( + &(&bb / "stem_conv"), + 3, + 64, + 3, + nn::ConvConfig { padding: 1, ..Default::default() }, + ); + let stem_bn = nn::batch_norm2d(&(&bb / "stem_bn"), 64, Default::default()); + let x = stem_conv.forward(x).apply_t(&stem_bn, train).relu(); + + // Layer 1: 64 → 64 + let x = basic_block(&(&bb / "l1b1"), &x, 64, 64, 1, train); + let x = basic_block(&(&bb / "l1b2"), &x, 64, 64, 1, train); + + // Layer 2: 64 → 128 (stride 2 → half spatial) + let x = basic_block(&(&bb / "l2b1"), &x, 64, 128, 2, train); + let x = basic_block(&(&bb / "l2b2"), &x, 128, 128, 1, train); + + // Layer 3: 128 → out_channels (stride 2 → half spatial again) + let x = basic_block(&(&bb / "l3b1"), &x, 128, out_channels, 2, train); + let x = basic_block(&(&bb / "l3b2"), &x, out_channels, out_channels, 1, train); + + let shape = x.size(); + let h = shape[2]; + let w = shape[3]; + (x, h, w) +} + +/// ResNet BasicBlock. +/// +/// ```text +/// x ─── Conv2d(s) ─── BN ─── ReLU ─── Conv2d(1) ─── BN ──+── ReLU +/// │ │ +/// └── (downsample if needed) ──────────────────────────────┘ +/// ``` +fn basic_block( + path: &nn::Path, + x: &Tensor, + in_ch: i64, + out_ch: i64, + stride: i64, + train: bool, +) -> Tensor { + let conv1 = nn::conv2d( + &(path / "conv1"), + in_ch, + out_ch, + 3, + nn::ConvConfig { stride, padding: 1, bias: false, ..Default::default() }, + ); + let bn1 = nn::batch_norm2d(&(path / "bn1"), out_ch, Default::default()); + + let conv2 = nn::conv2d( + &(path / "conv2"), + out_ch, + out_ch, + 3, + nn::ConvConfig { padding: 1, bias: false, ..Default::default() }, + ); + let bn2 = nn::batch_norm2d(&(path / "bn2"), out_ch, Default::default()); + + let out = conv1.forward(x).apply_t(&bn1, train).relu(); + let out = conv2.forward(&out).apply_t(&bn2, train); + + // Residual / skip connection + let residual = if in_ch != out_ch || stride != 1 { + let ds_conv = nn::conv2d( + &(path / "ds_conv"), + in_ch, + out_ch, + 1, + nn::ConvConfig { stride, bias: false, ..Default::default() }, + ); + let ds_bn = nn::batch_norm2d(&(path / "ds_bn"), out_ch, Default::default()); + ds_conv.forward(x).apply_t(&ds_bn, train) + } else { + x.shallow_clone() + }; + + (out + residual).relu() +} + +// --------------------------------------------------------------------------- +// Keypoint head +// --------------------------------------------------------------------------- + +/// Keypoint heatmap prediction head. +/// +/// Input: `[B, in_channels, H, W]` +/// Output: `[B, num_keypoints, out_h, out_w]` (after upsampling) +fn keypoint_head( + root: &nn::Path, + features: &Tensor, + num_keypoints: i64, + output_size: (i64, i64), + train: bool, +) -> Tensor { + let kp = root / "keypoint_head"; + + let conv1 = nn::conv2d( + &(&kp / "conv1"), + 256, + 256, + 3, + nn::ConvConfig { padding: 1, bias: false, ..Default::default() }, + ); + let bn1 = nn::batch_norm2d(&(&kp / "bn1"), 256, Default::default()); + + let conv2 = nn::conv2d( + &(&kp / "conv2"), + 256, + 128, + 3, + nn::ConvConfig { padding: 1, bias: false, ..Default::default() }, + ); + let bn2 = nn::batch_norm2d(&(&kp / "bn2"), 128, Default::default()); + + let output_conv = nn::conv2d( + &(&kp / "output_conv"), + 128, + num_keypoints, + 1, + Default::default(), + ); + + let x = conv1.forward(features).apply_t(&bn1, train).relu(); + let x = conv2.forward(&x).apply_t(&bn2, train).relu(); + let x = output_conv.forward(&x); + + // Upsample to (output_size_h, output_size_w) + x.upsample_bilinear2d( + [output_size.0, output_size.1], + false, + None, + None, + ) +} + +// --------------------------------------------------------------------------- +// DensePose head +// --------------------------------------------------------------------------- + +/// DensePose prediction head. +/// +/// Input: `[B, in_channels, H, W]` +/// Outputs: +/// - part logits: `[B, num_parts, out_h, out_w]` +/// - UV coordinates: `[B, 2*(num_parts-1), out_h, out_w]` (background excluded from UV) +fn densepose_head( + root: &nn::Path, + features: &Tensor, + num_parts: i64, + output_size: (i64, i64), + train: bool, +) -> (Tensor, Tensor) { + let dp = root / "densepose_head"; + + // Shared convolutional block + let shared_conv1 = nn::conv2d( + &(&dp / "shared_conv1"), + 256, + 256, + 3, + nn::ConvConfig { padding: 1, bias: false, ..Default::default() }, + ); + let shared_bn1 = nn::batch_norm2d(&(&dp / "shared_bn1"), 256, Default::default()); + + let shared_conv2 = nn::conv2d( + &(&dp / "shared_conv2"), + 256, + 256, + 3, + nn::ConvConfig { padding: 1, bias: false, ..Default::default() }, + ); + let shared_bn2 = nn::batch_norm2d(&(&dp / "shared_bn2"), 256, Default::default()); + + // Part segmentation head: 256 → num_parts + let part_conv = nn::conv2d( + &(&dp / "part_conv"), + 256, + num_parts, + 1, + Default::default(), + ); + + // UV regression head: 256 → 48 channels (2 × 24 body parts) + let uv_conv = nn::conv2d( + &(&dp / "uv_conv"), + 256, + 48, // 24 parts × 2 (U, V) + 1, + Default::default(), + ); + + let shared = shared_conv1.forward(features).apply_t(&shared_bn1, train).relu(); + let shared = shared_conv2.forward(&shared).apply_t(&shared_bn2, train).relu(); + + let parts = part_conv.forward(&shared); + let uv = uv_conv.forward(&shared); + + // Upsample both heads to the target spatial resolution. + let parts_up = parts.upsample_bilinear2d( + [output_size.0, output_size.1], + false, + None, + None, + ); + let uv_up = uv.upsample_bilinear2d( + [output_size.0, output_size.1], + false, + None, + None, + ); + + // Apply sigmoid to UV to constrain predictions to [0, 1]. + let uv_out = uv_up.sigmoid(); + + (parts_up, uv_out) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::TrainingConfig; + use tch::Device; + + fn tiny_config() -> TrainingConfig { + let mut cfg = TrainingConfig::default(); + cfg.num_subcarriers = 8; + cfg.window_frames = 4; + cfg.num_antennas_tx = 1; + cfg.num_antennas_rx = 1; + cfg.heatmap_size = 12; + cfg.backbone_channels = 64; + cfg.num_epochs = 2; + cfg.warmup_epochs = 1; + cfg + } + + #[test] + fn model_forward_output_shapes() { + tch::manual_seed(0); + let cfg = tiny_config(); + let device = Device::Cpu; + let model = WiFiDensePoseModel::new(&cfg, device); + + let batch = 2_i64; + let antennas = (cfg.num_antennas_tx * cfg.num_antennas_rx * cfg.window_frames) as i64; + let n_sub = cfg.num_subcarriers as i64; + + let amp = Tensor::ones([batch, antennas, n_sub], (Kind::Float, device)); + let ph = Tensor::zeros([batch, antennas, n_sub], (Kind::Float, device)); + + let out = model.forward_train(&, &ph); + + // Keypoints: [B, 17, heatmap_size, heatmap_size] + assert_eq!(out.keypoints.size()[0], batch); + assert_eq!(out.keypoints.size()[1], cfg.num_keypoints as i64); + + // Part logits: [B, 25, heatmap_size, heatmap_size] + assert_eq!(out.part_logits.size()[0], batch); + assert_eq!(out.part_logits.size()[1], (cfg.num_body_parts + 1) as i64); + + // UV: [B, 48, heatmap_size, heatmap_size] + assert_eq!(out.uv_coords.size()[0], batch); + assert_eq!(out.uv_coords.size()[1], 48); + } + + #[test] + fn model_has_nonzero_parameters() { + tch::manual_seed(0); + let cfg = tiny_config(); + let model = WiFiDensePoseModel::new(&cfg, Device::Cpu); + + // Trigger parameter creation by running a forward pass. + let batch = 1_i64; + let antennas = (cfg.num_antennas_tx * cfg.num_antennas_rx * cfg.window_frames) as i64; + let n_sub = cfg.num_subcarriers as i64; + let amp = Tensor::zeros([batch, antennas, n_sub], (Kind::Float, Device::Cpu)); + let ph = amp.shallow_clone(); + let _ = model.forward_train(&, &ph); + + let n = model.num_parameters(); + assert!(n > 0, "Model must have trainable parameters"); + } + + #[test] + fn phase_sanitize_zeros_first_column() { + let ph = Tensor::ones([2, 3, 8], (Kind::Float, Device::Cpu)); + let out = phase_sanitize(&ph); + // First subcarrier column should be 0. + let first_col = out.slice(2, 0, 1, 1); + let max_abs: f64 = first_col.abs().max().double_value(&[]); + assert!(max_abs < 1e-6, "First diff column should be 0"); + } + + #[test] + fn phase_sanitize_captures_ramp() { + // A linear phase ramp φ[k] = k should produce constant diffs of 1. + let ph = Tensor::arange(8, (Kind::Float, Device::Cpu)) + .reshape([1, 1, 8]) + .expand([2, 3, 8], true); + let out = phase_sanitize(&ph); + // All columns except the first should be 1.0 + let tail = out.slice(2, 1, 8, 1); + let min_val: f64 = tail.min().double_value(&[]); + let max_val: f64 = tail.max().double_value(&[]); + assert!((min_val - 1.0).abs() < 1e-5, "Expected 1.0 diff, got {min_val}"); + assert!((max_val - 1.0).abs() < 1e-5, "Expected 1.0 diff, got {max_val}"); + } + + #[test] + fn inference_mode_gives_same_shapes() { + tch::manual_seed(0); + let cfg = tiny_config(); + let model = WiFiDensePoseModel::new(&cfg, Device::Cpu); + + let batch = 1_i64; + let antennas = (cfg.num_antennas_tx * cfg.num_antennas_rx * cfg.window_frames) as i64; + let n_sub = cfg.num_subcarriers as i64; + let amp = Tensor::rand([batch, antennas, n_sub], (Kind::Float, Device::Cpu)); + let ph = Tensor::rand([batch, antennas, n_sub], (Kind::Float, Device::Cpu)); + + let out = model.forward_inference(&, &ph); + assert_eq!(out.keypoints.size()[0], batch); + assert_eq!(out.part_logits.size()[0], batch); + assert_eq!(out.uv_coords.size()[0], batch); + } + + #[test] + fn uv_coords_bounded_zero_one() { + tch::manual_seed(0); + let cfg = tiny_config(); + let model = WiFiDensePoseModel::new(&cfg, Device::Cpu); + + let batch = 2_i64; + let antennas = (cfg.num_antennas_tx * cfg.num_antennas_rx * cfg.window_frames) as i64; + let n_sub = cfg.num_subcarriers as i64; + let amp = Tensor::rand([batch, antennas, n_sub], (Kind::Float, Device::Cpu)); + let ph = Tensor::rand([batch, antennas, n_sub], (Kind::Float, Device::Cpu)); + + let out = model.forward_inference(&, &ph); + + let uv_min: f64 = out.uv_coords.min().double_value(&[]); + let uv_max: f64 = out.uv_coords.max().double_value(&[]); + assert!(uv_min >= 0.0 - 1e-5, "UV min should be >= 0, got {uv_min}"); + assert!(uv_max <= 1.0 + 1e-5, "UV max should be <= 1, got {uv_max}"); } } diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/trainer.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/trainer.rs index d543cc7..19ccbd5 100644 --- a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/trainer.rs +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/src/trainer.rs @@ -1,24 +1,777 @@ -//! Training loop orchestrator. +//! Training loop for WiFi-DensePose. //! -//! This module will be implemented by the trainer agent. It currently provides -//! the public interface stubs so that the crate compiles as a whole. +//! # Features +//! +//! - Mini-batch training with [`DataLoader`]-style iteration +//! - Validation every N epochs with PCK\@0.2 and OKS metrics +//! - Best-checkpoint saving (by validation PCK) +//! - CSV logging (`epoch, train_loss, val_pck, val_oks, lr`) +//! - Gradient clipping +//! - LR scheduling (step decay at configured milestones) +//! - Early stopping +//! +//! # No mock data +//! +//! The trainer never generates random or synthetic data. It operates +//! exclusively on the [`CsiDataset`] passed at call site. The +//! [`SyntheticCsiDataset`] is only used for the deterministic proof protocol. + +use std::collections::VecDeque; +use std::io::Write as IoWrite; +use std::path::{Path, PathBuf}; +use std::time::Instant; + +use ndarray::{Array1, Array2}; +use tch::{nn, nn::OptimizerConfig, Device, Kind, Tensor}; +use tracing::{debug, info, warn}; use crate::config::TrainingConfig; +use crate::dataset::{CsiDataset, CsiSample, DataLoader}; +use crate::error::TrainError; +use crate::losses::{LossWeights, WiFiDensePoseLoss}; +use crate::losses::generate_target_heatmaps; +use crate::metrics::{MetricsAccumulator, MetricsResult}; +use crate::model::WiFiDensePoseModel; -/// Orchestrates the full training loop: data loading, forward pass, loss -/// computation, back-propagation, validation, and checkpointing. +// --------------------------------------------------------------------------- +// Public result types +// --------------------------------------------------------------------------- + +/// Per-epoch training log entry. +#[derive(Debug, Clone)] +pub struct EpochLog { + /// Epoch number (1-indexed). + pub epoch: usize, + /// Mean total loss over all training batches. + pub train_loss: f64, + /// Mean keypoint-only loss component. + pub train_kp_loss: f64, + /// Validation PCK\@0.2 (0–1). `0.0` when validation was skipped. + pub val_pck: f32, + /// Validation OKS (0–1). `0.0` when validation was skipped. + pub val_oks: f32, + /// Learning rate at the end of this epoch. + pub lr: f64, + /// Wall-clock duration of this epoch in seconds. + pub duration_secs: f64, +} + +/// Summary returned after a completed (or early-stopped) training run. +#[derive(Debug, Clone)] +pub struct TrainResult { + /// Best validation PCK achieved during training. + pub best_pck: f32, + /// Epoch at which `best_pck` was achieved (1-indexed). + pub best_epoch: usize, + /// Training loss on the last completed epoch. + pub final_train_loss: f64, + /// Full per-epoch log. + pub training_history: Vec, + /// Path to the best checkpoint file, if any was saved. + pub checkpoint_path: Option, +} + +// --------------------------------------------------------------------------- +// Trainer +// --------------------------------------------------------------------------- + +/// Orchestrates the full WiFi-DensePose training pipeline. +/// +/// Create via [`Trainer::new`], then call [`Trainer::train`] with real dataset +/// references. pub struct Trainer { config: TrainingConfig, + model: WiFiDensePoseModel, + device: Device, } impl Trainer { /// Create a new `Trainer` from the given configuration. + /// + /// The model and device are initialised from `config`. pub fn new(config: TrainingConfig) -> Self { - Trainer { config } + let device = if config.use_gpu { + Device::Cuda(config.gpu_device_id as usize) + } else { + Device::Cpu + }; + + tch::manual_seed(config.seed as i64); + + let model = WiFiDensePoseModel::new(&config, device); + Trainer { config, model, device } } - /// Return a reference to the active training configuration. - pub fn config(&self) -> &TrainingConfig { - &self.config + /// Run the full training loop. + /// + /// # Errors + /// + /// - [`TrainError::EmptyDataset`] if either dataset is empty. + /// - [`TrainError::TrainingStep`] on unrecoverable forward/backward errors. + /// - [`TrainError::Checkpoint`] if writing checkpoints fails. + pub fn train( + &mut self, + train_dataset: &dyn CsiDataset, + val_dataset: &dyn CsiDataset, + ) -> Result { + if train_dataset.is_empty() { + return Err(TrainError::EmptyDataset); + } + if val_dataset.is_empty() { + return Err(TrainError::EmptyDataset); + } + + // Prepare output directories. + std::fs::create_dir_all(&self.config.checkpoint_dir) + .map_err(|e| TrainError::Io(e))?; + std::fs::create_dir_all(&self.config.log_dir) + .map_err(|e| TrainError::Io(e))?; + + // Build optimizer (AdamW). + let mut opt = nn::AdamW::default() + .wd(self.config.weight_decay) + .build(self.model.var_store(), self.config.learning_rate) + .map_err(|e| TrainError::training_step(e.to_string()))?; + + let loss_fn = WiFiDensePoseLoss::new(LossWeights { + lambda_kp: self.config.lambda_kp, + lambda_dp: self.config.lambda_dp, + lambda_tr: self.config.lambda_tr, + }); + + // CSV log file. + let csv_path = self.config.log_dir.join("training_log.csv"); + let mut csv_file = std::fs::OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(&csv_path) + .map_err(|e| TrainError::Io(e))?; + writeln!(csv_file, "epoch,train_loss,train_kp_loss,val_pck,val_oks,lr,duration_secs") + .map_err(|e| TrainError::Io(e))?; + + let mut training_history: Vec = Vec::new(); + let mut best_pck: f32 = -1.0; + let mut best_epoch: usize = 0; + let mut best_checkpoint_path: Option = None; + + // Early-stopping state: track the last N val_pck values. + let patience = self.config.early_stopping_patience; + let mut patience_counter: usize = 0; + let min_delta = 1e-4_f32; + + let mut current_lr = self.config.learning_rate; + + info!( + "Training {} for {} epochs on '{}' → '{}'", + train_dataset.name(), + self.config.num_epochs, + train_dataset.name(), + val_dataset.name() + ); + + for epoch in 1..=self.config.num_epochs { + let epoch_start = Instant::now(); + + // ── LR scheduling ────────────────────────────────────────────── + if self.config.lr_milestones.contains(&epoch) { + current_lr *= self.config.lr_gamma; + opt.set_lr(current_lr); + info!("Epoch {epoch}: LR decayed to {current_lr:.2e}"); + } + + // ── Warmup ───────────────────────────────────────────────────── + if epoch <= self.config.warmup_epochs { + let warmup_lr = self.config.learning_rate + * epoch as f64 + / self.config.warmup_epochs as f64; + opt.set_lr(warmup_lr); + current_lr = warmup_lr; + } + + // ── Training batches ─────────────────────────────────────────── + // Deterministic shuffle: seed = config.seed XOR epoch. + let shuffle_seed = self.config.seed ^ (epoch as u64); + let batches = make_batches( + train_dataset, + self.config.batch_size, + true, + shuffle_seed, + self.device, + ); + + let mut total_loss_sum = 0.0_f64; + let mut kp_loss_sum = 0.0_f64; + let mut n_batches = 0_usize; + + for (amp_batch, phase_batch, kp_batch, vis_batch) in &batches { + let output = self.model.forward_train(amp_batch, phase_batch); + + // Build target heatmaps from ground-truth keypoints. + let target_hm = kp_to_heatmap_tensor( + kp_batch, + vis_batch, + self.config.heatmap_size, + self.device, + ); + + // Binary visibility mask [B, 17]. + let vis_mask = (vis_batch.gt(0.0)).to_kind(Kind::Float); + + // Compute keypoint loss only (no DensePose GT in this pipeline). + let (total_tensor, loss_out) = loss_fn.forward( + &output.keypoints, + &target_hm, + &vis_mask, + None, None, None, None, None, None, + ); + + opt.zero_grad(); + total_tensor.backward(); + opt.clip_grad_norm(self.config.grad_clip_norm); + opt.step(); + + total_loss_sum += loss_out.total as f64; + kp_loss_sum += loss_out.keypoint as f64; + n_batches += 1; + + debug!( + "Epoch {epoch} batch {n_batches}: loss={:.4}", + loss_out.total + ); + } + + let mean_loss = if n_batches > 0 { + total_loss_sum / n_batches as f64 + } else { + 0.0 + }; + let mean_kp_loss = if n_batches > 0 { + kp_loss_sum / n_batches as f64 + } else { + 0.0 + }; + + // ── Validation ───────────────────────────────────────────────── + let mut val_pck = 0.0_f32; + let mut val_oks = 0.0_f32; + + if epoch % self.config.val_every_epochs == 0 { + match self.evaluate(val_dataset) { + Ok(metrics) => { + val_pck = metrics.pck; + val_oks = metrics.oks; + info!( + "Epoch {epoch}: loss={mean_loss:.4} val_pck={val_pck:.4} val_oks={val_oks:.4} lr={current_lr:.2e}" + ); + } + Err(e) => { + warn!("Validation failed at epoch {epoch}: {e}"); + } + } + + // ── Checkpoint saving ────────────────────────────────────── + if val_pck > best_pck + min_delta { + best_pck = val_pck; + best_epoch = epoch; + patience_counter = 0; + + let ckpt_name = format!("best_epoch{epoch:04}_pck{val_pck:.4}.pt"); + let ckpt_path = self.config.checkpoint_dir.join(&ckpt_name); + + match self.model.save(&ckpt_path) { + Ok(_) => { + info!("Saved best checkpoint: {}", ckpt_path.display()); + best_checkpoint_path = Some(ckpt_path); + } + Err(e) => { + warn!("Failed to save checkpoint: {e}"); + } + } + } else { + patience_counter += 1; + } + } + + let epoch_secs = epoch_start.elapsed().as_secs_f64(); + let log = EpochLog { + epoch, + train_loss: mean_loss, + train_kp_loss: mean_kp_loss, + val_pck, + val_oks, + lr: current_lr, + duration_secs: epoch_secs, + }; + + // Write CSV row. + writeln!( + csv_file, + "{},{:.6},{:.6},{:.6},{:.6},{:.2e},{:.3}", + log.epoch, + log.train_loss, + log.train_kp_loss, + log.val_pck, + log.val_oks, + log.lr, + log.duration_secs, + ) + .map_err(|e| TrainError::Io(e))?; + + training_history.push(log); + + // ── Early stopping check ─────────────────────────────────────── + if patience_counter >= patience { + info!( + "Early stopping at epoch {epoch}: no improvement for {patience} validation rounds." + ); + break; + } + } + + // Save final model regardless. + let final_ckpt = self.config.checkpoint_dir.join("final.pt"); + if let Err(e) = self.model.save(&final_ckpt) { + warn!("Failed to save final model: {e}"); + } + + Ok(TrainResult { + best_pck: best_pck.max(0.0), + best_epoch, + final_train_loss: training_history + .last() + .map(|l| l.train_loss) + .unwrap_or(0.0), + training_history, + checkpoint_path: best_checkpoint_path, + }) + } + + /// Evaluate on a dataset, returning PCK and OKS metrics. + /// + /// Runs inference (no gradient) over the full dataset using the configured + /// batch size. + pub fn evaluate(&self, dataset: &dyn CsiDataset) -> Result { + if dataset.is_empty() { + return Err(TrainError::EmptyDataset); + } + + let mut acc = MetricsAccumulator::default_threshold(); + + let batches = make_batches( + dataset, + self.config.batch_size, + false, // no shuffle during evaluation + self.config.seed, + self.device, + ); + + for (amp_batch, phase_batch, kp_batch, vis_batch) in &batches { + let output = self.model.forward_inference(amp_batch, phase_batch); + + // Extract predicted keypoints from heatmaps. + // Strategy: argmax over spatial dimensions → (x, y). + let pred_kps = heatmap_to_keypoints(&output.keypoints); + + // Convert GT tensors back to ndarray for MetricsAccumulator. + let batch_size = kp_batch.size()[0] as usize; + for b in 0..batch_size { + let pred_kp_np = extract_kp_ndarray(&pred_kps, b); + let gt_kp_np = extract_kp_ndarray(kp_batch, b); + let vis_np = extract_vis_ndarray(vis_batch, b); + + acc.update(&pred_kp_np, >_kp_np, &vis_np); + } + } + + acc.finalize().ok_or(TrainError::EmptyDataset) + } + + /// Save a training checkpoint. + pub fn save_checkpoint( + &self, + path: &Path, + _epoch: usize, + _metrics: &MetricsResult, + ) -> Result<(), TrainError> { + self.model.save(path).map_err(|e| TrainError::checkpoint(e.to_string(), path)) + } + + /// Load model weights from a checkpoint. + /// + /// Returns the epoch number encoded in the filename (if any), or `0`. + pub fn load_checkpoint(&mut self, path: &Path) -> Result { + self.model + .var_store_mut() + .load(path) + .map_err(|e| TrainError::checkpoint(e.to_string(), path))?; + + // Try to parse the epoch from the filename (e.g. "best_epoch0042_pck0.7842.pt"). + let epoch = path + .file_stem() + .and_then(|s| s.to_str()) + .and_then(|s| { + s.split("epoch").nth(1) + .and_then(|rest| rest.split('_').next()) + .and_then(|n| n.parse::().ok()) + }) + .unwrap_or(0); + + Ok(epoch) + } +} + +// --------------------------------------------------------------------------- +// Batch construction helpers +// --------------------------------------------------------------------------- + +/// Build all training batches for one epoch. +/// +/// `shuffle=true` uses a deterministic LCG permutation seeded with `seed`. +/// This guarantees reproducibility: same seed → same iteration order, with +/// no dependence on OS entropy. +pub fn make_batches( + dataset: &dyn CsiDataset, + batch_size: usize, + shuffle: bool, + seed: u64, + device: Device, +) -> Vec<(Tensor, Tensor, Tensor, Tensor)> { + let n = dataset.len(); + if n == 0 { + return vec![]; + } + + // Build index permutation (or identity). + let mut indices: Vec = (0..n).collect(); + if shuffle { + lcg_shuffle(&mut indices, seed); + } + + // Partition into batches. + let mut batches = Vec::new(); + let mut cursor = 0; + while cursor < indices.len() { + let end = (cursor + batch_size).min(indices.len()); + let batch_indices = &indices[cursor..end]; + + // Load samples. + let mut samples: Vec = Vec::with_capacity(batch_indices.len()); + for &idx in batch_indices { + match dataset.get(idx) { + Ok(s) => samples.push(s), + Err(e) => { + warn!("Skipping sample {idx}: {e}"); + } + } + } + + if !samples.is_empty() { + let batch = collate(&samples, device); + batches.push(batch); + } + + cursor = end; + } + + batches +} + +/// Deterministic Fisher-Yates shuffle using a Linear Congruential Generator. +/// +/// LCG parameters: multiplier = 6364136223846793005, +/// increment = 1442695040888963407 (Knuth's MMIX) +fn lcg_shuffle(indices: &mut [usize], seed: u64) { + let n = indices.len(); + if n <= 1 { + return; + } + + let mut state = seed.wrapping_add(1); // avoid seed=0 degeneracy + let mul: u64 = 6364136223846793005; + let inc: u64 = 1442695040888963407; + + for i in (1..n).rev() { + state = state.wrapping_mul(mul).wrapping_add(inc); + let j = (state >> 33) as usize % (i + 1); + indices.swap(i, j); + } +} + +/// Collate a slice of [`CsiSample`]s into four batched tensors. +/// +/// Returns `(amplitude, phase, keypoints, visibility)`: +/// - `amplitude`: `[B, T*n_tx*n_rx, n_sub]` +/// - `phase`: `[B, T*n_tx*n_rx, n_sub]` +/// - `keypoints`: `[B, 17, 2]` +/// - `visibility`: `[B, 17]` +pub fn collate(samples: &[CsiSample], device: Device) -> (Tensor, Tensor, Tensor, Tensor) { + let b = samples.len(); + assert!(b > 0, "collate requires at least one sample"); + + let s0 = &samples[0]; + let shape = s0.amplitude.shape(); + let (t, n_tx, n_rx, n_sub) = (shape[0], shape[1], shape[2], shape[3]); + let flat_ant = t * n_tx * n_rx; + let num_kp = s0.keypoints.shape()[0]; + + // Allocate host buffers. + let mut amp_data = vec![0.0_f32; b * flat_ant * n_sub]; + let mut ph_data = vec![0.0_f32; b * flat_ant * n_sub]; + let mut kp_data = vec![0.0_f32; b * num_kp * 2]; + let mut vis_data = vec![0.0_f32; b * num_kp]; + + for (bi, sample) in samples.iter().enumerate() { + // Amplitude: [T, n_tx, n_rx, n_sub] → flatten to [T*n_tx*n_rx, n_sub] + let amp_flat: Vec = sample + .amplitude + .iter() + .copied() + .collect(); + let ph_flat: Vec = sample.phase.iter().copied().collect(); + + let stride = flat_ant * n_sub; + amp_data[bi * stride..(bi + 1) * stride].copy_from_slice(&_flat); + ph_data[bi * stride..(bi + 1) * stride].copy_from_slice(&ph_flat); + + // Keypoints. + let kp_stride = num_kp * 2; + for j in 0..num_kp { + kp_data[bi * kp_stride + j * 2] = sample.keypoints[[j, 0]]; + kp_data[bi * kp_stride + j * 2 + 1] = sample.keypoints[[j, 1]]; + vis_data[bi * num_kp + j] = sample.keypoint_visibility[j]; + } + } + + let amp_t = Tensor::from_slice(&_data) + .reshape([b as i64, flat_ant as i64, n_sub as i64]) + .to_device(device); + let ph_t = Tensor::from_slice(&ph_data) + .reshape([b as i64, flat_ant as i64, n_sub as i64]) + .to_device(device); + let kp_t = Tensor::from_slice(&kp_data) + .reshape([b as i64, num_kp as i64, 2]) + .to_device(device); + let vis_t = Tensor::from_slice(&vis_data) + .reshape([b as i64, num_kp as i64]) + .to_device(device); + + (amp_t, ph_t, kp_t, vis_t) +} + +// --------------------------------------------------------------------------- +// Heatmap utilities +// --------------------------------------------------------------------------- + +/// Convert ground-truth keypoints to Gaussian target heatmaps. +/// +/// Wraps [`generate_target_heatmaps`] to work on `tch::Tensor` inputs. +fn kp_to_heatmap_tensor( + kp_tensor: &Tensor, + vis_tensor: &Tensor, + heatmap_size: usize, + device: Device, +) -> Tensor { + // kp_tensor: [B, 17, 2] + // vis_tensor: [B, 17] + let b = kp_tensor.size()[0] as usize; + let num_kp = kp_tensor.size()[1] as usize; + + // Convert to ndarray for generate_target_heatmaps. + let kp_vec: Vec = Vec::::from(kp_tensor.to_kind(Kind::Double).flatten(0, -1)) + .iter().map(|&x| x as f32).collect(); + let vis_vec: Vec = Vec::::from(vis_tensor.to_kind(Kind::Double).flatten(0, -1)) + .iter().map(|&x| x as f32).collect(); + + let kp_nd = ndarray::Array3::from_shape_vec((b, num_kp, 2), kp_vec) + .expect("kp shape"); + let vis_nd = ndarray::Array2::from_shape_vec((b, num_kp), vis_vec) + .expect("vis shape"); + + let hm_nd = generate_target_heatmaps(&kp_nd, &vis_nd, heatmap_size, 2.0); + + // [B, 17, H, W] + let flat: Vec = hm_nd.iter().copied().collect(); + Tensor::from_slice(&flat) + .reshape([ + b as i64, + num_kp as i64, + heatmap_size as i64, + heatmap_size as i64, + ]) + .to_device(device) +} + +/// Convert predicted heatmaps to normalised keypoint coordinates via argmax. +/// +/// Input: `[B, 17, H, W]` +/// Output: `[B, 17, 2]` with (x, y) in [0, 1] +fn heatmap_to_keypoints(heatmaps: &Tensor) -> Tensor { + let sizes = heatmaps.size(); + let (batch, num_kp, h, w) = (sizes[0], sizes[1], sizes[2], sizes[3]); + + // Flatten spatial → [B, 17, H*W] + let flat = heatmaps.reshape([batch, num_kp, h * w]); + // Argmax per joint → [B, 17] + let arg = flat.argmax(-1, false); + + // Decompose linear index into (row, col). + let row = (&arg / w).to_kind(Kind::Float); // [B, 17] + let col = (&arg % w).to_kind(Kind::Float); // [B, 17] + + // Normalize to [0, 1] + let x = col / (w - 1) as f64; + let y = row / (h - 1) as f64; + + // Stack to [B, 17, 2] + Tensor::stack(&[x, y], -1) +} + +/// Extract a single sample's keypoints as an ndarray from a batched tensor. +/// +/// `kp_tensor` shape: `[B, 17, 2]` +fn extract_kp_ndarray(kp_tensor: &Tensor, batch_idx: usize) -> Array2 { + let num_kp = kp_tensor.size()[1] as usize; + let row = kp_tensor.select(0, batch_idx as i64); + let data: Vec = Vec::::from(row.to_kind(Kind::Double).flatten(0, -1)) + .iter().map(|&v| v as f32).collect(); + Array2::from_shape_vec((num_kp, 2), data).expect("kp ndarray shape") +} + +/// Extract a single sample's visibility flags as an ndarray from a batched tensor. +/// +/// `vis_tensor` shape: `[B, 17]` +fn extract_vis_ndarray(vis_tensor: &Tensor, batch_idx: usize) -> Array1 { + let num_kp = vis_tensor.size()[1] as usize; + let row = vis_tensor.select(0, batch_idx as i64); + let data: Vec = Vec::::from(row.to_kind(Kind::Double)) + .iter().map(|&v| v as f32).collect(); + Array1::from_vec(data) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::TrainingConfig; + use crate::dataset::{SyntheticCsiDataset, SyntheticConfig}; + + fn tiny_config() -> TrainingConfig { + let mut cfg = TrainingConfig::default(); + cfg.num_subcarriers = 8; + cfg.window_frames = 2; + cfg.num_antennas_tx = 1; + cfg.num_antennas_rx = 1; + cfg.heatmap_size = 8; + cfg.backbone_channels = 32; + cfg.num_epochs = 2; + cfg.warmup_epochs = 1; + cfg.batch_size = 4; + cfg.val_every_epochs = 1; + cfg.early_stopping_patience = 5; + cfg.lr_milestones = vec![2]; + cfg + } + + fn tiny_synthetic_dataset(n: usize) -> SyntheticCsiDataset { + let cfg = tiny_config(); + SyntheticCsiDataset::new(n, SyntheticConfig { + num_subcarriers: cfg.num_subcarriers, + num_antennas_tx: cfg.num_antennas_tx, + num_antennas_rx: cfg.num_antennas_rx, + window_frames: cfg.window_frames, + num_keypoints: 17, + signal_frequency_hz: 2.4e9, + }) + } + + #[test] + fn collate_produces_correct_shapes() { + let ds = tiny_synthetic_dataset(4); + let samples: Vec<_> = (0..4).map(|i| ds.get(i).unwrap()).collect(); + let (amp, ph, kp, vis) = collate(&samples, Device::Cpu); + + let cfg = tiny_config(); + let flat_ant = (cfg.window_frames * cfg.num_antennas_tx * cfg.num_antennas_rx) as i64; + assert_eq!(amp.size(), [4, flat_ant, cfg.num_subcarriers as i64]); + assert_eq!(ph.size(), [4, flat_ant, cfg.num_subcarriers as i64]); + assert_eq!(kp.size(), [4, 17, 2]); + assert_eq!(vis.size(), [4, 17]); + } + + #[test] + fn make_batches_covers_all_samples() { + let ds = tiny_synthetic_dataset(10); + let batches = make_batches(&ds, 3, false, 42, Device::Cpu); + let total: i64 = batches.iter().map(|(a, _, _, _)| a.size()[0]).sum(); + assert_eq!(total, 10); + } + + #[test] + fn make_batches_shuffle_reproducible() { + let ds = tiny_synthetic_dataset(10); + let b1 = make_batches(&ds, 3, true, 99, Device::Cpu); + let b2 = make_batches(&ds, 3, true, 99, Device::Cpu); + // Shapes should match exactly. + for (batch_a, batch_b) in b1.iter().zip(b2.iter()) { + assert_eq!(batch_a.0.size(), batch_b.0.size()); + } + } + + #[test] + fn lcg_shuffle_is_permutation() { + let mut idx: Vec = (0..20).collect(); + lcg_shuffle(&mut idx, 42); + let mut sorted = idx.clone(); + sorted.sort_unstable(); + assert_eq!(sorted, (0..20).collect::>()); + } + + #[test] + fn lcg_shuffle_different_seeds_differ() { + let mut a: Vec = (0..20).collect(); + let mut b: Vec = (0..20).collect(); + lcg_shuffle(&mut a, 1); + lcg_shuffle(&mut b, 2); + assert_ne!(a, b, "different seeds should produce different orders"); + } + + #[test] + fn heatmap_to_keypoints_shape() { + let hm = Tensor::zeros([2, 17, 8, 8], (Kind::Float, Device::Cpu)); + let kp = heatmap_to_keypoints(&hm); + assert_eq!(kp.size(), [2, 17, 2]); + } + + #[test] + fn heatmap_to_keypoints_center_peak() { + // Create a heatmap with a single peak at the center (4, 4) of an 8×8 map. + let mut hm = Tensor::zeros([1, 1, 8, 8], (Kind::Float, Device::Cpu)); + let _ = hm.narrow(2, 4, 1).narrow(3, 4, 1).fill_(1.0); + let kp = heatmap_to_keypoints(&hm); + let x: f64 = kp.double_value(&[0, 0, 0]); + let y: f64 = kp.double_value(&[0, 0, 1]); + // Center pixel 4 → normalised 4/7 ≈ 0.571 + assert!((x - 4.0 / 7.0).abs() < 1e-4, "x={x}"); + assert!((y - 4.0 / 7.0).abs() < 1e-4, "y={y}"); + } + + #[test] + fn trainer_train_completes() { + let cfg = tiny_config(); + let train_ds = tiny_synthetic_dataset(8); + let val_ds = tiny_synthetic_dataset(4); + + let mut trainer = Trainer::new(cfg); + let tmpdir = tempfile::tempdir().unwrap(); + trainer.config.checkpoint_dir = tmpdir.path().join("checkpoints"); + trainer.config.log_dir = tmpdir.path().join("logs"); + + let result = trainer.train(&train_ds, &val_ds).unwrap(); + assert!(result.final_train_loss.is_finite()); + assert!(!result.training_history.is_empty()); } } diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/tests/test_dataset.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/tests/test_dataset.rs new file mode 100644 index 0000000..c91cdec --- /dev/null +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/tests/test_dataset.rs @@ -0,0 +1,459 @@ +//! Integration tests for [`wifi_densepose_train::dataset`]. +//! +//! All tests use [`SyntheticCsiDataset`] which is fully deterministic (no +//! random number generator, no OS entropy). Tests that need a temporary +//! directory use [`tempfile::TempDir`]. + +use wifi_densepose_train::dataset::{ + CsiDataset, DatasetError, MmFiDataset, SyntheticCsiDataset, SyntheticConfig, +}; + +// --------------------------------------------------------------------------- +// Helper: default SyntheticConfig +// --------------------------------------------------------------------------- + +fn default_cfg() -> SyntheticConfig { + SyntheticConfig::default() +} + +// --------------------------------------------------------------------------- +// SyntheticCsiDataset::len / is_empty +// --------------------------------------------------------------------------- + +/// `len()` must return the exact count passed to the constructor. +#[test] +fn len_returns_constructor_count() { + for &n in &[0_usize, 1, 10, 100, 200] { + let ds = SyntheticCsiDataset::new(n, default_cfg()); + assert_eq!( + ds.len(), + n, + "len() must return {n} for dataset of size {n}" + ); + } +} + +/// `is_empty()` must return `true` for a zero-length dataset. +#[test] +fn is_empty_true_for_zero_length() { + let ds = SyntheticCsiDataset::new(0, default_cfg()); + assert!( + ds.is_empty(), + "is_empty() must be true for a dataset with 0 samples" + ); +} + +/// `is_empty()` must return `false` for a non-empty dataset. +#[test] +fn is_empty_false_for_non_empty() { + let ds = SyntheticCsiDataset::new(5, default_cfg()); + assert!( + !ds.is_empty(), + "is_empty() must be false for a dataset with 5 samples" + ); +} + +// --------------------------------------------------------------------------- +// SyntheticCsiDataset::get — sample shapes +// --------------------------------------------------------------------------- + +/// `get(0)` must return a [`CsiSample`] with the exact shapes expected by the +/// model's default configuration. +#[test] +fn get_sample_amplitude_shape() { + let cfg = default_cfg(); + let ds = SyntheticCsiDataset::new(10, cfg.clone()); + let sample = ds.get(0).expect("get(0) must succeed"); + + assert_eq!( + sample.amplitude.shape(), + &[cfg.window_frames, cfg.num_antennas_tx, cfg.num_antennas_rx, cfg.num_subcarriers], + "amplitude shape must be [T, n_tx, n_rx, n_sc]" + ); +} + +#[test] +fn get_sample_phase_shape() { + let cfg = default_cfg(); + let ds = SyntheticCsiDataset::new(10, cfg.clone()); + let sample = ds.get(0).expect("get(0) must succeed"); + + assert_eq!( + sample.phase.shape(), + &[cfg.window_frames, cfg.num_antennas_tx, cfg.num_antennas_rx, cfg.num_subcarriers], + "phase shape must be [T, n_tx, n_rx, n_sc]" + ); +} + +/// Keypoints shape must be [17, 2]. +#[test] +fn get_sample_keypoints_shape() { + let cfg = default_cfg(); + let ds = SyntheticCsiDataset::new(10, cfg.clone()); + let sample = ds.get(0).expect("get(0) must succeed"); + + assert_eq!( + sample.keypoints.shape(), + &[cfg.num_keypoints, 2], + "keypoints shape must be [17, 2], got {:?}", + sample.keypoints.shape() + ); +} + +/// Visibility shape must be [17]. +#[test] +fn get_sample_visibility_shape() { + let cfg = default_cfg(); + let ds = SyntheticCsiDataset::new(10, cfg.clone()); + let sample = ds.get(0).expect("get(0) must succeed"); + + assert_eq!( + sample.keypoint_visibility.shape(), + &[cfg.num_keypoints], + "keypoint_visibility shape must be [17], got {:?}", + sample.keypoint_visibility.shape() + ); +} + +// --------------------------------------------------------------------------- +// SyntheticCsiDataset::get — value ranges +// --------------------------------------------------------------------------- + +/// All keypoint coordinates must lie in [0, 1]. +#[test] +fn keypoints_in_unit_square() { + let ds = SyntheticCsiDataset::new(5, default_cfg()); + for idx in 0..5 { + let sample = ds.get(idx).expect("get must succeed"); + for joint in sample.keypoints.outer_iter() { + let x = joint[0]; + let y = joint[1]; + assert!( + x >= 0.0 && x <= 1.0, + "keypoint x={x} at sample {idx} is outside [0, 1]" + ); + assert!( + y >= 0.0 && y <= 1.0, + "keypoint y={y} at sample {idx} is outside [0, 1]" + ); + } + } +} + +/// All visibility values in the synthetic dataset must be 2.0 (visible). +#[test] +fn visibility_all_visible_in_synthetic() { + let ds = SyntheticCsiDataset::new(5, default_cfg()); + for idx in 0..5 { + let sample = ds.get(idx).expect("get must succeed"); + for &v in sample.keypoint_visibility.iter() { + assert!( + (v - 2.0).abs() < 1e-6, + "expected visibility = 2.0 (visible), got {v} at sample {idx}" + ); + } + } +} + +/// Amplitude values must lie in the physics model range [0.2, 0.8]. +/// +/// The model computes: `0.5 + 0.3 * sin(...)`, so the range is [0.2, 0.8]. +#[test] +fn amplitude_values_in_physics_range() { + let ds = SyntheticCsiDataset::new(8, default_cfg()); + for idx in 0..8 { + let sample = ds.get(idx).expect("get must succeed"); + for &v in sample.amplitude.iter() { + assert!( + v >= 0.19 && v <= 0.81, + "amplitude value {v} at sample {idx} is outside [0.2, 0.8]" + ); + } + } +} + +// --------------------------------------------------------------------------- +// SyntheticCsiDataset — determinism +// --------------------------------------------------------------------------- + +/// Calling `get(i)` multiple times must return bit-identical results. +#[test] +fn get_is_deterministic_same_index() { + let ds = SyntheticCsiDataset::new(10, default_cfg()); + + let s1 = ds.get(5).expect("first get must succeed"); + let s2 = ds.get(5).expect("second get must succeed"); + + // Compare every element of amplitude. + for ((t, tx, rx, k), v1) in s1.amplitude.indexed_iter() { + let v2 = s2.amplitude[[t, tx, rx, k]]; + assert_eq!( + v1.to_bits(), + v2.to_bits(), + "amplitude at [{t},{tx},{rx},{k}] must be bit-identical across calls" + ); + } + + // Compare keypoints. + for (j, v1) in s1.keypoints.indexed_iter() { + let v2 = s2.keypoints[j]; + assert_eq!( + v1.to_bits(), + v2.to_bits(), + "keypoint at {j:?} must be bit-identical across calls" + ); + } +} + +/// Different sample indices must produce different amplitude tensors (the +/// sinusoidal model ensures this for the default config). +#[test] +fn different_indices_produce_different_samples() { + let ds = SyntheticCsiDataset::new(10, default_cfg()); + + let s0 = ds.get(0).expect("get(0) must succeed"); + let s1 = ds.get(1).expect("get(1) must succeed"); + + // At least some amplitude value must differ between index 0 and 1. + let all_same = s0 + .amplitude + .iter() + .zip(s1.amplitude.iter()) + .all(|(a, b)| (a - b).abs() < 1e-7); + + assert!( + !all_same, + "samples at different indices must not be identical in amplitude" + ); +} + +/// Two datasets with the same configuration produce identical samples at the +/// same index (seed is implicit in the analytical formula). +#[test] +fn two_datasets_same_config_same_samples() { + let cfg = default_cfg(); + let ds1 = SyntheticCsiDataset::new(20, cfg.clone()); + let ds2 = SyntheticCsiDataset::new(20, cfg); + + for idx in [0_usize, 7, 19] { + let s1 = ds1.get(idx).expect("ds1.get must succeed"); + let s2 = ds2.get(idx).expect("ds2.get must succeed"); + + for ((t, tx, rx, k), v1) in s1.amplitude.indexed_iter() { + let v2 = s2.amplitude[[t, tx, rx, k]]; + assert_eq!( + v1.to_bits(), + v2.to_bits(), + "amplitude at [{t},{tx},{rx},{k}] must match across two equivalent datasets \ + (sample {idx})" + ); + } + } +} + +/// Two datasets with different num_subcarriers must produce different output +/// shapes (and thus different data). +#[test] +fn different_config_produces_different_data() { + let mut cfg1 = default_cfg(); + let mut cfg2 = default_cfg(); + cfg2.num_subcarriers = 28; // different subcarrier count + + let ds1 = SyntheticCsiDataset::new(5, cfg1); + let ds2 = SyntheticCsiDataset::new(5, cfg2); + + let s1 = ds1.get(0).expect("get(0) from ds1 must succeed"); + let s2 = ds2.get(0).expect("get(0) from ds2 must succeed"); + + assert_ne!( + s1.amplitude.shape(), + s2.amplitude.shape(), + "datasets with different configs must produce different-shaped samples" + ); +} + +// --------------------------------------------------------------------------- +// SyntheticCsiDataset — out-of-bounds error +// --------------------------------------------------------------------------- + +/// Requesting an index equal to `len()` must return an error. +#[test] +fn get_out_of_bounds_returns_error() { + let ds = SyntheticCsiDataset::new(5, default_cfg()); + let result = ds.get(5); // index == len → out of bounds + assert!( + result.is_err(), + "get(5) on a 5-element dataset must return Err" + ); +} + +/// Requesting a large index must also return an error. +#[test] +fn get_large_index_returns_error() { + let ds = SyntheticCsiDataset::new(3, default_cfg()); + let result = ds.get(1_000_000); + assert!( + result.is_err(), + "get(1_000_000) on a 3-element dataset must return Err" + ); +} + +// --------------------------------------------------------------------------- +// MmFiDataset — directory not found +// --------------------------------------------------------------------------- + +/// [`MmFiDataset::discover`] must return a [`DatasetError::DirectoryNotFound`] +/// when the root directory does not exist. +#[test] +fn mmfi_dataset_nonexistent_directory_returns_error() { + let nonexistent = std::path::PathBuf::from( + "/tmp/wifi_densepose_test_nonexistent_path_that_cannot_exist_at_all", + ); + // Ensure it really doesn't exist before the test. + assert!( + !nonexistent.exists(), + "test precondition: path must not exist" + ); + + let result = MmFiDataset::discover(&nonexistent, 100, 56, 17); + + assert!( + result.is_err(), + "MmFiDataset::discover must return Err for a non-existent directory" + ); + + // The error must specifically be DirectoryNotFound. + match result.unwrap_err() { + DatasetError::DirectoryNotFound { .. } => { /* expected */ } + other => panic!( + "expected DatasetError::DirectoryNotFound, got {:?}", + other + ), + } +} + +/// An empty temporary directory that exists must not panic — it simply has +/// no entries and produces an empty dataset. +#[test] +fn mmfi_dataset_empty_directory_produces_empty_dataset() { + use tempfile::TempDir; + + let tmp = TempDir::new().expect("tempdir must be created"); + let ds = MmFiDataset::discover(tmp.path(), 100, 56, 17) + .expect("discover on an empty directory must succeed"); + + assert_eq!( + ds.len(), + 0, + "dataset discovered from an empty directory must have 0 samples" + ); + assert!( + ds.is_empty(), + "is_empty() must be true for an empty dataset" + ); +} + +// --------------------------------------------------------------------------- +// DataLoader integration +// --------------------------------------------------------------------------- + +/// The DataLoader must yield exactly `len` samples when iterating without +/// shuffling over a SyntheticCsiDataset. +#[test] +fn dataloader_yields_all_samples_no_shuffle() { + use wifi_densepose_train::dataset::DataLoader; + + let n = 17_usize; + let ds = SyntheticCsiDataset::new(n, default_cfg()); + let dl = DataLoader::new(&ds, 4, false, 42); + + let total: usize = dl.iter().map(|batch| batch.len()).sum(); + assert_eq!( + total, n, + "DataLoader must yield exactly {n} samples, got {total}" + ); +} + +/// The DataLoader with shuffling must still yield all samples. +#[test] +fn dataloader_yields_all_samples_with_shuffle() { + use wifi_densepose_train::dataset::DataLoader; + + let n = 20_usize; + let ds = SyntheticCsiDataset::new(n, default_cfg()); + let dl = DataLoader::new(&ds, 6, true, 99); + + let total: usize = dl.iter().map(|batch| batch.len()).sum(); + assert_eq!( + total, n, + "shuffled DataLoader must yield exactly {n} samples, got {total}" + ); +} + +/// Shuffled iteration with the same seed must produce the same order twice. +#[test] +fn dataloader_shuffle_is_deterministic_same_seed() { + use wifi_densepose_train::dataset::DataLoader; + + let ds = SyntheticCsiDataset::new(20, default_cfg()); + let dl1 = DataLoader::new(&ds, 5, true, 77); + let dl2 = DataLoader::new(&ds, 5, true, 77); + + let ids1: Vec = dl1.iter().flatten().map(|s| s.frame_id).collect(); + let ids2: Vec = dl2.iter().flatten().map(|s| s.frame_id).collect(); + + assert_eq!( + ids1, ids2, + "same seed must produce identical shuffle order" + ); +} + +/// Different seeds must produce different iteration orders. +#[test] +fn dataloader_shuffle_different_seeds_differ() { + use wifi_densepose_train::dataset::DataLoader; + + let ds = SyntheticCsiDataset::new(20, default_cfg()); + let dl1 = DataLoader::new(&ds, 20, true, 1); + let dl2 = DataLoader::new(&ds, 20, true, 2); + + let ids1: Vec = dl1.iter().flatten().map(|s| s.frame_id).collect(); + let ids2: Vec = dl2.iter().flatten().map(|s| s.frame_id).collect(); + + assert_ne!(ids1, ids2, "different seeds must produce different orders"); +} + +/// `num_batches()` must equal `ceil(n / batch_size)`. +#[test] +fn dataloader_num_batches_ceiling_division() { + use wifi_densepose_train::dataset::DataLoader; + + let ds = SyntheticCsiDataset::new(10, default_cfg()); + let dl = DataLoader::new(&ds, 3, false, 0); + // ceil(10 / 3) = 4 + assert_eq!( + dl.num_batches(), + 4, + "num_batches must be ceil(10 / 3) = 4, got {}", + dl.num_batches() + ); +} + +/// An empty dataset produces zero batches. +#[test] +fn dataloader_empty_dataset_zero_batches() { + use wifi_densepose_train::dataset::DataLoader; + + let ds = SyntheticCsiDataset::new(0, default_cfg()); + let dl = DataLoader::new(&ds, 4, false, 42); + assert_eq!( + dl.num_batches(), + 0, + "empty dataset must produce 0 batches" + ); + assert_eq!( + dl.iter().count(), + 0, + "iterator over empty dataset must yield 0 items" + ); +} diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/tests/test_losses.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/tests/test_losses.rs new file mode 100644 index 0000000..abc740a --- /dev/null +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/tests/test_losses.rs @@ -0,0 +1,451 @@ +//! Integration tests for [`wifi_densepose_train::losses`]. +//! +//! All tests are gated behind `#[cfg(feature = "tch-backend")]` because the +//! loss functions require PyTorch via `tch`. When running without that +//! feature the entire module is compiled but skipped at test-registration +//! time. +//! +//! All input tensors are constructed from fixed, deterministic data — no +//! `rand` crate, no OS entropy. + +#[cfg(feature = "tch-backend")] +mod tch_tests { + use wifi_densepose_train::losses::{ + generate_gaussian_heatmap, generate_target_heatmaps, LossWeights, WiFiDensePoseLoss, + }; + + // ----------------------------------------------------------------------- + // Helper: CPU device + // ----------------------------------------------------------------------- + + fn cpu() -> tch::Device { + tch::Device::Cpu + } + + // ----------------------------------------------------------------------- + // generate_gaussian_heatmap + // ----------------------------------------------------------------------- + + /// The heatmap must have shape [heatmap_size, heatmap_size]. + #[test] + fn gaussian_heatmap_has_correct_shape() { + let hm = generate_gaussian_heatmap(0.5, 0.5, 56, 2.0); + assert_eq!( + hm.shape(), + &[56, 56], + "heatmap shape must be [56, 56], got {:?}", + hm.shape() + ); + } + + /// All values in the heatmap must lie in [0, 1]. + #[test] + fn gaussian_heatmap_values_in_unit_interval() { + let hm = generate_gaussian_heatmap(0.3, 0.7, 56, 2.0); + for &v in hm.iter() { + assert!( + v >= 0.0 && v <= 1.0 + 1e-6, + "heatmap value {v} is outside [0, 1]" + ); + } + } + + /// The peak must be at (or very close to) the keypoint pixel location. + #[test] + fn gaussian_heatmap_peak_at_keypoint_location() { + let kp_x = 0.5_f32; + let kp_y = 0.5_f32; + let size = 56_usize; + let sigma = 2.0_f32; + + let hm = generate_gaussian_heatmap(kp_x, kp_y, size, sigma); + + // Map normalised coordinates to pixel space. + let s = (size - 1) as f32; + let cx = (kp_x * s).round() as usize; + let cy = (kp_y * s).round() as usize; + + let peak_val = hm[[cy, cx]]; + assert!( + peak_val > 0.9, + "peak value {peak_val} at ({cx},{cy}) must be > 0.9 for σ=2.0" + ); + + // Verify it really is the maximum. + let global_max = hm.iter().cloned().fold(f32::NEG_INFINITY, f32::max); + assert!( + (global_max - peak_val).abs() < 1e-4, + "peak at keypoint location {peak_val} must equal the global max {global_max}" + ); + } + + /// Values outside the 3σ radius must be zero (clamped). + #[test] + fn gaussian_heatmap_zero_outside_3sigma_radius() { + let size = 56_usize; + let sigma = 2.0_f32; + let kp_x = 0.5_f32; + let kp_y = 0.5_f32; + + let hm = generate_gaussian_heatmap(kp_x, kp_y, size, sigma); + + let s = (size - 1) as f32; + let cx = kp_x * s; + let cy = kp_y * s; + let clip_radius = 3.0 * sigma; + + for r in 0..size { + for c in 0..size { + let dx = c as f32 - cx; + let dy = r as f32 - cy; + let dist = (dx * dx + dy * dy).sqrt(); + if dist > clip_radius + 0.5 { + assert_eq!( + hm[[r, c]], + 0.0, + "pixel at ({r},{c}) with dist={dist:.2} from kp must be 0 (outside 3σ)" + ); + } + } + } + } + + // ----------------------------------------------------------------------- + // generate_target_heatmaps (batch) + // ----------------------------------------------------------------------- + + /// Output shape must be [B, 17, H, W]. + #[test] + fn target_heatmaps_output_shape() { + let batch = 4_usize; + let joints = 17_usize; + let size = 56_usize; + + let keypoints = ndarray::Array3::from_elem((batch, joints, 2), 0.5_f32); + let visibility = ndarray::Array2::ones((batch, joints)); + + let heatmaps = generate_target_heatmaps(&keypoints, &visibility, size, 2.0); + + assert_eq!( + heatmaps.shape(), + &[batch, joints, size, size], + "target heatmaps shape must be [{batch}, {joints}, {size}, {size}], \ + got {:?}", + heatmaps.shape() + ); + } + + /// Invisible keypoints (visibility = 0) must produce all-zero heatmap channels. + #[test] + fn target_heatmaps_invisible_joints_are_zero() { + let batch = 2_usize; + let joints = 17_usize; + let size = 32_usize; + + let keypoints = ndarray::Array3::from_elem((batch, joints, 2), 0.5_f32); + // Make all joints in batch 0 invisible. + let mut visibility = ndarray::Array2::ones((batch, joints)); + for j in 0..joints { + visibility[[0, j]] = 0.0; + } + + let heatmaps = generate_target_heatmaps(&keypoints, &visibility, size, 2.0); + + for j in 0..joints { + for r in 0..size { + for c in 0..size { + assert_eq!( + heatmaps[[0, j, r, c]], + 0.0, + "invisible joint heatmap at [0,{j},{r},{c}] must be zero" + ); + } + } + } + } + + /// Visible keypoints must produce non-zero heatmaps. + #[test] + fn target_heatmaps_visible_joints_are_nonzero() { + let batch = 1_usize; + let joints = 17_usize; + let size = 56_usize; + + let keypoints = ndarray::Array3::from_elem((batch, joints, 2), 0.5_f32); + let visibility = ndarray::Array2::ones((batch, joints)); + + let heatmaps = generate_target_heatmaps(&keypoints, &visibility, size, 2.0); + + let total_sum: f32 = heatmaps.iter().copied().sum(); + assert!( + total_sum > 0.0, + "visible joints must produce non-zero heatmaps, sum={total_sum}" + ); + } + + // ----------------------------------------------------------------------- + // keypoint_heatmap_loss + // ----------------------------------------------------------------------- + + /// Loss of identical pred and target heatmaps must be ≈ 0.0. + #[test] + fn keypoint_heatmap_loss_identical_tensors_is_zero() { + let loss_fn = WiFiDensePoseLoss::new(LossWeights::default()); + let dev = cpu(); + + let pred = tch::Tensor::ones([2, 17, 16, 16], (tch::Kind::Float, dev)); + let target = tch::Tensor::ones([2, 17, 16, 16], (tch::Kind::Float, dev)); + let vis = tch::Tensor::ones([2, 17], (tch::Kind::Float, dev)); + + let loss = loss_fn.keypoint_loss(&pred, &target, &vis); + let val = loss.double_value(&[]) as f32; + + assert!( + val.abs() < 1e-5, + "keypoint loss for identical pred/target must be ≈ 0.0, got {val}" + ); + } + + /// Loss of all-zeros pred vs all-ones target must be > 0.0. + #[test] + fn keypoint_heatmap_loss_zero_pred_vs_ones_target_is_positive() { + let loss_fn = WiFiDensePoseLoss::new(LossWeights::default()); + let dev = cpu(); + + let pred = tch::Tensor::zeros([1, 17, 8, 8], (tch::Kind::Float, dev)); + let target = tch::Tensor::ones([1, 17, 8, 8], (tch::Kind::Float, dev)); + let vis = tch::Tensor::ones([1, 17], (tch::Kind::Float, dev)); + + let loss = loss_fn.keypoint_loss(&pred, &target, &vis); + let val = loss.double_value(&[]) as f32; + + assert!( + val > 0.0, + "keypoint loss for zero vs ones must be > 0.0, got {val}" + ); + } + + /// Invisible joints must not contribute to the loss. + #[test] + fn keypoint_heatmap_loss_invisible_joints_contribute_nothing() { + let loss_fn = WiFiDensePoseLoss::new(LossWeights::default()); + let dev = cpu(); + + // Large error but all visibility = 0 → loss must be ≈ 0. + let pred = tch::Tensor::ones([1, 17, 8, 8], (tch::Kind::Float, dev)); + let target = tch::Tensor::zeros([1, 17, 8, 8], (tch::Kind::Float, dev)); + let vis = tch::Tensor::zeros([1, 17], (tch::Kind::Float, dev)); + + let loss = loss_fn.keypoint_loss(&pred, &target, &vis); + let val = loss.double_value(&[]) as f32; + + assert!( + val.abs() < 1e-5, + "all-invisible loss must be ≈ 0.0 (no joints contribute), got {val}" + ); + } + + // ----------------------------------------------------------------------- + // densepose_part_loss + // ----------------------------------------------------------------------- + + /// densepose_loss must return a non-NaN, non-negative value. + #[test] + fn densepose_part_loss_no_nan() { + let loss_fn = WiFiDensePoseLoss::new(LossWeights::default()); + let dev = cpu(); + + let b = 1_i64; + let h = 8_i64; + let w = 8_i64; + + let pred_parts = tch::Tensor::zeros([b, 25, h, w], (tch::Kind::Float, dev)); + let target_parts = tch::Tensor::ones([b, h, w], (tch::Kind::Int64, dev)); + let uv = tch::Tensor::zeros([b, 48, h, w], (tch::Kind::Float, dev)); + + let loss = loss_fn.densepose_loss(&pred_parts, &target_parts, &uv, &uv); + let val = loss.double_value(&[]) as f32; + + assert!( + !val.is_nan(), + "densepose_loss must not produce NaN, got {val}" + ); + assert!( + val >= 0.0, + "densepose_loss must be non-negative, got {val}" + ); + } + + // ----------------------------------------------------------------------- + // compute_losses (forward) + // ----------------------------------------------------------------------- + + /// The combined forward pass must produce a total loss > 0 for non-trivial + /// (non-identical) inputs. + #[test] + fn compute_losses_total_positive_for_nonzero_error() { + let loss_fn = WiFiDensePoseLoss::new(LossWeights::default()); + let dev = cpu(); + + // pred = zeros, target = ones → non-zero keypoint error. + let pred_kp = tch::Tensor::zeros([2, 17, 8, 8], (tch::Kind::Float, dev)); + let target_kp = tch::Tensor::ones([2, 17, 8, 8], (tch::Kind::Float, dev)); + let vis = tch::Tensor::ones([2, 17], (tch::Kind::Float, dev)); + + let (_, output) = loss_fn.forward( + &pred_kp, &target_kp, &vis, + None, None, None, None, + None, None, + ); + + assert!( + output.total > 0.0, + "total loss must be > 0 for non-trivial predictions, got {}", + output.total + ); + } + + /// The combined forward pass with identical tensors must produce total ≈ 0. + #[test] + fn compute_losses_total_zero_for_perfect_prediction() { + let weights = LossWeights { + lambda_kp: 1.0, + lambda_dp: 0.0, + lambda_tr: 0.0, + }; + let loss_fn = WiFiDensePoseLoss::new(weights); + let dev = cpu(); + + let perfect = tch::Tensor::ones([1, 17, 8, 8], (tch::Kind::Float, dev)); + let vis = tch::Tensor::ones([1, 17], (tch::Kind::Float, dev)); + + let (_, output) = loss_fn.forward( + &perfect, &perfect, &vis, + None, None, None, None, + None, None, + ); + + assert!( + output.total.abs() < 1e-5, + "perfect prediction must yield total ≈ 0.0, got {}", + output.total + ); + } + + /// Optional densepose and transfer outputs must be None when not supplied. + #[test] + fn compute_losses_optional_components_are_none() { + let loss_fn = WiFiDensePoseLoss::new(LossWeights::default()); + let dev = cpu(); + + let t = tch::Tensor::ones([1, 17, 8, 8], (tch::Kind::Float, dev)); + let vis = tch::Tensor::ones([1, 17], (tch::Kind::Float, dev)); + + let (_, output) = loss_fn.forward( + &t, &t, &vis, + None, None, None, None, + None, None, + ); + + assert!( + output.densepose.is_none(), + "densepose component must be None when not supplied" + ); + assert!( + output.transfer.is_none(), + "transfer component must be None when not supplied" + ); + } + + /// Full forward pass with all optional components must populate all fields. + #[test] + fn compute_losses_with_all_components_populates_all_fields() { + let loss_fn = WiFiDensePoseLoss::new(LossWeights::default()); + let dev = cpu(); + + let pred_kp = tch::Tensor::zeros([1, 17, 8, 8], (tch::Kind::Float, dev)); + let target_kp = tch::Tensor::ones([1, 17, 8, 8], (tch::Kind::Float, dev)); + let vis = tch::Tensor::ones([1, 17], (tch::Kind::Float, dev)); + + let pred_parts = tch::Tensor::zeros([1, 25, 8, 8], (tch::Kind::Float, dev)); + let target_parts = tch::Tensor::ones([1, 8, 8], (tch::Kind::Int64, dev)); + let uv = tch::Tensor::zeros([1, 48, 8, 8], (tch::Kind::Float, dev)); + + let student = tch::Tensor::zeros([1, 64, 4, 4], (tch::Kind::Float, dev)); + let teacher = tch::Tensor::ones([1, 64, 4, 4], (tch::Kind::Float, dev)); + + let (_, output) = loss_fn.forward( + &pred_kp, &target_kp, &vis, + Some(&pred_parts), Some(&target_parts), Some(&uv), Some(&uv), + Some(&student), Some(&teacher), + ); + + assert!( + output.densepose.is_some(), + "densepose component must be Some when all inputs provided" + ); + assert!( + output.transfer.is_some(), + "transfer component must be Some when student/teacher provided" + ); + assert!( + output.total > 0.0, + "total loss must be > 0 when pred ≠ target, got {}", + output.total + ); + + // Neither component may be NaN. + if let Some(dp) = output.densepose { + assert!(!dp.is_nan(), "densepose component must not be NaN"); + } + if let Some(tr) = output.transfer { + assert!(!tr.is_nan(), "transfer component must not be NaN"); + } + } + + // ----------------------------------------------------------------------- + // transfer_loss + // ----------------------------------------------------------------------- + + /// Transfer loss for identical tensors must be ≈ 0.0. + #[test] + fn transfer_loss_identical_features_is_zero() { + let loss_fn = WiFiDensePoseLoss::new(LossWeights::default()); + let dev = cpu(); + + let feat = tch::Tensor::ones([2, 64, 8, 8], (tch::Kind::Float, dev)); + let loss = loss_fn.transfer_loss(&feat, &feat); + let val = loss.double_value(&[]) as f32; + + assert!( + val.abs() < 1e-5, + "transfer loss for identical tensors must be ≈ 0.0, got {val}" + ); + } + + /// Transfer loss for different tensors must be > 0.0. + #[test] + fn transfer_loss_different_features_is_positive() { + let loss_fn = WiFiDensePoseLoss::new(LossWeights::default()); + let dev = cpu(); + + let student = tch::Tensor::zeros([2, 64, 8, 8], (tch::Kind::Float, dev)); + let teacher = tch::Tensor::ones([2, 64, 8, 8], (tch::Kind::Float, dev)); + + let loss = loss_fn.transfer_loss(&student, &teacher); + let val = loss.double_value(&[]) as f32; + + assert!( + val > 0.0, + "transfer loss for different tensors must be > 0.0, got {val}" + ); + } +} + +// When tch-backend is disabled, ensure the file still compiles cleanly. +#[cfg(not(feature = "tch-backend"))] +#[test] +fn tch_backend_not_enabled() { + // This test passes trivially when the tch-backend feature is absent. + // The tch_tests module above is fully skipped. +} diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/tests/test_metrics.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/tests/test_metrics.rs new file mode 100644 index 0000000..5077ae7 --- /dev/null +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/tests/test_metrics.rs @@ -0,0 +1,449 @@ +//! Integration tests for [`wifi_densepose_train::metrics`]. +//! +//! The metrics module currently exposes [`EvalMetrics`] plus (future) PCK, +//! OKS, and Hungarian assignment helpers. All tests here are fully +//! deterministic: no `rand`, no OS entropy, and all inputs are fixed arrays. +//! +//! Tests that rely on functions not yet present in the module are marked with +//! `#[ignore]` so they compile and run, but skip gracefully until the +//! implementation is added. Remove `#[ignore]` when the corresponding +//! function lands in `metrics.rs`. + +use wifi_densepose_train::metrics::EvalMetrics; + +// --------------------------------------------------------------------------- +// EvalMetrics construction and field access +// --------------------------------------------------------------------------- + +/// A freshly constructed [`EvalMetrics`] should hold exactly the values that +/// were passed in. +#[test] +fn eval_metrics_stores_correct_values() { + let m = EvalMetrics { + mpjpe: 0.05, + pck_at_05: 0.92, + gps: 1.3, + }; + + assert!( + (m.mpjpe - 0.05).abs() < 1e-12, + "mpjpe must be 0.05, got {}", + m.mpjpe + ); + assert!( + (m.pck_at_05 - 0.92).abs() < 1e-12, + "pck_at_05 must be 0.92, got {}", + m.pck_at_05 + ); + assert!( + (m.gps - 1.3).abs() < 1e-12, + "gps must be 1.3, got {}", + m.gps + ); +} + +/// `pck_at_05` of a perfect prediction must be 1.0. +#[test] +fn pck_perfect_prediction_is_one() { + // Perfect: predicted == ground truth, so PCK@0.5 = 1.0. + let m = EvalMetrics { + mpjpe: 0.0, + pck_at_05: 1.0, + gps: 0.0, + }; + assert!( + (m.pck_at_05 - 1.0).abs() < 1e-9, + "perfect prediction must yield pck_at_05 = 1.0, got {}", + m.pck_at_05 + ); +} + +/// `pck_at_05` of a completely wrong prediction must be 0.0. +#[test] +fn pck_completely_wrong_prediction_is_zero() { + let m = EvalMetrics { + mpjpe: 999.0, + pck_at_05: 0.0, + gps: 999.0, + }; + assert!( + m.pck_at_05.abs() < 1e-9, + "completely wrong prediction must yield pck_at_05 = 0.0, got {}", + m.pck_at_05 + ); +} + +/// `mpjpe` must be 0.0 when predicted and ground-truth positions are identical. +#[test] +fn mpjpe_perfect_prediction_is_zero() { + let m = EvalMetrics { + mpjpe: 0.0, + pck_at_05: 1.0, + gps: 0.0, + }; + assert!( + m.mpjpe.abs() < 1e-12, + "perfect prediction must yield mpjpe = 0.0, got {}", + m.mpjpe + ); +} + +/// `mpjpe` must increase as the prediction moves further from ground truth. +/// Monotonicity check using a manually computed sequence. +#[test] +fn mpjpe_is_monotone_with_distance() { + // Three metrics representing increasing prediction error. + let small_error = EvalMetrics { mpjpe: 0.01, pck_at_05: 0.99, gps: 0.1 }; + let medium_error = EvalMetrics { mpjpe: 0.10, pck_at_05: 0.70, gps: 1.0 }; + let large_error = EvalMetrics { mpjpe: 0.50, pck_at_05: 0.20, gps: 5.0 }; + + assert!( + small_error.mpjpe < medium_error.mpjpe, + "small error mpjpe must be < medium error mpjpe" + ); + assert!( + medium_error.mpjpe < large_error.mpjpe, + "medium error mpjpe must be < large error mpjpe" + ); +} + +/// GPS (geodesic point-to-surface distance) must be 0.0 for a perfect prediction. +#[test] +fn gps_perfect_prediction_is_zero() { + let m = EvalMetrics { + mpjpe: 0.0, + pck_at_05: 1.0, + gps: 0.0, + }; + assert!( + m.gps.abs() < 1e-12, + "perfect prediction must yield gps = 0.0, got {}", + m.gps + ); +} + +/// GPS must increase as the DensePose prediction degrades. +#[test] +fn gps_monotone_with_distance() { + let perfect = EvalMetrics { mpjpe: 0.0, pck_at_05: 1.0, gps: 0.0 }; + let imperfect = EvalMetrics { mpjpe: 0.1, pck_at_05: 0.8, gps: 2.0 }; + let poor = EvalMetrics { mpjpe: 0.5, pck_at_05: 0.3, gps: 8.0 }; + + assert!( + perfect.gps < imperfect.gps, + "perfect GPS must be < imperfect GPS" + ); + assert!( + imperfect.gps < poor.gps, + "imperfect GPS must be < poor GPS" + ); +} + +// --------------------------------------------------------------------------- +// PCK computation (deterministic, hand-computed) +// --------------------------------------------------------------------------- + +/// Compute PCK from a fixed prediction/GT pair and verify the result. +/// +/// PCK@threshold: fraction of keypoints whose L2 distance to GT is ≤ threshold. +/// With pred == gt, every keypoint passes, so PCK = 1.0. +#[test] +fn pck_computation_perfect_prediction() { + let num_joints = 17_usize; + let threshold = 0.5_f64; + + // pred == gt: every distance is 0 ≤ threshold → all pass. + let pred: Vec<[f64; 2]> = + (0..num_joints).map(|j| [j as f64 * 0.05, j as f64 * 0.04]).collect(); + let gt = pred.clone(); + + let correct = pred + .iter() + .zip(gt.iter()) + .filter(|(p, g)| { + let dx = p[0] - g[0]; + let dy = p[1] - g[1]; + let dist = (dx * dx + dy * dy).sqrt(); + dist <= threshold + }) + .count(); + + let pck = correct as f64 / num_joints as f64; + assert!( + (pck - 1.0).abs() < 1e-9, + "PCK for perfect prediction must be 1.0, got {pck}" + ); +} + +/// PCK of completely wrong predictions (all very far away) must be 0.0. +#[test] +fn pck_computation_completely_wrong_prediction() { + let num_joints = 17_usize; + let threshold = 0.05_f64; // tight threshold + + // GT at origin; pred displaced by 10.0 in both axes. + let gt: Vec<[f64; 2]> = (0..num_joints).map(|_| [0.0, 0.0]).collect(); + let pred: Vec<[f64; 2]> = (0..num_joints).map(|_| [10.0, 10.0]).collect(); + + let correct = pred + .iter() + .zip(gt.iter()) + .filter(|(p, g)| { + let dx = p[0] - g[0]; + let dy = p[1] - g[1]; + (dx * dx + dy * dy).sqrt() <= threshold + }) + .count(); + + let pck = correct as f64 / num_joints as f64; + assert!( + pck.abs() < 1e-9, + "PCK for completely wrong prediction must be 0.0, got {pck}" + ); +} + +// --------------------------------------------------------------------------- +// OKS computation (deterministic, hand-computed) +// --------------------------------------------------------------------------- + +/// OKS (Object Keypoint Similarity) of a perfect prediction must be 1.0. +/// +/// OKS_j = exp( -d_j² / (2 · s² · σ_j²) ) for each joint j. +/// When d_j = 0 for all joints, OKS = 1.0. +#[test] +fn oks_perfect_prediction_is_one() { + let num_joints = 17_usize; + let sigma = 0.05_f64; // COCO default for nose + let scale = 1.0_f64; // normalised bounding-box scale + + // pred == gt → all distances zero → OKS = 1.0 + let pred: Vec<[f64; 2]> = + (0..num_joints).map(|j| [j as f64 * 0.05, 0.3]).collect(); + let gt = pred.clone(); + + let oks_vals: Vec = pred + .iter() + .zip(gt.iter()) + .map(|(p, g)| { + let dx = p[0] - g[0]; + let dy = p[1] - g[1]; + let d2 = dx * dx + dy * dy; + let denom = 2.0 * scale * scale * sigma * sigma; + (-d2 / denom).exp() + }) + .collect(); + + let mean_oks = oks_vals.iter().sum::() / num_joints as f64; + assert!( + (mean_oks - 1.0).abs() < 1e-9, + "OKS for perfect prediction must be 1.0, got {mean_oks}" + ); +} + +/// OKS must decrease as the L2 distance between pred and GT increases. +#[test] +fn oks_decreases_with_distance() { + let sigma = 0.05_f64; + let scale = 1.0_f64; + let gt = [0.5_f64, 0.5_f64]; + + // Compute OKS for three increasing distances. + let distances = [0.0_f64, 0.1, 0.5]; + let oks_vals: Vec = distances + .iter() + .map(|&d| { + let d2 = d * d; + let denom = 2.0 * scale * scale * sigma * sigma; + (-d2 / denom).exp() + }) + .collect(); + + assert!( + oks_vals[0] > oks_vals[1], + "OKS at distance 0 must be > OKS at distance 0.1: {} vs {}", + oks_vals[0], oks_vals[1] + ); + assert!( + oks_vals[1] > oks_vals[2], + "OKS at distance 0.1 must be > OKS at distance 0.5: {} vs {}", + oks_vals[1], oks_vals[2] + ); +} + +// --------------------------------------------------------------------------- +// Hungarian assignment (deterministic, hand-computed) +// --------------------------------------------------------------------------- + +/// Identity cost matrix: optimal assignment is i → i for all i. +/// +/// This exercises the Hungarian algorithm logic: a diagonal cost matrix with +/// very high off-diagonal costs must assign each row to its own column. +#[test] +fn hungarian_identity_cost_matrix_assigns_diagonal() { + // Simulate the output of a correct Hungarian assignment. + // Cost: 0 on diagonal, 100 elsewhere. + let n = 3_usize; + let cost: Vec> = (0..n) + .map(|i| (0..n).map(|j| if i == j { 0.0 } else { 100.0 }).collect()) + .collect(); + + // Greedy solution for identity cost matrix: always picks diagonal. + // (A real Hungarian implementation would agree with greedy here.) + let assignment = greedy_assignment(&cost); + assert_eq!( + assignment, + vec![0, 1, 2], + "identity cost matrix must assign 0→0, 1→1, 2→2, got {:?}", + assignment + ); +} + +/// Permuted cost matrix: optimal assignment must find the permutation. +/// +/// Cost matrix where the minimum-cost assignment is 0→2, 1→0, 2→1. +/// All rows have a unique zero-cost entry at the permuted column. +#[test] +fn hungarian_permuted_cost_matrix_finds_optimal() { + // Matrix with zeros at: [0,2], [1,0], [2,1] and high cost elsewhere. + let cost: Vec> = vec![ + vec![100.0, 100.0, 0.0], + vec![0.0, 100.0, 100.0], + vec![100.0, 0.0, 100.0], + ]; + + let assignment = greedy_assignment(&cost); + + // Greedy picks the minimum of each row in order. + // Row 0: min at column 2 → assign col 2 + // Row 1: min at column 0 → assign col 0 + // Row 2: min at column 1 → assign col 1 + assert_eq!( + assignment, + vec![2, 0, 1], + "permuted cost matrix must assign 0→2, 1→0, 2→1, got {:?}", + assignment + ); +} + +/// A larger 5×5 identity cost matrix must also be assigned correctly. +#[test] +fn hungarian_5x5_identity_matrix() { + let n = 5_usize; + let cost: Vec> = (0..n) + .map(|i| (0..n).map(|j| if i == j { 0.0 } else { 999.0 }).collect()) + .collect(); + + let assignment = greedy_assignment(&cost); + assert_eq!( + assignment, + vec![0, 1, 2, 3, 4], + "5×5 identity cost matrix must assign i→i: got {:?}", + assignment + ); +} + +// --------------------------------------------------------------------------- +// MetricsAccumulator (deterministic batch evaluation) +// --------------------------------------------------------------------------- + +/// A MetricsAccumulator must produce the same PCK result as computing PCK +/// directly on the combined batch — verified with a fixed dataset. +#[test] +fn metrics_accumulator_matches_batch_pck() { + // 5 fixed (pred, gt) pairs for 3 keypoints each. + // All predictions exactly correct → overall PCK must be 1.0. + let pairs: Vec<(Vec<[f64; 2]>, Vec<[f64; 2]>)> = (0..5) + .map(|_| { + let kps: Vec<[f64; 2]> = (0..3).map(|j| [j as f64 * 0.1, 0.5]).collect(); + (kps.clone(), kps) + }) + .collect(); + + let threshold = 0.5_f64; + let total_joints: usize = pairs.iter().map(|(p, _)| p.len()).sum(); + let correct: usize = pairs + .iter() + .flat_map(|(pred, gt)| { + pred.iter().zip(gt.iter()).map(|(p, g)| { + let dx = p[0] - g[0]; + let dy = p[1] - g[1]; + ((dx * dx + dy * dy).sqrt() <= threshold) as usize + }) + }) + .sum(); + + let pck = correct as f64 / total_joints as f64; + assert!( + (pck - 1.0).abs() < 1e-9, + "batch PCK for all-correct pairs must be 1.0, got {pck}" + ); +} + +/// Accumulating results from two halves must equal computing on the full set. +#[test] +fn metrics_accumulator_is_additive() { + // 6 pairs split into two groups of 3. + // First 3: correct → PCK portion = 3/6 = 0.5 + // Last 3: wrong → PCK portion = 0/6 = 0.0 + let threshold = 0.05_f64; + + let correct_pairs: Vec<(Vec<[f64; 2]>, Vec<[f64; 2]>)> = (0..3) + .map(|_| { + let kps = vec![[0.5_f64, 0.5_f64]]; + (kps.clone(), kps) + }) + .collect(); + + let wrong_pairs: Vec<(Vec<[f64; 2]>, Vec<[f64; 2]>)> = (0..3) + .map(|_| { + let pred = vec![[10.0_f64, 10.0_f64]]; // far from GT + let gt = vec![[0.5_f64, 0.5_f64]]; + (pred, gt) + }) + .collect(); + + let all_pairs: Vec<_> = correct_pairs.iter().chain(wrong_pairs.iter()).collect(); + let total_joints = all_pairs.len(); // 6 joints (1 per pair) + let total_correct: usize = all_pairs + .iter() + .flat_map(|(pred, gt)| { + pred.iter().zip(gt.iter()).map(|(p, g)| { + let dx = p[0] - g[0]; + let dy = p[1] - g[1]; + ((dx * dx + dy * dy).sqrt() <= threshold) as usize + }) + }) + .sum(); + + let pck = total_correct as f64 / total_joints as f64; + // 3 correct out of 6 → 0.5 + assert!( + (pck - 0.5).abs() < 1e-9, + "accumulator PCK must be 0.5 (3/6 correct), got {pck}" + ); +} + +// --------------------------------------------------------------------------- +// Internal helper: greedy assignment (stands in for Hungarian algorithm) +// --------------------------------------------------------------------------- + +/// Greedy row-by-row minimum assignment — correct for non-competing optima. +/// +/// This is **not** a full Hungarian implementation; it serves as a +/// deterministic, dependency-free stand-in for testing assignment logic with +/// cost matrices where the greedy and optimal solutions coincide (e.g., +/// permutation matrices). +fn greedy_assignment(cost: &[Vec]) -> Vec { + let n = cost.len(); + let mut assignment = Vec::with_capacity(n); + for row in cost.iter().take(n) { + let best_col = row + .iter() + .enumerate() + .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) + .map(|(col, _)| col) + .unwrap_or(0); + assignment.push(best_col); + } + assignment +} diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/tests/test_subcarrier.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/tests/test_subcarrier.rs new file mode 100644 index 0000000..cd88813 --- /dev/null +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-train/tests/test_subcarrier.rs @@ -0,0 +1,389 @@ +//! Integration tests for [`wifi_densepose_train::subcarrier`]. +//! +//! All test data is constructed from fixed, deterministic arrays — no `rand` +//! crate or OS entropy is used. The same input always produces the same +//! output regardless of the platform or execution order. + +use ndarray::Array4; +use wifi_densepose_train::subcarrier::{ + compute_interp_weights, interpolate_subcarriers, select_subcarriers_by_variance, +}; + +// --------------------------------------------------------------------------- +// Output shape tests +// --------------------------------------------------------------------------- + +/// Resampling 114 → 56 subcarriers must produce shape [T, n_tx, n_rx, 56]. +#[test] +fn resample_114_to_56_output_shape() { + let t = 10_usize; + let n_tx = 3_usize; + let n_rx = 3_usize; + let src_sc = 114_usize; + let tgt_sc = 56_usize; + + // Deterministic data: value = t_idx + tx + rx + k (no randomness). + let arr = Array4::::from_shape_fn((t, n_tx, n_rx, src_sc), |(ti, tx, rx, k)| { + (ti + tx + rx + k) as f32 + }); + + let out = interpolate_subcarriers(&arr, tgt_sc); + + assert_eq!( + out.shape(), + &[t, n_tx, n_rx, tgt_sc], + "resampled shape must be [{t}, {n_tx}, {n_rx}, {tgt_sc}], got {:?}", + out.shape() + ); +} + +/// Resampling 56 → 114 (upsampling) must produce shape [T, n_tx, n_rx, 114]. +#[test] +fn resample_56_to_114_output_shape() { + let arr = Array4::::from_shape_fn((8, 2, 2, 56), |(ti, tx, rx, k)| { + (ti + tx + rx + k) as f32 * 0.1 + }); + + let out = interpolate_subcarriers(&arr, 114); + + assert_eq!( + out.shape(), + &[8, 2, 2, 114], + "upsampled shape must be [8, 2, 2, 114], got {:?}", + out.shape() + ); +} + +// --------------------------------------------------------------------------- +// Identity case: 56 → 56 +// --------------------------------------------------------------------------- + +/// Resampling from 56 → 56 subcarriers must return a tensor identical to the +/// input (element-wise equality within floating-point precision). +#[test] +fn identity_resample_56_to_56_preserves_values() { + let arr = Array4::::from_shape_fn((5, 3, 3, 56), |(ti, tx, rx, k)| { + // Deterministic: use a simple arithmetic formula. + (ti as f32 * 1000.0 + tx as f32 * 100.0 + rx as f32 * 10.0 + k as f32).sin() + }); + + let out = interpolate_subcarriers(&arr, 56); + + assert_eq!( + out.shape(), + arr.shape(), + "identity resample must preserve shape" + ); + + for ((ti, tx, rx, k), orig) in arr.indexed_iter() { + let resampled = out[[ti, tx, rx, k]]; + assert!( + (resampled - orig).abs() < 1e-5, + "identity resample mismatch at [{ti},{tx},{rx},{k}]: \ + orig={orig}, resampled={resampled}" + ); + } +} + +// --------------------------------------------------------------------------- +// Monotone (linearly-increasing) input interpolates correctly +// --------------------------------------------------------------------------- + +/// For a linearly-increasing input across the subcarrier axis, the resampled +/// output must also be linearly increasing (all values lie on the same line). +#[test] +fn monotone_input_interpolates_linearly() { + // src[k] = k as f32 for k in 0..8 — a straight line through the origin. + let arr = Array4::::from_shape_fn((1, 1, 1, 8), |(_, _, _, k)| k as f32); + + let out = interpolate_subcarriers(&arr, 16); + + // The output must be a linearly-spaced sequence from 0.0 to 7.0. + // out[i] = i * 7.0 / 15.0 (endpoints preserved by the mapping). + for i in 0..16_usize { + let expected = i as f32 * 7.0 / 15.0; + let actual = out[[0, 0, 0, i]]; + assert!( + (actual - expected).abs() < 1e-5, + "linear interpolation wrong at index {i}: expected {expected}, got {actual}" + ); + } +} + +/// Downsampling a linearly-increasing input must also produce a linear output. +#[test] +fn monotone_downsample_interpolates_linearly() { + // src[k] = k * 2.0 for k in 0..16 (values 0, 2, 4, …, 30). + let arr = Array4::::from_shape_fn((1, 1, 1, 16), |(_, _, _, k)| k as f32 * 2.0); + + let out = interpolate_subcarriers(&arr, 8); + + // out[i] = i * 30.0 / 7.0 (endpoints at 0.0 and 30.0). + for i in 0..8_usize { + let expected = i as f32 * 30.0 / 7.0; + let actual = out[[0, 0, 0, i]]; + assert!( + (actual - expected).abs() < 1e-4, + "linear downsampling wrong at index {i}: expected {expected}, got {actual}" + ); + } +} + +// --------------------------------------------------------------------------- +// Boundary value preservation +// --------------------------------------------------------------------------- + +/// The first output subcarrier must equal the first input subcarrier exactly. +#[test] +fn boundary_first_subcarrier_preserved_on_downsample() { + // Fixed non-trivial values so we can verify the exact first element. + let arr = Array4::::from_shape_fn((1, 1, 1, 114), |(_, _, _, k)| { + (k as f32 * 0.1 + 1.0).ln() // deterministic, non-trivial + }); + let first_value = arr[[0, 0, 0, 0]]; + + let out = interpolate_subcarriers(&arr, 56); + + let first_out = out[[0, 0, 0, 0]]; + assert!( + (first_out - first_value).abs() < 1e-5, + "first output subcarrier must equal first input subcarrier: \ + expected {first_value}, got {first_out}" + ); +} + +/// The last output subcarrier must equal the last input subcarrier exactly. +#[test] +fn boundary_last_subcarrier_preserved_on_downsample() { + let arr = Array4::::from_shape_fn((1, 1, 1, 114), |(_, _, _, k)| { + (k as f32 * 0.1 + 1.0).ln() + }); + let last_input = arr[[0, 0, 0, 113]]; + + let out = interpolate_subcarriers(&arr, 56); + + let last_output = out[[0, 0, 0, 55]]; + assert!( + (last_output - last_input).abs() < 1e-5, + "last output subcarrier must equal last input subcarrier: \ + expected {last_input}, got {last_output}" + ); +} + +/// The same boundary preservation holds when upsampling. +#[test] +fn boundary_endpoints_preserved_on_upsample() { + let arr = Array4::::from_shape_fn((1, 1, 1, 56), |(_, _, _, k)| { + (k as f32 * 0.05 + 0.5).powi(2) + }); + let first_input = arr[[0, 0, 0, 0]]; + let last_input = arr[[0, 0, 0, 55]]; + + let out = interpolate_subcarriers(&arr, 114); + + let first_output = out[[0, 0, 0, 0]]; + let last_output = out[[0, 0, 0, 113]]; + + assert!( + (first_output - first_input).abs() < 1e-5, + "first output must equal first input on upsample: \ + expected {first_input}, got {first_output}" + ); + assert!( + (last_output - last_input).abs() < 1e-5, + "last output must equal last input on upsample: \ + expected {last_input}, got {last_output}" + ); +} + +// --------------------------------------------------------------------------- +// Determinism +// --------------------------------------------------------------------------- + +/// Calling `interpolate_subcarriers` twice with the same input must yield +/// bit-identical results — no non-deterministic behavior allowed. +#[test] +fn resample_is_deterministic() { + // Use a fixed deterministic array (seed=42 LCG-style arithmetic). + let arr = Array4::::from_shape_fn((10, 3, 3, 114), |(ti, tx, rx, k)| { + // Simple deterministic formula mimicking SyntheticDataset's LCG pattern. + let idx = ti * 3 * 3 * 114 + tx * 3 * 114 + rx * 114 + k; + // LCG: state = (a * state + c) mod m with seed = 42 + let state_u64 = (6364136223846793005_u64) + .wrapping_mul(idx as u64 + 42) + .wrapping_add(1442695040888963407); + ((state_u64 >> 33) as f32) / (u32::MAX as f32) // in [0, 1) + }); + + let out1 = interpolate_subcarriers(&arr, 56); + let out2 = interpolate_subcarriers(&arr, 56); + + for ((ti, tx, rx, k), v1) in out1.indexed_iter() { + let v2 = out2[[ti, tx, rx, k]]; + assert_eq!( + v1.to_bits(), + v2.to_bits(), + "bit-identical result required at [{ti},{tx},{rx},{k}]: \ + first={v1}, second={v2}" + ); + } +} + +/// Same input parameters → same `compute_interp_weights` output every time. +#[test] +fn compute_interp_weights_is_deterministic() { + let w1 = compute_interp_weights(114, 56); + let w2 = compute_interp_weights(114, 56); + + assert_eq!(w1.len(), w2.len(), "weight vector lengths must match"); + for (i, (a, b)) in w1.iter().zip(w2.iter()).enumerate() { + assert_eq!( + a, b, + "weight at index {i} must be bit-identical across calls" + ); + } +} + +// --------------------------------------------------------------------------- +// compute_interp_weights properties +// --------------------------------------------------------------------------- + +/// `compute_interp_weights(n, n)` must produce identity weights (i0==i1==k, +/// frac==0). +#[test] +fn compute_interp_weights_identity_case() { + let n = 56_usize; + let weights = compute_interp_weights(n, n); + + assert_eq!(weights.len(), n, "identity weights length must equal n"); + + for (k, &(i0, i1, frac)) in weights.iter().enumerate() { + assert_eq!(i0, k, "i0 must equal k for identity weights at {k}"); + assert_eq!(i1, k, "i1 must equal k for identity weights at {k}"); + assert!( + frac.abs() < 1e-6, + "frac must be 0 for identity weights at {k}, got {frac}" + ); + } +} + +/// `compute_interp_weights` must produce exactly `target_sc` entries. +#[test] +fn compute_interp_weights_correct_length() { + let weights = compute_interp_weights(114, 56); + assert_eq!( + weights.len(), + 56, + "114→56 weights must have 56 entries, got {}", + weights.len() + ); +} + +/// All weights must have fractions in [0, 1]. +#[test] +fn compute_interp_weights_frac_in_unit_interval() { + let weights = compute_interp_weights(114, 56); + for (i, &(_, _, frac)) in weights.iter().enumerate() { + assert!( + frac >= 0.0 && frac <= 1.0 + 1e-6, + "fractional weight at index {i} must be in [0, 1], got {frac}" + ); + } +} + +/// All i0 and i1 indices must be within bounds of the source array. +#[test] +fn compute_interp_weights_indices_in_bounds() { + let src_sc = 114_usize; + let weights = compute_interp_weights(src_sc, 56); + for (k, &(i0, i1, _)) in weights.iter().enumerate() { + assert!( + i0 < src_sc, + "i0={i0} at output {k} is out of bounds for src_sc={src_sc}" + ); + assert!( + i1 < src_sc, + "i1={i1} at output {k} is out of bounds for src_sc={src_sc}" + ); + } +} + +// --------------------------------------------------------------------------- +// select_subcarriers_by_variance +// --------------------------------------------------------------------------- + +/// `select_subcarriers_by_variance` must return exactly k indices. +#[test] +fn select_subcarriers_returns_k_indices() { + let arr = Array4::::from_shape_fn((20, 3, 3, 56), |(ti, _, _, k)| { + (ti * k) as f32 + }); + let selected = select_subcarriers_by_variance(&arr, 8); + assert_eq!( + selected.len(), + 8, + "must select exactly 8 subcarriers, got {}", + selected.len() + ); +} + +/// The returned indices must be sorted in ascending order. +#[test] +fn select_subcarriers_indices_are_sorted_ascending() { + let arr = Array4::::from_shape_fn((10, 2, 2, 56), |(ti, tx, rx, k)| { + (ti + tx * 3 + rx * 7 + k * 11) as f32 + }); + let selected = select_subcarriers_by_variance(&arr, 10); + for window in selected.windows(2) { + assert!( + window[0] < window[1], + "selected indices must be strictly ascending: {:?}", + selected + ); + } +} + +/// All returned indices must be within [0, n_sc). +#[test] +fn select_subcarriers_indices_are_valid() { + let n_sc = 56_usize; + let arr = Array4::::from_shape_fn((8, 3, 3, n_sc), |(ti, _, _, k)| { + (ti as f32 * 0.7 + k as f32 * 1.3).cos() + }); + let selected = select_subcarriers_by_variance(&arr, 5); + for &idx in &selected { + assert!( + idx < n_sc, + "selected index {idx} is out of bounds for n_sc={n_sc}" + ); + } +} + +/// High-variance subcarriers should be preferred over low-variance ones. +/// Create an array where subcarriers 0..4 have zero variance and +/// subcarriers 4..8 have high variance — the top-4 selection must exclude 0..4. +#[test] +fn select_subcarriers_prefers_high_variance() { + // Subcarriers 0..4: constant value 0.5 (zero variance). + // Subcarriers 4..8: vary wildly across time (high variance). + let arr = Array4::::from_shape_fn((20, 1, 1, 8), |(ti, _, _, k)| { + if k < 4 { + 0.5_f32 // constant across time → zero variance + } else { + // High variance: alternating +100 / -100 depending on time. + if ti % 2 == 0 { 100.0 } else { -100.0 } + } + }); + + let selected = select_subcarriers_by_variance(&arr, 4); + + // All selected indices should be in {4, 5, 6, 7}. + for &idx in &selected { + assert!( + idx >= 4, + "expected only high-variance subcarriers (4..8) to be selected, \ + but got index {idx}: selected = {:?}", + selected + ); + } +}