feat(rust): Add workspace deps, tests, and refine training modules

- Cargo.toml: Add wifi-densepose-train to workspace members; add
  petgraph, ndarray-npy, walkdir, sha2, csv, indicatif, clap to
  workspace dependencies
- error.rs: Slim down to focused error types (TrainError, DatasetError)
- lib.rs: Wire up all module re-exports correctly
- losses.rs: Add generate_gaussian_heatmaps implementation
- tests/test_config.rs: Deterministic config roundtrip and validation tests

https://claude.ai/code/session_01BSBAQJ34SLkiJy4A8SoiL4
This commit is contained in:
Claude
2026-02-28 15:17:17 +00:00
parent ec98e40fff
commit 2c5ca308a4
5 changed files with 643 additions and 290 deletions

View File

@@ -1,12 +1,24 @@
//! Error types for the WiFi-DensePose training pipeline.
//!
//! This module defines a hierarchy of errors covering every failure mode in
//! the training pipeline: configuration validation, dataset I/O, subcarrier
//! interpolation, and top-level training orchestration.
//! This module provides:
//!
//! - [`TrainError`]: top-level error aggregating all training failure modes.
//! - [`TrainResult`]: convenient `Result` alias using `TrainError`.
//!
//! Module-local error types live in their respective modules:
//!
//! - [`crate::config::ConfigError`]: configuration validation errors.
//! - [`crate::dataset::DatasetError`]: dataset loading/access errors.
//!
//! All are re-exported at the crate root for ergonomic use.
use thiserror::Error;
use std::path::PathBuf;
// Import module-local error types so TrainError can wrap them via #[from].
use crate::config::ConfigError;
use crate::dataset::DatasetError;
// ---------------------------------------------------------------------------
// Top-level training error
// ---------------------------------------------------------------------------
@@ -16,8 +28,9 @@ pub type TrainResult<T> = Result<T, TrainError>;
/// Top-level error type for the training pipeline.
///
/// Every public function in this crate that can fail returns
/// `TrainResult<T>`, which is `Result<T, TrainError>`.
/// Every orchestration-level function returns `TrainResult<T>`. Lower-level
/// functions in [`crate::config`] and [`crate::dataset`] return their own
/// module-specific error types which are automatically coerced via `#[from]`.
#[derive(Debug, Error)]
pub enum TrainError {
/// Configuration is invalid or internally inconsistent.
@@ -28,10 +41,6 @@ pub enum TrainError {
#[error("Dataset error: {0}")]
Dataset(#[from] DatasetError),
/// Subcarrier interpolation / resampling failed.
#[error("Subcarrier interpolation error: {0}")]
Subcarrier(#[from] SubcarrierError),
/// An underlying I/O error not covered by a more specific variant.
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
@@ -40,14 +49,6 @@ pub enum TrainError {
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
/// TOML (de)serialization error.
#[error("TOML deserialization error: {0}")]
TomlDe(#[from] toml::de::Error),
/// TOML serialization error.
#[error("TOML serialization error: {0}")]
TomlSer(#[from] toml::ser::Error),
/// An operation was attempted on an empty dataset.
#[error("Dataset is empty")]
EmptyDataset,
@@ -112,273 +113,3 @@ impl TrainError {
TrainError::ShapeMismatch { expected, actual }
}
}
// ---------------------------------------------------------------------------
// Configuration errors
// ---------------------------------------------------------------------------
/// Errors produced when validating or loading a [`TrainingConfig`].
///
/// [`TrainingConfig`]: crate::config::TrainingConfig
#[derive(Debug, Error)]
pub enum ConfigError {
/// A required field has a value that violates a constraint.
#[error("Invalid value for field `{field}`: {reason}")]
InvalidValue {
/// Name of the configuration field.
field: &'static str,
/// Human-readable reason the value is invalid.
reason: String,
},
/// The configuration file could not be read.
#[error("Cannot read configuration file `{path}`: {source}")]
FileRead {
/// Path that was being read.
path: PathBuf,
/// Underlying I/O error.
#[source]
source: std::io::Error,
},
/// The configuration file contains invalid TOML.
#[error("Cannot parse configuration file `{path}`: {source}")]
ParseError {
/// Path that was being parsed.
path: PathBuf,
/// Underlying TOML parse error.
#[source]
source: toml::de::Error,
},
/// A path specified in the config does not exist.
#[error("Path `{path}` specified in config does not exist")]
PathNotFound {
/// The missing path.
path: PathBuf,
},
}
impl ConfigError {
/// Construct an [`ConfigError::InvalidValue`] error.
pub fn invalid_value<S: Into<String>>(field: &'static str, reason: S) -> Self {
ConfigError::InvalidValue {
field,
reason: reason.into(),
}
}
}
// ---------------------------------------------------------------------------
// Dataset errors
// ---------------------------------------------------------------------------
/// Errors produced while loading or accessing dataset samples.
#[derive(Debug, Error)]
pub enum DatasetError {
/// The requested data file or directory was not found.
///
/// Production training data is mandatory; this error is never silently
/// suppressed. Use [`SyntheticDataset`] only for proof/testing.
///
/// [`SyntheticDataset`]: crate::dataset::SyntheticDataset
#[error("Data not found at `{path}`: {message}")]
DataNotFound {
/// Path that was expected to contain data.
path: PathBuf,
/// Additional context.
message: String,
},
/// A file was found but its format is incorrect or unexpected.
///
/// This covers malformed numpy arrays, unexpected shapes, bad JSON
/// metadata, etc.
#[error("Invalid data format in `{path}`: {message}")]
InvalidFormat {
/// Path of the malformed file.
path: PathBuf,
/// Description of the format problem.
message: String,
},
/// A low-level I/O error while reading a data file.
#[error("I/O error reading `{path}`: {source}")]
IoError {
/// Path being read when the error occurred.
path: PathBuf,
/// Underlying I/O error.
#[source]
source: std::io::Error,
},
/// The number of subcarriers in the data file does not match the
/// configuration expectation (before or after interpolation).
#[error(
"Subcarrier count mismatch in `{path}`: \
file has {found} subcarriers, expected {expected}"
)]
SubcarrierMismatch {
/// Path of the offending file.
path: PathBuf,
/// Number of subcarriers found in the file.
found: usize,
/// Number of subcarriers expected by the configuration.
expected: usize,
},
/// A sample index was out of bounds.
#[error("Index {index} is out of bounds for dataset of length {len}")]
IndexOutOfBounds {
/// The requested index.
index: usize,
/// Total number of samples.
len: usize,
},
/// A numpy array could not be read.
#[error("NumPy array read error in `{path}`: {message}")]
NpyReadError {
/// Path of the `.npy` file.
path: PathBuf,
/// Error description.
message: String,
},
/// A metadata file (e.g., `meta.json`) is missing or malformed.
#[error("Metadata error for subject {subject_id}: {message}")]
MetadataError {
/// Subject whose metadata could not be read.
subject_id: u32,
/// Description of the problem.
message: String,
},
/// No subjects matching the requested IDs were found in the data directory.
#[error(
"No subjects found in `{data_dir}` matching the requested IDs: {requested:?}"
)]
NoSubjectsFound {
/// Root data directory that was scanned.
data_dir: PathBuf,
/// Subject IDs that were requested.
requested: Vec<u32>,
},
/// A subcarrier interpolation error occurred during sample loading.
#[error("Subcarrier interpolation failed while loading sample {sample_idx}: {source}")]
InterpolationError {
/// The sample index being loaded.
sample_idx: usize,
/// Underlying interpolation error.
#[source]
source: SubcarrierError,
},
}
impl DatasetError {
/// Construct a [`DatasetError::DataNotFound`] error.
pub fn not_found<S: Into<String>>(path: impl Into<PathBuf>, msg: S) -> Self {
DatasetError::DataNotFound {
path: path.into(),
message: msg.into(),
}
}
/// Construct a [`DatasetError::InvalidFormat`] error.
pub fn invalid_format<S: Into<String>>(path: impl Into<PathBuf>, msg: S) -> Self {
DatasetError::InvalidFormat {
path: path.into(),
message: msg.into(),
}
}
/// Construct a [`DatasetError::IoError`] error.
pub fn io_error(path: impl Into<PathBuf>, source: std::io::Error) -> Self {
DatasetError::IoError {
path: path.into(),
source,
}
}
/// Construct a [`DatasetError::SubcarrierMismatch`] error.
pub fn subcarrier_mismatch(path: impl Into<PathBuf>, found: usize, expected: usize) -> Self {
DatasetError::SubcarrierMismatch {
path: path.into(),
found,
expected,
}
}
/// Construct a [`DatasetError::NpyReadError`] error.
pub fn npy_read<S: Into<String>>(path: impl Into<PathBuf>, msg: S) -> Self {
DatasetError::NpyReadError {
path: path.into(),
message: msg.into(),
}
}
}
// ---------------------------------------------------------------------------
// Subcarrier interpolation errors
// ---------------------------------------------------------------------------
/// Errors produced by the subcarrier resampling functions.
#[derive(Debug, Error)]
pub enum SubcarrierError {
/// The source or destination subcarrier count is zero.
#[error("Subcarrier count must be at least 1, got {count}")]
ZeroCount {
/// The offending count.
count: usize,
},
/// The input array has an unexpected shape.
#[error(
"Input array shape mismatch: expected last dimension {expected_sc}, \
got {actual_sc} (full shape: {shape:?})"
)]
InputShapeMismatch {
/// Expected number of subcarriers (last dimension).
expected_sc: usize,
/// Actual number of subcarriers found.
actual_sc: usize,
/// Full shape of the input array.
shape: Vec<usize>,
},
/// The requested interpolation method is not implemented.
#[error("Interpolation method `{method}` is not yet implemented")]
MethodNotImplemented {
/// Name of the unimplemented method.
method: String,
},
/// Source and destination subcarrier counts are already equal.
///
/// Callers should check [`TrainingConfig::needs_subcarrier_interp`] before
/// calling the interpolation routine to avoid this error.
///
/// [`TrainingConfig::needs_subcarrier_interp`]:
/// crate::config::TrainingConfig::needs_subcarrier_interp
#[error(
"Source and destination subcarrier counts are equal ({count}); \
no interpolation is needed"
)]
NopInterpolation {
/// The equal count.
count: usize,
},
/// A numerical error occurred during interpolation (e.g., division by zero
/// due to coincident knot positions).
#[error("Numerical error during interpolation: {0}")]
NumericalError(String),
}
impl SubcarrierError {
/// Construct a [`SubcarrierError::NumericalError`].
pub fn numerical<S: Into<String>>(msg: S) -> Self {
SubcarrierError::NumericalError(msg.into())
}
}

