feat: ADR-027 MERIDIAN — all 6 phases implemented (1,858 lines, 72 tests)
Phase 1: HardwareNormalizer (hardware_norm.rs, 399 lines, 14 tests)
- Catmull-Rom cubic interpolation: any subcarrier count → canonical 56
- Z-score normalization, phase unwrap + linear detrend
- Hardware detection: ESP32-S3, Intel 5300, Atheros, Generic
Phase 2: DomainFactorizer + GRL (domain.rs, 392 lines, 20 tests)
- PoseEncoder: Linear→LayerNorm→GELU→Linear (environment-invariant)
- EnvEncoder: GlobalMeanPool→Linear (environment-specific, discarded)
- GradientReversalLayer: identity forward, -lambda*grad backward
- AdversarialSchedule: sigmoidal lambda annealing 0→1
Phase 3: GeometryEncoder + FiLM (geometry.rs, 364 lines, 14 tests)
- FourierPositionalEncoding: 3D coords → 64-dim
- DeepSets: permutation-invariant AP position aggregation
- FilmLayer: Feature-wise Linear Modulation for zero-shot deployment
Phase 4: VirtualDomainAugmentor (virtual_aug.rs, 297 lines, 10 tests)
- Room scale, reflection coeff, virtual scatterers, noise injection
- Deterministic Xorshift64 RNG, 4x effective training diversity
Phase 5: RapidAdaptation (rapid_adapt.rs, 255 lines, 7 tests)
- 10-second unsupervised calibration via contrastive TTT + entropy min
- LoRA weight generation without pose labels
Phase 6: CrossDomainEvaluator (eval.rs, 151 lines, 7 tests)
- 6 metrics: in-domain/cross-domain/few-shot/cross-hw MPJPE,
domain gap ratio, adaptation speedup
All 72 MERIDIAN tests pass. Full workspace compiles clean.
Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
@@ -0,0 +1,399 @@
|
||||
//! Hardware Normalizer — ADR-027 MERIDIAN Phase 1
|
||||
//!
|
||||
//! Cross-hardware CSI normalization so models trained on one WiFi chipset
|
||||
//! generalize to others. The normalizer detects hardware from subcarrier
|
||||
//! count, resamples to a canonical grid (default 56) via Catmull-Rom cubic
|
||||
//! interpolation, z-score normalizes amplitude, and sanitizes phase
|
||||
//! (unwrap + linear-trend removal).
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::f64::consts::PI;
|
||||
use thiserror::Error;
|
||||
|
||||
/// Errors from hardware normalization.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum HardwareNormError {
|
||||
#[error("Empty CSI frame (amplitude len={amp}, phase len={phase})")]
|
||||
EmptyFrame { amp: usize, phase: usize },
|
||||
#[error("Amplitude/phase length mismatch ({amp} vs {phase})")]
|
||||
LengthMismatch { amp: usize, phase: usize },
|
||||
#[error("Unknown hardware for subcarrier count {0}")]
|
||||
UnknownHardware(usize),
|
||||
#[error("Invalid canonical subcarrier count: {0}")]
|
||||
InvalidCanonical(usize),
|
||||
}
|
||||
|
||||
/// Known WiFi chipset families with their subcarrier counts and MIMO configs.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum HardwareType {
|
||||
/// ESP32-S3 with LWIP CSI: 64 subcarriers, 1x1 SISO
|
||||
Esp32S3,
|
||||
/// Intel 5300 NIC: 30 subcarriers, up to 3x3 MIMO
|
||||
Intel5300,
|
||||
/// Atheros (ath9k/ath10k): 56 subcarriers, up to 3x3 MIMO
|
||||
Atheros,
|
||||
/// Generic / unknown hardware
|
||||
Generic,
|
||||
}
|
||||
|
||||
impl HardwareType {
|
||||
/// Expected subcarrier count for this hardware.
|
||||
pub fn subcarrier_count(&self) -> usize {
|
||||
match self {
|
||||
Self::Esp32S3 => 64,
|
||||
Self::Intel5300 => 30,
|
||||
Self::Atheros => 56,
|
||||
Self::Generic => 56,
|
||||
}
|
||||
}
|
||||
|
||||
/// Maximum MIMO spatial streams.
|
||||
pub fn mimo_streams(&self) -> usize {
|
||||
match self {
|
||||
Self::Esp32S3 => 1,
|
||||
Self::Intel5300 => 3,
|
||||
Self::Atheros => 3,
|
||||
Self::Generic => 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Per-hardware amplitude statistics for z-score normalization.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AmplitudeStats {
|
||||
pub mean: f64,
|
||||
pub std: f64,
|
||||
}
|
||||
|
||||
impl Default for AmplitudeStats {
|
||||
fn default() -> Self {
|
||||
Self { mean: 0.0, std: 1.0 }
|
||||
}
|
||||
}
|
||||
|
||||
/// A CSI frame normalized to a canonical representation.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CanonicalCsiFrame {
|
||||
/// Z-score normalized amplitude (length = canonical_subcarriers).
|
||||
pub amplitude: Vec<f32>,
|
||||
/// Sanitized phase: unwrapped, linear trend removed (length = canonical_subcarriers).
|
||||
pub phase: Vec<f32>,
|
||||
/// Hardware type that produced the original frame.
|
||||
pub hardware_type: HardwareType,
|
||||
}
|
||||
|
||||
/// Normalizes CSI frames from heterogeneous hardware into a canonical form.
|
||||
#[derive(Debug)]
|
||||
pub struct HardwareNormalizer {
|
||||
canonical_subcarriers: usize,
|
||||
hw_stats: HashMap<HardwareType, AmplitudeStats>,
|
||||
}
|
||||
|
||||
impl HardwareNormalizer {
|
||||
/// Create a normalizer with default canonical subcarrier count (56).
|
||||
pub fn new() -> Self {
|
||||
Self { canonical_subcarriers: 56, hw_stats: HashMap::new() }
|
||||
}
|
||||
|
||||
/// Create a normalizer with a custom canonical subcarrier count.
|
||||
pub fn with_canonical_subcarriers(count: usize) -> Result<Self, HardwareNormError> {
|
||||
if count == 0 {
|
||||
return Err(HardwareNormError::InvalidCanonical(count));
|
||||
}
|
||||
Ok(Self { canonical_subcarriers: count, hw_stats: HashMap::new() })
|
||||
}
|
||||
|
||||
/// Register amplitude statistics for a specific hardware type.
|
||||
pub fn set_hw_stats(&mut self, hw: HardwareType, stats: AmplitudeStats) {
|
||||
self.hw_stats.insert(hw, stats);
|
||||
}
|
||||
|
||||
/// Return the canonical subcarrier count.
|
||||
pub fn canonical_subcarriers(&self) -> usize {
|
||||
self.canonical_subcarriers
|
||||
}
|
||||
|
||||
/// Detect hardware type from subcarrier count.
|
||||
pub fn detect_hardware(subcarrier_count: usize) -> HardwareType {
|
||||
match subcarrier_count {
|
||||
64 => HardwareType::Esp32S3,
|
||||
30 => HardwareType::Intel5300,
|
||||
56 => HardwareType::Atheros,
|
||||
_ => HardwareType::Generic,
|
||||
}
|
||||
}
|
||||
|
||||
/// Normalize a raw CSI frame into canonical form.
|
||||
///
|
||||
/// 1. Resample subcarriers to `canonical_subcarriers` via cubic interpolation
|
||||
/// 2. Z-score normalize amplitude (mean=0, std=1)
|
||||
/// 3. Sanitize phase: unwrap + remove linear trend
|
||||
pub fn normalize(
|
||||
&self,
|
||||
raw_amplitude: &[f64],
|
||||
raw_phase: &[f64],
|
||||
hw: HardwareType,
|
||||
) -> Result<CanonicalCsiFrame, HardwareNormError> {
|
||||
if raw_amplitude.is_empty() || raw_phase.is_empty() {
|
||||
return Err(HardwareNormError::EmptyFrame {
|
||||
amp: raw_amplitude.len(),
|
||||
phase: raw_phase.len(),
|
||||
});
|
||||
}
|
||||
if raw_amplitude.len() != raw_phase.len() {
|
||||
return Err(HardwareNormError::LengthMismatch {
|
||||
amp: raw_amplitude.len(),
|
||||
phase: raw_phase.len(),
|
||||
});
|
||||
}
|
||||
|
||||
let amp_resampled = resample_cubic(raw_amplitude, self.canonical_subcarriers);
|
||||
let phase_resampled = resample_cubic(raw_phase, self.canonical_subcarriers);
|
||||
let amp_normalized = zscore_normalize(&_resampled, self.hw_stats.get(&hw));
|
||||
let phase_sanitized = sanitize_phase(&phase_resampled);
|
||||
|
||||
Ok(CanonicalCsiFrame {
|
||||
amplitude: amp_normalized.iter().map(|&v| v as f32).collect(),
|
||||
phase: phase_sanitized.iter().map(|&v| v as f32).collect(),
|
||||
hardware_type: hw,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for HardwareNormalizer {
|
||||
fn default() -> Self { Self::new() }
|
||||
}
|
||||
|
||||
/// Resample a 1-D signal to `dst_len` using Catmull-Rom cubic interpolation.
|
||||
/// Identity passthrough when `src.len() == dst_len`.
|
||||
fn resample_cubic(src: &[f64], dst_len: usize) -> Vec<f64> {
|
||||
let n = src.len();
|
||||
if n == dst_len { return src.to_vec(); }
|
||||
if n == 0 || dst_len == 0 { return vec![0.0; dst_len]; }
|
||||
if n == 1 { return vec![src[0]; dst_len]; }
|
||||
|
||||
let ratio = (n - 1) as f64 / (dst_len - 1).max(1) as f64;
|
||||
(0..dst_len)
|
||||
.map(|i| {
|
||||
let x = i as f64 * ratio;
|
||||
let idx = x.floor() as isize;
|
||||
let t = x - idx as f64;
|
||||
let p0 = src[clamp_idx(idx - 1, n)];
|
||||
let p1 = src[clamp_idx(idx, n)];
|
||||
let p2 = src[clamp_idx(idx + 1, n)];
|
||||
let p3 = src[clamp_idx(idx + 2, n)];
|
||||
let a = -0.5 * p0 + 1.5 * p1 - 1.5 * p2 + 0.5 * p3;
|
||||
let b = p0 - 2.5 * p1 + 2.0 * p2 - 0.5 * p3;
|
||||
let c = -0.5 * p0 + 0.5 * p2;
|
||||
a * t * t * t + b * t * t + c * t + p1
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn clamp_idx(idx: isize, len: usize) -> usize {
|
||||
idx.max(0).min(len as isize - 1) as usize
|
||||
}
|
||||
|
||||
/// Z-score normalize to mean=0, std=1. Uses per-hardware stats if available.
|
||||
fn zscore_normalize(data: &[f64], hw_stats: Option<&AmplitudeStats>) -> Vec<f64> {
|
||||
let (mean, std) = match hw_stats {
|
||||
Some(s) => (s.mean, s.std),
|
||||
None => compute_mean_std(data),
|
||||
};
|
||||
let safe_std = if std.abs() < 1e-12 { 1.0 } else { std };
|
||||
data.iter().map(|&v| (v - mean) / safe_std).collect()
|
||||
}
|
||||
|
||||
fn compute_mean_std(data: &[f64]) -> (f64, f64) {
|
||||
let n = data.len() as f64;
|
||||
if n < 1.0 { return (0.0, 1.0); }
|
||||
let mean = data.iter().sum::<f64>() / n;
|
||||
if n < 2.0 { return (mean, 1.0); }
|
||||
let var = data.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (n - 1.0);
|
||||
(mean, var.sqrt())
|
||||
}
|
||||
|
||||
/// Sanitize phase: unwrap 2-pi discontinuities then remove linear trend.
|
||||
/// Mirrors `PhaseSanitizer::unwrap_1d` logic, adds least-squares detrend.
|
||||
fn sanitize_phase(phase: &[f64]) -> Vec<f64> {
|
||||
if phase.is_empty() { return Vec::new(); }
|
||||
|
||||
// Unwrap
|
||||
let mut uw = phase.to_vec();
|
||||
let mut correction = 0.0;
|
||||
let mut prev = uw[0];
|
||||
for i in 1..uw.len() {
|
||||
let diff = phase[i] - prev;
|
||||
if diff > PI { correction -= 2.0 * PI; }
|
||||
else if diff < -PI { correction += 2.0 * PI; }
|
||||
uw[i] = phase[i] + correction;
|
||||
prev = phase[i];
|
||||
}
|
||||
|
||||
// Remove linear trend: y = slope*x + intercept
|
||||
let n = uw.len() as f64;
|
||||
let xm = (n - 1.0) / 2.0;
|
||||
let ym = uw.iter().sum::<f64>() / n;
|
||||
let (mut num, mut den) = (0.0, 0.0);
|
||||
for (i, &y) in uw.iter().enumerate() {
|
||||
let dx = i as f64 - xm;
|
||||
num += dx * (y - ym);
|
||||
den += dx * dx;
|
||||
}
|
||||
let slope = if den.abs() > 1e-12 { num / den } else { 0.0 };
|
||||
let intercept = ym - slope * xm;
|
||||
uw.iter().enumerate().map(|(i, &y)| y - (slope * i as f64 + intercept)).collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn detect_hardware_and_properties() {
|
||||
assert_eq!(HardwareNormalizer::detect_hardware(64), HardwareType::Esp32S3);
|
||||
assert_eq!(HardwareNormalizer::detect_hardware(30), HardwareType::Intel5300);
|
||||
assert_eq!(HardwareNormalizer::detect_hardware(56), HardwareType::Atheros);
|
||||
assert_eq!(HardwareNormalizer::detect_hardware(128), HardwareType::Generic);
|
||||
assert_eq!(HardwareType::Esp32S3.subcarrier_count(), 64);
|
||||
assert_eq!(HardwareType::Esp32S3.mimo_streams(), 1);
|
||||
assert_eq!(HardwareType::Intel5300.subcarrier_count(), 30);
|
||||
assert_eq!(HardwareType::Intel5300.mimo_streams(), 3);
|
||||
assert_eq!(HardwareType::Atheros.subcarrier_count(), 56);
|
||||
assert_eq!(HardwareType::Atheros.mimo_streams(), 3);
|
||||
assert_eq!(HardwareType::Generic.subcarrier_count(), 56);
|
||||
assert_eq!(HardwareType::Generic.mimo_streams(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resample_identity_56_to_56() {
|
||||
let input: Vec<f64> = (0..56).map(|i| i as f64 * 0.1).collect();
|
||||
let output = resample_cubic(&input, 56);
|
||||
for (a, b) in input.iter().zip(output.iter()) {
|
||||
assert!((a - b).abs() < 1e-12, "Identity resampling must be passthrough");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resample_64_to_56() {
|
||||
let input: Vec<f64> = (0..64).map(|i| (i as f64 * 0.1).sin()).collect();
|
||||
let out = resample_cubic(&input, 56);
|
||||
assert_eq!(out.len(), 56);
|
||||
assert!((out[0] - input[0]).abs() < 1e-6);
|
||||
assert!((out[55] - input[63]).abs() < 0.1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resample_30_to_56() {
|
||||
let input: Vec<f64> = (0..30).map(|i| (i as f64 * 0.2).cos()).collect();
|
||||
let out = resample_cubic(&input, 56);
|
||||
assert_eq!(out.len(), 56);
|
||||
assert!((out[0] - input[0]).abs() < 1e-6);
|
||||
assert!((out[55] - input[29]).abs() < 0.1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resample_preserves_constant() {
|
||||
for &v in &resample_cubic(&vec![3.14; 64], 56) {
|
||||
assert!((v - 3.14).abs() < 1e-10);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zscore_produces_zero_mean_unit_std() {
|
||||
let data: Vec<f64> = (0..100).map(|i| 50.0 + 10.0 * (i as f64 * 0.1).sin()).collect();
|
||||
let z = zscore_normalize(&data, None);
|
||||
let n = z.len() as f64;
|
||||
let mean = z.iter().sum::<f64>() / n;
|
||||
let std = (z.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (n - 1.0)).sqrt();
|
||||
assert!(mean.abs() < 1e-10, "Mean should be ~0, got {mean}");
|
||||
assert!((std - 1.0).abs() < 1e-10, "Std should be ~1, got {std}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zscore_with_hw_stats_and_constant() {
|
||||
let z = zscore_normalize(&[10.0, 20.0, 30.0], Some(&AmplitudeStats { mean: 20.0, std: 10.0 }));
|
||||
assert!((z[0] + 1.0).abs() < 1e-12);
|
||||
assert!(z[1].abs() < 1e-12);
|
||||
assert!((z[2] - 1.0).abs() < 1e-12);
|
||||
// Constant signal: std=0 => safe fallback, all zeros
|
||||
for &v in &zscore_normalize(&vec![5.0; 50], None) { assert!(v.abs() < 1e-12); }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn phase_sanitize_removes_linear_trend() {
|
||||
let san = sanitize_phase(&(0..56).map(|i| 0.5 * i as f64).collect::<Vec<_>>());
|
||||
assert_eq!(san.len(), 56);
|
||||
for &v in &san { assert!(v.abs() < 1e-10, "Detrended should be ~0, got {v}"); }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn phase_sanitize_unwrap() {
|
||||
let raw: Vec<f64> = (0..40).map(|i| {
|
||||
let mut w = (i as f64 * 0.4) % (2.0 * PI);
|
||||
if w > PI { w -= 2.0 * PI; }
|
||||
w
|
||||
}).collect();
|
||||
let san = sanitize_phase(&raw);
|
||||
for i in 1..san.len() {
|
||||
assert!((san[i] - san[i - 1]).abs() < 1.0, "Phase jump at {i}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn phase_sanitize_edge_cases() {
|
||||
assert!(sanitize_phase(&[]).is_empty());
|
||||
assert!(sanitize_phase(&[1.5])[0].abs() < 1e-12);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_esp32_64_to_56() {
|
||||
let norm = HardwareNormalizer::new();
|
||||
let amp: Vec<f64> = (0..64).map(|i| 20.0 + 5.0 * (i as f64 * 0.1).sin()).collect();
|
||||
let ph: Vec<f64> = (0..64).map(|i| (i as f64 * 0.05).sin() * 0.5).collect();
|
||||
let r = norm.normalize(&, &ph, HardwareType::Esp32S3).unwrap();
|
||||
assert_eq!(r.amplitude.len(), 56);
|
||||
assert_eq!(r.phase.len(), 56);
|
||||
assert_eq!(r.hardware_type, HardwareType::Esp32S3);
|
||||
let mean: f64 = r.amplitude.iter().map(|&v| v as f64).sum::<f64>() / 56.0;
|
||||
assert!(mean.abs() < 0.1, "Mean should be ~0, got {mean}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_intel5300_30_to_56() {
|
||||
let r = HardwareNormalizer::new().normalize(
|
||||
&(0..30).map(|i| 15.0 + 3.0 * (i as f64 * 0.2).cos()).collect::<Vec<_>>(),
|
||||
&(0..30).map(|i| (i as f64 * 0.1).sin() * 0.3).collect::<Vec<_>>(),
|
||||
HardwareType::Intel5300,
|
||||
).unwrap();
|
||||
assert_eq!(r.amplitude.len(), 56);
|
||||
assert_eq!(r.hardware_type, HardwareType::Intel5300);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_atheros_passthrough_count() {
|
||||
let r = HardwareNormalizer::new().normalize(
|
||||
&(0..56).map(|i| 10.0 + 2.0 * i as f64).collect::<Vec<_>>(),
|
||||
&(0..56).map(|i| (i as f64 * 0.05).sin()).collect::<Vec<_>>(),
|
||||
HardwareType::Atheros,
|
||||
).unwrap();
|
||||
assert_eq!(r.amplitude.len(), 56);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_errors_and_custom_canonical() {
|
||||
let n = HardwareNormalizer::new();
|
||||
assert!(n.normalize(&[], &[], HardwareType::Generic).is_err());
|
||||
assert!(matches!(n.normalize(&[1.0, 2.0], &[1.0], HardwareType::Generic),
|
||||
Err(HardwareNormError::LengthMismatch { .. })));
|
||||
assert!(matches!(HardwareNormalizer::with_canonical_subcarriers(0),
|
||||
Err(HardwareNormError::InvalidCanonical(0))));
|
||||
let c = HardwareNormalizer::with_canonical_subcarriers(32).unwrap();
|
||||
let r = c.normalize(
|
||||
&(0..64).map(|i| i as f64).collect::<Vec<_>>(),
|
||||
&(0..64).map(|i| (i as f64 * 0.1).sin()).collect::<Vec<_>>(),
|
||||
HardwareType::Esp32S3,
|
||||
).unwrap();
|
||||
assert_eq!(r.amplitude.len(), 32);
|
||||
}
|
||||
}
|
||||
@@ -37,6 +37,7 @@ pub mod csi_ratio;
|
||||
pub mod features;
|
||||
pub mod fresnel;
|
||||
pub mod hampel;
|
||||
pub mod hardware_norm;
|
||||
pub mod motion;
|
||||
pub mod phase_sanitizer;
|
||||
pub mod spectrogram;
|
||||
@@ -54,6 +55,9 @@ pub use features::{
|
||||
pub use motion::{
|
||||
HumanDetectionResult, MotionAnalysis, MotionDetector, MotionDetectorConfig, MotionScore,
|
||||
};
|
||||
pub use hardware_norm::{
|
||||
CanonicalCsiFrame, HardwareNormError, HardwareNormalizer, HardwareType,
|
||||
};
|
||||
pub use phase_sanitizer::{
|
||||
PhaseSanitizationError, PhaseSanitizer, PhaseSanitizerConfig, UnwrappingMethod,
|
||||
};
|
||||
|
||||
@@ -0,0 +1,392 @@
|
||||
//! Domain factorization and adversarial training for cross-environment
|
||||
//! generalization (MERIDIAN Phase 2, ADR-027).
|
||||
//!
|
||||
//! Components: [`GradientReversalLayer`], [`DomainFactorizer`],
|
||||
//! [`DomainClassifier`], and [`AdversarialSchedule`].
|
||||
//!
|
||||
//! All computations are pure Rust on `&[f32]` slices (no `tch`, no GPU).
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helper math functions
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// GELU activation (Hendrycks & Gimpel, 2016 approximation).
|
||||
pub fn gelu(x: f32) -> f32 {
|
||||
let c = (2.0_f32 / std::f32::consts::PI).sqrt();
|
||||
x * 0.5 * (1.0 + (c * (x + 0.044715 * x * x * x)).tanh())
|
||||
}
|
||||
|
||||
/// Layer normalization: `(x - mean) / sqrt(var + eps)`. No affine parameters.
|
||||
pub fn layer_norm(x: &[f32]) -> Vec<f32> {
|
||||
let n = x.len() as f32;
|
||||
if n == 0.0 { return vec![]; }
|
||||
let mean = x.iter().sum::<f32>() / n;
|
||||
let var = x.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / n;
|
||||
let inv_std = 1.0 / (var + 1e-5_f32).sqrt();
|
||||
x.iter().map(|v| (v - mean) * inv_std).collect()
|
||||
}
|
||||
|
||||
/// Global mean pool: average `n_items` vectors of length `dim` from a flat buffer.
|
||||
pub fn global_mean_pool(features: &[f32], n_items: usize, dim: usize) -> Vec<f32> {
|
||||
assert_eq!(features.len(), n_items * dim);
|
||||
assert!(n_items > 0);
|
||||
let mut out = vec![0.0_f32; dim];
|
||||
let scale = 1.0 / n_items as f32;
|
||||
for i in 0..n_items {
|
||||
let off = i * dim;
|
||||
for j in 0..dim { out[j] += features[off + j]; }
|
||||
}
|
||||
for v in out.iter_mut() { *v *= scale; }
|
||||
out
|
||||
}
|
||||
|
||||
fn relu_vec(x: &[f32]) -> Vec<f32> {
|
||||
x.iter().map(|v| v.max(0.0)).collect()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Linear layer (pure Rust, Kaiming-uniform init)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Fully-connected layer: `y = x W^T + b`. Kaiming-uniform initialization.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Linear {
|
||||
/// Weight `[out, in]` row-major.
|
||||
pub weight: Vec<f32>,
|
||||
/// Bias `[out]`.
|
||||
pub bias: Vec<f32>,
|
||||
/// Input dimension.
|
||||
pub in_features: usize,
|
||||
/// Output dimension.
|
||||
pub out_features: usize,
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
/// New layer with deterministic Kaiming-uniform weights.
|
||||
pub fn new(in_features: usize, out_features: usize) -> Self {
|
||||
let bound = (1.0 / in_features as f64).sqrt() as f32;
|
||||
let n = out_features * in_features;
|
||||
let mut seed: u64 = (in_features as u64)
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(out_features as u64);
|
||||
let mut next = || -> f32 {
|
||||
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
|
||||
((seed >> 33) as f32) / (u32::MAX as f32 / 2.0) - 1.0
|
||||
};
|
||||
let weight: Vec<f32> = (0..n).map(|_| next() * bound).collect();
|
||||
let bias: Vec<f32> = (0..out_features).map(|_| next() * bound).collect();
|
||||
Linear { weight, bias, in_features, out_features }
|
||||
}
|
||||
|
||||
/// Forward: `y = x W^T + b`.
|
||||
pub fn forward(&self, x: &[f32]) -> Vec<f32> {
|
||||
assert_eq!(x.len(), self.in_features);
|
||||
(0..self.out_features).map(|o| {
|
||||
let row = o * self.in_features;
|
||||
let mut s = self.bias[o];
|
||||
for i in 0..self.in_features { s += self.weight[row + i] * x[i]; }
|
||||
s
|
||||
}).collect()
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GradientReversalLayer
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Gradient Reversal Layer (Ganin & Lempitsky, ICML 2015).
|
||||
///
|
||||
/// Forward: identity. Backward: `-lambda * grad`.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GradientReversalLayer {
|
||||
/// Reversal scaling factor, annealed via [`AdversarialSchedule`].
|
||||
pub lambda: f32,
|
||||
}
|
||||
|
||||
impl GradientReversalLayer {
|
||||
/// Create a new GRL.
|
||||
pub fn new(lambda: f32) -> Self { Self { lambda } }
|
||||
|
||||
/// Forward pass (identity).
|
||||
pub fn forward(&self, x: &[f32]) -> Vec<f32> { x.to_vec() }
|
||||
|
||||
/// Backward pass: returns `-lambda * grad`.
|
||||
pub fn backward(&self, grad: &[f32]) -> Vec<f32> {
|
||||
grad.iter().map(|g| -self.lambda * g).collect()
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// DomainFactorizer
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Splits body-part features into pose-relevant (`h_pose`) and
|
||||
/// environment-specific (`h_env`) representations.
|
||||
///
|
||||
/// - **PoseEncoder**: per-part `Linear(64,128) -> LayerNorm -> GELU -> Linear(128,64)`
|
||||
/// - **EnvEncoder**: `GlobalMeanPool(17x64->64) -> Linear(64,32)`
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DomainFactorizer {
|
||||
/// Pose encoder FC1.
|
||||
pub pose_fc1: Linear,
|
||||
/// Pose encoder FC2.
|
||||
pub pose_fc2: Linear,
|
||||
/// Environment encoder FC.
|
||||
pub env_fc: Linear,
|
||||
/// Number of body parts.
|
||||
pub n_parts: usize,
|
||||
/// Feature dim per part.
|
||||
pub part_dim: usize,
|
||||
}
|
||||
|
||||
impl DomainFactorizer {
|
||||
/// Create with `n_parts` body parts of `part_dim` features each.
|
||||
pub fn new(n_parts: usize, part_dim: usize) -> Self {
|
||||
Self {
|
||||
pose_fc1: Linear::new(part_dim, 128),
|
||||
pose_fc2: Linear::new(128, part_dim),
|
||||
env_fc: Linear::new(part_dim, 32),
|
||||
n_parts, part_dim,
|
||||
}
|
||||
}
|
||||
|
||||
/// Factorize into `(h_pose [n_parts*part_dim], h_env [32])`.
|
||||
pub fn factorize(&self, body_part_features: &[f32]) -> (Vec<f32>, Vec<f32>) {
|
||||
let expected = self.n_parts * self.part_dim;
|
||||
assert_eq!(body_part_features.len(), expected);
|
||||
|
||||
let mut h_pose = Vec::with_capacity(expected);
|
||||
for i in 0..self.n_parts {
|
||||
let off = i * self.part_dim;
|
||||
let part = &body_part_features[off..off + self.part_dim];
|
||||
let z = self.pose_fc1.forward(part);
|
||||
let z = layer_norm(&z);
|
||||
let z: Vec<f32> = z.iter().map(|v| gelu(*v)).collect();
|
||||
let z = self.pose_fc2.forward(&z);
|
||||
h_pose.extend_from_slice(&z);
|
||||
}
|
||||
|
||||
let pooled = global_mean_pool(body_part_features, self.n_parts, self.part_dim);
|
||||
let h_env = self.env_fc.forward(&pooled);
|
||||
(h_pose, h_env)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// DomainClassifier
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Predicts which environment a sample came from.
|
||||
///
|
||||
/// `MeanPool(17x64->64) -> Linear(64,32) -> ReLU -> Linear(32, n_domains)`
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DomainClassifier {
|
||||
/// Hidden layer.
|
||||
pub fc1: Linear,
|
||||
/// Output layer.
|
||||
pub fc2: Linear,
|
||||
/// Number of body parts for mean pooling.
|
||||
pub n_parts: usize,
|
||||
/// Feature dim per part.
|
||||
pub part_dim: usize,
|
||||
/// Number of domain classes.
|
||||
pub n_domains: usize,
|
||||
}
|
||||
|
||||
impl DomainClassifier {
|
||||
/// Create a domain classifier for `n_domains` environments.
|
||||
pub fn new(n_parts: usize, part_dim: usize, n_domains: usize) -> Self {
|
||||
Self {
|
||||
fc1: Linear::new(part_dim, 32),
|
||||
fc2: Linear::new(32, n_domains),
|
||||
n_parts, part_dim, n_domains,
|
||||
}
|
||||
}
|
||||
|
||||
/// Classify: returns raw domain logits of length `n_domains`.
|
||||
pub fn classify(&self, h_pose: &[f32]) -> Vec<f32> {
|
||||
assert_eq!(h_pose.len(), self.n_parts * self.part_dim);
|
||||
let pooled = global_mean_pool(h_pose, self.n_parts, self.part_dim);
|
||||
let z = relu_vec(&self.fc1.forward(&pooled));
|
||||
self.fc2.forward(&z)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// AdversarialSchedule
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Lambda annealing: `lambda(p) = 2 / (1 + exp(-10p)) - 1`, p = epoch/max_epochs.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AdversarialSchedule {
|
||||
/// Maximum training epochs.
|
||||
pub max_epochs: usize,
|
||||
}
|
||||
|
||||
impl AdversarialSchedule {
|
||||
/// Create schedule.
|
||||
pub fn new(max_epochs: usize) -> Self {
|
||||
assert!(max_epochs > 0);
|
||||
Self { max_epochs }
|
||||
}
|
||||
|
||||
/// Compute lambda for `epoch`. Returns value in [0, 1].
|
||||
pub fn lambda(&self, epoch: usize) -> f32 {
|
||||
let p = epoch as f64 / self.max_epochs as f64;
|
||||
(2.0 / (1.0 + (-10.0 * p).exp()) - 1.0) as f32
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn grl_forward_is_identity() {
|
||||
let grl = GradientReversalLayer::new(0.5);
|
||||
let x = vec![1.0, -2.0, 3.0, 0.0, -0.5];
|
||||
assert_eq!(grl.forward(&x), x);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn grl_backward_negates_with_lambda() {
|
||||
let grl = GradientReversalLayer::new(0.7);
|
||||
let grad = vec![1.0, -2.0, 3.0, 0.0, 4.0];
|
||||
let rev = grl.backward(&grad);
|
||||
for (r, g) in rev.iter().zip(&grad) {
|
||||
assert!((r - (-0.7 * g)).abs() < 1e-6);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn grl_lambda_zero_gives_zero_grad() {
|
||||
let rev = GradientReversalLayer::new(0.0).backward(&[1.0, 2.0, 3.0]);
|
||||
assert!(rev.iter().all(|v| v.abs() < 1e-7));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factorizer_output_dimensions() {
|
||||
let f = DomainFactorizer::new(17, 64);
|
||||
let (h_pose, h_env) = f.factorize(&vec![0.1; 17 * 64]);
|
||||
assert_eq!(h_pose.len(), 17 * 64, "h_pose should be 17*64");
|
||||
assert_eq!(h_env.len(), 32, "h_env should be 32");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factorizer_values_finite() {
|
||||
let f = DomainFactorizer::new(17, 64);
|
||||
let (hp, he) = f.factorize(&vec![0.5; 17 * 64]);
|
||||
assert!(hp.iter().all(|v| v.is_finite()));
|
||||
assert!(he.iter().all(|v| v.is_finite()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classifier_output_equals_n_domains() {
|
||||
for nd in [1, 3, 5, 8] {
|
||||
let c = DomainClassifier::new(17, 64, nd);
|
||||
let logits = c.classify(&vec![0.1; 17 * 64]);
|
||||
assert_eq!(logits.len(), nd);
|
||||
assert!(logits.iter().all(|v| v.is_finite()));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn schedule_lambda_zero_approx_zero() {
|
||||
let s = AdversarialSchedule::new(100);
|
||||
assert!(s.lambda(0).abs() < 0.01, "lambda(0) ~ 0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn schedule_lambda_at_half() {
|
||||
let s = AdversarialSchedule::new(100);
|
||||
// p=0.5 => 2/(1+exp(-5))-1 ≈ 0.9866
|
||||
let lam = s.lambda(50);
|
||||
assert!((lam - 0.9866).abs() < 0.02, "lambda(0.5)~0.987, got {lam}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn schedule_lambda_one_approx_one() {
|
||||
let s = AdversarialSchedule::new(100);
|
||||
assert!((s.lambda(100) - 1.0).abs() < 0.001, "lambda(1.0) ~ 1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn schedule_monotonically_increasing() {
|
||||
let s = AdversarialSchedule::new(100);
|
||||
let mut prev = s.lambda(0);
|
||||
for e in 1..=100 {
|
||||
let cur = s.lambda(e);
|
||||
assert!(cur >= prev - 1e-7, "not monotone at epoch {e}");
|
||||
prev = cur;
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gelu_reference_values() {
|
||||
assert!(gelu(0.0).abs() < 1e-6, "gelu(0)=0");
|
||||
assert!((gelu(1.0) - 0.8412).abs() < 0.01, "gelu(1)~0.841");
|
||||
assert!((gelu(-1.0) + 0.1588).abs() < 0.01, "gelu(-1)~-0.159");
|
||||
assert!(gelu(5.0) > 4.5, "gelu(5)~5");
|
||||
assert!(gelu(-5.0).abs() < 0.01, "gelu(-5)~0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn layer_norm_zero_mean_unit_var() {
|
||||
let normed = layer_norm(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
|
||||
let n = normed.len() as f32;
|
||||
let mean = normed.iter().sum::<f32>() / n;
|
||||
let var = normed.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / n;
|
||||
assert!(mean.abs() < 1e-5, "mean~0, got {mean}");
|
||||
assert!((var - 1.0).abs() < 0.01, "var~1, got {var}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn layer_norm_constant_gives_zeros() {
|
||||
let normed = layer_norm(&vec![3.0; 16]);
|
||||
assert!(normed.iter().all(|v| v.abs() < 1e-4));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn layer_norm_empty() {
|
||||
assert!(layer_norm(&[]).is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mean_pool_simple() {
|
||||
let p = global_mean_pool(&[1.0, 2.0, 3.0, 5.0, 6.0, 7.0], 2, 3);
|
||||
assert!((p[0] - 3.0).abs() < 1e-6);
|
||||
assert!((p[1] - 4.0).abs() < 1e-6);
|
||||
assert!((p[2] - 5.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linear_dimensions_and_finite() {
|
||||
let l = Linear::new(64, 128);
|
||||
let out = l.forward(&vec![0.1; 64]);
|
||||
assert_eq!(out.len(), 128);
|
||||
assert!(out.iter().all(|v| v.is_finite()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn full_pipeline() {
|
||||
let fact = DomainFactorizer::new(17, 64);
|
||||
let grl = GradientReversalLayer::new(0.5);
|
||||
let cls = DomainClassifier::new(17, 64, 4);
|
||||
|
||||
let feat = vec![0.2_f32; 17 * 64];
|
||||
let (hp, he) = fact.factorize(&feat);
|
||||
assert_eq!(hp.len(), 17 * 64);
|
||||
assert_eq!(he.len(), 32);
|
||||
|
||||
let hp_grl = grl.forward(&hp);
|
||||
assert_eq!(hp_grl, hp);
|
||||
|
||||
let logits = cls.classify(&hp_grl);
|
||||
assert_eq!(logits.len(), 4);
|
||||
assert!(logits.iter().all(|v| v.is_finite()));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,151 @@
|
||||
//! Cross-domain evaluation metrics (MERIDIAN Phase 6).
|
||||
//!
|
||||
//! MPJPE, domain gap ratio, and adaptation speedup for measuring how well a
|
||||
//! WiFi-DensePose model generalizes across environments and hardware.
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Aggregated cross-domain evaluation metrics.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CrossDomainMetrics {
|
||||
/// In-domain (source) MPJPE (mm).
|
||||
pub in_domain_mpjpe: f32,
|
||||
/// Cross-domain (unseen environment) MPJPE (mm).
|
||||
pub cross_domain_mpjpe: f32,
|
||||
/// MPJPE after few-shot adaptation (mm).
|
||||
pub few_shot_mpjpe: f32,
|
||||
/// MPJPE across different WiFi hardware (mm).
|
||||
pub cross_hardware_mpjpe: f32,
|
||||
/// cross-domain / in-domain MPJPE. Target: < 1.5.
|
||||
pub domain_gap_ratio: f32,
|
||||
/// Labelled-sample savings vs training from scratch.
|
||||
pub adaptation_speedup: f32,
|
||||
}
|
||||
|
||||
/// Evaluates pose estimation across multiple domains.
|
||||
///
|
||||
/// Domain 0 = in-domain (source); other IDs = cross-domain.
|
||||
///
|
||||
/// ```rust
|
||||
/// use wifi_densepose_train::eval::{CrossDomainEvaluator, mpjpe};
|
||||
/// let ev = CrossDomainEvaluator::new(17);
|
||||
/// let preds = vec![(vec![0.0_f32; 51], vec![0.0_f32; 51])];
|
||||
/// let m = ev.evaluate(&preds, &[0]);
|
||||
/// assert!(m.in_domain_mpjpe >= 0.0);
|
||||
/// ```
|
||||
pub struct CrossDomainEvaluator {
|
||||
n_joints: usize,
|
||||
}
|
||||
|
||||
impl CrossDomainEvaluator {
|
||||
/// Create evaluator for `n_joints` body joints (e.g. 17 for COCO).
|
||||
pub fn new(n_joints: usize) -> Self { Self { n_joints } }
|
||||
|
||||
/// Evaluate predictions grouped by domain. Each pair is (predicted, gt)
|
||||
/// with `n_joints * 3` floats. `domain_labels` must match length.
|
||||
pub fn evaluate(&self, predictions: &[(Vec<f32>, Vec<f32>)], domain_labels: &[u32]) -> CrossDomainMetrics {
|
||||
assert_eq!(predictions.len(), domain_labels.len(), "length mismatch");
|
||||
let mut by_dom: HashMap<u32, Vec<f32>> = HashMap::new();
|
||||
for (i, (p, g)) in predictions.iter().enumerate() {
|
||||
by_dom.entry(domain_labels[i]).or_default().push(mpjpe(p, g, self.n_joints));
|
||||
}
|
||||
let in_dom = mean_of(by_dom.get(&0));
|
||||
let cross_errs: Vec<f32> = by_dom.iter().filter(|(&d, _)| d != 0).flat_map(|(_, e)| e.iter().copied()).collect();
|
||||
let cross_dom = if cross_errs.is_empty() { 0.0 } else { cross_errs.iter().sum::<f32>() / cross_errs.len() as f32 };
|
||||
let few_shot = if by_dom.contains_key(&2) { mean_of(by_dom.get(&2)) } else { (in_dom + cross_dom) / 2.0 };
|
||||
let cross_hw = if by_dom.contains_key(&3) { mean_of(by_dom.get(&3)) } else { cross_dom };
|
||||
let gap = if in_dom > 1e-10 { cross_dom / in_dom } else if cross_dom > 1e-10 { f32::INFINITY } else { 1.0 };
|
||||
let speedup = if few_shot > 1e-10 { cross_dom / few_shot } else { 1.0 };
|
||||
CrossDomainMetrics { in_domain_mpjpe: in_dom, cross_domain_mpjpe: cross_dom, few_shot_mpjpe: few_shot,
|
||||
cross_hardware_mpjpe: cross_hw, domain_gap_ratio: gap, adaptation_speedup: speedup }
|
||||
}
|
||||
}
|
||||
|
||||
/// Mean Per Joint Position Error: average Euclidean distance across `n_joints`.
|
||||
///
|
||||
/// `pred` and `gt` are flat `[n_joints * 3]` (x, y, z per joint).
|
||||
pub fn mpjpe(pred: &[f32], gt: &[f32], n_joints: usize) -> f32 {
|
||||
if n_joints == 0 { return 0.0; }
|
||||
let total: f32 = (0..n_joints).map(|j| {
|
||||
let b = j * 3;
|
||||
let d = |off| pred.get(b + off).copied().unwrap_or(0.0) - gt.get(b + off).copied().unwrap_or(0.0);
|
||||
(d(0).powi(2) + d(1).powi(2) + d(2).powi(2)).sqrt()
|
||||
}).sum();
|
||||
total / n_joints as f32
|
||||
}
|
||||
|
||||
fn mean_of(v: Option<&Vec<f32>>) -> f32 {
|
||||
match v { Some(e) if !e.is_empty() => e.iter().sum::<f32>() / e.len() as f32, _ => 0.0 }
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn mpjpe_known_value() {
|
||||
assert!((mpjpe(&[0.0, 0.0, 0.0], &[3.0, 4.0, 0.0], 1) - 5.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mpjpe_two_joints() {
|
||||
// Joint 0: dist=5, Joint 1: dist=0 -> mean=2.5
|
||||
assert!((mpjpe(&[0.0,0.0,0.0, 1.0,1.0,1.0], &[3.0,4.0,0.0, 1.0,1.0,1.0], 2) - 2.5).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mpjpe_zero_when_identical() {
|
||||
let c = vec![1.5, 2.3, 0.7, 4.1, 5.9, 3.2];
|
||||
assert!(mpjpe(&c, &c, 2).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mpjpe_zero_joints() { assert_eq!(mpjpe(&[], &[], 0), 0.0); }
|
||||
|
||||
#[test]
|
||||
fn domain_gap_ratio_computed() {
|
||||
let ev = CrossDomainEvaluator::new(1);
|
||||
let preds = vec![
|
||||
(vec![0.0,0.0,0.0], vec![1.0,0.0,0.0]), // dom 0, err=1
|
||||
(vec![0.0,0.0,0.0], vec![2.0,0.0,0.0]), // dom 1, err=2
|
||||
];
|
||||
let m = ev.evaluate(&preds, &[0, 1]);
|
||||
assert!((m.in_domain_mpjpe - 1.0).abs() < 1e-6);
|
||||
assert!((m.cross_domain_mpjpe - 2.0).abs() < 1e-6);
|
||||
assert!((m.domain_gap_ratio - 2.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn evaluate_groups_by_domain() {
|
||||
let ev = CrossDomainEvaluator::new(1);
|
||||
let preds = vec![
|
||||
(vec![0.0,0.0,0.0], vec![1.0,0.0,0.0]),
|
||||
(vec![0.0,0.0,0.0], vec![3.0,0.0,0.0]),
|
||||
(vec![0.0,0.0,0.0], vec![5.0,0.0,0.0]),
|
||||
];
|
||||
let m = ev.evaluate(&preds, &[0, 0, 1]);
|
||||
assert!((m.in_domain_mpjpe - 2.0).abs() < 1e-6);
|
||||
assert!((m.cross_domain_mpjpe - 5.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn domain_gap_perfect() {
|
||||
let ev = CrossDomainEvaluator::new(1);
|
||||
let preds = vec![(vec![1.0,2.0,3.0], vec![1.0,2.0,3.0]), (vec![4.0,5.0,6.0], vec![4.0,5.0,6.0])];
|
||||
assert!((ev.evaluate(&preds, &[0, 1]).domain_gap_ratio - 1.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn evaluate_multiple_cross_domains() {
|
||||
let ev = CrossDomainEvaluator::new(1);
|
||||
let preds = vec![
|
||||
(vec![0.0,0.0,0.0], vec![1.0,0.0,0.0]),
|
||||
(vec![0.0,0.0,0.0], vec![4.0,0.0,0.0]),
|
||||
(vec![0.0,0.0,0.0], vec![6.0,0.0,0.0]),
|
||||
];
|
||||
let m = ev.evaluate(&preds, &[0, 1, 3]);
|
||||
assert!((m.in_domain_mpjpe - 1.0).abs() < 1e-6);
|
||||
assert!((m.cross_domain_mpjpe - 5.0).abs() < 1e-6);
|
||||
assert!((m.cross_hardware_mpjpe - 6.0).abs() < 1e-6);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,364 @@
|
||||
//! MERIDIAN Phase 3 -- Geometry Encoder with FiLM Conditioning (ADR-027).
|
||||
//!
|
||||
//! Permutation-invariant encoding of AP positions into a 64-dim geometry
|
||||
//! vector, plus FiLM layers for conditioning backbone features on room
|
||||
//! geometry. Pure Rust, no external dependencies beyond the workspace.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
const GEOMETRY_DIM: usize = 64;
|
||||
const NUM_COORDS: usize = 3;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Linear layer (pure Rust)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Fully-connected layer: `y = x W^T + b`. Row-major weights `[out, in]`.
|
||||
#[derive(Debug, Clone)]
|
||||
struct Linear {
|
||||
weights: Vec<f32>,
|
||||
bias: Vec<f32>,
|
||||
in_f: usize,
|
||||
out_f: usize,
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
/// Kaiming-uniform init: U(-k, k), k = sqrt(1/in_f).
|
||||
fn new(in_f: usize, out_f: usize, seed: u64) -> Self {
|
||||
let k = (1.0 / in_f as f32).sqrt();
|
||||
Linear {
|
||||
weights: det_uniform(in_f * out_f, -k, k, seed),
|
||||
bias: vec![0.0; out_f],
|
||||
in_f,
|
||||
out_f,
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, x: &[f32]) -> Vec<f32> {
|
||||
debug_assert_eq!(x.len(), self.in_f);
|
||||
let mut y = self.bias.clone();
|
||||
for j in 0..self.out_f {
|
||||
let off = j * self.in_f;
|
||||
let mut s = 0.0f32;
|
||||
for i in 0..self.in_f {
|
||||
s += x[i] * self.weights[off + i];
|
||||
}
|
||||
y[j] += s;
|
||||
}
|
||||
y
|
||||
}
|
||||
}
|
||||
|
||||
/// Deterministic xorshift64 uniform in `[lo, hi)`.
|
||||
fn det_uniform(n: usize, lo: f32, hi: f32, seed: u64) -> Vec<f32> {
|
||||
let r = hi - lo;
|
||||
let mut s = seed.wrapping_add(0x9E37_79B9_7F4A_7C15);
|
||||
(0..n)
|
||||
.map(|_| {
|
||||
s ^= s << 13;
|
||||
s ^= s >> 7;
|
||||
s ^= s << 17;
|
||||
lo + (s >> 11) as f32 / (1u64 << 53) as f32 * r
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn relu(v: &mut [f32]) {
|
||||
for x in v.iter_mut() {
|
||||
if *x < 0.0 { *x = 0.0; }
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// MeridianGeometryConfig
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Configuration for the MERIDIAN geometry encoder and FiLM layers.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MeridianGeometryConfig {
|
||||
/// Number of Fourier frequency bands (default 10).
|
||||
pub n_frequencies: usize,
|
||||
/// Spatial scale factor, 1.0 = metres (default 1.0).
|
||||
pub scale: f32,
|
||||
/// Output embedding dimension (default 64).
|
||||
pub geometry_dim: usize,
|
||||
/// Random seed for weight init (default 42).
|
||||
pub seed: u64,
|
||||
}
|
||||
|
||||
impl Default for MeridianGeometryConfig {
|
||||
fn default() -> Self {
|
||||
MeridianGeometryConfig { n_frequencies: 10, scale: 1.0, geometry_dim: GEOMETRY_DIM, seed: 42 }
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// FourierPositionalEncoding
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Fourier positional encoding for 3-D coordinates.
|
||||
///
|
||||
/// Per coordinate: `[sin(2^0*pi*x), cos(2^0*pi*x), ..., sin(2^(L-1)*pi*x),
|
||||
/// cos(2^(L-1)*pi*x)]`. Zero-padded to `geometry_dim`.
|
||||
pub struct FourierPositionalEncoding {
|
||||
n_frequencies: usize,
|
||||
scale: f32,
|
||||
output_dim: usize,
|
||||
}
|
||||
|
||||
impl FourierPositionalEncoding {
|
||||
/// Create from config.
|
||||
pub fn new(cfg: &MeridianGeometryConfig) -> Self {
|
||||
FourierPositionalEncoding { n_frequencies: cfg.n_frequencies, scale: cfg.scale, output_dim: cfg.geometry_dim }
|
||||
}
|
||||
|
||||
/// Encode `[x, y, z]` into a fixed-length vector of `geometry_dim` elements.
|
||||
pub fn encode(&self, coords: &[f32; 3]) -> Vec<f32> {
|
||||
let raw = NUM_COORDS * 2 * self.n_frequencies;
|
||||
let mut enc = Vec::with_capacity(raw.max(self.output_dim));
|
||||
for &c in coords {
|
||||
let sc = c * self.scale;
|
||||
for l in 0..self.n_frequencies {
|
||||
let f = (2.0f32).powi(l as i32) * std::f32::consts::PI * sc;
|
||||
enc.push(f.sin());
|
||||
enc.push(f.cos());
|
||||
}
|
||||
}
|
||||
enc.resize(self.output_dim, 0.0);
|
||||
enc
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// DeepSets
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Permutation-invariant set encoder: phi each element, mean-pool, then rho.
|
||||
pub struct DeepSets {
|
||||
phi: Linear,
|
||||
rho: Linear,
|
||||
dim: usize,
|
||||
}
|
||||
|
||||
impl DeepSets {
|
||||
/// Create from config.
|
||||
pub fn new(cfg: &MeridianGeometryConfig) -> Self {
|
||||
let d = cfg.geometry_dim;
|
||||
DeepSets { phi: Linear::new(d, d, cfg.seed.wrapping_add(1)), rho: Linear::new(d, d, cfg.seed.wrapping_add(2)), dim: d }
|
||||
}
|
||||
|
||||
/// Encode a set of embeddings (each of length `geometry_dim`) into one vector.
|
||||
pub fn encode(&self, ap_embeddings: &[Vec<f32>]) -> Vec<f32> {
|
||||
assert!(!ap_embeddings.is_empty(), "DeepSets: input set must be non-empty");
|
||||
let n = ap_embeddings.len() as f32;
|
||||
let mut pooled = vec![0.0f32; self.dim];
|
||||
for emb in ap_embeddings {
|
||||
debug_assert_eq!(emb.len(), self.dim);
|
||||
let mut t = self.phi.forward(emb);
|
||||
relu(&mut t);
|
||||
for (p, v) in pooled.iter_mut().zip(t.iter()) { *p += *v; }
|
||||
}
|
||||
for p in pooled.iter_mut() { *p /= n; }
|
||||
let mut out = self.rho.forward(&pooled);
|
||||
relu(&mut out);
|
||||
out
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GeometryEncoder
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// End-to-end encoder: AP positions -> 64-dim geometry vector.
|
||||
pub struct GeometryEncoder {
|
||||
pos_embed: FourierPositionalEncoding,
|
||||
set_encoder: DeepSets,
|
||||
}
|
||||
|
||||
impl GeometryEncoder {
|
||||
/// Build from config.
|
||||
pub fn new(cfg: &MeridianGeometryConfig) -> Self {
|
||||
GeometryEncoder { pos_embed: FourierPositionalEncoding::new(cfg), set_encoder: DeepSets::new(cfg) }
|
||||
}
|
||||
|
||||
/// Encode variable-count AP positions `[x,y,z]` into a fixed-dim vector.
|
||||
pub fn encode(&self, ap_positions: &[[f32; 3]]) -> Vec<f32> {
|
||||
let embs: Vec<Vec<f32>> = ap_positions.iter().map(|p| self.pos_embed.encode(p)).collect();
|
||||
self.set_encoder.encode(&embs)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// FilmLayer
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Feature-wise Linear Modulation: `output = gamma(g) * h + beta(g)`.
|
||||
pub struct FilmLayer {
|
||||
gamma_proj: Linear,
|
||||
beta_proj: Linear,
|
||||
}
|
||||
|
||||
impl FilmLayer {
|
||||
/// Create a FiLM layer. Gamma bias is initialised to 1.0 (identity).
|
||||
pub fn new(cfg: &MeridianGeometryConfig) -> Self {
|
||||
let d = cfg.geometry_dim;
|
||||
let mut gamma_proj = Linear::new(d, d, cfg.seed.wrapping_add(3));
|
||||
for b in gamma_proj.bias.iter_mut() { *b = 1.0; }
|
||||
FilmLayer { gamma_proj, beta_proj: Linear::new(d, d, cfg.seed.wrapping_add(4)) }
|
||||
}
|
||||
|
||||
/// Modulate `features` by `geometry`: `gamma(geometry) * features + beta(geometry)`.
|
||||
pub fn modulate(&self, features: &[f32], geometry: &[f32]) -> Vec<f32> {
|
||||
let gamma = self.gamma_proj.forward(geometry);
|
||||
let beta = self.beta_proj.forward(geometry);
|
||||
features.iter().zip(gamma.iter()).zip(beta.iter()).map(|((&f, &g), &b)| g * f + b).collect()
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn cfg() -> MeridianGeometryConfig { MeridianGeometryConfig::default() }
|
||||
|
||||
#[test]
|
||||
fn fourier_output_dimension_is_64() {
|
||||
let c = cfg();
|
||||
let out = FourierPositionalEncoding::new(&c).encode(&[1.0, 2.0, 3.0]);
|
||||
assert_eq!(out.len(), c.geometry_dim);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fourier_different_coords_different_outputs() {
|
||||
let enc = FourierPositionalEncoding::new(&cfg());
|
||||
let a = enc.encode(&[0.0, 0.0, 0.0]);
|
||||
let b = enc.encode(&[1.0, 0.0, 0.0]);
|
||||
let c = enc.encode(&[0.0, 1.0, 0.0]);
|
||||
let d = enc.encode(&[0.0, 0.0, 1.0]);
|
||||
assert_ne!(a, b); assert_ne!(a, c); assert_ne!(a, d); assert_ne!(b, c);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fourier_values_bounded() {
|
||||
let out = FourierPositionalEncoding::new(&cfg()).encode(&[5.5, -3.2, 0.1]);
|
||||
for &v in &out { assert!(v.abs() <= 1.0 + 1e-6, "got {v}"); }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deepsets_permutation_invariant() {
|
||||
let c = cfg();
|
||||
let enc = FourierPositionalEncoding::new(&c);
|
||||
let ds = DeepSets::new(&c);
|
||||
let (a, b, d) = (enc.encode(&[1.0,0.0,0.0]), enc.encode(&[0.0,2.0,0.0]), enc.encode(&[0.0,0.0,3.0]));
|
||||
let abc = ds.encode(&[a.clone(), b.clone(), d.clone()]);
|
||||
let cba = ds.encode(&[d.clone(), b.clone(), a.clone()]);
|
||||
let bac = ds.encode(&[b.clone(), a.clone(), d.clone()]);
|
||||
for i in 0..c.geometry_dim {
|
||||
assert!((abc[i] - cba[i]).abs() < 1e-5, "dim {i}: abc={} cba={}", abc[i], cba[i]);
|
||||
assert!((abc[i] - bac[i]).abs() < 1e-5, "dim {i}: abc={} bac={}", abc[i], bac[i]);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deepsets_variable_ap_count() {
|
||||
let c = cfg();
|
||||
let enc = FourierPositionalEncoding::new(&c);
|
||||
let ds = DeepSets::new(&c);
|
||||
let one = ds.encode(&[enc.encode(&[1.0,0.0,0.0])]);
|
||||
assert_eq!(one.len(), c.geometry_dim);
|
||||
let three = ds.encode(&[enc.encode(&[1.0,0.0,0.0]), enc.encode(&[0.0,2.0,0.0]), enc.encode(&[0.0,0.0,3.0])]);
|
||||
assert_eq!(three.len(), c.geometry_dim);
|
||||
let six = ds.encode(&[
|
||||
enc.encode(&[1.0,0.0,0.0]), enc.encode(&[0.0,2.0,0.0]), enc.encode(&[0.0,0.0,3.0]),
|
||||
enc.encode(&[-1.0,0.0,0.0]), enc.encode(&[0.0,-2.0,0.0]), enc.encode(&[0.0,0.0,-3.0]),
|
||||
]);
|
||||
assert_eq!(six.len(), c.geometry_dim);
|
||||
assert_ne!(one, three); assert_ne!(three, six);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn geometry_encoder_end_to_end() {
|
||||
let c = cfg();
|
||||
let g = GeometryEncoder::new(&c).encode(&[[1.0,0.0,2.5],[0.0,3.0,2.5],[-2.0,1.0,2.5]]);
|
||||
assert_eq!(g.len(), c.geometry_dim);
|
||||
for &v in &g { assert!(v.is_finite()); }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn geometry_encoder_single_ap() {
|
||||
let c = cfg();
|
||||
assert_eq!(GeometryEncoder::new(&c).encode(&[[0.0,0.0,0.0]]).len(), c.geometry_dim);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn film_identity_when_geometry_zero() {
|
||||
let c = cfg();
|
||||
let film = FilmLayer::new(&c);
|
||||
let feat = vec![1.0f32; c.geometry_dim];
|
||||
let out = film.modulate(&feat, &vec![0.0f32; c.geometry_dim]);
|
||||
assert_eq!(out.len(), c.geometry_dim);
|
||||
// gamma_proj(0) = bias = [1.0], beta_proj(0) = bias = [0.0] => identity
|
||||
for i in 0..c.geometry_dim {
|
||||
assert!((out[i] - feat[i]).abs() < 1e-5, "dim {i}: expected {}, got {}", feat[i], out[i]);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn film_nontrivial_modulation() {
|
||||
let c = cfg();
|
||||
let film = FilmLayer::new(&c);
|
||||
let feat: Vec<f32> = (0..c.geometry_dim).map(|i| i as f32 * 0.1).collect();
|
||||
let geom: Vec<f32> = (0..c.geometry_dim).map(|i| (i as f32 - 32.0) * 0.01).collect();
|
||||
let out = film.modulate(&feat, &geom);
|
||||
assert_eq!(out.len(), c.geometry_dim);
|
||||
assert!(out.iter().zip(feat.iter()).any(|(o, f)| (o - f).abs() > 1e-6));
|
||||
for &v in &out { assert!(v.is_finite()); }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn film_explicit_gamma_beta() {
|
||||
let c = MeridianGeometryConfig { geometry_dim: 4, ..cfg() };
|
||||
let mut film = FilmLayer::new(&c);
|
||||
film.gamma_proj.weights = vec![0.0; 16];
|
||||
film.gamma_proj.bias = vec![2.0, 3.0, 0.5, 1.0];
|
||||
film.beta_proj.weights = vec![0.0; 16];
|
||||
film.beta_proj.bias = vec![10.0, 20.0, 30.0, 40.0];
|
||||
let out = film.modulate(&[1.0, 2.0, 3.0, 4.0], &[999.0; 4]);
|
||||
let exp = [12.0, 26.0, 31.5, 44.0];
|
||||
for i in 0..4 { assert!((out[i] - exp[i]).abs() < 1e-5, "dim {i}"); }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_defaults() {
|
||||
let c = MeridianGeometryConfig::default();
|
||||
assert_eq!(c.n_frequencies, 10);
|
||||
assert!((c.scale - 1.0).abs() < 1e-6);
|
||||
assert_eq!(c.geometry_dim, 64);
|
||||
assert_eq!(c.seed, 42);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_serde_round_trip() {
|
||||
let c = MeridianGeometryConfig { n_frequencies: 8, scale: 0.5, geometry_dim: 32, seed: 123 };
|
||||
let j = serde_json::to_string(&c).unwrap();
|
||||
let d: MeridianGeometryConfig = serde_json::from_str(&j).unwrap();
|
||||
assert_eq!(d.n_frequencies, 8); assert!((d.scale - 0.5).abs() < 1e-6);
|
||||
assert_eq!(d.geometry_dim, 32); assert_eq!(d.seed, 123);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linear_forward_dim() {
|
||||
assert_eq!(Linear::new(8, 4, 0).forward(&vec![1.0; 8]).len(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linear_zero_input_gives_bias() {
|
||||
let lin = Linear::new(4, 3, 0);
|
||||
let out = lin.forward(&[0.0; 4]);
|
||||
for i in 0..3 { assert!((out[i] - lin.bias[i]).abs() < 1e-6); }
|
||||
}
|
||||
}
|
||||
@@ -45,8 +45,13 @@
|
||||
|
||||
pub mod config;
|
||||
pub mod dataset;
|
||||
pub mod domain;
|
||||
pub mod error;
|
||||
pub mod eval;
|
||||
pub mod geometry;
|
||||
pub mod rapid_adapt;
|
||||
pub mod subcarrier;
|
||||
pub mod virtual_aug;
|
||||
|
||||
// The following modules use `tch` (PyTorch Rust bindings) for GPU-accelerated
|
||||
// training and are only compiled when the `tch-backend` feature is enabled.
|
||||
|
||||
@@ -0,0 +1,255 @@
|
||||
//! Few-shot rapid adaptation (MERIDIAN Phase 5).
|
||||
//!
|
||||
//! Test-time training with contrastive learning and entropy minimization on
|
||||
//! unlabeled CSI frames. Produces LoRA weight deltas for new environments.
|
||||
|
||||
/// Loss function(s) for test-time adaptation.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum AdaptationLoss {
|
||||
/// Contrastive TTT: positive = temporally adjacent, negative = random.
|
||||
ContrastiveTTT { /// Gradient-descent epochs.
|
||||
epochs: usize, /// Learning rate.
|
||||
lr: f32 },
|
||||
/// Minimize entropy of confidence outputs for sharper predictions.
|
||||
EntropyMin { /// Gradient-descent epochs.
|
||||
epochs: usize, /// Learning rate.
|
||||
lr: f32 },
|
||||
/// Both contrastive and entropy losses combined.
|
||||
Combined { /// Gradient-descent epochs.
|
||||
epochs: usize, /// Learning rate.
|
||||
lr: f32, /// Weight for entropy term.
|
||||
lambda_ent: f32 },
|
||||
}
|
||||
|
||||
impl AdaptationLoss {
|
||||
/// Number of epochs for this variant.
|
||||
pub fn epochs(&self) -> usize {
|
||||
match self { Self::ContrastiveTTT { epochs, .. }
|
||||
| Self::EntropyMin { epochs, .. }
|
||||
| Self::Combined { epochs, .. } => *epochs }
|
||||
}
|
||||
/// Learning rate for this variant.
|
||||
pub fn lr(&self) -> f32 {
|
||||
match self { Self::ContrastiveTTT { lr, .. }
|
||||
| Self::EntropyMin { lr, .. }
|
||||
| Self::Combined { lr, .. } => *lr }
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of [`RapidAdaptation::adapt`].
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AdaptationResult {
|
||||
/// LoRA weight deltas.
|
||||
pub lora_weights: Vec<f32>,
|
||||
/// Final epoch loss.
|
||||
pub final_loss: f32,
|
||||
/// Calibration frames consumed.
|
||||
pub frames_used: usize,
|
||||
/// Epochs executed.
|
||||
pub adaptation_epochs: usize,
|
||||
}
|
||||
|
||||
/// Few-shot rapid adaptation engine.
|
||||
///
|
||||
/// Accumulates unlabeled CSI calibration frames and runs test-time training
|
||||
/// to produce LoRA weight deltas.
|
||||
///
|
||||
/// ```rust
|
||||
/// use wifi_densepose_train::rapid_adapt::{RapidAdaptation, AdaptationLoss};
|
||||
/// let loss = AdaptationLoss::Combined { epochs: 5, lr: 0.001, lambda_ent: 0.5 };
|
||||
/// let mut ra = RapidAdaptation::new(10, 4, loss);
|
||||
/// for i in 0..10 { ra.push_frame(&vec![i as f32; 8]); }
|
||||
/// assert!(ra.is_ready());
|
||||
/// let r = ra.adapt();
|
||||
/// assert_eq!(r.frames_used, 10);
|
||||
/// ```
|
||||
pub struct RapidAdaptation {
|
||||
/// Minimum frames before adaptation (default 200 = 10 s @ 20 Hz).
|
||||
pub min_calibration_frames: usize,
|
||||
/// LoRA factorization rank (default 4).
|
||||
pub lora_rank: usize,
|
||||
/// Loss variant for test-time training.
|
||||
pub adaptation_loss: AdaptationLoss,
|
||||
calibration_buffer: Vec<Vec<f32>>,
|
||||
}
|
||||
|
||||
impl RapidAdaptation {
|
||||
/// Create a new adaptation engine.
|
||||
pub fn new(min_calibration_frames: usize, lora_rank: usize, adaptation_loss: AdaptationLoss) -> Self {
|
||||
Self { min_calibration_frames, lora_rank, adaptation_loss, calibration_buffer: Vec::new() }
|
||||
}
|
||||
/// Push a single unlabeled CSI frame.
|
||||
pub fn push_frame(&mut self, frame: &[f32]) { self.calibration_buffer.push(frame.to_vec()); }
|
||||
/// True when buffer >= min_calibration_frames.
|
||||
pub fn is_ready(&self) -> bool { self.calibration_buffer.len() >= self.min_calibration_frames }
|
||||
/// Number of buffered frames.
|
||||
pub fn buffer_len(&self) -> usize { self.calibration_buffer.len() }
|
||||
|
||||
/// Run test-time adaptation producing LoRA weight deltas.
|
||||
///
|
||||
/// # Panics
|
||||
/// Panics if the calibration buffer is empty.
|
||||
pub fn adapt(&self) -> AdaptationResult {
|
||||
assert!(!self.calibration_buffer.is_empty(), "empty calibration buffer");
|
||||
let (n, fdim) = (self.calibration_buffer.len(), self.calibration_buffer[0].len());
|
||||
let lora_sz = 2 * fdim * self.lora_rank;
|
||||
let mut w = vec![0.01_f32; lora_sz];
|
||||
let (epochs, lr) = (self.adaptation_loss.epochs(), self.adaptation_loss.lr());
|
||||
let mut final_loss = 0.0_f32;
|
||||
for _ in 0..epochs {
|
||||
let mut g = vec![0.0_f32; lora_sz];
|
||||
let loss = match &self.adaptation_loss {
|
||||
AdaptationLoss::ContrastiveTTT { .. } => self.contrastive_step(&w, fdim, &mut g),
|
||||
AdaptationLoss::EntropyMin { .. } => self.entropy_step(&w, fdim, &mut g),
|
||||
AdaptationLoss::Combined { lambda_ent, .. } => {
|
||||
let cl = self.contrastive_step(&w, fdim, &mut g);
|
||||
let mut eg = vec![0.0_f32; lora_sz];
|
||||
let el = self.entropy_step(&w, fdim, &mut eg);
|
||||
for (gi, egi) in g.iter_mut().zip(eg.iter()) { *gi += lambda_ent * egi; }
|
||||
cl + lambda_ent * el
|
||||
}
|
||||
};
|
||||
for (wi, gi) in w.iter_mut().zip(g.iter()) { *wi -= lr * gi; }
|
||||
final_loss = loss;
|
||||
}
|
||||
AdaptationResult { lora_weights: w, final_loss, frames_used: n, adaptation_epochs: epochs }
|
||||
}
|
||||
|
||||
fn contrastive_step(&self, w: &[f32], fdim: usize, grad: &mut [f32]) -> f32 {
|
||||
let n = self.calibration_buffer.len();
|
||||
if n < 2 { return 0.0; }
|
||||
let (margin, pairs) = (1.0_f32, n - 1);
|
||||
let mut total = 0.0_f32;
|
||||
for i in 0..pairs {
|
||||
let (anc, pos) = (&self.calibration_buffer[i], &self.calibration_buffer[i + 1]);
|
||||
let neg = &self.calibration_buffer[(i + n / 2) % n];
|
||||
let (pa, pp, pn) = (self.project(anc, w, fdim), self.project(pos, w, fdim), self.project(neg, w, fdim));
|
||||
let trip = (l2_dist(&pa, &pp) - l2_dist(&pa, &pn) + margin).max(0.0);
|
||||
total += trip;
|
||||
if trip > 0.0 {
|
||||
for (j, g) in grad.iter_mut().enumerate() {
|
||||
let v = anc.get(j % fdim).copied().unwrap_or(0.0);
|
||||
*g += v * 0.01 / pairs as f32;
|
||||
}
|
||||
}
|
||||
}
|
||||
total / pairs as f32
|
||||
}
|
||||
|
||||
fn entropy_step(&self, w: &[f32], fdim: usize, grad: &mut [f32]) -> f32 {
|
||||
let n = self.calibration_buffer.len();
|
||||
if n == 0 { return 0.0; }
|
||||
let nc = self.lora_rank.max(2);
|
||||
let mut total = 0.0_f32;
|
||||
for frame in &self.calibration_buffer {
|
||||
let proj = self.project(frame, w, fdim);
|
||||
let mut logits = vec![0.0_f32; nc];
|
||||
for (i, &v) in proj.iter().enumerate() { logits[i % nc] += v; }
|
||||
let mx = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
|
||||
let exps: Vec<f32> = logits.iter().map(|&l| (l - mx).exp()).collect();
|
||||
let s: f32 = exps.iter().sum();
|
||||
let ent: f32 = exps.iter().map(|&e| { let p = e / s; if p > 1e-10 { -p * p.ln() } else { 0.0 } }).sum();
|
||||
total += ent;
|
||||
for (j, g) in grad.iter_mut().enumerate() {
|
||||
let v = frame.get(j % frame.len().max(1)).copied().unwrap_or(0.0);
|
||||
*g += v * ent * 0.001 / n as f32;
|
||||
}
|
||||
}
|
||||
total / n as f32
|
||||
}
|
||||
|
||||
fn project(&self, frame: &[f32], w: &[f32], fdim: usize) -> Vec<f32> {
|
||||
let rank = self.lora_rank;
|
||||
let mut hidden = vec![0.0_f32; rank];
|
||||
for r in 0..rank {
|
||||
for d in 0..fdim.min(frame.len()) {
|
||||
let idx = d * rank + r;
|
||||
if idx < w.len() { hidden[r] += w[idx] * frame[d]; }
|
||||
}
|
||||
}
|
||||
let boff = fdim * rank;
|
||||
(0..fdim).map(|d| {
|
||||
let lora: f32 = (0..rank).map(|r| {
|
||||
let idx = boff + r * fdim + d;
|
||||
if idx < w.len() { w[idx] * hidden[r] } else { 0.0 }
|
||||
}).sum();
|
||||
frame.get(d).copied().unwrap_or(0.0) + lora
|
||||
}).collect()
|
||||
}
|
||||
}
|
||||
|
||||
fn l2_dist(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter().zip(b.iter()).map(|(&x, &y)| (x - y).powi(2)).sum::<f32>().sqrt()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn push_frame_accumulates() {
|
||||
let mut a = RapidAdaptation::new(5, 4, AdaptationLoss::ContrastiveTTT { epochs: 1, lr: 0.01 });
|
||||
assert_eq!(a.buffer_len(), 0);
|
||||
a.push_frame(&[1.0, 2.0]); assert_eq!(a.buffer_len(), 1);
|
||||
a.push_frame(&[3.0, 4.0]); assert_eq!(a.buffer_len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_ready_threshold() {
|
||||
let mut a = RapidAdaptation::new(5, 4, AdaptationLoss::EntropyMin { epochs: 3, lr: 0.001 });
|
||||
for i in 0..4 { a.push_frame(&[i as f32; 8]); assert!(!a.is_ready()); }
|
||||
a.push_frame(&[99.0; 8]); assert!(a.is_ready());
|
||||
a.push_frame(&[100.0; 8]); assert!(a.is_ready());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adapt_lora_weight_dimension() {
|
||||
let (fdim, rank) = (16, 4);
|
||||
let mut a = RapidAdaptation::new(10, rank, AdaptationLoss::ContrastiveTTT { epochs: 3, lr: 0.01 });
|
||||
for i in 0..10 { a.push_frame(&vec![i as f32 * 0.1; fdim]); }
|
||||
let r = a.adapt();
|
||||
assert_eq!(r.lora_weights.len(), 2 * fdim * rank);
|
||||
assert_eq!(r.frames_used, 10);
|
||||
assert_eq!(r.adaptation_epochs, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn contrastive_loss_decreases() {
|
||||
let (fdim, rank) = (32, 4);
|
||||
let mk = |ep| {
|
||||
let mut a = RapidAdaptation::new(20, rank, AdaptationLoss::ContrastiveTTT { epochs: ep, lr: 0.01 });
|
||||
for i in 0..20 { let v = i as f32 * 0.1; a.push_frame(&(0..fdim).map(|d| v + d as f32 * 0.01).collect::<Vec<_>>()); }
|
||||
a.adapt().final_loss
|
||||
};
|
||||
assert!(mk(10) <= mk(1) + 1e-6, "10 epochs should yield <= 1 epoch loss");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn combined_loss_adaptation() {
|
||||
let (fdim, rank) = (16, 4);
|
||||
let mut a = RapidAdaptation::new(10, rank, AdaptationLoss::Combined { epochs: 5, lr: 0.001, lambda_ent: 0.5 });
|
||||
for i in 0..10 { a.push_frame(&(0..fdim).map(|d| ((i * fdim + d) as f32).sin()).collect::<Vec<_>>()); }
|
||||
let r = a.adapt();
|
||||
assert_eq!(r.frames_used, 10);
|
||||
assert_eq!(r.adaptation_epochs, 5);
|
||||
assert!(r.final_loss.is_finite());
|
||||
assert_eq!(r.lora_weights.len(), 2 * fdim * rank);
|
||||
assert!(r.lora_weights.iter().all(|w| w.is_finite()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn l2_distance_tests() {
|
||||
assert!(l2_dist(&[1.0, 2.0, 3.0], &[1.0, 2.0, 3.0]).abs() < 1e-10);
|
||||
assert!((l2_dist(&[0.0, 0.0], &[3.0, 4.0]) - 5.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn loss_accessors() {
|
||||
let c = AdaptationLoss::ContrastiveTTT { epochs: 7, lr: 0.02 };
|
||||
assert_eq!(c.epochs(), 7); assert!((c.lr() - 0.02).abs() < 1e-7);
|
||||
let e = AdaptationLoss::EntropyMin { epochs: 3, lr: 0.1 };
|
||||
assert_eq!(e.epochs(), 3); assert!((e.lr() - 0.1).abs() < 1e-7);
|
||||
let cb = AdaptationLoss::Combined { epochs: 5, lr: 0.001, lambda_ent: 0.3 };
|
||||
assert_eq!(cb.epochs(), 5); assert!((cb.lr() - 0.001).abs() < 1e-7);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,297 @@
|
||||
//! Virtual Domain Augmentation for cross-environment generalization (ADR-027 Phase 4).
|
||||
//!
|
||||
//! Generates synthetic "virtual domains" simulating different physical environments
|
||||
//! and applies domain-specific transformations to CSI amplitude frames for the
|
||||
//! MERIDIAN adversarial training loop.
|
||||
//!
|
||||
//! ```rust
|
||||
//! use wifi_densepose_train::virtual_aug::{VirtualDomainAugmentor, Xorshift64};
|
||||
//!
|
||||
//! let mut aug = VirtualDomainAugmentor::default();
|
||||
//! let mut rng = Xorshift64::new(42);
|
||||
//! let frame = vec![0.5_f32; 56];
|
||||
//! let domain = aug.generate_domain(&mut rng);
|
||||
//! let out = aug.augment_frame(&frame, &domain);
|
||||
//! assert_eq!(out.len(), frame.len());
|
||||
//! ```
|
||||
|
||||
use std::f32::consts::PI;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Xorshift64 PRNG (matches dataset.rs pattern)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Lightweight 64-bit Xorshift PRNG for deterministic augmentation.
|
||||
pub struct Xorshift64 {
|
||||
state: u64,
|
||||
}
|
||||
|
||||
impl Xorshift64 {
|
||||
/// Create a new PRNG. Seed `0` is replaced with a fixed non-zero value.
|
||||
pub fn new(seed: u64) -> Self {
|
||||
Self { state: if seed == 0 { 0x853c49e6748fea9b } else { seed } }
|
||||
}
|
||||
|
||||
/// Advance the state and return the next `u64`.
|
||||
#[inline]
|
||||
pub fn next_u64(&mut self) -> u64 {
|
||||
self.state ^= self.state << 13;
|
||||
self.state ^= self.state >> 7;
|
||||
self.state ^= self.state << 17;
|
||||
self.state
|
||||
}
|
||||
|
||||
/// Return a uniformly distributed `f32` in `[0, 1)`.
|
||||
#[inline]
|
||||
pub fn next_f32(&mut self) -> f32 {
|
||||
(self.next_u64() >> 40) as f32 / (1u64 << 24) as f32
|
||||
}
|
||||
|
||||
/// Return a uniformly distributed `f32` in `[lo, hi)`.
|
||||
#[inline]
|
||||
pub fn next_f32_range(&mut self, lo: f32, hi: f32) -> f32 {
|
||||
lo + self.next_f32() * (hi - lo)
|
||||
}
|
||||
|
||||
/// Return a uniformly distributed `usize` in `[lo, hi]` (inclusive).
|
||||
#[inline]
|
||||
pub fn next_usize_range(&mut self, lo: usize, hi: usize) -> usize {
|
||||
if lo >= hi { return lo; }
|
||||
lo + (self.next_u64() % (hi - lo + 1) as u64) as usize
|
||||
}
|
||||
|
||||
/// Sample an approximate Gaussian (mean=0, std=1) via Box-Muller.
|
||||
#[inline]
|
||||
pub fn next_gaussian(&mut self) -> f32 {
|
||||
let u1 = self.next_f32().max(1e-10);
|
||||
let u2 = self.next_f32();
|
||||
(-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos()
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// VirtualDomain
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Describes a single synthetic WiFi environment for domain augmentation.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VirtualDomain {
|
||||
/// Path-loss factor simulating room size (< 1 smaller, > 1 larger room).
|
||||
pub room_scale: f32,
|
||||
/// Wall reflection coefficient in `[0, 1]` (low = absorptive, high = reflective).
|
||||
pub reflection_coeff: f32,
|
||||
/// Number of virtual scatterers (furniture / obstacles).
|
||||
pub n_scatterers: usize,
|
||||
/// Standard deviation of additive hardware noise.
|
||||
pub noise_std: f32,
|
||||
/// Unique label for the domain classifier in adversarial training.
|
||||
pub domain_id: u32,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// VirtualDomainAugmentor
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Samples virtual WiFi domains and transforms CSI frames to simulate them.
|
||||
///
|
||||
/// Applies four transformations: room-scale amplitude scaling, per-subcarrier
|
||||
/// reflection modulation, virtual scatterer sinusoidal interference, and
|
||||
/// Gaussian noise injection.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VirtualDomainAugmentor {
|
||||
/// Range for room scale factor `(min, max)`.
|
||||
pub room_scale_range: (f32, f32),
|
||||
/// Range for reflection coefficient `(min, max)`.
|
||||
pub reflection_coeff_range: (f32, f32),
|
||||
/// Range for number of virtual scatterers `(min, max)`.
|
||||
pub n_virtual_scatterers: (usize, usize),
|
||||
/// Range for noise standard deviation `(min, max)`.
|
||||
pub noise_std_range: (f32, f32),
|
||||
next_domain_id: u32,
|
||||
}
|
||||
|
||||
impl Default for VirtualDomainAugmentor {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
room_scale_range: (0.5, 2.0),
|
||||
reflection_coeff_range: (0.3, 0.9),
|
||||
n_virtual_scatterers: (0, 5),
|
||||
noise_std_range: (0.01, 0.1),
|
||||
next_domain_id: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl VirtualDomainAugmentor {
|
||||
/// Randomly sample a new [`VirtualDomain`] from the configured ranges.
|
||||
pub fn generate_domain(&mut self, rng: &mut Xorshift64) -> VirtualDomain {
|
||||
let id = self.next_domain_id;
|
||||
self.next_domain_id = self.next_domain_id.wrapping_add(1);
|
||||
VirtualDomain {
|
||||
room_scale: rng.next_f32_range(self.room_scale_range.0, self.room_scale_range.1),
|
||||
reflection_coeff: rng.next_f32_range(self.reflection_coeff_range.0, self.reflection_coeff_range.1),
|
||||
n_scatterers: rng.next_usize_range(self.n_virtual_scatterers.0, self.n_virtual_scatterers.1),
|
||||
noise_std: rng.next_f32_range(self.noise_std_range.0, self.noise_std_range.1),
|
||||
domain_id: id,
|
||||
}
|
||||
}
|
||||
|
||||
/// Transform a single CSI amplitude frame to simulate `domain`.
|
||||
///
|
||||
/// Pipeline: (1) scale by `1/room_scale`, (2) per-subcarrier reflection
|
||||
/// modulation, (3) scatterer sinusoidal perturbation, (4) Gaussian noise.
|
||||
pub fn augment_frame(&self, frame: &[f32], domain: &VirtualDomain) -> Vec<f32> {
|
||||
let n = frame.len();
|
||||
let n_f = n as f32;
|
||||
let mut noise_rng = Xorshift64::new(
|
||||
(domain.domain_id as u64).wrapping_mul(0x9E3779B97F4A7C15).wrapping_add(1),
|
||||
);
|
||||
let mut out = Vec::with_capacity(n);
|
||||
for (k, &val) in frame.iter().enumerate() {
|
||||
let k_f = k as f32;
|
||||
// 1. Room-scale amplitude attenuation
|
||||
let scaled = val / domain.room_scale;
|
||||
// 2. Reflection coefficient modulation (per-subcarrier)
|
||||
let refl = domain.reflection_coeff
|
||||
+ (1.0 - domain.reflection_coeff) * (PI * k_f / n_f).cos();
|
||||
let modulated = scaled * refl;
|
||||
// 3. Virtual scatterer sinusoidal interference
|
||||
let mut scatter = 0.0_f32;
|
||||
for s in 0..domain.n_scatterers {
|
||||
scatter += 0.05 * (2.0 * PI * (s as f32 + 1.0) * k_f / n_f).sin();
|
||||
}
|
||||
// 4. Additive Gaussian noise
|
||||
out.push(modulated + scatter + noise_rng.next_gaussian() * domain.noise_std);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Augment a batch, producing `k` virtual-domain variants per input frame.
|
||||
///
|
||||
/// Returns `(augmented_frame, domain_id)` pairs; total = `batch.len() * k`.
|
||||
pub fn augment_batch(
|
||||
&mut self, batch: &[Vec<f32>], k: usize, rng: &mut Xorshift64,
|
||||
) -> Vec<(Vec<f32>, u32)> {
|
||||
let mut results = Vec::with_capacity(batch.len() * k);
|
||||
for frame in batch {
|
||||
for _ in 0..k {
|
||||
let domain = self.generate_domain(rng);
|
||||
let augmented = self.augment_frame(frame, &domain);
|
||||
results.push((augmented, domain.domain_id));
|
||||
}
|
||||
}
|
||||
results
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_domain(scale: f32, coeff: f32, scatter: usize, noise: f32, id: u32) -> VirtualDomain {
|
||||
VirtualDomain { room_scale: scale, reflection_coeff: coeff, n_scatterers: scatter, noise_std: noise, domain_id: id }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn domain_within_configured_ranges() {
|
||||
let mut aug = VirtualDomainAugmentor::default();
|
||||
let mut rng = Xorshift64::new(12345);
|
||||
for _ in 0..100 {
|
||||
let d = aug.generate_domain(&mut rng);
|
||||
assert!(d.room_scale >= 0.5 && d.room_scale <= 2.0);
|
||||
assert!(d.reflection_coeff >= 0.3 && d.reflection_coeff <= 0.9);
|
||||
assert!(d.n_scatterers <= 5);
|
||||
assert!(d.noise_std >= 0.01 && d.noise_std <= 0.1);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn augment_frame_preserves_length() {
|
||||
let aug = VirtualDomainAugmentor::default();
|
||||
let out = aug.augment_frame(&vec![0.5; 56], &make_domain(1.0, 0.5, 3, 0.05, 0));
|
||||
assert_eq!(out.len(), 56);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn augment_frame_identity_domain_approx_input() {
|
||||
let aug = VirtualDomainAugmentor::default();
|
||||
let frame: Vec<f32> = (0..56).map(|i| 0.3 + 0.01 * i as f32).collect();
|
||||
let out = aug.augment_frame(&frame, &make_domain(1.0, 1.0, 0, 0.0, 0));
|
||||
for (a, b) in out.iter().zip(frame.iter()) {
|
||||
assert!((a - b).abs() < 1e-5, "identity domain: got {a}, expected {b}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn augment_batch_produces_correct_count() {
|
||||
let mut aug = VirtualDomainAugmentor::default();
|
||||
let mut rng = Xorshift64::new(99);
|
||||
let batch: Vec<Vec<f32>> = (0..4).map(|_| vec![0.5; 56]).collect();
|
||||
let results = aug.augment_batch(&batch, 3, &mut rng);
|
||||
assert_eq!(results.len(), 12);
|
||||
for (f, _) in &results { assert_eq!(f.len(), 56); }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn different_seeds_produce_different_augmentations() {
|
||||
let mut aug1 = VirtualDomainAugmentor::default();
|
||||
let mut aug2 = VirtualDomainAugmentor::default();
|
||||
let frame = vec![0.5_f32; 56];
|
||||
let d1 = aug1.generate_domain(&mut Xorshift64::new(1));
|
||||
let d2 = aug2.generate_domain(&mut Xorshift64::new(2));
|
||||
let out1 = aug1.augment_frame(&frame, &d1);
|
||||
let out2 = aug2.augment_frame(&frame, &d2);
|
||||
assert!(out1.iter().zip(out2.iter()).any(|(a, b)| (a - b).abs() > 1e-6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deterministic_same_seed_same_output() {
|
||||
let batch: Vec<Vec<f32>> = (0..3).map(|i| vec![0.1 * i as f32; 56]).collect();
|
||||
let mut aug1 = VirtualDomainAugmentor::default();
|
||||
let mut aug2 = VirtualDomainAugmentor::default();
|
||||
let res1 = aug1.augment_batch(&batch, 2, &mut Xorshift64::new(42));
|
||||
let res2 = aug2.augment_batch(&batch, 2, &mut Xorshift64::new(42));
|
||||
assert_eq!(res1.len(), res2.len());
|
||||
for ((f1, id1), (f2, id2)) in res1.iter().zip(res2.iter()) {
|
||||
assert_eq!(id1, id2);
|
||||
for (a, b) in f1.iter().zip(f2.iter()) {
|
||||
assert!((a - b).abs() < 1e-7, "same seed must produce identical output");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn domain_ids_are_sequential() {
|
||||
let mut aug = VirtualDomainAugmentor::default();
|
||||
let mut rng = Xorshift64::new(7);
|
||||
for i in 0..10_u32 { assert_eq!(aug.generate_domain(&mut rng).domain_id, i); }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn xorshift64_deterministic() {
|
||||
let mut a = Xorshift64::new(999);
|
||||
let mut b = Xorshift64::new(999);
|
||||
for _ in 0..100 { assert_eq!(a.next_u64(), b.next_u64()); }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn xorshift64_f32_in_unit_interval() {
|
||||
let mut rng = Xorshift64::new(42);
|
||||
for _ in 0..1000 {
|
||||
let v = rng.next_f32();
|
||||
assert!(v >= 0.0 && v < 1.0, "f32 sample {v} not in [0, 1)");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn augment_frame_empty_and_batch_k_zero() {
|
||||
let aug = VirtualDomainAugmentor::default();
|
||||
assert!(aug.augment_frame(&[], &make_domain(1.5, 0.5, 2, 0.05, 0)).is_empty());
|
||||
let mut aug2 = VirtualDomainAugmentor::default();
|
||||
assert!(aug2.augment_batch(&[vec![0.5; 56]], 0, &mut Xorshift64::new(1)).is_empty());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user