fix: Review fixes for end-to-end training pipeline
- Snapshot best-epoch weights during training and restore before checkpoint/RVF export (prevents exporting overfit final-epoch params) - Add CsiToPoseTransformer::zeros() for fast zero-init when weights will be overwritten, avoiding wasteful Xavier init during gradient estimation (~2*param_count transformer constructions per batch) - Deduplicate synthetic data generation in main.rs training mode Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
@@ -452,6 +452,23 @@ impl CsiToPoseTransformer {
|
||||
config,
|
||||
}
|
||||
}
|
||||
/// Construct with zero-initialized weights (faster than Xavier init).
|
||||
/// Use with `unflatten_weights()` when you plan to overwrite all weights.
|
||||
pub fn zeros(config: TransformerConfig) -> Self {
|
||||
let d = config.d_model;
|
||||
let bg = BodyGraph::new();
|
||||
let kq = vec![vec![0.0f32; d]; config.n_keypoints];
|
||||
Self {
|
||||
csi_embed: Linear::zeros(config.n_subcarriers, d),
|
||||
keypoint_queries: kq,
|
||||
cross_attn: CrossAttention::new(d, config.n_heads), // small; kept for correct structure
|
||||
gnn: GnnStack::new(d, d, config.n_gnn_layers, &bg),
|
||||
xyz_head: Linear::zeros(d, 3),
|
||||
conf_head: Linear::zeros(d, 1),
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
/// csi_features [n_antenna_pairs, n_subcarriers] -> PoseOutput with 17 keypoints.
|
||||
pub fn forward(&self, csi_features: &[Vec<f32>]) -> PoseOutput {
|
||||
let embedded: Vec<Vec<f32>> = csi_features.iter()
|
||||
|
||||
Reference in New Issue
Block a user