feat(rust): Complete training pipeline — losses, metrics, model, trainer, binaries

Losses (losses.rs — 1056 lines):
- WiFiDensePoseLoss with keypoint (visibility-masked MSE), DensePose
  (cross-entropy + Smooth L1 UV masked to foreground), transfer (MSE)
- generate_gaussian_heatmaps: Tensor-native 2D Gaussian heatmap gen
- compute_losses: unified functional API
- 11 deterministic unit tests

Metrics (metrics.rs — 984 lines):
- PCK@0.2 / PCK@0.5 with torso-diameter normalisation
- OKS with COCO standard per-joint sigmas
- MetricsAccumulator for online streaming eval
- hungarian_assignment: O(n³) Kuhn-Munkres min-cut via DFS augmenting
  paths for optimal multi-person keypoint assignment (ruvector min-cut)
- build_oks_cost_matrix: 1−OKS cost for bipartite matching
- 20 deterministic tests (perfect/wrong/invisible keypoints, 2×2/3×3/
  rectangular/empty Hungarian cases)

Model (model.rs — 713 lines):
- WiFiDensePoseModel end-to-end with tch-rs
- ModalityTranslator: amp+phase FC encoders → spatial pseudo-image
- Backbone: lightweight ResNet-style [B,3,48,48]→[B,256,6,6]
- KeypointHead: [B,256,6,6]→[B,17,H,W] heatmaps
- DensePoseHead: [B,256,6,6]→[B,25,H,W] parts + [B,48,H,W] UV

Trainer (trainer.rs — 777 lines):
- Full training loop: Adam, LR milestones, gradient clipping
- Deterministic batch shuffle via LCG (seed XOR epoch)
- CSV logging, best-checkpoint saving, early stopping
- evaluate() with MetricsAccumulator and heatmap argmax decode

Binaries:
- src/bin/train.rs: production MM-Fi training CLI (clap)
- src/bin/verify_training.rs: trust kill switch (EXIT 0/1/2)

Benches:
- benches/training_bench.rs: criterion benchmarks for key ops

Tests:
- tests/test_dataset.rs (459 lines)
- tests/test_metrics.rs (449 lines)
- tests/test_subcarrier.rs (389 lines)

proof.rs still stub — trainer agent completing it.

https://claude.ai/code/session_01BSBAQJ34SLkiJy4A8SoiL4
This commit is contained in:
Claude
2026-02-28 15:22:54 +00:00
parent 2c5ca308a4
commit fce1271140
16 changed files with 4828 additions and 159 deletions

View File

