fix: Review fixes for end-to-end training pipeline
- Snapshot best-epoch weights during training and restore before checkpoint/RVF export (prevents exporting overfit final-epoch params) - Add CsiToPoseTransformer::zeros() for fast zero-init when weights will be overwritten, avoiding wasteful Xavier init during gradient estimation (~2*param_count transformer constructions per batch) - Deduplicate synthetic data generation in main.rs training mode Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
@@ -452,6 +452,23 @@ impl CsiToPoseTransformer {
|
|||||||
config,
|
config,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
/// Construct with zero-initialized weights (faster than Xavier init).
|
||||||
|
/// Use with `unflatten_weights()` when you plan to overwrite all weights.
|
||||||
|
pub fn zeros(config: TransformerConfig) -> Self {
|
||||||
|
let d = config.d_model;
|
||||||
|
let bg = BodyGraph::new();
|
||||||
|
let kq = vec![vec![0.0f32; d]; config.n_keypoints];
|
||||||
|
Self {
|
||||||
|
csi_embed: Linear::zeros(config.n_subcarriers, d),
|
||||||
|
keypoint_queries: kq,
|
||||||
|
cross_attn: CrossAttention::new(d, config.n_heads), // small; kept for correct structure
|
||||||
|
gnn: GnnStack::new(d, d, config.n_gnn_layers, &bg),
|
||||||
|
xyz_head: Linear::zeros(d, 3),
|
||||||
|
conf_head: Linear::zeros(d, 1),
|
||||||
|
config,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// csi_features [n_antenna_pairs, n_subcarriers] -> PoseOutput with 17 keypoints.
|
/// csi_features [n_antenna_pairs, n_subcarriers] -> PoseOutput with 17 keypoints.
|
||||||
pub fn forward(&self, csi_features: &[Vec<f32>]) -> PoseOutput {
|
pub fn forward(&self, csi_features: &[Vec<f32>]) -> PoseOutput {
|
||||||
let embedded: Vec<Vec<f32>> = csi_features.iter()
|
let embedded: Vec<Vec<f32>> = csi_features.iter()
|
||||||
|
|||||||
@@ -1551,61 +1551,42 @@ async fn main() {
|
|||||||
..Default::default()
|
..Default::default()
|
||||||
});
|
});
|
||||||
|
|
||||||
// Load samples
|
// Generate synthetic training data (50 samples with deterministic CSI + keypoints)
|
||||||
|
let generate_synthetic = || -> Vec<dataset::TrainingSample> {
|
||||||
|
(0..50).map(|i| {
|
||||||
|
let csi: Vec<Vec<f32>> = (0..4).map(|a| {
|
||||||
|
(0..56).map(|s| ((i * 7 + a * 13 + s) as f32 * 0.31).sin() * 0.5).collect()
|
||||||
|
}).collect();
|
||||||
|
let mut kps = [(0.0f32, 0.0f32, 1.0f32); 17];
|
||||||
|
for (k, kp) in kps.iter_mut().enumerate() {
|
||||||
|
kp.0 = (k as f32 * 0.1 + i as f32 * 0.02).sin() * 100.0 + 320.0;
|
||||||
|
kp.1 = (k as f32 * 0.15 + i as f32 * 0.03).cos() * 80.0 + 240.0;
|
||||||
|
}
|
||||||
|
dataset::TrainingSample {
|
||||||
|
csi_window: csi,
|
||||||
|
pose_label: dataset::PoseLabel {
|
||||||
|
keypoints: kps,
|
||||||
|
body_parts: Vec::new(),
|
||||||
|
confidence: 1.0,
|
||||||
|
},
|
||||||
|
source: "synthetic",
|
||||||
|
}
|
||||||
|
}).collect()
|
||||||
|
};
|
||||||
|
|
||||||
|
// Load samples (fall back to synthetic if dataset missing/empty)
|
||||||
let samples = match pipeline.load() {
|
let samples = match pipeline.load() {
|
||||||
Ok(s) if !s.is_empty() => {
|
Ok(s) if !s.is_empty() => {
|
||||||
eprintln!("Loaded {} samples from {}", s.len(), ds_path.display());
|
eprintln!("Loaded {} samples from {}", s.len(), ds_path.display());
|
||||||
s
|
s
|
||||||
}
|
}
|
||||||
Ok(_) => {
|
Ok(_) => {
|
||||||
eprintln!("No samples found at {}. Generating synthetic training data...", ds_path.display());
|
eprintln!("No samples found at {}. Using synthetic data.", ds_path.display());
|
||||||
// Generate synthetic samples for testing the pipeline
|
generate_synthetic()
|
||||||
let mut synth = Vec::new();
|
|
||||||
for i in 0..50 {
|
|
||||||
let csi: Vec<Vec<f32>> = (0..4).map(|a| {
|
|
||||||
(0..56).map(|s| ((i * 7 + a * 13 + s) as f32 * 0.31).sin() * 0.5).collect()
|
|
||||||
}).collect();
|
|
||||||
let mut kps = [(0.0f32, 0.0f32, 1.0f32); 17];
|
|
||||||
for (k, kp) in kps.iter_mut().enumerate() {
|
|
||||||
kp.0 = (k as f32 * 0.1 + i as f32 * 0.02).sin() * 100.0 + 320.0;
|
|
||||||
kp.1 = (k as f32 * 0.15 + i as f32 * 0.03).cos() * 80.0 + 240.0;
|
|
||||||
}
|
|
||||||
synth.push(dataset::TrainingSample {
|
|
||||||
csi_window: csi,
|
|
||||||
pose_label: dataset::PoseLabel {
|
|
||||||
keypoints: kps,
|
|
||||||
body_parts: Vec::new(),
|
|
||||||
confidence: 1.0,
|
|
||||||
},
|
|
||||||
source: "synthetic",
|
|
||||||
});
|
|
||||||
}
|
|
||||||
synth
|
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
eprintln!("Failed to load dataset: {e}");
|
eprintln!("Failed to load dataset: {e}. Using synthetic data.");
|
||||||
eprintln!("Generating synthetic training data...");
|
generate_synthetic()
|
||||||
let mut synth = Vec::new();
|
|
||||||
for i in 0..50 {
|
|
||||||
let csi: Vec<Vec<f32>> = (0..4).map(|a| {
|
|
||||||
(0..56).map(|s| ((i * 7 + a * 13 + s) as f32 * 0.31).sin() * 0.5).collect()
|
|
||||||
}).collect();
|
|
||||||
let mut kps = [(0.0f32, 0.0f32, 1.0f32); 17];
|
|
||||||
for (k, kp) in kps.iter_mut().enumerate() {
|
|
||||||
kp.0 = (k as f32 * 0.1 + i as f32 * 0.02).sin() * 100.0 + 320.0;
|
|
||||||
kp.1 = (k as f32 * 0.15 + i as f32 * 0.03).cos() * 80.0 + 240.0;
|
|
||||||
}
|
|
||||||
synth.push(dataset::TrainingSample {
|
|
||||||
csi_window: csi,
|
|
||||||
pose_label: dataset::PoseLabel {
|
|
||||||
keypoints: kps,
|
|
||||||
body_parts: Vec::new(),
|
|
||||||
confidence: 1.0,
|
|
||||||
},
|
|
||||||
source: "synthetic",
|
|
||||||
});
|
|
||||||
}
|
|
||||||
synth
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -398,6 +398,8 @@ pub struct Trainer {
|
|||||||
best_val_loss: f32,
|
best_val_loss: f32,
|
||||||
best_epoch: usize,
|
best_epoch: usize,
|
||||||
epochs_without_improvement: usize,
|
epochs_without_improvement: usize,
|
||||||
|
/// Snapshot of params at the best validation loss epoch.
|
||||||
|
best_params: Vec<f32>,
|
||||||
/// When set, predict_keypoints delegates to the transformer's forward().
|
/// When set, predict_keypoints delegates to the transformer's forward().
|
||||||
transformer: Option<CsiToPoseTransformer>,
|
transformer: Option<CsiToPoseTransformer>,
|
||||||
/// Transformer config (needed for unflatten during gradient estimation).
|
/// Transformer config (needed for unflatten during gradient estimation).
|
||||||
@@ -411,10 +413,11 @@ impl Trainer {
|
|||||||
config.warmup_epochs, config.lr, config.min_lr, config.epochs,
|
config.warmup_epochs, config.lr, config.min_lr, config.epochs,
|
||||||
);
|
);
|
||||||
let params: Vec<f32> = (0..64).map(|i| (i as f32 * 0.7 + 0.3).sin() * 0.1).collect();
|
let params: Vec<f32> = (0..64).map(|i| (i as f32 * 0.7 + 0.3).sin() * 0.1).collect();
|
||||||
|
let best_params = params.clone();
|
||||||
Self {
|
Self {
|
||||||
config, optimizer, scheduler, params, history: Vec::new(),
|
config, optimizer, scheduler, params, history: Vec::new(),
|
||||||
best_val_loss: f32::MAX, best_epoch: 0, epochs_without_improvement: 0,
|
best_val_loss: f32::MAX, best_epoch: 0, epochs_without_improvement: 0,
|
||||||
transformer: None, transformer_config: None,
|
best_params, transformer: None, transformer_config: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -427,10 +430,11 @@ impl Trainer {
|
|||||||
config.warmup_epochs, config.lr, config.min_lr, config.epochs,
|
config.warmup_epochs, config.lr, config.min_lr, config.epochs,
|
||||||
);
|
);
|
||||||
let tc = transformer.config().clone();
|
let tc = transformer.config().clone();
|
||||||
|
let best_params = params.clone();
|
||||||
Self {
|
Self {
|
||||||
config, optimizer, scheduler, params, history: Vec::new(),
|
config, optimizer, scheduler, params, history: Vec::new(),
|
||||||
best_val_loss: f32::MAX, best_epoch: 0, epochs_without_improvement: 0,
|
best_val_loss: f32::MAX, best_epoch: 0, epochs_without_improvement: 0,
|
||||||
transformer: Some(transformer), transformer_config: Some(tc),
|
best_params, transformer: Some(transformer), transformer_config: Some(tc),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -523,12 +527,15 @@ impl Trainer {
|
|||||||
if val_loss < self.best_val_loss {
|
if val_loss < self.best_val_loss {
|
||||||
self.best_val_loss = val_loss;
|
self.best_val_loss = val_loss;
|
||||||
self.best_epoch = stats.epoch;
|
self.best_epoch = stats.epoch;
|
||||||
|
self.best_params = self.params.clone();
|
||||||
self.epochs_without_improvement = 0;
|
self.epochs_without_improvement = 0;
|
||||||
} else {
|
} else {
|
||||||
self.epochs_without_improvement += 1;
|
self.epochs_without_improvement += 1;
|
||||||
}
|
}
|
||||||
if self.should_stop() { break; }
|
if self.should_stop() { break; }
|
||||||
}
|
}
|
||||||
|
// Restore best-epoch params for checkpoint and downstream use
|
||||||
|
self.params = self.best_params.clone();
|
||||||
let best = self.best_metrics().cloned().unwrap_or(EpochStats {
|
let best = self.best_metrics().cloned().unwrap_or(EpochStats {
|
||||||
epoch: 0, train_loss: f32::MAX, val_loss: f32::MAX, pck_02: 0.0,
|
epoch: 0, train_loss: f32::MAX, val_loss: f32::MAX, pck_02: 0.0,
|
||||||
oks_map: 0.0, lr: self.config.lr, loss_components: LossComponents::default(),
|
oks_map: 0.0, lr: self.config.lr, loss_components: LossComponents::default(),
|
||||||
@@ -625,12 +632,12 @@ impl Trainer {
|
|||||||
}).collect()
|
}).collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Predict keypoints using the graph transformer. Creates a temporary
|
/// Predict keypoints using the graph transformer. Uses zero-init
|
||||||
/// transformer with the given params and runs forward().
|
/// constructor (fast) then overwrites all weights from params.
|
||||||
fn predict_keypoints_transformer(
|
fn predict_keypoints_transformer(
|
||||||
params: &[f32], sample: &TrainingSample, tc: &TransformerConfig,
|
params: &[f32], sample: &TrainingSample, tc: &TransformerConfig,
|
||||||
) -> Vec<(f32, f32, f32)> {
|
) -> Vec<(f32, f32, f32)> {
|
||||||
let mut t = CsiToPoseTransformer::new(tc.clone());
|
let mut t = CsiToPoseTransformer::zeros(tc.clone());
|
||||||
if t.unflatten_weights(params).is_err() {
|
if t.unflatten_weights(params).is_err() {
|
||||||
return Self::predict_keypoints(params, sample);
|
return Self::predict_keypoints(params, sample);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user