diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/graph_transformer.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/graph_transformer.rs index 38622f7..f4483ff 100644 --- a/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/graph_transformer.rs +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/graph_transformer.rs @@ -452,6 +452,23 @@ impl CsiToPoseTransformer { config, } } + /// Construct with zero-initialized weights (faster than Xavier init). + /// Use with `unflatten_weights()` when you plan to overwrite all weights. + pub fn zeros(config: TransformerConfig) -> Self { + let d = config.d_model; + let bg = BodyGraph::new(); + let kq = vec![vec![0.0f32; d]; config.n_keypoints]; + Self { + csi_embed: Linear::zeros(config.n_subcarriers, d), + keypoint_queries: kq, + cross_attn: CrossAttention::new(d, config.n_heads), // small; kept for correct structure + gnn: GnnStack::new(d, d, config.n_gnn_layers, &bg), + xyz_head: Linear::zeros(d, 3), + conf_head: Linear::zeros(d, 1), + config, + } + } + /// csi_features [n_antenna_pairs, n_subcarriers] -> PoseOutput with 17 keypoints. pub fn forward(&self, csi_features: &[Vec]) -> PoseOutput { let embedded: Vec> = csi_features.iter() diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/main.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/main.rs index b0078a0..36e40bb 100644 --- a/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/main.rs +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/main.rs @@ -1551,61 +1551,42 @@ async fn main() { ..Default::default() }); - // Load samples + // Generate synthetic training data (50 samples with deterministic CSI + keypoints) + let generate_synthetic = || -> Vec { + (0..50).map(|i| { + let csi: Vec> = (0..4).map(|a| { + (0..56).map(|s| ((i * 7 + a * 13 + s) as f32 * 0.31).sin() * 0.5).collect() + }).collect(); + let mut kps = [(0.0f32, 0.0f32, 1.0f32); 17]; + for (k, kp) in kps.iter_mut().enumerate() { + kp.0 = (k as f32 * 0.1 + i as f32 * 0.02).sin() * 100.0 + 320.0; + kp.1 = (k as f32 * 0.15 + i as f32 * 0.03).cos() * 80.0 + 240.0; + } + dataset::TrainingSample { + csi_window: csi, + pose_label: dataset::PoseLabel { + keypoints: kps, + body_parts: Vec::new(), + confidence: 1.0, + }, + source: "synthetic", + } + }).collect() + }; + + // Load samples (fall back to synthetic if dataset missing/empty) let samples = match pipeline.load() { Ok(s) if !s.is_empty() => { eprintln!("Loaded {} samples from {}", s.len(), ds_path.display()); s } Ok(_) => { - eprintln!("No samples found at {}. Generating synthetic training data...", ds_path.display()); - // Generate synthetic samples for testing the pipeline - let mut synth = Vec::new(); - for i in 0..50 { - let csi: Vec> = (0..4).map(|a| { - (0..56).map(|s| ((i * 7 + a * 13 + s) as f32 * 0.31).sin() * 0.5).collect() - }).collect(); - let mut kps = [(0.0f32, 0.0f32, 1.0f32); 17]; - for (k, kp) in kps.iter_mut().enumerate() { - kp.0 = (k as f32 * 0.1 + i as f32 * 0.02).sin() * 100.0 + 320.0; - kp.1 = (k as f32 * 0.15 + i as f32 * 0.03).cos() * 80.0 + 240.0; - } - synth.push(dataset::TrainingSample { - csi_window: csi, - pose_label: dataset::PoseLabel { - keypoints: kps, - body_parts: Vec::new(), - confidence: 1.0, - }, - source: "synthetic", - }); - } - synth + eprintln!("No samples found at {}. Using synthetic data.", ds_path.display()); + generate_synthetic() } Err(e) => { - eprintln!("Failed to load dataset: {e}"); - eprintln!("Generating synthetic training data..."); - let mut synth = Vec::new(); - for i in 0..50 { - let csi: Vec> = (0..4).map(|a| { - (0..56).map(|s| ((i * 7 + a * 13 + s) as f32 * 0.31).sin() * 0.5).collect() - }).collect(); - let mut kps = [(0.0f32, 0.0f32, 1.0f32); 17]; - for (k, kp) in kps.iter_mut().enumerate() { - kp.0 = (k as f32 * 0.1 + i as f32 * 0.02).sin() * 100.0 + 320.0; - kp.1 = (k as f32 * 0.15 + i as f32 * 0.03).cos() * 80.0 + 240.0; - } - synth.push(dataset::TrainingSample { - csi_window: csi, - pose_label: dataset::PoseLabel { - keypoints: kps, - body_parts: Vec::new(), - confidence: 1.0, - }, - source: "synthetic", - }); - } - synth + eprintln!("Failed to load dataset: {e}. Using synthetic data."); + generate_synthetic() } }; diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/trainer.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/trainer.rs index 2953274..e06b777 100644 --- a/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/trainer.rs +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/trainer.rs @@ -398,6 +398,8 @@ pub struct Trainer { best_val_loss: f32, best_epoch: usize, epochs_without_improvement: usize, + /// Snapshot of params at the best validation loss epoch. + best_params: Vec, /// When set, predict_keypoints delegates to the transformer's forward(). transformer: Option, /// Transformer config (needed for unflatten during gradient estimation). @@ -411,10 +413,11 @@ impl Trainer { config.warmup_epochs, config.lr, config.min_lr, config.epochs, ); let params: Vec = (0..64).map(|i| (i as f32 * 0.7 + 0.3).sin() * 0.1).collect(); + let best_params = params.clone(); Self { config, optimizer, scheduler, params, history: Vec::new(), best_val_loss: f32::MAX, best_epoch: 0, epochs_without_improvement: 0, - transformer: None, transformer_config: None, + best_params, transformer: None, transformer_config: None, } } @@ -427,10 +430,11 @@ impl Trainer { config.warmup_epochs, config.lr, config.min_lr, config.epochs, ); let tc = transformer.config().clone(); + let best_params = params.clone(); Self { config, optimizer, scheduler, params, history: Vec::new(), best_val_loss: f32::MAX, best_epoch: 0, epochs_without_improvement: 0, - transformer: Some(transformer), transformer_config: Some(tc), + best_params, transformer: Some(transformer), transformer_config: Some(tc), } } @@ -523,12 +527,15 @@ impl Trainer { if val_loss < self.best_val_loss { self.best_val_loss = val_loss; self.best_epoch = stats.epoch; + self.best_params = self.params.clone(); self.epochs_without_improvement = 0; } else { self.epochs_without_improvement += 1; } if self.should_stop() { break; } } + // Restore best-epoch params for checkpoint and downstream use + self.params = self.best_params.clone(); let best = self.best_metrics().cloned().unwrap_or(EpochStats { epoch: 0, train_loss: f32::MAX, val_loss: f32::MAX, pck_02: 0.0, oks_map: 0.0, lr: self.config.lr, loss_components: LossComponents::default(), @@ -625,12 +632,12 @@ impl Trainer { }).collect() } - /// Predict keypoints using the graph transformer. Creates a temporary - /// transformer with the given params and runs forward(). + /// Predict keypoints using the graph transformer. Uses zero-init + /// constructor (fast) then overwrites all weights from params. fn predict_keypoints_transformer( params: &[f32], sample: &TrainingSample, tc: &TransformerConfig, ) -> Vec<(f32, f32, f32)> { - let mut t = CsiToPoseTransformer::new(tc.clone()); + let mut t = CsiToPoseTransformer::zeros(tc.clone()); if t.unflatten_weights(params).is_err() { return Self::predict_keypoints(params, sample); }