@@ -2,7 +2,7 @@
//!
//! This module defines the [`CsiDataset`] trait plus two concrete implementations:
//!
//! - [`MmFiDataset`]: reads MM-Fi NPY/HDF5 files from disk.
//! - [`MmFiDataset`]: reads MM-Fi NPY files from disk.
//! - [`SyntheticCsiDataset`]: generates fully-deterministic CSI from a physics
//! model; useful for unit tests, integration tests, and dry-run sanity checks.
//! **Never uses random data.**
@@ -18,7 +18,7 @@
//! A01/
//! wifi_csi.npy # amplitude [T, n_tx, n_rx, n_sc]
//! wifi_csi_phase.npy # phase [T, n_tx, n_rx, n_sc]
//! gt_keypoints.npy # keypoints [T, 17, 3] (x, y, vis)
//! gt_keypoints.npy # ground-truth keypoints [T, 17, 3] (x, y, vis)
//! A02/
//! ...
//! S02/
@@ -42,9 +42,9 @@
use ndarray::{Array1, Array2, Array4};
use std::path::{Path, PathBuf};
use thiserror::Error;
use tracing::{debug, info, warn};
use crate::error::DatasetError;
use crate::subcarrier::interpolate_subcarriers;
// ---------------------------------------------------------------------------
@@ -259,8 +259,6 @@ struct MmFiEntry {
num_frames: usize,
/// Window size in frames (mirrors config).
window_frames: usize,
/// First global sample index that maps into this clip.
global_start_idx: usize,
}
impl MmFiEntry {
@@ -305,8 +303,8 @@ impl MmFiDataset {
///
/// # Errors
///
/// Returns [`DatasetError::DirectoryNotFound`] if `root` does not exist, or
/// [`DatasetError::Io`] for any filesystem access failure.
/// Returns [`DatasetError::DataNotFound`] if `root` does not exist, or an
/// IO error for any filesystem access failure.
pub fn discover(
root: &Path,
window_frames: usize,
@@ -314,16 +312,17 @@ impl MmFiDataset {
num_keypoints: usize,
) -> Result<Self, DatasetError> {
if !root.exists() {
return Err(DatasetError::DirectoryNotFound {
path: root.display().to_string(),
});
return Err(DatasetError::not_found(
root,
"MM-Fi root directory not found",
));
}
let mut entries: Vec<MmFiEntry> = Vec::new();
let mut global_idx = 0usize;
// Walk subject directories (S01, S02, …)
let mut subject_dirs: Vec<PathBuf> = std::fs::read_dir(root)?
let mut subject_dirs: Vec<PathBuf> = std::fs::read_dir(root)
.map_err(|e| DatasetError::io_error(root, e))?
.filter_map(|e| e.ok())
.filter(|e| e.file_type().map(|t| t.is_dir()).unwrap_or(false))
.map(|e| e.path())
@@ -335,7 +334,8 @@ impl MmFiDataset {
let subject_id = parse_id_suffix(subj_name).unwrap_or(0);
// Walk action directories (A01, A02, …)
let mut action_dirs: Vec<PathBuf> = std::fs::read_dir(subj_path)?
let mut action_dirs: Vec<PathBuf> = std::fs::read_dir(subj_path)
.map_err(|e| DatasetError::io_error(subj_path.as_path(), e))?
.filter_map(|e| e.ok())
.filter(|e| e.file_type().map(|t| t.is_dir()).unwrap_or(false))
.map(|e| e.path())
@@ -368,7 +368,7 @@ impl MmFiDataset {
}
};
let entry = MmFiEntry {
entries.push(MmFiEntry {
subject_id,
action_id,
amp_path,
@@ -376,17 +376,15 @@ impl MmFiDataset {
kp_path,
num_frames,
window_frames,
global_start_idx: global_idx,
};
global_idx += entry.num_windows();
entries.push(entry);
});
}
}
let total_windows: usize = entries.iter().map(|e| e.num_windows()).sum();
info!(
"MmFiDataset: scanned {} clips, {} total windows (root={})",
entries.len(),
global_idx,
total_windows,
root.display()
);
@@ -429,9 +427,11 @@ impl CsiDataset for MmFiDataset {
fn get(&self, idx: usize) -> Result<CsiSample, DatasetError> {
let total = self.len();
let (entry_idx, frame_offset) = self
.locate(idx)
.ok_or(DatasetError::IndexOutOfBounds { idx, len: total })?;
let (entry_idx, frame_offset) =
self.locate(idx).ok_or(DatasetError::IndexOutOfBounds {
index: idx,
len: total,
})?;
let entry = &self.entries[entry_idx];
let t_start = frame_offset;
@@ -441,10 +441,12 @@ impl CsiDataset for MmFiDataset {
let amp_full = load_npy_f32(&entry.amp_path)?;
let (t, n_tx, n_rx, n_sc) = amp_full.dim();
if t_end > t {
return Err(DatasetError::Format(format!(
"window [{t_start}, {t_end}) exceeds clip length {t} in {}",
entry.amp_path.display()
)));
return Err(DatasetError::invalid_format(
&entry.amp_path,
format!(
"window [{t_start}, {t_end}) exceeds clip length {t}"
),
));
}
let amp_window = amp_full
.slice(ndarray::s![t_start..t_end, .., .., ..])
@@ -500,78 +502,77 @@ impl CsiDataset for MmFiDataset {
}
// ---------------------------------------------------------------------------
// NPY helpers (no-HDF5 path; HDF5 path is feature-gated below)
// NPY helpers
// ---------------------------------------------------------------------------
/// Load a 4-D float32 NPY array from disk.
///
/// The NPY format is read using `ndarray_npy`.
fn load_npy_f32(path: &Path) -> Result<Array4<f32>, DatasetError> {
use ndarray_npy::ReadNpyExt;
let file = std::fs::File::open(path)?;
let file = std::fs::File::open(path)
.map_err(|e| DatasetError::io_error(path, e))?;
let arr: ndarray::ArrayD<f32> = ndarray::ArrayD::read_npy(file)
.map_err(|e| DatasetError::Format(format!("NPY read error at {}: {e}", path.display())))?;
arr.into_dimensionality::<ndarray::Ix4>().map_err(|e| {
DatasetError::Format(format!(
"Expected 4-D array in {}, got shape {:?}: {e}",
path.display(),
arr.shape()
))
.map_err(|e| DatasetError::npy_read(path, e.to_string()))?;
arr.into_dimensionality::<ndarray::Ix4>().map_err(|_e| {
DatasetError::invalid_format(
path,
format!("Expected 4-D array, got shape {:?}", arr.shape()),
)
})
}
/// Load a 3-D float32 NPY array (keypoints: `[T, J, 3]`).
fn load_npy_kp(path: &Path, _num_keypoints: usize) -> Result<ndarray::Array3<f32>, DatasetError> {
use ndarray_npy::ReadNpyExt;
let file = std::fs::File::open(path)?;
let file = std::fs::File::open(path)
.map_err(|e| DatasetError::io_error(path, e))?;
let arr: ndarray::ArrayD<f32> = ndarray::ArrayD::read_npy(file)
.map_err(|e| DatasetError::Format(format!("NPY read error at {}: {e}", path.display())))?;
arr.into_dimensionality::<ndarray::Ix3>().map_err(|e| {
DatasetError::Format(format!(
"Expected 3-D keypoint array in {}, got shape {:?}: {e}",
path.display(),
arr.shape()
))
.map_err(|e| DatasetError::npy_read(path, e.to_string()))?;
arr.into_dimensionality::<ndarray::Ix3>().map_err(|_e| {
DatasetError::invalid_format(
path,
format!("Expected 3-D keypoint array, got shape {:?}", arr.shape()),
)
})
}
/// Read only the first dimension of an NPY header (the frame count) without
/// loading the entire file into memory.
fn peek_npy_first_dim(path: &Path) -> Result<usize, DatasetError> {
// Minimum viable NPY header parse: magic + version + header_len + header.
use std::io::{BufReader, Read};
let f = std::fs::File::open(path)?;
let f = std::fs::File::open(path)
.map_err(|e| DatasetError::io_error(path, e))?;
let mut reader = BufReader::new(f);
let mut magic = [0u8; 6];
reader.read_exact(&mut magic)?;
reader.read_exact(&mut magic)
.map_err(|e| DatasetError::io_error(path, e))?;
if &magic != b"\x93NUMPY" {
return Err(DatasetError::Format(format!(
"Not a valid NPY file: {}",
path.display()
)));
return Err(DatasetError::invalid_format(path, "Not a valid NPY file"));
}
let mut version = [0u8; 2];
reader.read_exact(&mut version)?;
reader.read_exact(&mut version)
.map_err(|e| DatasetError::io_error(path, e))?;
// Header length field: 2 bytes in v1, 4 bytes in v2
let header_len: usize = if version[0] == 1 {
let mut buf = [0u8; 2];
reader.read_exact(&mut buf)?;
reader.read_exact(&mut buf)
.map_err(|e| DatasetError::io_error(path, e))?;
u16::from_le_bytes(buf) as usize
} else {
let mut buf = [0u8; 4];
reader.read_exact(&mut buf)?;
reader.read_exact(&mut buf)
.map_err(|e| DatasetError::io_error(path, e))?;
u32::from_le_bytes(buf) as usize
};
let mut header = vec![0u8; header_len];
reader.read_exact(&mut header)?;
reader.read_exact(&mut header)
.map_err(|e| DatasetError::io_error(path, e))?;
let header_str = String::from_utf8_lossy(&header);
// Parse the shape tuple using a simple substring search.
// Example header: "{'descr': '<f4', 'fortran_order': False, 'shape': (300, 3, 3, 114), }"
if let Some(start) = header_str.find("'shape': (") {
let rest = &header_str[start + "'shape': (".len()..];
if let Some(end) = rest.find(')') {
@@ -586,10 +587,7 @@ fn peek_npy_first_dim(path: &Path) -> Result<usize, DatasetError> {
}
}
Err(DatasetError::Format(format!(
"Cannot parse shape from NPY header in {}",
path.display()
)))
Err(DatasetError::invalid_format(path, "Cannot parse shape from NPY header"))
}
/// Parse the numeric suffix of a directory name like `S01` → `1` or `A12` → `12`.
@@ -711,7 +709,7 @@ impl CsiDataset for SyntheticCsiDataset {
fn get(&self, idx: usize) -> Result<CsiSample, DatasetError> {
if idx >= self.num_samples {
return Err(DatasetError::IndexOutOfBounds {
idx,
index: idx,
len: self.num_samples,
});
}
@@ -755,34 +753,6 @@ impl CsiDataset for SyntheticCsiDataset {
}
}
// ---------------------------------------------------------------------------
// DatasetError
// ---------------------------------------------------------------------------
/// Errors produced by dataset operations.
#[derive(Debug, Error)]
pub enum DatasetError {
/// Requested index is outside the valid range.
#[error("Index {idx} out of bounds (dataset has {len} samples)")]
IndexOutOfBounds { idx: usize, len: usize },
/// An underlying file-system error occurred.
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
/// The file exists but does not match the expected format.
#[error("File format error: {0}")]
Format(String),
/// The loaded array has a different subcarrier count than required.
#[error("Subcarrier count mismatch: expected {expected}, got {actual}")]
SubcarrierMismatch { expected: usize, actual: usize },
/// The specified root directory does not exist.
#[error("Directory not found: {path}")]
DirectoryNotFound { path: String },
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
@@ -800,8 +770,14 @@ mod tests {
let ds = SyntheticCsiDataset::new(10, cfg.clone());
let s = ds.get(0).unwrap();
assert_eq!(s.amplitude.shape(), &[cfg.window_frames, cfg.num_antennas_tx, cfg.num_antennas_rx, cfg.num_subcarriers]);
assert_eq!(s.phase.shape(), &[cfg.window_frames, cfg.num_antennas_tx, cfg.num_antennas_rx, cfg.num_subcarriers]);
assert_eq!(
s.amplitude.shape(),
&[cfg.window_frames, cfg.num_antennas_tx, cfg.num_antennas_rx, cfg.num_subcarriers]
);
assert_eq!(
s.phase.shape(),
&[cfg.window_frames, cfg.num_antennas_tx, cfg.num_antennas_rx, cfg.num_subcarriers]
);
assert_eq!(s.keypoints.shape(), &[cfg.num_keypoints, 2]);
assert_eq!(s.keypoint_visibility.shape(), &[cfg.num_keypoints]);
}
@@ -812,7 +788,11 @@ mod tests {
let ds = SyntheticCsiDataset::new(10, cfg);
let s0a = ds.get(3).unwrap();
let s0b = ds.get(3).unwrap();
assert_abs_diff_eq!(s0a.amplitude[[0, 0, 0, 0]], s0b.amplitude[[0, 0, 0, 0]], epsilon = 1e-7);
assert_abs_diff_eq!(
s0a.amplitude[[0, 0, 0, 0]],
s0b.amplitude[[0, 0, 0, 0]],
epsilon = 1e-7
);
assert_abs_diff_eq!(s0a.keypoints[[5, 0]], s0b.keypoints[[5, 0]], epsilon = 1e-7);
}
@@ -829,7 +809,10 @@ mod tests {
#[test]
fn synthetic_out_of_bounds() {
let ds = SyntheticCsiDataset::new(5, SyntheticConfig::default());
assert!(matches!(ds.get(5), Err(DatasetError::IndexOutOfBounds { idx: 5, len: 5 })));
assert!(matches!(
ds.get(5),
Err(DatasetError::IndexOutOfBounds { index: 5, len: 5 })
));
}
#[test]
@@ -861,7 +844,7 @@ mod tests {
#[test]
fn synthetic_all_joints_visible() {
let cfg = SyntheticConfig::default();
let ds = SyntheticCsiDataset::new(3, cfg.clone());
let ds = SyntheticCsiDataset::new(3, cfg);
let s = ds.get(0).unwrap();
assert!(s.keypoint_visibility.iter().all(|&v| (v - 2.0).abs() < 1e-6));
}