fix: Review fixes for end-to-end training pipeline

- Snapshot best-epoch weights during training and restore before
  checkpoint/RVF export (prevents exporting overfit final-epoch params)
- Add CsiToPoseTransformer::zeros() for fast zero-init when weights
  will be overwritten, avoiding wasteful Xavier init during gradient
  estimation (~2*param_count transformer constructions per batch)
- Deduplicate synthetic data generation in main.rs training mode

Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
ruv
2026-02-28 23:58:20 -05:00
parent 4cabffa726
commit 45f0304d52
3 changed files with 57 additions and 52 deletions

View File

@@ -452,6 +452,23 @@ impl CsiToPoseTransformer {
config,
}
}
/// Construct with zero-initialized weights (faster than Xavier init).
/// Use with `unflatten_weights()` when you plan to overwrite all weights.
pub fn zeros(config: TransformerConfig) -> Self {
let d = config.d_model;
let bg = BodyGraph::new();
let kq = vec![vec![0.0f32; d]; config.n_keypoints];
Self {
csi_embed: Linear::zeros(config.n_subcarriers, d),
keypoint_queries: kq,
cross_attn: CrossAttention::new(d, config.n_heads), // small; kept for correct structure
gnn: GnnStack::new(d, d, config.n_gnn_layers, &bg),
xyz_head: Linear::zeros(d, 3),
conf_head: Linear::zeros(d, 1),
config,
}
}
/// csi_features [n_antenna_pairs, n_subcarriers] -> PoseOutput with 17 keypoints.
pub fn forward(&self, csi_features: &[Vec<f32>]) -> PoseOutput {
let embedded: Vec<Vec<f32>> = csi_features.iter()