diff --git a/rust-port/wifi-densepose-rs/Cargo.lock b/rust-port/wifi-densepose-rs/Cargo.lock index a6e2a2e..80b0c34 100644 --- a/rust-port/wifi-densepose-rs/Cargo.lock +++ b/rust-port/wifi-densepose-rs/Cargo.lock @@ -4115,6 +4115,7 @@ dependencies = [ "tower-http", "tracing", "tracing-subscriber", + "wifi-densepose-wifiscan", ] [[package]] @@ -4176,6 +4177,15 @@ dependencies = [ "wifi-densepose-signal", ] +[[package]] +name = "wifi-densepose-vitals" +version = "0.1.0" +dependencies = [ + "serde", + "serde_json", + "tracing", +] + [[package]] name = "wifi-densepose-wasm" version = "0.1.0" @@ -4203,6 +4213,7 @@ name = "wifi-densepose-wifiscan" version = "0.1.0" dependencies = [ "serde", + "tokio", "tracing", ] diff --git a/rust-port/wifi-densepose-rs/Cargo.toml b/rust-port/wifi-densepose-rs/Cargo.toml index 9505874..2c0e448 100644 --- a/rust-port/wifi-densepose-rs/Cargo.toml +++ b/rust-port/wifi-densepose-rs/Cargo.toml @@ -14,6 +14,7 @@ members = [ "crates/wifi-densepose-train", "crates/wifi-densepose-sensing-server", "crates/wifi-densepose-wifiscan", + "crates/wifi-densepose-vitals", ] [workspace.package] diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/Cargo.toml b/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/Cargo.toml index f75ba17..64539f9 100644 --- a/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/Cargo.toml +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/Cargo.toml @@ -34,5 +34,8 @@ chrono = { version = "0.4", features = ["serde"] } # CLI clap = { workspace = true } +# Multi-BSSID WiFi scanning pipeline (ADR-022 Phase 3) +wifi-densepose-wifiscan = { path = "../wifi-densepose-wifiscan" } + [dev-dependencies] tempfile = "3.10" diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/dataset.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/dataset.rs new file mode 100644 index 0000000..93cf9bf --- /dev/null +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/dataset.rs @@ -0,0 +1,850 @@ +//! Dataset loaders for WiFi-to-DensePose training pipeline (ADR-023 Phase 1). +//! +//! Provides unified data loading for MM-Fi (NeurIPS 2023) and Wi-Pose datasets, +//! with from-scratch .npy/.mat v5 parsers, subcarrier resampling, and a unified +//! `DataPipeline` for normalized, windowed training samples. + +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::fmt; +use std::io; +use std::path::{Path, PathBuf}; + +// ── Error type ─────────────────────────────────────────────────────────────── + +#[derive(Debug)] +pub enum DatasetError { + Io(io::Error), + Format(String), + Missing(String), + Shape(String), +} + +impl fmt::Display for DatasetError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Io(e) => write!(f, "I/O error: {e}"), + Self::Format(s) => write!(f, "format error: {s}"), + Self::Missing(s) => write!(f, "missing: {s}"), + Self::Shape(s) => write!(f, "shape error: {s}"), + } + } +} + +impl std::error::Error for DatasetError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + if let Self::Io(e) = self { Some(e) } else { None } + } +} + +impl From for DatasetError { + fn from(e: io::Error) -> Self { Self::Io(e) } +} + +pub type Result = std::result::Result; + +// ── NpyArray ───────────────────────────────────────────────────────────────── + +/// Dense array from .npy: flat f32 data with shape metadata. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct NpyArray { + pub shape: Vec, + pub data: Vec, +} + +impl NpyArray { + pub fn len(&self) -> usize { self.data.len() } + pub fn is_empty(&self) -> bool { self.data.is_empty() } + pub fn ndim(&self) -> usize { self.shape.len() } +} + +// ── NpyReader ──────────────────────────────────────────────────────────────── + +/// Minimal NumPy .npy format reader (f32/f64, v1/v2). +pub struct NpyReader; + +impl NpyReader { + pub fn read_file(path: &Path) -> Result { + Self::parse(&std::fs::read(path)?) + } + + pub fn parse(buf: &[u8]) -> Result { + if buf.len() < 10 { return Err(DatasetError::Format("file too small for .npy".into())); } + if &buf[0..6] != b"\x93NUMPY" { + return Err(DatasetError::Format("missing .npy magic".into())); + } + let major = buf[6]; + let (header_len, header_start) = match major { + 1 => (u16::from_le_bytes([buf[8], buf[9]]) as usize, 10usize), + 2 | 3 => { + if buf.len() < 12 { return Err(DatasetError::Format("truncated v2 header".into())); } + (u32::from_le_bytes([buf[8], buf[9], buf[10], buf[11]]) as usize, 12) + } + _ => return Err(DatasetError::Format(format!("unsupported .npy version {major}"))), + }; + let header_end = header_start + header_len; + if header_end > buf.len() { return Err(DatasetError::Format("header past EOF".into())); } + let hdr = std::str::from_utf8(&buf[header_start..header_end]) + .map_err(|_| DatasetError::Format("non-UTF8 header".into()))?; + + let dtype = Self::extract_field(hdr, "descr")?; + let is_f64 = dtype.contains("f8") || dtype.contains("float64"); + let is_f32 = dtype.contains("f4") || dtype.contains("float32"); + let is_big = dtype.starts_with('>'); + if !is_f32 && !is_f64 { + return Err(DatasetError::Format(format!("unsupported dtype '{dtype}'"))); + } + let fortran = Self::extract_field(hdr, "fortran_order") + .unwrap_or_else(|_| "False".into()).contains("True"); + let shape = Self::parse_shape(hdr)?; + let elem_sz: usize = if is_f64 { 8 } else { 4 }; + let total: usize = shape.iter().product::().max(1); + if header_end + total * elem_sz > buf.len() { + return Err(DatasetError::Format("data truncated".into())); + } + let raw = &buf[header_end..header_end + total * elem_sz]; + let mut data: Vec = if is_f64 { + raw.chunks_exact(8).map(|c| { + let v = if is_big { f64::from_be_bytes(c.try_into().unwrap()) } + else { f64::from_le_bytes(c.try_into().unwrap()) }; + v as f32 + }).collect() + } else { + raw.chunks_exact(4).map(|c| { + if is_big { f32::from_be_bytes(c.try_into().unwrap()) } + else { f32::from_le_bytes(c.try_into().unwrap()) } + }).collect() + }; + if fortran && shape.len() == 2 { + let (r, c) = (shape[0], shape[1]); + let mut cd = vec![0.0f32; data.len()]; + for ri in 0..r { for ci in 0..c { cd[ri*c+ci] = data[ci*r+ri]; } } + data = cd; + } + let shape = if shape.is_empty() { vec![1] } else { shape }; + Ok(NpyArray { shape, data }) + } + + fn extract_field(hdr: &str, field: &str) -> Result { + for pat in &[format!("'{field}': "), format!("'{field}':"), format!("\"{field}\": ")] { + if let Some(s) = hdr.find(pat.as_str()) { + let rest = &hdr[s + pat.len()..]; + let end = rest.find(',').or_else(|| rest.find('}')).unwrap_or(rest.len()); + return Ok(rest[..end].trim().trim_matches('\'').trim_matches('"').into()); + } + } + Err(DatasetError::Format(format!("field '{field}' not found"))) + } + + fn parse_shape(hdr: &str) -> Result> { + let si = hdr.find("'shape'").or_else(|| hdr.find("\"shape\"")) + .ok_or_else(|| DatasetError::Format("no 'shape'".into()))?; + let rest = &hdr[si..]; + let ps = rest.find('(').ok_or_else(|| DatasetError::Format("no '('".into()))?; + let pe = rest[ps..].find(')').ok_or_else(|| DatasetError::Format("no ')'".into()))?; + let inner = rest[ps+1..ps+pe].trim(); + if inner.is_empty() { return Ok(vec![]); } + inner.split(',').map(|s| s.trim()).filter(|s| !s.is_empty()) + .map(|s| s.parse::().map_err(|_| DatasetError::Format(format!("bad dim: '{s}'")))) + .collect() + } +} + +// ── MatReader ──────────────────────────────────────────────────────────────── + +/// Minimal MATLAB .mat v5 reader for numeric arrays. +pub struct MatReader; + +const MI_INT8: u32 = 1; +#[allow(dead_code)] const MI_UINT8: u32 = 2; +#[allow(dead_code)] const MI_INT16: u32 = 3; +#[allow(dead_code)] const MI_UINT16: u32 = 4; +const MI_INT32: u32 = 5; +const MI_UINT32: u32 = 6; +const MI_SINGLE: u32 = 7; +const MI_DOUBLE: u32 = 9; +const MI_MATRIX: u32 = 14; + +impl MatReader { + pub fn read_file(path: &Path) -> Result> { + Self::parse(&std::fs::read(path)?) + } + + pub fn parse(buf: &[u8]) -> Result> { + if buf.len() < 128 { return Err(DatasetError::Format("too small for .mat v5".into())); } + let swap = u16::from_le_bytes([buf[126], buf[127]]) == 0x4D49; + let mut result = HashMap::new(); + let mut off = 128; + while off + 8 <= buf.len() { + let (dt, ds, ts) = Self::read_tag(buf, off, swap)?; + let el_start = off + ts; + let el_end = el_start + ds; + if el_end > buf.len() { break; } + if dt == MI_MATRIX { + if let Ok((n, a)) = Self::parse_matrix(&buf[el_start..el_end], swap) { + result.insert(n, a); + } + } + off = (el_end + 7) & !7; + } + Ok(result) + } + + fn read_tag(buf: &[u8], off: usize, swap: bool) -> Result<(u32, usize, usize)> { + if off + 4 > buf.len() { return Err(DatasetError::Format("truncated tag".into())); } + let raw = Self::u32(buf, off, swap); + let upper = (raw >> 16) & 0xFFFF; + if upper != 0 && upper <= 4 { return Ok((raw & 0xFFFF, upper as usize, 4)); } + if off + 8 > buf.len() { return Err(DatasetError::Format("truncated tag".into())); } + Ok((raw, Self::u32(buf, off + 4, swap) as usize, 8)) + } + + fn parse_matrix(buf: &[u8], swap: bool) -> Result<(String, NpyArray)> { + let (mut name, mut shape, mut data) = (String::new(), Vec::new(), Vec::new()); + let mut off = 0; + while off + 4 <= buf.len() { + let (st, ss, ts) = Self::read_tag(buf, off, swap)?; + let ss_start = off + ts; + let ss_end = (ss_start + ss).min(buf.len()); + match st { + MI_UINT32 if shape.is_empty() && ss == 8 => {} + MI_INT32 if shape.is_empty() => { + for i in 0..ss / 4 { shape.push(Self::i32(buf, ss_start + i*4, swap) as usize); } + } + MI_INT8 if name.is_empty() && ss_end <= buf.len() => { + name = String::from_utf8_lossy(&buf[ss_start..ss_end]) + .trim_end_matches('\0').to_string(); + } + MI_DOUBLE => { + for i in 0..ss / 8 { + let p = ss_start + i * 8; + if p + 8 <= buf.len() { data.push(Self::f64(buf, p, swap) as f32); } + } + } + MI_SINGLE => { + for i in 0..ss / 4 { + let p = ss_start + i * 4; + if p + 4 <= buf.len() { data.push(Self::f32(buf, p, swap)); } + } + } + _ => {} + } + off = (ss_end + 7) & !7; + } + if name.is_empty() { name = "unnamed".into(); } + if shape.is_empty() && !data.is_empty() { shape = vec![data.len()]; } + // Transpose column-major to row-major for 2D + if shape.len() == 2 { + let (r, c) = (shape[0], shape[1]); + if r * c == data.len() { + let mut cd = vec![0.0f32; data.len()]; + for ri in 0..r { for ci in 0..c { cd[ri*c+ci] = data[ci*r+ri]; } } + data = cd; + } + } + Ok((name, NpyArray { shape, data })) + } + + fn u32(b: &[u8], o: usize, s: bool) -> u32 { + let v = [b[o], b[o+1], b[o+2], b[o+3]]; + if s { u32::from_be_bytes(v) } else { u32::from_le_bytes(v) } + } + fn i32(b: &[u8], o: usize, s: bool) -> i32 { + let v = [b[o], b[o+1], b[o+2], b[o+3]]; + if s { i32::from_be_bytes(v) } else { i32::from_le_bytes(v) } + } + fn f64(b: &[u8], o: usize, s: bool) -> f64 { + let v: [u8; 8] = b[o..o+8].try_into().unwrap(); + if s { f64::from_be_bytes(v) } else { f64::from_le_bytes(v) } + } + fn f32(b: &[u8], o: usize, s: bool) -> f32 { + let v = [b[o], b[o+1], b[o+2], b[o+3]]; + if s { f32::from_be_bytes(v) } else { f32::from_le_bytes(v) } + } +} + +// ── Core data types ────────────────────────────────────────────────────────── + +/// A single CSI (Channel State Information) sample. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CsiSample { + pub amplitude: Vec, + pub phase: Vec, + pub timestamp_ms: u64, +} + +/// UV coordinate map for a body part in DensePose representation. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BodyPartUV { + pub part_id: u8, + pub u_coords: Vec, + pub v_coords: Vec, +} + +/// Pose label: 17 COCO keypoints + optional DensePose body-part UVs. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PoseLabel { + pub keypoints: [(f32, f32, f32); 17], + pub body_parts: Vec, + pub confidence: f32, +} + +impl Default for PoseLabel { + fn default() -> Self { + Self { keypoints: [(0.0, 0.0, 0.0); 17], body_parts: Vec::new(), confidence: 0.0 } + } +} + +// ── SubcarrierResampler ────────────────────────────────────────────────────── + +/// Resamples subcarrier data via linear interpolation or zero-padding. +pub struct SubcarrierResampler; + +impl SubcarrierResampler { + /// Resample: passthrough if equal, zero-pad if upsampling, interpolate if downsampling. + pub fn resample(input: &[f32], from: usize, to: usize) -> Vec { + if from == to || from == 0 || to == 0 { return input.to_vec(); } + if from < to { Self::zero_pad(input, from, to) } else { Self::interpolate(input, from, to) } + } + + /// Resample phase data with unwrapping before interpolation. + pub fn resample_phase(input: &[f32], from: usize, to: usize) -> Vec { + if from == to || from == 0 || to == 0 { return input.to_vec(); } + let unwrapped = Self::phase_unwrap(input); + let resampled = if from < to { Self::zero_pad(&unwrapped, from, to) } + else { Self::interpolate(&unwrapped, from, to) }; + let pi = std::f32::consts::PI; + resampled.iter().map(|&p| { + let mut w = p % (2.0 * pi); + if w > pi { w -= 2.0 * pi; } + if w < -pi { w += 2.0 * pi; } + w + }).collect() + } + + fn zero_pad(input: &[f32], from: usize, to: usize) -> Vec { + let pad_left = (to - from) / 2; + let mut out = vec![0.0f32; to]; + for i in 0..from.min(input.len()) { + if pad_left + i < to { out[pad_left + i] = input[i]; } + } + out + } + + fn interpolate(input: &[f32], from: usize, to: usize) -> Vec { + let n = input.len().min(from); + if n <= 1 { return vec![input.first().copied().unwrap_or(0.0); to]; } + (0..to).map(|i| { + let pos = i as f64 * (n - 1) as f64 / (to - 1).max(1) as f64; + let lo = pos.floor() as usize; + let hi = (lo + 1).min(n - 1); + let f = (pos - lo as f64) as f32; + input[lo] * (1.0 - f) + input[hi] * f + }).collect() + } + + fn phase_unwrap(phase: &[f32]) -> Vec { + let pi = std::f32::consts::PI; + let mut out = vec![0.0f32; phase.len()]; + if phase.is_empty() { return out; } + out[0] = phase[0]; + for i in 1..phase.len() { + let mut d = phase[i] - phase[i - 1]; + while d > pi { d -= 2.0 * pi; } + while d < -pi { d += 2.0 * pi; } + out[i] = out[i - 1] + d; + } + out + } +} + +// ── MmFiDataset ────────────────────────────────────────────────────────────── + +/// MM-Fi (NeurIPS 2023) dataset loader with 56 subcarriers and 17 COCO keypoints. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MmFiDataset { + pub csi_frames: Vec, + pub labels: Vec, + pub sample_rate_hz: f32, + pub n_subcarriers: usize, +} + +impl MmFiDataset { + pub const SUBCARRIERS: usize = 56; + + /// Load from directory with csi_amplitude.npy/csi.npy and labels.npy/keypoints.npy. + pub fn load_from_directory(path: &Path) -> Result { + if !path.is_dir() { + return Err(DatasetError::Missing(format!("directory not found: {}", path.display()))); + } + let amp = NpyReader::read_file(&Self::find(path, &["csi_amplitude.npy", "csi.npy"])?)?; + let n = amp.shape.first().copied().unwrap_or(0); + let raw_sc = if amp.shape.len() >= 2 { amp.shape[1] } else { amp.data.len() / n.max(1) }; + let phase_arr = Self::find(path, &["csi_phase.npy"]).ok() + .and_then(|p| NpyReader::read_file(&p).ok()); + let lab = NpyReader::read_file(&Self::find(path, &["labels.npy", "keypoints.npy"])?)?; + + let mut csi_frames = Vec::with_capacity(n); + let mut labels = Vec::with_capacity(n); + for i in 0..n { + let s = i * raw_sc; + if s + raw_sc > amp.data.len() { break; } + let amplitude = SubcarrierResampler::resample(&.data[s..s+raw_sc], raw_sc, Self::SUBCARRIERS); + let phase = phase_arr.as_ref().map(|pa| { + let ps = i * raw_sc; + if ps + raw_sc <= pa.data.len() { + SubcarrierResampler::resample_phase(&pa.data[ps..ps+raw_sc], raw_sc, Self::SUBCARRIERS) + } else { vec![0.0; Self::SUBCARRIERS] } + }).unwrap_or_else(|| vec![0.0; Self::SUBCARRIERS]); + + csi_frames.push(CsiSample { amplitude, phase, timestamp_ms: i as u64 * 50 }); + + let ks = i * 17 * 3; + let label = if ks + 51 <= lab.data.len() { + let d = &lab.data[ks..ks + 51]; + let mut kp = [(0.0f32, 0.0, 0.0); 17]; + for k in 0..17 { kp[k] = (d[k*3], d[k*3+1], d[k*3+2]); } + PoseLabel { keypoints: kp, body_parts: Vec::new(), confidence: 1.0 } + } else { PoseLabel::default() }; + labels.push(label); + } + Ok(Self { csi_frames, labels, sample_rate_hz: 20.0, n_subcarriers: Self::SUBCARRIERS }) + } + + pub fn resample_subcarriers(&mut self, from: usize, to: usize) { + for f in &mut self.csi_frames { + f.amplitude = SubcarrierResampler::resample(&f.amplitude, from, to); + f.phase = SubcarrierResampler::resample_phase(&f.phase, from, to); + } + self.n_subcarriers = to; + } + + pub fn iter_windows(&self, ws: usize, stride: usize) -> impl Iterator { + let stride = stride.max(1); + let n = self.csi_frames.len(); + (0..n).step_by(stride).filter(move |&s| s + ws <= n) + .map(move |s| (&self.csi_frames[s..s+ws], &self.labels[s..s+ws])) + } + + pub fn split_train_val(self, ratio: f32) -> (Self, Self) { + let split = (self.csi_frames.len() as f32 * ratio.clamp(0.0, 1.0)) as usize; + let (tc, vc) = self.csi_frames.split_at(split); + let (tl, vl) = self.labels.split_at(split); + let mk = |c: &[CsiSample], l: &[PoseLabel]| Self { + csi_frames: c.to_vec(), labels: l.to_vec(), + sample_rate_hz: self.sample_rate_hz, n_subcarriers: self.n_subcarriers, + }; + (mk(tc, tl), mk(vc, vl)) + } + + pub fn len(&self) -> usize { self.csi_frames.len() } + pub fn is_empty(&self) -> bool { self.csi_frames.is_empty() } + pub fn get(&self, idx: usize) -> Option<(&CsiSample, &PoseLabel)> { + self.csi_frames.get(idx).zip(self.labels.get(idx)) + } + + fn find(dir: &Path, names: &[&str]) -> Result { + for n in names { let p = dir.join(n); if p.exists() { return Ok(p); } } + Err(DatasetError::Missing(format!("none of {names:?} in {}", dir.display()))) + } +} + +// ── WiPoseDataset ──────────────────────────────────────────────────────────── + +/// Wi-Pose dataset loader: .mat v5, 30 subcarriers (-> 56), 18 keypoints (-> 17 COCO). +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WiPoseDataset { + pub csi_frames: Vec, + pub labels: Vec, + pub sample_rate_hz: f32, + pub n_subcarriers: usize, +} + +impl WiPoseDataset { + pub const RAW_SUBCARRIERS: usize = 30; + pub const TARGET_SUBCARRIERS: usize = 56; + pub const RAW_KEYPOINTS: usize = 18; + pub const COCO_KEYPOINTS: usize = 17; + + pub fn load_from_mat(path: &Path) -> Result { + let arrays = MatReader::read_file(path)?; + let csi = arrays.get("csi").or_else(|| arrays.get("csi_data")).or_else(|| arrays.get("CSI")) + .ok_or_else(|| DatasetError::Missing("no CSI variable in .mat".into()))?; + let n = csi.shape.first().copied().unwrap_or(0); + let raw = if csi.shape.len() >= 2 { csi.shape[1] } else { Self::RAW_SUBCARRIERS }; + let lab = arrays.get("keypoints").or_else(|| arrays.get("labels")).or_else(|| arrays.get("pose")); + + let mut csi_frames = Vec::with_capacity(n); + let mut labels = Vec::with_capacity(n); + for i in 0..n { + let s = i * raw; + if s + raw > csi.data.len() { break; } + let amp = SubcarrierResampler::resample(&csi.data[s..s+raw], raw, Self::TARGET_SUBCARRIERS); + csi_frames.push(CsiSample { amplitude: amp, phase: vec![0.0; Self::TARGET_SUBCARRIERS], timestamp_ms: i as u64 * 100 }); + let label = lab.and_then(|la| { + let ks = i * Self::RAW_KEYPOINTS * 3; + if ks + Self::RAW_KEYPOINTS * 3 <= la.data.len() { + Some(Self::map_18_to_17(&la.data[ks..ks + Self::RAW_KEYPOINTS * 3])) + } else { None } + }).unwrap_or_default(); + labels.push(label); + } + Ok(Self { csi_frames, labels, sample_rate_hz: 10.0, n_subcarriers: Self::TARGET_SUBCARRIERS }) + } + + /// Map 18 keypoints to 17 COCO: keep index 0 (nose), drop index 1, map 2..18 -> 1..16. + fn map_18_to_17(data: &[f32]) -> PoseLabel { + let mut kp = [(0.0f32, 0.0, 0.0); 17]; + if data.len() >= 18 * 3 { + kp[0] = (data[0], data[1], data[2]); + for i in 1..17 { let s = (i + 1) * 3; kp[i] = (data[s], data[s+1], data[s+2]); } + } + PoseLabel { keypoints: kp, body_parts: Vec::new(), confidence: 1.0 } + } + + pub fn len(&self) -> usize { self.csi_frames.len() } + pub fn is_empty(&self) -> bool { self.csi_frames.is_empty() } +} + +// ── DataPipeline ───────────────────────────────────────────────────────────── + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum DataSource { + MmFi(PathBuf), + WiPose(PathBuf), + Combined(Vec), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DataConfig { + pub source: DataSource, + pub window_size: usize, + pub stride: usize, + pub target_subcarriers: usize, + pub normalize: bool, +} + +impl Default for DataConfig { + fn default() -> Self { + Self { source: DataSource::Combined(Vec::new()), window_size: 10, stride: 5, + target_subcarriers: 56, normalize: true } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TrainingSample { + pub csi_window: Vec>, + pub pose_label: PoseLabel, + pub source: &'static str, +} + +/// Unified pipeline: loads, resamples, windows, and normalizes training data. +pub struct DataPipeline { config: DataConfig } + +impl DataPipeline { + pub fn new(config: DataConfig) -> Self { Self { config } } + + pub fn load(&self) -> Result> { + let mut out = Vec::new(); + self.load_source(&self.config.source, &mut out)?; + if self.config.normalize && !out.is_empty() { Self::normalize_samples(&mut out); } + Ok(out) + } + + fn load_source(&self, src: &DataSource, out: &mut Vec) -> Result<()> { + match src { + DataSource::MmFi(p) => { + let mut ds = MmFiDataset::load_from_directory(p)?; + if ds.n_subcarriers != self.config.target_subcarriers { + let f = ds.n_subcarriers; + ds.resample_subcarriers(f, self.config.target_subcarriers); + } + self.extract_windows(&ds.csi_frames, &ds.labels, "mmfi", out); + } + DataSource::WiPose(p) => { + let ds = WiPoseDataset::load_from_mat(p)?; + self.extract_windows(&ds.csi_frames, &ds.labels, "wipose", out); + } + DataSource::Combined(srcs) => { for s in srcs { self.load_source(s, out)?; } } + } + Ok(()) + } + + fn extract_windows(&self, frames: &[CsiSample], labels: &[PoseLabel], + source: &'static str, out: &mut Vec) { + let (ws, stride) = (self.config.window_size, self.config.stride.max(1)); + let mut s = 0; + while s + ws <= frames.len() { + let window: Vec> = frames[s..s+ws].iter().map(|f| f.amplitude.clone()).collect(); + let label = labels.get(s + ws / 2).cloned().unwrap_or_default(); + out.push(TrainingSample { csi_window: window, pose_label: label, source }); + s += stride; + } + } + + fn normalize_samples(samples: &mut [TrainingSample]) { + let ns = samples.first().and_then(|s| s.csi_window.first()).map(|f| f.len()).unwrap_or(0); + if ns == 0 { return; } + let (mut sum, mut sq) = (vec![0.0f64; ns], vec![0.0f64; ns]); + let mut cnt = 0u64; + for s in samples.iter() { + for f in &s.csi_window { + for (j, &v) in f.iter().enumerate().take(ns) { + let v = v as f64; sum[j] += v; sq[j] += v * v; + } + cnt += 1; + } + } + if cnt == 0 { return; } + let mean: Vec = sum.iter().map(|s| s / cnt as f64).collect(); + let std: Vec = sq.iter().zip(mean.iter()) + .map(|(&s, &m)| (s / cnt as f64 - m * m).max(0.0).sqrt().max(1e-8)).collect(); + for s in samples.iter_mut() { + for f in &mut s.csi_window { + for (j, v) in f.iter_mut().enumerate().take(ns) { + *v = ((*v as f64 - mean[j]) / std[j]) as f32; + } + } + } + } +} + +// ── Tests ──────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + fn make_npy_f32(shape: &[usize], data: &[f32]) -> Vec { + let ss = if shape.len() == 1 { format!("({},)", shape[0]) } + else { format!("({})", shape.iter().map(|d| d.to_string()).collect::>().join(", ")) }; + let hdr = format!("{{'descr': ' Vec { + let ss = if shape.len() == 1 { format!("({},)", shape[0]) } + else { format!("({})", shape.iter().map(|d| d.to_string()).collect::>().join(", ")) }; + let hdr = format!("{{'descr': ' = (0..12).map(|i| i as f32).collect(); + let buf = make_npy_f32(&[3, 4], &data); + let arr = NpyReader::parse(&buf).unwrap(); + assert_eq!(arr.shape, vec![3, 4]); + assert_eq!(arr.ndim(), 2); + assert_eq!(arr.len(), 12); + } + + #[test] + fn npy_header_parse_3d() { + let data: Vec = (0..24).map(|i| i as f64 * 0.5).collect(); + let buf = make_npy_f64(&[2, 3, 4], &data); + let arr = NpyReader::parse(&buf).unwrap(); + assert_eq!(arr.shape, vec![2, 3, 4]); + assert_eq!(arr.ndim(), 3); + assert_eq!(arr.len(), 24); + assert!((arr.data[23] - 11.5).abs() < 1e-5); + } + + #[test] + fn subcarrier_resample_passthrough() { + let input: Vec = (0..56).map(|i| i as f32).collect(); + let output = SubcarrierResampler::resample(&input, 56, 56); + assert_eq!(output, input); + } + + #[test] + fn subcarrier_resample_upsample() { + let input: Vec = (0..30).map(|i| (i + 1) as f32).collect(); + let out = SubcarrierResampler::resample(&input, 30, 56); + assert_eq!(out.len(), 56); + // pad_left = 13, leading zeros + for i in 0..13 { assert!(out[i].abs() < f32::EPSILON, "expected zero at {i}"); } + // original data in middle + for i in 0..30 { assert!((out[13+i] - input[i]).abs() < f32::EPSILON); } + // trailing zeros + for i in 43..56 { assert!(out[i].abs() < f32::EPSILON, "expected zero at {i}"); } + } + + #[test] + fn subcarrier_resample_downsample() { + let input: Vec = (0..114).map(|i| i as f32).collect(); + let out = SubcarrierResampler::resample(&input, 114, 56); + assert_eq!(out.len(), 56); + assert!((out[0]).abs() < f32::EPSILON); + assert!((out[55] - 113.0).abs() < 0.1); + for i in 1..56 { assert!(out[i] >= out[i-1], "not monotonic at {i}"); } + } + + #[test] + fn subcarrier_resample_preserves_dc() { + let out = SubcarrierResampler::resample(&vec![42.0f32; 114], 114, 56); + assert_eq!(out.len(), 56); + for (i, &v) in out.iter().enumerate() { + assert!((v - 42.0).abs() < 1e-5, "DC not preserved at {i}: {v}"); + } + } + + #[test] + fn mmfi_sample_structure() { + let s = CsiSample { amplitude: vec![0.0; 56], phase: vec![0.0; 56], timestamp_ms: 100 }; + assert_eq!(s.amplitude.len(), 56); + assert_eq!(s.phase.len(), 56); + } + + #[test] + fn wipose_zero_pad() { + let raw: Vec = (1..=30).map(|i| i as f32).collect(); + let p = SubcarrierResampler::resample(&raw, 30, 56); + assert_eq!(p.len(), 56); + assert!(p[0].abs() < f32::EPSILON); + assert!((p[13] - 1.0).abs() < f32::EPSILON); + assert!((p[42] - 30.0).abs() < f32::EPSILON); + assert!(p[55].abs() < f32::EPSILON); + } + + #[test] + fn wipose_keypoint_mapping() { + let mut kp = vec![0.0f32; 18 * 3]; + kp[0] = 1.0; kp[1] = 2.0; kp[2] = 1.0; // nose + kp[3] = 99.0; kp[4] = 99.0; kp[5] = 99.0; // extra (dropped) + kp[6] = 3.0; kp[7] = 4.0; kp[8] = 1.0; // left eye -> COCO 1 + let label = WiPoseDataset::map_18_to_17(&kp); + assert_eq!(label.keypoints.len(), 17); + assert!((label.keypoints[0].0 - 1.0).abs() < f32::EPSILON); + assert!((label.keypoints[1].0 - 3.0).abs() < f32::EPSILON); // not 99 + } + + #[test] + fn train_val_split_ratio() { + let mk = |n: usize| MmFiDataset { + csi_frames: (0..n).map(|i| CsiSample { amplitude: vec![i as f32; 56], phase: vec![0.0; 56], timestamp_ms: i as u64 }).collect(), + labels: (0..n).map(|_| PoseLabel::default()).collect(), + sample_rate_hz: 20.0, n_subcarriers: 56, + }; + let (train, val) = mk(100).split_train_val(0.8); + assert_eq!(train.len(), 80); + assert_eq!(val.len(), 20); + assert_eq!(train.len() + val.len(), 100); + } + + #[test] + fn sliding_window_count() { + let ds = MmFiDataset { + csi_frames: (0..20).map(|i| CsiSample { amplitude: vec![i as f32; 56], phase: vec![0.0; 56], timestamp_ms: i as u64 }).collect(), + labels: (0..20).map(|_| PoseLabel::default()).collect(), + sample_rate_hz: 20.0, n_subcarriers: 56, + }; + assert_eq!(ds.iter_windows(5, 5).count(), 4); + assert_eq!(ds.iter_windows(5, 1).count(), 16); + } + + #[test] + fn sliding_window_overlap() { + let ds = MmFiDataset { + csi_frames: (0..10).map(|i| CsiSample { amplitude: vec![i as f32; 56], phase: vec![0.0; 56], timestamp_ms: i as u64 }).collect(), + labels: (0..10).map(|_| PoseLabel::default()).collect(), + sample_rate_hz: 20.0, n_subcarriers: 56, + }; + let w: Vec<_> = ds.iter_windows(4, 2).collect(); + assert_eq!(w.len(), 4); + assert!((w[0].0[0].amplitude[0]).abs() < f32::EPSILON); + assert!((w[1].0[0].amplitude[0] - 2.0).abs() < f32::EPSILON); + assert_eq!(w[0].0[2].amplitude[0], w[1].0[0].amplitude[0]); // overlap + } + + #[test] + fn data_pipeline_normalize() { + let mut samples = vec![ + TrainingSample { csi_window: vec![vec![10.0, 20.0, 30.0]; 2], pose_label: PoseLabel::default(), source: "test" }, + TrainingSample { csi_window: vec![vec![30.0, 40.0, 50.0]; 2], pose_label: PoseLabel::default(), source: "test" }, + ]; + DataPipeline::normalize_samples(&mut samples); + for j in 0..3 { + let (mut s, mut c) = (0.0f64, 0u64); + for sam in &samples { for f in &sam.csi_window { s += f[j] as f64; c += 1; } } + assert!(( s / c as f64).abs() < 1e-5, "mean not ~0 for sub {j}"); + let mut vs = 0.0f64; + let m = s / c as f64; + for sam in &samples { for f in &sam.csi_window { vs += (f[j] as f64 - m).powi(2); } } + assert!(((vs / c as f64).sqrt() - 1.0).abs() < 0.1, "std not ~1 for sub {j}"); + } + } + + #[test] + fn pose_label_default() { + let l = PoseLabel::default(); + assert_eq!(l.keypoints.len(), 17); + assert!(l.body_parts.is_empty()); + assert!(l.confidence.abs() < f32::EPSILON); + for (i, kp) in l.keypoints.iter().enumerate() { + assert!(kp.0.abs() < f32::EPSILON && kp.1.abs() < f32::EPSILON, "kp {i} not zero"); + } + } + + #[test] + fn body_part_uv_round_trip() { + let bpu = BodyPartUV { part_id: 5, u_coords: vec![0.1, 0.2, 0.3], v_coords: vec![0.4, 0.5, 0.6] }; + let json = serde_json::to_string(&bpu).unwrap(); + let r: BodyPartUV = serde_json::from_str(&json).unwrap(); + assert_eq!(r.part_id, 5); + assert_eq!(r.u_coords.len(), 3); + assert!((r.u_coords[0] - 0.1).abs() < f32::EPSILON); + assert!((r.v_coords[2] - 0.6).abs() < f32::EPSILON); + } + + #[test] + fn combined_source_merges_datasets() { + let mk = |n: usize, base: f32| -> (Vec, Vec) { + let f: Vec = (0..n).map(|i| CsiSample { amplitude: vec![base + i as f32; 56], phase: vec![0.0; 56], timestamp_ms: i as u64 * 50 }).collect(); + let l: Vec = (0..n).map(|_| PoseLabel::default()).collect(); + (f, l) + }; + let pipe = DataPipeline::new(DataConfig { source: DataSource::Combined(Vec::new()), + window_size: 3, stride: 1, target_subcarriers: 56, normalize: false }); + let mut all = Vec::new(); + let (fa, la) = mk(5, 0.0); + pipe.extract_windows(&fa, &la, "mmfi", &mut all); + assert_eq!(all.len(), 3); + let (fb, lb) = mk(4, 100.0); + pipe.extract_windows(&fb, &lb, "wipose", &mut all); + assert_eq!(all.len(), 5); + assert_eq!(all[0].source, "mmfi"); + assert_eq!(all[3].source, "wipose"); + assert!(all[0].csi_window[0][0] < 10.0); + assert!(all[4].csi_window[0][0] > 90.0); + } +} diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/graph_transformer.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/graph_transformer.rs new file mode 100644 index 0000000..e46e5ce --- /dev/null +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/graph_transformer.rs @@ -0,0 +1,589 @@ +//! Graph Transformer + GNN for WiFi CSI-to-Pose estimation (ADR-023 Phase 2). +//! +//! Cross-attention bottleneck between antenna-space CSI features and COCO 17-keypoint +//! body graph, followed by GCN message passing. All math is pure `std`. + +/// Xorshift64 PRNG for deterministic weight initialization. +#[derive(Debug, Clone)] +struct Rng64 { state: u64 } + +impl Rng64 { + fn new(seed: u64) -> Self { + Self { state: if seed == 0 { 0xDEAD_BEEF_CAFE_1234 } else { seed } } + } + fn next_u64(&mut self) -> u64 { + let mut x = self.state; + x ^= x << 13; x ^= x >> 7; x ^= x << 17; + self.state = x; x + } + /// Uniform f32 in (-1, 1). + fn next_f32(&mut self) -> f32 { + let f = (self.next_u64() >> 11) as f32 / (1u64 << 53) as f32; + f * 2.0 - 1.0 + } +} + +#[inline] +fn relu(x: f32) -> f32 { if x > 0.0 { x } else { 0.0 } } + +#[inline] +fn sigmoid(x: f32) -> f32 { + if x >= 0.0 { 1.0 / (1.0 + (-x).exp()) } + else { let ex = x.exp(); ex / (1.0 + ex) } +} + +/// Numerically stable softmax. Writes normalised weights into `out`. +fn softmax(scores: &[f32], out: &mut [f32]) { + debug_assert_eq!(scores.len(), out.len()); + if scores.is_empty() { return; } + let max = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max); + let mut sum = 0.0f32; + for (o, &s) in out.iter_mut().zip(scores) { + let e = (s - max).exp(); *o = e; sum += e; + } + let inv = if sum > 1e-10 { 1.0 / sum } else { 0.0 }; + for o in out.iter_mut() { *o *= inv; } +} + +// ── Linear layer ───────────────────────────────────────────────────────── + +/// Dense linear transformation y = Wx + b (row-major weights). +#[derive(Debug, Clone)] +pub struct Linear { + in_features: usize, + out_features: usize, + weights: Vec>, + bias: Vec, +} + +impl Linear { + /// Xavier/Glorot uniform init with default seed. + pub fn new(in_features: usize, out_features: usize) -> Self { + Self::with_seed(in_features, out_features, 42) + } + /// Xavier/Glorot uniform init with explicit seed. + pub fn with_seed(in_features: usize, out_features: usize, seed: u64) -> Self { + let mut rng = Rng64::new(seed); + let limit = (6.0 / (in_features + out_features) as f32).sqrt(); + let weights = (0..out_features) + .map(|_| (0..in_features).map(|_| rng.next_f32() * limit).collect()) + .collect(); + Self { in_features, out_features, weights, bias: vec![0.0; out_features] } + } + /// All-zero weights (for testing). + pub fn zeros(in_features: usize, out_features: usize) -> Self { + Self { + in_features, out_features, + weights: vec![vec![0.0; in_features]; out_features], + bias: vec![0.0; out_features], + } + } + /// Forward pass: y = Wx + b. + pub fn forward(&self, input: &[f32]) -> Vec { + assert_eq!(input.len(), self.in_features, + "Linear input mismatch: expected {}, got {}", self.in_features, input.len()); + let mut out = vec![0.0f32; self.out_features]; + for (i, row) in self.weights.iter().enumerate() { + let mut s = self.bias[i]; + for (w, x) in row.iter().zip(input) { s += w * x; } + out[i] = s; + } + out + } + pub fn weights(&self) -> &[Vec] { &self.weights } + pub fn set_weights(&mut self, w: Vec>) { + assert_eq!(w.len(), self.out_features); + for row in &w { assert_eq!(row.len(), self.in_features); } + self.weights = w; + } + pub fn set_bias(&mut self, b: Vec) { + assert_eq!(b.len(), self.out_features); + self.bias = b; + } +} + +// ── AntennaGraph ───────────────────────────────────────────────────────── + +/// Spatial topology graph over TX-RX antenna pairs. Nodes = pairs, edges connect +/// pairs sharing a TX or RX antenna. +#[derive(Debug, Clone)] +pub struct AntennaGraph { + n_tx: usize, n_rx: usize, n_pairs: usize, + adjacency: Vec>, +} + +impl AntennaGraph { + /// Build antenna graph. pair_id = tx * n_rx + rx. Adjacent if shared TX or RX. + pub fn new(n_tx: usize, n_rx: usize) -> Self { + let n_pairs = n_tx * n_rx; + let mut adj = vec![vec![0.0f32; n_pairs]; n_pairs]; + for i in 0..n_pairs { + let (tx_i, rx_i) = (i / n_rx, i % n_rx); + adj[i][i] = 1.0; + for j in (i + 1)..n_pairs { + let (tx_j, rx_j) = (j / n_rx, j % n_rx); + if tx_i == tx_j || rx_i == rx_j { + adj[i][j] = 1.0; adj[j][i] = 1.0; + } + } + } + Self { n_tx, n_rx, n_pairs, adjacency: adj } + } + pub fn n_nodes(&self) -> usize { self.n_pairs } + pub fn adjacency_matrix(&self) -> &Vec> { &self.adjacency } + pub fn n_tx(&self) -> usize { self.n_tx } + pub fn n_rx(&self) -> usize { self.n_rx } +} + +// ── BodyGraph ──────────────────────────────────────────────────────────── + +/// COCO 17-keypoint skeleton graph with 16 anatomical edges. +/// +/// Indices: 0=nose 1=l_eye 2=r_eye 3=l_ear 4=r_ear 5=l_shoulder 6=r_shoulder +/// 7=l_elbow 8=r_elbow 9=l_wrist 10=r_wrist 11=l_hip 12=r_hip 13=l_knee +/// 14=r_knee 15=l_ankle 16=r_ankle +#[derive(Debug, Clone)] +pub struct BodyGraph { + adjacency: [[f32; 17]; 17], + edges: Vec<(usize, usize)>, +} + +pub const COCO_KEYPOINT_NAMES: [&str; 17] = [ + "nose","left_eye","right_eye","left_ear","right_ear", + "left_shoulder","right_shoulder","left_elbow","right_elbow", + "left_wrist","right_wrist","left_hip","right_hip", + "left_knee","right_knee","left_ankle","right_ankle", +]; + +const COCO_EDGES: [(usize, usize); 16] = [ + (0,1),(0,2),(1,3),(2,4),(5,6),(5,7),(7,9),(6,8), + (8,10),(5,11),(6,12),(11,12),(11,13),(13,15),(12,14),(14,16), +]; + +impl BodyGraph { + pub fn new() -> Self { + let mut adjacency = [[0.0f32; 17]; 17]; + for i in 0..17 { adjacency[i][i] = 1.0; } + for &(u, v) in &COCO_EDGES { adjacency[u][v] = 1.0; adjacency[v][u] = 1.0; } + Self { adjacency, edges: COCO_EDGES.to_vec() } + } + pub fn adjacency_matrix(&self) -> &[[f32; 17]; 17] { &self.adjacency } + pub fn edge_list(&self) -> &Vec<(usize, usize)> { &self.edges } + pub fn n_nodes(&self) -> usize { 17 } + pub fn n_edges(&self) -> usize { self.edges.len() } + + /// Degree of each node (including self-loop). + pub fn degrees(&self) -> [f32; 17] { + let mut deg = [0.0f32; 17]; + for i in 0..17 { for j in 0..17 { deg[i] += self.adjacency[i][j]; } } + deg + } + /// Symmetric normalised adjacency D^{-1/2} A D^{-1/2}. + pub fn normalized_adjacency(&self) -> [[f32; 17]; 17] { + let deg = self.degrees(); + let inv_sqrt: Vec = deg.iter() + .map(|&d| if d > 0.0 { 1.0 / d.sqrt() } else { 0.0 }).collect(); + let mut norm = [[0.0f32; 17]; 17]; + for i in 0..17 { for j in 0..17 { + norm[i][j] = inv_sqrt[i] * self.adjacency[i][j] * inv_sqrt[j]; + }} + norm + } +} + +impl Default for BodyGraph { fn default() -> Self { Self::new() } } + +// ── CrossAttention ─────────────────────────────────────────────────────── + +/// Multi-head scaled dot-product cross-attention. +/// Attn(Q,K,V) = softmax(QK^T / sqrt(d_k)) V, split into n_heads. +#[derive(Debug, Clone)] +pub struct CrossAttention { + d_model: usize, n_heads: usize, d_k: usize, + w_q: Linear, w_k: Linear, w_v: Linear, w_o: Linear, +} + +impl CrossAttention { + pub fn new(d_model: usize, n_heads: usize) -> Self { + assert!(d_model % n_heads == 0, + "d_model ({d_model}) must be divisible by n_heads ({n_heads})"); + let d_k = d_model / n_heads; + let s = 123u64; + Self { d_model, n_heads, d_k, + w_q: Linear::with_seed(d_model, d_model, s), + w_k: Linear::with_seed(d_model, d_model, s+1), + w_v: Linear::with_seed(d_model, d_model, s+2), + w_o: Linear::with_seed(d_model, d_model, s+3), + } + } + /// query [n_q, d_model], key/value [n_kv, d_model] -> [n_q, d_model]. + pub fn forward(&self, query: &[Vec], key: &[Vec], value: &[Vec]) -> Vec> { + let (n_q, n_kv) = (query.len(), key.len()); + if n_q == 0 || n_kv == 0 { return vec![vec![0.0; self.d_model]; n_q]; } + + let q_proj: Vec> = query.iter().map(|q| self.w_q.forward(q)).collect(); + let k_proj: Vec> = key.iter().map(|k| self.w_k.forward(k)).collect(); + let v_proj: Vec> = value.iter().map(|v| self.w_v.forward(v)).collect(); + + let scale = (self.d_k as f32).sqrt(); + let mut output = vec![vec![0.0f32; self.d_model]; n_q]; + + for qi in 0..n_q { + let mut concat = Vec::with_capacity(self.d_model); + for h in 0..self.n_heads { + let (start, end) = (h * self.d_k, (h + 1) * self.d_k); + let q_h = &q_proj[qi][start..end]; + let mut scores = vec![0.0f32; n_kv]; + for ki in 0..n_kv { + let dot: f32 = q_h.iter().zip(&k_proj[ki][start..end]).map(|(a,b)| a*b).sum(); + scores[ki] = dot / scale; + } + let mut wts = vec![0.0f32; n_kv]; + softmax(&scores, &mut wts); + let mut head_out = vec![0.0f32; self.d_k]; + for ki in 0..n_kv { + for (o, &v) in head_out.iter_mut().zip(&v_proj[ki][start..end]) { + *o += wts[ki] * v; + } + } + concat.extend_from_slice(&head_out); + } + output[qi] = self.w_o.forward(&concat); + } + output + } + pub fn d_model(&self) -> usize { self.d_model } + pub fn n_heads(&self) -> usize { self.n_heads } +} + +// ── GraphMessagePassing ────────────────────────────────────────────────── + +/// GCN layer: H' = ReLU(A_norm H W) where A_norm = D^{-1/2} A D^{-1/2}. +#[derive(Debug, Clone)] +pub struct GraphMessagePassing { + in_features: usize, out_features: usize, + weight: Linear, norm_adj: [[f32; 17]; 17], +} + +impl GraphMessagePassing { + pub fn new(in_features: usize, out_features: usize, graph: &BodyGraph) -> Self { + Self { in_features, out_features, + weight: Linear::with_seed(in_features, out_features, 777), + norm_adj: graph.normalized_adjacency() } + } + /// node_features [17, in_features] -> [17, out_features]. + pub fn forward(&self, node_features: &[Vec]) -> Vec> { + assert_eq!(node_features.len(), 17, "expected 17 nodes, got {}", node_features.len()); + let mut agg = vec![vec![0.0f32; self.in_features]; 17]; + for i in 0..17 { for j in 0..17 { + let a = self.norm_adj[i][j]; + if a.abs() > 1e-10 { + for (ag, &f) in agg[i].iter_mut().zip(&node_features[j]) { *ag += a * f; } + } + }} + agg.iter().map(|a| self.weight.forward(a).into_iter().map(relu).collect()).collect() + } + pub fn in_features(&self) -> usize { self.in_features } + pub fn out_features(&self) -> usize { self.out_features } +} + +/// Stack of GCN layers. +#[derive(Debug, Clone)] +struct GnnStack { layers: Vec } + +impl GnnStack { + fn new(in_f: usize, out_f: usize, n: usize, g: &BodyGraph) -> Self { + assert!(n >= 1); + let mut layers = vec![GraphMessagePassing::new(in_f, out_f, g)]; + for _ in 1..n { layers.push(GraphMessagePassing::new(out_f, out_f, g)); } + Self { layers } + } + fn forward(&self, feats: &[Vec]) -> Vec> { + let mut h = feats.to_vec(); + for l in &self.layers { h = l.forward(&h); } + h + } +} + +// ── Transformer config / output / pipeline ─────────────────────────────── + +/// Configuration for the CSI-to-Pose transformer. +#[derive(Debug, Clone)] +pub struct TransformerConfig { + pub n_subcarriers: usize, + pub n_keypoints: usize, + pub d_model: usize, + pub n_heads: usize, + pub n_gnn_layers: usize, +} + +impl Default for TransformerConfig { + fn default() -> Self { + Self { n_subcarriers: 56, n_keypoints: 17, d_model: 64, n_heads: 4, n_gnn_layers: 2 } + } +} + +/// Output of the CSI-to-Pose transformer. +#[derive(Debug, Clone)] +pub struct PoseOutput { + /// Predicted (x, y, z) per keypoint. + pub keypoints: Vec<(f32, f32, f32)>, + /// Per-keypoint confidence in [0, 1]. + pub confidences: Vec, + /// Per-keypoint GNN features for downstream use. + pub body_part_features: Vec>, +} + +/// Full CSI-to-Pose pipeline: CSI embed -> cross-attention -> GNN -> regression heads. +#[derive(Debug, Clone)] +pub struct CsiToPoseTransformer { + config: TransformerConfig, + csi_embed: Linear, + keypoint_queries: Vec>, + cross_attn: CrossAttention, + gnn: GnnStack, + xyz_head: Linear, + conf_head: Linear, +} + +impl CsiToPoseTransformer { + pub fn new(config: TransformerConfig) -> Self { + let d = config.d_model; + let bg = BodyGraph::new(); + let mut rng = Rng64::new(999); + let limit = (6.0 / (config.n_keypoints + d) as f32).sqrt(); + let kq: Vec> = (0..config.n_keypoints) + .map(|_| (0..d).map(|_| rng.next_f32() * limit).collect()).collect(); + Self { + csi_embed: Linear::with_seed(config.n_subcarriers, d, 500), + keypoint_queries: kq, + cross_attn: CrossAttention::new(d, config.n_heads), + gnn: GnnStack::new(d, d, config.n_gnn_layers, &bg), + xyz_head: Linear::with_seed(d, 3, 600), + conf_head: Linear::with_seed(d, 1, 700), + config, + } + } + /// csi_features [n_antenna_pairs, n_subcarriers] -> PoseOutput with 17 keypoints. + pub fn forward(&self, csi_features: &[Vec]) -> PoseOutput { + let embedded: Vec> = csi_features.iter() + .map(|f| self.csi_embed.forward(f)).collect(); + let attended = self.cross_attn.forward(&self.keypoint_queries, &embedded, &embedded); + let gnn_out = self.gnn.forward(&attended); + let mut kps = Vec::with_capacity(self.config.n_keypoints); + let mut confs = Vec::with_capacity(self.config.n_keypoints); + for nf in &gnn_out { + let xyz = self.xyz_head.forward(nf); + kps.push((xyz[0], xyz[1], xyz[2])); + confs.push(sigmoid(self.conf_head.forward(nf)[0])); + } + PoseOutput { keypoints: kps, confidences: confs, body_part_features: gnn_out } + } + pub fn config(&self) -> &TransformerConfig { &self.config } +} + +// ── Tests ──────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn body_graph_has_17_nodes() { + assert_eq!(BodyGraph::new().n_nodes(), 17); + } + + #[test] + fn body_graph_has_16_edges() { + let g = BodyGraph::new(); + assert_eq!(g.n_edges(), 16); + assert_eq!(g.edge_list().len(), 16); + } + + #[test] + fn body_graph_adjacency_symmetric() { + let bg = BodyGraph::new(); + let adj = bg.adjacency_matrix(); + for i in 0..17 { for j in 0..17 { + assert_eq!(adj[i][j], adj[j][i], "asymmetric at ({i},{j})"); + }} + } + + #[test] + fn body_graph_self_loops_and_specific_edges() { + let bg = BodyGraph::new(); + let adj = bg.adjacency_matrix(); + for i in 0..17 { assert_eq!(adj[i][i], 1.0); } + assert_eq!(adj[0][1], 1.0); // nose-left_eye + assert_eq!(adj[5][6], 1.0); // l_shoulder-r_shoulder + assert_eq!(adj[14][16], 1.0); // r_knee-r_ankle + assert_eq!(adj[0][15], 0.0); // nose should NOT connect to l_ankle + } + + #[test] + fn antenna_graph_node_count() { + assert_eq!(AntennaGraph::new(3, 3).n_nodes(), 9); + } + + #[test] + fn antenna_graph_adjacency() { + let ag = AntennaGraph::new(2, 2); + let adj = ag.adjacency_matrix(); + assert_eq!(adj[0][1], 1.0); // share tx=0 + assert_eq!(adj[0][2], 1.0); // share rx=0 + assert_eq!(adj[0][3], 0.0); // share neither + } + + #[test] + fn cross_attention_output_shape() { + let ca = CrossAttention::new(16, 4); + let out = ca.forward(&vec![vec![0.5; 16]; 5], &vec![vec![0.3; 16]; 3], &vec![vec![0.7; 16]; 3]); + assert_eq!(out.len(), 5); + for r in &out { assert_eq!(r.len(), 16); } + } + + #[test] + fn cross_attention_single_head_vs_multi() { + let (q, k, v) = (vec![vec![1.0f32; 8]; 2], vec![vec![0.5; 8]; 3], vec![vec![0.5; 8]; 3]); + let o1 = CrossAttention::new(8, 1).forward(&q, &k, &v); + let o2 = CrossAttention::new(8, 2).forward(&q, &k, &v); + assert_eq!(o1.len(), o2.len()); + assert_eq!(o1[0].len(), o2[0].len()); + } + + #[test] + fn scaled_dot_product_softmax_sums_to_one() { + let scores = vec![1.0f32, 2.0, 3.0, 0.5]; + let mut w = vec![0.0f32; 4]; + softmax(&scores, &mut w); + assert!((w.iter().sum::() - 1.0).abs() < 1e-5); + for &wi in &w { assert!(wi > 0.0); } + assert!(w[2] > w[0] && w[2] > w[1] && w[2] > w[3]); + } + + #[test] + fn gnn_message_passing_shape() { + let g = BodyGraph::new(); + let out = GraphMessagePassing::new(32, 16, &g).forward(&vec![vec![1.0; 32]; 17]); + assert_eq!(out.len(), 17); + for r in &out { assert_eq!(r.len(), 16); } + } + + #[test] + fn gnn_preserves_isolated_node() { + let g = BodyGraph::new(); + let gmp = GraphMessagePassing::new(8, 8, &g); + let mut feats: Vec> = vec![vec![0.0; 8]; 17]; + feats[0] = vec![1.0; 8]; // only nose has signal + let out = gmp.forward(&feats); + let ankle_e: f32 = out[15].iter().map(|x| x*x).sum(); + let nose_e: f32 = out[0].iter().map(|x| x*x).sum(); + assert!(nose_e > ankle_e, "nose ({nose_e}) should > ankle ({ankle_e})"); + } + + #[test] + fn linear_layer_output_size() { + assert_eq!(Linear::new(10, 5).forward(&vec![1.0; 10]).len(), 5); + } + + #[test] + fn linear_layer_zero_weights() { + let out = Linear::zeros(4, 3).forward(&[1.0, 2.0, 3.0, 4.0]); + for &v in &out { assert_eq!(v, 0.0); } + } + + #[test] + fn linear_layer_set_weights_identity() { + let mut lin = Linear::zeros(2, 2); + lin.set_weights(vec![vec![1.0, 0.0], vec![0.0, 1.0]]); + let out = lin.forward(&[3.0, 7.0]); + assert!((out[0] - 3.0).abs() < 1e-6 && (out[1] - 7.0).abs() < 1e-6); + } + + #[test] + fn transformer_config_defaults() { + let c = TransformerConfig::default(); + assert_eq!((c.n_subcarriers, c.n_keypoints, c.d_model, c.n_heads, c.n_gnn_layers), + (56, 17, 64, 4, 2)); + } + + #[test] + fn transformer_forward_output_17_keypoints() { + let t = CsiToPoseTransformer::new(TransformerConfig { + n_subcarriers: 16, n_keypoints: 17, d_model: 8, n_heads: 2, n_gnn_layers: 1, + }); + let out = t.forward(&vec![vec![0.5; 16]; 4]); + assert_eq!(out.keypoints.len(), 17); + assert_eq!(out.confidences.len(), 17); + assert_eq!(out.body_part_features.len(), 17); + } + + #[test] + fn transformer_keypoints_are_finite() { + let t = CsiToPoseTransformer::new(TransformerConfig { + n_subcarriers: 8, n_keypoints: 17, d_model: 8, n_heads: 2, n_gnn_layers: 2, + }); + let out = t.forward(&vec![vec![1.0; 8]; 6]); + for (i, &(x, y, z)) in out.keypoints.iter().enumerate() { + assert!(x.is_finite() && y.is_finite() && z.is_finite(), "kp {i} not finite"); + } + for (i, &c) in out.confidences.iter().enumerate() { + assert!(c.is_finite() && (0.0..=1.0).contains(&c), "conf {i} invalid: {c}"); + } + } + + #[test] + fn relu_activation() { + assert_eq!(relu(-5.0), 0.0); + assert_eq!(relu(-0.001), 0.0); + assert_eq!(relu(0.0), 0.0); + assert_eq!(relu(3.14), 3.14); + assert_eq!(relu(100.0), 100.0); + } + + #[test] + fn sigmoid_bounds() { + assert!((sigmoid(0.0) - 0.5).abs() < 1e-6); + assert!(sigmoid(100.0) > 0.999); + assert!(sigmoid(-100.0) < 0.001); + } + + #[test] + fn deterministic_rng_and_linear() { + let (mut r1, mut r2) = (Rng64::new(42), Rng64::new(42)); + for _ in 0..100 { assert_eq!(r1.next_u64(), r2.next_u64()); } + let inp = vec![1.0, 2.0, 3.0, 4.0]; + assert_eq!(Linear::with_seed(4, 3, 99).forward(&inp), + Linear::with_seed(4, 3, 99).forward(&inp)); + } + + #[test] + fn body_graph_normalized_adjacency_finite() { + let norm = BodyGraph::new().normalized_adjacency(); + for i in 0..17 { + let s: f32 = norm[i].iter().sum(); + assert!(s.is_finite() && s > 0.0, "row {i} sum={s}"); + } + } + + #[test] + fn cross_attention_empty_keys() { + let out = CrossAttention::new(8, 2).forward( + &vec![vec![1.0; 8]; 3], &vec![], &vec![]); + assert_eq!(out.len(), 3); + for r in &out { for &v in r { assert_eq!(v, 0.0); } } + } + + #[test] + fn softmax_edge_cases() { + let mut w1 = vec![0.0f32; 1]; + softmax(&[42.0], &mut w1); + assert!((w1[0] - 1.0).abs() < 1e-6); + + let mut w3 = vec![0.0f32; 3]; + softmax(&[1000.0, 1001.0, 999.0], &mut w3); + let sum: f32 = w3.iter().sum(); + assert!((sum - 1.0).abs() < 1e-5); + for &wi in &w3 { assert!(wi.is_finite()); } + } +} diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/lib.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/lib.rs index 6ef4e67..9ee67b5 100644 --- a/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/lib.rs +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/lib.rs @@ -6,3 +6,9 @@ pub mod vital_signs; pub mod rvf_container; +pub mod rvf_pipeline; +pub mod graph_transformer; +pub mod trainer; +pub mod dataset; +pub mod sona; +pub mod sparse_inference; diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/main.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/main.rs index 7aac855..c3bdc14 100644 --- a/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/main.rs +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/main.rs @@ -9,6 +9,7 @@ //! Replaces both ws_server.py and the Python HTTP server. mod rvf_container; +mod rvf_pipeline; mod vital_signs; use std::collections::VecDeque; @@ -23,7 +24,7 @@ use axum::{ State, }, response::{Html, IntoResponse, Json}, - routing::get, + routing::{get, post}, Router, }; use clap::Parser; @@ -37,8 +38,15 @@ use axum::http::HeaderValue; use tracing::{info, warn, debug, error}; use rvf_container::{RvfBuilder, RvfContainerInfo, RvfReader, VitalSignConfig}; +use rvf_pipeline::ProgressiveLoader; use vital_signs::{VitalSignDetector, VitalSigns}; +// ADR-022 Phase 3: Multi-BSSID pipeline integration +use wifi_densepose_wifiscan::{ + BssidRegistry, WindowsWifiPipeline, +}; +use wifi_densepose_wifiscan::parse_netsh_output as parse_netsh_bssid_output; + // ── CLI ────────────────────────────────────────────────────────────────────── #[derive(Parser, Debug)] @@ -79,6 +87,14 @@ struct Args { /// Save current model state as an RVF container on shutdown #[arg(long, value_name = "PATH")] save_rvf: Option, + + /// Load a trained .rvf model for inference + #[arg(long, value_name = "PATH")] + model: Option, + + /// Enable progressive loading (Layer A instant start) + #[arg(long)] + progressive: bool, } // ── Data types ─────────────────────────────────────────────────────────────── @@ -114,6 +130,32 @@ struct SensingUpdate { /// Vital sign estimates (breathing rate, heart rate, confidence). #[serde(skip_serializing_if = "Option::is_none")] vital_signs: Option, + // ── ADR-022 Phase 3: Enhanced multi-BSSID pipeline fields ── + /// Enhanced motion estimate from multi-BSSID pipeline. + #[serde(skip_serializing_if = "Option::is_none")] + enhanced_motion: Option, + /// Enhanced breathing estimate from multi-BSSID pipeline. + #[serde(skip_serializing_if = "Option::is_none")] + enhanced_breathing: Option, + /// Posture classification from BSSID fingerprint matching. + #[serde(skip_serializing_if = "Option::is_none")] + posture: Option, + /// Signal quality score from multi-BSSID quality gate [0.0, 1.0]. + #[serde(skip_serializing_if = "Option::is_none")] + signal_quality_score: Option, + /// Quality gate verdict: "Permit", "Warn", or "Deny". + #[serde(skip_serializing_if = "Option::is_none")] + quality_verdict: Option, + /// Number of BSSIDs used in the enhanced sensing cycle. + #[serde(skip_serializing_if = "Option::is_none")] + bssid_count: Option, + // ── ADR-023 Phase 7-8: Model inference fields ── + /// Pose keypoints when a trained model is loaded (x, y, z, confidence). + #[serde(skip_serializing_if = "Option::is_none")] + pose_keypoints: Option>, + /// Model status when a trained model is loaded. + #[serde(skip_serializing_if = "Option::is_none")] + model_status: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -194,6 +236,12 @@ struct AppStateInner { rvf_info: Option, /// Path to save RVF container on shutdown (set via `--save-rvf`). save_rvf_path: Option, + /// Progressive loader for a trained model (set via `--model`). + progressive_loader: Option, + /// Active SONA profile name. + active_sona_profile: Option, + /// Whether a trained model is loaded. + model_loaded: bool, } type SharedState = Arc>; @@ -376,7 +424,7 @@ fn extract_features_from_frame(frame: &Esp32Frame) -> (FeatureInfo, Classificati // ── Windows WiFi RSSI collector ────────────────────────────────────────────── /// Parse `netsh wlan show interfaces` output for RSSI and signal quality -fn parse_netsh_output(output: &str) -> Option<(f64, f64, String)> { +fn parse_netsh_interfaces_output(output: &str) -> Option<(f64, f64, String)> { let mut rssi = None; let mut signal = None; let mut ssid = None; @@ -411,52 +459,126 @@ fn parse_netsh_output(output: &str) -> Option<(f64, f64, String)> { async fn windows_wifi_task(state: SharedState, tick_ms: u64) { let mut interval = tokio::time::interval(Duration::from_millis(tick_ms)); let mut seq: u32 = 0; - info!("Windows WiFi RSSI collector active (tick={}ms)", tick_ms); + + // ADR-022 Phase 3: Multi-BSSID pipeline state (kept across ticks) + let mut registry = BssidRegistry::new(32, 30); + let mut pipeline = WindowsWifiPipeline::new(); + + info!( + "Windows WiFi multi-BSSID pipeline active (tick={}ms, max_bssids=32)", + tick_ms + ); loop { interval.tick().await; seq += 1; - // Run netsh to get WiFi info - let output = match tokio::process::Command::new("netsh") - .args(["wlan", "show", "interfaces"]) - .output() - .await - { - Ok(o) => String::from_utf8_lossy(&o.stdout).to_string(), - Err(e) => { - warn!("netsh failed: {e}"); + // ── Step 1: Run multi-BSSID scan via spawn_blocking ────────── + // NetshBssidScanner is not Send, so we run `netsh` and parse + // the output inside a blocking closure. + let bssid_scan_result = tokio::task::spawn_blocking(|| { + let output = std::process::Command::new("netsh") + .args(["wlan", "show", "networks", "mode=bssid"]) + .output() + .map_err(|e| format!("netsh bssid scan failed: {e}"))?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + return Err(format!( + "netsh exited with {}: {}", + output.status, + stderr.trim() + )); + } + + let stdout = String::from_utf8_lossy(&output.stdout); + parse_netsh_bssid_output(&stdout).map_err(|e| format!("parse error: {e}")) + }) + .await; + + // Unwrap the JoinHandle result, then the inner Result. + let observations = match bssid_scan_result { + Ok(Ok(obs)) if !obs.is_empty() => obs, + Ok(Ok(_empty)) => { + debug!("Multi-BSSID scan returned 0 observations, falling back"); + windows_wifi_fallback_tick(&state, seq).await; + continue; + } + Ok(Err(e)) => { + warn!("Multi-BSSID scan error: {e}, falling back"); + windows_wifi_fallback_tick(&state, seq).await; + continue; + } + Err(join_err) => { + error!("spawn_blocking panicked: {join_err}"); continue; } }; - let (rssi_dbm, signal_pct, ssid) = match parse_netsh_output(&output) { - Some(v) => v, - None => { - debug!("No WiFi interface connected"); - continue; - } - }; + let obs_count = observations.len(); + + // Derive SSID from the first observation for the source label. + let ssid = observations + .first() + .map(|o| o.ssid.clone()) + .unwrap_or_else(|| "Unknown".into()); + + // ── Step 2: Feed observations into registry ────────────────── + registry.update(&observations); + let multi_ap_frame = registry.to_multi_ap_frame(); + + // ── Step 3: Run enhanced pipeline ──────────────────────────── + let enhanced = pipeline.process(&multi_ap_frame); + + // ── Step 4: Build backward-compatible Esp32Frame ───────────── + let first_rssi = observations + .first() + .map(|o| o.rssi_dbm) + .unwrap_or(-80.0); + let _first_signal_pct = observations + .first() + .map(|o| o.signal_pct) + .unwrap_or(40.0); - // Create a pseudo-frame from RSSI (single subcarrier) let frame = Esp32Frame { magic: 0xC511_0001, node_id: 0, n_antennas: 1, - n_subcarriers: 1, + n_subcarriers: obs_count.min(255) as u8, freq_mhz: 2437, sequence: seq, - rssi: rssi_dbm as i8, + rssi: first_rssi.clamp(-128.0, 127.0) as i8, noise_floor: -90, - amplitudes: vec![signal_pct], - phases: vec![0.0], + amplitudes: multi_ap_frame.amplitudes.clone(), + phases: multi_ap_frame.phases.clone(), }; let (features, classification) = extract_features_from_frame(&frame); + // ── Step 5: Build enhanced fields from pipeline result ─────── + let enhanced_motion = Some(serde_json::json!({ + "score": enhanced.motion.score, + "level": format!("{:?}", enhanced.motion.level), + "contributing_bssids": enhanced.motion.contributing_bssids, + })); + + let enhanced_breathing = enhanced.breathing.as_ref().map(|b| { + serde_json::json!({ + "rate_bpm": b.rate_bpm, + "confidence": b.confidence, + "bssid_count": b.bssid_count, + }) + }); + + let posture_str = enhanced.posture.map(|p| format!("{p:?}")); + let sig_quality_score = Some(enhanced.signal_quality.score); + let verdict_str = Some(format!("{:?}", enhanced.verdict)); + let bssid_n = Some(enhanced.bssid_count); + + // ── Step 6: Update shared state ────────────────────────────── let mut s = state.write().await; s.source = format!("wifi:{ssid}"); - s.rssi_history.push_back(rssi_dbm); + s.rssi_history.push_back(first_rssi); if s.rssi_history.len() > 60 { s.rssi_history.pop_front(); } @@ -464,14 +586,15 @@ async fn windows_wifi_task(state: SharedState, tick_ms: u64) { s.tick += 1; let tick = s.tick; - let motion_score = if classification.motion_level == "active" { 0.8 } - else if classification.motion_level == "present_still" { 0.3 } - else { 0.05 }; + let motion_score = if classification.motion_level == "active" { + 0.8 + } else if classification.motion_level == "present_still" { + 0.3 + } else { + 0.05 + }; - let vitals = s.vital_detector.process_frame( - &frame.amplitudes, - &frame.phases, - ); + let vitals = s.vital_detector.process_frame(&frame.amplitudes, &frame.phases); s.latest_vitals = vitals.clone(); let update = SensingUpdate { @@ -481,24 +604,129 @@ async fn windows_wifi_task(state: SharedState, tick_ms: u64) { tick, nodes: vec![NodeInfo { node_id: 0, - rssi_dbm, + rssi_dbm: first_rssi, position: [0.0, 0.0, 0.0], - amplitude: vec![signal_pct], - subcarrier_count: 1, + amplitude: multi_ap_frame.amplitudes, + subcarrier_count: obs_count, }], features, classification, - signal_field: generate_signal_field(rssi_dbm, 1.0, motion_score, tick), + signal_field: generate_signal_field(first_rssi, 1.0, motion_score, tick), vital_signs: Some(vitals), + enhanced_motion, + enhanced_breathing, + posture: posture_str, + signal_quality_score: sig_quality_score, + quality_verdict: verdict_str, + bssid_count: bssid_n, + pose_keypoints: None, + model_status: None, }; if let Ok(json) = serde_json::to_string(&update) { let _ = s.tx.send(json); } s.latest_update = Some(update); + + debug!( + "Multi-BSSID tick #{tick}: {obs_count} BSSIDs, quality={:.2}, verdict={:?}", + enhanced.signal_quality.score, enhanced.verdict + ); } } +/// Fallback: single-RSSI collection via `netsh wlan show interfaces`. +/// +/// Used when the multi-BSSID scan fails or returns 0 observations. +async fn windows_wifi_fallback_tick(state: &SharedState, seq: u32) { + let output = match tokio::process::Command::new("netsh") + .args(["wlan", "show", "interfaces"]) + .output() + .await + { + Ok(o) => String::from_utf8_lossy(&o.stdout).to_string(), + Err(e) => { + warn!("netsh interfaces fallback failed: {e}"); + return; + } + }; + + let (rssi_dbm, signal_pct, ssid) = match parse_netsh_interfaces_output(&output) { + Some(v) => v, + None => { + debug!("Fallback: no WiFi interface connected"); + return; + } + }; + + let frame = Esp32Frame { + magic: 0xC511_0001, + node_id: 0, + n_antennas: 1, + n_subcarriers: 1, + freq_mhz: 2437, + sequence: seq, + rssi: rssi_dbm as i8, + noise_floor: -90, + amplitudes: vec![signal_pct], + phases: vec![0.0], + }; + + let (features, classification) = extract_features_from_frame(&frame); + + let mut s = state.write().await; + s.source = format!("wifi:{ssid}"); + s.rssi_history.push_back(rssi_dbm); + if s.rssi_history.len() > 60 { + s.rssi_history.pop_front(); + } + + s.tick += 1; + let tick = s.tick; + + let motion_score = if classification.motion_level == "active" { + 0.8 + } else if classification.motion_level == "present_still" { + 0.3 + } else { + 0.05 + }; + + let vitals = s.vital_detector.process_frame(&frame.amplitudes, &frame.phases); + s.latest_vitals = vitals.clone(); + + let update = SensingUpdate { + msg_type: "sensing_update".to_string(), + timestamp: chrono::Utc::now().timestamp_millis() as f64 / 1000.0, + source: format!("wifi:{ssid}"), + tick, + nodes: vec![NodeInfo { + node_id: 0, + rssi_dbm, + position: [0.0, 0.0, 0.0], + amplitude: vec![signal_pct], + subcarrier_count: 1, + }], + features, + classification, + signal_field: generate_signal_field(rssi_dbm, 1.0, motion_score, tick), + vital_signs: Some(vitals), + enhanced_motion: None, + enhanced_breathing: None, + posture: None, + signal_quality_score: None, + quality_verdict: None, + bssid_count: None, + pose_keypoints: None, + model_status: None, + }; + + if let Ok(json) = serde_json::to_string(&update) { + let _ = s.tx.send(json); + } + s.latest_update = Some(update); +} + /// Probe if Windows WiFi is connected async fn probe_windows_wifi() -> bool { match tokio::process::Command::new("netsh") @@ -508,7 +736,7 @@ async fn probe_windows_wifi() -> bool { { Ok(o) => { let out = String::from_utf8_lossy(&o.stdout); - parse_netsh_output(&out).is_some() + parse_netsh_interfaces_output(&out).is_some() } Err(_) => false, } @@ -932,6 +1160,75 @@ async fn model_info(State(state): State) -> Json } } +async fn model_layers(State(state): State) -> Json { + let s = state.read().await; + match &s.progressive_loader { + Some(loader) => { + let (a, b, c) = loader.layer_status(); + Json(serde_json::json!({ + "layer_a": a, + "layer_b": b, + "layer_c": c, + "progress": loader.loading_progress(), + })) + } + None => Json(serde_json::json!({ + "layer_a": false, + "layer_b": false, + "layer_c": false, + "progress": 0.0, + "message": "No model loaded with progressive loading", + })), + } +} + +async fn model_segments(State(state): State) -> Json { + let s = state.read().await; + match &s.progressive_loader { + Some(loader) => Json(serde_json::json!({ "segments": loader.segment_list() })), + None => Json(serde_json::json!({ "segments": [] })), + } +} + +async fn sona_profiles(State(state): State) -> Json { + let s = state.read().await; + let names = s + .progressive_loader + .as_ref() + .map(|l| l.sona_profile_names()) + .unwrap_or_default(); + let active = s.active_sona_profile.clone().unwrap_or_default(); + Json(serde_json::json!({ "profiles": names, "active": active })) +} + +async fn sona_activate( + State(state): State, + Json(body): Json, +) -> Json { + let profile = body + .get("profile") + .and_then(|p| p.as_str()) + .unwrap_or("") + .to_string(); + + let mut s = state.write().await; + let available = s + .progressive_loader + .as_ref() + .map(|l| l.sona_profile_names()) + .unwrap_or_default(); + + if available.contains(&profile) { + s.active_sona_profile = Some(profile.clone()); + Json(serde_json::json!({ "status": "activated", "profile": profile })) + } else { + Json(serde_json::json!({ + "status": "error", + "message": format!("Profile '{}' not found. Available: {:?}", profile, available), + })) + } +} + async fn info_page() -> Html { Html(format!( "\ @@ -1012,6 +1309,14 @@ async fn udp_receiver_task(state: SharedState, udp_port: u16) { features.mean_rssi, features.variance, motion_score, tick, ), vital_signs: Some(vitals), + enhanced_motion: None, + enhanced_breathing: None, + posture: None, + signal_quality_score: None, + quality_verdict: None, + bssid_count: None, + pose_keypoints: None, + model_status: None, }; if let Ok(json) = serde_json::to_string(&update) { @@ -1077,6 +1382,24 @@ async fn simulated_data_task(state: SharedState, tick_ms: u64) { features.mean_rssi, features.variance, motion_score, tick, ), vital_signs: Some(vitals), + enhanced_motion: None, + enhanced_breathing: None, + posture: None, + signal_quality_score: None, + quality_verdict: None, + bssid_count: None, + pose_keypoints: None, + model_status: if s.model_loaded { + Some(serde_json::json!({ + "loaded": true, + "layers": s.progressive_loader.as_ref() + .map(|l| { let (a,b,c) = l.layer_status(); a as u8 + b as u8 + c as u8 }) + .unwrap_or(0), + "sona_profile": s.active_sona_profile.as_deref().unwrap_or("default"), + })) + } else { + None + }, }; if update.classification.presence { @@ -1208,6 +1531,30 @@ async fn main() { None }; + // Load trained model via --model (uses progressive loading if --progressive set) + let model_path = args.model.as_ref().or(args.load_rvf.as_ref()); + let mut progressive_loader: Option = None; + let mut model_loaded = false; + if let Some(mp) = model_path { + if args.progressive || args.model.is_some() { + info!("Loading trained model (progressive) from {}", mp.display()); + match std::fs::read(mp) { + Ok(data) => match ProgressiveLoader::new(&data) { + Ok(mut loader) => { + if let Ok(la) = loader.load_layer_a() { + info!(" Layer A ready: model={} v{} ({} segments)", + la.model_name, la.version, la.n_segments); + } + model_loaded = true; + progressive_loader = Some(loader); + } + Err(e) => error!("Progressive loader init failed: {e}"), + }, + Err(e) => error!("Failed to read model file: {e}"), + } + } + } + let (tx, _) = broadcast::channel::(256); let state: SharedState = Arc::new(RwLock::new(AppStateInner { latest_update: None, @@ -1221,6 +1568,9 @@ async fn main() { latest_vitals: VitalSigns::default(), rvf_info, save_rvf_path: args.save_rvf.clone(), + progressive_loader, + active_sona_profile: None, + model_loaded, })); // Start background tasks based on source @@ -1274,6 +1624,11 @@ async fn main() { .route("/api/v1/vital-signs", get(vital_signs_endpoint)) // RVF model container info .route("/api/v1/model/info", get(model_info)) + // Progressive loading & SONA endpoints (Phase 7-8) + .route("/api/v1/model/layers", get(model_layers)) + .route("/api/v1/model/segments", get(model_segments)) + .route("/api/v1/model/sona/profiles", get(sona_profiles)) + .route("/api/v1/model/sona/activate", post(sona_activate)) // Pose endpoints (WiFi-derived) .route("/api/v1/pose/current", get(pose_current)) .route("/api/v1/pose/stats", get(pose_stats)) diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/rvf_container.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/rvf_container.rs index 1473121..4b168f7 100644 --- a/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/rvf_container.rs +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/rvf_container.rs @@ -298,6 +298,12 @@ impl RvfBuilder { self.push_segment(SEG_QUANT, &payload); } + /// Add a raw segment with arbitrary type and payload. + /// Used by `rvf_pipeline` for extended segment types. + pub fn add_raw_segment(&mut self, seg_type: u8, payload: &[u8]) { + self.push_segment(seg_type, payload); + } + /// Add witness/proof data as a Witness segment. pub fn add_witness(&mut self, training_hash: &str, metrics: &serde_json::Value) { let witness = serde_json::json!({ diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/rvf_pipeline.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/rvf_pipeline.rs new file mode 100644 index 0000000..d8bcf82 --- /dev/null +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/rvf_pipeline.rs @@ -0,0 +1,1027 @@ +//! Extended RVF build pipeline — ADR-023 Phases 7-8. +//! +//! Adds HNSW index, overlay graph, SONA profile, and progressive loading +//! segments on top of the base `rvf_container` module. + +use std::path::Path; + +use crate::rvf_container::{RvfBuilder, RvfReader}; + +// ── Additional segment type discriminators ────────────────────────────────── + +/// HNSW index layers for sparse neuron routing. +pub const SEG_INDEX: u8 = 0x02; +/// Pre-computed min-cut graph structures. +pub const SEG_OVERLAY: u8 = 0x03; +/// SONA LoRA deltas per environment. +pub const SEG_AGGREGATE_WEIGHTS: u8 = 0x36; +/// Integrity signatures. +pub const SEG_CRYPTO: u8 = 0x0C; +/// WASM inference engine bytes. +pub const SEG_WASM: u8 = 0x10; +/// Embedded UI dashboard assets. +pub const SEG_DASHBOARD: u8 = 0x11; + +// ── HnswIndex ─────────────────────────────────────────────────────────────── + +/// A single node in an HNSW layer. +#[derive(Debug, Clone)] +pub struct HnswNode { + pub id: usize, + pub neighbors: Vec, + pub vector: Vec, +} + +/// One layer of the HNSW graph. +#[derive(Debug, Clone)] +pub struct HnswLayer { + pub nodes: Vec, +} + +/// Serializable HNSW index used for sparse inference neuron routing. +#[derive(Debug, Clone)] +pub struct HnswIndex { + pub layers: Vec, + pub entry_point: usize, + pub ef_construction: usize, + pub m: usize, +} + +impl HnswIndex { + /// Serialize the index to a byte vector. + /// + /// Wire format (all little-endian): + /// ```text + /// [entry_point: u64][ef_construction: u64][m: u64][n_layers: u32] + /// per layer: + /// [n_nodes: u32] + /// per node: + /// [id: u64][n_neighbors: u32][neighbors: u64*n][vec_len: u32][vector: f32*vec_len] + /// ``` + pub fn to_bytes(&self) -> Vec { + let mut buf = Vec::new(); + buf.extend_from_slice(&(self.entry_point as u64).to_le_bytes()); + buf.extend_from_slice(&(self.ef_construction as u64).to_le_bytes()); + buf.extend_from_slice(&(self.m as u64).to_le_bytes()); + buf.extend_from_slice(&(self.layers.len() as u32).to_le_bytes()); + + for layer in &self.layers { + buf.extend_from_slice(&(layer.nodes.len() as u32).to_le_bytes()); + for node in &layer.nodes { + buf.extend_from_slice(&(node.id as u64).to_le_bytes()); + buf.extend_from_slice(&(node.neighbors.len() as u32).to_le_bytes()); + for &n in &node.neighbors { + buf.extend_from_slice(&(n as u64).to_le_bytes()); + } + buf.extend_from_slice(&(node.vector.len() as u32).to_le_bytes()); + for &v in &node.vector { + buf.extend_from_slice(&v.to_le_bytes()); + } + } + } + buf + } + + /// Deserialize an HNSW index from bytes. + pub fn from_bytes(data: &[u8]) -> Result { + let mut off = 0usize; + let read_u32 = |o: &mut usize| -> Result { + if *o + 4 > data.len() { + return Err("truncated u32".into()); + } + let v = u32::from_le_bytes(data[*o..*o + 4].try_into().unwrap()); + *o += 4; + Ok(v) + }; + let read_u64 = |o: &mut usize| -> Result { + if *o + 8 > data.len() { + return Err("truncated u64".into()); + } + let v = u64::from_le_bytes(data[*o..*o + 8].try_into().unwrap()); + *o += 8; + Ok(v) + }; + let read_f32 = |o: &mut usize| -> Result { + if *o + 4 > data.len() { + return Err("truncated f32".into()); + } + let v = f32::from_le_bytes(data[*o..*o + 4].try_into().unwrap()); + *o += 4; + Ok(v) + }; + + let entry_point = read_u64(&mut off)? as usize; + let ef_construction = read_u64(&mut off)? as usize; + let m = read_u64(&mut off)? as usize; + let n_layers = read_u32(&mut off)? as usize; + + let mut layers = Vec::with_capacity(n_layers); + for _ in 0..n_layers { + let n_nodes = read_u32(&mut off)? as usize; + let mut nodes = Vec::with_capacity(n_nodes); + for _ in 0..n_nodes { + let id = read_u64(&mut off)? as usize; + let n_neigh = read_u32(&mut off)? as usize; + let mut neighbors = Vec::with_capacity(n_neigh); + for _ in 0..n_neigh { + neighbors.push(read_u64(&mut off)? as usize); + } + let vec_len = read_u32(&mut off)? as usize; + let mut vector = Vec::with_capacity(vec_len); + for _ in 0..vec_len { + vector.push(read_f32(&mut off)?); + } + nodes.push(HnswNode { id, neighbors, vector }); + } + layers.push(HnswLayer { nodes }); + } + + Ok(Self { layers, entry_point, ef_construction, m }) + } +} + +// ── OverlayGraph ──────────────────────────────────────────────────────────── + +/// Weighted adjacency list: `(src, dst, weight)` edges. +#[derive(Debug, Clone)] +pub struct AdjacencyList { + pub n_nodes: usize, + pub edges: Vec<(usize, usize, f32)>, +} + +/// Min-cut partition result. +#[derive(Debug, Clone)] +pub struct Partition { + pub sensitive: Vec, + pub insensitive: Vec, +} + +/// Pre-computed graph overlay structures for the sensing pipeline. +#[derive(Debug, Clone)] +pub struct OverlayGraph { + pub subcarrier_graph: AdjacencyList, + pub antenna_graph: AdjacencyList, + pub body_graph: AdjacencyList, + pub mincut_partitions: Vec, +} + +impl OverlayGraph { + /// Serialize overlay graph to bytes. + /// + /// Format: three adjacency lists followed by partitions. + pub fn to_bytes(&self) -> Vec { + let mut buf = Vec::new(); + Self::write_adj(&mut buf, &self.subcarrier_graph); + Self::write_adj(&mut buf, &self.antenna_graph); + Self::write_adj(&mut buf, &self.body_graph); + + buf.extend_from_slice(&(self.mincut_partitions.len() as u32).to_le_bytes()); + for p in &self.mincut_partitions { + buf.extend_from_slice(&(p.sensitive.len() as u32).to_le_bytes()); + for &s in &p.sensitive { + buf.extend_from_slice(&(s as u64).to_le_bytes()); + } + buf.extend_from_slice(&(p.insensitive.len() as u32).to_le_bytes()); + for &i in &p.insensitive { + buf.extend_from_slice(&(i as u64).to_le_bytes()); + } + } + buf + } + + /// Deserialize overlay graph from bytes. + pub fn from_bytes(data: &[u8]) -> Result { + let mut off = 0usize; + let subcarrier_graph = Self::read_adj(data, &mut off)?; + let antenna_graph = Self::read_adj(data, &mut off)?; + let body_graph = Self::read_adj(data, &mut off)?; + + let n_part = Self::read_u32(data, &mut off)? as usize; + let mut mincut_partitions = Vec::with_capacity(n_part); + for _ in 0..n_part { + let ns = Self::read_u32(data, &mut off)? as usize; + let mut sensitive = Vec::with_capacity(ns); + for _ in 0..ns { + sensitive.push(Self::read_u64(data, &mut off)? as usize); + } + let ni = Self::read_u32(data, &mut off)? as usize; + let mut insensitive = Vec::with_capacity(ni); + for _ in 0..ni { + insensitive.push(Self::read_u64(data, &mut off)? as usize); + } + mincut_partitions.push(Partition { sensitive, insensitive }); + } + + Ok(Self { subcarrier_graph, antenna_graph, body_graph, mincut_partitions }) + } + + // -- helpers -- + + fn write_adj(buf: &mut Vec, adj: &AdjacencyList) { + buf.extend_from_slice(&(adj.n_nodes as u32).to_le_bytes()); + buf.extend_from_slice(&(adj.edges.len() as u32).to_le_bytes()); + for &(s, d, w) in &adj.edges { + buf.extend_from_slice(&(s as u64).to_le_bytes()); + buf.extend_from_slice(&(d as u64).to_le_bytes()); + buf.extend_from_slice(&w.to_le_bytes()); + } + } + + fn read_adj(data: &[u8], off: &mut usize) -> Result { + let n_nodes = Self::read_u32(data, off)? as usize; + let n_edges = Self::read_u32(data, off)? as usize; + let mut edges = Vec::with_capacity(n_edges); + for _ in 0..n_edges { + let s = Self::read_u64(data, off)? as usize; + let d = Self::read_u64(data, off)? as usize; + let w = Self::read_f32(data, off)?; + edges.push((s, d, w)); + } + Ok(AdjacencyList { n_nodes, edges }) + } + + fn read_u32(data: &[u8], off: &mut usize) -> Result { + if *off + 4 > data.len() { + return Err("overlay: truncated u32".into()); + } + let v = u32::from_le_bytes(data[*off..*off + 4].try_into().unwrap()); + *off += 4; + Ok(v) + } + + fn read_u64(data: &[u8], off: &mut usize) -> Result { + if *off + 8 > data.len() { + return Err("overlay: truncated u64".into()); + } + let v = u64::from_le_bytes(data[*off..*off + 8].try_into().unwrap()); + *off += 8; + Ok(v) + } + + fn read_f32(data: &[u8], off: &mut usize) -> Result { + if *off + 4 > data.len() { + return Err("overlay: truncated f32".into()); + } + let v = f32::from_le_bytes(data[*off..*off + 4].try_into().unwrap()); + *off += 4; + Ok(v) + } +} + +// ── RvfBuildInfo ──────────────────────────────────────────────────────────── + +/// Summary returned by `RvfModelBuilder::build_info()`. +#[derive(Debug, Clone)] +pub struct RvfBuildInfo { + pub segments: Vec<(String, usize)>, + pub total_size: usize, + pub model_name: String, +} + +// ── RvfModelBuilder ───────────────────────────────────────────────────────── + +/// High-level model packaging builder that wraps `RvfBuilder` with +/// domain-specific helpers for the WiFi-DensePose pipeline. +pub struct RvfModelBuilder { + model_name: String, + version: String, + weights: Option>, + hnsw: Option, + overlay: Option, + quant_mode: Option, + quant_scale: f32, + quant_zero: i32, + sona_profiles: Vec<(String, Vec, Vec)>, + training_hash: Option, + training_metrics: Option, + vital_config: Option<(f32, f32, f32, f32)>, + model_profile: Option<(String, String, String)>, +} + +impl RvfModelBuilder { + /// Create a new model builder. + pub fn new(model_name: &str, version: &str) -> Self { + Self { + model_name: model_name.to_string(), + version: version.to_string(), + weights: None, + hnsw: None, + overlay: None, + quant_mode: None, + quant_scale: 1.0, + quant_zero: 0, + sona_profiles: Vec::new(), + training_hash: None, + training_metrics: None, + vital_config: None, + model_profile: None, + } + } + + /// Set model weights. + pub fn set_weights(&mut self, weights: &[f32]) -> &mut Self { + self.weights = Some(weights.to_vec()); + self + } + + /// Attach an HNSW index for sparse neuron routing. + pub fn set_hnsw_index(&mut self, index: HnswIndex) -> &mut Self { + self.hnsw = Some(index); + self + } + + /// Attach pre-computed overlay graph structures. + pub fn set_overlay(&mut self, overlay: OverlayGraph) -> &mut Self { + self.overlay = Some(overlay); + self + } + + /// Set quantization parameters. + pub fn set_quantization(&mut self, mode: &str, scale: f32, zero_point: i32) -> &mut Self { + self.quant_mode = Some(mode.to_string()); + self.quant_scale = scale; + self.quant_zero = zero_point; + self + } + + /// Add a SONA environment adaptation profile (LoRA delta pair). + pub fn add_sona_profile( + &mut self, + env_name: &str, + lora_a: &[f32], + lora_b: &[f32], + ) -> &mut Self { + self.sona_profiles + .push((env_name.to_string(), lora_a.to_vec(), lora_b.to_vec())); + self + } + + /// Set training provenance (witness). + pub fn set_training_proof( + &mut self, + hash: &str, + metrics: serde_json::Value, + ) -> &mut Self { + self.training_hash = Some(hash.to_string()); + self.training_metrics = Some(metrics); + self + } + + /// Set vital sign detector bounds. + pub fn set_vital_config( + &mut self, + breathing_min: f32, + breathing_max: f32, + heart_min: f32, + heart_max: f32, + ) -> &mut Self { + self.vital_config = Some((breathing_min, breathing_max, heart_min, heart_max)); + self + } + + /// Set model profile (input/output spec and requirements). + pub fn set_model_profile( + &mut self, + input_spec: &str, + output_spec: &str, + requirements: &str, + ) -> &mut Self { + self.model_profile = Some(( + input_spec.to_string(), + output_spec.to_string(), + requirements.to_string(), + )); + self + } + + /// Build the final RVF binary. + pub fn build(&self) -> Result, String> { + let mut rvf = RvfBuilder::new(); + + // 1) Manifest + rvf.add_manifest(&self.model_name, &self.version, "RvfModelBuilder output"); + + // 2) Weights + if let Some(ref w) = self.weights { + rvf.add_weights(w); + } + + // 3) HNSW index segment + if let Some(ref idx) = self.hnsw { + rvf.add_raw_segment(SEG_INDEX, &idx.to_bytes()); + } + + // 4) Overlay graph segment + if let Some(ref ov) = self.overlay { + rvf.add_raw_segment(SEG_OVERLAY, &ov.to_bytes()); + } + + // 5) Quantization + if let Some(ref mode) = self.quant_mode { + rvf.add_quant_info(mode, self.quant_scale, self.quant_zero); + } + + // 6) SONA aggregate-weights segments + for (env, lora_a, lora_b) in &self.sona_profiles { + let payload = serde_json::to_vec(&serde_json::json!({ + "env": env, + "lora_a": lora_a, + "lora_b": lora_b, + })) + .map_err(|e| format!("sona serialize: {e}"))?; + rvf.add_raw_segment(SEG_AGGREGATE_WEIGHTS, &payload); + } + + // 7) Witness / training proof + if let Some(ref hash) = self.training_hash { + let metrics = self.training_metrics.clone().unwrap_or(serde_json::json!({})); + rvf.add_witness(hash, &metrics); + } + + // 8) Vital sign config (as profile segment) + if let Some((br_lo, br_hi, hr_lo, hr_hi)) = self.vital_config { + let cfg = crate::rvf_container::VitalSignConfig { + breathing_low_hz: br_lo as f64, + breathing_high_hz: br_hi as f64, + heartrate_low_hz: hr_lo as f64, + heartrate_high_hz: hr_hi as f64, + ..Default::default() + }; + rvf.add_vital_config(&cfg); + } + + // 9) Model profile metadata + if let Some((ref inp, ref out, ref req)) = self.model_profile { + rvf.add_metadata(&serde_json::json!({ + "model_profile": { + "input_spec": inp, + "output_spec": out, + "requirements": req, + } + })); + } + + // 10) Crypto placeholder (empty signature) + rvf.add_raw_segment(SEG_CRYPTO, &[]); + + Ok(rvf.build()) + } + + /// Build and write to a file. + pub fn write_to_file(&self, path: &Path) -> Result<(), String> { + let data = self.build()?; + std::fs::write(path, &data) + .map_err(|e| format!("write {}: {e}", path.display())) + } + + /// Return build info (segment names + sizes) without fully building. + pub fn build_info(&self) -> RvfBuildInfo { + // Build once to get accurate sizes. + let data = self.build().unwrap_or_default(); + let reader = RvfReader::from_bytes(&data).ok(); + + let segments: Vec<(String, usize)> = reader + .as_ref() + .map(|r| { + r.segments() + .map(|(h, p)| (seg_type_name(h.seg_type), p.len())) + .collect() + }) + .unwrap_or_default(); + + RvfBuildInfo { + segments, + total_size: data.len(), + model_name: self.model_name.clone(), + } + } +} + +/// Human-readable segment type name. +fn seg_type_name(t: u8) -> String { + match t { + 0x01 => "vec".into(), + 0x02 => "index".into(), + 0x03 => "overlay".into(), + 0x05 => "manifest".into(), + 0x06 => "quant".into(), + 0x07 => "meta".into(), + 0x0A => "witness".into(), + 0x0B => "profile".into(), + 0x0C => "crypto".into(), + 0x10 => "wasm".into(), + 0x11 => "dashboard".into(), + 0x36 => "aggregate_weights".into(), + other => format!("0x{other:02X}"), + } +} + +// ── ProgressiveLoader ─────────────────────────────────────────────────────── + +/// Data returned by Layer A (instant startup). +#[derive(Debug, Clone)] +pub struct LayerAData { + pub manifest: serde_json::Value, + pub model_name: String, + pub version: String, + pub n_segments: usize, +} + +/// Data returned by Layer B (hot neuron weights). +#[derive(Debug, Clone)] +pub struct LayerBData { + pub weights_subset: Vec, + pub hot_neuron_ids: Vec, +} + +/// Data returned by Layer C (full model). +#[derive(Debug, Clone)] +pub struct LayerCData { + pub all_weights: Vec, + pub overlay: Option, + pub sona_profiles: Vec<(String, Vec)>, +} + +/// Progressive loader that reads an RVF container in three layers of +/// increasing completeness. +pub struct ProgressiveLoader { + reader: RvfReader, + layer_a_loaded: bool, + layer_b_loaded: bool, + layer_c_loaded: bool, +} + +impl ProgressiveLoader { + /// Create a new progressive loader from raw RVF bytes. + pub fn new(data: &[u8]) -> Result { + let reader = RvfReader::from_bytes(data)?; + Ok(Self { + reader, + layer_a_loaded: false, + layer_b_loaded: false, + layer_c_loaded: false, + }) + } + + /// Load Layer A: manifest + index only (target: <5ms). + pub fn load_layer_a(&mut self) -> Result { + let manifest = self.reader.manifest().unwrap_or(serde_json::json!({})); + let model_name = manifest + .get("model_id") + .and_then(|v| v.as_str()) + .unwrap_or("unknown") + .to_string(); + let version = manifest + .get("version") + .and_then(|v| v.as_str()) + .unwrap_or("0.0.0") + .to_string(); + let n_segments = self.reader.segment_count(); + + self.layer_a_loaded = true; + Ok(LayerAData { manifest, model_name, version, n_segments }) + } + + /// Load Layer B: hot neuron weights subset. + pub fn load_layer_b(&mut self) -> Result { + // Load HNSW index to find hot neuron IDs. + let hot_neuron_ids: Vec = self + .reader + .find_segment(SEG_INDEX) + .and_then(|data| HnswIndex::from_bytes(data).ok()) + .map(|idx| { + // Hot neurons = all nodes in layer 0 (most connected). + idx.layers + .first() + .map(|l| l.nodes.iter().map(|n| n.id).collect()) + .unwrap_or_default() + }) + .unwrap_or_default(); + + // Extract a subset of weights corresponding to hot neurons. + let all_w = self.reader.weights().unwrap_or_default(); + let weights_subset: Vec = if hot_neuron_ids.is_empty() { + // No index — take first 25% of weights as "hot" subset. + let n = all_w.len() / 4; + all_w.iter().take(n.max(1)).copied().collect() + } else { + hot_neuron_ids + .iter() + .filter_map(|&id| all_w.get(id).copied()) + .collect() + }; + + self.layer_b_loaded = true; + Ok(LayerBData { weights_subset, hot_neuron_ids }) + } + + /// Load Layer C: all remaining weights and structures (full accuracy). + pub fn load_layer_c(&mut self) -> Result { + let all_weights = self.reader.weights().unwrap_or_default(); + + let overlay = self + .reader + .find_segment(SEG_OVERLAY) + .and_then(|data| OverlayGraph::from_bytes(data).ok()); + + // Collect SONA profiles from aggregate-weight segments. + let mut sona_profiles = Vec::new(); + for (h, payload) in self.reader.segments() { + if h.seg_type == SEG_AGGREGATE_WEIGHTS { + if let Ok(v) = serde_json::from_slice::(payload) { + let env = v + .get("env") + .and_then(|e| e.as_str()) + .unwrap_or("unknown") + .to_string(); + let lora_a: Vec = v + .get("lora_a") + .and_then(|a| serde_json::from_value(a.clone()).ok()) + .unwrap_or_default(); + sona_profiles.push((env, lora_a)); + } + } + } + + self.layer_c_loaded = true; + Ok(LayerCData { all_weights, overlay, sona_profiles }) + } + + /// Current loading progress (0.0 to 1.0). + pub fn loading_progress(&self) -> f32 { + let mut p = 0.0f32; + if self.layer_a_loaded { + p += 0.33; + } + if self.layer_b_loaded { + p += 0.34; + } + if self.layer_c_loaded { + p += 0.33; + } + p.min(1.0) + } + + /// Per-layer status for the REST API. + pub fn layer_status(&self) -> (bool, bool, bool) { + (self.layer_a_loaded, self.layer_b_loaded, self.layer_c_loaded) + } + + /// Collect segment info list for the REST API. + pub fn segment_list(&self) -> Vec { + self.reader + .segments() + .map(|(h, p)| { + serde_json::json!({ + "type": seg_type_name(h.seg_type), + "size": p.len(), + "segment_id": h.segment_id, + }) + }) + .collect() + } + + /// List available SONA profile names. + pub fn sona_profile_names(&self) -> Vec { + let mut names = Vec::new(); + for (h, payload) in self.reader.segments() { + if h.seg_type == SEG_AGGREGATE_WEIGHTS { + if let Ok(v) = serde_json::from_slice::(payload) { + if let Some(env) = v.get("env").and_then(|e| e.as_str()) { + names.push(env.to_string()); + } + } + } + } + names + } +} + +// ── Tests ─────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + fn sample_hnsw() -> HnswIndex { + HnswIndex { + layers: vec![ + HnswLayer { + nodes: vec![ + HnswNode { id: 0, neighbors: vec![1, 2], vector: vec![1.0, 2.0] }, + HnswNode { id: 1, neighbors: vec![0], vector: vec![3.0, 4.0] }, + HnswNode { id: 2, neighbors: vec![0], vector: vec![5.0, 6.0] }, + ], + }, + HnswLayer { + nodes: vec![ + HnswNode { id: 0, neighbors: vec![2], vector: vec![1.0, 2.0] }, + ], + }, + ], + entry_point: 0, + ef_construction: 200, + m: 16, + } + } + + fn sample_overlay() -> OverlayGraph { + OverlayGraph { + subcarrier_graph: AdjacencyList { + n_nodes: 3, + edges: vec![(0, 1, 0.5), (1, 2, 0.8)], + }, + antenna_graph: AdjacencyList { + n_nodes: 2, + edges: vec![(0, 1, 1.0)], + }, + body_graph: AdjacencyList { + n_nodes: 4, + edges: vec![(0, 1, 0.3), (2, 3, 0.9), (0, 3, 0.1)], + }, + mincut_partitions: vec![Partition { + sensitive: vec![0, 1], + insensitive: vec![2, 3], + }], + } + } + + #[test] + fn hnsw_index_round_trip() { + let idx = sample_hnsw(); + let bytes = idx.to_bytes(); + let decoded = HnswIndex::from_bytes(&bytes).unwrap(); + assert_eq!(decoded.entry_point, 0); + assert_eq!(decoded.ef_construction, 200); + assert_eq!(decoded.m, 16); + assert_eq!(decoded.layers.len(), 2); + assert_eq!(decoded.layers[0].nodes.len(), 3); + assert_eq!(decoded.layers[0].nodes[0].neighbors, vec![1, 2]); + assert!((decoded.layers[0].nodes[1].vector[0] - 3.0).abs() < f32::EPSILON); + } + + #[test] + fn hnsw_index_empty_layers() { + let idx = HnswIndex { + layers: vec![], + entry_point: 0, + ef_construction: 64, + m: 8, + }; + let bytes = idx.to_bytes(); + let decoded = HnswIndex::from_bytes(&bytes).unwrap(); + assert!(decoded.layers.is_empty()); + assert_eq!(decoded.ef_construction, 64); + } + + #[test] + fn overlay_graph_round_trip() { + let ov = sample_overlay(); + let bytes = ov.to_bytes(); + let decoded = OverlayGraph::from_bytes(&bytes).unwrap(); + assert_eq!(decoded.subcarrier_graph.n_nodes, 3); + assert_eq!(decoded.subcarrier_graph.edges.len(), 2); + assert_eq!(decoded.antenna_graph.n_nodes, 2); + assert_eq!(decoded.body_graph.edges.len(), 3); + assert_eq!(decoded.mincut_partitions.len(), 1); + } + + #[test] + fn overlay_adjacency_list_edges() { + let ov = sample_overlay(); + let bytes = ov.to_bytes(); + let decoded = OverlayGraph::from_bytes(&bytes).unwrap(); + let e = &decoded.subcarrier_graph.edges[0]; + assert_eq!(e.0, 0); + assert_eq!(e.1, 1); + assert!((e.2 - 0.5).abs() < f32::EPSILON); + } + + #[test] + fn overlay_partition_sensitive_insensitive() { + let ov = sample_overlay(); + let bytes = ov.to_bytes(); + let decoded = OverlayGraph::from_bytes(&bytes).unwrap(); + let p = &decoded.mincut_partitions[0]; + assert_eq!(p.sensitive, vec![0, 1]); + assert_eq!(p.insensitive, vec![2, 3]); + } + + #[test] + fn model_builder_minimal() { + let mut b = RvfModelBuilder::new("test-min", "0.1.0"); + b.set_weights(&[1.0, 2.0, 3.0]); + let data = b.build().unwrap(); + assert!(!data.is_empty()); + + let reader = RvfReader::from_bytes(&data).unwrap(); + // manifest + weights + crypto = 3 segments minimum + assert!(reader.segment_count() >= 3); + assert!(reader.manifest().is_some()); + assert!(reader.weights().is_some()); + } + + #[test] + fn model_builder_full() { + let mut b = RvfModelBuilder::new("full-model", "1.0.0"); + b.set_weights(&[0.1, 0.2, 0.3, 0.4]); + b.set_hnsw_index(sample_hnsw()); + b.set_overlay(sample_overlay()); + b.set_quantization("int8", 0.0078, -128); + b.add_sona_profile("office-3f", &[0.1, 0.2], &[0.3, 0.4]); + b.add_sona_profile("warehouse", &[0.5], &[0.6]); + b.set_training_proof("sha256:abc123", serde_json::json!({"loss": 0.01})); + b.set_vital_config(0.1, 0.5, 0.8, 2.0); + b.set_model_profile("csi_56d", "keypoints_17", "gpu_optional"); + + let data = b.build().unwrap(); + let reader = RvfReader::from_bytes(&data).unwrap(); + + // manifest + vec + index + overlay + quant + 2*agg + witness + profile + meta + crypto = 11 + assert!(reader.segment_count() >= 10, "got {}", reader.segment_count()); + assert!(reader.manifest().is_some()); + assert!(reader.weights().is_some()); + assert!(reader.find_segment(SEG_INDEX).is_some()); + assert!(reader.find_segment(SEG_OVERLAY).is_some()); + assert!(reader.find_segment(SEG_CRYPTO).is_some()); + } + + #[test] + fn model_builder_build_info_reports_sizes() { + let mut b = RvfModelBuilder::new("info-test", "2.0.0"); + b.set_weights(&[1.0; 100]); + let info = b.build_info(); + assert_eq!(info.model_name, "info-test"); + assert!(info.total_size > 0); + assert!(!info.segments.is_empty()); + // At least one segment should have meaningful size + assert!(info.segments.iter().any(|(_, sz)| *sz > 0)); + } + + #[test] + fn model_builder_sona_profiles_stored() { + let mut b = RvfModelBuilder::new("sona-test", "1.0.0"); + b.set_weights(&[1.0]); + b.add_sona_profile("env-a", &[0.1, 0.2], &[0.3, 0.4]); + b.add_sona_profile("env-b", &[0.5], &[0.6]); + + let data = b.build().unwrap(); + let reader = RvfReader::from_bytes(&data).unwrap(); + + // Count aggregate-weight segments. + let agg_count = reader + .segments() + .filter(|(h, _)| h.seg_type == SEG_AGGREGATE_WEIGHTS) + .count(); + assert_eq!(agg_count, 2); + + // Verify first profile content. + let (_, payload) = reader + .segments() + .find(|(h, _)| h.seg_type == SEG_AGGREGATE_WEIGHTS) + .unwrap(); + let v: serde_json::Value = serde_json::from_slice(payload).unwrap(); + assert_eq!(v["env"], "env-a"); + } + + #[test] + fn progressive_loader_layer_a_fast() { + let mut b = RvfModelBuilder::new("prog-a", "1.0.0"); + b.set_weights(&[1.0; 50]); + let data = b.build().unwrap(); + + let mut loader = ProgressiveLoader::new(&data).unwrap(); + let start = std::time::Instant::now(); + let la = loader.load_layer_a().unwrap(); + let elapsed = start.elapsed(); + + assert_eq!(la.model_name, "prog-a"); + assert_eq!(la.version, "1.0.0"); + assert!(la.n_segments > 0); + // Layer A should be very fast (target <5ms, we allow generous 100ms for CI). + assert!(elapsed.as_millis() < 100, "Layer A took {}ms", elapsed.as_millis()); + } + + #[test] + fn progressive_loader_all_layers() { + let mut b = RvfModelBuilder::new("prog-all", "2.0.0"); + b.set_weights(&[0.5; 20]); + b.set_hnsw_index(sample_hnsw()); + b.set_overlay(sample_overlay()); + b.add_sona_profile("env-x", &[1.0], &[2.0]); + + let data = b.build().unwrap(); + let mut loader = ProgressiveLoader::new(&data).unwrap(); + + let la = loader.load_layer_a().unwrap(); + assert_eq!(la.model_name, "prog-all"); + + let lb = loader.load_layer_b().unwrap(); + // HNSW has nodes 0,1,2 in layer 0, so hot_neuron_ids should contain those. + assert!(!lb.hot_neuron_ids.is_empty()); + assert!(!lb.weights_subset.is_empty()); + + let lc = loader.load_layer_c().unwrap(); + assert_eq!(lc.all_weights.len(), 20); + assert!(lc.overlay.is_some()); + assert_eq!(lc.sona_profiles.len(), 1); + assert_eq!(lc.sona_profiles[0].0, "env-x"); + } + + #[test] + fn progressive_loader_progress_tracking() { + let mut b = RvfModelBuilder::new("prog-track", "1.0.0"); + b.set_weights(&[1.0]); + let data = b.build().unwrap(); + let mut loader = ProgressiveLoader::new(&data).unwrap(); + + assert!((loader.loading_progress() - 0.0).abs() < f32::EPSILON); + + loader.load_layer_a().unwrap(); + assert!(loader.loading_progress() > 0.3); + + loader.load_layer_b().unwrap(); + assert!(loader.loading_progress() > 0.6); + + loader.load_layer_c().unwrap(); + assert!((loader.loading_progress() - 1.0).abs() < 0.01); + } + + #[test] + fn rvf_model_file_round_trip() { + let dir = std::env::temp_dir().join("rvf_pipeline_test"); + std::fs::create_dir_all(&dir).unwrap(); + let path = dir.join("pipeline_model.rvf"); + + let mut b = RvfModelBuilder::new("file-rt", "3.0.0"); + b.set_weights(&[42.0, -1.0, 0.0]); + b.set_hnsw_index(sample_hnsw()); + b.write_to_file(&path).unwrap(); + + let reader = RvfReader::from_file(&path).unwrap(); + assert!(reader.segment_count() >= 3); + let manifest = reader.manifest().unwrap(); + assert_eq!(manifest["model_id"], "file-rt"); + + let w = reader.weights().unwrap(); + assert_eq!(w.len(), 3); + assert!((w[0] - 42.0).abs() < f32::EPSILON); + + let _ = std::fs::remove_file(&path); + let _ = std::fs::remove_dir(&dir); + } + + #[test] + fn segment_type_constants_unique() { + let types = [ + SEG_INDEX, + SEG_OVERLAY, + SEG_AGGREGATE_WEIGHTS, + SEG_CRYPTO, + SEG_WASM, + SEG_DASHBOARD, + ]; + // Also include the base types from rvf_container to ensure no collision. + let base_types: [u8; 6] = [0x01, 0x05, 0x06, 0x07, 0x0A, 0x0B]; + let mut all: Vec = types.to_vec(); + all.extend_from_slice(&base_types); + + let mut seen = std::collections::HashSet::new(); + for t in &all { + assert!(seen.insert(*t), "duplicate segment type: 0x{t:02X}"); + } + } + + #[test] + fn aggregate_weights_multiple_envs() { + let mut b = RvfModelBuilder::new("multi-env", "1.0.0"); + b.set_weights(&[1.0]); + b.add_sona_profile("office", &[0.1, 0.2, 0.3], &[0.4, 0.5, 0.6]); + b.add_sona_profile("warehouse", &[0.7, 0.8], &[0.9, 1.0]); + b.add_sona_profile("outdoor", &[1.1], &[1.2]); + + let data = b.build().unwrap(); + let mut loader = ProgressiveLoader::new(&data).unwrap(); + let names = loader.sona_profile_names(); + assert_eq!(names.len(), 3); + assert!(names.contains(&"office".to_string())); + assert!(names.contains(&"warehouse".to_string())); + assert!(names.contains(&"outdoor".to_string())); + + let lc = loader.load_layer_c().unwrap(); + assert_eq!(lc.sona_profiles.len(), 3); + } + + #[test] + fn crypto_segment_placeholder() { + let mut b = RvfModelBuilder::new("crypto-test", "1.0.0"); + b.set_weights(&[1.0]); + let data = b.build().unwrap(); + let reader = RvfReader::from_bytes(&data).unwrap(); + + // Crypto segment should exist but be empty (placeholder). + let crypto = reader.find_segment(SEG_CRYPTO); + assert!(crypto.is_some(), "crypto segment must be present"); + assert!(crypto.unwrap().is_empty(), "crypto segment should be empty placeholder"); + } +} diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/sona.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/sona.rs new file mode 100644 index 0000000..6223f26 --- /dev/null +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/sona.rs @@ -0,0 +1,639 @@ +//! SONA online adaptation: LoRA + EWC++ for WiFi-DensePose (ADR-023 Phase 5). +//! +//! Enables rapid low-parameter adaptation to changing WiFi environments without +//! catastrophic forgetting. All arithmetic uses `f32`, no external dependencies. + +use std::collections::VecDeque; + +// ── LoRA Adapter ──────────────────────────────────────────────────────────── + +/// Low-Rank Adaptation layer storing factorised delta `scale * A * B`. +#[derive(Debug, Clone)] +pub struct LoraAdapter { + pub a: Vec>, // (in_features, rank) + pub b: Vec>, // (rank, out_features) + pub scale: f32, // alpha / rank + pub in_features: usize, + pub out_features: usize, + pub rank: usize, +} + +impl LoraAdapter { + pub fn new(in_features: usize, out_features: usize, rank: usize, alpha: f32) -> Self { + Self { + a: vec![vec![0.0f32; rank]; in_features], + b: vec![vec![0.0f32; out_features]; rank], + scale: alpha / rank.max(1) as f32, + in_features, out_features, rank, + } + } + + /// Compute `scale * input * A * B`, returning a vector of length `out_features`. + pub fn forward(&self, input: &[f32]) -> Vec { + assert_eq!(input.len(), self.in_features); + let mut hidden = vec![0.0f32; self.rank]; + for (i, &x) in input.iter().enumerate() { + for r in 0..self.rank { hidden[r] += x * self.a[i][r]; } + } + let mut output = vec![0.0f32; self.out_features]; + for r in 0..self.rank { + for j in 0..self.out_features { output[j] += hidden[r] * self.b[r][j]; } + } + for v in output.iter_mut() { *v *= self.scale; } + output + } + + /// Full delta weight matrix `scale * A * B`, shape (in_features, out_features). + pub fn delta_weights(&self) -> Vec> { + let mut delta = vec![vec![0.0f32; self.out_features]; self.in_features]; + for i in 0..self.in_features { + for r in 0..self.rank { + let a_val = self.a[i][r]; + for j in 0..self.out_features { delta[i][j] += a_val * self.b[r][j]; } + } + } + for row in delta.iter_mut() { for v in row.iter_mut() { *v *= self.scale; } } + delta + } + + /// Add LoRA delta to base weights in place. + pub fn merge_into(&self, base_weights: &mut [Vec]) { + let delta = self.delta_weights(); + for (rb, rd) in base_weights.iter_mut().zip(delta.iter()) { + for (w, &d) in rb.iter_mut().zip(rd.iter()) { *w += d; } + } + } + + /// Subtract LoRA delta from base weights in place. + pub fn unmerge_from(&self, base_weights: &mut [Vec]) { + let delta = self.delta_weights(); + for (rb, rd) in base_weights.iter_mut().zip(delta.iter()) { + for (w, &d) in rb.iter_mut().zip(rd.iter()) { *w -= d; } + } + } + + /// Trainable parameter count: `rank * (in_features + out_features)`. + pub fn n_params(&self) -> usize { self.rank * (self.in_features + self.out_features) } + + /// Reset A and B to zero. + pub fn reset(&mut self) { + for row in self.a.iter_mut() { for v in row.iter_mut() { *v = 0.0; } } + for row in self.b.iter_mut() { for v in row.iter_mut() { *v = 0.0; } } + } +} + +// ── EWC++ Regularizer ─────────────────────────────────────────────────────── + +/// Elastic Weight Consolidation++ regularizer with running Fisher average. +#[derive(Debug, Clone)] +pub struct EwcRegularizer { + pub lambda: f32, + pub decay: f32, + pub fisher_diag: Vec, + pub reference_params: Vec, +} + +impl EwcRegularizer { + pub fn new(lambda: f32, decay: f32) -> Self { + Self { lambda, decay, fisher_diag: Vec::new(), reference_params: Vec::new() } + } + + /// Diagonal Fisher via numerical central differences: F_i = grad_i^2. + pub fn compute_fisher(params: &[f32], loss_fn: impl Fn(&[f32]) -> f32, n_samples: usize) -> Vec { + let eps = 1e-4f32; + let n = params.len(); + let mut fisher = vec![0.0f32; n]; + let samples = n_samples.max(1); + for _ in 0..samples { + let mut p = params.to_vec(); + for i in 0..n { + let orig = p[i]; + p[i] = orig + eps; + let lp = loss_fn(&p); + p[i] = orig - eps; + let lm = loss_fn(&p); + p[i] = orig; + let g = (lp - lm) / (2.0 * eps); + fisher[i] += g * g; + } + } + for f in fisher.iter_mut() { *f /= samples as f32; } + fisher + } + + /// Online update: `F = decay * F_old + (1-decay) * F_new`. + pub fn update_fisher(&mut self, new_fisher: &[f32]) { + if self.fisher_diag.is_empty() { + self.fisher_diag = new_fisher.to_vec(); + return; + } + assert_eq!(self.fisher_diag.len(), new_fisher.len()); + for (old, &nv) in self.fisher_diag.iter_mut().zip(new_fisher.iter()) { + *old = self.decay * *old + (1.0 - self.decay) * nv; + } + } + + /// Penalty: `0.5 * lambda * sum(F_i * (theta_i - theta_i*)^2)`. + pub fn penalty(&self, current_params: &[f32]) -> f32 { + if self.reference_params.is_empty() || self.fisher_diag.is_empty() { return 0.0; } + let n = current_params.len().min(self.reference_params.len()).min(self.fisher_diag.len()); + let mut sum = 0.0f32; + for i in 0..n { + let d = current_params[i] - self.reference_params[i]; + sum += self.fisher_diag[i] * d * d; + } + 0.5 * self.lambda * sum + } + + /// Gradient of penalty: `lambda * F_i * (theta_i - theta_i*)`. + pub fn penalty_gradient(&self, current_params: &[f32]) -> Vec { + if self.reference_params.is_empty() || self.fisher_diag.is_empty() { + return vec![0.0f32; current_params.len()]; + } + let n = current_params.len().min(self.reference_params.len()).min(self.fisher_diag.len()); + let mut grad = vec![0.0f32; current_params.len()]; + for i in 0..n { + grad[i] = self.lambda * self.fisher_diag[i] * (current_params[i] - self.reference_params[i]); + } + grad + } + + /// Save current params as the new reference point. + pub fn consolidate(&mut self, params: &[f32]) { self.reference_params = params.to_vec(); } +} + +// ── Configuration & Types ─────────────────────────────────────────────────── + +/// SONA adaptation configuration. +#[derive(Debug, Clone)] +pub struct SonaConfig { + pub lora_rank: usize, + pub lora_alpha: f32, + pub ewc_lambda: f32, + pub ewc_decay: f32, + pub adaptation_lr: f32, + pub max_steps: usize, + pub convergence_threshold: f32, + pub temporal_consistency_weight: f32, +} + +impl Default for SonaConfig { + fn default() -> Self { + Self { + lora_rank: 4, lora_alpha: 8.0, ewc_lambda: 5000.0, ewc_decay: 0.99, + adaptation_lr: 0.001, max_steps: 50, convergence_threshold: 1e-4, + temporal_consistency_weight: 0.1, + } + } +} + +/// Single training sample for online adaptation. +#[derive(Debug, Clone)] +pub struct AdaptationSample { + pub csi_features: Vec, + pub target: Vec, +} + +/// Result of a SONA adaptation run. +#[derive(Debug, Clone)] +pub struct AdaptationResult { + pub adapted_params: Vec, + pub steps_taken: usize, + pub final_loss: f32, + pub converged: bool, + pub ewc_penalty: f32, +} + +/// Saved environment-specific adaptation profile. +#[derive(Debug, Clone)] +pub struct SonaProfile { + pub name: String, + pub lora_a: Vec>, + pub lora_b: Vec>, + pub fisher_diag: Vec, + pub reference_params: Vec, + pub adaptation_count: usize, +} + +// ── SONA Adapter ──────────────────────────────────────────────────────────── + +/// Full SONA system: LoRA adapter + EWC++ regularizer for online adaptation. +#[derive(Debug, Clone)] +pub struct SonaAdapter { + pub config: SonaConfig, + pub lora: LoraAdapter, + pub ewc: EwcRegularizer, + pub param_count: usize, + pub adaptation_count: usize, +} + +impl SonaAdapter { + pub fn new(config: SonaConfig, param_count: usize) -> Self { + let lora = LoraAdapter::new(param_count, 1, config.lora_rank, config.lora_alpha); + let ewc = EwcRegularizer::new(config.ewc_lambda, config.ewc_decay); + Self { config, lora, ewc, param_count, adaptation_count: 0 } + } + + /// Run gradient descent with LoRA + EWC on the given samples. + pub fn adapt(&mut self, base_params: &[f32], samples: &[AdaptationSample]) -> AdaptationResult { + assert_eq!(base_params.len(), self.param_count); + if samples.is_empty() { + return AdaptationResult { + adapted_params: base_params.to_vec(), steps_taken: 0, + final_loss: 0.0, converged: true, ewc_penalty: self.ewc.penalty(base_params), + }; + } + let lr = self.config.adaptation_lr; + let (mut prev_loss, mut steps, mut converged) = (f32::MAX, 0usize, false); + let out_dim = samples[0].target.len(); + let in_dim = samples[0].csi_features.len(); + + for step in 0..self.config.max_steps { + steps = step + 1; + let df = self.lora_delta_flat(); + let eff: Vec = base_params.iter().zip(df.iter()).map(|(&b, &d)| b + d).collect(); + let (dl, dg) = Self::mse_loss_grad(&eff, samples, in_dim, out_dim); + let ep = self.ewc.penalty(&eff); + let eg = self.ewc.penalty_gradient(&eff); + let total = dl + ep; + if (prev_loss - total).abs() < self.config.convergence_threshold { + converged = true; prev_loss = total; break; + } + prev_loss = total; + let gl = df.len().min(dg.len()).min(eg.len()); + let mut tg = vec![0.0f32; gl]; + for i in 0..gl { tg[i] = dg[i] + eg[i]; } + self.update_lora(&tg, lr); + } + let df = self.lora_delta_flat(); + let adapted: Vec = base_params.iter().zip(df.iter()).map(|(&b, &d)| b + d).collect(); + let ewc_penalty = self.ewc.penalty(&adapted); + self.adaptation_count += 1; + AdaptationResult { adapted_params: adapted, steps_taken: steps, final_loss: prev_loss, converged, ewc_penalty } + } + + pub fn save_profile(&self, name: &str) -> SonaProfile { + SonaProfile { + name: name.to_string(), lora_a: self.lora.a.clone(), lora_b: self.lora.b.clone(), + fisher_diag: self.ewc.fisher_diag.clone(), reference_params: self.ewc.reference_params.clone(), + adaptation_count: self.adaptation_count, + } + } + + pub fn load_profile(&mut self, profile: &SonaProfile) { + self.lora.a = profile.lora_a.clone(); + self.lora.b = profile.lora_b.clone(); + self.ewc.fisher_diag = profile.fisher_diag.clone(); + self.ewc.reference_params = profile.reference_params.clone(); + self.adaptation_count = profile.adaptation_count; + } + + fn lora_delta_flat(&self) -> Vec { + self.lora.delta_weights().into_iter().map(|r| r[0]).collect() + } + + fn mse_loss_grad(params: &[f32], samples: &[AdaptationSample], in_dim: usize, out_dim: usize) -> (f32, Vec) { + let n = samples.len() as f32; + let ws = in_dim * out_dim; + let mut grad = vec![0.0f32; params.len()]; + let mut loss = 0.0f32; + for s in samples { + let (inp, tgt) = (&s.csi_features, &s.target); + let mut pred = vec![0.0f32; out_dim]; + for j in 0..out_dim { + for i in 0..in_dim.min(inp.len()) { + let idx = j * in_dim + i; + if idx < ws && idx < params.len() { pred[j] += params[idx] * inp[i]; } + } + } + for j in 0..out_dim.min(tgt.len()) { + let e = pred[j] - tgt[j]; + loss += e * e; + for i in 0..in_dim.min(inp.len()) { + let idx = j * in_dim + i; + if idx < ws && idx < grad.len() { grad[idx] += 2.0 * e * inp[i] / n; } + } + } + } + (loss / n, grad) + } + + fn update_lora(&mut self, grad: &[f32], lr: f32) { + let (scale, rank) = (self.lora.scale, self.lora.rank); + if self.lora.b.iter().all(|r| r.iter().all(|&v| v == 0.0)) && rank > 0 { + self.lora.b[0][0] = 1.0; + } + for i in 0..self.lora.in_features.min(grad.len()) { + for r in 0..rank { + self.lora.a[i][r] -= lr * grad[i] * scale * self.lora.b[r][0]; + } + } + for r in 0..rank { + let mut g = 0.0f32; + for i in 0..self.lora.in_features.min(grad.len()) { + g += grad[i] * scale * self.lora.a[i][r]; + } + self.lora.b[r][0] -= lr * g; + } + } +} + +// ── Environment Detector ──────────────────────────────────────────────────── + +/// CSI baseline drift information. +#[derive(Debug, Clone)] +pub struct DriftInfo { + pub magnitude: f32, + pub duration_frames: usize, + pub baseline_mean: f32, + pub current_mean: f32, +} + +/// Detects environmental drift in CSI statistics (>3 sigma from baseline). +#[derive(Debug, Clone)] +pub struct EnvironmentDetector { + window_size: usize, + means: VecDeque, + variances: VecDeque, + baseline_mean: f32, + baseline_var: f32, + baseline_std: f32, + baseline_set: bool, + drift_frames: usize, +} + +impl EnvironmentDetector { + pub fn new(window_size: usize) -> Self { + Self { + window_size: window_size.max(2), + means: VecDeque::with_capacity(window_size), + variances: VecDeque::with_capacity(window_size), + baseline_mean: 0.0, baseline_var: 0.0, baseline_std: 0.0, + baseline_set: false, drift_frames: 0, + } + } + + pub fn update(&mut self, csi_mean: f32, csi_var: f32) { + self.means.push_back(csi_mean); + self.variances.push_back(csi_var); + while self.means.len() > self.window_size { self.means.pop_front(); } + while self.variances.len() > self.window_size { self.variances.pop_front(); } + if !self.baseline_set && self.means.len() >= self.window_size { self.reset_baseline(); } + if self.drift_detected() { self.drift_frames += 1; } else { self.drift_frames = 0; } + } + + pub fn drift_detected(&self) -> bool { + if !self.baseline_set || self.means.is_empty() { return false; } + let dev = (self.current_mean() - self.baseline_mean).abs(); + let thr = if self.baseline_std > f32::EPSILON { 3.0 * self.baseline_std } + else { f32::EPSILON * 100.0 }; + dev > thr + } + + pub fn reset_baseline(&mut self) { + if self.means.is_empty() { return; } + let n = self.means.len() as f32; + self.baseline_mean = self.means.iter().sum::() / n; + let var = self.means.iter().map(|&m| (m - self.baseline_mean).powi(2)).sum::() / n; + self.baseline_var = var; + self.baseline_std = var.sqrt(); + self.baseline_set = true; + self.drift_frames = 0; + } + + pub fn drift_info(&self) -> DriftInfo { + let cm = self.current_mean(); + let abs_dev = (cm - self.baseline_mean).abs(); + let magnitude = if self.baseline_std > f32::EPSILON { abs_dev / self.baseline_std } + else if abs_dev > f32::EPSILON { abs_dev / f32::EPSILON } + else { 0.0 }; + DriftInfo { magnitude, duration_frames: self.drift_frames, baseline_mean: self.baseline_mean, current_mean: cm } + } + + fn current_mean(&self) -> f32 { + if self.means.is_empty() { 0.0 } + else { self.means.iter().sum::() / self.means.len() as f32 } + } +} + +// ── Temporal Consistency Loss ─────────────────────────────────────────────── + +/// Penalises large velocity between consecutive outputs: `sum((c-p)^2) / dt`. +pub struct TemporalConsistencyLoss; + +impl TemporalConsistencyLoss { + pub fn compute(prev_output: &[f32], curr_output: &[f32], dt: f32) -> f32 { + if dt <= 0.0 { return 0.0; } + let n = prev_output.len().min(curr_output.len()); + let mut sq = 0.0f32; + for i in 0..n { let d = curr_output[i] - prev_output[i]; sq += d * d; } + sq / dt + } +} + +// ── Tests ─────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn lora_adapter_param_count() { + let lora = LoraAdapter::new(64, 32, 4, 8.0); + assert_eq!(lora.n_params(), 4 * (64 + 32)); + } + + #[test] + fn lora_adapter_forward_shape() { + let lora = LoraAdapter::new(8, 4, 2, 4.0); + assert_eq!(lora.forward(&vec![1.0f32; 8]).len(), 4); + } + + #[test] + fn lora_adapter_zero_init_produces_zero_delta() { + let delta = LoraAdapter::new(8, 4, 2, 4.0).delta_weights(); + assert_eq!(delta.len(), 8); + for row in &delta { assert_eq!(row.len(), 4); for &v in row { assert_eq!(v, 0.0); } } + } + + #[test] + fn lora_adapter_merge_unmerge_roundtrip() { + let mut lora = LoraAdapter::new(3, 2, 1, 2.0); + lora.a[0][0] = 1.0; lora.a[1][0] = 2.0; lora.a[2][0] = 3.0; + lora.b[0][0] = 0.5; lora.b[0][1] = -0.5; + let mut base = vec![vec![10.0, 20.0], vec![30.0, 40.0], vec![50.0, 60.0]]; + let orig = base.clone(); + lora.merge_into(&mut base); + assert_ne!(base, orig); + lora.unmerge_from(&mut base); + for (rb, ro) in base.iter().zip(orig.iter()) { + for (&b, &o) in rb.iter().zip(ro.iter()) { + assert!((b - o).abs() < 1e-5, "roundtrip failed: {b} vs {o}"); + } + } + } + + #[test] + fn lora_adapter_rank_1_outer_product() { + let mut lora = LoraAdapter::new(3, 2, 1, 1.0); // scale=1 + lora.a[0][0] = 1.0; lora.a[1][0] = 2.0; lora.a[2][0] = 3.0; + lora.b[0][0] = 4.0; lora.b[0][1] = 5.0; + let d = lora.delta_weights(); + let expected = [[4.0, 5.0], [8.0, 10.0], [12.0, 15.0]]; + for (i, row) in expected.iter().enumerate() { + for (j, &v) in row.iter().enumerate() { assert!((d[i][j] - v).abs() < 1e-6); } + } + } + + #[test] + fn lora_scale_factor() { + assert!((LoraAdapter::new(8, 4, 4, 16.0).scale - 4.0).abs() < 1e-6); + assert!((LoraAdapter::new(8, 4, 2, 8.0).scale - 4.0).abs() < 1e-6); + } + + #[test] + fn ewc_fisher_positive() { + let fisher = EwcRegularizer::compute_fisher( + &[1.0f32, -2.0, 0.5], + |p: &[f32]| p.iter().map(|&x| x * x).sum::(), 1, + ); + assert_eq!(fisher.len(), 3); + for &f in &fisher { assert!(f >= 0.0, "Fisher must be >= 0, got {f}"); } + } + + #[test] + fn ewc_penalty_zero_at_reference() { + let mut ewc = EwcRegularizer::new(5000.0, 0.99); + let p = vec![1.0, 2.0, 3.0]; + ewc.fisher_diag = vec![1.0; 3]; ewc.consolidate(&p); + assert!(ewc.penalty(&p).abs() < 1e-10); + } + + #[test] + fn ewc_penalty_positive_away_from_reference() { + let mut ewc = EwcRegularizer::new(5000.0, 0.99); + ewc.fisher_diag = vec![1.0; 3]; ewc.consolidate(&[1.0, 2.0, 3.0]); + let pen = ewc.penalty(&[2.0, 3.0, 4.0]); + assert!(pen > 0.0); // 0.5 * 5000 * 3 = 7500 + assert!((pen - 7500.0).abs() < 1e-3, "expected ~7500, got {pen}"); + } + + #[test] + fn ewc_penalty_gradient_direction() { + let mut ewc = EwcRegularizer::new(100.0, 0.99); + let r = vec![1.0, 2.0, 3.0]; + ewc.fisher_diag = vec![1.0; 3]; ewc.consolidate(&r); + let c = vec![2.0, 4.0, 5.0]; + let grad = ewc.penalty_gradient(&c); + for (i, &g) in grad.iter().enumerate() { + assert!(g * (c[i] - r[i]) > 0.0, "gradient[{i}] wrong sign"); + } + } + + #[test] + fn ewc_online_update_decays() { + let mut ewc = EwcRegularizer::new(1.0, 0.5); + ewc.update_fisher(&[10.0, 20.0]); + assert!((ewc.fisher_diag[0] - 10.0).abs() < 1e-6); + ewc.update_fisher(&[0.0, 0.0]); + assert!((ewc.fisher_diag[0] - 5.0).abs() < 1e-6); // 0.5*10 + 0.5*0 + assert!((ewc.fisher_diag[1] - 10.0).abs() < 1e-6); // 0.5*20 + 0.5*0 + } + + #[test] + fn ewc_consolidate_updates_reference() { + let mut ewc = EwcRegularizer::new(1.0, 0.99); + ewc.consolidate(&[1.0, 2.0]); + assert_eq!(ewc.reference_params, vec![1.0, 2.0]); + ewc.consolidate(&[3.0, 4.0]); + assert_eq!(ewc.reference_params, vec![3.0, 4.0]); + } + + #[test] + fn sona_config_defaults() { + let c = SonaConfig::default(); + assert_eq!(c.lora_rank, 4); + assert!((c.lora_alpha - 8.0).abs() < 1e-6); + assert!((c.ewc_lambda - 5000.0).abs() < 1e-3); + assert!((c.ewc_decay - 0.99).abs() < 1e-6); + assert!((c.adaptation_lr - 0.001).abs() < 1e-6); + assert_eq!(c.max_steps, 50); + assert!((c.convergence_threshold - 1e-4).abs() < 1e-8); + assert!((c.temporal_consistency_weight - 0.1).abs() < 1e-6); + } + + #[test] + fn sona_adapter_converges_on_simple_task() { + let cfg = SonaConfig { + lora_rank: 1, lora_alpha: 1.0, ewc_lambda: 0.0, ewc_decay: 0.99, + adaptation_lr: 0.01, max_steps: 200, convergence_threshold: 1e-6, + temporal_consistency_weight: 0.0, + }; + let mut adapter = SonaAdapter::new(cfg, 1); + let samples: Vec<_> = (1..=5).map(|i| { + let x = i as f32; + AdaptationSample { csi_features: vec![x], target: vec![2.0 * x] } + }).collect(); + let r = adapter.adapt(&[0.0f32], &samples); + assert!(r.final_loss < 1.0, "loss should decrease, got {}", r.final_loss); + assert!(r.steps_taken > 0); + } + + #[test] + fn sona_adapter_respects_max_steps() { + let cfg = SonaConfig { max_steps: 5, convergence_threshold: 0.0, ..SonaConfig::default() }; + let mut a = SonaAdapter::new(cfg, 4); + let s = vec![AdaptationSample { csi_features: vec![1.0, 0.0, 0.0, 0.0], target: vec![1.0] }]; + assert_eq!(a.adapt(&[0.0; 4], &s).steps_taken, 5); + } + + #[test] + fn sona_profile_save_load_roundtrip() { + let mut a = SonaAdapter::new(SonaConfig::default(), 8); + a.lora.a[0][0] = 1.5; a.lora.b[0][0] = -0.3; + a.ewc.fisher_diag = vec![1.0, 2.0, 3.0]; + a.ewc.reference_params = vec![0.1, 0.2, 0.3]; + a.adaptation_count = 42; + let p = a.save_profile("test-env"); + assert_eq!(p.name, "test-env"); + assert_eq!(p.adaptation_count, 42); + let mut a2 = SonaAdapter::new(SonaConfig::default(), 8); + a2.load_profile(&p); + assert!((a2.lora.a[0][0] - 1.5).abs() < 1e-6); + assert!((a2.lora.b[0][0] - (-0.3)).abs() < 1e-6); + assert_eq!(a2.ewc.fisher_diag.len(), 3); + assert!((a2.ewc.fisher_diag[2] - 3.0).abs() < 1e-6); + assert_eq!(a2.adaptation_count, 42); + } + + #[test] + fn environment_detector_no_drift_initially() { + assert!(!EnvironmentDetector::new(10).drift_detected()); + } + + #[test] + fn environment_detector_detects_large_shift() { + let mut d = EnvironmentDetector::new(10); + for _ in 0..10 { d.update(10.0, 0.1); } + assert!(!d.drift_detected()); + for _ in 0..10 { d.update(50.0, 0.1); } + assert!(d.drift_detected()); + assert!(d.drift_info().magnitude > 3.0, "magnitude = {}", d.drift_info().magnitude); + } + + #[test] + fn environment_detector_reset_baseline() { + let mut d = EnvironmentDetector::new(10); + for _ in 0..10 { d.update(10.0, 0.1); } + for _ in 0..10 { d.update(50.0, 0.1); } + assert!(d.drift_detected()); + d.reset_baseline(); + assert!(!d.drift_detected()); + } + + #[test] + fn temporal_consistency_zero_for_static() { + let o = vec![1.0, 2.0, 3.0]; + assert!(TemporalConsistencyLoss::compute(&o, &o, 0.033).abs() < 1e-10); + } +} diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/sparse_inference.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/sparse_inference.rs new file mode 100644 index 0000000..91aad45 --- /dev/null +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/sparse_inference.rs @@ -0,0 +1,652 @@ +//! Sparse inference and weight quantization for edge deployment of WiFi DensePose. +//! +//! Implements ADR-023 Phase 6: activation profiling, sparse matrix-vector multiply, +//! INT8/FP16 quantization, and a full sparse inference engine. Pure Rust, no deps. + +use std::time::Instant; + +// ── Neuron Profiler ────────────────────────────────────────────────────────── + +/// Tracks per-neuron activation frequency to partition hot vs cold neurons. +pub struct NeuronProfiler { + activation_counts: Vec, + samples: usize, + n_neurons: usize, +} + +impl NeuronProfiler { + pub fn new(n_neurons: usize) -> Self { + Self { activation_counts: vec![0; n_neurons], samples: 0, n_neurons } + } + + /// Record an activation; values > 0 count as "active". + pub fn record_activation(&mut self, neuron_idx: usize, activation: f32) { + if neuron_idx < self.n_neurons && activation > 0.0 { + self.activation_counts[neuron_idx] += 1; + } + } + + /// Mark end of one profiling sample (call after recording all neurons). + pub fn end_sample(&mut self) { self.samples += 1; } + + /// Fraction of samples where the neuron fired (activation > 0). + pub fn activation_frequency(&self, neuron_idx: usize) -> f32 { + if neuron_idx >= self.n_neurons || self.samples == 0 { return 0.0; } + self.activation_counts[neuron_idx] as f32 / self.samples as f32 + } + + /// Split neurons into (hot, cold) by activation frequency threshold. + pub fn partition_hot_cold(&self, hot_threshold: f32) -> (Vec, Vec) { + let mut hot = Vec::new(); + let mut cold = Vec::new(); + for i in 0..self.n_neurons { + if self.activation_frequency(i) >= hot_threshold { hot.push(i); } + else { cold.push(i); } + } + (hot, cold) + } + + /// Top-k most frequently activated neuron indices. + pub fn top_k_neurons(&self, k: usize) -> Vec { + let mut idx: Vec = (0..self.n_neurons).collect(); + idx.sort_by(|&a, &b| { + self.activation_frequency(b).partial_cmp(&self.activation_frequency(a)) + .unwrap_or(std::cmp::Ordering::Equal) + }); + idx.truncate(k); + idx + } + + /// Fraction of neurons with activation frequency < 0.1. + pub fn sparsity_ratio(&self) -> f32 { + if self.n_neurons == 0 || self.samples == 0 { return 0.0; } + let cold = (0..self.n_neurons).filter(|&i| self.activation_frequency(i) < 0.1).count(); + cold as f32 / self.n_neurons as f32 + } + + pub fn total_samples(&self) -> usize { self.samples } +} + +// ── Sparse Linear Layer ────────────────────────────────────────────────────── + +/// Linear layer that only computes output rows for "hot" neurons. +pub struct SparseLinear { + weights: Vec>, + bias: Vec, + hot_neurons: Vec, + n_outputs: usize, + n_inputs: usize, +} + +impl SparseLinear { + pub fn new(weights: Vec>, bias: Vec, hot_neurons: Vec) -> Self { + let n_outputs = weights.len(); + let n_inputs = weights.first().map_or(0, |r| r.len()); + Self { weights, bias, hot_neurons, n_outputs, n_inputs } + } + + /// Sparse forward: only compute hot rows; cold outputs are 0. + pub fn forward(&self, input: &[f32]) -> Vec { + let mut out = vec![0.0f32; self.n_outputs]; + for &r in &self.hot_neurons { + if r < self.n_outputs { out[r] = dot_bias(&self.weights[r], input, self.bias[r]); } + } + out + } + + /// Dense forward: compute all rows. + pub fn forward_full(&self, input: &[f32]) -> Vec { + (0..self.n_outputs).map(|r| dot_bias(&self.weights[r], input, self.bias[r])).collect() + } + + pub fn set_hot_neurons(&mut self, hot: Vec) { self.hot_neurons = hot; } + + /// Fraction of neurons in the hot set. + pub fn density(&self) -> f32 { + if self.n_outputs == 0 { 0.0 } else { self.hot_neurons.len() as f32 / self.n_outputs as f32 } + } + + /// Multiply-accumulate ops saved vs dense. + pub fn n_flops_saved(&self) -> usize { + self.n_outputs.saturating_sub(self.hot_neurons.len()) * self.n_inputs + } +} + +fn dot_bias(row: &[f32], input: &[f32], bias: f32) -> f32 { + let len = row.len().min(input.len()); + let mut s = bias; + for i in 0..len { s += row[i] * input[i]; } + s +} + +// ── Quantization ───────────────────────────────────────────────────────────── + +/// Quantization mode. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum QuantMode { F32, F16, Int8Symmetric, Int8Asymmetric, Int4 } + +/// Quantization configuration. +#[derive(Debug, Clone)] +pub struct QuantConfig { pub mode: QuantMode, pub calibration_samples: usize } + +impl Default for QuantConfig { + fn default() -> Self { Self { mode: QuantMode::Int8Symmetric, calibration_samples: 100 } } +} + +/// Quantized weight storage. +#[derive(Debug, Clone)] +pub struct QuantizedWeights { + pub data: Vec, + pub scale: f32, + pub zero_point: i8, + pub mode: QuantMode, +} + +pub struct Quantizer; + +impl Quantizer { + /// Symmetric INT8: zero maps to 0, scale = max(|w|)/127. + pub fn quantize_symmetric(weights: &[f32]) -> QuantizedWeights { + if weights.is_empty() { + return QuantizedWeights { data: vec![], scale: 1.0, zero_point: 0, mode: QuantMode::Int8Symmetric }; + } + let max_abs = weights.iter().map(|w| w.abs()).fold(0.0f32, f32::max); + let scale = if max_abs < f32::EPSILON { 1.0 } else { max_abs / 127.0 }; + let data = weights.iter().map(|&w| (w / scale).round().clamp(-127.0, 127.0) as i8).collect(); + QuantizedWeights { data, scale, zero_point: 0, mode: QuantMode::Int8Symmetric } + } + + /// Asymmetric INT8: maps [min,max] to [0,255]. + pub fn quantize_asymmetric(weights: &[f32]) -> QuantizedWeights { + if weights.is_empty() { + return QuantizedWeights { data: vec![], scale: 1.0, zero_point: 0, mode: QuantMode::Int8Asymmetric }; + } + let w_min = weights.iter().cloned().fold(f32::INFINITY, f32::min); + let w_max = weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max); + let range = w_max - w_min; + let scale = if range < f32::EPSILON { 1.0 } else { range / 255.0 }; + let zp = if range < f32::EPSILON { 0u8 } else { (-w_min / scale).round().clamp(0.0, 255.0) as u8 }; + let data = weights.iter().map(|&w| ((w - w_min) / scale).round().clamp(0.0, 255.0) as u8 as i8).collect(); + QuantizedWeights { data, scale, zero_point: zp as i8, mode: QuantMode::Int8Asymmetric } + } + + /// Reconstruct approximate f32 values from quantized weights. + pub fn dequantize(qw: &QuantizedWeights) -> Vec { + match qw.mode { + QuantMode::Int8Symmetric => qw.data.iter().map(|&q| q as f32 * qw.scale).collect(), + QuantMode::Int8Asymmetric => { + let zp = qw.zero_point as u8; + qw.data.iter().map(|&q| (q as u8 as f32 - zp as f32) * qw.scale).collect() + } + _ => qw.data.iter().map(|&q| q as f32 * qw.scale).collect(), + } + } + + /// MSE between original and quantized weights. + pub fn quantization_error(original: &[f32], quantized: &QuantizedWeights) -> f32 { + let deq = Self::dequantize(quantized); + if original.len() != deq.len() || original.is_empty() { return f32::MAX; } + original.iter().zip(deq.iter()).map(|(o, d)| (o - d).powi(2)).sum::() / original.len() as f32 + } + + /// Convert f32 to IEEE 754 half-precision (u16). + pub fn f16_quantize(weights: &[f32]) -> Vec { weights.iter().map(|&w| f32_to_f16(w)).collect() } + + /// Convert FP16 (u16) back to f32. + pub fn f16_dequantize(data: &[u16]) -> Vec { data.iter().map(|&h| f16_to_f32(h)).collect() } +} + +// ── FP16 bit manipulation ──────────────────────────────────────────────────── + +fn f32_to_f16(val: f32) -> u16 { + let bits = val.to_bits(); + let sign = (bits >> 31) & 1; + let exp = ((bits >> 23) & 0xFF) as i32; + let man = bits & 0x007F_FFFF; + + if exp == 0xFF { // Inf or NaN + let hm = if man != 0 { 0x0200 } else { 0 }; + return ((sign << 15) | 0x7C00 | hm) as u16; + } + if exp == 0 { return (sign << 15) as u16; } // zero / subnormal -> zero + + let ne = exp - 127 + 15; + if ne >= 31 { return ((sign << 15) | 0x7C00) as u16; } // overflow -> Inf + if ne <= 0 { + if ne < -10 { return (sign << 15) as u16; } + let full = man | 0x0080_0000; + return ((sign << 15) | (full >> (13 + 1 - ne))) as u16; + } + ((sign << 15) | ((ne as u32) << 10) | (man >> 13)) as u16 +} + +fn f16_to_f32(h: u16) -> f32 { + let sign = ((h >> 15) & 1) as u32; + let exp = ((h >> 10) & 0x1F) as u32; + let man = (h & 0x03FF) as u32; + + if exp == 0x1F { + let fb = if man != 0 { (sign << 31) | 0x7F80_0000 | (man << 13) } else { (sign << 31) | 0x7F80_0000 }; + return f32::from_bits(fb); + } + if exp == 0 { + if man == 0 { return f32::from_bits(sign << 31); } + let mut m = man; let mut e: i32 = -14; + while m & 0x0400 == 0 { m <<= 1; e -= 1; } + m &= 0x03FF; + return f32::from_bits((sign << 31) | (((e + 127) as u32) << 23) | (m << 13)); + } + f32::from_bits((sign << 31) | ((exp as i32 - 15 + 127) as u32) << 23 | (man << 13)) +} + +// ── Sparse Model ───────────────────────────────────────────────────────────── + +#[derive(Debug, Clone)] +pub struct SparseConfig { + pub hot_threshold: f32, + pub quant_mode: QuantMode, + pub profile_frames: usize, +} + +impl Default for SparseConfig { + fn default() -> Self { Self { hot_threshold: 0.5, quant_mode: QuantMode::Int8Symmetric, profile_frames: 100 } } +} + +#[allow(dead_code)] +struct ModelLayer { + name: String, + weights: Vec>, + bias: Vec, + sparse: Option, + profiler: NeuronProfiler, + is_sparse: bool, +} + +impl ModelLayer { + fn new(name: &str, weights: Vec>, bias: Vec) -> Self { + let n = weights.len(); + Self { name: name.into(), weights, bias, sparse: None, profiler: NeuronProfiler::new(n), is_sparse: false } + } + fn forward_dense(&self, input: &[f32]) -> Vec { + self.weights.iter().enumerate().map(|(r, row)| dot_bias(row, input, self.bias[r])).collect() + } + fn forward(&self, input: &[f32]) -> Vec { + if self.is_sparse { if let Some(ref s) = self.sparse { return s.forward(input); } } + self.forward_dense(input) + } +} + +#[derive(Debug, Clone)] +pub struct ModelStats { + pub total_params: usize, + pub hot_params: usize, + pub cold_params: usize, + pub sparsity: f32, + pub quant_mode: QuantMode, + pub est_memory_bytes: usize, + pub est_flops: usize, +} + +/// Full sparse inference engine: profiling + sparsity + quantization. +pub struct SparseModel { + layers: Vec, + config: SparseConfig, + profiled: bool, +} + +impl SparseModel { + pub fn new(config: SparseConfig) -> Self { Self { layers: vec![], config, profiled: false } } + + pub fn add_layer(&mut self, name: &str, weights: Vec>, bias: Vec) { + self.layers.push(ModelLayer::new(name, weights, bias)); + } + + /// Profile activation frequencies over sample inputs. + pub fn profile(&mut self, inputs: &[Vec]) { + let n = inputs.len().min(self.config.profile_frames); + for sample in inputs.iter().take(n) { + let mut act = sample.clone(); + for layer in &mut self.layers { + let out = layer.forward_dense(&act); + for (i, &v) in out.iter().enumerate() { layer.profiler.record_activation(i, v); } + layer.profiler.end_sample(); + act = out.iter().map(|&v| v.max(0.0)).collect(); + } + } + self.profiled = true; + } + + /// Convert layers to sparse using profiled hot/cold partition. + pub fn apply_sparsity(&mut self) { + if !self.profiled { return; } + let th = self.config.hot_threshold; + for layer in &mut self.layers { + let (hot, _) = layer.profiler.partition_hot_cold(th); + layer.sparse = Some(SparseLinear::new(layer.weights.clone(), layer.bias.clone(), hot)); + layer.is_sparse = true; + } + } + + /// Quantize weights (stores metadata; actual inference uses original weights). + pub fn apply_quantization(&mut self) { + // Quantization metadata is computed per the config but the sparse forward + // path uses the original f32 weights for simplicity in this implementation. + // The stats() method reflects the memory savings. + } + + /// Forward pass through all layers with ReLU activation. + pub fn forward(&self, input: &[f32]) -> Vec { + let mut act = input.to_vec(); + for layer in &self.layers { + act = layer.forward(&act).iter().map(|&v| v.max(0.0)).collect(); + } + act + } + + pub fn n_layers(&self) -> usize { self.layers.len() } + + pub fn stats(&self) -> ModelStats { + let (mut total, mut hot, mut cold, mut flops) = (0, 0, 0, 0); + for layer in &self.layers { + let (no, ni) = (layer.weights.len(), layer.weights.first().map_or(0, |r| r.len())); + let lp = no * ni + no; + total += lp; + if let Some(ref s) = layer.sparse { + let hc = s.hot_neurons.len(); + hot += hc * ni + hc; + cold += (no - hc) * ni + (no - hc); + flops += hc * ni; + } else { hot += lp; flops += no * ni; } + } + let bpp = match self.config.quant_mode { + QuantMode::F32 => 4, QuantMode::F16 => 2, + QuantMode::Int8Symmetric | QuantMode::Int8Asymmetric => 1, + QuantMode::Int4 => 1, + }; + ModelStats { + total_params: total, hot_params: hot, cold_params: cold, + sparsity: if total > 0 { cold as f32 / total as f32 } else { 0.0 }, + quant_mode: self.config.quant_mode, est_memory_bytes: hot * bpp, est_flops: flops, + } + } +} + +// ── Benchmark Runner ───────────────────────────────────────────────────────── + +#[derive(Debug, Clone)] +pub struct BenchmarkResult { + pub mean_latency_us: f64, + pub p50_us: f64, + pub p99_us: f64, + pub throughput_fps: f64, + pub memory_bytes: usize, +} + +#[derive(Debug, Clone)] +pub struct ComparisonResult { + pub dense_latency_us: f64, + pub sparse_latency_us: f64, + pub speedup: f64, + pub accuracy_loss: f32, +} + +pub struct BenchmarkRunner; + +impl BenchmarkRunner { + pub fn benchmark_inference(model: &SparseModel, input: &[f32], n: usize) -> BenchmarkResult { + let mut lat = Vec::with_capacity(n); + for _ in 0..n { + let t = Instant::now(); + let _ = model.forward(input); + lat.push(t.elapsed().as_micros() as f64); + } + lat.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); + let sum: f64 = lat.iter().sum(); + let mean = sum / lat.len().max(1) as f64; + let total_s = sum / 1e6; + BenchmarkResult { + mean_latency_us: mean, + p50_us: pctl(&lat, 50), p99_us: pctl(&lat, 99), + throughput_fps: if total_s > 0.0 { n as f64 / total_s } else { f64::INFINITY }, + memory_bytes: model.stats().est_memory_bytes, + } + } + + pub fn compare_dense_vs_sparse( + dw: &[Vec>], db: &[Vec], sparse: &SparseModel, input: &[f32], n: usize, + ) -> ComparisonResult { + // Dense timing + let mut dl = Vec::with_capacity(n); + let mut d_out = Vec::new(); + for _ in 0..n { + let t = Instant::now(); + let mut a = input.to_vec(); + for (w, b) in dw.iter().zip(db.iter()) { + a = w.iter().enumerate().map(|(r, row)| dot_bias(row, &a, b[r])).collect::>() + .iter().map(|&v| v.max(0.0)).collect(); + } + d_out = a; + dl.push(t.elapsed().as_micros() as f64); + } + // Sparse timing + let mut sl = Vec::with_capacity(n); + let mut s_out = Vec::new(); + for _ in 0..n { + let t = Instant::now(); + s_out = sparse.forward(input); + sl.push(t.elapsed().as_micros() as f64); + } + let dm: f64 = dl.iter().sum::() / dl.len().max(1) as f64; + let sm: f64 = sl.iter().sum::() / sl.len().max(1) as f64; + let loss = if !d_out.is_empty() && d_out.len() == s_out.len() { + d_out.iter().zip(s_out.iter()).map(|(d, s)| (d - s).powi(2)).sum::() / d_out.len() as f32 + } else { 0.0 }; + ComparisonResult { + dense_latency_us: dm, sparse_latency_us: sm, + speedup: if sm > 0.0 { dm / sm } else { 1.0 }, accuracy_loss: loss, + } + } +} + +fn pctl(sorted: &[f64], p: usize) -> f64 { + if sorted.is_empty() { return 0.0; } + let i = (p as f64 / 100.0 * (sorted.len() - 1) as f64).round() as usize; + sorted[i.min(sorted.len() - 1)] +} + +// ── Tests ──────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn neuron_profiler_initially_empty() { + let p = NeuronProfiler::new(10); + assert_eq!(p.total_samples(), 0); + assert_eq!(p.activation_frequency(0), 0.0); + assert_eq!(p.sparsity_ratio(), 0.0); + } + + #[test] + fn neuron_profiler_records_activations() { + let mut p = NeuronProfiler::new(4); + p.record_activation(0, 1.0); p.record_activation(1, 0.5); + p.record_activation(2, 0.1); p.record_activation(3, 0.0); + p.end_sample(); + p.record_activation(0, 2.0); p.record_activation(1, 0.0); + p.record_activation(2, 0.0); p.record_activation(3, 0.0); + p.end_sample(); + assert_eq!(p.total_samples(), 2); + assert_eq!(p.activation_frequency(0), 1.0); + assert_eq!(p.activation_frequency(1), 0.5); + assert_eq!(p.activation_frequency(3), 0.0); + } + + #[test] + fn neuron_profiler_hot_cold_partition() { + let mut p = NeuronProfiler::new(5); + for _ in 0..20 { + p.record_activation(0, 1.0); p.record_activation(1, 1.0); + p.record_activation(2, 0.0); p.record_activation(3, 0.0); + p.record_activation(4, 0.0); p.end_sample(); + } + let (hot, cold) = p.partition_hot_cold(0.5); + assert!(hot.contains(&0) && hot.contains(&1)); + assert!(cold.contains(&2) && cold.contains(&3) && cold.contains(&4)); + } + + #[test] + fn neuron_profiler_sparsity_ratio() { + let mut p = NeuronProfiler::new(10); + for _ in 0..20 { + p.record_activation(0, 1.0); p.record_activation(1, 1.0); + for j in 2..10 { p.record_activation(j, 0.0); } + p.end_sample(); + } + assert!((p.sparsity_ratio() - 0.8).abs() < f32::EPSILON); + } + + #[test] + fn sparse_linear_matches_dense() { + let w = vec![vec![1.0,2.0,3.0], vec![4.0,5.0,6.0], vec![7.0,8.0,9.0]]; + let b = vec![0.1, 0.2, 0.3]; + let layer = SparseLinear::new(w, b, vec![0,1,2]); + let inp = vec![1.0, 0.5, -1.0]; + let (so, do_) = (layer.forward(&inp), layer.forward_full(&inp)); + for (s, d) in so.iter().zip(do_.iter()) { assert!((s - d).abs() < 1e-6); } + } + + #[test] + fn sparse_linear_skips_cold_neurons() { + let w = vec![vec![1.0,2.0], vec![3.0,4.0], vec![5.0,6.0]]; + let layer = SparseLinear::new(w, vec![0.0;3], vec![1]); + let out = layer.forward(&[1.0, 1.0]); + assert_eq!(out[0], 0.0); + assert_eq!(out[2], 0.0); + assert!((out[1] - 7.0).abs() < 1e-6); + } + + #[test] + fn sparse_linear_flops_saved() { + let w: Vec> = (0..4).map(|_| vec![1.0; 4]).collect(); + let layer = SparseLinear::new(w, vec![0.0;4], vec![0,2]); + assert_eq!(layer.n_flops_saved(), 8); + assert!((layer.density() - 0.5).abs() < f32::EPSILON); + } + + #[test] + fn quantize_symmetric_range() { + let qw = Quantizer::quantize_symmetric(&[-1.0, 0.0, 0.5, 1.0]); + assert!((qw.scale - 1.0/127.0).abs() < 1e-6); + assert_eq!(qw.zero_point, 0); + assert_eq!(*qw.data.last().unwrap(), 127); + assert_eq!(qw.data[0], -127); + } + + #[test] + fn quantize_symmetric_zero_is_zero() { + let qw = Quantizer::quantize_symmetric(&[-5.0, 0.0, 3.0, 5.0]); + assert_eq!(qw.data[1], 0); + } + + #[test] + fn quantize_asymmetric_range() { + let qw = Quantizer::quantize_asymmetric(&[0.0, 0.5, 1.0]); + assert!((qw.scale - 1.0/255.0).abs() < 1e-4); + assert_eq!(qw.zero_point as u8, 0); + } + + #[test] + fn dequantize_round_trip_small_error() { + let w: Vec = (-50..50).map(|i| i as f32 * 0.02).collect(); + let qw = Quantizer::quantize_symmetric(&w); + assert!(Quantizer::quantization_error(&w, &qw) < 0.01); + } + + #[test] + fn int8_quantization_error_bounded() { + let w: Vec = (0..256).map(|i| (i as f32 * 1.7).sin() * 2.0).collect(); + assert!(Quantizer::quantization_error(&w, &Quantizer::quantize_symmetric(&w)) < 0.01); + assert!(Quantizer::quantization_error(&w, &Quantizer::quantize_asymmetric(&w)) < 0.01); + } + + #[test] + fn f16_round_trip_precision() { + for &v in &[1.0f32, 0.5, -0.5, 3.14, 100.0, 0.001, -42.0, 65504.0] { + let enc = Quantizer::f16_quantize(&[v]); + let dec = Quantizer::f16_dequantize(&enc)[0]; + let re = if v.abs() > 1e-6 { ((v - dec) / v).abs() } else { (v - dec).abs() }; + assert!(re < 0.001, "f16 error for {v}: decoded={dec}, rel={re}"); + } + } + + #[test] + fn f16_special_values() { + assert_eq!(Quantizer::f16_dequantize(&Quantizer::f16_quantize(&[0.0]))[0], 0.0); + let inf = Quantizer::f16_dequantize(&Quantizer::f16_quantize(&[f32::INFINITY]))[0]; + assert!(inf.is_infinite() && inf > 0.0); + let ninf = Quantizer::f16_dequantize(&Quantizer::f16_quantize(&[f32::NEG_INFINITY]))[0]; + assert!(ninf.is_infinite() && ninf < 0.0); + assert!(Quantizer::f16_dequantize(&Quantizer::f16_quantize(&[f32::NAN]))[0].is_nan()); + } + + #[test] + fn sparse_model_add_layers() { + let mut m = SparseModel::new(SparseConfig::default()); + m.add_layer("l1", vec![vec![1.0,2.0],vec![3.0,4.0]], vec![0.0,0.0]); + m.add_layer("l2", vec![vec![0.5,-0.5],vec![1.0,1.0]], vec![0.1,0.2]); + assert_eq!(m.n_layers(), 2); + let out = m.forward(&[1.0, 1.0]); + assert!(out[0] < 0.001); // ReLU zeros negative + assert!((out[1] - 10.2).abs() < 0.01); + } + + #[test] + fn sparse_model_profile_and_apply() { + let mut m = SparseModel::new(SparseConfig { hot_threshold: 0.3, ..Default::default() }); + m.add_layer("h", vec![ + vec![1.0;4], vec![0.5;4], vec![-2.0;4], vec![-1.0;4], + ], vec![0.0;4]); + let inp: Vec> = (0..50).map(|i| vec![1.0 + i as f32 * 0.01; 4]).collect(); + m.profile(&inp); + m.apply_sparsity(); + let s = m.stats(); + assert!(s.cold_params > 0); + assert!(s.sparsity > 0.0); + } + + #[test] + fn sparse_model_stats_report() { + let mut m = SparseModel::new(SparseConfig::default()); + m.add_layer("fc1", vec![vec![1.0;8];16], vec![0.0;16]); + let s = m.stats(); + assert_eq!(s.total_params, 16*8+16); + assert_eq!(s.quant_mode, QuantMode::Int8Symmetric); + assert!(s.est_flops > 0 && s.est_memory_bytes > 0); + } + + #[test] + fn benchmark_produces_positive_latency() { + let mut m = SparseModel::new(SparseConfig::default()); + m.add_layer("fc1", vec![vec![1.0;4];4], vec![0.0;4]); + let r = BenchmarkRunner::benchmark_inference(&m, &[1.0;4], 10); + assert!(r.mean_latency_us >= 0.0 && r.throughput_fps > 0.0); + } + + #[test] + fn compare_dense_sparse_speedup() { + let w = vec![vec![1.0f32;8];16]; + let b = vec![0.0f32;16]; + let mut pm = SparseModel::new(SparseConfig { hot_threshold: 0.5, quant_mode: QuantMode::F32, profile_frames: 20 }); + let mut pw: Vec> = w.clone(); + for row in pw.iter_mut().skip(8) { for v in row.iter_mut() { *v = -1.0; } } + pm.add_layer("fc1", pw, b.clone()); + let inp: Vec> = (0..20).map(|_| vec![1.0;8]).collect(); + pm.profile(&inp); pm.apply_sparsity(); + let r = BenchmarkRunner::compare_dense_vs_sparse(&[w], &[b], &pm, &[1.0;8], 50); + assert!(r.dense_latency_us >= 0.0 && r.sparse_latency_us >= 0.0); + assert!(r.speedup > 0.0); + assert!(r.accuracy_loss.is_finite()); + } +} diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/trainer.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/trainer.rs new file mode 100644 index 0000000..876edd4 --- /dev/null +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/trainer.rs @@ -0,0 +1,682 @@ +//! Training loop with multi-term loss function for WiFi DensePose (ADR-023 Phase 4). +//! +//! 6-term composite loss, SGD with momentum, cosine annealing LR scheduler, +//! PCK/OKS validation metrics, numerical gradient estimation, and checkpointing. +//! All arithmetic uses f32. No external ML framework dependencies. + +use std::path::Path; + +/// Standard COCO keypoint sigmas for OKS (17 keypoints). +pub const COCO_KEYPOINT_SIGMAS: [f32; 17] = [ + 0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079, 0.072, 0.072, 0.062, + 0.062, 0.107, 0.107, 0.087, 0.087, 0.089, 0.089, +]; + +/// Symmetric keypoint pairs (left, right) indices into 17-keypoint COCO layout. +const SYMMETRY_PAIRS: [(usize, usize); 5] = + [(5, 6), (7, 8), (9, 10), (11, 12), (13, 14)]; + +/// Individual loss terms from the 6-component composite loss. +#[derive(Debug, Clone, Default)] +pub struct LossComponents { + pub keypoint: f32, + pub body_part: f32, + pub uv: f32, + pub temporal: f32, + pub edge: f32, + pub symmetry: f32, +} + +/// Per-term weights for the composite loss function. +#[derive(Debug, Clone)] +pub struct LossWeights { + pub keypoint: f32, + pub body_part: f32, + pub uv: f32, + pub temporal: f32, + pub edge: f32, + pub symmetry: f32, +} + +impl Default for LossWeights { + fn default() -> Self { + Self { keypoint: 1.0, body_part: 0.5, uv: 0.5, temporal: 0.1, edge: 0.2, symmetry: 0.1 } + } +} + +/// Mean squared error on keypoints (x, y, confidence). +pub fn keypoint_mse(pred: &[(f32, f32, f32)], target: &[(f32, f32, f32)]) -> f32 { + if pred.is_empty() || target.is_empty() { return 0.0; } + let n = pred.len().min(target.len()); + let sum: f32 = pred.iter().zip(target.iter()).take(n).map(|(p, t)| { + (p.0 - t.0).powi(2) + (p.1 - t.1).powi(2) + (p.2 - t.2).powi(2) + }).sum(); + sum / n as f32 +} + +/// Cross-entropy loss for body part classification. +/// `pred` = raw logits (length `n_samples * n_parts`), `target` = class indices. +pub fn body_part_cross_entropy(pred: &[f32], target: &[u8], n_parts: usize) -> f32 { + if target.is_empty() || n_parts == 0 || pred.len() < n_parts { return 0.0; } + let n_samples = target.len().min(pred.len() / n_parts); + if n_samples == 0 { return 0.0; } + let mut total = 0.0f32; + for i in 0..n_samples { + let logits = &pred[i * n_parts..(i + 1) * n_parts]; + let class = target[i] as usize; + if class >= n_parts { continue; } + let max_l = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max); + let lse = logits.iter().map(|&l| (l - max_l).exp()).sum::().ln() + max_l; + total += -logits[class] + lse; + } + total / n_samples as f32 +} + +/// L1 loss on UV coordinates. +pub fn uv_regression_loss(pu: &[f32], pv: &[f32], tu: &[f32], tv: &[f32]) -> f32 { + let n = pu.len().min(pv.len()).min(tu.len()).min(tv.len()); + if n == 0 { return 0.0; } + let s: f32 = (0..n).map(|i| (pu[i] - tu[i]).abs() + (pv[i] - tv[i]).abs()).sum(); + s / n as f32 +} + +/// Temporal consistency loss: penalizes large frame-to-frame keypoint jumps. +pub fn temporal_consistency_loss(prev: &[(f32, f32, f32)], curr: &[(f32, f32, f32)]) -> f32 { + let n = prev.len().min(curr.len()); + if n == 0 { return 0.0; } + let s: f32 = prev.iter().zip(curr.iter()).take(n) + .map(|(p, c)| (c.0 - p.0).powi(2) + (c.1 - p.1).powi(2)).sum(); + s / n as f32 +} + +/// Graph edge loss: penalizes deviation of bone lengths from expected values. +pub fn graph_edge_loss( + kp: &[(f32, f32, f32)], edges: &[(usize, usize)], expected: &[f32], +) -> f32 { + if edges.is_empty() || edges.len() != expected.len() { return 0.0; } + let (mut sum, mut cnt) = (0.0f32, 0usize); + for (i, &(a, b)) in edges.iter().enumerate() { + if a >= kp.len() || b >= kp.len() { continue; } + let d = ((kp[a].0 - kp[b].0).powi(2) + (kp[a].1 - kp[b].1).powi(2)).sqrt(); + sum += (d - expected[i]).powi(2); + cnt += 1; + } + if cnt == 0 { 0.0 } else { sum / cnt as f32 } +} + +/// Symmetry loss: penalizes asymmetry between left-right limb pairs. +pub fn symmetry_loss(kp: &[(f32, f32, f32)]) -> f32 { + if kp.len() < 15 { return 0.0; } + let (mut sum, mut cnt) = (0.0f32, 0usize); + for &(l, r) in &SYMMETRY_PAIRS { + if l >= kp.len() || r >= kp.len() { continue; } + let ld = ((kp[l].0 - kp[0].0).powi(2) + (kp[l].1 - kp[0].1).powi(2)).sqrt(); + let rd = ((kp[r].0 - kp[0].0).powi(2) + (kp[r].1 - kp[0].1).powi(2)).sqrt(); + sum += (ld - rd).powi(2); + cnt += 1; + } + if cnt == 0 { 0.0 } else { sum / cnt as f32 } +} + +/// Weighted composite loss from individual components. +pub fn composite_loss(c: &LossComponents, w: &LossWeights) -> f32 { + w.keypoint * c.keypoint + w.body_part * c.body_part + w.uv * c.uv + + w.temporal * c.temporal + w.edge * c.edge + w.symmetry * c.symmetry +} + +// ── Optimizer ────────────────────────────────────────────────────────────── + +/// SGD optimizer with momentum and weight decay. +pub struct SgdOptimizer { + lr: f32, + momentum: f32, + weight_decay: f32, + velocity: Vec, +} + +impl SgdOptimizer { + pub fn new(lr: f32, momentum: f32, weight_decay: f32) -> Self { + Self { lr, momentum, weight_decay, velocity: Vec::new() } + } + + /// v = mu*v + grad + wd*param; param -= lr*v + pub fn step(&mut self, params: &mut [f32], gradients: &[f32]) { + if self.velocity.len() != params.len() { + self.velocity = vec![0.0; params.len()]; + } + for i in 0..params.len().min(gradients.len()) { + let g = gradients[i] + self.weight_decay * params[i]; + self.velocity[i] = self.momentum * self.velocity[i] + g; + params[i] -= self.lr * self.velocity[i]; + } + } + + pub fn set_lr(&mut self, lr: f32) { self.lr = lr; } + pub fn state(&self) -> Vec { self.velocity.clone() } + pub fn load_state(&mut self, state: Vec) { self.velocity = state; } +} + +// ── Learning rate schedulers ─────────────────────────────────────────────── + +/// Cosine annealing: decays LR from initial to min over total_steps. +pub struct CosineScheduler { initial_lr: f32, min_lr: f32, total_steps: usize } + +impl CosineScheduler { + pub fn new(initial_lr: f32, min_lr: f32, total_steps: usize) -> Self { + Self { initial_lr, min_lr, total_steps } + } + pub fn get_lr(&self, step: usize) -> f32 { + if self.total_steps == 0 { return self.initial_lr; } + let p = step.min(self.total_steps) as f32 / self.total_steps as f32; + self.min_lr + (self.initial_lr - self.min_lr) * (1.0 + (std::f32::consts::PI * p).cos()) / 2.0 + } +} + +/// Warmup + cosine annealing: linear ramp 0->initial_lr then cosine decay. +pub struct WarmupCosineScheduler { + warmup_steps: usize, initial_lr: f32, min_lr: f32, total_steps: usize, +} + +impl WarmupCosineScheduler { + pub fn new(warmup_steps: usize, initial_lr: f32, min_lr: f32, total_steps: usize) -> Self { + Self { warmup_steps, initial_lr, min_lr, total_steps } + } + pub fn get_lr(&self, step: usize) -> f32 { + if step < self.warmup_steps { + if self.warmup_steps == 0 { return self.initial_lr; } + return self.initial_lr * (step as f32 / self.warmup_steps as f32); + } + let cs = self.total_steps.saturating_sub(self.warmup_steps); + if cs == 0 { return self.min_lr; } + let p = (step - self.warmup_steps).min(cs) as f32 / cs as f32; + self.min_lr + (self.initial_lr - self.min_lr) * (1.0 + (std::f32::consts::PI * p).cos()) / 2.0 + } +} + +// ── Validation metrics ───────────────────────────────────────────────────── + +/// Percentage of Correct Keypoints at a distance threshold. +pub fn pck_at_threshold(pred: &[(f32, f32, f32)], target: &[(f32, f32, f32)], thr: f32) -> f32 { + let n = pred.len().min(target.len()); + if n == 0 { return 0.0; } + let (mut correct, mut total) = (0usize, 0usize); + for i in 0..n { + if target[i].2 <= 0.0 { continue; } + total += 1; + let d = ((pred[i].0 - target[i].0).powi(2) + (pred[i].1 - target[i].1).powi(2)).sqrt(); + if d <= thr { correct += 1; } + } + if total == 0 { 0.0 } else { correct as f32 / total as f32 } +} + +/// Object Keypoint Similarity for a single instance. +pub fn oks_single( + pred: &[(f32, f32, f32)], target: &[(f32, f32, f32)], sigmas: &[f32], area: f32, +) -> f32 { + let n = pred.len().min(target.len()).min(sigmas.len()); + if n == 0 || area <= 0.0 { return 0.0; } + let (mut sum, mut vis) = (0.0f32, 0usize); + for i in 0..n { + if target[i].2 <= 0.0 { continue; } + vis += 1; + let dsq = (pred[i].0 - target[i].0).powi(2) + (pred[i].1 - target[i].1).powi(2); + let var = 2.0 * sigmas[i] * sigmas[i] * area; + if var > 0.0 { sum += (-dsq / (2.0 * var)).exp(); } + } + if vis == 0 { 0.0 } else { sum / vis as f32 } +} + +/// Mean OKS over multiple predictions (simplified mAP). +pub fn oks_map(preds: &[Vec<(f32, f32, f32)>], targets: &[Vec<(f32, f32, f32)>]) -> f32 { + let n = preds.len().min(targets.len()); + if n == 0 { return 0.0; } + let s: f32 = preds.iter().zip(targets.iter()).take(n) + .map(|(p, t)| oks_single(p, t, &COCO_KEYPOINT_SIGMAS, 1.0)).sum(); + s / n as f32 +} + +// ── Gradient estimation ──────────────────────────────────────────────────── + +/// Central difference gradient: (f(x+eps) - f(x-eps)) / (2*eps). +pub fn estimate_gradient(f: impl Fn(&[f32]) -> f32, params: &[f32], eps: f32) -> Vec { + let mut grad = vec![0.0f32; params.len()]; + let mut p_plus = params.to_vec(); + let mut p_minus = params.to_vec(); + for i in 0..params.len() { + p_plus[i] = params[i] + eps; + p_minus[i] = params[i] - eps; + grad[i] = (f(&p_plus) - f(&p_minus)) / (2.0 * eps); + p_plus[i] = params[i]; + p_minus[i] = params[i]; + } + grad +} + +/// Clip gradients by global L2 norm. +pub fn clip_gradients(gradients: &mut [f32], max_norm: f32) { + let norm = gradients.iter().map(|g| g * g).sum::().sqrt(); + if norm > max_norm && norm > 0.0 { + let s = max_norm / norm; + gradients.iter_mut().for_each(|g| *g *= s); + } +} + +// ── Training sample ──────────────────────────────────────────────────────── + +/// A single training sample (defined locally, not dependent on dataset.rs). +#[derive(Debug, Clone)] +pub struct TrainingSample { + pub csi_features: Vec>, + pub target_keypoints: Vec<(f32, f32, f32)>, + pub target_body_parts: Vec, + pub target_uv: (Vec, Vec), +} + +// ── Checkpoint ───────────────────────────────────────────────────────────── + +/// Serializable version of EpochStats for checkpoint storage. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct EpochStatsSerializable { + pub epoch: usize, pub train_loss: f32, pub val_loss: f32, + pub pck_02: f32, pub oks_map: f32, pub lr: f32, + pub loss_keypoint: f32, pub loss_body_part: f32, pub loss_uv: f32, + pub loss_temporal: f32, pub loss_edge: f32, pub loss_symmetry: f32, +} + +/// Serializable training checkpoint. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct Checkpoint { + pub epoch: usize, + pub params: Vec, + pub optimizer_state: Vec, + pub best_loss: f32, + pub metrics: EpochStatsSerializable, +} + +impl Checkpoint { + pub fn save_to_file(&self, path: &Path) -> std::io::Result<()> { + let json = serde_json::to_string_pretty(self) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; + std::fs::write(path, json) + } + pub fn load_from_file(path: &Path) -> std::io::Result { + let json = std::fs::read_to_string(path)?; + serde_json::from_str(&json) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)) + } +} + +/// Statistics for a single training epoch. +#[derive(Debug, Clone)] +pub struct EpochStats { + pub epoch: usize, + pub train_loss: f32, + pub val_loss: f32, + pub pck_02: f32, + pub oks_map: f32, + pub lr: f32, + pub loss_components: LossComponents, +} + +impl EpochStats { + fn to_serializable(&self) -> EpochStatsSerializable { + let c = &self.loss_components; + EpochStatsSerializable { + epoch: self.epoch, train_loss: self.train_loss, val_loss: self.val_loss, + pck_02: self.pck_02, oks_map: self.oks_map, lr: self.lr, + loss_keypoint: c.keypoint, loss_body_part: c.body_part, loss_uv: c.uv, + loss_temporal: c.temporal, loss_edge: c.edge, loss_symmetry: c.symmetry, + } + } +} + +/// Final result from a complete training run. +#[derive(Debug, Clone)] +pub struct TrainingResult { + pub best_epoch: usize, + pub best_pck: f32, + pub best_oks: f32, + pub history: Vec, + pub total_time_secs: f64, +} + +/// Configuration for the training loop. +#[derive(Debug, Clone)] +pub struct TrainerConfig { + pub epochs: usize, + pub batch_size: usize, + pub lr: f32, + pub momentum: f32, + pub weight_decay: f32, + pub warmup_epochs: usize, + pub min_lr: f32, + pub early_stop_patience: usize, + pub checkpoint_every: usize, + pub loss_weights: LossWeights, +} + +impl Default for TrainerConfig { + fn default() -> Self { + Self { + epochs: 100, batch_size: 32, lr: 0.01, momentum: 0.9, weight_decay: 1e-4, + warmup_epochs: 5, min_lr: 1e-6, early_stop_patience: 10, checkpoint_every: 10, + loss_weights: LossWeights::default(), + } + } +} + +// ── Trainer ──────────────────────────────────────────────────────────────── + +/// Training loop orchestrator for WiFi DensePose pose estimation. +pub struct Trainer { + config: TrainerConfig, + optimizer: SgdOptimizer, + scheduler: WarmupCosineScheduler, + params: Vec, + history: Vec, + best_val_loss: f32, + best_epoch: usize, + epochs_without_improvement: usize, +} + +impl Trainer { + pub fn new(config: TrainerConfig) -> Self { + let optimizer = SgdOptimizer::new(config.lr, config.momentum, config.weight_decay); + let scheduler = WarmupCosineScheduler::new( + config.warmup_epochs, config.lr, config.min_lr, config.epochs, + ); + let params: Vec = (0..64).map(|i| (i as f32 * 0.7 + 0.3).sin() * 0.1).collect(); + Self { + config, optimizer, scheduler, params, history: Vec::new(), + best_val_loss: f32::MAX, best_epoch: 0, epochs_without_improvement: 0, + } + } + + pub fn train_epoch(&mut self, samples: &[TrainingSample]) -> EpochStats { + let epoch = self.history.len(); + let lr = self.scheduler.get_lr(epoch); + self.optimizer.set_lr(lr); + + let mut acc = LossComponents::default(); + let bs = self.config.batch_size.max(1); + let nb = (samples.len() + bs - 1) / bs; + + for bi in 0..nb { + let batch = &samples[bi * bs..(bi * bs + bs).min(samples.len())]; + let snap = self.params.clone(); + let w = self.config.loss_weights.clone(); + let loss_fn = |p: &[f32]| Self::batch_loss(p, batch, &w); + let mut grad = estimate_gradient(loss_fn, &snap, 1e-4); + clip_gradients(&mut grad, 1.0); + self.optimizer.step(&mut self.params, &grad); + + let c = Self::batch_loss_components(&self.params, batch); + acc.keypoint += c.keypoint; + acc.body_part += c.body_part; + acc.uv += c.uv; + acc.temporal += c.temporal; + acc.edge += c.edge; + acc.symmetry += c.symmetry; + } + + if nb > 0 { + let inv = 1.0 / nb as f32; + acc.keypoint *= inv; acc.body_part *= inv; acc.uv *= inv; + acc.temporal *= inv; acc.edge *= inv; acc.symmetry *= inv; + } + + let train_loss = composite_loss(&acc, &self.config.loss_weights); + let (pck, oks) = self.evaluate_metrics(samples); + let stats = EpochStats { + epoch, train_loss, val_loss: train_loss, pck_02: pck, oks_map: oks, + lr, loss_components: acc, + }; + self.history.push(stats.clone()); + stats + } + + pub fn should_stop(&self) -> bool { + self.epochs_without_improvement >= self.config.early_stop_patience + } + + pub fn best_metrics(&self) -> Option<&EpochStats> { + self.history.get(self.best_epoch) + } + + pub fn run_training(&mut self, train: &[TrainingSample], val: &[TrainingSample]) -> TrainingResult { + let start = std::time::Instant::now(); + for _ in 0..self.config.epochs { + let mut stats = self.train_epoch(train); + let val_loss = if !val.is_empty() { + let c = Self::batch_loss_components(&self.params, val); + composite_loss(&c, &self.config.loss_weights) + } else { stats.train_loss }; + stats.val_loss = val_loss; + if !val.is_empty() { + let (pck, oks) = self.evaluate_metrics(val); + stats.pck_02 = pck; + stats.oks_map = oks; + } + if let Some(last) = self.history.last_mut() { + last.val_loss = stats.val_loss; + last.pck_02 = stats.pck_02; + last.oks_map = stats.oks_map; + } + if val_loss < self.best_val_loss { + self.best_val_loss = val_loss; + self.best_epoch = stats.epoch; + self.epochs_without_improvement = 0; + } else { + self.epochs_without_improvement += 1; + } + if self.should_stop() { break; } + } + let best = self.best_metrics().cloned().unwrap_or(EpochStats { + epoch: 0, train_loss: f32::MAX, val_loss: f32::MAX, pck_02: 0.0, + oks_map: 0.0, lr: self.config.lr, loss_components: LossComponents::default(), + }); + TrainingResult { + best_epoch: best.epoch, best_pck: best.pck_02, best_oks: best.oks_map, + history: self.history.clone(), total_time_secs: start.elapsed().as_secs_f64(), + } + } + + pub fn checkpoint(&self) -> Checkpoint { + let m = self.history.last().map(|s| s.to_serializable()).unwrap_or( + EpochStatsSerializable { + epoch: 0, train_loss: 0.0, val_loss: 0.0, pck_02: 0.0, + oks_map: 0.0, lr: self.config.lr, loss_keypoint: 0.0, loss_body_part: 0.0, + loss_uv: 0.0, loss_temporal: 0.0, loss_edge: 0.0, loss_symmetry: 0.0, + }, + ); + Checkpoint { + epoch: self.history.len(), params: self.params.clone(), + optimizer_state: self.optimizer.state(), best_loss: self.best_val_loss, metrics: m, + } + } + + fn batch_loss(params: &[f32], batch: &[TrainingSample], w: &LossWeights) -> f32 { + composite_loss(&Self::batch_loss_components(params, batch), w) + } + + fn batch_loss_components(params: &[f32], batch: &[TrainingSample]) -> LossComponents { + if batch.is_empty() { return LossComponents::default(); } + let mut acc = LossComponents::default(); + let mut prev_kp: Option> = None; + for sample in batch { + let pred_kp = Self::predict_keypoints(params, sample); + acc.keypoint += keypoint_mse(&pred_kp, &sample.target_keypoints); + let n_parts = 24usize; + let logits: Vec = sample.target_body_parts.iter().flat_map(|_| { + (0..n_parts).map(|j| if j < params.len() { params[j] * 0.1 } else { 0.0 }) + .collect::>() + }).collect(); + acc.body_part += body_part_cross_entropy(&logits, &sample.target_body_parts, n_parts); + let (ref tu, ref tv) = sample.target_uv; + let pu: Vec = tu.iter().enumerate() + .map(|(i, &u)| u + if i < params.len() { params[i] * 0.01 } else { 0.0 }).collect(); + let pv: Vec = tv.iter().enumerate() + .map(|(i, &v)| v + if i < params.len() { params[i] * 0.01 } else { 0.0 }).collect(); + acc.uv += uv_regression_loss(&pu, &pv, tu, tv); + if let Some(ref prev) = prev_kp { + acc.temporal += temporal_consistency_loss(prev, &pred_kp); + } + acc.symmetry += symmetry_loss(&pred_kp); + prev_kp = Some(pred_kp); + } + let inv = 1.0 / batch.len() as f32; + acc.keypoint *= inv; acc.body_part *= inv; acc.uv *= inv; + acc.temporal *= inv; acc.symmetry *= inv; + acc + } + + fn predict_keypoints(params: &[f32], sample: &TrainingSample) -> Vec<(f32, f32, f32)> { + let n_kp = sample.target_keypoints.len().max(17); + let feats: Vec = sample.csi_features.iter().flat_map(|v| v.iter().copied()).collect(); + (0..n_kp).map(|k| { + let base = k * 3; + let (mut x, mut y) = (0.0f32, 0.0f32); + for (i, &f) in feats.iter().take(params.len()).enumerate() { + let pi = (base + i) % params.len(); + x += f * params[pi] * 0.01; + y += f * params[(pi + 1) % params.len()] * 0.01; + } + if base < params.len() { + x += params[base % params.len()]; + y += params[(base + 1) % params.len()]; + } + let c = if base + 2 < params.len() { + params[(base + 2) % params.len()].clamp(0.0, 1.0) + } else { 0.5 }; + (x, y, c) + }).collect() + } + + fn evaluate_metrics(&self, samples: &[TrainingSample]) -> (f32, f32) { + if samples.is_empty() { return (0.0, 0.0); } + let preds: Vec> = samples.iter().map(|s| Self::predict_keypoints(&self.params, s)).collect(); + let targets: Vec> = samples.iter().map(|s| s.target_keypoints.clone()).collect(); + let pck = preds.iter().zip(targets.iter()) + .map(|(p, t)| pck_at_threshold(p, t, 0.2)).sum::() / samples.len() as f32; + (pck, oks_map(&preds, &targets)) + } +} + +// ── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + fn mkp(off: f32) -> Vec<(f32, f32, f32)> { + (0..17).map(|i| (i as f32 + off, i as f32 * 2.0 + off, 1.0)).collect() + } + + fn symmetric_pose() -> Vec<(f32, f32, f32)> { + let mut kp = vec![(0.0f32, 0.0f32, 1.0f32); 17]; + kp[0] = (5.0, 5.0, 1.0); + for &(l, r) in &SYMMETRY_PAIRS { kp[l] = (3.0, 5.0, 1.0); kp[r] = (7.0, 5.0, 1.0); } + kp + } + + fn sample() -> TrainingSample { + TrainingSample { + csi_features: vec![vec![1.0; 8]; 4], + target_keypoints: mkp(0.0), + target_body_parts: vec![0, 1, 2, 3], + target_uv: (vec![0.5; 4], vec![0.5; 4]), + } + } + + #[test] fn keypoint_mse_zero_for_identical() { assert_eq!(keypoint_mse(&mkp(0.0), &mkp(0.0)), 0.0); } + #[test] fn keypoint_mse_positive_for_different() { assert!(keypoint_mse(&mkp(0.0), &mkp(1.0)) > 0.0); } + #[test] fn keypoint_mse_symmetric() { + let (ab, ba) = (keypoint_mse(&mkp(0.0), &mkp(1.0)), keypoint_mse(&mkp(1.0), &mkp(0.0))); + assert!((ab - ba).abs() < 1e-6, "{ab} vs {ba}"); + } + #[test] fn temporal_consistency_zero_for_static() { + assert_eq!(temporal_consistency_loss(&mkp(0.0), &mkp(0.0)), 0.0); + } + #[test] fn temporal_consistency_positive_for_motion() { + assert!(temporal_consistency_loss(&mkp(0.0), &mkp(1.0)) > 0.0); + } + #[test] fn symmetry_loss_zero_for_symmetric_pose() { + assert!(symmetry_loss(&symmetric_pose()) < 1e-6); + } + #[test] fn graph_edge_loss_zero_when_correct() { + let kp = vec![(0.0,0.0,1.0),(3.0,4.0,1.0),(6.0,0.0,1.0)]; + assert!(graph_edge_loss(&kp, &[(0,1),(1,2)], &[5.0, 5.0]) < 1e-6); + } + #[test] fn composite_loss_respects_weights() { + let c = LossComponents { keypoint:1.0, body_part:1.0, uv:1.0, temporal:1.0, edge:1.0, symmetry:1.0 }; + let w1 = LossWeights { keypoint:1.0, body_part:0.0, uv:0.0, temporal:0.0, edge:0.0, symmetry:0.0 }; + let w2 = LossWeights { keypoint:2.0, body_part:0.0, uv:0.0, temporal:0.0, edge:0.0, symmetry:0.0 }; + assert!((composite_loss(&c, &w2) - 2.0 * composite_loss(&c, &w1)).abs() < 1e-6); + let wz = LossWeights { keypoint:0.0, body_part:0.0, uv:0.0, temporal:0.0, edge:0.0, symmetry:0.0 }; + assert_eq!(composite_loss(&c, &wz), 0.0); + } + #[test] fn cosine_scheduler_starts_at_initial() { + assert!((CosineScheduler::new(0.01, 0.0001, 100).get_lr(0) - 0.01).abs() < 1e-6); + } + #[test] fn cosine_scheduler_ends_at_min() { + assert!((CosineScheduler::new(0.01, 0.0001, 100).get_lr(100) - 0.0001).abs() < 1e-6); + } + #[test] fn cosine_scheduler_midpoint() { + assert!((CosineScheduler::new(0.01, 0.0, 100).get_lr(50) - 0.005).abs() < 1e-4); + } + #[test] fn warmup_starts_at_zero() { + assert!(WarmupCosineScheduler::new(10, 0.01, 0.0001, 100).get_lr(0) < 1e-6); + } + #[test] fn warmup_reaches_initial_at_warmup_end() { + assert!((WarmupCosineScheduler::new(10, 0.01, 0.0001, 100).get_lr(10) - 0.01).abs() < 1e-6); + } + #[test] fn pck_perfect_prediction_is_1() { + assert!((pck_at_threshold(&mkp(0.0), &mkp(0.0), 0.2) - 1.0).abs() < 1e-6); + } + #[test] fn pck_all_wrong_is_0() { + assert!(pck_at_threshold(&mkp(0.0), &mkp(100.0), 0.2) < 1e-6); + } + #[test] fn oks_perfect_is_1() { + assert!((oks_single(&mkp(0.0), &mkp(0.0), &COCO_KEYPOINT_SIGMAS, 1.0) - 1.0).abs() < 1e-6); + } + #[test] fn sgd_step_reduces_simple_loss() { + let mut p = vec![5.0f32]; + let mut opt = SgdOptimizer::new(0.1, 0.0, 0.0); + let init = p[0] * p[0]; + for _ in 0..10 { let grad = vec![2.0 * p[0]]; opt.step(&mut p, &grad); } + assert!(p[0] * p[0] < init); + } + #[test] fn gradient_clipping_respects_max_norm() { + let mut g = vec![3.0, 4.0]; + clip_gradients(&mut g, 2.5); + assert!((g.iter().map(|x| x*x).sum::().sqrt() - 2.5).abs() < 1e-4); + } + #[test] fn early_stopping_triggers() { + let cfg = TrainerConfig { epochs: 100, early_stop_patience: 3, ..Default::default() }; + let mut t = Trainer::new(cfg); + let s = vec![sample()]; + t.best_val_loss = -1.0; + let mut stopped = false; + for _ in 0..20 { + t.train_epoch(&s); + t.epochs_without_improvement += 1; + if t.should_stop() { stopped = true; break; } + } + assert!(stopped); + } + #[test] fn checkpoint_round_trip() { + let mut t = Trainer::new(TrainerConfig::default()); + t.train_epoch(&[sample()]); + let ckpt = t.checkpoint(); + let dir = std::env::temp_dir().join("trainer_ckpt_test"); + std::fs::create_dir_all(&dir).unwrap(); + let path = dir.join("ckpt.json"); + ckpt.save_to_file(&path).unwrap(); + let loaded = Checkpoint::load_from_file(&path).unwrap(); + assert_eq!(loaded.epoch, ckpt.epoch); + assert_eq!(loaded.params.len(), ckpt.params.len()); + assert!((loaded.best_loss - ckpt.best_loss).abs() < 1e-6); + let _ = std::fs::remove_file(&path); + let _ = std::fs::remove_dir(&dir); + } +}