- docs/adr/ADR-016: Full ruvector integration ADR with verified API details from source inspection (github.com/ruvnet/ruvector). Covers mincut, attn-mincut, temporal-tensor, solver, and attention at v2.0.4. - Cargo.toml: Add ruvector-mincut, ruvector-attn-mincut, ruvector-temporal- tensor, ruvector-solver, ruvector-attention = "2.0.4" to workspace deps and wifi-densepose-train crate deps. - metrics.rs: Add DynamicPersonMatcher wrapping ruvector_mincut::DynamicMinCut for subpolynomial O(n^1.5 log n) multi-frame person tracking; adds assignment_mincut() public entry point. - proof.rs, trainer.rs, model.rs, dataset.rs, subcarrier.rs: Agent improvements to full implementations (loss decrease verification, SHA-256 hash, LCG shuffle, ResNet18 backbone, MmFiDataset, linear interp). - tests: test_config, test_dataset, test_metrics, test_proof, training_bench all added/updated. 100+ tests pass with no-default-features. https://claude.ai/code/session_01BSBAQJ34SLkiJy4A8SoiL4
458 lines
15 KiB
Rust
458 lines
15 KiB
Rust
//! Integration tests for [`wifi_densepose_train::config`].
|
||
//!
|
||
//! All tests are deterministic: they use only fixed values and the
|
||
//! `TrainingConfig::default()` constructor. No OS entropy or `rand` crate
|
||
//! is used.
|
||
|
||
use wifi_densepose_train::config::TrainingConfig;
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// Default config invariants
|
||
// ---------------------------------------------------------------------------
|
||
|
||
/// The default configuration must pass its own validation.
|
||
#[test]
|
||
fn default_config_is_valid() {
|
||
let cfg = TrainingConfig::default();
|
||
cfg.validate()
|
||
.expect("default TrainingConfig must be valid");
|
||
}
|
||
|
||
/// Every numeric field in the default config must be strictly positive where
|
||
/// the domain requires it.
|
||
#[test]
|
||
fn default_config_all_positive_fields() {
|
||
let cfg = TrainingConfig::default();
|
||
|
||
assert!(cfg.num_subcarriers > 0, "num_subcarriers must be > 0");
|
||
assert!(cfg.native_subcarriers > 0, "native_subcarriers must be > 0");
|
||
assert!(cfg.num_antennas_tx > 0, "num_antennas_tx must be > 0");
|
||
assert!(cfg.num_antennas_rx > 0, "num_antennas_rx must be > 0");
|
||
assert!(cfg.window_frames > 0, "window_frames must be > 0");
|
||
assert!(cfg.heatmap_size > 0, "heatmap_size must be > 0");
|
||
assert!(cfg.num_keypoints > 0, "num_keypoints must be > 0");
|
||
assert!(cfg.num_body_parts > 0, "num_body_parts must be > 0");
|
||
assert!(cfg.backbone_channels > 0, "backbone_channels must be > 0");
|
||
assert!(cfg.batch_size > 0, "batch_size must be > 0");
|
||
assert!(cfg.learning_rate > 0.0, "learning_rate must be > 0.0");
|
||
assert!(cfg.weight_decay >= 0.0, "weight_decay must be >= 0.0");
|
||
assert!(cfg.num_epochs > 0, "num_epochs must be > 0");
|
||
assert!(cfg.grad_clip_norm > 0.0, "grad_clip_norm must be > 0.0");
|
||
}
|
||
|
||
/// The three loss weights in the default config must all be non-negative and
|
||
/// their sum must be positive (not all zero).
|
||
#[test]
|
||
fn default_config_loss_weights_sum_positive() {
|
||
let cfg = TrainingConfig::default();
|
||
|
||
assert!(cfg.lambda_kp >= 0.0, "lambda_kp must be >= 0.0");
|
||
assert!(cfg.lambda_dp >= 0.0, "lambda_dp must be >= 0.0");
|
||
assert!(cfg.lambda_tr >= 0.0, "lambda_tr must be >= 0.0");
|
||
|
||
let total = cfg.lambda_kp + cfg.lambda_dp + cfg.lambda_tr;
|
||
assert!(
|
||
total > 0.0,
|
||
"sum of loss weights must be > 0.0, got {total}"
|
||
);
|
||
}
|
||
|
||
/// The default loss weights should sum to exactly 1.0 (within floating-point
|
||
/// tolerance).
|
||
#[test]
|
||
fn default_config_loss_weights_sum_to_one() {
|
||
let cfg = TrainingConfig::default();
|
||
let total = cfg.lambda_kp + cfg.lambda_dp + cfg.lambda_tr;
|
||
let diff = (total - 1.0_f64).abs();
|
||
assert!(
|
||
diff < 1e-9,
|
||
"expected loss weights to sum to 1.0, got {total} (diff={diff})"
|
||
);
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// Specific default values
|
||
// ---------------------------------------------------------------------------
|
||
|
||
/// The default number of subcarriers is 56 (MM-Fi target).
|
||
#[test]
|
||
fn default_num_subcarriers_is_56() {
|
||
let cfg = TrainingConfig::default();
|
||
assert_eq!(
|
||
cfg.num_subcarriers, 56,
|
||
"expected default num_subcarriers = 56, got {}",
|
||
cfg.num_subcarriers
|
||
);
|
||
}
|
||
|
||
/// The default number of native subcarriers is 114 (raw MM-Fi hardware output).
|
||
#[test]
|
||
fn default_native_subcarriers_is_114() {
|
||
let cfg = TrainingConfig::default();
|
||
assert_eq!(
|
||
cfg.native_subcarriers, 114,
|
||
"expected default native_subcarriers = 114, got {}",
|
||
cfg.native_subcarriers
|
||
);
|
||
}
|
||
|
||
/// The default number of keypoints is 17 (COCO skeleton).
|
||
#[test]
|
||
fn default_num_keypoints_is_17() {
|
||
let cfg = TrainingConfig::default();
|
||
assert_eq!(
|
||
cfg.num_keypoints, 17,
|
||
"expected default num_keypoints = 17, got {}",
|
||
cfg.num_keypoints
|
||
);
|
||
}
|
||
|
||
/// The default antenna counts are 3×3.
|
||
#[test]
|
||
fn default_antenna_counts_are_3x3() {
|
||
let cfg = TrainingConfig::default();
|
||
assert_eq!(cfg.num_antennas_tx, 3, "expected num_antennas_tx = 3");
|
||
assert_eq!(cfg.num_antennas_rx, 3, "expected num_antennas_rx = 3");
|
||
}
|
||
|
||
/// The default window length is 100 frames.
|
||
#[test]
|
||
fn default_window_frames_is_100() {
|
||
let cfg = TrainingConfig::default();
|
||
assert_eq!(
|
||
cfg.window_frames, 100,
|
||
"expected window_frames = 100, got {}",
|
||
cfg.window_frames
|
||
);
|
||
}
|
||
|
||
/// The default seed is 42.
|
||
#[test]
|
||
fn default_seed_is_42() {
|
||
let cfg = TrainingConfig::default();
|
||
assert_eq!(cfg.seed, 42, "expected seed = 42, got {}", cfg.seed);
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// needs_subcarrier_interp equivalent property
|
||
// ---------------------------------------------------------------------------
|
||
|
||
/// When native_subcarriers differs from num_subcarriers, interpolation is
|
||
/// needed. The default config has 114 != 56, so this property must hold.
|
||
#[test]
|
||
fn default_config_needs_interpolation() {
|
||
let cfg = TrainingConfig::default();
|
||
// 114 native → 56 target: interpolation is required.
|
||
assert_ne!(
|
||
cfg.native_subcarriers, cfg.num_subcarriers,
|
||
"default config must require subcarrier interpolation (native={} != target={})",
|
||
cfg.native_subcarriers, cfg.num_subcarriers
|
||
);
|
||
}
|
||
|
||
/// When native_subcarriers equals num_subcarriers no interpolation is needed.
|
||
#[test]
|
||
fn equal_subcarrier_counts_means_no_interpolation_needed() {
|
||
let mut cfg = TrainingConfig::default();
|
||
cfg.native_subcarriers = cfg.num_subcarriers; // e.g., both = 56
|
||
cfg.validate().expect("config with equal subcarrier counts must be valid");
|
||
assert_eq!(
|
||
cfg.native_subcarriers, cfg.num_subcarriers,
|
||
"after setting equal counts, native ({}) must equal target ({})",
|
||
cfg.native_subcarriers, cfg.num_subcarriers
|
||
);
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// csi_flat_size equivalent property
|
||
// ---------------------------------------------------------------------------
|
||
|
||
/// The flat input size of a single CSI window is
|
||
/// `window_frames × num_antennas_tx × num_antennas_rx × num_subcarriers`.
|
||
/// Verify the arithmetic matches the default config.
|
||
#[test]
|
||
fn csi_flat_size_matches_expected() {
|
||
let cfg = TrainingConfig::default();
|
||
let expected = cfg.window_frames
|
||
* cfg.num_antennas_tx
|
||
* cfg.num_antennas_rx
|
||
* cfg.num_subcarriers;
|
||
// Default: 100 * 3 * 3 * 56 = 50400
|
||
assert_eq!(
|
||
expected, 50_400,
|
||
"CSI flat size must be 50400 for default config, got {expected}"
|
||
);
|
||
}
|
||
|
||
/// The CSI flat size must be > 0 for any valid config.
|
||
#[test]
|
||
fn csi_flat_size_positive_for_valid_config() {
|
||
let cfg = TrainingConfig::default();
|
||
let flat_size = cfg.window_frames
|
||
* cfg.num_antennas_tx
|
||
* cfg.num_antennas_rx
|
||
* cfg.num_subcarriers;
|
||
assert!(
|
||
flat_size > 0,
|
||
"CSI flat size must be > 0, got {flat_size}"
|
||
);
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// JSON serialization round-trip
|
||
// ---------------------------------------------------------------------------
|
||
|
||
/// Serializing a config to JSON and deserializing it must yield an identical
|
||
/// config (all fields must match).
|
||
#[test]
|
||
fn config_json_roundtrip_identical() {
|
||
use tempfile::tempdir;
|
||
|
||
let tmp = tempdir().expect("tempdir must be created");
|
||
let path = tmp.path().join("config.json");
|
||
|
||
let original = TrainingConfig::default();
|
||
original
|
||
.to_json(&path)
|
||
.expect("to_json must succeed for default config");
|
||
|
||
let loaded = TrainingConfig::from_json(&path)
|
||
.expect("from_json must succeed for previously serialized config");
|
||
|
||
// Verify all fields are equal.
|
||
assert_eq!(
|
||
loaded.num_subcarriers, original.num_subcarriers,
|
||
"num_subcarriers must survive round-trip"
|
||
);
|
||
assert_eq!(
|
||
loaded.native_subcarriers, original.native_subcarriers,
|
||
"native_subcarriers must survive round-trip"
|
||
);
|
||
assert_eq!(
|
||
loaded.num_antennas_tx, original.num_antennas_tx,
|
||
"num_antennas_tx must survive round-trip"
|
||
);
|
||
assert_eq!(
|
||
loaded.num_antennas_rx, original.num_antennas_rx,
|
||
"num_antennas_rx must survive round-trip"
|
||
);
|
||
assert_eq!(
|
||
loaded.window_frames, original.window_frames,
|
||
"window_frames must survive round-trip"
|
||
);
|
||
assert_eq!(
|
||
loaded.heatmap_size, original.heatmap_size,
|
||
"heatmap_size must survive round-trip"
|
||
);
|
||
assert_eq!(
|
||
loaded.num_keypoints, original.num_keypoints,
|
||
"num_keypoints must survive round-trip"
|
||
);
|
||
assert_eq!(
|
||
loaded.num_body_parts, original.num_body_parts,
|
||
"num_body_parts must survive round-trip"
|
||
);
|
||
assert_eq!(
|
||
loaded.backbone_channels, original.backbone_channels,
|
||
"backbone_channels must survive round-trip"
|
||
);
|
||
assert_eq!(
|
||
loaded.batch_size, original.batch_size,
|
||
"batch_size must survive round-trip"
|
||
);
|
||
assert!(
|
||
(loaded.learning_rate - original.learning_rate).abs() < 1e-12,
|
||
"learning_rate must survive round-trip: got {}",
|
||
loaded.learning_rate
|
||
);
|
||
assert!(
|
||
(loaded.weight_decay - original.weight_decay).abs() < 1e-12,
|
||
"weight_decay must survive round-trip"
|
||
);
|
||
assert_eq!(
|
||
loaded.num_epochs, original.num_epochs,
|
||
"num_epochs must survive round-trip"
|
||
);
|
||
assert_eq!(
|
||
loaded.warmup_epochs, original.warmup_epochs,
|
||
"warmup_epochs must survive round-trip"
|
||
);
|
||
assert_eq!(
|
||
loaded.lr_milestones, original.lr_milestones,
|
||
"lr_milestones must survive round-trip"
|
||
);
|
||
assert!(
|
||
(loaded.lr_gamma - original.lr_gamma).abs() < 1e-12,
|
||
"lr_gamma must survive round-trip"
|
||
);
|
||
assert!(
|
||
(loaded.grad_clip_norm - original.grad_clip_norm).abs() < 1e-12,
|
||
"grad_clip_norm must survive round-trip"
|
||
);
|
||
assert!(
|
||
(loaded.lambda_kp - original.lambda_kp).abs() < 1e-12,
|
||
"lambda_kp must survive round-trip"
|
||
);
|
||
assert!(
|
||
(loaded.lambda_dp - original.lambda_dp).abs() < 1e-12,
|
||
"lambda_dp must survive round-trip"
|
||
);
|
||
assert!(
|
||
(loaded.lambda_tr - original.lambda_tr).abs() < 1e-12,
|
||
"lambda_tr must survive round-trip"
|
||
);
|
||
assert_eq!(
|
||
loaded.val_every_epochs, original.val_every_epochs,
|
||
"val_every_epochs must survive round-trip"
|
||
);
|
||
assert_eq!(
|
||
loaded.early_stopping_patience, original.early_stopping_patience,
|
||
"early_stopping_patience must survive round-trip"
|
||
);
|
||
assert_eq!(
|
||
loaded.save_top_k, original.save_top_k,
|
||
"save_top_k must survive round-trip"
|
||
);
|
||
assert_eq!(loaded.use_gpu, original.use_gpu, "use_gpu must survive round-trip");
|
||
assert_eq!(
|
||
loaded.gpu_device_id, original.gpu_device_id,
|
||
"gpu_device_id must survive round-trip"
|
||
);
|
||
assert_eq!(
|
||
loaded.num_workers, original.num_workers,
|
||
"num_workers must survive round-trip"
|
||
);
|
||
assert_eq!(loaded.seed, original.seed, "seed must survive round-trip");
|
||
}
|
||
|
||
/// A modified config with non-default values must also survive a JSON
|
||
/// round-trip.
|
||
#[test]
|
||
fn config_json_roundtrip_modified_values() {
|
||
use tempfile::tempdir;
|
||
|
||
let tmp = tempdir().expect("tempdir must be created");
|
||
let path = tmp.path().join("modified.json");
|
||
|
||
let mut cfg = TrainingConfig::default();
|
||
cfg.batch_size = 16;
|
||
cfg.learning_rate = 5e-4;
|
||
cfg.num_epochs = 100;
|
||
cfg.warmup_epochs = 10;
|
||
cfg.lr_milestones = vec![50, 80];
|
||
cfg.seed = 99;
|
||
|
||
cfg.validate().expect("modified config must be valid before serialization");
|
||
cfg.to_json(&path).expect("to_json must succeed");
|
||
|
||
let loaded = TrainingConfig::from_json(&path).expect("from_json must succeed");
|
||
|
||
assert_eq!(loaded.batch_size, 16, "batch_size must match after round-trip");
|
||
assert!(
|
||
(loaded.learning_rate - 5e-4_f64).abs() < 1e-12,
|
||
"learning_rate must match after round-trip"
|
||
);
|
||
assert_eq!(loaded.num_epochs, 100, "num_epochs must match after round-trip");
|
||
assert_eq!(loaded.warmup_epochs, 10, "warmup_epochs must match after round-trip");
|
||
assert_eq!(
|
||
loaded.lr_milestones,
|
||
vec![50, 80],
|
||
"lr_milestones must match after round-trip"
|
||
);
|
||
assert_eq!(loaded.seed, 99, "seed must match after round-trip");
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// Validation: invalid configurations are rejected
|
||
// ---------------------------------------------------------------------------
|
||
|
||
/// Setting num_subcarriers to 0 must produce a validation error.
|
||
#[test]
|
||
fn zero_num_subcarriers_is_invalid() {
|
||
let mut cfg = TrainingConfig::default();
|
||
cfg.num_subcarriers = 0;
|
||
assert!(
|
||
cfg.validate().is_err(),
|
||
"num_subcarriers = 0 must be rejected by validate()"
|
||
);
|
||
}
|
||
|
||
/// Setting native_subcarriers to 0 must produce a validation error.
|
||
#[test]
|
||
fn zero_native_subcarriers_is_invalid() {
|
||
let mut cfg = TrainingConfig::default();
|
||
cfg.native_subcarriers = 0;
|
||
assert!(
|
||
cfg.validate().is_err(),
|
||
"native_subcarriers = 0 must be rejected by validate()"
|
||
);
|
||
}
|
||
|
||
/// Setting batch_size to 0 must produce a validation error.
|
||
#[test]
|
||
fn zero_batch_size_is_invalid() {
|
||
let mut cfg = TrainingConfig::default();
|
||
cfg.batch_size = 0;
|
||
assert!(
|
||
cfg.validate().is_err(),
|
||
"batch_size = 0 must be rejected by validate()"
|
||
);
|
||
}
|
||
|
||
/// A negative learning rate must produce a validation error.
|
||
#[test]
|
||
fn negative_learning_rate_is_invalid() {
|
||
let mut cfg = TrainingConfig::default();
|
||
cfg.learning_rate = -0.001;
|
||
assert!(
|
||
cfg.validate().is_err(),
|
||
"learning_rate < 0 must be rejected by validate()"
|
||
);
|
||
}
|
||
|
||
/// warmup_epochs >= num_epochs must produce a validation error.
|
||
#[test]
|
||
fn warmup_exceeding_epochs_is_invalid() {
|
||
let mut cfg = TrainingConfig::default();
|
||
cfg.warmup_epochs = cfg.num_epochs; // equal, which is still invalid
|
||
assert!(
|
||
cfg.validate().is_err(),
|
||
"warmup_epochs >= num_epochs must be rejected by validate()"
|
||
);
|
||
}
|
||
|
||
/// All loss weights set to 0.0 must produce a validation error.
|
||
#[test]
|
||
fn all_zero_loss_weights_are_invalid() {
|
||
let mut cfg = TrainingConfig::default();
|
||
cfg.lambda_kp = 0.0;
|
||
cfg.lambda_dp = 0.0;
|
||
cfg.lambda_tr = 0.0;
|
||
assert!(
|
||
cfg.validate().is_err(),
|
||
"all-zero loss weights must be rejected by validate()"
|
||
);
|
||
}
|
||
|
||
/// Non-increasing lr_milestones must produce a validation error.
|
||
#[test]
|
||
fn non_increasing_milestones_are_invalid() {
|
||
let mut cfg = TrainingConfig::default();
|
||
cfg.lr_milestones = vec![40, 30]; // wrong order
|
||
assert!(
|
||
cfg.validate().is_err(),
|
||
"non-increasing lr_milestones must be rejected by validate()"
|
||
);
|
||
}
|
||
|
||
/// An lr_milestone beyond num_epochs must produce a validation error.
|
||
#[test]
|
||
fn milestone_beyond_num_epochs_is_invalid() {
|
||
let mut cfg = TrainingConfig::default();
|
||
cfg.lr_milestones = vec![30, cfg.num_epochs + 1];
|
||
assert!(
|
||
cfg.validate().is_err(),
|
||
"lr_milestone > num_epochs must be rejected by validate()"
|
||
);
|
||
}
|