feat(rust): Complete training pipeline — losses, metrics, model, trainer, binaries
Losses (losses.rs — 1056 lines): - WiFiDensePoseLoss with keypoint (visibility-masked MSE), DensePose (cross-entropy + Smooth L1 UV masked to foreground), transfer (MSE) - generate_gaussian_heatmaps: Tensor-native 2D Gaussian heatmap gen - compute_losses: unified functional API - 11 deterministic unit tests Metrics (metrics.rs — 984 lines): - PCK@0.2 / PCK@0.5 with torso-diameter normalisation - OKS with COCO standard per-joint sigmas - MetricsAccumulator for online streaming eval - hungarian_assignment: O(n³) Kuhn-Munkres min-cut via DFS augmenting paths for optimal multi-person keypoint assignment (ruvector min-cut) - build_oks_cost_matrix: 1−OKS cost for bipartite matching - 20 deterministic tests (perfect/wrong/invisible keypoints, 2×2/3×3/ rectangular/empty Hungarian cases) Model (model.rs — 713 lines): - WiFiDensePoseModel end-to-end with tch-rs - ModalityTranslator: amp+phase FC encoders → spatial pseudo-image - Backbone: lightweight ResNet-style [B,3,48,48]→[B,256,6,6] - KeypointHead: [B,256,6,6]→[B,17,H,W] heatmaps - DensePoseHead: [B,256,6,6]→[B,25,H,W] parts + [B,48,H,W] UV Trainer (trainer.rs — 777 lines): - Full training loop: Adam, LR milestones, gradient clipping - Deterministic batch shuffle via LCG (seed XOR epoch) - CSV logging, best-checkpoint saving, early stopping - evaluate() with MetricsAccumulator and heatmap argmax decode Binaries: - src/bin/train.rs: production MM-Fi training CLI (clap) - src/bin/verify_training.rs: trust kill switch (EXIT 0/1/2) Benches: - benches/training_bench.rs: criterion benchmarks for key ops Tests: - tests/test_dataset.rs (459 lines) - tests/test_metrics.rs (449 lines) - tests/test_subcarrier.rs (389 lines) proof.rs still stub — trainer agent completing it. https://claude.ai/code/session_01BSBAQJ34SLkiJy4A8SoiL4
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
//!
|
||||
//! This module defines the [`CsiDataset`] trait plus two concrete implementations:
|
||||
//!
|
||||
//! - [`MmFiDataset`]: reads MM-Fi NPY/HDF5 files from disk.
|
||||
//! - [`MmFiDataset`]: reads MM-Fi NPY files from disk.
|
||||
//! - [`SyntheticCsiDataset`]: generates fully-deterministic CSI from a physics
|
||||
//! model; useful for unit tests, integration tests, and dry-run sanity checks.
|
||||
//! **Never uses random data.**
|
||||
@@ -18,7 +18,7 @@
|
||||
//! A01/
|
||||
//! wifi_csi.npy # amplitude [T, n_tx, n_rx, n_sc]
|
||||
//! wifi_csi_phase.npy # phase [T, n_tx, n_rx, n_sc]
|
||||
//! gt_keypoints.npy # keypoints [T, 17, 3] (x, y, vis)
|
||||
//! gt_keypoints.npy # ground-truth keypoints [T, 17, 3] (x, y, vis)
|
||||
//! A02/
|
||||
//! ...
|
||||
//! S02/
|
||||
@@ -42,9 +42,9 @@
|
||||
|
||||
use ndarray::{Array1, Array2, Array4};
|
||||
use std::path::{Path, PathBuf};
|
||||
use thiserror::Error;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::error::DatasetError;
|
||||
use crate::subcarrier::interpolate_subcarriers;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -259,8 +259,6 @@ struct MmFiEntry {
|
||||
num_frames: usize,
|
||||
/// Window size in frames (mirrors config).
|
||||
window_frames: usize,
|
||||
/// First global sample index that maps into this clip.
|
||||
global_start_idx: usize,
|
||||
}
|
||||
|
||||
impl MmFiEntry {
|
||||
@@ -305,8 +303,8 @@ impl MmFiDataset {
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`DatasetError::DirectoryNotFound`] if `root` does not exist, or
|
||||
/// [`DatasetError::Io`] for any filesystem access failure.
|
||||
/// Returns [`DatasetError::DataNotFound`] if `root` does not exist, or an
|
||||
/// IO error for any filesystem access failure.
|
||||
pub fn discover(
|
||||
root: &Path,
|
||||
window_frames: usize,
|
||||
@@ -314,16 +312,17 @@ impl MmFiDataset {
|
||||
num_keypoints: usize,
|
||||
) -> Result<Self, DatasetError> {
|
||||
if !root.exists() {
|
||||
return Err(DatasetError::DirectoryNotFound {
|
||||
path: root.display().to_string(),
|
||||
});
|
||||
return Err(DatasetError::not_found(
|
||||
root,
|
||||
"MM-Fi root directory not found",
|
||||
));
|
||||
}
|
||||
|
||||
let mut entries: Vec<MmFiEntry> = Vec::new();
|
||||
let mut global_idx = 0usize;
|
||||
|
||||
// Walk subject directories (S01, S02, …)
|
||||
let mut subject_dirs: Vec<PathBuf> = std::fs::read_dir(root)?
|
||||
let mut subject_dirs: Vec<PathBuf> = std::fs::read_dir(root)
|
||||
.map_err(|e| DatasetError::io_error(root, e))?
|
||||
.filter_map(|e| e.ok())
|
||||
.filter(|e| e.file_type().map(|t| t.is_dir()).unwrap_or(false))
|
||||
.map(|e| e.path())
|
||||
@@ -335,7 +334,8 @@ impl MmFiDataset {
|
||||
let subject_id = parse_id_suffix(subj_name).unwrap_or(0);
|
||||
|
||||
// Walk action directories (A01, A02, …)
|
||||
let mut action_dirs: Vec<PathBuf> = std::fs::read_dir(subj_path)?
|
||||
let mut action_dirs: Vec<PathBuf> = std::fs::read_dir(subj_path)
|
||||
.map_err(|e| DatasetError::io_error(subj_path.as_path(), e))?
|
||||
.filter_map(|e| e.ok())
|
||||
.filter(|e| e.file_type().map(|t| t.is_dir()).unwrap_or(false))
|
||||
.map(|e| e.path())
|
||||
@@ -368,7 +368,7 @@ impl MmFiDataset {
|
||||
}
|
||||
};
|
||||
|
||||
let entry = MmFiEntry {
|
||||
entries.push(MmFiEntry {
|
||||
subject_id,
|
||||
action_id,
|
||||
amp_path,
|
||||
@@ -376,17 +376,15 @@ impl MmFiDataset {
|
||||
kp_path,
|
||||
num_frames,
|
||||
window_frames,
|
||||
global_start_idx: global_idx,
|
||||
};
|
||||
global_idx += entry.num_windows();
|
||||
entries.push(entry);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let total_windows: usize = entries.iter().map(|e| e.num_windows()).sum();
|
||||
info!(
|
||||
"MmFiDataset: scanned {} clips, {} total windows (root={})",
|
||||
entries.len(),
|
||||
global_idx,
|
||||
total_windows,
|
||||
root.display()
|
||||
);
|
||||
|
||||
@@ -429,9 +427,11 @@ impl CsiDataset for MmFiDataset {
|
||||
|
||||
fn get(&self, idx: usize) -> Result<CsiSample, DatasetError> {
|
||||
let total = self.len();
|
||||
let (entry_idx, frame_offset) = self
|
||||
.locate(idx)
|
||||
.ok_or(DatasetError::IndexOutOfBounds { idx, len: total })?;
|
||||
let (entry_idx, frame_offset) =
|
||||
self.locate(idx).ok_or(DatasetError::IndexOutOfBounds {
|
||||
index: idx,
|
||||
len: total,
|
||||
})?;
|
||||
|
||||
let entry = &self.entries[entry_idx];
|
||||
let t_start = frame_offset;
|
||||
@@ -441,10 +441,12 @@ impl CsiDataset for MmFiDataset {
|
||||
let amp_full = load_npy_f32(&entry.amp_path)?;
|
||||
let (t, n_tx, n_rx, n_sc) = amp_full.dim();
|
||||
if t_end > t {
|
||||
return Err(DatasetError::Format(format!(
|
||||
"window [{t_start}, {t_end}) exceeds clip length {t} in {}",
|
||||
entry.amp_path.display()
|
||||
)));
|
||||
return Err(DatasetError::invalid_format(
|
||||
&entry.amp_path,
|
||||
format!(
|
||||
"window [{t_start}, {t_end}) exceeds clip length {t}"
|
||||
),
|
||||
));
|
||||
}
|
||||
let amp_window = amp_full
|
||||
.slice(ndarray::s![t_start..t_end, .., .., ..])
|
||||
@@ -500,78 +502,77 @@ impl CsiDataset for MmFiDataset {
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// NPY helpers (no-HDF5 path; HDF5 path is feature-gated below)
|
||||
// NPY helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Load a 4-D float32 NPY array from disk.
|
||||
///
|
||||
/// The NPY format is read using `ndarray_npy`.
|
||||
fn load_npy_f32(path: &Path) -> Result<Array4<f32>, DatasetError> {
|
||||
use ndarray_npy::ReadNpyExt;
|
||||
let file = std::fs::File::open(path)?;
|
||||
let file = std::fs::File::open(path)
|
||||
.map_err(|e| DatasetError::io_error(path, e))?;
|
||||
let arr: ndarray::ArrayD<f32> = ndarray::ArrayD::read_npy(file)
|
||||
.map_err(|e| DatasetError::Format(format!("NPY read error at {}: {e}", path.display())))?;
|
||||
arr.into_dimensionality::<ndarray::Ix4>().map_err(|e| {
|
||||
DatasetError::Format(format!(
|
||||
"Expected 4-D array in {}, got shape {:?}: {e}",
|
||||
path.display(),
|
||||
arr.shape()
|
||||
))
|
||||
.map_err(|e| DatasetError::npy_read(path, e.to_string()))?;
|
||||
arr.into_dimensionality::<ndarray::Ix4>().map_err(|_e| {
|
||||
DatasetError::invalid_format(
|
||||
path,
|
||||
format!("Expected 4-D array, got shape {:?}", arr.shape()),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/// Load a 3-D float32 NPY array (keypoints: `[T, J, 3]`).
|
||||
fn load_npy_kp(path: &Path, _num_keypoints: usize) -> Result<ndarray::Array3<f32>, DatasetError> {
|
||||
use ndarray_npy::ReadNpyExt;
|
||||
let file = std::fs::File::open(path)?;
|
||||
let file = std::fs::File::open(path)
|
||||
.map_err(|e| DatasetError::io_error(path, e))?;
|
||||
let arr: ndarray::ArrayD<f32> = ndarray::ArrayD::read_npy(file)
|
||||
.map_err(|e| DatasetError::Format(format!("NPY read error at {}: {e}", path.display())))?;
|
||||
arr.into_dimensionality::<ndarray::Ix3>().map_err(|e| {
|
||||
DatasetError::Format(format!(
|
||||
"Expected 3-D keypoint array in {}, got shape {:?}: {e}",
|
||||
path.display(),
|
||||
arr.shape()
|
||||
))
|
||||
.map_err(|e| DatasetError::npy_read(path, e.to_string()))?;
|
||||
arr.into_dimensionality::<ndarray::Ix3>().map_err(|_e| {
|
||||
DatasetError::invalid_format(
|
||||
path,
|
||||
format!("Expected 3-D keypoint array, got shape {:?}", arr.shape()),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/// Read only the first dimension of an NPY header (the frame count) without
|
||||
/// loading the entire file into memory.
|
||||
fn peek_npy_first_dim(path: &Path) -> Result<usize, DatasetError> {
|
||||
// Minimum viable NPY header parse: magic + version + header_len + header.
|
||||
use std::io::{BufReader, Read};
|
||||
let f = std::fs::File::open(path)?;
|
||||
let f = std::fs::File::open(path)
|
||||
.map_err(|e| DatasetError::io_error(path, e))?;
|
||||
let mut reader = BufReader::new(f);
|
||||
|
||||
let mut magic = [0u8; 6];
|
||||
reader.read_exact(&mut magic)?;
|
||||
reader.read_exact(&mut magic)
|
||||
.map_err(|e| DatasetError::io_error(path, e))?;
|
||||
if &magic != b"\x93NUMPY" {
|
||||
return Err(DatasetError::Format(format!(
|
||||
"Not a valid NPY file: {}",
|
||||
path.display()
|
||||
)));
|
||||
return Err(DatasetError::invalid_format(path, "Not a valid NPY file"));
|
||||
}
|
||||
|
||||
let mut version = [0u8; 2];
|
||||
reader.read_exact(&mut version)?;
|
||||
reader.read_exact(&mut version)
|
||||
.map_err(|e| DatasetError::io_error(path, e))?;
|
||||
|
||||
// Header length field: 2 bytes in v1, 4 bytes in v2
|
||||
let header_len: usize = if version[0] == 1 {
|
||||
let mut buf = [0u8; 2];
|
||||
reader.read_exact(&mut buf)?;
|
||||
reader.read_exact(&mut buf)
|
||||
.map_err(|e| DatasetError::io_error(path, e))?;
|
||||
u16::from_le_bytes(buf) as usize
|
||||
} else {
|
||||
let mut buf = [0u8; 4];
|
||||
reader.read_exact(&mut buf)?;
|
||||
reader.read_exact(&mut buf)
|
||||
.map_err(|e| DatasetError::io_error(path, e))?;
|
||||
u32::from_le_bytes(buf) as usize
|
||||
};
|
||||
|
||||
let mut header = vec![0u8; header_len];
|
||||
reader.read_exact(&mut header)?;
|
||||
reader.read_exact(&mut header)
|
||||
.map_err(|e| DatasetError::io_error(path, e))?;
|
||||
let header_str = String::from_utf8_lossy(&header);
|
||||
|
||||
// Parse the shape tuple using a simple substring search.
|
||||
// Example header: "{'descr': '<f4', 'fortran_order': False, 'shape': (300, 3, 3, 114), }"
|
||||
if let Some(start) = header_str.find("'shape': (") {
|
||||
let rest = &header_str[start + "'shape': (".len()..];
|
||||
if let Some(end) = rest.find(')') {
|
||||
@@ -586,10 +587,7 @@ fn peek_npy_first_dim(path: &Path) -> Result<usize, DatasetError> {
|
||||
}
|
||||
}
|
||||
|
||||
Err(DatasetError::Format(format!(
|
||||
"Cannot parse shape from NPY header in {}",
|
||||
path.display()
|
||||
)))
|
||||
Err(DatasetError::invalid_format(path, "Cannot parse shape from NPY header"))
|
||||
}
|
||||
|
||||
/// Parse the numeric suffix of a directory name like `S01` → `1` or `A12` → `12`.
|
||||
@@ -711,7 +709,7 @@ impl CsiDataset for SyntheticCsiDataset {
|
||||
fn get(&self, idx: usize) -> Result<CsiSample, DatasetError> {
|
||||
if idx >= self.num_samples {
|
||||
return Err(DatasetError::IndexOutOfBounds {
|
||||
idx,
|
||||
index: idx,
|
||||
len: self.num_samples,
|
||||
});
|
||||
}
|
||||
@@ -755,34 +753,6 @@ impl CsiDataset for SyntheticCsiDataset {
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// DatasetError
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Errors produced by dataset operations.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum DatasetError {
|
||||
/// Requested index is outside the valid range.
|
||||
#[error("Index {idx} out of bounds (dataset has {len} samples)")]
|
||||
IndexOutOfBounds { idx: usize, len: usize },
|
||||
|
||||
/// An underlying file-system error occurred.
|
||||
#[error("IO error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
/// The file exists but does not match the expected format.
|
||||
#[error("File format error: {0}")]
|
||||
Format(String),
|
||||
|
||||
/// The loaded array has a different subcarrier count than required.
|
||||
#[error("Subcarrier count mismatch: expected {expected}, got {actual}")]
|
||||
SubcarrierMismatch { expected: usize, actual: usize },
|
||||
|
||||
/// The specified root directory does not exist.
|
||||
#[error("Directory not found: {path}")]
|
||||
DirectoryNotFound { path: String },
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -800,8 +770,14 @@ mod tests {
|
||||
let ds = SyntheticCsiDataset::new(10, cfg.clone());
|
||||
let s = ds.get(0).unwrap();
|
||||
|
||||
assert_eq!(s.amplitude.shape(), &[cfg.window_frames, cfg.num_antennas_tx, cfg.num_antennas_rx, cfg.num_subcarriers]);
|
||||
assert_eq!(s.phase.shape(), &[cfg.window_frames, cfg.num_antennas_tx, cfg.num_antennas_rx, cfg.num_subcarriers]);
|
||||
assert_eq!(
|
||||
s.amplitude.shape(),
|
||||
&[cfg.window_frames, cfg.num_antennas_tx, cfg.num_antennas_rx, cfg.num_subcarriers]
|
||||
);
|
||||
assert_eq!(
|
||||
s.phase.shape(),
|
||||
&[cfg.window_frames, cfg.num_antennas_tx, cfg.num_antennas_rx, cfg.num_subcarriers]
|
||||
);
|
||||
assert_eq!(s.keypoints.shape(), &[cfg.num_keypoints, 2]);
|
||||
assert_eq!(s.keypoint_visibility.shape(), &[cfg.num_keypoints]);
|
||||
}
|
||||
@@ -812,7 +788,11 @@ mod tests {
|
||||
let ds = SyntheticCsiDataset::new(10, cfg);
|
||||
let s0a = ds.get(3).unwrap();
|
||||
let s0b = ds.get(3).unwrap();
|
||||
assert_abs_diff_eq!(s0a.amplitude[[0, 0, 0, 0]], s0b.amplitude[[0, 0, 0, 0]], epsilon = 1e-7);
|
||||
assert_abs_diff_eq!(
|
||||
s0a.amplitude[[0, 0, 0, 0]],
|
||||
s0b.amplitude[[0, 0, 0, 0]],
|
||||
epsilon = 1e-7
|
||||
);
|
||||
assert_abs_diff_eq!(s0a.keypoints[[5, 0]], s0b.keypoints[[5, 0]], epsilon = 1e-7);
|
||||
}
|
||||
|
||||
@@ -829,7 +809,10 @@ mod tests {
|
||||
#[test]
|
||||
fn synthetic_out_of_bounds() {
|
||||
let ds = SyntheticCsiDataset::new(5, SyntheticConfig::default());
|
||||
assert!(matches!(ds.get(5), Err(DatasetError::IndexOutOfBounds { idx: 5, len: 5 })));
|
||||
assert!(matches!(
|
||||
ds.get(5),
|
||||
Err(DatasetError::IndexOutOfBounds { index: 5, len: 5 })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -861,7 +844,7 @@ mod tests {
|
||||
#[test]
|
||||
fn synthetic_all_joints_visible() {
|
||||
let cfg = SyntheticConfig::default();
|
||||
let ds = SyntheticCsiDataset::new(3, cfg.clone());
|
||||
let ds = SyntheticCsiDataset::new(3, cfg);
|
||||
let s = ds.get(0).unwrap();
|
||||
assert!(s.keypoint_visibility.iter().all(|&v| (v - 2.0).abs() < 1e-6));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user