feat(rust): Complete training pipeline — losses, metrics, model, trainer, binaries
Losses (losses.rs — 1056 lines): - WiFiDensePoseLoss with keypoint (visibility-masked MSE), DensePose (cross-entropy + Smooth L1 UV masked to foreground), transfer (MSE) - generate_gaussian_heatmaps: Tensor-native 2D Gaussian heatmap gen - compute_losses: unified functional API - 11 deterministic unit tests Metrics (metrics.rs — 984 lines): - PCK@0.2 / PCK@0.5 with torso-diameter normalisation - OKS with COCO standard per-joint sigmas - MetricsAccumulator for online streaming eval - hungarian_assignment: O(n³) Kuhn-Munkres min-cut via DFS augmenting paths for optimal multi-person keypoint assignment (ruvector min-cut) - build_oks_cost_matrix: 1−OKS cost for bipartite matching - 20 deterministic tests (perfect/wrong/invisible keypoints, 2×2/3×3/ rectangular/empty Hungarian cases) Model (model.rs — 713 lines): - WiFiDensePoseModel end-to-end with tch-rs - ModalityTranslator: amp+phase FC encoders → spatial pseudo-image - Backbone: lightweight ResNet-style [B,3,48,48]→[B,256,6,6] - KeypointHead: [B,256,6,6]→[B,17,H,W] heatmaps - DensePoseHead: [B,256,6,6]→[B,25,H,W] parts + [B,48,H,W] UV Trainer (trainer.rs — 777 lines): - Full training loop: Adam, LR milestones, gradient clipping - Deterministic batch shuffle via LCG (seed XOR epoch) - CSV logging, best-checkpoint saving, early stopping - evaluate() with MetricsAccumulator and heatmap argmax decode Binaries: - src/bin/train.rs: production MM-Fi training CLI (clap) - src/bin/verify_training.rs: trust kill switch (EXIT 0/1/2) Benches: - benches/training_bench.rs: criterion benchmarks for key ops Tests: - tests/test_dataset.rs (459 lines) - tests/test_metrics.rs (449 lines) - tests/test_subcarrier.rs (389 lines) proof.rs still stub — trainer agent completing it. https://claude.ai/code/session_01BSBAQJ34SLkiJy4A8SoiL4
This commit is contained in:
@@ -1,16 +1,713 @@
|
||||
//! WiFi-DensePose model definition and construction.
|
||||
//! WiFi-DensePose end-to-end model using tch-rs (PyTorch Rust bindings).
|
||||
//!
|
||||
//! This module will be implemented by the trainer agent. It currently provides
|
||||
//! the public interface stubs so that the crate compiles as a whole.
|
||||
//! # Architecture
|
||||
//!
|
||||
//! ```text
|
||||
//! CSI amplitude + phase
|
||||
//! │
|
||||
//! ▼
|
||||
//! ┌─────────────────────┐
|
||||
//! │ PhaseSanitizerNet │ differentiable conjugate multiplication
|
||||
//! └─────────────────────┘
|
||||
//! │
|
||||
//! ▼
|
||||
//! ┌────────────────────────────┐
|
||||
//! │ ModalityTranslatorNet │ CSI → spatial pseudo-image [B, 3, 48, 48]
|
||||
//! └────────────────────────────┘
|
||||
//! │
|
||||
//! ▼
|
||||
//! ┌─────────────────┐
|
||||
//! │ ResNet18-like │ [B, 256, H/4, W/4] feature maps
|
||||
//! │ Backbone │
|
||||
//! └─────────────────┘
|
||||
//! │
|
||||
//! ┌───┴───┐
|
||||
//! │ │
|
||||
//! ▼ ▼
|
||||
//! ┌─────┐ ┌────────────┐
|
||||
//! │ KP │ │ DensePose │
|
||||
//! │ Head│ │ Head │
|
||||
//! └─────┘ └────────────┘
|
||||
//! [B,17,H,W] [B,25,H,W] + [B,48,H,W]
|
||||
//! ```
|
||||
//!
|
||||
//! # No pre-trained weights
|
||||
//!
|
||||
//! The backbone uses a ResNet18-compatible architecture built purely with
|
||||
//! `tch::nn`. Weights are initialised from scratch (Kaiming uniform by
|
||||
//! default from tch). Pre-trained ImageNet weights are not loaded because
|
||||
//! network access is not guaranteed during training runs.
|
||||
|
||||
/// Placeholder for the compiled model handle.
|
||||
use std::path::Path;
|
||||
use tch::{nn, nn::Module, nn::ModuleT, Device, Kind, Tensor};
|
||||
|
||||
use crate::config::TrainingConfig;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Public output type
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Outputs produced by a single forward pass of [`WiFiDensePoseModel`].
|
||||
pub struct ModelOutput {
|
||||
/// Keypoint heatmaps: `[B, 17, H, W]`.
|
||||
pub keypoints: Tensor,
|
||||
/// Body-part logits (24 parts + background): `[B, 25, H, W]`.
|
||||
pub part_logits: Tensor,
|
||||
/// UV coordinates (24 × 2 channels interleaved): `[B, 48, H, W]`.
|
||||
pub uv_coords: Tensor,
|
||||
/// Backbone feature map used for cross-modal transfer loss: `[B, 256, H/4, W/4]`.
|
||||
pub features: Tensor,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// WiFiDensePoseModel
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Complete WiFi-DensePose model.
|
||||
///
|
||||
/// The real implementation wraps a `tch::CModule` or a custom `nn::Module`.
|
||||
pub struct DensePoseModel;
|
||||
/// Input: CSI amplitude and phase tensors with shape
|
||||
/// `[B, T*n_tx*n_rx, n_sub]` (flattened antenna-time dimension).
|
||||
///
|
||||
/// Output: [`ModelOutput`] with keypoints and DensePose predictions.
|
||||
pub struct WiFiDensePoseModel {
|
||||
vs: nn::VarStore,
|
||||
config: TrainingConfig,
|
||||
}
|
||||
|
||||
impl DensePoseModel {
|
||||
/// Construct a new model from the given number of subcarriers and keypoints.
|
||||
pub fn new(_num_subcarriers: usize, _num_keypoints: usize) -> Self {
|
||||
DensePoseModel
|
||||
// Internal model components stored in the VarStore.
|
||||
// We use sub-paths inside the single VarStore to keep all parameters in
|
||||
// one serialisable store.
|
||||
|
||||
impl WiFiDensePoseModel {
|
||||
/// Create a new model on `device`.
|
||||
///
|
||||
/// All sub-networks are constructed and their parameters registered in the
|
||||
/// internal `VarStore`.
|
||||
pub fn new(config: &TrainingConfig, device: Device) -> Self {
|
||||
let vs = nn::VarStore::new(device);
|
||||
WiFiDensePoseModel {
|
||||
vs,
|
||||
config: config.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass with gradient tracking (training mode).
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `amplitude`: `[B, T*n_tx*n_rx, n_sub]`
|
||||
/// - `phase`: `[B, T*n_tx*n_rx, n_sub]`
|
||||
pub fn forward_train(&self, amplitude: &Tensor, phase: &Tensor) -> ModelOutput {
|
||||
self.forward_impl(amplitude, phase, true)
|
||||
}
|
||||
|
||||
/// Forward pass without gradient tracking (inference mode).
|
||||
pub fn forward_inference(&self, amplitude: &Tensor, phase: &Tensor) -> ModelOutput {
|
||||
tch::no_grad(|| self.forward_impl(amplitude, phase, false))
|
||||
}
|
||||
|
||||
/// Save model weights to `path`.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the file cannot be written.
|
||||
pub fn save(&self, path: &Path) -> Result<(), Box<dyn std::error::Error>> {
|
||||
self.vs.save(path)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load model weights from `path`.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the file cannot be read or the weights are
|
||||
/// incompatible with the model architecture.
|
||||
pub fn load(
|
||||
path: &Path,
|
||||
config: &TrainingConfig,
|
||||
device: Device,
|
||||
) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let mut model = Self::new(config, device);
|
||||
// Build parameter graph first so load can find named tensors.
|
||||
let _dummy_amp = Tensor::zeros(
|
||||
[1, 1, config.num_subcarriers as i64],
|
||||
(Kind::Float, device),
|
||||
);
|
||||
let _dummy_phase = _dummy_amp.shallow_clone();
|
||||
let _ = model.forward_impl(&_dummy_amp, &_dummy_phase, false);
|
||||
model.vs.load(path)?;
|
||||
Ok(model)
|
||||
}
|
||||
|
||||
/// Return all trainable variable tensors.
|
||||
pub fn trainable_variables(&self) -> Vec<Tensor> {
|
||||
self.vs
|
||||
.trainable_variables()
|
||||
.into_iter()
|
||||
.map(|t| t.shallow_clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Count total trainable parameters.
|
||||
pub fn num_parameters(&self) -> usize {
|
||||
self.vs
|
||||
.trainable_variables()
|
||||
.iter()
|
||||
.map(|t| t.numel() as usize)
|
||||
.sum()
|
||||
}
|
||||
|
||||
/// Access the internal `VarStore` (e.g. to create an optimizer).
|
||||
pub fn var_store(&self) -> &nn::VarStore {
|
||||
&self.vs
|
||||
}
|
||||
|
||||
/// Mutable access to the internal `VarStore`.
|
||||
pub fn var_store_mut(&mut self) -> &mut nn::VarStore {
|
||||
&mut self.vs
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------
|
||||
// Internal forward implementation
|
||||
// ------------------------------------------------------------------
|
||||
|
||||
fn forward_impl(
|
||||
&self,
|
||||
amplitude: &Tensor,
|
||||
phase: &Tensor,
|
||||
train: bool,
|
||||
) -> ModelOutput {
|
||||
let root = self.vs.root();
|
||||
let cfg = &self.config;
|
||||
|
||||
// ── Phase sanitization ───────────────────────────────────────────
|
||||
let clean_phase = phase_sanitize(phase);
|
||||
|
||||
// ── Modality translation ─────────────────────────────────────────
|
||||
// Flatten antenna-time and subcarrier dimensions → [B, flat]
|
||||
let batch = amplitude.size()[0];
|
||||
let flat_amp = amplitude.reshape([batch, -1]);
|
||||
let flat_phase = clean_phase.reshape([batch, -1]);
|
||||
let input_size = flat_amp.size()[1];
|
||||
|
||||
let spatial = modality_translate(&root, &flat_amp, &flat_phase, input_size, train);
|
||||
// spatial: [B, 3, 48, 48]
|
||||
|
||||
// ── ResNet18-like backbone ────────────────────────────────────────
|
||||
let (features, feat_h, feat_w) = resnet18_backbone(&root, &spatial, train, cfg.backbone_channels as i64);
|
||||
// features: [B, 256, 12, 12]
|
||||
|
||||
// ── Keypoint head ────────────────────────────────────────────────
|
||||
let kp_h = cfg.heatmap_size as i64;
|
||||
let kp_w = kp_h;
|
||||
let keypoints = keypoint_head(&root, &features, cfg.num_keypoints as i64, (kp_h, kp_w), train);
|
||||
|
||||
// ── DensePose head ───────────────────────────────────────────────
|
||||
let (part_logits, uv_coords) = densepose_head(
|
||||
&root,
|
||||
&features,
|
||||
(cfg.num_body_parts + 1) as i64, // +1 for background
|
||||
(kp_h, kp_w),
|
||||
train,
|
||||
);
|
||||
|
||||
ModelOutput {
|
||||
keypoints,
|
||||
part_logits,
|
||||
uv_coords,
|
||||
features,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Phase sanitizer (no learned parameters)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Differentiable phase sanitization via conjugate multiplication.
|
||||
///
|
||||
/// Implements the CSI ratio model: for each adjacent subcarrier pair, compute
|
||||
/// the phase difference to cancel out common-mode phase drift (e.g. carrier
|
||||
/// frequency offset, sampling offset).
|
||||
///
|
||||
/// Input: `[B, T*n_ant, n_sub]`
|
||||
/// Output: `[B, T*n_ant, n_sub]` (sanitized phase)
|
||||
fn phase_sanitize(phase: &Tensor) -> Tensor {
|
||||
// For each subcarrier k, compute the differential phase:
|
||||
// φ_clean[k] = φ[k] - φ[k-1] for k > 0
|
||||
// φ_clean[0] = 0
|
||||
//
|
||||
// This removes linear phase ramps caused by timing and CFO.
|
||||
// Implemented as: diff along last dimension with zero-padding on the left.
|
||||
|
||||
let n_sub = phase.size()[2];
|
||||
if n_sub <= 1 {
|
||||
return phase.zeros_like();
|
||||
}
|
||||
|
||||
// Slice k=1..N and k=0..N-1, compute difference.
|
||||
let later = phase.slice(2, 1, n_sub, 1);
|
||||
let earlier = phase.slice(2, 0, n_sub - 1, 1);
|
||||
let diff = later - earlier;
|
||||
|
||||
// Prepend a zero column so the output has the same shape as input.
|
||||
let zeros = Tensor::zeros(
|
||||
[phase.size()[0], phase.size()[1], 1],
|
||||
(Kind::Float, phase.device()),
|
||||
);
|
||||
Tensor::cat(&[zeros, diff], 2)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Modality translator
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Build and run the modality translator network.
|
||||
///
|
||||
/// Architecture:
|
||||
/// - Amplitude encoder: `Linear(input_size, 512) → ReLU → Linear(512, 256) → ReLU`
|
||||
/// - Phase encoder: same structure as amplitude encoder
|
||||
/// - Fusion: `Linear(512, 256) → ReLU → Linear(256, 48*48*3)`
|
||||
/// → reshape to `[B, 3, 48, 48]`
|
||||
///
|
||||
/// All layers share the same `root` VarStore path so weights accumulate
|
||||
/// across calls (the parameters are created lazily on first call and reused).
|
||||
fn modality_translate(
|
||||
root: &nn::Path,
|
||||
flat_amp: &Tensor,
|
||||
flat_phase: &Tensor,
|
||||
input_size: i64,
|
||||
train: bool,
|
||||
) -> Tensor {
|
||||
let mt = root / "modality_translator";
|
||||
|
||||
// Amplitude encoder
|
||||
let ae = |x: &Tensor| {
|
||||
let h = ((&mt / "amp_enc_fc1").linear(x, input_size, 512));
|
||||
let h = h.relu();
|
||||
let h = ((&mt / "amp_enc_fc2").linear(&h, 512, 256));
|
||||
h.relu()
|
||||
};
|
||||
|
||||
// Phase encoder
|
||||
let pe = |x: &Tensor| {
|
||||
let h = ((&mt / "ph_enc_fc1").linear(x, input_size, 512));
|
||||
let h = h.relu();
|
||||
let h = ((&mt / "ph_enc_fc2").linear(&h, 512, 256));
|
||||
h.relu()
|
||||
};
|
||||
|
||||
let amp_feat = ae(flat_amp); // [B, 256]
|
||||
let phase_feat = pe(flat_phase); // [B, 256]
|
||||
|
||||
// Concatenate and fuse
|
||||
let fused = Tensor::cat(&[amp_feat, phase_feat], 1); // [B, 512]
|
||||
|
||||
let spatial_out: i64 = 3 * 48 * 48;
|
||||
let fused = (&mt / "fusion_fc1").linear(&fused, 512, 256);
|
||||
let fused = fused.relu();
|
||||
let fused = (&mt / "fusion_fc2").linear(&fused, 256, spatial_out);
|
||||
// fused: [B, 3*48*48]
|
||||
|
||||
let batch = fused.size()[0];
|
||||
let spatial_map = fused.reshape([batch, 3, 48, 48]);
|
||||
|
||||
// Optional: apply tanh to bound activations before passing to CNN.
|
||||
spatial_map.tanh()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Path::linear helper (creates or retrieves a Linear layer)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Extension trait to make `nn::Path` callable with `linear(x, in, out)`.
|
||||
trait PathLinear {
|
||||
fn linear(&self, x: &Tensor, in_dim: i64, out_dim: i64) -> Tensor;
|
||||
}
|
||||
|
||||
impl PathLinear for nn::Path<'_> {
|
||||
fn linear(&self, x: &Tensor, in_dim: i64, out_dim: i64) -> Tensor {
|
||||
let cfg = nn::LinearConfig::default();
|
||||
let layer = nn::linear(self, in_dim, out_dim, cfg);
|
||||
layer.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ResNet18-like backbone
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// A ResNet18-style CNN backbone.
|
||||
///
|
||||
/// Input: `[B, 3, 48, 48]`
|
||||
/// Output: `[B, 256, 12, 12]` (spatial features)
|
||||
///
|
||||
/// Architecture:
|
||||
/// - Stem: Conv2d(3→64, k=3, s=1, p=1) + BN + ReLU
|
||||
/// - Layer1: 2 × BasicBlock(64→64)
|
||||
/// - Layer2: 2 × BasicBlock(64→128, stride=2) → 24×24
|
||||
/// - Layer3: 2 × BasicBlock(128→256, stride=2) → 12×12
|
||||
///
|
||||
/// (No Layer4/pooling to preserve spatial resolution.)
|
||||
fn resnet18_backbone(
|
||||
root: &nn::Path,
|
||||
x: &Tensor,
|
||||
train: bool,
|
||||
out_channels: i64,
|
||||
) -> (Tensor, i64, i64) {
|
||||
let bb = root / "backbone";
|
||||
|
||||
// Stem
|
||||
let stem_conv = nn::conv2d(
|
||||
&(&bb / "stem_conv"),
|
||||
3,
|
||||
64,
|
||||
3,
|
||||
nn::ConvConfig { padding: 1, ..Default::default() },
|
||||
);
|
||||
let stem_bn = nn::batch_norm2d(&(&bb / "stem_bn"), 64, Default::default());
|
||||
let x = stem_conv.forward(x).apply_t(&stem_bn, train).relu();
|
||||
|
||||
// Layer 1: 64 → 64
|
||||
let x = basic_block(&(&bb / "l1b1"), &x, 64, 64, 1, train);
|
||||
let x = basic_block(&(&bb / "l1b2"), &x, 64, 64, 1, train);
|
||||
|
||||
// Layer 2: 64 → 128 (stride 2 → half spatial)
|
||||
let x = basic_block(&(&bb / "l2b1"), &x, 64, 128, 2, train);
|
||||
let x = basic_block(&(&bb / "l2b2"), &x, 128, 128, 1, train);
|
||||
|
||||
// Layer 3: 128 → out_channels (stride 2 → half spatial again)
|
||||
let x = basic_block(&(&bb / "l3b1"), &x, 128, out_channels, 2, train);
|
||||
let x = basic_block(&(&bb / "l3b2"), &x, out_channels, out_channels, 1, train);
|
||||
|
||||
let shape = x.size();
|
||||
let h = shape[2];
|
||||
let w = shape[3];
|
||||
(x, h, w)
|
||||
}
|
||||
|
||||
/// ResNet BasicBlock.
|
||||
///
|
||||
/// ```text
|
||||
/// x ─── Conv2d(s) ─── BN ─── ReLU ─── Conv2d(1) ─── BN ──+── ReLU
|
||||
/// │ │
|
||||
/// └── (downsample if needed) ──────────────────────────────┘
|
||||
/// ```
|
||||
fn basic_block(
|
||||
path: &nn::Path,
|
||||
x: &Tensor,
|
||||
in_ch: i64,
|
||||
out_ch: i64,
|
||||
stride: i64,
|
||||
train: bool,
|
||||
) -> Tensor {
|
||||
let conv1 = nn::conv2d(
|
||||
&(path / "conv1"),
|
||||
in_ch,
|
||||
out_ch,
|
||||
3,
|
||||
nn::ConvConfig { stride, padding: 1, bias: false, ..Default::default() },
|
||||
);
|
||||
let bn1 = nn::batch_norm2d(&(path / "bn1"), out_ch, Default::default());
|
||||
|
||||
let conv2 = nn::conv2d(
|
||||
&(path / "conv2"),
|
||||
out_ch,
|
||||
out_ch,
|
||||
3,
|
||||
nn::ConvConfig { padding: 1, bias: false, ..Default::default() },
|
||||
);
|
||||
let bn2 = nn::batch_norm2d(&(path / "bn2"), out_ch, Default::default());
|
||||
|
||||
let out = conv1.forward(x).apply_t(&bn1, train).relu();
|
||||
let out = conv2.forward(&out).apply_t(&bn2, train);
|
||||
|
||||
// Residual / skip connection
|
||||
let residual = if in_ch != out_ch || stride != 1 {
|
||||
let ds_conv = nn::conv2d(
|
||||
&(path / "ds_conv"),
|
||||
in_ch,
|
||||
out_ch,
|
||||
1,
|
||||
nn::ConvConfig { stride, bias: false, ..Default::default() },
|
||||
);
|
||||
let ds_bn = nn::batch_norm2d(&(path / "ds_bn"), out_ch, Default::default());
|
||||
ds_conv.forward(x).apply_t(&ds_bn, train)
|
||||
} else {
|
||||
x.shallow_clone()
|
||||
};
|
||||
|
||||
(out + residual).relu()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Keypoint head
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Keypoint heatmap prediction head.
|
||||
///
|
||||
/// Input: `[B, in_channels, H, W]`
|
||||
/// Output: `[B, num_keypoints, out_h, out_w]` (after upsampling)
|
||||
fn keypoint_head(
|
||||
root: &nn::Path,
|
||||
features: &Tensor,
|
||||
num_keypoints: i64,
|
||||
output_size: (i64, i64),
|
||||
train: bool,
|
||||
) -> Tensor {
|
||||
let kp = root / "keypoint_head";
|
||||
|
||||
let conv1 = nn::conv2d(
|
||||
&(&kp / "conv1"),
|
||||
256,
|
||||
256,
|
||||
3,
|
||||
nn::ConvConfig { padding: 1, bias: false, ..Default::default() },
|
||||
);
|
||||
let bn1 = nn::batch_norm2d(&(&kp / "bn1"), 256, Default::default());
|
||||
|
||||
let conv2 = nn::conv2d(
|
||||
&(&kp / "conv2"),
|
||||
256,
|
||||
128,
|
||||
3,
|
||||
nn::ConvConfig { padding: 1, bias: false, ..Default::default() },
|
||||
);
|
||||
let bn2 = nn::batch_norm2d(&(&kp / "bn2"), 128, Default::default());
|
||||
|
||||
let output_conv = nn::conv2d(
|
||||
&(&kp / "output_conv"),
|
||||
128,
|
||||
num_keypoints,
|
||||
1,
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
let x = conv1.forward(features).apply_t(&bn1, train).relu();
|
||||
let x = conv2.forward(&x).apply_t(&bn2, train).relu();
|
||||
let x = output_conv.forward(&x);
|
||||
|
||||
// Upsample to (output_size_h, output_size_w)
|
||||
x.upsample_bilinear2d(
|
||||
[output_size.0, output_size.1],
|
||||
false,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// DensePose head
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// DensePose prediction head.
|
||||
///
|
||||
/// Input: `[B, in_channels, H, W]`
|
||||
/// Outputs:
|
||||
/// - part logits: `[B, num_parts, out_h, out_w]`
|
||||
/// - UV coordinates: `[B, 2*(num_parts-1), out_h, out_w]` (background excluded from UV)
|
||||
fn densepose_head(
|
||||
root: &nn::Path,
|
||||
features: &Tensor,
|
||||
num_parts: i64,
|
||||
output_size: (i64, i64),
|
||||
train: bool,
|
||||
) -> (Tensor, Tensor) {
|
||||
let dp = root / "densepose_head";
|
||||
|
||||
// Shared convolutional block
|
||||
let shared_conv1 = nn::conv2d(
|
||||
&(&dp / "shared_conv1"),
|
||||
256,
|
||||
256,
|
||||
3,
|
||||
nn::ConvConfig { padding: 1, bias: false, ..Default::default() },
|
||||
);
|
||||
let shared_bn1 = nn::batch_norm2d(&(&dp / "shared_bn1"), 256, Default::default());
|
||||
|
||||
let shared_conv2 = nn::conv2d(
|
||||
&(&dp / "shared_conv2"),
|
||||
256,
|
||||
256,
|
||||
3,
|
||||
nn::ConvConfig { padding: 1, bias: false, ..Default::default() },
|
||||
);
|
||||
let shared_bn2 = nn::batch_norm2d(&(&dp / "shared_bn2"), 256, Default::default());
|
||||
|
||||
// Part segmentation head: 256 → num_parts
|
||||
let part_conv = nn::conv2d(
|
||||
&(&dp / "part_conv"),
|
||||
256,
|
||||
num_parts,
|
||||
1,
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
// UV regression head: 256 → 48 channels (2 × 24 body parts)
|
||||
let uv_conv = nn::conv2d(
|
||||
&(&dp / "uv_conv"),
|
||||
256,
|
||||
48, // 24 parts × 2 (U, V)
|
||||
1,
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
let shared = shared_conv1.forward(features).apply_t(&shared_bn1, train).relu();
|
||||
let shared = shared_conv2.forward(&shared).apply_t(&shared_bn2, train).relu();
|
||||
|
||||
let parts = part_conv.forward(&shared);
|
||||
let uv = uv_conv.forward(&shared);
|
||||
|
||||
// Upsample both heads to the target spatial resolution.
|
||||
let parts_up = parts.upsample_bilinear2d(
|
||||
[output_size.0, output_size.1],
|
||||
false,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
let uv_up = uv.upsample_bilinear2d(
|
||||
[output_size.0, output_size.1],
|
||||
false,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
// Apply sigmoid to UV to constrain predictions to [0, 1].
|
||||
let uv_out = uv_up.sigmoid();
|
||||
|
||||
(parts_up, uv_out)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::TrainingConfig;
|
||||
use tch::Device;
|
||||
|
||||
fn tiny_config() -> TrainingConfig {
|
||||
let mut cfg = TrainingConfig::default();
|
||||
cfg.num_subcarriers = 8;
|
||||
cfg.window_frames = 4;
|
||||
cfg.num_antennas_tx = 1;
|
||||
cfg.num_antennas_rx = 1;
|
||||
cfg.heatmap_size = 12;
|
||||
cfg.backbone_channels = 64;
|
||||
cfg.num_epochs = 2;
|
||||
cfg.warmup_epochs = 1;
|
||||
cfg
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn model_forward_output_shapes() {
|
||||
tch::manual_seed(0);
|
||||
let cfg = tiny_config();
|
||||
let device = Device::Cpu;
|
||||
let model = WiFiDensePoseModel::new(&cfg, device);
|
||||
|
||||
let batch = 2_i64;
|
||||
let antennas = (cfg.num_antennas_tx * cfg.num_antennas_rx * cfg.window_frames) as i64;
|
||||
let n_sub = cfg.num_subcarriers as i64;
|
||||
|
||||
let amp = Tensor::ones([batch, antennas, n_sub], (Kind::Float, device));
|
||||
let ph = Tensor::zeros([batch, antennas, n_sub], (Kind::Float, device));
|
||||
|
||||
let out = model.forward_train(&, &ph);
|
||||
|
||||
// Keypoints: [B, 17, heatmap_size, heatmap_size]
|
||||
assert_eq!(out.keypoints.size()[0], batch);
|
||||
assert_eq!(out.keypoints.size()[1], cfg.num_keypoints as i64);
|
||||
|
||||
// Part logits: [B, 25, heatmap_size, heatmap_size]
|
||||
assert_eq!(out.part_logits.size()[0], batch);
|
||||
assert_eq!(out.part_logits.size()[1], (cfg.num_body_parts + 1) as i64);
|
||||
|
||||
// UV: [B, 48, heatmap_size, heatmap_size]
|
||||
assert_eq!(out.uv_coords.size()[0], batch);
|
||||
assert_eq!(out.uv_coords.size()[1], 48);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn model_has_nonzero_parameters() {
|
||||
tch::manual_seed(0);
|
||||
let cfg = tiny_config();
|
||||
let model = WiFiDensePoseModel::new(&cfg, Device::Cpu);
|
||||
|
||||
// Trigger parameter creation by running a forward pass.
|
||||
let batch = 1_i64;
|
||||
let antennas = (cfg.num_antennas_tx * cfg.num_antennas_rx * cfg.window_frames) as i64;
|
||||
let n_sub = cfg.num_subcarriers as i64;
|
||||
let amp = Tensor::zeros([batch, antennas, n_sub], (Kind::Float, Device::Cpu));
|
||||
let ph = amp.shallow_clone();
|
||||
let _ = model.forward_train(&, &ph);
|
||||
|
||||
let n = model.num_parameters();
|
||||
assert!(n > 0, "Model must have trainable parameters");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn phase_sanitize_zeros_first_column() {
|
||||
let ph = Tensor::ones([2, 3, 8], (Kind::Float, Device::Cpu));
|
||||
let out = phase_sanitize(&ph);
|
||||
// First subcarrier column should be 0.
|
||||
let first_col = out.slice(2, 0, 1, 1);
|
||||
let max_abs: f64 = first_col.abs().max().double_value(&[]);
|
||||
assert!(max_abs < 1e-6, "First diff column should be 0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn phase_sanitize_captures_ramp() {
|
||||
// A linear phase ramp φ[k] = k should produce constant diffs of 1.
|
||||
let ph = Tensor::arange(8, (Kind::Float, Device::Cpu))
|
||||
.reshape([1, 1, 8])
|
||||
.expand([2, 3, 8], true);
|
||||
let out = phase_sanitize(&ph);
|
||||
// All columns except the first should be 1.0
|
||||
let tail = out.slice(2, 1, 8, 1);
|
||||
let min_val: f64 = tail.min().double_value(&[]);
|
||||
let max_val: f64 = tail.max().double_value(&[]);
|
||||
assert!((min_val - 1.0).abs() < 1e-5, "Expected 1.0 diff, got {min_val}");
|
||||
assert!((max_val - 1.0).abs() < 1e-5, "Expected 1.0 diff, got {max_val}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn inference_mode_gives_same_shapes() {
|
||||
tch::manual_seed(0);
|
||||
let cfg = tiny_config();
|
||||
let model = WiFiDensePoseModel::new(&cfg, Device::Cpu);
|
||||
|
||||
let batch = 1_i64;
|
||||
let antennas = (cfg.num_antennas_tx * cfg.num_antennas_rx * cfg.window_frames) as i64;
|
||||
let n_sub = cfg.num_subcarriers as i64;
|
||||
let amp = Tensor::rand([batch, antennas, n_sub], (Kind::Float, Device::Cpu));
|
||||
let ph = Tensor::rand([batch, antennas, n_sub], (Kind::Float, Device::Cpu));
|
||||
|
||||
let out = model.forward_inference(&, &ph);
|
||||
assert_eq!(out.keypoints.size()[0], batch);
|
||||
assert_eq!(out.part_logits.size()[0], batch);
|
||||
assert_eq!(out.uv_coords.size()[0], batch);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn uv_coords_bounded_zero_one() {
|
||||
tch::manual_seed(0);
|
||||
let cfg = tiny_config();
|
||||
let model = WiFiDensePoseModel::new(&cfg, Device::Cpu);
|
||||
|
||||
let batch = 2_i64;
|
||||
let antennas = (cfg.num_antennas_tx * cfg.num_antennas_rx * cfg.window_frames) as i64;
|
||||
let n_sub = cfg.num_subcarriers as i64;
|
||||
let amp = Tensor::rand([batch, antennas, n_sub], (Kind::Float, Device::Cpu));
|
||||
let ph = Tensor::rand([batch, antennas, n_sub], (Kind::Float, Device::Cpu));
|
||||
|
||||
let out = model.forward_inference(&, &ph);
|
||||
|
||||
let uv_min: f64 = out.uv_coords.min().double_value(&[]);
|
||||
let uv_max: f64 = out.uv_coords.max().double_value(&[]);
|
||||
assert!(uv_min >= 0.0 - 1e-5, "UV min should be >= 0, got {uv_min}");
|
||||
assert!(uv_max <= 1.0 + 1e-5, "UV max should be <= 1, got {uv_max}");
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user