feat(rust): Add workspace deps, tests, and refine training modules

- Cargo.toml: Add wifi-densepose-train to workspace members; add
  petgraph, ndarray-npy, walkdir, sha2, csv, indicatif, clap to
  workspace dependencies
- error.rs: Slim down to focused error types (TrainError, DatasetError)
- lib.rs: Wire up all module re-exports correctly
- losses.rs: Add generate_gaussian_heatmaps implementation
- tests/test_config.rs: Deterministic config roundtrip and validation tests

https://claude.ai/code/session_01BSBAQJ34SLkiJy4A8SoiL4
This commit is contained in:
Claude
2026-02-28 15:17:17 +00:00
parent ec98e40fff
commit 2c5ca308a4
5 changed files with 643 additions and 290 deletions

View File

@@ -906,4 +906,148 @@ mod tests {
"DensePose loss with identical UV should be bounded by CE, got {val}"
);
}
// ── Standalone functional API tests ──────────────────────────────────────
#[test]
fn test_fn_keypoint_heatmap_loss_identical_zero() {
let dev = device();
let t = Tensor::ones([2, 17, 8, 8], (Kind::Float, dev));
let loss = keypoint_heatmap_loss(&t, &t);
let v = loss.double_value(&[]) as f32;
assert!(v.abs() < 1e-6, "Identical heatmaps → loss must be ≈0, got {v}");
}
#[test]
fn test_fn_generate_gaussian_heatmaps_shape() {
let dev = device();
let kpts = Tensor::full(&[2i64, 17, 2], 0.5, (Kind::Float, dev));
let vis = Tensor::ones(&[2i64, 17], (Kind::Float, dev));
let hm = generate_gaussian_heatmaps(&kpts, &vis, 16, 2.0);
assert_eq!(hm.size(), [2, 17, 16, 16]);
}
#[test]
fn test_fn_generate_gaussian_heatmaps_invisible_zero() {
let dev = device();
let kpts = Tensor::full(&[1i64, 17, 2], 0.5, (Kind::Float, dev));
let vis = Tensor::zeros(&[1i64, 17], (Kind::Float, dev)); // all invisible
let hm = generate_gaussian_heatmaps(&kpts, &vis, 8, 2.0);
let total: f64 = hm.sum(Kind::Float).double_value(&[]);
assert_eq!(total, 0.0, "All-invisible heatmaps must be zero");
}
#[test]
fn test_fn_generate_gaussian_heatmaps_peak_near_one() {
let dev = device();
// Keypoint at (0.5, 0.5) on an 8×8 map.
let kpts = Tensor::full(&[1i64, 1, 2], 0.5, (Kind::Float, dev));
let vis = Tensor::ones(&[1i64, 1], (Kind::Float, dev));
let hm = generate_gaussian_heatmaps(&kpts, &vis, 8, 1.5);
let max_val: f64 = hm.max().double_value(&[]);
assert!(max_val > 0.9, "Peak value {max_val} should be > 0.9");
}
#[test]
fn test_fn_densepose_part_loss_returns_finite() {
let dev = device();
let logits = Tensor::zeros(&[1i64, 25, 4, 4], (Kind::Float, dev));
let labels = Tensor::zeros(&[1i64, 4, 4], (Kind::Int64, dev));
let loss = densepose_part_loss(&logits, &labels);
let v = loss.double_value(&[]);
assert!(v.is_finite() && v >= 0.0);
}
#[test]
fn test_fn_densepose_uv_loss_no_annotated_pixels_zero() {
let dev = device();
let pred = Tensor::ones(&[1i64, 48, 4, 4], (Kind::Float, dev));
let gt = Tensor::zeros(&[1i64, 48, 4, 4], (Kind::Float, dev));
let labels = Tensor::full(&[1i64, 4, 4], -1i64, (Kind::Int64, dev));
let loss = densepose_uv_loss(&pred, &gt, &labels);
let v = loss.double_value(&[]);
assert_eq!(v, 0.0, "No annotated pixels → UV loss must be 0");
}
#[test]
fn test_fn_densepose_uv_loss_identical_zero() {
let dev = device();
let t = Tensor::ones(&[1i64, 48, 4, 4], (Kind::Float, dev));
let labels = Tensor::zeros(&[1i64, 4, 4], (Kind::Int64, dev));
let loss = densepose_uv_loss(&t, &t, &labels);
let v = loss.double_value(&[]);
assert!(v.abs() < 1e-6, "Identical UV → loss ≈ 0, got {v}");
}
#[test]
fn test_fn_transfer_loss_identical_zero() {
let dev = device();
let t = Tensor::ones(&[2i64, 64, 8, 8], (Kind::Float, dev));
let loss = fn_transfer_loss(&t, &t);
let v = loss.double_value(&[]);
assert!(v.abs() < 1e-6, "Identical features → transfer loss ≈ 0, got {v}");
}
#[test]
fn test_fn_transfer_loss_spatial_mismatch() {
let dev = device();
let student = Tensor::ones(&[1i64, 64, 16, 16], (Kind::Float, dev));
let teacher = Tensor::ones(&[1i64, 64, 8, 8], (Kind::Float, dev));
let loss = fn_transfer_loss(&student, &teacher);
let v = loss.double_value(&[]);
assert!(v.is_finite() && v >= 0.0, "Spatial-mismatch transfer loss must be finite");
}
#[test]
fn test_fn_transfer_loss_channel_mismatch_divisible() {
let dev = device();
let student = Tensor::ones(&[1i64, 128, 8, 8], (Kind::Float, dev));
let teacher = Tensor::ones(&[1i64, 64, 8, 8], (Kind::Float, dev));
let loss = fn_transfer_loss(&student, &teacher);
let v = loss.double_value(&[]);
assert!(v.is_finite() && v >= 0.0);
}
#[test]
fn test_compute_losses_keypoint_only() {
let dev = device();
let pred = Tensor::ones(&[1i64, 17, 8, 8], (Kind::Float, dev));
let gt = Tensor::ones(&[1i64, 17, 8, 8], (Kind::Float, dev));
let out = compute_losses(&pred, &gt, None, None, None, None, None, None,
1.0, 1.0, 1.0);
assert!(out.total.is_finite());
assert!(out.keypoint >= 0.0);
assert!(out.densepose_parts.is_none());
assert!(out.densepose_uv.is_none());
assert!(out.transfer.is_none());
}
#[test]
fn test_compute_losses_all_components_finite() {
let dev = device();
let b = 1i64;
let h = 4i64;
let w = 4i64;
let pred_kpt = Tensor::ones(&[b, 17, h, w], (Kind::Float, dev));
let gt_kpt = Tensor::ones(&[b, 17, h, w], (Kind::Float, dev));
let logits = Tensor::zeros(&[b, 25, h, w], (Kind::Float, dev));
let labels = Tensor::zeros(&[b, h, w], (Kind::Int64, dev));
let pred_uv = Tensor::ones(&[b, 48, h, w], (Kind::Float, dev));
let gt_uv = Tensor::ones(&[b, 48, h, w], (Kind::Float, dev));
let sf = Tensor::ones(&[b, 64, 2, 2], (Kind::Float, dev));
let tf = Tensor::ones(&[b, 64, 2, 2], (Kind::Float, dev));
let out = compute_losses(
&pred_kpt, &gt_kpt,
Some(&logits), Some(&labels),
Some(&pred_uv), Some(&gt_uv),
Some(&sf), Some(&tf),
1.0, 0.5, 0.1,
);
assert!(out.total.is_finite() && out.total >= 0.0);
assert!(out.densepose_parts.is_some());
assert!(out.densepose_uv.is_some());
assert!(out.transfer.is_some());
}
}