fix: Review fixes for end-to-end training pipeline
- Snapshot best-epoch weights during training and restore before checkpoint/RVF export (prevents exporting overfit final-epoch params) - Add CsiToPoseTransformer::zeros() for fast zero-init when weights will be overwritten, avoiding wasteful Xavier init during gradient estimation (~2*param_count transformer constructions per batch) - Deduplicate synthetic data generation in main.rs training mode Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
@@ -398,6 +398,8 @@ pub struct Trainer {
|
||||
best_val_loss: f32,
|
||||
best_epoch: usize,
|
||||
epochs_without_improvement: usize,
|
||||
/// Snapshot of params at the best validation loss epoch.
|
||||
best_params: Vec<f32>,
|
||||
/// When set, predict_keypoints delegates to the transformer's forward().
|
||||
transformer: Option<CsiToPoseTransformer>,
|
||||
/// Transformer config (needed for unflatten during gradient estimation).
|
||||
@@ -411,10 +413,11 @@ impl Trainer {
|
||||
config.warmup_epochs, config.lr, config.min_lr, config.epochs,
|
||||
);
|
||||
let params: Vec<f32> = (0..64).map(|i| (i as f32 * 0.7 + 0.3).sin() * 0.1).collect();
|
||||
let best_params = params.clone();
|
||||
Self {
|
||||
config, optimizer, scheduler, params, history: Vec::new(),
|
||||
best_val_loss: f32::MAX, best_epoch: 0, epochs_without_improvement: 0,
|
||||
transformer: None, transformer_config: None,
|
||||
best_params, transformer: None, transformer_config: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -427,10 +430,11 @@ impl Trainer {
|
||||
config.warmup_epochs, config.lr, config.min_lr, config.epochs,
|
||||
);
|
||||
let tc = transformer.config().clone();
|
||||
let best_params = params.clone();
|
||||
Self {
|
||||
config, optimizer, scheduler, params, history: Vec::new(),
|
||||
best_val_loss: f32::MAX, best_epoch: 0, epochs_without_improvement: 0,
|
||||
transformer: Some(transformer), transformer_config: Some(tc),
|
||||
best_params, transformer: Some(transformer), transformer_config: Some(tc),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -523,12 +527,15 @@ impl Trainer {
|
||||
if val_loss < self.best_val_loss {
|
||||
self.best_val_loss = val_loss;
|
||||
self.best_epoch = stats.epoch;
|
||||
self.best_params = self.params.clone();
|
||||
self.epochs_without_improvement = 0;
|
||||
} else {
|
||||
self.epochs_without_improvement += 1;
|
||||
}
|
||||
if self.should_stop() { break; }
|
||||
}
|
||||
// Restore best-epoch params for checkpoint and downstream use
|
||||
self.params = self.best_params.clone();
|
||||
let best = self.best_metrics().cloned().unwrap_or(EpochStats {
|
||||
epoch: 0, train_loss: f32::MAX, val_loss: f32::MAX, pck_02: 0.0,
|
||||
oks_map: 0.0, lr: self.config.lr, loss_components: LossComponents::default(),
|
||||
@@ -625,12 +632,12 @@ impl Trainer {
|
||||
}).collect()
|
||||
}
|
||||
|
||||
/// Predict keypoints using the graph transformer. Creates a temporary
|
||||
/// transformer with the given params and runs forward().
|
||||
/// Predict keypoints using the graph transformer. Uses zero-init
|
||||
/// constructor (fast) then overwrites all weights from params.
|
||||
fn predict_keypoints_transformer(
|
||||
params: &[f32], sample: &TrainingSample, tc: &TransformerConfig,
|
||||
) -> Vec<(f32, f32, f32)> {
|
||||
let mut t = CsiToPoseTransformer::new(tc.clone());
|
||||
let mut t = CsiToPoseTransformer::zeros(tc.clone());
|
||||
if t.unflatten_weights(params).is_err() {
|
||||
return Self::predict_keypoints(params, sample);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user