feat(rust): Add workspace deps, tests, and refine training modules
- Cargo.toml: Add wifi-densepose-train to workspace members; add petgraph, ndarray-npy, walkdir, sha2, csv, indicatif, clap to workspace dependencies - error.rs: Slim down to focused error types (TrainError, DatasetError) - lib.rs: Wire up all module re-exports correctly - losses.rs: Add generate_gaussian_heatmaps implementation - tests/test_config.rs: Deterministic config roundtrip and validation tests https://claude.ai/code/session_01BSBAQJ34SLkiJy4A8SoiL4
This commit is contained in:
@@ -1,12 +1,24 @@
|
||||
//! Error types for the WiFi-DensePose training pipeline.
|
||||
//!
|
||||
//! This module defines a hierarchy of errors covering every failure mode in
|
||||
//! the training pipeline: configuration validation, dataset I/O, subcarrier
|
||||
//! interpolation, and top-level training orchestration.
|
||||
//! This module provides:
|
||||
//!
|
||||
//! - [`TrainError`]: top-level error aggregating all training failure modes.
|
||||
//! - [`TrainResult`]: convenient `Result` alias using `TrainError`.
|
||||
//!
|
||||
//! Module-local error types live in their respective modules:
|
||||
//!
|
||||
//! - [`crate::config::ConfigError`]: configuration validation errors.
|
||||
//! - [`crate::dataset::DatasetError`]: dataset loading/access errors.
|
||||
//!
|
||||
//! All are re-exported at the crate root for ergonomic use.
|
||||
|
||||
use thiserror::Error;
|
||||
use std::path::PathBuf;
|
||||
|
||||
// Import module-local error types so TrainError can wrap them via #[from].
|
||||
use crate::config::ConfigError;
|
||||
use crate::dataset::DatasetError;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Top-level training error
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -16,8 +28,9 @@ pub type TrainResult<T> = Result<T, TrainError>;
|
||||
|
||||
/// Top-level error type for the training pipeline.
|
||||
///
|
||||
/// Every public function in this crate that can fail returns
|
||||
/// `TrainResult<T>`, which is `Result<T, TrainError>`.
|
||||
/// Every orchestration-level function returns `TrainResult<T>`. Lower-level
|
||||
/// functions in [`crate::config`] and [`crate::dataset`] return their own
|
||||
/// module-specific error types which are automatically coerced via `#[from]`.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum TrainError {
|
||||
/// Configuration is invalid or internally inconsistent.
|
||||
@@ -28,10 +41,6 @@ pub enum TrainError {
|
||||
#[error("Dataset error: {0}")]
|
||||
Dataset(#[from] DatasetError),
|
||||
|
||||
/// Subcarrier interpolation / resampling failed.
|
||||
#[error("Subcarrier interpolation error: {0}")]
|
||||
Subcarrier(#[from] SubcarrierError),
|
||||
|
||||
/// An underlying I/O error not covered by a more specific variant.
|
||||
#[error("I/O error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
@@ -40,14 +49,6 @@ pub enum TrainError {
|
||||
#[error("JSON error: {0}")]
|
||||
Json(#[from] serde_json::Error),
|
||||
|
||||
/// TOML (de)serialization error.
|
||||
#[error("TOML deserialization error: {0}")]
|
||||
TomlDe(#[from] toml::de::Error),
|
||||
|
||||
/// TOML serialization error.
|
||||
#[error("TOML serialization error: {0}")]
|
||||
TomlSer(#[from] toml::ser::Error),
|
||||
|
||||
/// An operation was attempted on an empty dataset.
|
||||
#[error("Dataset is empty")]
|
||||
EmptyDataset,
|
||||
@@ -112,273 +113,3 @@ impl TrainError {
|
||||
TrainError::ShapeMismatch { expected, actual }
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Configuration errors
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Errors produced when validating or loading a [`TrainingConfig`].
|
||||
///
|
||||
/// [`TrainingConfig`]: crate::config::TrainingConfig
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ConfigError {
|
||||
/// A required field has a value that violates a constraint.
|
||||
#[error("Invalid value for field `{field}`: {reason}")]
|
||||
InvalidValue {
|
||||
/// Name of the configuration field.
|
||||
field: &'static str,
|
||||
/// Human-readable reason the value is invalid.
|
||||
reason: String,
|
||||
},
|
||||
|
||||
/// The configuration file could not be read.
|
||||
#[error("Cannot read configuration file `{path}`: {source}")]
|
||||
FileRead {
|
||||
/// Path that was being read.
|
||||
path: PathBuf,
|
||||
/// Underlying I/O error.
|
||||
#[source]
|
||||
source: std::io::Error,
|
||||
},
|
||||
|
||||
/// The configuration file contains invalid TOML.
|
||||
#[error("Cannot parse configuration file `{path}`: {source}")]
|
||||
ParseError {
|
||||
/// Path that was being parsed.
|
||||
path: PathBuf,
|
||||
/// Underlying TOML parse error.
|
||||
#[source]
|
||||
source: toml::de::Error,
|
||||
},
|
||||
|
||||
/// A path specified in the config does not exist.
|
||||
#[error("Path `{path}` specified in config does not exist")]
|
||||
PathNotFound {
|
||||
/// The missing path.
|
||||
path: PathBuf,
|
||||
},
|
||||
}
|
||||
|
||||
impl ConfigError {
|
||||
/// Construct an [`ConfigError::InvalidValue`] error.
|
||||
pub fn invalid_value<S: Into<String>>(field: &'static str, reason: S) -> Self {
|
||||
ConfigError::InvalidValue {
|
||||
field,
|
||||
reason: reason.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Dataset errors
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Errors produced while loading or accessing dataset samples.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum DatasetError {
|
||||
/// The requested data file or directory was not found.
|
||||
///
|
||||
/// Production training data is mandatory; this error is never silently
|
||||
/// suppressed. Use [`SyntheticDataset`] only for proof/testing.
|
||||
///
|
||||
/// [`SyntheticDataset`]: crate::dataset::SyntheticDataset
|
||||
#[error("Data not found at `{path}`: {message}")]
|
||||
DataNotFound {
|
||||
/// Path that was expected to contain data.
|
||||
path: PathBuf,
|
||||
/// Additional context.
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// A file was found but its format is incorrect or unexpected.
|
||||
///
|
||||
/// This covers malformed numpy arrays, unexpected shapes, bad JSON
|
||||
/// metadata, etc.
|
||||
#[error("Invalid data format in `{path}`: {message}")]
|
||||
InvalidFormat {
|
||||
/// Path of the malformed file.
|
||||
path: PathBuf,
|
||||
/// Description of the format problem.
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// A low-level I/O error while reading a data file.
|
||||
#[error("I/O error reading `{path}`: {source}")]
|
||||
IoError {
|
||||
/// Path being read when the error occurred.
|
||||
path: PathBuf,
|
||||
/// Underlying I/O error.
|
||||
#[source]
|
||||
source: std::io::Error,
|
||||
},
|
||||
|
||||
/// The number of subcarriers in the data file does not match the
|
||||
/// configuration expectation (before or after interpolation).
|
||||
#[error(
|
||||
"Subcarrier count mismatch in `{path}`: \
|
||||
file has {found} subcarriers, expected {expected}"
|
||||
)]
|
||||
SubcarrierMismatch {
|
||||
/// Path of the offending file.
|
||||
path: PathBuf,
|
||||
/// Number of subcarriers found in the file.
|
||||
found: usize,
|
||||
/// Number of subcarriers expected by the configuration.
|
||||
expected: usize,
|
||||
},
|
||||
|
||||
/// A sample index was out of bounds.
|
||||
#[error("Index {index} is out of bounds for dataset of length {len}")]
|
||||
IndexOutOfBounds {
|
||||
/// The requested index.
|
||||
index: usize,
|
||||
/// Total number of samples.
|
||||
len: usize,
|
||||
},
|
||||
|
||||
/// A numpy array could not be read.
|
||||
#[error("NumPy array read error in `{path}`: {message}")]
|
||||
NpyReadError {
|
||||
/// Path of the `.npy` file.
|
||||
path: PathBuf,
|
||||
/// Error description.
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// A metadata file (e.g., `meta.json`) is missing or malformed.
|
||||
#[error("Metadata error for subject {subject_id}: {message}")]
|
||||
MetadataError {
|
||||
/// Subject whose metadata could not be read.
|
||||
subject_id: u32,
|
||||
/// Description of the problem.
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// No subjects matching the requested IDs were found in the data directory.
|
||||
#[error(
|
||||
"No subjects found in `{data_dir}` matching the requested IDs: {requested:?}"
|
||||
)]
|
||||
NoSubjectsFound {
|
||||
/// Root data directory that was scanned.
|
||||
data_dir: PathBuf,
|
||||
/// Subject IDs that were requested.
|
||||
requested: Vec<u32>,
|
||||
},
|
||||
|
||||
/// A subcarrier interpolation error occurred during sample loading.
|
||||
#[error("Subcarrier interpolation failed while loading sample {sample_idx}: {source}")]
|
||||
InterpolationError {
|
||||
/// The sample index being loaded.
|
||||
sample_idx: usize,
|
||||
/// Underlying interpolation error.
|
||||
#[source]
|
||||
source: SubcarrierError,
|
||||
},
|
||||
}
|
||||
|
||||
impl DatasetError {
|
||||
/// Construct a [`DatasetError::DataNotFound`] error.
|
||||
pub fn not_found<S: Into<String>>(path: impl Into<PathBuf>, msg: S) -> Self {
|
||||
DatasetError::DataNotFound {
|
||||
path: path.into(),
|
||||
message: msg.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct a [`DatasetError::InvalidFormat`] error.
|
||||
pub fn invalid_format<S: Into<String>>(path: impl Into<PathBuf>, msg: S) -> Self {
|
||||
DatasetError::InvalidFormat {
|
||||
path: path.into(),
|
||||
message: msg.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct a [`DatasetError::IoError`] error.
|
||||
pub fn io_error(path: impl Into<PathBuf>, source: std::io::Error) -> Self {
|
||||
DatasetError::IoError {
|
||||
path: path.into(),
|
||||
source,
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct a [`DatasetError::SubcarrierMismatch`] error.
|
||||
pub fn subcarrier_mismatch(path: impl Into<PathBuf>, found: usize, expected: usize) -> Self {
|
||||
DatasetError::SubcarrierMismatch {
|
||||
path: path.into(),
|
||||
found,
|
||||
expected,
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct a [`DatasetError::NpyReadError`] error.
|
||||
pub fn npy_read<S: Into<String>>(path: impl Into<PathBuf>, msg: S) -> Self {
|
||||
DatasetError::NpyReadError {
|
||||
path: path.into(),
|
||||
message: msg.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Subcarrier interpolation errors
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Errors produced by the subcarrier resampling functions.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum SubcarrierError {
|
||||
/// The source or destination subcarrier count is zero.
|
||||
#[error("Subcarrier count must be at least 1, got {count}")]
|
||||
ZeroCount {
|
||||
/// The offending count.
|
||||
count: usize,
|
||||
},
|
||||
|
||||
/// The input array has an unexpected shape.
|
||||
#[error(
|
||||
"Input array shape mismatch: expected last dimension {expected_sc}, \
|
||||
got {actual_sc} (full shape: {shape:?})"
|
||||
)]
|
||||
InputShapeMismatch {
|
||||
/// Expected number of subcarriers (last dimension).
|
||||
expected_sc: usize,
|
||||
/// Actual number of subcarriers found.
|
||||
actual_sc: usize,
|
||||
/// Full shape of the input array.
|
||||
shape: Vec<usize>,
|
||||
},
|
||||
|
||||
/// The requested interpolation method is not implemented.
|
||||
#[error("Interpolation method `{method}` is not yet implemented")]
|
||||
MethodNotImplemented {
|
||||
/// Name of the unimplemented method.
|
||||
method: String,
|
||||
},
|
||||
|
||||
/// Source and destination subcarrier counts are already equal.
|
||||
///
|
||||
/// Callers should check [`TrainingConfig::needs_subcarrier_interp`] before
|
||||
/// calling the interpolation routine to avoid this error.
|
||||
///
|
||||
/// [`TrainingConfig::needs_subcarrier_interp`]:
|
||||
/// crate::config::TrainingConfig::needs_subcarrier_interp
|
||||
#[error(
|
||||
"Source and destination subcarrier counts are equal ({count}); \
|
||||
no interpolation is needed"
|
||||
)]
|
||||
NopInterpolation {
|
||||
/// The equal count.
|
||||
count: usize,
|
||||
},
|
||||
|
||||
/// A numerical error occurred during interpolation (e.g., division by zero
|
||||
/// due to coincident knot positions).
|
||||
#[error("Numerical error during interpolation: {0}")]
|
||||
NumericalError(String),
|
||||
}
|
||||
|
||||
impl SubcarrierError {
|
||||
/// Construct a [`SubcarrierError::NumericalError`].
|
||||
pub fn numerical<S: Into<String>>(msg: S) -> Self {
|
||||
SubcarrierError::NumericalError(msg.into())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user