feat(rust): Add workspace deps, tests, and refine training modules
- Cargo.toml: Add wifi-densepose-train to workspace members; add petgraph, ndarray-npy, walkdir, sha2, csv, indicatif, clap to workspace dependencies - error.rs: Slim down to focused error types (TrainError, DatasetError) - lib.rs: Wire up all module re-exports correctly - losses.rs: Add generate_gaussian_heatmaps implementation - tests/test_config.rs: Deterministic config roundtrip and validation tests https://claude.ai/code/session_01BSBAQJ34SLkiJy4A8SoiL4
This commit is contained in:
@@ -906,4 +906,148 @@ mod tests {
|
||||
"DensePose loss with identical UV should be bounded by CE, got {val}"
|
||||
);
|
||||
}
|
||||
|
||||
// ── Standalone functional API tests ──────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn test_fn_keypoint_heatmap_loss_identical_zero() {
|
||||
let dev = device();
|
||||
let t = Tensor::ones([2, 17, 8, 8], (Kind::Float, dev));
|
||||
let loss = keypoint_heatmap_loss(&t, &t);
|
||||
let v = loss.double_value(&[]) as f32;
|
||||
assert!(v.abs() < 1e-6, "Identical heatmaps → loss must be ≈0, got {v}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fn_generate_gaussian_heatmaps_shape() {
|
||||
let dev = device();
|
||||
let kpts = Tensor::full(&[2i64, 17, 2], 0.5, (Kind::Float, dev));
|
||||
let vis = Tensor::ones(&[2i64, 17], (Kind::Float, dev));
|
||||
let hm = generate_gaussian_heatmaps(&kpts, &vis, 16, 2.0);
|
||||
assert_eq!(hm.size(), [2, 17, 16, 16]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fn_generate_gaussian_heatmaps_invisible_zero() {
|
||||
let dev = device();
|
||||
let kpts = Tensor::full(&[1i64, 17, 2], 0.5, (Kind::Float, dev));
|
||||
let vis = Tensor::zeros(&[1i64, 17], (Kind::Float, dev)); // all invisible
|
||||
let hm = generate_gaussian_heatmaps(&kpts, &vis, 8, 2.0);
|
||||
let total: f64 = hm.sum(Kind::Float).double_value(&[]);
|
||||
assert_eq!(total, 0.0, "All-invisible heatmaps must be zero");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fn_generate_gaussian_heatmaps_peak_near_one() {
|
||||
let dev = device();
|
||||
// Keypoint at (0.5, 0.5) on an 8×8 map.
|
||||
let kpts = Tensor::full(&[1i64, 1, 2], 0.5, (Kind::Float, dev));
|
||||
let vis = Tensor::ones(&[1i64, 1], (Kind::Float, dev));
|
||||
let hm = generate_gaussian_heatmaps(&kpts, &vis, 8, 1.5);
|
||||
let max_val: f64 = hm.max().double_value(&[]);
|
||||
assert!(max_val > 0.9, "Peak value {max_val} should be > 0.9");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fn_densepose_part_loss_returns_finite() {
|
||||
let dev = device();
|
||||
let logits = Tensor::zeros(&[1i64, 25, 4, 4], (Kind::Float, dev));
|
||||
let labels = Tensor::zeros(&[1i64, 4, 4], (Kind::Int64, dev));
|
||||
let loss = densepose_part_loss(&logits, &labels);
|
||||
let v = loss.double_value(&[]);
|
||||
assert!(v.is_finite() && v >= 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fn_densepose_uv_loss_no_annotated_pixels_zero() {
|
||||
let dev = device();
|
||||
let pred = Tensor::ones(&[1i64, 48, 4, 4], (Kind::Float, dev));
|
||||
let gt = Tensor::zeros(&[1i64, 48, 4, 4], (Kind::Float, dev));
|
||||
let labels = Tensor::full(&[1i64, 4, 4], -1i64, (Kind::Int64, dev));
|
||||
let loss = densepose_uv_loss(&pred, >, &labels);
|
||||
let v = loss.double_value(&[]);
|
||||
assert_eq!(v, 0.0, "No annotated pixels → UV loss must be 0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fn_densepose_uv_loss_identical_zero() {
|
||||
let dev = device();
|
||||
let t = Tensor::ones(&[1i64, 48, 4, 4], (Kind::Float, dev));
|
||||
let labels = Tensor::zeros(&[1i64, 4, 4], (Kind::Int64, dev));
|
||||
let loss = densepose_uv_loss(&t, &t, &labels);
|
||||
let v = loss.double_value(&[]);
|
||||
assert!(v.abs() < 1e-6, "Identical UV → loss ≈ 0, got {v}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fn_transfer_loss_identical_zero() {
|
||||
let dev = device();
|
||||
let t = Tensor::ones(&[2i64, 64, 8, 8], (Kind::Float, dev));
|
||||
let loss = fn_transfer_loss(&t, &t);
|
||||
let v = loss.double_value(&[]);
|
||||
assert!(v.abs() < 1e-6, "Identical features → transfer loss ≈ 0, got {v}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fn_transfer_loss_spatial_mismatch() {
|
||||
let dev = device();
|
||||
let student = Tensor::ones(&[1i64, 64, 16, 16], (Kind::Float, dev));
|
||||
let teacher = Tensor::ones(&[1i64, 64, 8, 8], (Kind::Float, dev));
|
||||
let loss = fn_transfer_loss(&student, &teacher);
|
||||
let v = loss.double_value(&[]);
|
||||
assert!(v.is_finite() && v >= 0.0, "Spatial-mismatch transfer loss must be finite");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fn_transfer_loss_channel_mismatch_divisible() {
|
||||
let dev = device();
|
||||
let student = Tensor::ones(&[1i64, 128, 8, 8], (Kind::Float, dev));
|
||||
let teacher = Tensor::ones(&[1i64, 64, 8, 8], (Kind::Float, dev));
|
||||
let loss = fn_transfer_loss(&student, &teacher);
|
||||
let v = loss.double_value(&[]);
|
||||
assert!(v.is_finite() && v >= 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_losses_keypoint_only() {
|
||||
let dev = device();
|
||||
let pred = Tensor::ones(&[1i64, 17, 8, 8], (Kind::Float, dev));
|
||||
let gt = Tensor::ones(&[1i64, 17, 8, 8], (Kind::Float, dev));
|
||||
let out = compute_losses(&pred, >, None, None, None, None, None, None,
|
||||
1.0, 1.0, 1.0);
|
||||
assert!(out.total.is_finite());
|
||||
assert!(out.keypoint >= 0.0);
|
||||
assert!(out.densepose_parts.is_none());
|
||||
assert!(out.densepose_uv.is_none());
|
||||
assert!(out.transfer.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_losses_all_components_finite() {
|
||||
let dev = device();
|
||||
let b = 1i64;
|
||||
let h = 4i64;
|
||||
let w = 4i64;
|
||||
let pred_kpt = Tensor::ones(&[b, 17, h, w], (Kind::Float, dev));
|
||||
let gt_kpt = Tensor::ones(&[b, 17, h, w], (Kind::Float, dev));
|
||||
let logits = Tensor::zeros(&[b, 25, h, w], (Kind::Float, dev));
|
||||
let labels = Tensor::zeros(&[b, h, w], (Kind::Int64, dev));
|
||||
let pred_uv = Tensor::ones(&[b, 48, h, w], (Kind::Float, dev));
|
||||
let gt_uv = Tensor::ones(&[b, 48, h, w], (Kind::Float, dev));
|
||||
let sf = Tensor::ones(&[b, 64, 2, 2], (Kind::Float, dev));
|
||||
let tf = Tensor::ones(&[b, 64, 2, 2], (Kind::Float, dev));
|
||||
|
||||
let out = compute_losses(
|
||||
&pred_kpt, >_kpt,
|
||||
Some(&logits), Some(&labels),
|
||||
Some(&pred_uv), Some(>_uv),
|
||||
Some(&sf), Some(&tf),
|
||||
1.0, 0.5, 0.1,
|
||||
);
|
||||
|
||||
assert!(out.total.is_finite() && out.total >= 0.0);
|
||||
assert!(out.densepose_parts.is_some());
|
||||
assert!(out.densepose_uv.is_some());
|
||||
assert!(out.transfer.is_some());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user