fix: Review fixes for end-to-end training pipeline
- Snapshot best-epoch weights during training and restore before checkpoint/RVF export (prevents exporting overfit final-epoch params) - Add CsiToPoseTransformer::zeros() for fast zero-init when weights will be overwritten, avoiding wasteful Xavier init during gradient estimation (~2*param_count transformer constructions per batch) - Deduplicate synthetic data generation in main.rs training mode Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
@@ -1551,61 +1551,42 @@ async fn main() {
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
// Load samples
|
||||
// Generate synthetic training data (50 samples with deterministic CSI + keypoints)
|
||||
let generate_synthetic = || -> Vec<dataset::TrainingSample> {
|
||||
(0..50).map(|i| {
|
||||
let csi: Vec<Vec<f32>> = (0..4).map(|a| {
|
||||
(0..56).map(|s| ((i * 7 + a * 13 + s) as f32 * 0.31).sin() * 0.5).collect()
|
||||
}).collect();
|
||||
let mut kps = [(0.0f32, 0.0f32, 1.0f32); 17];
|
||||
for (k, kp) in kps.iter_mut().enumerate() {
|
||||
kp.0 = (k as f32 * 0.1 + i as f32 * 0.02).sin() * 100.0 + 320.0;
|
||||
kp.1 = (k as f32 * 0.15 + i as f32 * 0.03).cos() * 80.0 + 240.0;
|
||||
}
|
||||
dataset::TrainingSample {
|
||||
csi_window: csi,
|
||||
pose_label: dataset::PoseLabel {
|
||||
keypoints: kps,
|
||||
body_parts: Vec::new(),
|
||||
confidence: 1.0,
|
||||
},
|
||||
source: "synthetic",
|
||||
}
|
||||
}).collect()
|
||||
};
|
||||
|
||||
// Load samples (fall back to synthetic if dataset missing/empty)
|
||||
let samples = match pipeline.load() {
|
||||
Ok(s) if !s.is_empty() => {
|
||||
eprintln!("Loaded {} samples from {}", s.len(), ds_path.display());
|
||||
s
|
||||
}
|
||||
Ok(_) => {
|
||||
eprintln!("No samples found at {}. Generating synthetic training data...", ds_path.display());
|
||||
// Generate synthetic samples for testing the pipeline
|
||||
let mut synth = Vec::new();
|
||||
for i in 0..50 {
|
||||
let csi: Vec<Vec<f32>> = (0..4).map(|a| {
|
||||
(0..56).map(|s| ((i * 7 + a * 13 + s) as f32 * 0.31).sin() * 0.5).collect()
|
||||
}).collect();
|
||||
let mut kps = [(0.0f32, 0.0f32, 1.0f32); 17];
|
||||
for (k, kp) in kps.iter_mut().enumerate() {
|
||||
kp.0 = (k as f32 * 0.1 + i as f32 * 0.02).sin() * 100.0 + 320.0;
|
||||
kp.1 = (k as f32 * 0.15 + i as f32 * 0.03).cos() * 80.0 + 240.0;
|
||||
}
|
||||
synth.push(dataset::TrainingSample {
|
||||
csi_window: csi,
|
||||
pose_label: dataset::PoseLabel {
|
||||
keypoints: kps,
|
||||
body_parts: Vec::new(),
|
||||
confidence: 1.0,
|
||||
},
|
||||
source: "synthetic",
|
||||
});
|
||||
}
|
||||
synth
|
||||
eprintln!("No samples found at {}. Using synthetic data.", ds_path.display());
|
||||
generate_synthetic()
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Failed to load dataset: {e}");
|
||||
eprintln!("Generating synthetic training data...");
|
||||
let mut synth = Vec::new();
|
||||
for i in 0..50 {
|
||||
let csi: Vec<Vec<f32>> = (0..4).map(|a| {
|
||||
(0..56).map(|s| ((i * 7 + a * 13 + s) as f32 * 0.31).sin() * 0.5).collect()
|
||||
}).collect();
|
||||
let mut kps = [(0.0f32, 0.0f32, 1.0f32); 17];
|
||||
for (k, kp) in kps.iter_mut().enumerate() {
|
||||
kp.0 = (k as f32 * 0.1 + i as f32 * 0.02).sin() * 100.0 + 320.0;
|
||||
kp.1 = (k as f32 * 0.15 + i as f32 * 0.03).cos() * 80.0 + 240.0;
|
||||
}
|
||||
synth.push(dataset::TrainingSample {
|
||||
csi_window: csi,
|
||||
pose_label: dataset::PoseLabel {
|
||||
keypoints: kps,
|
||||
body_parts: Vec::new(),
|
||||
confidence: 1.0,
|
||||
},
|
||||
source: "synthetic",
|
||||
});
|
||||
}
|
||||
synth
|
||||
eprintln!("Failed to load dataset: {e}. Using synthetic data.");
|
||||
generate_synthetic()
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user