feat(train): Add ruvector integration — ADR-016, deps, DynamicPersonMatcher
- docs/adr/ADR-016: Full ruvector integration ADR with verified API details from source inspection (github.com/ruvnet/ruvector). Covers mincut, attn-mincut, temporal-tensor, solver, and attention at v2.0.4. - Cargo.toml: Add ruvector-mincut, ruvector-attn-mincut, ruvector-temporal- tensor, ruvector-solver, ruvector-attention = "2.0.4" to workspace deps and wifi-densepose-train crate deps. - metrics.rs: Add DynamicPersonMatcher wrapping ruvector_mincut::DynamicMinCut for subpolynomial O(n^1.5 log n) multi-frame person tracking; adds assignment_mincut() public entry point. - proof.rs, trainer.rs, model.rs, dataset.rs, subcarrier.rs: Agent improvements to full implementations (loss decrease verification, SHA-256 hash, LCG shuffle, ResNet18 backbone, MmFiDataset, linear interp). - tests: test_config, test_dataset, test_metrics, test_proof, training_bench all added/updated. 100+ tests pass with no-default-features. https://claude.ai/code/session_01BSBAQJ34SLkiJy4A8SoiL4
This commit is contained in:
@@ -1,44 +1,46 @@
|
||||
//! Error types for the WiFi-DensePose training pipeline.
|
||||
//!
|
||||
//! This module provides:
|
||||
//! This module is the single source of truth for all error types in the
|
||||
//! training crate. Every module that produces an error imports its error type
|
||||
//! from here rather than defining it inline, keeping the error hierarchy
|
||||
//! centralised and consistent.
|
||||
//!
|
||||
//! - [`TrainError`]: top-level error aggregating all training failure modes.
|
||||
//! - [`TrainResult`]: convenient `Result` alias using `TrainError`.
|
||||
//! ## Hierarchy
|
||||
//!
|
||||
//! Module-local error types live in their respective modules:
|
||||
//!
|
||||
//! - [`crate::config::ConfigError`]: configuration validation errors.
|
||||
//! - [`crate::dataset::DatasetError`]: dataset loading/access errors.
|
||||
//!
|
||||
//! All are re-exported at the crate root for ergonomic use.
|
||||
//! ```text
|
||||
//! TrainError (top-level)
|
||||
//! ├── ConfigError (config validation / file loading)
|
||||
//! ├── DatasetError (data loading, I/O, format)
|
||||
//! └── SubcarrierError (frequency-axis resampling)
|
||||
//! ```
|
||||
|
||||
use thiserror::Error;
|
||||
use std::path::PathBuf;
|
||||
|
||||
// Import module-local error types so TrainError can wrap them via #[from],
|
||||
// and re-export them so `lib.rs` can forward them from `error::*`.
|
||||
pub use crate::config::ConfigError;
|
||||
pub use crate::dataset::DatasetError;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Top-level training error
|
||||
// TrainResult
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// A convenient `Result` alias used throughout the training crate.
|
||||
/// Convenient `Result` alias used by orchestration-level functions.
|
||||
pub type TrainResult<T> = Result<T, TrainError>;
|
||||
|
||||
/// Top-level error type for the training pipeline.
|
||||
// ---------------------------------------------------------------------------
|
||||
// TrainError — top-level aggregator
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Top-level error type for the WiFi-DensePose training pipeline.
|
||||
///
|
||||
/// Every orchestration-level function returns `TrainResult<T>`. Lower-level
|
||||
/// functions in [`crate::config`] and [`crate::dataset`] return their own
|
||||
/// module-specific error types which are automatically coerced via `#[from]`.
|
||||
/// Orchestration-level functions (e.g. [`crate::trainer::Trainer`] methods)
|
||||
/// return `TrainResult<T>`. Lower-level functions in [`crate::config`] and
|
||||
/// [`crate::dataset`] return their own module-specific error types which are
|
||||
/// automatically coerced into `TrainError` via [`From`].
|
||||
#[derive(Debug, Error)]
|
||||
pub enum TrainError {
|
||||
/// Configuration is invalid or internally inconsistent.
|
||||
/// A configuration validation or loading error.
|
||||
#[error("Configuration error: {0}")]
|
||||
Config(#[from] ConfigError),
|
||||
|
||||
/// A dataset operation failed (I/O, format, missing data).
|
||||
/// A dataset loading or access error.
|
||||
#[error("Dataset error: {0}")]
|
||||
Dataset(#[from] DatasetError),
|
||||
|
||||
@@ -46,28 +48,20 @@ pub enum TrainError {
|
||||
#[error("JSON error: {0}")]
|
||||
Json(#[from] serde_json::Error),
|
||||
|
||||
/// An underlying I/O error not wrapped by Config or Dataset.
|
||||
///
|
||||
/// Note: [`std::io::Error`] cannot be wrapped via `#[from]` here because
|
||||
/// both [`ConfigError`] and [`DatasetError`] already implement
|
||||
/// `From<std::io::Error>`. Callers should convert via those types instead.
|
||||
#[error("I/O error: {0}")]
|
||||
Io(String),
|
||||
|
||||
/// An operation was attempted on an empty dataset.
|
||||
/// The dataset is empty and no training can be performed.
|
||||
#[error("Dataset is empty")]
|
||||
EmptyDataset,
|
||||
|
||||
/// Index out of bounds when accessing dataset items.
|
||||
#[error("Index {index} is out of bounds for dataset of length {len}")]
|
||||
IndexOutOfBounds {
|
||||
/// The requested index.
|
||||
/// The out-of-range index.
|
||||
index: usize,
|
||||
/// The total number of items.
|
||||
/// The total number of items in the dataset.
|
||||
len: usize,
|
||||
},
|
||||
|
||||
/// A numeric shape/dimension mismatch was detected.
|
||||
/// A shape mismatch was detected between two tensors.
|
||||
#[error("Shape mismatch: expected {expected:?}, got {actual:?}")]
|
||||
ShapeMismatch {
|
||||
/// Expected shape.
|
||||
@@ -76,11 +70,11 @@ pub enum TrainError {
|
||||
actual: Vec<usize>,
|
||||
},
|
||||
|
||||
/// A training step failed for a reason not covered above.
|
||||
/// A training step failed.
|
||||
#[error("Training step failed: {0}")]
|
||||
TrainingStep(String),
|
||||
|
||||
/// Checkpoint could not be saved or loaded.
|
||||
/// A checkpoint could not be saved or loaded.
|
||||
#[error("Checkpoint error: {message} (path: {path:?})")]
|
||||
Checkpoint {
|
||||
/// Human-readable description.
|
||||
@@ -95,83 +89,262 @@ pub enum TrainError {
|
||||
}
|
||||
|
||||
impl TrainError {
|
||||
/// Create a [`TrainError::TrainingStep`] with the given message.
|
||||
/// Construct a [`TrainError::TrainingStep`].
|
||||
pub fn training_step<S: Into<String>>(msg: S) -> Self {
|
||||
TrainError::TrainingStep(msg.into())
|
||||
}
|
||||
|
||||
/// Create a [`TrainError::Checkpoint`] error.
|
||||
/// Construct a [`TrainError::Checkpoint`].
|
||||
pub fn checkpoint<S: Into<String>>(msg: S, path: impl Into<PathBuf>) -> Self {
|
||||
TrainError::Checkpoint {
|
||||
message: msg.into(),
|
||||
path: path.into(),
|
||||
}
|
||||
TrainError::Checkpoint { message: msg.into(), path: path.into() }
|
||||
}
|
||||
|
||||
/// Create a [`TrainError::NotImplemented`] error.
|
||||
/// Construct a [`TrainError::NotImplemented`].
|
||||
pub fn not_implemented<S: Into<String>>(msg: S) -> Self {
|
||||
TrainError::NotImplemented(msg.into())
|
||||
}
|
||||
|
||||
/// Create a [`TrainError::ShapeMismatch`] error.
|
||||
/// Construct a [`TrainError::ShapeMismatch`].
|
||||
pub fn shape_mismatch(expected: Vec<usize>, actual: Vec<usize>) -> Self {
|
||||
TrainError::ShapeMismatch { expected, actual }
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ConfigError
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Errors produced when loading or validating a [`TrainingConfig`].
|
||||
///
|
||||
/// [`TrainingConfig`]: crate::config::TrainingConfig
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ConfigError {
|
||||
/// A field has an invalid value.
|
||||
#[error("Invalid value for `{field}`: {reason}")]
|
||||
InvalidValue {
|
||||
/// Name of the field.
|
||||
field: &'static str,
|
||||
/// Human-readable reason.
|
||||
reason: String,
|
||||
},
|
||||
|
||||
/// A configuration file could not be read from disk.
|
||||
#[error("Cannot read config file `{path}`: {source}")]
|
||||
FileRead {
|
||||
/// Path that was being read.
|
||||
path: PathBuf,
|
||||
/// Underlying I/O error.
|
||||
#[source]
|
||||
source: std::io::Error,
|
||||
},
|
||||
|
||||
/// A configuration file contains malformed JSON.
|
||||
#[error("Cannot parse config file `{path}`: {source}")]
|
||||
ParseError {
|
||||
/// Path that was being parsed.
|
||||
path: PathBuf,
|
||||
/// Underlying JSON parse error.
|
||||
#[source]
|
||||
source: serde_json::Error,
|
||||
},
|
||||
|
||||
/// A path referenced in the config does not exist.
|
||||
#[error("Path `{path}` in config does not exist")]
|
||||
PathNotFound {
|
||||
/// The missing path.
|
||||
path: PathBuf,
|
||||
},
|
||||
}
|
||||
|
||||
impl ConfigError {
|
||||
/// Construct a [`ConfigError::InvalidValue`].
|
||||
pub fn invalid_value<S: Into<String>>(field: &'static str, reason: S) -> Self {
|
||||
ConfigError::InvalidValue { field, reason: reason.into() }
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// DatasetError
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Errors produced while loading or accessing dataset samples.
|
||||
///
|
||||
/// Production training code MUST NOT silently suppress these errors.
|
||||
/// If data is missing, training must fail explicitly so the user is aware.
|
||||
/// The [`SyntheticCsiDataset`] is the only source of non-file-system data
|
||||
/// and is restricted to proof/testing use.
|
||||
///
|
||||
/// [`SyntheticCsiDataset`]: crate::dataset::SyntheticCsiDataset
|
||||
#[derive(Debug, Error)]
|
||||
pub enum DatasetError {
|
||||
/// A required data file or directory was not found on disk.
|
||||
#[error("Data not found at `{path}`: {message}")]
|
||||
DataNotFound {
|
||||
/// Path that was expected to contain data.
|
||||
path: PathBuf,
|
||||
/// Additional context.
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// A file was found but its format or shape is wrong.
|
||||
#[error("Invalid data format in `{path}`: {message}")]
|
||||
InvalidFormat {
|
||||
/// Path of the malformed file.
|
||||
path: PathBuf,
|
||||
/// Description of the problem.
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// A low-level I/O error while reading a data file.
|
||||
#[error("I/O error reading `{path}`: {source}")]
|
||||
IoError {
|
||||
/// Path being read when the error occurred.
|
||||
path: PathBuf,
|
||||
/// Underlying I/O error.
|
||||
#[source]
|
||||
source: std::io::Error,
|
||||
},
|
||||
|
||||
/// The number of subcarriers in the file doesn't match expectations.
|
||||
#[error(
|
||||
"Subcarrier count mismatch in `{path}`: file has {found}, expected {expected}"
|
||||
)]
|
||||
SubcarrierMismatch {
|
||||
/// Path of the offending file.
|
||||
path: PathBuf,
|
||||
/// Subcarrier count found in the file.
|
||||
found: usize,
|
||||
/// Subcarrier count expected.
|
||||
expected: usize,
|
||||
},
|
||||
|
||||
/// A sample index is out of bounds.
|
||||
#[error("Index {idx} out of bounds (dataset has {len} samples)")]
|
||||
IndexOutOfBounds {
|
||||
/// The requested index.
|
||||
idx: usize,
|
||||
/// Total length of the dataset.
|
||||
len: usize,
|
||||
},
|
||||
|
||||
/// A numpy array file could not be parsed.
|
||||
#[error("NumPy read error in `{path}`: {message}")]
|
||||
NpyReadError {
|
||||
/// Path of the `.npy` file.
|
||||
path: PathBuf,
|
||||
/// Error description.
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// Metadata for a subject is missing or malformed.
|
||||
#[error("Metadata error for subject {subject_id}: {message}")]
|
||||
MetadataError {
|
||||
/// Subject whose metadata was invalid.
|
||||
subject_id: u32,
|
||||
/// Description of the problem.
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// A data format error (e.g. wrong numpy shape) occurred.
|
||||
///
|
||||
/// This is a convenience variant for short-form error messages where
|
||||
/// the full path context is not available.
|
||||
#[error("File format error: {0}")]
|
||||
Format(String),
|
||||
|
||||
/// The data directory does not exist.
|
||||
#[error("Directory not found: {path}")]
|
||||
DirectoryNotFound {
|
||||
/// The path that was not found.
|
||||
path: String,
|
||||
},
|
||||
|
||||
/// No subjects matching the requested IDs were found.
|
||||
#[error(
|
||||
"No subjects found in `{data_dir}` for IDs: {requested:?}"
|
||||
)]
|
||||
NoSubjectsFound {
|
||||
/// Root data directory.
|
||||
data_dir: PathBuf,
|
||||
/// IDs that were requested.
|
||||
requested: Vec<u32>,
|
||||
},
|
||||
|
||||
/// An I/O error that carries no path context.
|
||||
#[error("IO error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
impl DatasetError {
|
||||
/// Construct a [`DatasetError::DataNotFound`].
|
||||
pub fn not_found<S: Into<String>>(path: impl Into<PathBuf>, msg: S) -> Self {
|
||||
DatasetError::DataNotFound { path: path.into(), message: msg.into() }
|
||||
}
|
||||
|
||||
/// Construct a [`DatasetError::InvalidFormat`].
|
||||
pub fn invalid_format<S: Into<String>>(path: impl Into<PathBuf>, msg: S) -> Self {
|
||||
DatasetError::InvalidFormat { path: path.into(), message: msg.into() }
|
||||
}
|
||||
|
||||
/// Construct a [`DatasetError::IoError`].
|
||||
pub fn io_error(path: impl Into<PathBuf>, source: std::io::Error) -> Self {
|
||||
DatasetError::IoError { path: path.into(), source }
|
||||
}
|
||||
|
||||
/// Construct a [`DatasetError::SubcarrierMismatch`].
|
||||
pub fn subcarrier_mismatch(path: impl Into<PathBuf>, found: usize, expected: usize) -> Self {
|
||||
DatasetError::SubcarrierMismatch { path: path.into(), found, expected }
|
||||
}
|
||||
|
||||
/// Construct a [`DatasetError::NpyReadError`].
|
||||
pub fn npy_read<S: Into<String>>(path: impl Into<PathBuf>, msg: S) -> Self {
|
||||
DatasetError::NpyReadError { path: path.into(), message: msg.into() }
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SubcarrierError
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Errors produced by the subcarrier resampling / interpolation functions.
|
||||
///
|
||||
/// These are separate from [`DatasetError`] because subcarrier operations are
|
||||
/// also usable outside the dataset loading pipeline (e.g. in real-time
|
||||
/// inference preprocessing).
|
||||
#[derive(Debug, Error)]
|
||||
pub enum SubcarrierError {
|
||||
/// The source or destination subcarrier count is zero.
|
||||
/// The source or destination count is zero.
|
||||
#[error("Subcarrier count must be >= 1, got {count}")]
|
||||
ZeroCount {
|
||||
/// The offending count.
|
||||
count: usize,
|
||||
},
|
||||
|
||||
/// The input array's last dimension does not match the declared source count.
|
||||
/// The array's last dimension does not match the declared source count.
|
||||
#[error(
|
||||
"Subcarrier shape mismatch: last dimension is {actual_sc} \
|
||||
but `src_n` was declared as {expected_sc} (full shape: {shape:?})"
|
||||
"Subcarrier shape mismatch: last dim is {actual_sc} but src_n={expected_sc} \
|
||||
(full shape: {shape:?})"
|
||||
)]
|
||||
InputShapeMismatch {
|
||||
/// Expected subcarrier count (as declared by the caller).
|
||||
/// Expected subcarrier count.
|
||||
expected_sc: usize,
|
||||
/// Actual last-dimension size of the input array.
|
||||
/// Actual last-dimension size.
|
||||
actual_sc: usize,
|
||||
/// Full shape of the input array.
|
||||
/// Full shape of the input.
|
||||
shape: Vec<usize>,
|
||||
},
|
||||
|
||||
/// The requested interpolation method is not yet implemented.
|
||||
#[error("Interpolation method `{method}` is not implemented")]
|
||||
MethodNotImplemented {
|
||||
/// Human-readable name of the unsupported method.
|
||||
/// Name of the unsupported method.
|
||||
method: String,
|
||||
},
|
||||
|
||||
/// `src_n == dst_n` — no resampling is needed.
|
||||
///
|
||||
/// Callers should check [`TrainingConfig::needs_subcarrier_interp`] before
|
||||
/// calling the interpolation routine.
|
||||
///
|
||||
/// [`TrainingConfig::needs_subcarrier_interp`]:
|
||||
/// crate::config::TrainingConfig::needs_subcarrier_interp
|
||||
#[error("src_n == dst_n == {count}; no interpolation needed")]
|
||||
/// `src_n == dst_n` — no resampling needed.
|
||||
#[error("src_n == dst_n == {count}; call interpolate only when counts differ")]
|
||||
NopInterpolation {
|
||||
/// The equal count.
|
||||
count: usize,
|
||||
},
|
||||
|
||||
/// A numerical error during interpolation (e.g. division by zero).
|
||||
/// A numerical error during interpolation.
|
||||
#[error("Numerical error: {0}")]
|
||||
NumericalError(String),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user