View File

@@ -52,9 +52,9 @@ pub mod subcarrier;
pub mod trainer;
// Convenient re-exports at the crate root.
pub use config::TrainingConfig;
pub use dataset::{CsiDataset, CsiSample, DataLoader, MmFiDataset, SyntheticCsiDataset, SyntheticConfig};
pub use error::{ConfigError, DatasetError, SubcarrierError, TrainError, TrainResult};
pub use config::{ConfigError, TrainingConfig};
pub use dataset::{CsiDataset, CsiSample, DataLoader, DatasetError, MmFiDataset, SyntheticCsiDataset, SyntheticConfig};
pub use error::{TrainError, TrainResult};
pub use subcarrier::{compute_interp_weights, interpolate_subcarriers, select_subcarriers_by_variance};
/// Crate version string.

View File

@@ -906,4 +906,148 @@ mod tests {
"DensePose loss with identical UV should be bounded by CE, got {val}"
);
}
// ── Standalone functional API tests ──────────────────────────────────────
#[test]
fn test_fn_keypoint_heatmap_loss_identical_zero() {
let dev = device();
let t = Tensor::ones([2, 17, 8, 8], (Kind::Float, dev));
let loss = keypoint_heatmap_loss(&t, &t);
let v = loss.double_value(&[]) as f32;
assert!(v.abs() < 1e-6, "Identical heatmaps → loss must be ≈0, got {v}");
}
#[test]
fn test_fn_generate_gaussian_heatmaps_shape() {
let dev = device();
let kpts = Tensor::full(&[2i64, 17, 2], 0.5, (Kind::Float, dev));
let vis = Tensor::ones(&[2i64, 17], (Kind::Float, dev));
let hm = generate_gaussian_heatmaps(&kpts, &vis, 16, 2.0);
assert_eq!(hm.size(), [2, 17, 16, 16]);
}
#[test]
fn test_fn_generate_gaussian_heatmaps_invisible_zero() {
let dev = device();
let kpts = Tensor::full(&[1i64, 17, 2], 0.5, (Kind::Float, dev));
let vis = Tensor::zeros(&[1i64, 17], (Kind::Float, dev)); // all invisible
let hm = generate_gaussian_heatmaps(&kpts, &vis, 8, 2.0);
let total: f64 = hm.sum(Kind::Float).double_value(&[]);
assert_eq!(total, 0.0, "All-invisible heatmaps must be zero");
}
#[test]
fn test_fn_generate_gaussian_heatmaps_peak_near_one() {
let dev = device();
// Keypoint at (0.5, 0.5) on an 8×8 map.
let kpts = Tensor::full(&[1i64, 1, 2], 0.5, (Kind::Float, dev));
let vis = Tensor::ones(&[1i64, 1], (Kind::Float, dev));
let hm = generate_gaussian_heatmaps(&kpts, &vis, 8, 1.5);
let max_val: f64 = hm.max().double_value(&[]);
assert!(max_val > 0.9, "Peak value {max_val} should be > 0.9");
}
#[test]
fn test_fn_densepose_part_loss_returns_finite() {
let dev = device();
let logits = Tensor::zeros(&[1i64, 25, 4, 4], (Kind::Float, dev));
let labels = Tensor::zeros(&[1i64, 4, 4], (Kind::Int64, dev));
let loss = densepose_part_loss(&logits, &labels);
let v = loss.double_value(&[]);
assert!(v.is_finite() && v >= 0.0);
}
#[test]
fn test_fn_densepose_uv_loss_no_annotated_pixels_zero() {
let dev = device();
let pred = Tensor::ones(&[1i64, 48, 4, 4], (Kind::Float, dev));
let gt = Tensor::zeros(&[1i64, 48, 4, 4], (Kind::Float, dev));
let labels = Tensor::full(&[1i64, 4, 4], -1i64, (Kind::Int64, dev));
let loss = densepose_uv_loss(&pred, &gt, &labels);
let v = loss.double_value(&[]);
assert_eq!(v, 0.0, "No annotated pixels → UV loss must be 0");
}
#[test]
fn test_fn_densepose_uv_loss_identical_zero() {
let dev = device();
let t = Tensor::ones(&[1i64, 48, 4, 4], (Kind::Float, dev));
let labels = Tensor::zeros(&[1i64, 4, 4], (Kind::Int64, dev));
let loss = densepose_uv_loss(&t, &t, &labels);
let v = loss.double_value(&[]);
assert!(v.abs() < 1e-6, "Identical UV → loss ≈ 0, got {v}");
}
#[test]
fn test_fn_transfer_loss_identical_zero() {
let dev = device();
let t = Tensor::ones(&[2i64, 64, 8, 8], (Kind::Float, dev));
let loss = fn_transfer_loss(&t, &t);
let v = loss.double_value(&[]);
assert!(v.abs() < 1e-6, "Identical features → transfer loss ≈ 0, got {v}");
}
#[test]
fn test_fn_transfer_loss_spatial_mismatch() {
let dev = device();
let student = Tensor::ones(&[1i64, 64, 16, 16], (Kind::Float, dev));
let teacher = Tensor::ones(&[1i64, 64, 8, 8], (Kind::Float, dev));
let loss = fn_transfer_loss(&student, &teacher);
let v = loss.double_value(&[]);
assert!(v.is_finite() && v >= 0.0, "Spatial-mismatch transfer loss must be finite");
}
#[test]
fn test_fn_transfer_loss_channel_mismatch_divisible() {
let dev = device();
let student = Tensor::ones(&[1i64, 128, 8, 8], (Kind::Float, dev));
let teacher = Tensor::ones(&[1i64, 64, 8, 8], (Kind::Float, dev));
let loss = fn_transfer_loss(&student, &teacher);
let v = loss.double_value(&[]);
assert!(v.is_finite() && v >= 0.0);
}
#[test]
fn test_compute_losses_keypoint_only() {
let dev = device();
let pred = Tensor::ones(&[1i64, 17, 8, 8], (Kind::Float, dev));
let gt = Tensor::ones(&[1i64, 17, 8, 8], (Kind::Float, dev));
let out = compute_losses(&pred, &gt, None, None, None, None, None, None,
1.0, 1.0, 1.0);
assert!(out.total.is_finite());
assert!(out.keypoint >= 0.0);
assert!(out.densepose_parts.is_none());
assert!(out.densepose_uv.is_none());
assert!(out.transfer.is_none());
}
#[test]
fn test_compute_losses_all_components_finite() {
let dev = device();
let b = 1i64;
let h = 4i64;
let w = 4i64;
let pred_kpt = Tensor::ones(&[b, 17, h, w], (Kind::Float, dev));
let gt_kpt = Tensor::ones(&[b, 17, h, w], (Kind::Float, dev));
let logits = Tensor::zeros(&[b, 25, h, w], (Kind::Float, dev));
let labels = Tensor::zeros(&[b, h, w], (Kind::Int64, dev));
let pred_uv = Tensor::ones(&[b, 48, h, w], (Kind::Float, dev));
let gt_uv = Tensor::ones(&[b, 48, h, w], (Kind::Float, dev));
let sf = Tensor::ones(&[b, 64, 2, 2], (Kind::Float, dev));
let tf = Tensor::ones(&[b, 64, 2, 2], (Kind::Float, dev));
let out = compute_losses(
&pred_kpt, &gt_kpt,
Some(&logits), Some(&labels),
Some(&pred_uv), Some(&gt_uv),
Some(&sf), Some(&tf),
1.0, 0.5, 0.1,
);
assert!(out.total.is_finite() && out.total >= 0.0);
assert!(out.densepose_parts.is_some());
assert!(out.densepose_uv.is_some());
assert!(out.transfer.is_some());
}
}