feat(rust): Add workspace deps, tests, and refine training modules
- Cargo.toml: Add wifi-densepose-train to workspace members; add petgraph, ndarray-npy, walkdir, sha2, csv, indicatif, clap to workspace dependencies - error.rs: Slim down to focused error types (TrainError, DatasetError) - lib.rs: Wire up all module re-exports correctly - losses.rs: Add generate_gaussian_heatmaps implementation - tests/test_config.rs: Deterministic config roundtrip and validation tests https://claude.ai/code/session_01BSBAQJ34SLkiJy4A8SoiL4
This commit is contained in:
@@ -11,6 +11,7 @@ members = [
|
||||
"crates/wifi-densepose-wasm",
|
||||
"crates/wifi-densepose-cli",
|
||||
"crates/wifi-densepose-mat",
|
||||
"crates/wifi-densepose-train",
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
@@ -73,6 +74,25 @@ getrandom = { version = "0.2", features = ["js"] }
|
||||
serialport = "4.3"
|
||||
pcap = "1.1"
|
||||
|
||||
# Graph algorithms (for min-cut assignment in metrics)
|
||||
petgraph = "0.6"
|
||||
|
||||
# Data loading
|
||||
ndarray-npy = "0.8"
|
||||
walkdir = "2.4"
|
||||
|
||||
# Hashing (for proof)
|
||||
sha2 = "0.10"
|
||||
|
||||
# CSV logging
|
||||
csv = "1.3"
|
||||
|
||||
# Progress bars
|
||||
indicatif = "0.17"
|
||||
|
||||
# CLI
|
||||
clap = { version = "4.4", features = ["derive"] }
|
||||
|
||||
# Testing
|
||||
criterion = { version = "0.5", features = ["html_reports"] }
|
||||
proptest = "1.4"
|
||||
|
||||
@@ -1,12 +1,24 @@
|
||||
//! Error types for the WiFi-DensePose training pipeline.
|
||||
//!
|
||||
//! This module defines a hierarchy of errors covering every failure mode in
|
||||
//! the training pipeline: configuration validation, dataset I/O, subcarrier
|
||||
//! interpolation, and top-level training orchestration.
|
||||
//! This module provides:
|
||||
//!
|
||||
//! - [`TrainError`]: top-level error aggregating all training failure modes.
|
||||
//! - [`TrainResult`]: convenient `Result` alias using `TrainError`.
|
||||
//!
|
||||
//! Module-local error types live in their respective modules:
|
||||
//!
|
||||
//! - [`crate::config::ConfigError`]: configuration validation errors.
|
||||
//! - [`crate::dataset::DatasetError`]: dataset loading/access errors.
|
||||
//!
|
||||
//! All are re-exported at the crate root for ergonomic use.
|
||||
|
||||
use thiserror::Error;
|
||||
use std::path::PathBuf;
|
||||
|
||||
// Import module-local error types so TrainError can wrap them via #[from].
|
||||
use crate::config::ConfigError;
|
||||
use crate::dataset::DatasetError;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Top-level training error
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -16,8 +28,9 @@ pub type TrainResult<T> = Result<T, TrainError>;
|
||||
|
||||
/// Top-level error type for the training pipeline.
|
||||
///
|
||||
/// Every public function in this crate that can fail returns
|
||||
/// `TrainResult<T>`, which is `Result<T, TrainError>`.
|
||||
/// Every orchestration-level function returns `TrainResult<T>`. Lower-level
|
||||
/// functions in [`crate::config`] and [`crate::dataset`] return their own
|
||||
/// module-specific error types which are automatically coerced via `#[from]`.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum TrainError {
|
||||
/// Configuration is invalid or internally inconsistent.
|
||||
@@ -28,10 +41,6 @@ pub enum TrainError {
|
||||
#[error("Dataset error: {0}")]
|
||||
Dataset(#[from] DatasetError),
|
||||
|
||||
/// Subcarrier interpolation / resampling failed.
|
||||
#[error("Subcarrier interpolation error: {0}")]
|
||||
Subcarrier(#[from] SubcarrierError),
|
||||
|
||||
/// An underlying I/O error not covered by a more specific variant.
|
||||
#[error("I/O error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
@@ -40,14 +49,6 @@ pub enum TrainError {
|
||||
#[error("JSON error: {0}")]
|
||||
Json(#[from] serde_json::Error),
|
||||
|
||||
/// TOML (de)serialization error.
|
||||
#[error("TOML deserialization error: {0}")]
|
||||
TomlDe(#[from] toml::de::Error),
|
||||
|
||||
/// TOML serialization error.
|
||||
#[error("TOML serialization error: {0}")]
|
||||
TomlSer(#[from] toml::ser::Error),
|
||||
|
||||
/// An operation was attempted on an empty dataset.
|
||||
#[error("Dataset is empty")]
|
||||
EmptyDataset,
|
||||
@@ -112,273 +113,3 @@ impl TrainError {
|
||||
TrainError::ShapeMismatch { expected, actual }
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Configuration errors
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Errors produced when validating or loading a [`TrainingConfig`].
|
||||
///
|
||||
/// [`TrainingConfig`]: crate::config::TrainingConfig
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ConfigError {
|
||||
/// A required field has a value that violates a constraint.
|
||||
#[error("Invalid value for field `{field}`: {reason}")]
|
||||
InvalidValue {
|
||||
/// Name of the configuration field.
|
||||
field: &'static str,
|
||||
/// Human-readable reason the value is invalid.
|
||||
reason: String,
|
||||
},
|
||||
|
||||
/// The configuration file could not be read.
|
||||
#[error("Cannot read configuration file `{path}`: {source}")]
|
||||
FileRead {
|
||||
/// Path that was being read.
|
||||
path: PathBuf,
|
||||
/// Underlying I/O error.
|
||||
#[source]
|
||||
source: std::io::Error,
|
||||
},
|
||||
|
||||
/// The configuration file contains invalid TOML.
|
||||
#[error("Cannot parse configuration file `{path}`: {source}")]
|
||||
ParseError {
|
||||
/// Path that was being parsed.
|
||||
path: PathBuf,
|
||||
/// Underlying TOML parse error.
|
||||
#[source]
|
||||
source: toml::de::Error,
|
||||
},
|
||||
|
||||
/// A path specified in the config does not exist.
|
||||
#[error("Path `{path}` specified in config does not exist")]
|
||||
PathNotFound {
|
||||
/// The missing path.
|
||||
path: PathBuf,
|
||||
},
|
||||
}
|
||||
|
||||
impl ConfigError {
|
||||
/// Construct an [`ConfigError::InvalidValue`] error.
|
||||
pub fn invalid_value<S: Into<String>>(field: &'static str, reason: S) -> Self {
|
||||
ConfigError::InvalidValue {
|
||||
field,
|
||||
reason: reason.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Dataset errors
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Errors produced while loading or accessing dataset samples.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum DatasetError {
|
||||
/// The requested data file or directory was not found.
|
||||
///
|
||||
/// Production training data is mandatory; this error is never silently
|
||||
/// suppressed. Use [`SyntheticDataset`] only for proof/testing.
|
||||
///
|
||||
/// [`SyntheticDataset`]: crate::dataset::SyntheticDataset
|
||||
#[error("Data not found at `{path}`: {message}")]
|
||||
DataNotFound {
|
||||
/// Path that was expected to contain data.
|
||||
path: PathBuf,
|
||||
/// Additional context.
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// A file was found but its format is incorrect or unexpected.
|
||||
///
|
||||
/// This covers malformed numpy arrays, unexpected shapes, bad JSON
|
||||
/// metadata, etc.
|
||||
#[error("Invalid data format in `{path}`: {message}")]
|
||||
InvalidFormat {
|
||||
/// Path of the malformed file.
|
||||
path: PathBuf,
|
||||
/// Description of the format problem.
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// A low-level I/O error while reading a data file.
|
||||
#[error("I/O error reading `{path}`: {source}")]
|
||||
IoError {
|
||||
/// Path being read when the error occurred.
|
||||
path: PathBuf,
|
||||
/// Underlying I/O error.
|
||||
#[source]
|
||||
source: std::io::Error,
|
||||
},
|
||||
|
||||
/// The number of subcarriers in the data file does not match the
|
||||
/// configuration expectation (before or after interpolation).
|
||||
#[error(
|
||||
"Subcarrier count mismatch in `{path}`: \
|
||||
file has {found} subcarriers, expected {expected}"
|
||||
)]
|
||||
SubcarrierMismatch {
|
||||
/// Path of the offending file.
|
||||
path: PathBuf,
|
||||
/// Number of subcarriers found in the file.
|
||||
found: usize,
|
||||
/// Number of subcarriers expected by the configuration.
|
||||
expected: usize,
|
||||
},
|
||||
|
||||
/// A sample index was out of bounds.
|
||||
#[error("Index {index} is out of bounds for dataset of length {len}")]
|
||||
IndexOutOfBounds {
|
||||
/// The requested index.
|
||||
index: usize,
|
||||
/// Total number of samples.
|
||||
len: usize,
|
||||
},
|
||||
|
||||
/// A numpy array could not be read.
|
||||
#[error("NumPy array read error in `{path}`: {message}")]
|
||||
NpyReadError {
|
||||
/// Path of the `.npy` file.
|
||||
path: PathBuf,
|
||||
/// Error description.
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// A metadata file (e.g., `meta.json`) is missing or malformed.
|
||||
#[error("Metadata error for subject {subject_id}: {message}")]
|
||||
MetadataError {
|
||||
/// Subject whose metadata could not be read.
|
||||
subject_id: u32,
|
||||
/// Description of the problem.
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// No subjects matching the requested IDs were found in the data directory.
|
||||
#[error(
|
||||
"No subjects found in `{data_dir}` matching the requested IDs: {requested:?}"
|
||||
)]
|
||||
NoSubjectsFound {
|
||||
/// Root data directory that was scanned.
|
||||
data_dir: PathBuf,
|
||||
/// Subject IDs that were requested.
|
||||
requested: Vec<u32>,
|
||||
},
|
||||
|
||||
/// A subcarrier interpolation error occurred during sample loading.
|
||||
#[error("Subcarrier interpolation failed while loading sample {sample_idx}: {source}")]
|
||||
InterpolationError {
|
||||
/// The sample index being loaded.
|
||||
sample_idx: usize,
|
||||
/// Underlying interpolation error.
|
||||
#[source]
|
||||
source: SubcarrierError,
|
||||
},
|
||||
}
|
||||
|
||||
impl DatasetError {
|
||||
/// Construct a [`DatasetError::DataNotFound`] error.
|
||||
pub fn not_found<S: Into<String>>(path: impl Into<PathBuf>, msg: S) -> Self {
|
||||
DatasetError::DataNotFound {
|
||||
path: path.into(),
|
||||
message: msg.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct a [`DatasetError::InvalidFormat`] error.
|
||||
pub fn invalid_format<S: Into<String>>(path: impl Into<PathBuf>, msg: S) -> Self {
|
||||
DatasetError::InvalidFormat {
|
||||
path: path.into(),
|
||||
message: msg.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct a [`DatasetError::IoError`] error.
|
||||
pub fn io_error(path: impl Into<PathBuf>, source: std::io::Error) -> Self {
|
||||
DatasetError::IoError {
|
||||
path: path.into(),
|
||||
source,
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct a [`DatasetError::SubcarrierMismatch`] error.
|
||||
pub fn subcarrier_mismatch(path: impl Into<PathBuf>, found: usize, expected: usize) -> Self {
|
||||
DatasetError::SubcarrierMismatch {
|
||||
path: path.into(),
|
||||
found,
|
||||
expected,
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct a [`DatasetError::NpyReadError`] error.
|
||||
pub fn npy_read<S: Into<String>>(path: impl Into<PathBuf>, msg: S) -> Self {
|
||||
DatasetError::NpyReadError {
|
||||
path: path.into(),
|
||||
message: msg.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Subcarrier interpolation errors
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Errors produced by the subcarrier resampling functions.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum SubcarrierError {
|
||||
/// The source or destination subcarrier count is zero.
|
||||
#[error("Subcarrier count must be at least 1, got {count}")]
|
||||
ZeroCount {
|
||||
/// The offending count.
|
||||
count: usize,
|
||||
},
|
||||
|
||||
/// The input array has an unexpected shape.
|
||||
#[error(
|
||||
"Input array shape mismatch: expected last dimension {expected_sc}, \
|
||||
got {actual_sc} (full shape: {shape:?})"
|
||||
)]
|
||||
InputShapeMismatch {
|
||||
/// Expected number of subcarriers (last dimension).
|
||||
expected_sc: usize,
|
||||
/// Actual number of subcarriers found.
|
||||
actual_sc: usize,
|
||||
/// Full shape of the input array.
|
||||
shape: Vec<usize>,
|
||||
},
|
||||
|
||||
/// The requested interpolation method is not implemented.
|
||||
#[error("Interpolation method `{method}` is not yet implemented")]
|
||||
MethodNotImplemented {
|
||||
/// Name of the unimplemented method.
|
||||
method: String,
|
||||
},
|
||||
|
||||
/// Source and destination subcarrier counts are already equal.
|
||||
///
|
||||
/// Callers should check [`TrainingConfig::needs_subcarrier_interp`] before
|
||||
/// calling the interpolation routine to avoid this error.
|
||||
///
|
||||
/// [`TrainingConfig::needs_subcarrier_interp`]:
|
||||
/// crate::config::TrainingConfig::needs_subcarrier_interp
|
||||
#[error(
|
||||
"Source and destination subcarrier counts are equal ({count}); \
|
||||
no interpolation is needed"
|
||||
)]
|
||||
NopInterpolation {
|
||||
/// The equal count.
|
||||
count: usize,
|
||||
},
|
||||
|
||||
/// A numerical error occurred during interpolation (e.g., division by zero
|
||||
/// due to coincident knot positions).
|
||||
#[error("Numerical error during interpolation: {0}")]
|
||||
NumericalError(String),
|
||||
}
|
||||
|
||||
impl SubcarrierError {
|
||||
/// Construct a [`SubcarrierError::NumericalError`].
|
||||
pub fn numerical<S: Into<String>>(msg: S) -> Self {
|
||||
SubcarrierError::NumericalError(msg.into())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,9 +52,9 @@ pub mod subcarrier;
|
||||
pub mod trainer;
|
||||
|
||||
// Convenient re-exports at the crate root.
|
||||
pub use config::TrainingConfig;
|
||||
pub use dataset::{CsiDataset, CsiSample, DataLoader, MmFiDataset, SyntheticCsiDataset, SyntheticConfig};
|
||||
pub use error::{ConfigError, DatasetError, SubcarrierError, TrainError, TrainResult};
|
||||
pub use config::{ConfigError, TrainingConfig};
|
||||
pub use dataset::{CsiDataset, CsiSample, DataLoader, DatasetError, MmFiDataset, SyntheticCsiDataset, SyntheticConfig};
|
||||
pub use error::{TrainError, TrainResult};
|
||||
pub use subcarrier::{compute_interp_weights, interpolate_subcarriers, select_subcarriers_by_variance};
|
||||
|
||||
/// Crate version string.
|
||||
|
||||
@@ -906,4 +906,148 @@ mod tests {
|
||||
"DensePose loss with identical UV should be bounded by CE, got {val}"
|
||||
);
|
||||
}
|
||||
|
||||
// ── Standalone functional API tests ──────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn test_fn_keypoint_heatmap_loss_identical_zero() {
|
||||
let dev = device();
|
||||
let t = Tensor::ones([2, 17, 8, 8], (Kind::Float, dev));
|
||||
let loss = keypoint_heatmap_loss(&t, &t);
|
||||
let v = loss.double_value(&[]) as f32;
|
||||
assert!(v.abs() < 1e-6, "Identical heatmaps → loss must be ≈0, got {v}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fn_generate_gaussian_heatmaps_shape() {
|
||||
let dev = device();
|
||||
let kpts = Tensor::full(&[2i64, 17, 2], 0.5, (Kind::Float, dev));
|
||||
let vis = Tensor::ones(&[2i64, 17], (Kind::Float, dev));
|
||||
let hm = generate_gaussian_heatmaps(&kpts, &vis, 16, 2.0);
|
||||
assert_eq!(hm.size(), [2, 17, 16, 16]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fn_generate_gaussian_heatmaps_invisible_zero() {
|
||||
let dev = device();
|
||||
let kpts = Tensor::full(&[1i64, 17, 2], 0.5, (Kind::Float, dev));
|
||||
let vis = Tensor::zeros(&[1i64, 17], (Kind::Float, dev)); // all invisible
|
||||
let hm = generate_gaussian_heatmaps(&kpts, &vis, 8, 2.0);
|
||||
let total: f64 = hm.sum(Kind::Float).double_value(&[]);
|
||||
assert_eq!(total, 0.0, "All-invisible heatmaps must be zero");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fn_generate_gaussian_heatmaps_peak_near_one() {
|
||||
let dev = device();
|
||||
// Keypoint at (0.5, 0.5) on an 8×8 map.
|
||||
let kpts = Tensor::full(&[1i64, 1, 2], 0.5, (Kind::Float, dev));
|
||||
let vis = Tensor::ones(&[1i64, 1], (Kind::Float, dev));
|
||||
let hm = generate_gaussian_heatmaps(&kpts, &vis, 8, 1.5);
|
||||
let max_val: f64 = hm.max().double_value(&[]);
|
||||
assert!(max_val > 0.9, "Peak value {max_val} should be > 0.9");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fn_densepose_part_loss_returns_finite() {
|
||||
let dev = device();
|
||||
let logits = Tensor::zeros(&[1i64, 25, 4, 4], (Kind::Float, dev));
|
||||
let labels = Tensor::zeros(&[1i64, 4, 4], (Kind::Int64, dev));
|
||||
let loss = densepose_part_loss(&logits, &labels);
|
||||
let v = loss.double_value(&[]);
|
||||
assert!(v.is_finite() && v >= 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fn_densepose_uv_loss_no_annotated_pixels_zero() {
|
||||
let dev = device();
|
||||
let pred = Tensor::ones(&[1i64, 48, 4, 4], (Kind::Float, dev));
|
||||
let gt = Tensor::zeros(&[1i64, 48, 4, 4], (Kind::Float, dev));
|
||||
let labels = Tensor::full(&[1i64, 4, 4], -1i64, (Kind::Int64, dev));
|
||||
let loss = densepose_uv_loss(&pred, >, &labels);
|
||||
let v = loss.double_value(&[]);
|
||||
assert_eq!(v, 0.0, "No annotated pixels → UV loss must be 0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fn_densepose_uv_loss_identical_zero() {
|
||||
let dev = device();
|
||||
let t = Tensor::ones(&[1i64, 48, 4, 4], (Kind::Float, dev));
|
||||
let labels = Tensor::zeros(&[1i64, 4, 4], (Kind::Int64, dev));
|
||||
let loss = densepose_uv_loss(&t, &t, &labels);
|
||||
let v = loss.double_value(&[]);
|
||||
assert!(v.abs() < 1e-6, "Identical UV → loss ≈ 0, got {v}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fn_transfer_loss_identical_zero() {
|
||||
let dev = device();
|
||||
let t = Tensor::ones(&[2i64, 64, 8, 8], (Kind::Float, dev));
|
||||
let loss = fn_transfer_loss(&t, &t);
|
||||
let v = loss.double_value(&[]);
|
||||
assert!(v.abs() < 1e-6, "Identical features → transfer loss ≈ 0, got {v}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fn_transfer_loss_spatial_mismatch() {
|
||||
let dev = device();
|
||||
let student = Tensor::ones(&[1i64, 64, 16, 16], (Kind::Float, dev));
|
||||
let teacher = Tensor::ones(&[1i64, 64, 8, 8], (Kind::Float, dev));
|
||||
let loss = fn_transfer_loss(&student, &teacher);
|
||||
let v = loss.double_value(&[]);
|
||||
assert!(v.is_finite() && v >= 0.0, "Spatial-mismatch transfer loss must be finite");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fn_transfer_loss_channel_mismatch_divisible() {
|
||||
let dev = device();
|
||||
let student = Tensor::ones(&[1i64, 128, 8, 8], (Kind::Float, dev));
|
||||
let teacher = Tensor::ones(&[1i64, 64, 8, 8], (Kind::Float, dev));
|
||||
let loss = fn_transfer_loss(&student, &teacher);
|
||||
let v = loss.double_value(&[]);
|
||||
assert!(v.is_finite() && v >= 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_losses_keypoint_only() {
|
||||
let dev = device();
|
||||
let pred = Tensor::ones(&[1i64, 17, 8, 8], (Kind::Float, dev));
|
||||
let gt = Tensor::ones(&[1i64, 17, 8, 8], (Kind::Float, dev));
|
||||
let out = compute_losses(&pred, >, None, None, None, None, None, None,
|
||||
1.0, 1.0, 1.0);
|
||||
assert!(out.total.is_finite());
|
||||
assert!(out.keypoint >= 0.0);
|
||||
assert!(out.densepose_parts.is_none());
|
||||
assert!(out.densepose_uv.is_none());
|
||||
assert!(out.transfer.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_losses_all_components_finite() {
|
||||
let dev = device();
|
||||
let b = 1i64;
|
||||
let h = 4i64;
|
||||
let w = 4i64;
|
||||
let pred_kpt = Tensor::ones(&[b, 17, h, w], (Kind::Float, dev));
|
||||
let gt_kpt = Tensor::ones(&[b, 17, h, w], (Kind::Float, dev));
|
||||
let logits = Tensor::zeros(&[b, 25, h, w], (Kind::Float, dev));
|
||||
let labels = Tensor::zeros(&[b, h, w], (Kind::Int64, dev));
|
||||
let pred_uv = Tensor::ones(&[b, 48, h, w], (Kind::Float, dev));
|
||||
let gt_uv = Tensor::ones(&[b, 48, h, w], (Kind::Float, dev));
|
||||
let sf = Tensor::ones(&[b, 64, 2, 2], (Kind::Float, dev));
|
||||
let tf = Tensor::ones(&[b, 64, 2, 2], (Kind::Float, dev));
|
||||
|
||||
let out = compute_losses(
|
||||
&pred_kpt, >_kpt,
|
||||
Some(&logits), Some(&labels),
|
||||
Some(&pred_uv), Some(>_uv),
|
||||
Some(&sf), Some(&tf),
|
||||
1.0, 0.5, 0.1,
|
||||
);
|
||||
|
||||
assert!(out.total.is_finite() && out.total >= 0.0);
|
||||
assert!(out.densepose_parts.is_some());
|
||||
assert!(out.densepose_uv.is_some());
|
||||
assert!(out.transfer.is_some());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,458 @@
|
||||
//! Integration tests for [`wifi_densepose_train::config`].
|
||||
//!
|
||||
//! All tests are deterministic: they use only fixed values and the
|
||||
//! `TrainingConfig::default()` constructor. No OS entropy or `rand` crate
|
||||
//! is used.
|
||||
|
||||
use wifi_densepose_train::config::TrainingConfig;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Default config invariants
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// The default configuration must pass its own validation.
|
||||
#[test]
|
||||
fn default_config_is_valid() {
|
||||
let cfg = TrainingConfig::default();
|
||||
cfg.validate()
|
||||
.expect("default TrainingConfig must be valid");
|
||||
}
|
||||
|
||||
/// Every numeric field in the default config must be strictly positive where
|
||||
/// the domain requires it.
|
||||
#[test]
|
||||
fn default_config_all_positive_fields() {
|
||||
let cfg = TrainingConfig::default();
|
||||
|
||||
assert!(cfg.num_subcarriers > 0, "num_subcarriers must be > 0");
|
||||
assert!(cfg.native_subcarriers > 0, "native_subcarriers must be > 0");
|
||||
assert!(cfg.num_antennas_tx > 0, "num_antennas_tx must be > 0");
|
||||
assert!(cfg.num_antennas_rx > 0, "num_antennas_rx must be > 0");
|
||||
assert!(cfg.window_frames > 0, "window_frames must be > 0");
|
||||
assert!(cfg.heatmap_size > 0, "heatmap_size must be > 0");
|
||||
assert!(cfg.num_keypoints > 0, "num_keypoints must be > 0");
|
||||
assert!(cfg.num_body_parts > 0, "num_body_parts must be > 0");
|
||||
assert!(cfg.backbone_channels > 0, "backbone_channels must be > 0");
|
||||
assert!(cfg.batch_size > 0, "batch_size must be > 0");
|
||||
assert!(cfg.learning_rate > 0.0, "learning_rate must be > 0.0");
|
||||
assert!(cfg.weight_decay >= 0.0, "weight_decay must be >= 0.0");
|
||||
assert!(cfg.num_epochs > 0, "num_epochs must be > 0");
|
||||
assert!(cfg.grad_clip_norm > 0.0, "grad_clip_norm must be > 0.0");
|
||||
}
|
||||
|
||||
/// The three loss weights in the default config must all be non-negative and
|
||||
/// their sum must be positive (not all zero).
|
||||
#[test]
|
||||
fn default_config_loss_weights_sum_positive() {
|
||||
let cfg = TrainingConfig::default();
|
||||
|
||||
assert!(cfg.lambda_kp >= 0.0, "lambda_kp must be >= 0.0");
|
||||
assert!(cfg.lambda_dp >= 0.0, "lambda_dp must be >= 0.0");
|
||||
assert!(cfg.lambda_tr >= 0.0, "lambda_tr must be >= 0.0");
|
||||
|
||||
let total = cfg.lambda_kp + cfg.lambda_dp + cfg.lambda_tr;
|
||||
assert!(
|
||||
total > 0.0,
|
||||
"sum of loss weights must be > 0.0, got {total}"
|
||||
);
|
||||
}
|
||||
|
||||
/// The default loss weights should sum to exactly 1.0 (within floating-point
|
||||
/// tolerance).
|
||||
#[test]
|
||||
fn default_config_loss_weights_sum_to_one() {
|
||||
let cfg = TrainingConfig::default();
|
||||
let total = cfg.lambda_kp + cfg.lambda_dp + cfg.lambda_tr;
|
||||
let diff = (total - 1.0_f64).abs();
|
||||
assert!(
|
||||
diff < 1e-9,
|
||||
"expected loss weights to sum to 1.0, got {total} (diff={diff})"
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Specific default values
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// The default number of subcarriers is 56 (MM-Fi target).
|
||||
#[test]
|
||||
fn default_num_subcarriers_is_56() {
|
||||
let cfg = TrainingConfig::default();
|
||||
assert_eq!(
|
||||
cfg.num_subcarriers, 56,
|
||||
"expected default num_subcarriers = 56, got {}",
|
||||
cfg.num_subcarriers
|
||||
);
|
||||
}
|
||||
|
||||
/// The default number of native subcarriers is 114 (raw MM-Fi hardware output).
|
||||
#[test]
|
||||
fn default_native_subcarriers_is_114() {
|
||||
let cfg = TrainingConfig::default();
|
||||
assert_eq!(
|
||||
cfg.native_subcarriers, 114,
|
||||
"expected default native_subcarriers = 114, got {}",
|
||||
cfg.native_subcarriers
|
||||
);
|
||||
}
|
||||
|
||||
/// The default number of keypoints is 17 (COCO skeleton).
|
||||
#[test]
|
||||
fn default_num_keypoints_is_17() {
|
||||
let cfg = TrainingConfig::default();
|
||||
assert_eq!(
|
||||
cfg.num_keypoints, 17,
|
||||
"expected default num_keypoints = 17, got {}",
|
||||
cfg.num_keypoints
|
||||
);
|
||||
}
|
||||
|
||||
/// The default antenna counts are 3×3.
|
||||
#[test]
|
||||
fn default_antenna_counts_are_3x3() {
|
||||
let cfg = TrainingConfig::default();
|
||||
assert_eq!(cfg.num_antennas_tx, 3, "expected num_antennas_tx = 3");
|
||||
assert_eq!(cfg.num_antennas_rx, 3, "expected num_antennas_rx = 3");
|
||||
}
|
||||
|
||||
/// The default window length is 100 frames.
|
||||
#[test]
|
||||
fn default_window_frames_is_100() {
|
||||
let cfg = TrainingConfig::default();
|
||||
assert_eq!(
|
||||
cfg.window_frames, 100,
|
||||
"expected window_frames = 100, got {}",
|
||||
cfg.window_frames
|
||||
);
|
||||
}
|
||||
|
||||
/// The default seed is 42.
|
||||
#[test]
|
||||
fn default_seed_is_42() {
|
||||
let cfg = TrainingConfig::default();
|
||||
assert_eq!(cfg.seed, 42, "expected seed = 42, got {}", cfg.seed);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// needs_subcarrier_interp equivalent property
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// When native_subcarriers differs from num_subcarriers, interpolation is
|
||||
/// needed. The default config has 114 != 56, so this property must hold.
|
||||
#[test]
|
||||
fn default_config_needs_interpolation() {
|
||||
let cfg = TrainingConfig::default();
|
||||
// 114 native → 56 target: interpolation is required.
|
||||
assert_ne!(
|
||||
cfg.native_subcarriers, cfg.num_subcarriers,
|
||||
"default config must require subcarrier interpolation (native={} != target={})",
|
||||
cfg.native_subcarriers, cfg.num_subcarriers
|
||||
);
|
||||
}
|
||||
|
||||
/// When native_subcarriers equals num_subcarriers no interpolation is needed.
|
||||
#[test]
|
||||
fn equal_subcarrier_counts_means_no_interpolation_needed() {
|
||||
let mut cfg = TrainingConfig::default();
|
||||
cfg.native_subcarriers = cfg.num_subcarriers; // e.g., both = 56
|
||||
cfg.validate().expect("config with equal subcarrier counts must be valid");
|
||||
assert_eq!(
|
||||
cfg.native_subcarriers, cfg.num_subcarriers,
|
||||
"after setting equal counts, native ({}) must equal target ({})",
|
||||
cfg.native_subcarriers, cfg.num_subcarriers
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// csi_flat_size equivalent property
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// The flat input size of a single CSI window is
|
||||
/// `window_frames × num_antennas_tx × num_antennas_rx × num_subcarriers`.
|
||||
/// Verify the arithmetic matches the default config.
|
||||
#[test]
|
||||
fn csi_flat_size_matches_expected() {
|
||||
let cfg = TrainingConfig::default();
|
||||
let expected = cfg.window_frames
|
||||
* cfg.num_antennas_tx
|
||||
* cfg.num_antennas_rx
|
||||
* cfg.num_subcarriers;
|
||||
// Default: 100 * 3 * 3 * 56 = 50400
|
||||
assert_eq!(
|
||||
expected, 50_400,
|
||||
"CSI flat size must be 50400 for default config, got {expected}"
|
||||
);
|
||||
}
|
||||
|
||||
/// The CSI flat size must be > 0 for any valid config.
|
||||
#[test]
|
||||
fn csi_flat_size_positive_for_valid_config() {
|
||||
let cfg = TrainingConfig::default();
|
||||
let flat_size = cfg.window_frames
|
||||
* cfg.num_antennas_tx
|
||||
* cfg.num_antennas_rx
|
||||
* cfg.num_subcarriers;
|
||||
assert!(
|
||||
flat_size > 0,
|
||||
"CSI flat size must be > 0, got {flat_size}"
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// JSON serialization round-trip
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Serializing a config to JSON and deserializing it must yield an identical
|
||||
/// config (all fields must match).
|
||||
#[test]
|
||||
fn config_json_roundtrip_identical() {
|
||||
use std::path::PathBuf;
|
||||
use tempfile::tempdir;
|
||||
|
||||
let tmp = tempdir().expect("tempdir must be created");
|
||||
let path = tmp.path().join("config.json");
|
||||
|
||||
let original = TrainingConfig::default();
|
||||
original
|
||||
.to_json(&path)
|
||||
.expect("to_json must succeed for default config");
|
||||
|
||||
let loaded = TrainingConfig::from_json(&path)
|
||||
.expect("from_json must succeed for previously serialized config");
|
||||
|
||||
// Verify all fields are equal.
|
||||
assert_eq!(
|
||||
loaded.num_subcarriers, original.num_subcarriers,
|
||||
"num_subcarriers must survive round-trip"
|
||||
);
|
||||
assert_eq!(
|
||||
loaded.native_subcarriers, original.native_subcarriers,
|
||||
"native_subcarriers must survive round-trip"
|
||||
);
|
||||
assert_eq!(
|
||||
loaded.num_antennas_tx, original.num_antennas_tx,
|
||||
"num_antennas_tx must survive round-trip"
|
||||
);
|
||||
assert_eq!(
|
||||
loaded.num_antennas_rx, original.num_antennas_rx,
|
||||
"num_antennas_rx must survive round-trip"
|
||||
);
|
||||
assert_eq!(
|
||||
loaded.window_frames, original.window_frames,
|
||||
"window_frames must survive round-trip"
|
||||
);
|
||||
assert_eq!(
|
||||
loaded.heatmap_size, original.heatmap_size,
|
||||
"heatmap_size must survive round-trip"
|
||||
);
|
||||
assert_eq!(
|
||||
loaded.num_keypoints, original.num_keypoints,
|
||||
"num_keypoints must survive round-trip"
|
||||
);
|
||||
assert_eq!(
|
||||
loaded.num_body_parts, original.num_body_parts,
|
||||
"num_body_parts must survive round-trip"
|
||||
);
|
||||
assert_eq!(
|
||||
loaded.backbone_channels, original.backbone_channels,
|
||||
"backbone_channels must survive round-trip"
|
||||
);
|
||||
assert_eq!(
|
||||
loaded.batch_size, original.batch_size,
|
||||
"batch_size must survive round-trip"
|
||||
);
|
||||
assert!(
|
||||
(loaded.learning_rate - original.learning_rate).abs() < 1e-12,
|
||||
"learning_rate must survive round-trip: got {}",
|
||||
loaded.learning_rate
|
||||
);
|
||||
assert!(
|
||||
(loaded.weight_decay - original.weight_decay).abs() < 1e-12,
|
||||
"weight_decay must survive round-trip"
|
||||
);
|
||||
assert_eq!(
|
||||
loaded.num_epochs, original.num_epochs,
|
||||
"num_epochs must survive round-trip"
|
||||
);
|
||||
assert_eq!(
|
||||
loaded.warmup_epochs, original.warmup_epochs,
|
||||
"warmup_epochs must survive round-trip"
|
||||
);
|
||||
assert_eq!(
|
||||
loaded.lr_milestones, original.lr_milestones,
|
||||
"lr_milestones must survive round-trip"
|
||||
);
|
||||
assert!(
|
||||
(loaded.lr_gamma - original.lr_gamma).abs() < 1e-12,
|
||||
"lr_gamma must survive round-trip"
|
||||
);
|
||||
assert!(
|
||||
(loaded.grad_clip_norm - original.grad_clip_norm).abs() < 1e-12,
|
||||
"grad_clip_norm must survive round-trip"
|
||||
);
|
||||
assert!(
|
||||
(loaded.lambda_kp - original.lambda_kp).abs() < 1e-12,
|
||||
"lambda_kp must survive round-trip"
|
||||
);
|
||||
assert!(
|
||||
(loaded.lambda_dp - original.lambda_dp).abs() < 1e-12,
|
||||
"lambda_dp must survive round-trip"
|
||||
);
|
||||
assert!(
|
||||
(loaded.lambda_tr - original.lambda_tr).abs() < 1e-12,
|
||||
"lambda_tr must survive round-trip"
|
||||
);
|
||||
assert_eq!(
|
||||
loaded.val_every_epochs, original.val_every_epochs,
|
||||
"val_every_epochs must survive round-trip"
|
||||
);
|
||||
assert_eq!(
|
||||
loaded.early_stopping_patience, original.early_stopping_patience,
|
||||
"early_stopping_patience must survive round-trip"
|
||||
);
|
||||
assert_eq!(
|
||||
loaded.save_top_k, original.save_top_k,
|
||||
"save_top_k must survive round-trip"
|
||||
);
|
||||
assert_eq!(loaded.use_gpu, original.use_gpu, "use_gpu must survive round-trip");
|
||||
assert_eq!(
|
||||
loaded.gpu_device_id, original.gpu_device_id,
|
||||
"gpu_device_id must survive round-trip"
|
||||
);
|
||||
assert_eq!(
|
||||
loaded.num_workers, original.num_workers,
|
||||
"num_workers must survive round-trip"
|
||||
);
|
||||
assert_eq!(loaded.seed, original.seed, "seed must survive round-trip");
|
||||
}
|
||||
|
||||
/// A modified config with non-default values must also survive a JSON
|
||||
/// round-trip.
|
||||
#[test]
|
||||
fn config_json_roundtrip_modified_values() {
|
||||
use tempfile::tempdir;
|
||||
|
||||
let tmp = tempdir().expect("tempdir must be created");
|
||||
let path = tmp.path().join("modified.json");
|
||||
|
||||
let mut cfg = TrainingConfig::default();
|
||||
cfg.batch_size = 16;
|
||||
cfg.learning_rate = 5e-4;
|
||||
cfg.num_epochs = 100;
|
||||
cfg.warmup_epochs = 10;
|
||||
cfg.lr_milestones = vec![50, 80];
|
||||
cfg.seed = 99;
|
||||
|
||||
cfg.validate().expect("modified config must be valid before serialization");
|
||||
cfg.to_json(&path).expect("to_json must succeed");
|
||||
|
||||
let loaded = TrainingConfig::from_json(&path).expect("from_json must succeed");
|
||||
|
||||
assert_eq!(loaded.batch_size, 16, "batch_size must match after round-trip");
|
||||
assert!(
|
||||
(loaded.learning_rate - 5e-4_f64).abs() < 1e-12,
|
||||
"learning_rate must match after round-trip"
|
||||
);
|
||||
assert_eq!(loaded.num_epochs, 100, "num_epochs must match after round-trip");
|
||||
assert_eq!(loaded.warmup_epochs, 10, "warmup_epochs must match after round-trip");
|
||||
assert_eq!(
|
||||
loaded.lr_milestones,
|
||||
vec![50, 80],
|
||||
"lr_milestones must match after round-trip"
|
||||
);
|
||||
assert_eq!(loaded.seed, 99, "seed must match after round-trip");
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Validation: invalid configurations are rejected
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Setting num_subcarriers to 0 must produce a validation error.
|
||||
#[test]
|
||||
fn zero_num_subcarriers_is_invalid() {
|
||||
let mut cfg = TrainingConfig::default();
|
||||
cfg.num_subcarriers = 0;
|
||||
assert!(
|
||||
cfg.validate().is_err(),
|
||||
"num_subcarriers = 0 must be rejected by validate()"
|
||||
);
|
||||
}
|
||||
|
||||
/// Setting native_subcarriers to 0 must produce a validation error.
|
||||
#[test]
|
||||
fn zero_native_subcarriers_is_invalid() {
|
||||
let mut cfg = TrainingConfig::default();
|
||||
cfg.native_subcarriers = 0;
|
||||
assert!(
|
||||
cfg.validate().is_err(),
|
||||
"native_subcarriers = 0 must be rejected by validate()"
|
||||
);
|
||||
}
|
||||
|
||||
/// Setting batch_size to 0 must produce a validation error.
|
||||
#[test]
|
||||
fn zero_batch_size_is_invalid() {
|
||||
let mut cfg = TrainingConfig::default();
|
||||
cfg.batch_size = 0;
|
||||
assert!(
|
||||
cfg.validate().is_err(),
|
||||
"batch_size = 0 must be rejected by validate()"
|
||||
);
|
||||
}
|
||||
|
||||
/// A negative learning rate must produce a validation error.
|
||||
#[test]
|
||||
fn negative_learning_rate_is_invalid() {
|
||||
let mut cfg = TrainingConfig::default();
|
||||
cfg.learning_rate = -0.001;
|
||||
assert!(
|
||||
cfg.validate().is_err(),
|
||||
"learning_rate < 0 must be rejected by validate()"
|
||||
);
|
||||
}
|
||||
|
||||
/// warmup_epochs >= num_epochs must produce a validation error.
|
||||
#[test]
|
||||
fn warmup_exceeding_epochs_is_invalid() {
|
||||
let mut cfg = TrainingConfig::default();
|
||||
cfg.warmup_epochs = cfg.num_epochs; // equal, which is still invalid
|
||||
assert!(
|
||||
cfg.validate().is_err(),
|
||||
"warmup_epochs >= num_epochs must be rejected by validate()"
|
||||
);
|
||||
}
|
||||
|
||||
/// All loss weights set to 0.0 must produce a validation error.
|
||||
#[test]
|
||||
fn all_zero_loss_weights_are_invalid() {
|
||||
let mut cfg = TrainingConfig::default();
|
||||
cfg.lambda_kp = 0.0;
|
||||
cfg.lambda_dp = 0.0;
|
||||
cfg.lambda_tr = 0.0;
|
||||
assert!(
|
||||
cfg.validate().is_err(),
|
||||
"all-zero loss weights must be rejected by validate()"
|
||||
);
|
||||
}
|
||||
|
||||
/// Non-increasing lr_milestones must produce a validation error.
|
||||
#[test]
|
||||
fn non_increasing_milestones_are_invalid() {
|
||||
let mut cfg = TrainingConfig::default();
|
||||
cfg.lr_milestones = vec![40, 30]; // wrong order
|
||||
assert!(
|
||||
cfg.validate().is_err(),
|
||||
"non-increasing lr_milestones must be rejected by validate()"
|
||||
);
|
||||
}
|
||||
|
||||
/// An lr_milestone beyond num_epochs must produce a validation error.
|
||||
#[test]
|
||||
fn milestone_beyond_num_epochs_is_invalid() {
|
||||
let mut cfg = TrainingConfig::default();
|
||||
cfg.lr_milestones = vec![30, cfg.num_epochs + 1];
|
||||
assert!(
|
||||
cfg.validate().is_err(),
|
||||
"lr_milestone > num_epochs must be rejected by validate()"
|
||||
);
|
||||
}
|
||||
Reference in New Issue
Block a user