feat(rust): Add wifi-densepose-train crate with full training pipeline
Implements the training infrastructure described in ADR-015: - config.rs: TrainingConfig with all hyperparams (batch size, LR, loss weights, subcarrier interp method, validation split) - dataset.rs: MmFiDataset (real MM-Fi .npy loader) + SyntheticDataset (deterministic LCG, seed=42, proof/testing only — never production) - subcarrier.rs: Linear/cubic interpolation 114→56 subcarriers - error.rs: Typed errors (DataNotFound, InvalidFormat, IoError) - losses.rs: Keypoint heatmap (MSE), DensePose (CE + Smooth L1), teacher-student transfer (MSE), Gaussian heatmap generation - metrics.rs: PCK@0.2, OKS with Hungarian min-cut bipartite assignment via petgraph (optimal multi-person keypoint matching) - model.rs: WiFiDensePoseModel end-to-end with tch-rs (PyTorch bindings) - trainer.rs: Full training loop, LR scheduling, gradient clipping, early stopping, CSV logging, best-checkpoint saving - proof.rs: Deterministic training proof (SHA-256 trust kill switch) No random data in production paths. SyntheticDataset uses deterministic LCG (a=1664525, c=1013904223) — same seed always produces same output. https://claude.ai/code/session_01BSBAQJ34SLkiJy4A8SoiL4
This commit is contained in:
@@ -0,0 +1,61 @@
|
||||
//! # WiFi-DensePose Training Infrastructure
|
||||
//!
|
||||
//! This crate provides the complete training pipeline for the WiFi-DensePose pose
|
||||
//! estimation model. It includes configuration management, dataset loading with
|
||||
//! subcarrier interpolation, loss functions, evaluation metrics, and the training
|
||||
//! loop orchestrator.
|
||||
//!
|
||||
//! ## Architecture
|
||||
//!
|
||||
//! ```text
|
||||
//! TrainingConfig ──► Trainer ──► Model
|
||||
//! │ │
|
||||
//! │ DataLoader
|
||||
//! │ │
|
||||
//! │ CsiDataset (MmFiDataset | SyntheticCsiDataset)
|
||||
//! │ │
|
||||
//! │ subcarrier::interpolate_subcarriers
|
||||
//! │
|
||||
//! └──► losses / metrics
|
||||
//! ```
|
||||
//!
|
||||
//! ## Quick Start
|
||||
//!
|
||||
//! ```rust,no_run
|
||||
//! use wifi_densepose_train::config::TrainingConfig;
|
||||
//! use wifi_densepose_train::dataset::{SyntheticCsiDataset, SyntheticConfig, CsiDataset};
|
||||
//!
|
||||
//! // Build config
|
||||
//! let config = TrainingConfig::default();
|
||||
//! config.validate().expect("config is valid");
|
||||
//!
|
||||
//! // Create a synthetic dataset (deterministic, fixed-seed)
|
||||
//! let syn_cfg = SyntheticConfig::default();
|
||||
//! let dataset = SyntheticCsiDataset::new(200, syn_cfg);
|
||||
//!
|
||||
//! // Load one sample
|
||||
//! let sample = dataset.get(0).unwrap();
|
||||
//! println!("amplitude shape: {:?}", sample.amplitude.shape());
|
||||
//! ```
|
||||
|
||||
#![forbid(unsafe_code)]
|
||||
#![warn(missing_docs)]
|
||||
|
||||
pub mod config;
|
||||
pub mod dataset;
|
||||
pub mod error;
|
||||
pub mod losses;
|
||||
pub mod metrics;
|
||||
pub mod model;
|
||||
pub mod proof;
|
||||
pub mod subcarrier;
|
||||
pub mod trainer;
|
||||
|
||||
// Convenient re-exports at the crate root.
|
||||
pub use config::TrainingConfig;
|
||||
pub use dataset::{CsiDataset, CsiSample, DataLoader, MmFiDataset, SyntheticCsiDataset, SyntheticConfig};
|
||||
pub use error::{ConfigError, DatasetError, SubcarrierError, TrainError, TrainResult};
|
||||
pub use subcarrier::{compute_interp_weights, interpolate_subcarriers, select_subcarriers_by_variance};
|
||||
|
||||
/// Crate version string.
|
||||
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
Reference in New Issue
Block a user