feat: Training mode, ADR docs, vitals and wifiscan crates
- Add --train CLI flag with dataset loading, graph transformer training, cosine-scheduled SGD, PCK/OKS validation, and checkpoint saving - Refactor main.rs to import training modules from lib.rs instead of duplicating mod declarations - Add ADR-021 (vital sign detection), ADR-022 (Windows WiFi enhanced fidelity), ADR-023 (trained DensePose pipeline) documentation - Add wifi-densepose-vitals crate: breathing, heartrate, anomaly detection, preprocessor, and temporal store - Add wifi-densepose-wifiscan crate: 8-stage signal intelligence pipeline with netsh/wlanapi adapters, multi-BSSID registry, attention weighting, spatial correlation, and breathing extraction Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
@@ -745,4 +745,94 @@ mod tests {
|
||||
assert!((sum - 1.0).abs() < 1e-5);
|
||||
for &wi in &w3 { assert!(wi.is_finite()); }
|
||||
}
|
||||
|
||||
// ── Weight serialization integration tests ────────────────────────
|
||||
|
||||
#[test]
|
||||
fn linear_flatten_unflatten_roundtrip() {
|
||||
let lin = Linear::with_seed(8, 4, 42);
|
||||
let mut flat = Vec::new();
|
||||
lin.flatten_into(&mut flat);
|
||||
assert_eq!(flat.len(), lin.param_count());
|
||||
let (restored, consumed) = Linear::unflatten_from(&flat, 8, 4);
|
||||
assert_eq!(consumed, flat.len());
|
||||
let inp = vec![1.0f32; 8];
|
||||
assert_eq!(lin.forward(&inp), restored.forward(&inp));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cross_attention_flatten_unflatten_roundtrip() {
|
||||
let ca = CrossAttention::new(16, 4);
|
||||
let mut flat = Vec::new();
|
||||
ca.flatten_into(&mut flat);
|
||||
assert_eq!(flat.len(), ca.param_count());
|
||||
let (restored, consumed) = CrossAttention::unflatten_from(&flat, 16, 4);
|
||||
assert_eq!(consumed, flat.len());
|
||||
let q = vec![vec![0.5f32; 16]; 3];
|
||||
let k = vec![vec![0.3f32; 16]; 5];
|
||||
let v = vec![vec![0.7f32; 16]; 5];
|
||||
let orig = ca.forward(&q, &k, &v);
|
||||
let rest = restored.forward(&q, &k, &v);
|
||||
for (a, b) in orig.iter().zip(rest.iter()) {
|
||||
for (x, y) in a.iter().zip(b.iter()) {
|
||||
assert!((x - y).abs() < 1e-6, "mismatch: {x} vs {y}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transformer_weight_roundtrip() {
|
||||
let config = TransformerConfig {
|
||||
n_subcarriers: 16, n_keypoints: 17, d_model: 8, n_heads: 2, n_gnn_layers: 1,
|
||||
};
|
||||
let t = CsiToPoseTransformer::new(config.clone());
|
||||
let weights = t.flatten_weights();
|
||||
assert_eq!(weights.len(), t.param_count());
|
||||
|
||||
let mut t2 = CsiToPoseTransformer::new(config);
|
||||
t2.unflatten_weights(&weights).expect("unflatten should succeed");
|
||||
|
||||
// Forward pass should produce identical results
|
||||
let csi = vec![vec![0.5f32; 16]; 4];
|
||||
let out1 = t.forward(&csi);
|
||||
let out2 = t2.forward(&csi);
|
||||
for (a, b) in out1.keypoints.iter().zip(out2.keypoints.iter()) {
|
||||
assert!((a.0 - b.0).abs() < 1e-6);
|
||||
assert!((a.1 - b.1).abs() < 1e-6);
|
||||
assert!((a.2 - b.2).abs() < 1e-6);
|
||||
}
|
||||
for (a, b) in out1.confidences.iter().zip(out2.confidences.iter()) {
|
||||
assert!((a - b).abs() < 1e-6);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transformer_param_count_positive() {
|
||||
let t = CsiToPoseTransformer::new(TransformerConfig::default());
|
||||
assert!(t.param_count() > 1000, "expected many params, got {}", t.param_count());
|
||||
let flat = t.flatten_weights();
|
||||
assert_eq!(flat.len(), t.param_count());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gnn_stack_flatten_unflatten() {
|
||||
let bg = BodyGraph::new();
|
||||
let gnn = GnnStack::new(8, 8, 2, &bg);
|
||||
let mut flat = Vec::new();
|
||||
gnn.flatten_into(&mut flat);
|
||||
assert_eq!(flat.len(), gnn.param_count());
|
||||
|
||||
let mut gnn2 = GnnStack::new(8, 8, 2, &bg);
|
||||
let consumed = gnn2.unflatten_from(&flat);
|
||||
assert_eq!(consumed, flat.len());
|
||||
|
||||
let feats = vec![vec![1.0f32; 8]; 17];
|
||||
let o1 = gnn.forward(&feats);
|
||||
let o2 = gnn2.forward(&feats);
|
||||
for (a, b) in o1.iter().zip(o2.iter()) {
|
||||
for (x, y) in a.iter().zip(b.iter()) {
|
||||
assert!((x - y).abs() < 1e-6);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user