feat(train): Add ruvector integration — ADR-016, deps, DynamicPersonMatcher
- docs/adr/ADR-016: Full ruvector integration ADR with verified API details from source inspection (github.com/ruvnet/ruvector). Covers mincut, attn-mincut, temporal-tensor, solver, and attention at v2.0.4. - Cargo.toml: Add ruvector-mincut, ruvector-attn-mincut, ruvector-temporal- tensor, ruvector-solver, ruvector-attention = "2.0.4" to workspace deps and wifi-densepose-train crate deps. - metrics.rs: Add DynamicPersonMatcher wrapping ruvector_mincut::DynamicMinCut for subpolynomial O(n^1.5 log n) multi-frame person tracking; adds assignment_mincut() public entry point. - proof.rs, trainer.rs, model.rs, dataset.rs, subcarrier.rs: Agent improvements to full implementations (loss decrease verification, SHA-256 hash, LCG shuffle, ResNet18 backbone, MmFiDataset, linear interp). - tests: test_config, test_dataset, test_metrics, test_proof, training_bench all added/updated. 100+ tests pass with no-default-features. https://claude.ai/code/session_01BSBAQJ34SLkiJy4A8SoiL4
This commit is contained in:
@@ -1,6 +1,12 @@
|
||||
//! Benchmarks for the WiFi-DensePose training pipeline.
|
||||
//!
|
||||
//! All benchmark inputs are constructed from fixed, deterministic data — no
|
||||
//! `rand` crate or OS entropy is used. This ensures that benchmark numbers are
|
||||
//! reproducible and that the benchmark harness itself cannot introduce
|
||||
//! non-determinism.
|
||||
//!
|
||||
//! Run with:
|
||||
//!
|
||||
//! ```bash
|
||||
//! cargo bench -p wifi-densepose-train
|
||||
//! ```
|
||||
@@ -15,95 +21,52 @@ use wifi_densepose_train::{
|
||||
subcarrier::{compute_interp_weights, interpolate_subcarriers},
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Dataset benchmarks
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Benchmark synthetic sample generation for a single index.
|
||||
fn bench_synthetic_get(c: &mut Criterion) {
|
||||
let syn_cfg = SyntheticConfig::default();
|
||||
let dataset = SyntheticCsiDataset::new(1000, syn_cfg);
|
||||
|
||||
c.bench_function("synthetic_dataset_get", |b| {
|
||||
b.iter(|| {
|
||||
let _ = dataset.get(black_box(42)).expect("sample 42 must exist");
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/// Benchmark full epoch iteration (no I/O — all in-process).
|
||||
fn bench_synthetic_epoch(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("synthetic_epoch");
|
||||
|
||||
for n_samples in [64usize, 256, 1024] {
|
||||
let syn_cfg = SyntheticConfig::default();
|
||||
let dataset = SyntheticCsiDataset::new(n_samples, syn_cfg);
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("samples", n_samples),
|
||||
&n_samples,
|
||||
|b, &n| {
|
||||
b.iter(|| {
|
||||
for i in 0..n {
|
||||
let _ = dataset.get(black_box(i)).expect("sample exists");
|
||||
}
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Subcarrier interpolation benchmarks
|
||||
// ---------------------------------------------------------------------------
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Benchmark `interpolate_subcarriers` for the standard 114 → 56 use-case.
|
||||
fn bench_interp_114_to_56(c: &mut Criterion) {
|
||||
// Simulate a single sample worth of raw CSI from MM-Fi.
|
||||
/// Benchmark `interpolate_subcarriers` 114 → 56 for a batch of 32 windows.
|
||||
///
|
||||
/// Represents the per-batch preprocessing step during a real training epoch.
|
||||
fn bench_interp_114_to_56_batch32(c: &mut Criterion) {
|
||||
let cfg = TrainingConfig::default();
|
||||
let arr: Array4<f32> = Array4::from_shape_fn(
|
||||
(cfg.window_frames, cfg.num_antennas_tx, cfg.num_antennas_rx, 114),
|
||||
let batch_size = 32_usize;
|
||||
|
||||
// Deterministic data: linear ramp across all axes.
|
||||
let arr = Array4::<f32>::from_shape_fn(
|
||||
(
|
||||
cfg.window_frames,
|
||||
cfg.num_antennas_tx * batch_size, // stack batch along tx dimension
|
||||
cfg.num_antennas_rx,
|
||||
114,
|
||||
),
|
||||
|(t, tx, rx, k)| (t + tx + rx + k) as f32 * 0.001,
|
||||
);
|
||||
|
||||
c.bench_function("interp_114_to_56", |b| {
|
||||
c.bench_function("interp_114_to_56_batch32", |b| {
|
||||
b.iter(|| {
|
||||
let _ = interpolate_subcarriers(black_box(&arr), black_box(56));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/// Benchmark `compute_interp_weights` to ensure it is fast enough to
|
||||
/// precompute at dataset construction time.
|
||||
fn bench_compute_interp_weights(c: &mut Criterion) {
|
||||
c.bench_function("compute_interp_weights_114_56", |b| {
|
||||
b.iter(|| {
|
||||
let _ = compute_interp_weights(black_box(114), black_box(56));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/// Benchmark interpolation for varying source subcarrier counts.
|
||||
/// Benchmark `interpolate_subcarriers` for varying source subcarrier counts.
|
||||
fn bench_interp_scaling(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("interp_scaling");
|
||||
let cfg = TrainingConfig::default();
|
||||
|
||||
for src_sc in [56usize, 114, 256, 512] {
|
||||
let arr: Array4<f32> = Array4::zeros((
|
||||
cfg.window_frames,
|
||||
cfg.num_antennas_tx,
|
||||
cfg.num_antennas_rx,
|
||||
src_sc,
|
||||
));
|
||||
for src_sc in [56_usize, 114, 256, 512] {
|
||||
let arr = Array4::<f32>::from_shape_fn(
|
||||
(cfg.window_frames, cfg.num_antennas_tx, cfg.num_antennas_rx, src_sc),
|
||||
|(t, tx, rx, k)| (t + tx + rx + k) as f32 * 0.001,
|
||||
);
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("src_sc", src_sc),
|
||||
&src_sc,
|
||||
|b, &sc| {
|
||||
if sc == 56 {
|
||||
// Identity case — skip; interpolate_subcarriers clones.
|
||||
// Identity case: the function just clones the array.
|
||||
b.iter(|| {
|
||||
let _ = arr.clone();
|
||||
});
|
||||
@@ -119,11 +82,59 @@ fn bench_interp_scaling(c: &mut Criterion) {
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Config benchmarks
|
||||
// ---------------------------------------------------------------------------
|
||||
/// Benchmark interpolation weight precomputation (called once at dataset
|
||||
/// construction time).
|
||||
fn bench_compute_interp_weights(c: &mut Criterion) {
|
||||
c.bench_function("compute_interp_weights_114_56", |b| {
|
||||
b.iter(|| {
|
||||
let _ = compute_interp_weights(black_box(114), black_box(56));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/// Benchmark TrainingConfig::validate() to ensure it stays O(1).
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// SyntheticCsiDataset benchmarks
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Benchmark a single `get()` call on the synthetic dataset.
|
||||
fn bench_synthetic_get(c: &mut Criterion) {
|
||||
let dataset = SyntheticCsiDataset::new(1000, SyntheticConfig::default());
|
||||
|
||||
c.bench_function("synthetic_dataset_get", |b| {
|
||||
b.iter(|| {
|
||||
let _ = dataset.get(black_box(42)).expect("sample 42 must exist");
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/// Benchmark sequential full-epoch iteration at varying dataset sizes.
|
||||
fn bench_synthetic_epoch(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("synthetic_epoch");
|
||||
|
||||
for n_samples in [64_usize, 256, 1024] {
|
||||
let dataset = SyntheticCsiDataset::new(n_samples, SyntheticConfig::default());
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("samples", n_samples),
|
||||
&n_samples,
|
||||
|b, &n| {
|
||||
b.iter(|| {
|
||||
for i in 0..n {
|
||||
let _ = dataset.get(black_box(i)).expect("sample must exist");
|
||||
}
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Config benchmarks
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Benchmark `TrainingConfig::validate()` to ensure it stays O(1).
|
||||
fn bench_config_validate(c: &mut Criterion) {
|
||||
let config = TrainingConfig::default();
|
||||
c.bench_function("config_validate", |b| {
|
||||
@@ -133,17 +144,86 @@ fn bench_config_validate(c: &mut Criterion) {
|
||||
});
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Criterion main
|
||||
// ---------------------------------------------------------------------------
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// PCK computation benchmark (pure Rust, no tch dependency)
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Inline PCK@threshold computation for a single (pred, gt) sample.
|
||||
#[inline(always)]
|
||||
fn compute_pck(pred: &[[f32; 2]], gt: &[[f32; 2]], threshold: f32) -> f32 {
|
||||
let n = pred.len();
|
||||
if n == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
let correct = pred
|
||||
.iter()
|
||||
.zip(gt.iter())
|
||||
.filter(|(p, g)| {
|
||||
let dx = p[0] - g[0];
|
||||
let dy = p[1] - g[1];
|
||||
(dx * dx + dy * dy).sqrt() <= threshold
|
||||
})
|
||||
.count();
|
||||
correct as f32 / n as f32
|
||||
}
|
||||
|
||||
/// Benchmark PCK computation over 100 deterministic samples.
|
||||
fn bench_pck_100_samples(c: &mut Criterion) {
|
||||
let num_samples = 100_usize;
|
||||
let num_joints = 17_usize;
|
||||
let threshold = 0.05_f32;
|
||||
|
||||
// Build deterministic fixed pred/gt pairs using sines for variety.
|
||||
let samples: Vec<(Vec<[f32; 2]>, Vec<[f32; 2]>)> = (0..num_samples)
|
||||
.map(|i| {
|
||||
let pred: Vec<[f32; 2]> = (0..num_joints)
|
||||
.map(|j| {
|
||||
[
|
||||
((i as f32 * 0.03 + j as f32 * 0.05).sin() * 0.5 + 0.5).clamp(0.0, 1.0),
|
||||
(j as f32 * 0.04 + 0.2_f32).clamp(0.0, 1.0),
|
||||
]
|
||||
})
|
||||
.collect();
|
||||
let gt: Vec<[f32; 2]> = (0..num_joints)
|
||||
.map(|j| {
|
||||
[
|
||||
((i as f32 * 0.03 + j as f32 * 0.05 + 0.01).sin() * 0.5 + 0.5)
|
||||
.clamp(0.0, 1.0),
|
||||
(j as f32 * 0.04 + 0.2_f32).clamp(0.0, 1.0),
|
||||
]
|
||||
})
|
||||
.collect();
|
||||
(pred, gt)
|
||||
})
|
||||
.collect();
|
||||
|
||||
c.bench_function("pck_100_samples", |b| {
|
||||
b.iter(|| {
|
||||
let total: f32 = samples
|
||||
.iter()
|
||||
.map(|(p, g)| compute_pck(black_box(p), black_box(g), threshold))
|
||||
.sum();
|
||||
let _ = total / num_samples as f32;
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Criterion registration
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
criterion_group!(
|
||||
benches,
|
||||
// Subcarrier interpolation
|
||||
bench_interp_114_to_56_batch32,
|
||||
bench_interp_scaling,
|
||||
bench_compute_interp_weights,
|
||||
// Dataset
|
||||
bench_synthetic_get,
|
||||
bench_synthetic_epoch,
|
||||
bench_interp_114_to_56,
|
||||
bench_compute_interp_weights,
|
||||
bench_interp_scaling,
|
||||
// Config
|
||||
bench_config_validate,
|
||||
// Metrics (pure Rust, no tch)
|
||||
bench_pck_100_samples,
|
||||
);
|
||||
criterion_main!(benches);
|
||||
|
||||
Reference in New Issue
Block a user