feat: End-to-end training pipeline with RuVector signal intelligence #49
11
rust-port/wifi-densepose-rs/Cargo.lock
generated
11
rust-port/wifi-densepose-rs/Cargo.lock
generated
@@ -4115,6 +4115,7 @@ dependencies = [
|
||||
"tower-http",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"wifi-densepose-wifiscan",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4176,6 +4177,15 @@ dependencies = [
|
||||
"wifi-densepose-signal",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wifi-densepose-vitals"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wifi-densepose-wasm"
|
||||
version = "0.1.0"
|
||||
@@ -4203,6 +4213,7 @@ name = "wifi-densepose-wifiscan"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"tokio",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ members = [
|
||||
"crates/wifi-densepose-train",
|
||||
"crates/wifi-densepose-sensing-server",
|
||||
"crates/wifi-densepose-wifiscan",
|
||||
"crates/wifi-densepose-vitals",
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
|
||||
@@ -34,5 +34,8 @@ chrono = { version = "0.4", features = ["serde"] }
|
||||
# CLI
|
||||
clap = { workspace = true }
|
||||
|
||||
# Multi-BSSID WiFi scanning pipeline (ADR-022 Phase 3)
|
||||
wifi-densepose-wifiscan = { path = "../wifi-densepose-wifiscan" }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3.10"
|
||||
|
||||
@@ -0,0 +1,850 @@
|
||||
//! Dataset loaders for WiFi-to-DensePose training pipeline (ADR-023 Phase 1).
|
||||
//!
|
||||
//! Provides unified data loading for MM-Fi (NeurIPS 2023) and Wi-Pose datasets,
|
||||
//! with from-scratch .npy/.mat v5 parsers, subcarrier resampling, and a unified
|
||||
//! `DataPipeline` for normalized, windowed training samples.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
use std::io;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
// ── Error type ───────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum DatasetError {
|
||||
Io(io::Error),
|
||||
Format(String),
|
||||
Missing(String),
|
||||
Shape(String),
|
||||
}
|
||||
|
||||
impl fmt::Display for DatasetError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::Io(e) => write!(f, "I/O error: {e}"),
|
||||
Self::Format(s) => write!(f, "format error: {s}"),
|
||||
Self::Missing(s) => write!(f, "missing: {s}"),
|
||||
Self::Shape(s) => write!(f, "shape error: {s}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for DatasetError {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
if let Self::Io(e) = self { Some(e) } else { None }
|
||||
}
|
||||
}
|
||||
|
||||
impl From<io::Error> for DatasetError {
|
||||
fn from(e: io::Error) -> Self { Self::Io(e) }
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, DatasetError>;
|
||||
|
||||
// ── NpyArray ─────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Dense array from .npy: flat f32 data with shape metadata.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct NpyArray {
|
||||
pub shape: Vec<usize>,
|
||||
pub data: Vec<f32>,
|
||||
}
|
||||
|
||||
impl NpyArray {
|
||||
pub fn len(&self) -> usize { self.data.len() }
|
||||
pub fn is_empty(&self) -> bool { self.data.is_empty() }
|
||||
pub fn ndim(&self) -> usize { self.shape.len() }
|
||||
}
|
||||
|
||||
// ── NpyReader ────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Minimal NumPy .npy format reader (f32/f64, v1/v2).
|
||||
pub struct NpyReader;
|
||||
|
||||
impl NpyReader {
|
||||
pub fn read_file(path: &Path) -> Result<NpyArray> {
|
||||
Self::parse(&std::fs::read(path)?)
|
||||
}
|
||||
|
||||
pub fn parse(buf: &[u8]) -> Result<NpyArray> {
|
||||
if buf.len() < 10 { return Err(DatasetError::Format("file too small for .npy".into())); }
|
||||
if &buf[0..6] != b"\x93NUMPY" {
|
||||
return Err(DatasetError::Format("missing .npy magic".into()));
|
||||
}
|
||||
let major = buf[6];
|
||||
let (header_len, header_start) = match major {
|
||||
1 => (u16::from_le_bytes([buf[8], buf[9]]) as usize, 10usize),
|
||||
2 | 3 => {
|
||||
if buf.len() < 12 { return Err(DatasetError::Format("truncated v2 header".into())); }
|
||||
(u32::from_le_bytes([buf[8], buf[9], buf[10], buf[11]]) as usize, 12)
|
||||
}
|
||||
_ => return Err(DatasetError::Format(format!("unsupported .npy version {major}"))),
|
||||
};
|
||||
let header_end = header_start + header_len;
|
||||
if header_end > buf.len() { return Err(DatasetError::Format("header past EOF".into())); }
|
||||
let hdr = std::str::from_utf8(&buf[header_start..header_end])
|
||||
.map_err(|_| DatasetError::Format("non-UTF8 header".into()))?;
|
||||
|
||||
let dtype = Self::extract_field(hdr, "descr")?;
|
||||
let is_f64 = dtype.contains("f8") || dtype.contains("float64");
|
||||
let is_f32 = dtype.contains("f4") || dtype.contains("float32");
|
||||
let is_big = dtype.starts_with('>');
|
||||
if !is_f32 && !is_f64 {
|
||||
return Err(DatasetError::Format(format!("unsupported dtype '{dtype}'")));
|
||||
}
|
||||
let fortran = Self::extract_field(hdr, "fortran_order")
|
||||
.unwrap_or_else(|_| "False".into()).contains("True");
|
||||
let shape = Self::parse_shape(hdr)?;
|
||||
let elem_sz: usize = if is_f64 { 8 } else { 4 };
|
||||
let total: usize = shape.iter().product::<usize>().max(1);
|
||||
if header_end + total * elem_sz > buf.len() {
|
||||
return Err(DatasetError::Format("data truncated".into()));
|
||||
}
|
||||
let raw = &buf[header_end..header_end + total * elem_sz];
|
||||
let mut data: Vec<f32> = if is_f64 {
|
||||
raw.chunks_exact(8).map(|c| {
|
||||
let v = if is_big { f64::from_be_bytes(c.try_into().unwrap()) }
|
||||
else { f64::from_le_bytes(c.try_into().unwrap()) };
|
||||
v as f32
|
||||
}).collect()
|
||||
} else {
|
||||
raw.chunks_exact(4).map(|c| {
|
||||
if is_big { f32::from_be_bytes(c.try_into().unwrap()) }
|
||||
else { f32::from_le_bytes(c.try_into().unwrap()) }
|
||||
}).collect()
|
||||
};
|
||||
if fortran && shape.len() == 2 {
|
||||
let (r, c) = (shape[0], shape[1]);
|
||||
let mut cd = vec![0.0f32; data.len()];
|
||||
for ri in 0..r { for ci in 0..c { cd[ri*c+ci] = data[ci*r+ri]; } }
|
||||
data = cd;
|
||||
}
|
||||
let shape = if shape.is_empty() { vec![1] } else { shape };
|
||||
Ok(NpyArray { shape, data })
|
||||
}
|
||||
|
||||
fn extract_field(hdr: &str, field: &str) -> Result<String> {
|
||||
for pat in &[format!("'{field}': "), format!("'{field}':"), format!("\"{field}\": ")] {
|
||||
if let Some(s) = hdr.find(pat.as_str()) {
|
||||
let rest = &hdr[s + pat.len()..];
|
||||
let end = rest.find(',').or_else(|| rest.find('}')).unwrap_or(rest.len());
|
||||
return Ok(rest[..end].trim().trim_matches('\'').trim_matches('"').into());
|
||||
}
|
||||
}
|
||||
Err(DatasetError::Format(format!("field '{field}' not found")))
|
||||
}
|
||||
|
||||
fn parse_shape(hdr: &str) -> Result<Vec<usize>> {
|
||||
let si = hdr.find("'shape'").or_else(|| hdr.find("\"shape\""))
|
||||
.ok_or_else(|| DatasetError::Format("no 'shape'".into()))?;
|
||||
let rest = &hdr[si..];
|
||||
let ps = rest.find('(').ok_or_else(|| DatasetError::Format("no '('".into()))?;
|
||||
let pe = rest[ps..].find(')').ok_or_else(|| DatasetError::Format("no ')'".into()))?;
|
||||
let inner = rest[ps+1..ps+pe].trim();
|
||||
if inner.is_empty() { return Ok(vec![]); }
|
||||
inner.split(',').map(|s| s.trim()).filter(|s| !s.is_empty())
|
||||
.map(|s| s.parse::<usize>().map_err(|_| DatasetError::Format(format!("bad dim: '{s}'"))))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
// ── MatReader ────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Minimal MATLAB .mat v5 reader for numeric arrays.
|
||||
pub struct MatReader;
|
||||
|
||||
const MI_INT8: u32 = 1;
|
||||
#[allow(dead_code)] const MI_UINT8: u32 = 2;
|
||||
#[allow(dead_code)] const MI_INT16: u32 = 3;
|
||||
#[allow(dead_code)] const MI_UINT16: u32 = 4;
|
||||
const MI_INT32: u32 = 5;
|
||||
const MI_UINT32: u32 = 6;
|
||||
const MI_SINGLE: u32 = 7;
|
||||
const MI_DOUBLE: u32 = 9;
|
||||
const MI_MATRIX: u32 = 14;
|
||||
|
||||
impl MatReader {
|
||||
pub fn read_file(path: &Path) -> Result<HashMap<String, NpyArray>> {
|
||||
Self::parse(&std::fs::read(path)?)
|
||||
}
|
||||
|
||||
pub fn parse(buf: &[u8]) -> Result<HashMap<String, NpyArray>> {
|
||||
if buf.len() < 128 { return Err(DatasetError::Format("too small for .mat v5".into())); }
|
||||
let swap = u16::from_le_bytes([buf[126], buf[127]]) == 0x4D49;
|
||||
let mut result = HashMap::new();
|
||||
let mut off = 128;
|
||||
while off + 8 <= buf.len() {
|
||||
let (dt, ds, ts) = Self::read_tag(buf, off, swap)?;
|
||||
let el_start = off + ts;
|
||||
let el_end = el_start + ds;
|
||||
if el_end > buf.len() { break; }
|
||||
if dt == MI_MATRIX {
|
||||
if let Ok((n, a)) = Self::parse_matrix(&buf[el_start..el_end], swap) {
|
||||
result.insert(n, a);
|
||||
}
|
||||
}
|
||||
off = (el_end + 7) & !7;
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn read_tag(buf: &[u8], off: usize, swap: bool) -> Result<(u32, usize, usize)> {
|
||||
if off + 4 > buf.len() { return Err(DatasetError::Format("truncated tag".into())); }
|
||||
let raw = Self::u32(buf, off, swap);
|
||||
let upper = (raw >> 16) & 0xFFFF;
|
||||
if upper != 0 && upper <= 4 { return Ok((raw & 0xFFFF, upper as usize, 4)); }
|
||||
if off + 8 > buf.len() { return Err(DatasetError::Format("truncated tag".into())); }
|
||||
Ok((raw, Self::u32(buf, off + 4, swap) as usize, 8))
|
||||
}
|
||||
|
||||
fn parse_matrix(buf: &[u8], swap: bool) -> Result<(String, NpyArray)> {
|
||||
let (mut name, mut shape, mut data) = (String::new(), Vec::new(), Vec::new());
|
||||
let mut off = 0;
|
||||
while off + 4 <= buf.len() {
|
||||
let (st, ss, ts) = Self::read_tag(buf, off, swap)?;
|
||||
let ss_start = off + ts;
|
||||
let ss_end = (ss_start + ss).min(buf.len());
|
||||
match st {
|
||||
MI_UINT32 if shape.is_empty() && ss == 8 => {}
|
||||
MI_INT32 if shape.is_empty() => {
|
||||
for i in 0..ss / 4 { shape.push(Self::i32(buf, ss_start + i*4, swap) as usize); }
|
||||
}
|
||||
MI_INT8 if name.is_empty() && ss_end <= buf.len() => {
|
||||
name = String::from_utf8_lossy(&buf[ss_start..ss_end])
|
||||
.trim_end_matches('\0').to_string();
|
||||
}
|
||||
MI_DOUBLE => {
|
||||
for i in 0..ss / 8 {
|
||||
let p = ss_start + i * 8;
|
||||
if p + 8 <= buf.len() { data.push(Self::f64(buf, p, swap) as f32); }
|
||||
}
|
||||
}
|
||||
MI_SINGLE => {
|
||||
for i in 0..ss / 4 {
|
||||
let p = ss_start + i * 4;
|
||||
if p + 4 <= buf.len() { data.push(Self::f32(buf, p, swap)); }
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
off = (ss_end + 7) & !7;
|
||||
}
|
||||
if name.is_empty() { name = "unnamed".into(); }
|
||||
if shape.is_empty() && !data.is_empty() { shape = vec![data.len()]; }
|
||||
// Transpose column-major to row-major for 2D
|
||||
if shape.len() == 2 {
|
||||
let (r, c) = (shape[0], shape[1]);
|
||||
if r * c == data.len() {
|
||||
let mut cd = vec![0.0f32; data.len()];
|
||||
for ri in 0..r { for ci in 0..c { cd[ri*c+ci] = data[ci*r+ri]; } }
|
||||
data = cd;
|
||||
}
|
||||
}
|
||||
Ok((name, NpyArray { shape, data }))
|
||||
}
|
||||
|
||||
fn u32(b: &[u8], o: usize, s: bool) -> u32 {
|
||||
let v = [b[o], b[o+1], b[o+2], b[o+3]];
|
||||
if s { u32::from_be_bytes(v) } else { u32::from_le_bytes(v) }
|
||||
}
|
||||
fn i32(b: &[u8], o: usize, s: bool) -> i32 {
|
||||
let v = [b[o], b[o+1], b[o+2], b[o+3]];
|
||||
if s { i32::from_be_bytes(v) } else { i32::from_le_bytes(v) }
|
||||
}
|
||||
fn f64(b: &[u8], o: usize, s: bool) -> f64 {
|
||||
let v: [u8; 8] = b[o..o+8].try_into().unwrap();
|
||||
if s { f64::from_be_bytes(v) } else { f64::from_le_bytes(v) }
|
||||
}
|
||||
fn f32(b: &[u8], o: usize, s: bool) -> f32 {
|
||||
let v = [b[o], b[o+1], b[o+2], b[o+3]];
|
||||
if s { f32::from_be_bytes(v) } else { f32::from_le_bytes(v) }
|
||||
}
|
||||
}
|
||||
|
||||
// ── Core data types ──────────────────────────────────────────────────────────
|
||||
|
||||
/// A single CSI (Channel State Information) sample.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CsiSample {
|
||||
pub amplitude: Vec<f32>,
|
||||
pub phase: Vec<f32>,
|
||||
pub timestamp_ms: u64,
|
||||
}
|
||||
|
||||
/// UV coordinate map for a body part in DensePose representation.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BodyPartUV {
|
||||
pub part_id: u8,
|
||||
pub u_coords: Vec<f32>,
|
||||
pub v_coords: Vec<f32>,
|
||||
}
|
||||
|
||||
/// Pose label: 17 COCO keypoints + optional DensePose body-part UVs.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PoseLabel {
|
||||
pub keypoints: [(f32, f32, f32); 17],
|
||||
pub body_parts: Vec<BodyPartUV>,
|
||||
pub confidence: f32,
|
||||
}
|
||||
|
||||
impl Default for PoseLabel {
|
||||
fn default() -> Self {
|
||||
Self { keypoints: [(0.0, 0.0, 0.0); 17], body_parts: Vec::new(), confidence: 0.0 }
|
||||
}
|
||||
}
|
||||
|
||||
// ── SubcarrierResampler ──────────────────────────────────────────────────────
|
||||
|
||||
/// Resamples subcarrier data via linear interpolation or zero-padding.
|
||||
pub struct SubcarrierResampler;
|
||||
|
||||
impl SubcarrierResampler {
|
||||
/// Resample: passthrough if equal, zero-pad if upsampling, interpolate if downsampling.
|
||||
pub fn resample(input: &[f32], from: usize, to: usize) -> Vec<f32> {
|
||||
if from == to || from == 0 || to == 0 { return input.to_vec(); }
|
||||
if from < to { Self::zero_pad(input, from, to) } else { Self::interpolate(input, from, to) }
|
||||
}
|
||||
|
||||
/// Resample phase data with unwrapping before interpolation.
|
||||
pub fn resample_phase(input: &[f32], from: usize, to: usize) -> Vec<f32> {
|
||||
if from == to || from == 0 || to == 0 { return input.to_vec(); }
|
||||
let unwrapped = Self::phase_unwrap(input);
|
||||
let resampled = if from < to { Self::zero_pad(&unwrapped, from, to) }
|
||||
else { Self::interpolate(&unwrapped, from, to) };
|
||||
let pi = std::f32::consts::PI;
|
||||
resampled.iter().map(|&p| {
|
||||
let mut w = p % (2.0 * pi);
|
||||
if w > pi { w -= 2.0 * pi; }
|
||||
if w < -pi { w += 2.0 * pi; }
|
||||
w
|
||||
}).collect()
|
||||
}
|
||||
|
||||
fn zero_pad(input: &[f32], from: usize, to: usize) -> Vec<f32> {
|
||||
let pad_left = (to - from) / 2;
|
||||
let mut out = vec![0.0f32; to];
|
||||
for i in 0..from.min(input.len()) {
|
||||
if pad_left + i < to { out[pad_left + i] = input[i]; }
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn interpolate(input: &[f32], from: usize, to: usize) -> Vec<f32> {
|
||||
let n = input.len().min(from);
|
||||
if n <= 1 { return vec![input.first().copied().unwrap_or(0.0); to]; }
|
||||
(0..to).map(|i| {
|
||||
let pos = i as f64 * (n - 1) as f64 / (to - 1).max(1) as f64;
|
||||
let lo = pos.floor() as usize;
|
||||
let hi = (lo + 1).min(n - 1);
|
||||
let f = (pos - lo as f64) as f32;
|
||||
input[lo] * (1.0 - f) + input[hi] * f
|
||||
}).collect()
|
||||
}
|
||||
|
||||
fn phase_unwrap(phase: &[f32]) -> Vec<f32> {
|
||||
let pi = std::f32::consts::PI;
|
||||
let mut out = vec![0.0f32; phase.len()];
|
||||
if phase.is_empty() { return out; }
|
||||
out[0] = phase[0];
|
||||
for i in 1..phase.len() {
|
||||
let mut d = phase[i] - phase[i - 1];
|
||||
while d > pi { d -= 2.0 * pi; }
|
||||
while d < -pi { d += 2.0 * pi; }
|
||||
out[i] = out[i - 1] + d;
|
||||
}
|
||||
out
|
||||
}
|
||||
}
|
||||
|
||||
// ── MmFiDataset ──────────────────────────────────────────────────────────────
|
||||
|
||||
/// MM-Fi (NeurIPS 2023) dataset loader with 56 subcarriers and 17 COCO keypoints.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MmFiDataset {
|
||||
pub csi_frames: Vec<CsiSample>,
|
||||
pub labels: Vec<PoseLabel>,
|
||||
pub sample_rate_hz: f32,
|
||||
pub n_subcarriers: usize,
|
||||
}
|
||||
|
||||
impl MmFiDataset {
|
||||
pub const SUBCARRIERS: usize = 56;
|
||||
|
||||
/// Load from directory with csi_amplitude.npy/csi.npy and labels.npy/keypoints.npy.
|
||||
pub fn load_from_directory(path: &Path) -> Result<Self> {
|
||||
if !path.is_dir() {
|
||||
return Err(DatasetError::Missing(format!("directory not found: {}", path.display())));
|
||||
}
|
||||
let amp = NpyReader::read_file(&Self::find(path, &["csi_amplitude.npy", "csi.npy"])?)?;
|
||||
let n = amp.shape.first().copied().unwrap_or(0);
|
||||
let raw_sc = if amp.shape.len() >= 2 { amp.shape[1] } else { amp.data.len() / n.max(1) };
|
||||
let phase_arr = Self::find(path, &["csi_phase.npy"]).ok()
|
||||
.and_then(|p| NpyReader::read_file(&p).ok());
|
||||
let lab = NpyReader::read_file(&Self::find(path, &["labels.npy", "keypoints.npy"])?)?;
|
||||
|
||||
let mut csi_frames = Vec::with_capacity(n);
|
||||
let mut labels = Vec::with_capacity(n);
|
||||
for i in 0..n {
|
||||
let s = i * raw_sc;
|
||||
if s + raw_sc > amp.data.len() { break; }
|
||||
let amplitude = SubcarrierResampler::resample(&.data[s..s+raw_sc], raw_sc, Self::SUBCARRIERS);
|
||||
let phase = phase_arr.as_ref().map(|pa| {
|
||||
let ps = i * raw_sc;
|
||||
if ps + raw_sc <= pa.data.len() {
|
||||
SubcarrierResampler::resample_phase(&pa.data[ps..ps+raw_sc], raw_sc, Self::SUBCARRIERS)
|
||||
} else { vec![0.0; Self::SUBCARRIERS] }
|
||||
}).unwrap_or_else(|| vec![0.0; Self::SUBCARRIERS]);
|
||||
|
||||
csi_frames.push(CsiSample { amplitude, phase, timestamp_ms: i as u64 * 50 });
|
||||
|
||||
let ks = i * 17 * 3;
|
||||
let label = if ks + 51 <= lab.data.len() {
|
||||
let d = &lab.data[ks..ks + 51];
|
||||
let mut kp = [(0.0f32, 0.0, 0.0); 17];
|
||||
for k in 0..17 { kp[k] = (d[k*3], d[k*3+1], d[k*3+2]); }
|
||||
PoseLabel { keypoints: kp, body_parts: Vec::new(), confidence: 1.0 }
|
||||
} else { PoseLabel::default() };
|
||||
labels.push(label);
|
||||
}
|
||||
Ok(Self { csi_frames, labels, sample_rate_hz: 20.0, n_subcarriers: Self::SUBCARRIERS })
|
||||
}
|
||||
|
||||
pub fn resample_subcarriers(&mut self, from: usize, to: usize) {
|
||||
for f in &mut self.csi_frames {
|
||||
f.amplitude = SubcarrierResampler::resample(&f.amplitude, from, to);
|
||||
f.phase = SubcarrierResampler::resample_phase(&f.phase, from, to);
|
||||
}
|
||||
self.n_subcarriers = to;
|
||||
}
|
||||
|
||||
pub fn iter_windows(&self, ws: usize, stride: usize) -> impl Iterator<Item = (&[CsiSample], &[PoseLabel])> {
|
||||
let stride = stride.max(1);
|
||||
let n = self.csi_frames.len();
|
||||
(0..n).step_by(stride).filter(move |&s| s + ws <= n)
|
||||
.map(move |s| (&self.csi_frames[s..s+ws], &self.labels[s..s+ws]))
|
||||
}
|
||||
|
||||
pub fn split_train_val(self, ratio: f32) -> (Self, Self) {
|
||||
let split = (self.csi_frames.len() as f32 * ratio.clamp(0.0, 1.0)) as usize;
|
||||
let (tc, vc) = self.csi_frames.split_at(split);
|
||||
let (tl, vl) = self.labels.split_at(split);
|
||||
let mk = |c: &[CsiSample], l: &[PoseLabel]| Self {
|
||||
csi_frames: c.to_vec(), labels: l.to_vec(),
|
||||
sample_rate_hz: self.sample_rate_hz, n_subcarriers: self.n_subcarriers,
|
||||
};
|
||||
(mk(tc, tl), mk(vc, vl))
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize { self.csi_frames.len() }
|
||||
pub fn is_empty(&self) -> bool { self.csi_frames.is_empty() }
|
||||
pub fn get(&self, idx: usize) -> Option<(&CsiSample, &PoseLabel)> {
|
||||
self.csi_frames.get(idx).zip(self.labels.get(idx))
|
||||
}
|
||||
|
||||
fn find(dir: &Path, names: &[&str]) -> Result<PathBuf> {
|
||||
for n in names { let p = dir.join(n); if p.exists() { return Ok(p); } }
|
||||
Err(DatasetError::Missing(format!("none of {names:?} in {}", dir.display())))
|
||||
}
|
||||
}
|
||||
|
||||
// ── WiPoseDataset ────────────────────────────────────────────────────────────
|
||||
|
||||
/// Wi-Pose dataset loader: .mat v5, 30 subcarriers (-> 56), 18 keypoints (-> 17 COCO).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WiPoseDataset {
|
||||
pub csi_frames: Vec<CsiSample>,
|
||||
pub labels: Vec<PoseLabel>,
|
||||
pub sample_rate_hz: f32,
|
||||
pub n_subcarriers: usize,
|
||||
}
|
||||
|
||||
impl WiPoseDataset {
|
||||
pub const RAW_SUBCARRIERS: usize = 30;
|
||||
pub const TARGET_SUBCARRIERS: usize = 56;
|
||||
pub const RAW_KEYPOINTS: usize = 18;
|
||||
pub const COCO_KEYPOINTS: usize = 17;
|
||||
|
||||
pub fn load_from_mat(path: &Path) -> Result<Self> {
|
||||
let arrays = MatReader::read_file(path)?;
|
||||
let csi = arrays.get("csi").or_else(|| arrays.get("csi_data")).or_else(|| arrays.get("CSI"))
|
||||
.ok_or_else(|| DatasetError::Missing("no CSI variable in .mat".into()))?;
|
||||
let n = csi.shape.first().copied().unwrap_or(0);
|
||||
let raw = if csi.shape.len() >= 2 { csi.shape[1] } else { Self::RAW_SUBCARRIERS };
|
||||
let lab = arrays.get("keypoints").or_else(|| arrays.get("labels")).or_else(|| arrays.get("pose"));
|
||||
|
||||
let mut csi_frames = Vec::with_capacity(n);
|
||||
let mut labels = Vec::with_capacity(n);
|
||||
for i in 0..n {
|
||||
let s = i * raw;
|
||||
if s + raw > csi.data.len() { break; }
|
||||
let amp = SubcarrierResampler::resample(&csi.data[s..s+raw], raw, Self::TARGET_SUBCARRIERS);
|
||||
csi_frames.push(CsiSample { amplitude: amp, phase: vec![0.0; Self::TARGET_SUBCARRIERS], timestamp_ms: i as u64 * 100 });
|
||||
let label = lab.and_then(|la| {
|
||||
let ks = i * Self::RAW_KEYPOINTS * 3;
|
||||
if ks + Self::RAW_KEYPOINTS * 3 <= la.data.len() {
|
||||
Some(Self::map_18_to_17(&la.data[ks..ks + Self::RAW_KEYPOINTS * 3]))
|
||||
} else { None }
|
||||
}).unwrap_or_default();
|
||||
labels.push(label);
|
||||
}
|
||||
Ok(Self { csi_frames, labels, sample_rate_hz: 10.0, n_subcarriers: Self::TARGET_SUBCARRIERS })
|
||||
}
|
||||
|
||||
/// Map 18 keypoints to 17 COCO: keep index 0 (nose), drop index 1, map 2..18 -> 1..16.
|
||||
fn map_18_to_17(data: &[f32]) -> PoseLabel {
|
||||
let mut kp = [(0.0f32, 0.0, 0.0); 17];
|
||||
if data.len() >= 18 * 3 {
|
||||
kp[0] = (data[0], data[1], data[2]);
|
||||
for i in 1..17 { let s = (i + 1) * 3; kp[i] = (data[s], data[s+1], data[s+2]); }
|
||||
}
|
||||
PoseLabel { keypoints: kp, body_parts: Vec::new(), confidence: 1.0 }
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize { self.csi_frames.len() }
|
||||
pub fn is_empty(&self) -> bool { self.csi_frames.is_empty() }
|
||||
}
|
||||
|
||||
// ── DataPipeline ─────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum DataSource {
|
||||
MmFi(PathBuf),
|
||||
WiPose(PathBuf),
|
||||
Combined(Vec<DataSource>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DataConfig {
|
||||
pub source: DataSource,
|
||||
pub window_size: usize,
|
||||
pub stride: usize,
|
||||
pub target_subcarriers: usize,
|
||||
pub normalize: bool,
|
||||
}
|
||||
|
||||
impl Default for DataConfig {
|
||||
fn default() -> Self {
|
||||
Self { source: DataSource::Combined(Vec::new()), window_size: 10, stride: 5,
|
||||
target_subcarriers: 56, normalize: true }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TrainingSample {
|
||||
pub csi_window: Vec<Vec<f32>>,
|
||||
pub pose_label: PoseLabel,
|
||||
pub source: &'static str,
|
||||
}
|
||||
|
||||
/// Unified pipeline: loads, resamples, windows, and normalizes training data.
|
||||
pub struct DataPipeline { config: DataConfig }
|
||||
|
||||
impl DataPipeline {
|
||||
pub fn new(config: DataConfig) -> Self { Self { config } }
|
||||
|
||||
pub fn load(&self) -> Result<Vec<TrainingSample>> {
|
||||
let mut out = Vec::new();
|
||||
self.load_source(&self.config.source, &mut out)?;
|
||||
if self.config.normalize && !out.is_empty() { Self::normalize_samples(&mut out); }
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
fn load_source(&self, src: &DataSource, out: &mut Vec<TrainingSample>) -> Result<()> {
|
||||
match src {
|
||||
DataSource::MmFi(p) => {
|
||||
let mut ds = MmFiDataset::load_from_directory(p)?;
|
||||
if ds.n_subcarriers != self.config.target_subcarriers {
|
||||
let f = ds.n_subcarriers;
|
||||
ds.resample_subcarriers(f, self.config.target_subcarriers);
|
||||
}
|
||||
self.extract_windows(&ds.csi_frames, &ds.labels, "mmfi", out);
|
||||
}
|
||||
DataSource::WiPose(p) => {
|
||||
let ds = WiPoseDataset::load_from_mat(p)?;
|
||||
self.extract_windows(&ds.csi_frames, &ds.labels, "wipose", out);
|
||||
}
|
||||
DataSource::Combined(srcs) => { for s in srcs { self.load_source(s, out)?; } }
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn extract_windows(&self, frames: &[CsiSample], labels: &[PoseLabel],
|
||||
source: &'static str, out: &mut Vec<TrainingSample>) {
|
||||
let (ws, stride) = (self.config.window_size, self.config.stride.max(1));
|
||||
let mut s = 0;
|
||||
while s + ws <= frames.len() {
|
||||
let window: Vec<Vec<f32>> = frames[s..s+ws].iter().map(|f| f.amplitude.clone()).collect();
|
||||
let label = labels.get(s + ws / 2).cloned().unwrap_or_default();
|
||||
out.push(TrainingSample { csi_window: window, pose_label: label, source });
|
||||
s += stride;
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_samples(samples: &mut [TrainingSample]) {
|
||||
let ns = samples.first().and_then(|s| s.csi_window.first()).map(|f| f.len()).unwrap_or(0);
|
||||
if ns == 0 { return; }
|
||||
let (mut sum, mut sq) = (vec![0.0f64; ns], vec![0.0f64; ns]);
|
||||
let mut cnt = 0u64;
|
||||
for s in samples.iter() {
|
||||
for f in &s.csi_window {
|
||||
for (j, &v) in f.iter().enumerate().take(ns) {
|
||||
let v = v as f64; sum[j] += v; sq[j] += v * v;
|
||||
}
|
||||
cnt += 1;
|
||||
}
|
||||
}
|
||||
if cnt == 0 { return; }
|
||||
let mean: Vec<f64> = sum.iter().map(|s| s / cnt as f64).collect();
|
||||
let std: Vec<f64> = sq.iter().zip(mean.iter())
|
||||
.map(|(&s, &m)| (s / cnt as f64 - m * m).max(0.0).sqrt().max(1e-8)).collect();
|
||||
for s in samples.iter_mut() {
|
||||
for f in &mut s.csi_window {
|
||||
for (j, v) in f.iter_mut().enumerate().take(ns) {
|
||||
*v = ((*v as f64 - mean[j]) / std[j]) as f32;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Tests ────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_npy_f32(shape: &[usize], data: &[f32]) -> Vec<u8> {
|
||||
let ss = if shape.len() == 1 { format!("({},)", shape[0]) }
|
||||
else { format!("({})", shape.iter().map(|d| d.to_string()).collect::<Vec<_>>().join(", ")) };
|
||||
let hdr = format!("{{'descr': '<f4', 'fortran_order': False, 'shape': {ss}, }}");
|
||||
let total = 10 + hdr.len();
|
||||
let padded = ((total + 63) / 64) * 64;
|
||||
let hl = padded - 10;
|
||||
let mut buf = Vec::new();
|
||||
buf.extend_from_slice(b"\x93NUMPY\x01\x00");
|
||||
buf.extend_from_slice(&(hl as u16).to_le_bytes());
|
||||
buf.extend_from_slice(hdr.as_bytes());
|
||||
buf.resize(10 + hl, b' ');
|
||||
for &v in data { buf.extend_from_slice(&v.to_le_bytes()); }
|
||||
buf
|
||||
}
|
||||
|
||||
fn make_npy_f64(shape: &[usize], data: &[f64]) -> Vec<u8> {
|
||||
let ss = if shape.len() == 1 { format!("({},)", shape[0]) }
|
||||
else { format!("({})", shape.iter().map(|d| d.to_string()).collect::<Vec<_>>().join(", ")) };
|
||||
let hdr = format!("{{'descr': '<f8', 'fortran_order': False, 'shape': {ss}, }}");
|
||||
let total = 10 + hdr.len();
|
||||
let padded = ((total + 63) / 64) * 64;
|
||||
let hl = padded - 10;
|
||||
let mut buf = Vec::new();
|
||||
buf.extend_from_slice(b"\x93NUMPY\x01\x00");
|
||||
buf.extend_from_slice(&(hl as u16).to_le_bytes());
|
||||
buf.extend_from_slice(hdr.as_bytes());
|
||||
buf.resize(10 + hl, b' ');
|
||||
for &v in data { buf.extend_from_slice(&v.to_le_bytes()); }
|
||||
buf
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn npy_header_parse_1d() {
|
||||
let buf = make_npy_f32(&[5], &[1.0, 2.0, 3.0, 4.0, 5.0]);
|
||||
let arr = NpyReader::parse(&buf).unwrap();
|
||||
assert_eq!(arr.shape, vec![5]);
|
||||
assert_eq!(arr.ndim(), 1);
|
||||
assert_eq!(arr.len(), 5);
|
||||
assert!((arr.data[0] - 1.0).abs() < f32::EPSILON);
|
||||
assert!((arr.data[4] - 5.0).abs() < f32::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn npy_header_parse_2d() {
|
||||
let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
|
||||
let buf = make_npy_f32(&[3, 4], &data);
|
||||
let arr = NpyReader::parse(&buf).unwrap();
|
||||
assert_eq!(arr.shape, vec![3, 4]);
|
||||
assert_eq!(arr.ndim(), 2);
|
||||
assert_eq!(arr.len(), 12);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn npy_header_parse_3d() {
|
||||
let data: Vec<f64> = (0..24).map(|i| i as f64 * 0.5).collect();
|
||||
let buf = make_npy_f64(&[2, 3, 4], &data);
|
||||
let arr = NpyReader::parse(&buf).unwrap();
|
||||
assert_eq!(arr.shape, vec![2, 3, 4]);
|
||||
assert_eq!(arr.ndim(), 3);
|
||||
assert_eq!(arr.len(), 24);
|
||||
assert!((arr.data[23] - 11.5).abs() < 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn subcarrier_resample_passthrough() {
|
||||
let input: Vec<f32> = (0..56).map(|i| i as f32).collect();
|
||||
let output = SubcarrierResampler::resample(&input, 56, 56);
|
||||
assert_eq!(output, input);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn subcarrier_resample_upsample() {
|
||||
let input: Vec<f32> = (0..30).map(|i| (i + 1) as f32).collect();
|
||||
let out = SubcarrierResampler::resample(&input, 30, 56);
|
||||
assert_eq!(out.len(), 56);
|
||||
// pad_left = 13, leading zeros
|
||||
for i in 0..13 { assert!(out[i].abs() < f32::EPSILON, "expected zero at {i}"); }
|
||||
// original data in middle
|
||||
for i in 0..30 { assert!((out[13+i] - input[i]).abs() < f32::EPSILON); }
|
||||
// trailing zeros
|
||||
for i in 43..56 { assert!(out[i].abs() < f32::EPSILON, "expected zero at {i}"); }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn subcarrier_resample_downsample() {
|
||||
let input: Vec<f32> = (0..114).map(|i| i as f32).collect();
|
||||
let out = SubcarrierResampler::resample(&input, 114, 56);
|
||||
assert_eq!(out.len(), 56);
|
||||
assert!((out[0]).abs() < f32::EPSILON);
|
||||
assert!((out[55] - 113.0).abs() < 0.1);
|
||||
for i in 1..56 { assert!(out[i] >= out[i-1], "not monotonic at {i}"); }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn subcarrier_resample_preserves_dc() {
|
||||
let out = SubcarrierResampler::resample(&vec![42.0f32; 114], 114, 56);
|
||||
assert_eq!(out.len(), 56);
|
||||
for (i, &v) in out.iter().enumerate() {
|
||||
assert!((v - 42.0).abs() < 1e-5, "DC not preserved at {i}: {v}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mmfi_sample_structure() {
|
||||
let s = CsiSample { amplitude: vec![0.0; 56], phase: vec![0.0; 56], timestamp_ms: 100 };
|
||||
assert_eq!(s.amplitude.len(), 56);
|
||||
assert_eq!(s.phase.len(), 56);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wipose_zero_pad() {
|
||||
let raw: Vec<f32> = (1..=30).map(|i| i as f32).collect();
|
||||
let p = SubcarrierResampler::resample(&raw, 30, 56);
|
||||
assert_eq!(p.len(), 56);
|
||||
assert!(p[0].abs() < f32::EPSILON);
|
||||
assert!((p[13] - 1.0).abs() < f32::EPSILON);
|
||||
assert!((p[42] - 30.0).abs() < f32::EPSILON);
|
||||
assert!(p[55].abs() < f32::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wipose_keypoint_mapping() {
|
||||
let mut kp = vec![0.0f32; 18 * 3];
|
||||
kp[0] = 1.0; kp[1] = 2.0; kp[2] = 1.0; // nose
|
||||
kp[3] = 99.0; kp[4] = 99.0; kp[5] = 99.0; // extra (dropped)
|
||||
kp[6] = 3.0; kp[7] = 4.0; kp[8] = 1.0; // left eye -> COCO 1
|
||||
let label = WiPoseDataset::map_18_to_17(&kp);
|
||||
assert_eq!(label.keypoints.len(), 17);
|
||||
assert!((label.keypoints[0].0 - 1.0).abs() < f32::EPSILON);
|
||||
assert!((label.keypoints[1].0 - 3.0).abs() < f32::EPSILON); // not 99
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn train_val_split_ratio() {
|
||||
let mk = |n: usize| MmFiDataset {
|
||||
csi_frames: (0..n).map(|i| CsiSample { amplitude: vec![i as f32; 56], phase: vec![0.0; 56], timestamp_ms: i as u64 }).collect(),
|
||||
labels: (0..n).map(|_| PoseLabel::default()).collect(),
|
||||
sample_rate_hz: 20.0, n_subcarriers: 56,
|
||||
};
|
||||
let (train, val) = mk(100).split_train_val(0.8);
|
||||
assert_eq!(train.len(), 80);
|
||||
assert_eq!(val.len(), 20);
|
||||
assert_eq!(train.len() + val.len(), 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sliding_window_count() {
|
||||
let ds = MmFiDataset {
|
||||
csi_frames: (0..20).map(|i| CsiSample { amplitude: vec![i as f32; 56], phase: vec![0.0; 56], timestamp_ms: i as u64 }).collect(),
|
||||
labels: (0..20).map(|_| PoseLabel::default()).collect(),
|
||||
sample_rate_hz: 20.0, n_subcarriers: 56,
|
||||
};
|
||||
assert_eq!(ds.iter_windows(5, 5).count(), 4);
|
||||
assert_eq!(ds.iter_windows(5, 1).count(), 16);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sliding_window_overlap() {
|
||||
let ds = MmFiDataset {
|
||||
csi_frames: (0..10).map(|i| CsiSample { amplitude: vec![i as f32; 56], phase: vec![0.0; 56], timestamp_ms: i as u64 }).collect(),
|
||||
labels: (0..10).map(|_| PoseLabel::default()).collect(),
|
||||
sample_rate_hz: 20.0, n_subcarriers: 56,
|
||||
};
|
||||
let w: Vec<_> = ds.iter_windows(4, 2).collect();
|
||||
assert_eq!(w.len(), 4);
|
||||
assert!((w[0].0[0].amplitude[0]).abs() < f32::EPSILON);
|
||||
assert!((w[1].0[0].amplitude[0] - 2.0).abs() < f32::EPSILON);
|
||||
assert_eq!(w[0].0[2].amplitude[0], w[1].0[0].amplitude[0]); // overlap
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn data_pipeline_normalize() {
|
||||
let mut samples = vec![
|
||||
TrainingSample { csi_window: vec![vec![10.0, 20.0, 30.0]; 2], pose_label: PoseLabel::default(), source: "test" },
|
||||
TrainingSample { csi_window: vec![vec![30.0, 40.0, 50.0]; 2], pose_label: PoseLabel::default(), source: "test" },
|
||||
];
|
||||
DataPipeline::normalize_samples(&mut samples);
|
||||
for j in 0..3 {
|
||||
let (mut s, mut c) = (0.0f64, 0u64);
|
||||
for sam in &samples { for f in &sam.csi_window { s += f[j] as f64; c += 1; } }
|
||||
assert!(( s / c as f64).abs() < 1e-5, "mean not ~0 for sub {j}");
|
||||
let mut vs = 0.0f64;
|
||||
let m = s / c as f64;
|
||||
for sam in &samples { for f in &sam.csi_window { vs += (f[j] as f64 - m).powi(2); } }
|
||||
assert!(((vs / c as f64).sqrt() - 1.0).abs() < 0.1, "std not ~1 for sub {j}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pose_label_default() {
|
||||
let l = PoseLabel::default();
|
||||
assert_eq!(l.keypoints.len(), 17);
|
||||
assert!(l.body_parts.is_empty());
|
||||
assert!(l.confidence.abs() < f32::EPSILON);
|
||||
for (i, kp) in l.keypoints.iter().enumerate() {
|
||||
assert!(kp.0.abs() < f32::EPSILON && kp.1.abs() < f32::EPSILON, "kp {i} not zero");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn body_part_uv_round_trip() {
|
||||
let bpu = BodyPartUV { part_id: 5, u_coords: vec![0.1, 0.2, 0.3], v_coords: vec![0.4, 0.5, 0.6] };
|
||||
let json = serde_json::to_string(&bpu).unwrap();
|
||||
let r: BodyPartUV = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(r.part_id, 5);
|
||||
assert_eq!(r.u_coords.len(), 3);
|
||||
assert!((r.u_coords[0] - 0.1).abs() < f32::EPSILON);
|
||||
assert!((r.v_coords[2] - 0.6).abs() < f32::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn combined_source_merges_datasets() {
|
||||
let mk = |n: usize, base: f32| -> (Vec<CsiSample>, Vec<PoseLabel>) {
|
||||
let f: Vec<CsiSample> = (0..n).map(|i| CsiSample { amplitude: vec![base + i as f32; 56], phase: vec![0.0; 56], timestamp_ms: i as u64 * 50 }).collect();
|
||||
let l: Vec<PoseLabel> = (0..n).map(|_| PoseLabel::default()).collect();
|
||||
(f, l)
|
||||
};
|
||||
let pipe = DataPipeline::new(DataConfig { source: DataSource::Combined(Vec::new()),
|
||||
window_size: 3, stride: 1, target_subcarriers: 56, normalize: false });
|
||||
let mut all = Vec::new();
|
||||
let (fa, la) = mk(5, 0.0);
|
||||
pipe.extract_windows(&fa, &la, "mmfi", &mut all);
|
||||
assert_eq!(all.len(), 3);
|
||||
let (fb, lb) = mk(4, 100.0);
|
||||
pipe.extract_windows(&fb, &lb, "wipose", &mut all);
|
||||
assert_eq!(all.len(), 5);
|
||||
assert_eq!(all[0].source, "mmfi");
|
||||
assert_eq!(all[3].source, "wipose");
|
||||
assert!(all[0].csi_window[0][0] < 10.0);
|
||||
assert!(all[4].csi_window[0][0] > 90.0);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,589 @@
|
||||
//! Graph Transformer + GNN for WiFi CSI-to-Pose estimation (ADR-023 Phase 2).
|
||||
//!
|
||||
//! Cross-attention bottleneck between antenna-space CSI features and COCO 17-keypoint
|
||||
//! body graph, followed by GCN message passing. All math is pure `std`.
|
||||
|
||||
/// Xorshift64 PRNG for deterministic weight initialization.
|
||||
#[derive(Debug, Clone)]
|
||||
struct Rng64 { state: u64 }
|
||||
|
||||
impl Rng64 {
|
||||
fn new(seed: u64) -> Self {
|
||||
Self { state: if seed == 0 { 0xDEAD_BEEF_CAFE_1234 } else { seed } }
|
||||
}
|
||||
fn next_u64(&mut self) -> u64 {
|
||||
let mut x = self.state;
|
||||
x ^= x << 13; x ^= x >> 7; x ^= x << 17;
|
||||
self.state = x; x
|
||||
}
|
||||
/// Uniform f32 in (-1, 1).
|
||||
fn next_f32(&mut self) -> f32 {
|
||||
let f = (self.next_u64() >> 11) as f32 / (1u64 << 53) as f32;
|
||||
f * 2.0 - 1.0
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn relu(x: f32) -> f32 { if x > 0.0 { x } else { 0.0 } }
|
||||
|
||||
#[inline]
|
||||
fn sigmoid(x: f32) -> f32 {
|
||||
if x >= 0.0 { 1.0 / (1.0 + (-x).exp()) }
|
||||
else { let ex = x.exp(); ex / (1.0 + ex) }
|
||||
}
|
||||
|
||||
/// Numerically stable softmax. Writes normalised weights into `out`.
|
||||
fn softmax(scores: &[f32], out: &mut [f32]) {
|
||||
debug_assert_eq!(scores.len(), out.len());
|
||||
if scores.is_empty() { return; }
|
||||
let max = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
|
||||
let mut sum = 0.0f32;
|
||||
for (o, &s) in out.iter_mut().zip(scores) {
|
||||
let e = (s - max).exp(); *o = e; sum += e;
|
||||
}
|
||||
let inv = if sum > 1e-10 { 1.0 / sum } else { 0.0 };
|
||||
for o in out.iter_mut() { *o *= inv; }
|
||||
}
|
||||
|
||||
// ── Linear layer ─────────────────────────────────────────────────────────
|
||||
|
||||
/// Dense linear transformation y = Wx + b (row-major weights).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Linear {
|
||||
in_features: usize,
|
||||
out_features: usize,
|
||||
weights: Vec<Vec<f32>>,
|
||||
bias: Vec<f32>,
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
/// Xavier/Glorot uniform init with default seed.
|
||||
pub fn new(in_features: usize, out_features: usize) -> Self {
|
||||
Self::with_seed(in_features, out_features, 42)
|
||||
}
|
||||
/// Xavier/Glorot uniform init with explicit seed.
|
||||
pub fn with_seed(in_features: usize, out_features: usize, seed: u64) -> Self {
|
||||
let mut rng = Rng64::new(seed);
|
||||
let limit = (6.0 / (in_features + out_features) as f32).sqrt();
|
||||
let weights = (0..out_features)
|
||||
.map(|_| (0..in_features).map(|_| rng.next_f32() * limit).collect())
|
||||
.collect();
|
||||
Self { in_features, out_features, weights, bias: vec![0.0; out_features] }
|
||||
}
|
||||
/// All-zero weights (for testing).
|
||||
pub fn zeros(in_features: usize, out_features: usize) -> Self {
|
||||
Self {
|
||||
in_features, out_features,
|
||||
weights: vec![vec![0.0; in_features]; out_features],
|
||||
bias: vec![0.0; out_features],
|
||||
}
|
||||
}
|
||||
/// Forward pass: y = Wx + b.
|
||||
pub fn forward(&self, input: &[f32]) -> Vec<f32> {
|
||||
assert_eq!(input.len(), self.in_features,
|
||||
"Linear input mismatch: expected {}, got {}", self.in_features, input.len());
|
||||
let mut out = vec![0.0f32; self.out_features];
|
||||
for (i, row) in self.weights.iter().enumerate() {
|
||||
let mut s = self.bias[i];
|
||||
for (w, x) in row.iter().zip(input) { s += w * x; }
|
||||
out[i] = s;
|
||||
}
|
||||
out
|
||||
}
|
||||
pub fn weights(&self) -> &[Vec<f32>] { &self.weights }
|
||||
pub fn set_weights(&mut self, w: Vec<Vec<f32>>) {
|
||||
assert_eq!(w.len(), self.out_features);
|
||||
for row in &w { assert_eq!(row.len(), self.in_features); }
|
||||
self.weights = w;
|
||||
}
|
||||
pub fn set_bias(&mut self, b: Vec<f32>) {
|
||||
assert_eq!(b.len(), self.out_features);
|
||||
self.bias = b;
|
||||
}
|
||||
}
|
||||
|
||||
// ── AntennaGraph ─────────────────────────────────────────────────────────
|
||||
|
||||
/// Spatial topology graph over TX-RX antenna pairs. Nodes = pairs, edges connect
|
||||
/// pairs sharing a TX or RX antenna.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AntennaGraph {
|
||||
n_tx: usize, n_rx: usize, n_pairs: usize,
|
||||
adjacency: Vec<Vec<f32>>,
|
||||
}
|
||||
|
||||
impl AntennaGraph {
|
||||
/// Build antenna graph. pair_id = tx * n_rx + rx. Adjacent if shared TX or RX.
|
||||
pub fn new(n_tx: usize, n_rx: usize) -> Self {
|
||||
let n_pairs = n_tx * n_rx;
|
||||
let mut adj = vec![vec![0.0f32; n_pairs]; n_pairs];
|
||||
for i in 0..n_pairs {
|
||||
let (tx_i, rx_i) = (i / n_rx, i % n_rx);
|
||||
adj[i][i] = 1.0;
|
||||
for j in (i + 1)..n_pairs {
|
||||
let (tx_j, rx_j) = (j / n_rx, j % n_rx);
|
||||
if tx_i == tx_j || rx_i == rx_j {
|
||||
adj[i][j] = 1.0; adj[j][i] = 1.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
Self { n_tx, n_rx, n_pairs, adjacency: adj }
|
||||
}
|
||||
pub fn n_nodes(&self) -> usize { self.n_pairs }
|
||||
pub fn adjacency_matrix(&self) -> &Vec<Vec<f32>> { &self.adjacency }
|
||||
pub fn n_tx(&self) -> usize { self.n_tx }
|
||||
pub fn n_rx(&self) -> usize { self.n_rx }
|
||||
}
|
||||
|
||||
// ── BodyGraph ────────────────────────────────────────────────────────────
|
||||
|
||||
/// COCO 17-keypoint skeleton graph with 16 anatomical edges.
|
||||
///
|
||||
/// Indices: 0=nose 1=l_eye 2=r_eye 3=l_ear 4=r_ear 5=l_shoulder 6=r_shoulder
|
||||
/// 7=l_elbow 8=r_elbow 9=l_wrist 10=r_wrist 11=l_hip 12=r_hip 13=l_knee
|
||||
/// 14=r_knee 15=l_ankle 16=r_ankle
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BodyGraph {
|
||||
adjacency: [[f32; 17]; 17],
|
||||
edges: Vec<(usize, usize)>,
|
||||
}
|
||||
|
||||
pub const COCO_KEYPOINT_NAMES: [&str; 17] = [
|
||||
"nose","left_eye","right_eye","left_ear","right_ear",
|
||||
"left_shoulder","right_shoulder","left_elbow","right_elbow",
|
||||
"left_wrist","right_wrist","left_hip","right_hip",
|
||||
"left_knee","right_knee","left_ankle","right_ankle",
|
||||
];
|
||||
|
||||
const COCO_EDGES: [(usize, usize); 16] = [
|
||||
(0,1),(0,2),(1,3),(2,4),(5,6),(5,7),(7,9),(6,8),
|
||||
(8,10),(5,11),(6,12),(11,12),(11,13),(13,15),(12,14),(14,16),
|
||||
];
|
||||
|
||||
impl BodyGraph {
|
||||
pub fn new() -> Self {
|
||||
let mut adjacency = [[0.0f32; 17]; 17];
|
||||
for i in 0..17 { adjacency[i][i] = 1.0; }
|
||||
for &(u, v) in &COCO_EDGES { adjacency[u][v] = 1.0; adjacency[v][u] = 1.0; }
|
||||
Self { adjacency, edges: COCO_EDGES.to_vec() }
|
||||
}
|
||||
pub fn adjacency_matrix(&self) -> &[[f32; 17]; 17] { &self.adjacency }
|
||||
pub fn edge_list(&self) -> &Vec<(usize, usize)> { &self.edges }
|
||||
pub fn n_nodes(&self) -> usize { 17 }
|
||||
pub fn n_edges(&self) -> usize { self.edges.len() }
|
||||
|
||||
/// Degree of each node (including self-loop).
|
||||
pub fn degrees(&self) -> [f32; 17] {
|
||||
let mut deg = [0.0f32; 17];
|
||||
for i in 0..17 { for j in 0..17 { deg[i] += self.adjacency[i][j]; } }
|
||||
deg
|
||||
}
|
||||
/// Symmetric normalised adjacency D^{-1/2} A D^{-1/2}.
|
||||
pub fn normalized_adjacency(&self) -> [[f32; 17]; 17] {
|
||||
let deg = self.degrees();
|
||||
let inv_sqrt: Vec<f32> = deg.iter()
|
||||
.map(|&d| if d > 0.0 { 1.0 / d.sqrt() } else { 0.0 }).collect();
|
||||
let mut norm = [[0.0f32; 17]; 17];
|
||||
for i in 0..17 { for j in 0..17 {
|
||||
norm[i][j] = inv_sqrt[i] * self.adjacency[i][j] * inv_sqrt[j];
|
||||
}}
|
||||
norm
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for BodyGraph { fn default() -> Self { Self::new() } }
|
||||
|
||||
// ── CrossAttention ───────────────────────────────────────────────────────
|
||||
|
||||
/// Multi-head scaled dot-product cross-attention.
|
||||
/// Attn(Q,K,V) = softmax(QK^T / sqrt(d_k)) V, split into n_heads.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CrossAttention {
|
||||
d_model: usize, n_heads: usize, d_k: usize,
|
||||
w_q: Linear, w_k: Linear, w_v: Linear, w_o: Linear,
|
||||
}
|
||||
|
||||
impl CrossAttention {
|
||||
pub fn new(d_model: usize, n_heads: usize) -> Self {
|
||||
assert!(d_model % n_heads == 0,
|
||||
"d_model ({d_model}) must be divisible by n_heads ({n_heads})");
|
||||
let d_k = d_model / n_heads;
|
||||
let s = 123u64;
|
||||
Self { d_model, n_heads, d_k,
|
||||
w_q: Linear::with_seed(d_model, d_model, s),
|
||||
w_k: Linear::with_seed(d_model, d_model, s+1),
|
||||
w_v: Linear::with_seed(d_model, d_model, s+2),
|
||||
w_o: Linear::with_seed(d_model, d_model, s+3),
|
||||
}
|
||||
}
|
||||
/// query [n_q, d_model], key/value [n_kv, d_model] -> [n_q, d_model].
|
||||
pub fn forward(&self, query: &[Vec<f32>], key: &[Vec<f32>], value: &[Vec<f32>]) -> Vec<Vec<f32>> {
|
||||
let (n_q, n_kv) = (query.len(), key.len());
|
||||
if n_q == 0 || n_kv == 0 { return vec![vec![0.0; self.d_model]; n_q]; }
|
||||
|
||||
let q_proj: Vec<Vec<f32>> = query.iter().map(|q| self.w_q.forward(q)).collect();
|
||||
let k_proj: Vec<Vec<f32>> = key.iter().map(|k| self.w_k.forward(k)).collect();
|
||||
let v_proj: Vec<Vec<f32>> = value.iter().map(|v| self.w_v.forward(v)).collect();
|
||||
|
||||
let scale = (self.d_k as f32).sqrt();
|
||||
let mut output = vec![vec![0.0f32; self.d_model]; n_q];
|
||||
|
||||
for qi in 0..n_q {
|
||||
let mut concat = Vec::with_capacity(self.d_model);
|
||||
for h in 0..self.n_heads {
|
||||
let (start, end) = (h * self.d_k, (h + 1) * self.d_k);
|
||||
let q_h = &q_proj[qi][start..end];
|
||||
let mut scores = vec![0.0f32; n_kv];
|
||||
for ki in 0..n_kv {
|
||||
let dot: f32 = q_h.iter().zip(&k_proj[ki][start..end]).map(|(a,b)| a*b).sum();
|
||||
scores[ki] = dot / scale;
|
||||
}
|
||||
let mut wts = vec![0.0f32; n_kv];
|
||||
softmax(&scores, &mut wts);
|
||||
let mut head_out = vec![0.0f32; self.d_k];
|
||||
for ki in 0..n_kv {
|
||||
for (o, &v) in head_out.iter_mut().zip(&v_proj[ki][start..end]) {
|
||||
*o += wts[ki] * v;
|
||||
}
|
||||
}
|
||||
concat.extend_from_slice(&head_out);
|
||||
}
|
||||
output[qi] = self.w_o.forward(&concat);
|
||||
}
|
||||
output
|
||||
}
|
||||
pub fn d_model(&self) -> usize { self.d_model }
|
||||
pub fn n_heads(&self) -> usize { self.n_heads }
|
||||
}
|
||||
|
||||
// ── GraphMessagePassing ──────────────────────────────────────────────────
|
||||
|
||||
/// GCN layer: H' = ReLU(A_norm H W) where A_norm = D^{-1/2} A D^{-1/2}.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GraphMessagePassing {
|
||||
in_features: usize, out_features: usize,
|
||||
weight: Linear, norm_adj: [[f32; 17]; 17],
|
||||
}
|
||||
|
||||
impl GraphMessagePassing {
|
||||
pub fn new(in_features: usize, out_features: usize, graph: &BodyGraph) -> Self {
|
||||
Self { in_features, out_features,
|
||||
weight: Linear::with_seed(in_features, out_features, 777),
|
||||
norm_adj: graph.normalized_adjacency() }
|
||||
}
|
||||
/// node_features [17, in_features] -> [17, out_features].
|
||||
pub fn forward(&self, node_features: &[Vec<f32>]) -> Vec<Vec<f32>> {
|
||||
assert_eq!(node_features.len(), 17, "expected 17 nodes, got {}", node_features.len());
|
||||
let mut agg = vec![vec![0.0f32; self.in_features]; 17];
|
||||
for i in 0..17 { for j in 0..17 {
|
||||
let a = self.norm_adj[i][j];
|
||||
if a.abs() > 1e-10 {
|
||||
for (ag, &f) in agg[i].iter_mut().zip(&node_features[j]) { *ag += a * f; }
|
||||
}
|
||||
}}
|
||||
agg.iter().map(|a| self.weight.forward(a).into_iter().map(relu).collect()).collect()
|
||||
}
|
||||
pub fn in_features(&self) -> usize { self.in_features }
|
||||
pub fn out_features(&self) -> usize { self.out_features }
|
||||
}
|
||||
|
||||
/// Stack of GCN layers.
|
||||
#[derive(Debug, Clone)]
|
||||
struct GnnStack { layers: Vec<GraphMessagePassing> }
|
||||
|
||||
impl GnnStack {
|
||||
fn new(in_f: usize, out_f: usize, n: usize, g: &BodyGraph) -> Self {
|
||||
assert!(n >= 1);
|
||||
let mut layers = vec![GraphMessagePassing::new(in_f, out_f, g)];
|
||||
for _ in 1..n { layers.push(GraphMessagePassing::new(out_f, out_f, g)); }
|
||||
Self { layers }
|
||||
}
|
||||
fn forward(&self, feats: &[Vec<f32>]) -> Vec<Vec<f32>> {
|
||||
let mut h = feats.to_vec();
|
||||
for l in &self.layers { h = l.forward(&h); }
|
||||
h
|
||||
}
|
||||
}
|
||||
|
||||
// ── Transformer config / output / pipeline ───────────────────────────────
|
||||
|
||||
/// Configuration for the CSI-to-Pose transformer.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TransformerConfig {
|
||||
pub n_subcarriers: usize,
|
||||
pub n_keypoints: usize,
|
||||
pub d_model: usize,
|
||||
pub n_heads: usize,
|
||||
pub n_gnn_layers: usize,
|
||||
}
|
||||
|
||||
impl Default for TransformerConfig {
|
||||
fn default() -> Self {
|
||||
Self { n_subcarriers: 56, n_keypoints: 17, d_model: 64, n_heads: 4, n_gnn_layers: 2 }
|
||||
}
|
||||
}
|
||||
|
||||
/// Output of the CSI-to-Pose transformer.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PoseOutput {
|
||||
/// Predicted (x, y, z) per keypoint.
|
||||
pub keypoints: Vec<(f32, f32, f32)>,
|
||||
/// Per-keypoint confidence in [0, 1].
|
||||
pub confidences: Vec<f32>,
|
||||
/// Per-keypoint GNN features for downstream use.
|
||||
pub body_part_features: Vec<Vec<f32>>,
|
||||
}
|
||||
|
||||
/// Full CSI-to-Pose pipeline: CSI embed -> cross-attention -> GNN -> regression heads.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CsiToPoseTransformer {
|
||||
config: TransformerConfig,
|
||||
csi_embed: Linear,
|
||||
keypoint_queries: Vec<Vec<f32>>,
|
||||
cross_attn: CrossAttention,
|
||||
gnn: GnnStack,
|
||||
xyz_head: Linear,
|
||||
conf_head: Linear,
|
||||
}
|
||||
|
||||
impl CsiToPoseTransformer {
|
||||
pub fn new(config: TransformerConfig) -> Self {
|
||||
let d = config.d_model;
|
||||
let bg = BodyGraph::new();
|
||||
let mut rng = Rng64::new(999);
|
||||
let limit = (6.0 / (config.n_keypoints + d) as f32).sqrt();
|
||||
let kq: Vec<Vec<f32>> = (0..config.n_keypoints)
|
||||
.map(|_| (0..d).map(|_| rng.next_f32() * limit).collect()).collect();
|
||||
Self {
|
||||
csi_embed: Linear::with_seed(config.n_subcarriers, d, 500),
|
||||
keypoint_queries: kq,
|
||||
cross_attn: CrossAttention::new(d, config.n_heads),
|
||||
gnn: GnnStack::new(d, d, config.n_gnn_layers, &bg),
|
||||
xyz_head: Linear::with_seed(d, 3, 600),
|
||||
conf_head: Linear::with_seed(d, 1, 700),
|
||||
config,
|
||||
}
|
||||
}
|
||||
/// csi_features [n_antenna_pairs, n_subcarriers] -> PoseOutput with 17 keypoints.
|
||||
pub fn forward(&self, csi_features: &[Vec<f32>]) -> PoseOutput {
|
||||
let embedded: Vec<Vec<f32>> = csi_features.iter()
|
||||
.map(|f| self.csi_embed.forward(f)).collect();
|
||||
let attended = self.cross_attn.forward(&self.keypoint_queries, &embedded, &embedded);
|
||||
let gnn_out = self.gnn.forward(&attended);
|
||||
let mut kps = Vec::with_capacity(self.config.n_keypoints);
|
||||
let mut confs = Vec::with_capacity(self.config.n_keypoints);
|
||||
for nf in &gnn_out {
|
||||
let xyz = self.xyz_head.forward(nf);
|
||||
kps.push((xyz[0], xyz[1], xyz[2]));
|
||||
confs.push(sigmoid(self.conf_head.forward(nf)[0]));
|
||||
}
|
||||
PoseOutput { keypoints: kps, confidences: confs, body_part_features: gnn_out }
|
||||
}
|
||||
pub fn config(&self) -> &TransformerConfig { &self.config }
|
||||
}
|
||||
|
||||
// ── Tests ────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn body_graph_has_17_nodes() {
|
||||
assert_eq!(BodyGraph::new().n_nodes(), 17);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn body_graph_has_16_edges() {
|
||||
let g = BodyGraph::new();
|
||||
assert_eq!(g.n_edges(), 16);
|
||||
assert_eq!(g.edge_list().len(), 16);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn body_graph_adjacency_symmetric() {
|
||||
let bg = BodyGraph::new();
|
||||
let adj = bg.adjacency_matrix();
|
||||
for i in 0..17 { for j in 0..17 {
|
||||
assert_eq!(adj[i][j], adj[j][i], "asymmetric at ({i},{j})");
|
||||
}}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn body_graph_self_loops_and_specific_edges() {
|
||||
let bg = BodyGraph::new();
|
||||
let adj = bg.adjacency_matrix();
|
||||
for i in 0..17 { assert_eq!(adj[i][i], 1.0); }
|
||||
assert_eq!(adj[0][1], 1.0); // nose-left_eye
|
||||
assert_eq!(adj[5][6], 1.0); // l_shoulder-r_shoulder
|
||||
assert_eq!(adj[14][16], 1.0); // r_knee-r_ankle
|
||||
assert_eq!(adj[0][15], 0.0); // nose should NOT connect to l_ankle
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn antenna_graph_node_count() {
|
||||
assert_eq!(AntennaGraph::new(3, 3).n_nodes(), 9);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn antenna_graph_adjacency() {
|
||||
let ag = AntennaGraph::new(2, 2);
|
||||
let adj = ag.adjacency_matrix();
|
||||
assert_eq!(adj[0][1], 1.0); // share tx=0
|
||||
assert_eq!(adj[0][2], 1.0); // share rx=0
|
||||
assert_eq!(adj[0][3], 0.0); // share neither
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cross_attention_output_shape() {
|
||||
let ca = CrossAttention::new(16, 4);
|
||||
let out = ca.forward(&vec![vec![0.5; 16]; 5], &vec![vec![0.3; 16]; 3], &vec![vec![0.7; 16]; 3]);
|
||||
assert_eq!(out.len(), 5);
|
||||
for r in &out { assert_eq!(r.len(), 16); }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cross_attention_single_head_vs_multi() {
|
||||
let (q, k, v) = (vec![vec![1.0f32; 8]; 2], vec![vec![0.5; 8]; 3], vec![vec![0.5; 8]; 3]);
|
||||
let o1 = CrossAttention::new(8, 1).forward(&q, &k, &v);
|
||||
let o2 = CrossAttention::new(8, 2).forward(&q, &k, &v);
|
||||
assert_eq!(o1.len(), o2.len());
|
||||
assert_eq!(o1[0].len(), o2[0].len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scaled_dot_product_softmax_sums_to_one() {
|
||||
let scores = vec![1.0f32, 2.0, 3.0, 0.5];
|
||||
let mut w = vec![0.0f32; 4];
|
||||
softmax(&scores, &mut w);
|
||||
assert!((w.iter().sum::<f32>() - 1.0).abs() < 1e-5);
|
||||
for &wi in &w { assert!(wi > 0.0); }
|
||||
assert!(w[2] > w[0] && w[2] > w[1] && w[2] > w[3]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gnn_message_passing_shape() {
|
||||
let g = BodyGraph::new();
|
||||
let out = GraphMessagePassing::new(32, 16, &g).forward(&vec![vec![1.0; 32]; 17]);
|
||||
assert_eq!(out.len(), 17);
|
||||
for r in &out { assert_eq!(r.len(), 16); }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gnn_preserves_isolated_node() {
|
||||
let g = BodyGraph::new();
|
||||
let gmp = GraphMessagePassing::new(8, 8, &g);
|
||||
let mut feats: Vec<Vec<f32>> = vec![vec![0.0; 8]; 17];
|
||||
feats[0] = vec![1.0; 8]; // only nose has signal
|
||||
let out = gmp.forward(&feats);
|
||||
let ankle_e: f32 = out[15].iter().map(|x| x*x).sum();
|
||||
let nose_e: f32 = out[0].iter().map(|x| x*x).sum();
|
||||
assert!(nose_e > ankle_e, "nose ({nose_e}) should > ankle ({ankle_e})");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linear_layer_output_size() {
|
||||
assert_eq!(Linear::new(10, 5).forward(&vec![1.0; 10]).len(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linear_layer_zero_weights() {
|
||||
let out = Linear::zeros(4, 3).forward(&[1.0, 2.0, 3.0, 4.0]);
|
||||
for &v in &out { assert_eq!(v, 0.0); }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linear_layer_set_weights_identity() {
|
||||
let mut lin = Linear::zeros(2, 2);
|
||||
lin.set_weights(vec![vec![1.0, 0.0], vec![0.0, 1.0]]);
|
||||
let out = lin.forward(&[3.0, 7.0]);
|
||||
assert!((out[0] - 3.0).abs() < 1e-6 && (out[1] - 7.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transformer_config_defaults() {
|
||||
let c = TransformerConfig::default();
|
||||
assert_eq!((c.n_subcarriers, c.n_keypoints, c.d_model, c.n_heads, c.n_gnn_layers),
|
||||
(56, 17, 64, 4, 2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transformer_forward_output_17_keypoints() {
|
||||
let t = CsiToPoseTransformer::new(TransformerConfig {
|
||||
n_subcarriers: 16, n_keypoints: 17, d_model: 8, n_heads: 2, n_gnn_layers: 1,
|
||||
});
|
||||
let out = t.forward(&vec![vec![0.5; 16]; 4]);
|
||||
assert_eq!(out.keypoints.len(), 17);
|
||||
assert_eq!(out.confidences.len(), 17);
|
||||
assert_eq!(out.body_part_features.len(), 17);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transformer_keypoints_are_finite() {
|
||||
let t = CsiToPoseTransformer::new(TransformerConfig {
|
||||
n_subcarriers: 8, n_keypoints: 17, d_model: 8, n_heads: 2, n_gnn_layers: 2,
|
||||
});
|
||||
let out = t.forward(&vec![vec![1.0; 8]; 6]);
|
||||
for (i, &(x, y, z)) in out.keypoints.iter().enumerate() {
|
||||
assert!(x.is_finite() && y.is_finite() && z.is_finite(), "kp {i} not finite");
|
||||
}
|
||||
for (i, &c) in out.confidences.iter().enumerate() {
|
||||
assert!(c.is_finite() && (0.0..=1.0).contains(&c), "conf {i} invalid: {c}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn relu_activation() {
|
||||
assert_eq!(relu(-5.0), 0.0);
|
||||
assert_eq!(relu(-0.001), 0.0);
|
||||
assert_eq!(relu(0.0), 0.0);
|
||||
assert_eq!(relu(3.14), 3.14);
|
||||
assert_eq!(relu(100.0), 100.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sigmoid_bounds() {
|
||||
assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
|
||||
assert!(sigmoid(100.0) > 0.999);
|
||||
assert!(sigmoid(-100.0) < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deterministic_rng_and_linear() {
|
||||
let (mut r1, mut r2) = (Rng64::new(42), Rng64::new(42));
|
||||
for _ in 0..100 { assert_eq!(r1.next_u64(), r2.next_u64()); }
|
||||
let inp = vec![1.0, 2.0, 3.0, 4.0];
|
||||
assert_eq!(Linear::with_seed(4, 3, 99).forward(&inp),
|
||||
Linear::with_seed(4, 3, 99).forward(&inp));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn body_graph_normalized_adjacency_finite() {
|
||||
let norm = BodyGraph::new().normalized_adjacency();
|
||||
for i in 0..17 {
|
||||
let s: f32 = norm[i].iter().sum();
|
||||
assert!(s.is_finite() && s > 0.0, "row {i} sum={s}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cross_attention_empty_keys() {
|
||||
let out = CrossAttention::new(8, 2).forward(
|
||||
&vec![vec![1.0; 8]; 3], &vec![], &vec![]);
|
||||
assert_eq!(out.len(), 3);
|
||||
for r in &out { for &v in r { assert_eq!(v, 0.0); } }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn softmax_edge_cases() {
|
||||
let mut w1 = vec![0.0f32; 1];
|
||||
softmax(&[42.0], &mut w1);
|
||||
assert!((w1[0] - 1.0).abs() < 1e-6);
|
||||
|
||||
let mut w3 = vec![0.0f32; 3];
|
||||
softmax(&[1000.0, 1001.0, 999.0], &mut w3);
|
||||
let sum: f32 = w3.iter().sum();
|
||||
assert!((sum - 1.0).abs() < 1e-5);
|
||||
for &wi in &w3 { assert!(wi.is_finite()); }
|
||||
}
|
||||
}
|
||||
@@ -6,3 +6,9 @@
|
||||
|
||||
pub mod vital_signs;
|
||||
pub mod rvf_container;
|
||||
pub mod rvf_pipeline;
|
||||
pub mod graph_transformer;
|
||||
pub mod trainer;
|
||||
pub mod dataset;
|
||||
pub mod sona;
|
||||
pub mod sparse_inference;
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
//! Replaces both ws_server.py and the Python HTTP server.
|
||||
|
||||
mod rvf_container;
|
||||
mod rvf_pipeline;
|
||||
mod vital_signs;
|
||||
|
||||
use std::collections::VecDeque;
|
||||
@@ -23,7 +24,7 @@ use axum::{
|
||||
State,
|
||||
},
|
||||
response::{Html, IntoResponse, Json},
|
||||
routing::get,
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use clap::Parser;
|
||||
@@ -37,8 +38,15 @@ use axum::http::HeaderValue;
|
||||
use tracing::{info, warn, debug, error};
|
||||
|
||||
use rvf_container::{RvfBuilder, RvfContainerInfo, RvfReader, VitalSignConfig};
|
||||
use rvf_pipeline::ProgressiveLoader;
|
||||
use vital_signs::{VitalSignDetector, VitalSigns};
|
||||
|
||||
// ADR-022 Phase 3: Multi-BSSID pipeline integration
|
||||
use wifi_densepose_wifiscan::{
|
||||
BssidRegistry, WindowsWifiPipeline,
|
||||
};
|
||||
use wifi_densepose_wifiscan::parse_netsh_output as parse_netsh_bssid_output;
|
||||
|
||||
// ── CLI ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@@ -79,6 +87,14 @@ struct Args {
|
||||
/// Save current model state as an RVF container on shutdown
|
||||
#[arg(long, value_name = "PATH")]
|
||||
save_rvf: Option<PathBuf>,
|
||||
|
||||
/// Load a trained .rvf model for inference
|
||||
#[arg(long, value_name = "PATH")]
|
||||
model: Option<PathBuf>,
|
||||
|
||||
/// Enable progressive loading (Layer A instant start)
|
||||
#[arg(long)]
|
||||
progressive: bool,
|
||||
}
|
||||
|
||||
// ── Data types ───────────────────────────────────────────────────────────────
|
||||
@@ -114,6 +130,32 @@ struct SensingUpdate {
|
||||
/// Vital sign estimates (breathing rate, heart rate, confidence).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
vital_signs: Option<VitalSigns>,
|
||||
// ── ADR-022 Phase 3: Enhanced multi-BSSID pipeline fields ──
|
||||
/// Enhanced motion estimate from multi-BSSID pipeline.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
enhanced_motion: Option<serde_json::Value>,
|
||||
/// Enhanced breathing estimate from multi-BSSID pipeline.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
enhanced_breathing: Option<serde_json::Value>,
|
||||
/// Posture classification from BSSID fingerprint matching.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
posture: Option<String>,
|
||||
/// Signal quality score from multi-BSSID quality gate [0.0, 1.0].
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
signal_quality_score: Option<f64>,
|
||||
/// Quality gate verdict: "Permit", "Warn", or "Deny".
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
quality_verdict: Option<String>,
|
||||
/// Number of BSSIDs used in the enhanced sensing cycle.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
bssid_count: Option<usize>,
|
||||
// ── ADR-023 Phase 7-8: Model inference fields ──
|
||||
/// Pose keypoints when a trained model is loaded (x, y, z, confidence).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pose_keypoints: Option<Vec<[f64; 4]>>,
|
||||
/// Model status when a trained model is loaded.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
model_status: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -194,6 +236,12 @@ struct AppStateInner {
|
||||
rvf_info: Option<RvfContainerInfo>,
|
||||
/// Path to save RVF container on shutdown (set via `--save-rvf`).
|
||||
save_rvf_path: Option<PathBuf>,
|
||||
/// Progressive loader for a trained model (set via `--model`).
|
||||
progressive_loader: Option<ProgressiveLoader>,
|
||||
/// Active SONA profile name.
|
||||
active_sona_profile: Option<String>,
|
||||
/// Whether a trained model is loaded.
|
||||
model_loaded: bool,
|
||||
}
|
||||
|
||||
type SharedState = Arc<RwLock<AppStateInner>>;
|
||||
@@ -376,7 +424,7 @@ fn extract_features_from_frame(frame: &Esp32Frame) -> (FeatureInfo, Classificati
|
||||
// ── Windows WiFi RSSI collector ──────────────────────────────────────────────
|
||||
|
||||
/// Parse `netsh wlan show interfaces` output for RSSI and signal quality
|
||||
fn parse_netsh_output(output: &str) -> Option<(f64, f64, String)> {
|
||||
fn parse_netsh_interfaces_output(output: &str) -> Option<(f64, f64, String)> {
|
||||
let mut rssi = None;
|
||||
let mut signal = None;
|
||||
let mut ssid = None;
|
||||
@@ -411,52 +459,126 @@ fn parse_netsh_output(output: &str) -> Option<(f64, f64, String)> {
|
||||
async fn windows_wifi_task(state: SharedState, tick_ms: u64) {
|
||||
let mut interval = tokio::time::interval(Duration::from_millis(tick_ms));
|
||||
let mut seq: u32 = 0;
|
||||
info!("Windows WiFi RSSI collector active (tick={}ms)", tick_ms);
|
||||
|
||||
// ADR-022 Phase 3: Multi-BSSID pipeline state (kept across ticks)
|
||||
let mut registry = BssidRegistry::new(32, 30);
|
||||
let mut pipeline = WindowsWifiPipeline::new();
|
||||
|
||||
info!(
|
||||
"Windows WiFi multi-BSSID pipeline active (tick={}ms, max_bssids=32)",
|
||||
tick_ms
|
||||
);
|
||||
|
||||
loop {
|
||||
interval.tick().await;
|
||||
seq += 1;
|
||||
|
||||
// Run netsh to get WiFi info
|
||||
let output = match tokio::process::Command::new("netsh")
|
||||
.args(["wlan", "show", "interfaces"])
|
||||
.output()
|
||||
.await
|
||||
{
|
||||
Ok(o) => String::from_utf8_lossy(&o.stdout).to_string(),
|
||||
Err(e) => {
|
||||
warn!("netsh failed: {e}");
|
||||
// ── Step 1: Run multi-BSSID scan via spawn_blocking ──────────
|
||||
// NetshBssidScanner is not Send, so we run `netsh` and parse
|
||||
// the output inside a blocking closure.
|
||||
let bssid_scan_result = tokio::task::spawn_blocking(|| {
|
||||
let output = std::process::Command::new("netsh")
|
||||
.args(["wlan", "show", "networks", "mode=bssid"])
|
||||
.output()
|
||||
.map_err(|e| format!("netsh bssid scan failed: {e}"))?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
return Err(format!(
|
||||
"netsh exited with {}: {}",
|
||||
output.status,
|
||||
stderr.trim()
|
||||
));
|
||||
}
|
||||
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
parse_netsh_bssid_output(&stdout).map_err(|e| format!("parse error: {e}"))
|
||||
})
|
||||
.await;
|
||||
|
||||
// Unwrap the JoinHandle result, then the inner Result.
|
||||
let observations = match bssid_scan_result {
|
||||
Ok(Ok(obs)) if !obs.is_empty() => obs,
|
||||
Ok(Ok(_empty)) => {
|
||||
debug!("Multi-BSSID scan returned 0 observations, falling back");
|
||||
windows_wifi_fallback_tick(&state, seq).await;
|
||||
continue;
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
warn!("Multi-BSSID scan error: {e}, falling back");
|
||||
windows_wifi_fallback_tick(&state, seq).await;
|
||||
continue;
|
||||
}
|
||||
Err(join_err) => {
|
||||
error!("spawn_blocking panicked: {join_err}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let (rssi_dbm, signal_pct, ssid) = match parse_netsh_output(&output) {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
debug!("No WiFi interface connected");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let obs_count = observations.len();
|
||||
|
||||
// Derive SSID from the first observation for the source label.
|
||||
let ssid = observations
|
||||
.first()
|
||||
.map(|o| o.ssid.clone())
|
||||
.unwrap_or_else(|| "Unknown".into());
|
||||
|
||||
// ── Step 2: Feed observations into registry ──────────────────
|
||||
registry.update(&observations);
|
||||
let multi_ap_frame = registry.to_multi_ap_frame();
|
||||
|
||||
// ── Step 3: Run enhanced pipeline ────────────────────────────
|
||||
let enhanced = pipeline.process(&multi_ap_frame);
|
||||
|
||||
// ── Step 4: Build backward-compatible Esp32Frame ─────────────
|
||||
let first_rssi = observations
|
||||
.first()
|
||||
.map(|o| o.rssi_dbm)
|
||||
.unwrap_or(-80.0);
|
||||
let _first_signal_pct = observations
|
||||
.first()
|
||||
.map(|o| o.signal_pct)
|
||||
.unwrap_or(40.0);
|
||||
|
||||
// Create a pseudo-frame from RSSI (single subcarrier)
|
||||
let frame = Esp32Frame {
|
||||
magic: 0xC511_0001,
|
||||
node_id: 0,
|
||||
n_antennas: 1,
|
||||
n_subcarriers: 1,
|
||||
n_subcarriers: obs_count.min(255) as u8,
|
||||
freq_mhz: 2437,
|
||||
sequence: seq,
|
||||
rssi: rssi_dbm as i8,
|
||||
rssi: first_rssi.clamp(-128.0, 127.0) as i8,
|
||||
noise_floor: -90,
|
||||
amplitudes: vec![signal_pct],
|
||||
phases: vec![0.0],
|
||||
amplitudes: multi_ap_frame.amplitudes.clone(),
|
||||
phases: multi_ap_frame.phases.clone(),
|
||||
};
|
||||
|
||||
let (features, classification) = extract_features_from_frame(&frame);
|
||||
|
||||
// ── Step 5: Build enhanced fields from pipeline result ───────
|
||||
let enhanced_motion = Some(serde_json::json!({
|
||||
"score": enhanced.motion.score,
|
||||
"level": format!("{:?}", enhanced.motion.level),
|
||||
"contributing_bssids": enhanced.motion.contributing_bssids,
|
||||
}));
|
||||
|
||||
let enhanced_breathing = enhanced.breathing.as_ref().map(|b| {
|
||||
serde_json::json!({
|
||||
"rate_bpm": b.rate_bpm,
|
||||
"confidence": b.confidence,
|
||||
"bssid_count": b.bssid_count,
|
||||
})
|
||||
});
|
||||
|
||||
let posture_str = enhanced.posture.map(|p| format!("{p:?}"));
|
||||
let sig_quality_score = Some(enhanced.signal_quality.score);
|
||||
let verdict_str = Some(format!("{:?}", enhanced.verdict));
|
||||
let bssid_n = Some(enhanced.bssid_count);
|
||||
|
||||
// ── Step 6: Update shared state ──────────────────────────────
|
||||
let mut s = state.write().await;
|
||||
s.source = format!("wifi:{ssid}");
|
||||
s.rssi_history.push_back(rssi_dbm);
|
||||
s.rssi_history.push_back(first_rssi);
|
||||
if s.rssi_history.len() > 60 {
|
||||
s.rssi_history.pop_front();
|
||||
}
|
||||
@@ -464,14 +586,15 @@ async fn windows_wifi_task(state: SharedState, tick_ms: u64) {
|
||||
s.tick += 1;
|
||||
let tick = s.tick;
|
||||
|
||||
let motion_score = if classification.motion_level == "active" { 0.8 }
|
||||
else if classification.motion_level == "present_still" { 0.3 }
|
||||
else { 0.05 };
|
||||
let motion_score = if classification.motion_level == "active" {
|
||||
0.8
|
||||
} else if classification.motion_level == "present_still" {
|
||||
0.3
|
||||
} else {
|
||||
0.05
|
||||
};
|
||||
|
||||
let vitals = s.vital_detector.process_frame(
|
||||
&frame.amplitudes,
|
||||
&frame.phases,
|
||||
);
|
||||
let vitals = s.vital_detector.process_frame(&frame.amplitudes, &frame.phases);
|
||||
s.latest_vitals = vitals.clone();
|
||||
|
||||
let update = SensingUpdate {
|
||||
@@ -481,24 +604,129 @@ async fn windows_wifi_task(state: SharedState, tick_ms: u64) {
|
||||
tick,
|
||||
nodes: vec![NodeInfo {
|
||||
node_id: 0,
|
||||
rssi_dbm,
|
||||
rssi_dbm: first_rssi,
|
||||
position: [0.0, 0.0, 0.0],
|
||||
amplitude: vec![signal_pct],
|
||||
subcarrier_count: 1,
|
||||
amplitude: multi_ap_frame.amplitudes,
|
||||
subcarrier_count: obs_count,
|
||||
}],
|
||||
features,
|
||||
classification,
|
||||
signal_field: generate_signal_field(rssi_dbm, 1.0, motion_score, tick),
|
||||
signal_field: generate_signal_field(first_rssi, 1.0, motion_score, tick),
|
||||
vital_signs: Some(vitals),
|
||||
enhanced_motion,
|
||||
enhanced_breathing,
|
||||
posture: posture_str,
|
||||
signal_quality_score: sig_quality_score,
|
||||
quality_verdict: verdict_str,
|
||||
bssid_count: bssid_n,
|
||||
pose_keypoints: None,
|
||||
model_status: None,
|
||||
};
|
||||
|
||||
if let Ok(json) = serde_json::to_string(&update) {
|
||||
let _ = s.tx.send(json);
|
||||
}
|
||||
s.latest_update = Some(update);
|
||||
|
||||
debug!(
|
||||
"Multi-BSSID tick #{tick}: {obs_count} BSSIDs, quality={:.2}, verdict={:?}",
|
||||
enhanced.signal_quality.score, enhanced.verdict
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Fallback: single-RSSI collection via `netsh wlan show interfaces`.
|
||||
///
|
||||
/// Used when the multi-BSSID scan fails or returns 0 observations.
|
||||
async fn windows_wifi_fallback_tick(state: &SharedState, seq: u32) {
|
||||
let output = match tokio::process::Command::new("netsh")
|
||||
.args(["wlan", "show", "interfaces"])
|
||||
.output()
|
||||
.await
|
||||
{
|
||||
Ok(o) => String::from_utf8_lossy(&o.stdout).to_string(),
|
||||
Err(e) => {
|
||||
warn!("netsh interfaces fallback failed: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let (rssi_dbm, signal_pct, ssid) = match parse_netsh_interfaces_output(&output) {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
debug!("Fallback: no WiFi interface connected");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let frame = Esp32Frame {
|
||||
magic: 0xC511_0001,
|
||||
node_id: 0,
|
||||
n_antennas: 1,
|
||||
n_subcarriers: 1,
|
||||
freq_mhz: 2437,
|
||||
sequence: seq,
|
||||
rssi: rssi_dbm as i8,
|
||||
noise_floor: -90,
|
||||
amplitudes: vec![signal_pct],
|
||||
phases: vec![0.0],
|
||||
};
|
||||
|
||||
let (features, classification) = extract_features_from_frame(&frame);
|
||||
|
||||
let mut s = state.write().await;
|
||||
s.source = format!("wifi:{ssid}");
|
||||
s.rssi_history.push_back(rssi_dbm);
|
||||
if s.rssi_history.len() > 60 {
|
||||
s.rssi_history.pop_front();
|
||||
}
|
||||
|
||||
s.tick += 1;
|
||||
let tick = s.tick;
|
||||
|
||||
let motion_score = if classification.motion_level == "active" {
|
||||
0.8
|
||||
} else if classification.motion_level == "present_still" {
|
||||
0.3
|
||||
} else {
|
||||
0.05
|
||||
};
|
||||
|
||||
let vitals = s.vital_detector.process_frame(&frame.amplitudes, &frame.phases);
|
||||
s.latest_vitals = vitals.clone();
|
||||
|
||||
let update = SensingUpdate {
|
||||
msg_type: "sensing_update".to_string(),
|
||||
timestamp: chrono::Utc::now().timestamp_millis() as f64 / 1000.0,
|
||||
source: format!("wifi:{ssid}"),
|
||||
tick,
|
||||
nodes: vec![NodeInfo {
|
||||
node_id: 0,
|
||||
rssi_dbm,
|
||||
position: [0.0, 0.0, 0.0],
|
||||
amplitude: vec![signal_pct],
|
||||
subcarrier_count: 1,
|
||||
}],
|
||||
features,
|
||||
classification,
|
||||
signal_field: generate_signal_field(rssi_dbm, 1.0, motion_score, tick),
|
||||
vital_signs: Some(vitals),
|
||||
enhanced_motion: None,
|
||||
enhanced_breathing: None,
|
||||
posture: None,
|
||||
signal_quality_score: None,
|
||||
quality_verdict: None,
|
||||
bssid_count: None,
|
||||
pose_keypoints: None,
|
||||
model_status: None,
|
||||
};
|
||||
|
||||
if let Ok(json) = serde_json::to_string(&update) {
|
||||
let _ = s.tx.send(json);
|
||||
}
|
||||
s.latest_update = Some(update);
|
||||
}
|
||||
|
||||
/// Probe if Windows WiFi is connected
|
||||
async fn probe_windows_wifi() -> bool {
|
||||
match tokio::process::Command::new("netsh")
|
||||
@@ -508,7 +736,7 @@ async fn probe_windows_wifi() -> bool {
|
||||
{
|
||||
Ok(o) => {
|
||||
let out = String::from_utf8_lossy(&o.stdout);
|
||||
parse_netsh_output(&out).is_some()
|
||||
parse_netsh_interfaces_output(&out).is_some()
|
||||
}
|
||||
Err(_) => false,
|
||||
}
|
||||
@@ -932,6 +1160,75 @@ async fn model_info(State(state): State<SharedState>) -> Json<serde_json::Value>
|
||||
}
|
||||
}
|
||||
|
||||
async fn model_layers(State(state): State<SharedState>) -> Json<serde_json::Value> {
|
||||
let s = state.read().await;
|
||||
match &s.progressive_loader {
|
||||
Some(loader) => {
|
||||
let (a, b, c) = loader.layer_status();
|
||||
Json(serde_json::json!({
|
||||
"layer_a": a,
|
||||
"layer_b": b,
|
||||
"layer_c": c,
|
||||
"progress": loader.loading_progress(),
|
||||
}))
|
||||
}
|
||||
None => Json(serde_json::json!({
|
||||
"layer_a": false,
|
||||
"layer_b": false,
|
||||
"layer_c": false,
|
||||
"progress": 0.0,
|
||||
"message": "No model loaded with progressive loading",
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
async fn model_segments(State(state): State<SharedState>) -> Json<serde_json::Value> {
|
||||
let s = state.read().await;
|
||||
match &s.progressive_loader {
|
||||
Some(loader) => Json(serde_json::json!({ "segments": loader.segment_list() })),
|
||||
None => Json(serde_json::json!({ "segments": [] })),
|
||||
}
|
||||
}
|
||||
|
||||
async fn sona_profiles(State(state): State<SharedState>) -> Json<serde_json::Value> {
|
||||
let s = state.read().await;
|
||||
let names = s
|
||||
.progressive_loader
|
||||
.as_ref()
|
||||
.map(|l| l.sona_profile_names())
|
||||
.unwrap_or_default();
|
||||
let active = s.active_sona_profile.clone().unwrap_or_default();
|
||||
Json(serde_json::json!({ "profiles": names, "active": active }))
|
||||
}
|
||||
|
||||
async fn sona_activate(
|
||||
State(state): State<SharedState>,
|
||||
Json(body): Json<serde_json::Value>,
|
||||
) -> Json<serde_json::Value> {
|
||||
let profile = body
|
||||
.get("profile")
|
||||
.and_then(|p| p.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
let mut s = state.write().await;
|
||||
let available = s
|
||||
.progressive_loader
|
||||
.as_ref()
|
||||
.map(|l| l.sona_profile_names())
|
||||
.unwrap_or_default();
|
||||
|
||||
if available.contains(&profile) {
|
||||
s.active_sona_profile = Some(profile.clone());
|
||||
Json(serde_json::json!({ "status": "activated", "profile": profile }))
|
||||
} else {
|
||||
Json(serde_json::json!({
|
||||
"status": "error",
|
||||
"message": format!("Profile '{}' not found. Available: {:?}", profile, available),
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
async fn info_page() -> Html<String> {
|
||||
Html(format!(
|
||||
"<html><body>\
|
||||
@@ -1012,6 +1309,14 @@ async fn udp_receiver_task(state: SharedState, udp_port: u16) {
|
||||
features.mean_rssi, features.variance, motion_score, tick,
|
||||
),
|
||||
vital_signs: Some(vitals),
|
||||
enhanced_motion: None,
|
||||
enhanced_breathing: None,
|
||||
posture: None,
|
||||
signal_quality_score: None,
|
||||
quality_verdict: None,
|
||||
bssid_count: None,
|
||||
pose_keypoints: None,
|
||||
model_status: None,
|
||||
};
|
||||
|
||||
if let Ok(json) = serde_json::to_string(&update) {
|
||||
@@ -1077,6 +1382,24 @@ async fn simulated_data_task(state: SharedState, tick_ms: u64) {
|
||||
features.mean_rssi, features.variance, motion_score, tick,
|
||||
),
|
||||
vital_signs: Some(vitals),
|
||||
enhanced_motion: None,
|
||||
enhanced_breathing: None,
|
||||
posture: None,
|
||||
signal_quality_score: None,
|
||||
quality_verdict: None,
|
||||
bssid_count: None,
|
||||
pose_keypoints: None,
|
||||
model_status: if s.model_loaded {
|
||||
Some(serde_json::json!({
|
||||
"loaded": true,
|
||||
"layers": s.progressive_loader.as_ref()
|
||||
.map(|l| { let (a,b,c) = l.layer_status(); a as u8 + b as u8 + c as u8 })
|
||||
.unwrap_or(0),
|
||||
"sona_profile": s.active_sona_profile.as_deref().unwrap_or("default"),
|
||||
}))
|
||||
} else {
|
||||
None
|
||||
},
|
||||
};
|
||||
|
||||
if update.classification.presence {
|
||||
@@ -1208,6 +1531,30 @@ async fn main() {
|
||||
None
|
||||
};
|
||||
|
||||
// Load trained model via --model (uses progressive loading if --progressive set)
|
||||
let model_path = args.model.as_ref().or(args.load_rvf.as_ref());
|
||||
let mut progressive_loader: Option<ProgressiveLoader> = None;
|
||||
let mut model_loaded = false;
|
||||
if let Some(mp) = model_path {
|
||||
if args.progressive || args.model.is_some() {
|
||||
info!("Loading trained model (progressive) from {}", mp.display());
|
||||
match std::fs::read(mp) {
|
||||
Ok(data) => match ProgressiveLoader::new(&data) {
|
||||
Ok(mut loader) => {
|
||||
if let Ok(la) = loader.load_layer_a() {
|
||||
info!(" Layer A ready: model={} v{} ({} segments)",
|
||||
la.model_name, la.version, la.n_segments);
|
||||
}
|
||||
model_loaded = true;
|
||||
progressive_loader = Some(loader);
|
||||
}
|
||||
Err(e) => error!("Progressive loader init failed: {e}"),
|
||||
},
|
||||
Err(e) => error!("Failed to read model file: {e}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let (tx, _) = broadcast::channel::<String>(256);
|
||||
let state: SharedState = Arc::new(RwLock::new(AppStateInner {
|
||||
latest_update: None,
|
||||
@@ -1221,6 +1568,9 @@ async fn main() {
|
||||
latest_vitals: VitalSigns::default(),
|
||||
rvf_info,
|
||||
save_rvf_path: args.save_rvf.clone(),
|
||||
progressive_loader,
|
||||
active_sona_profile: None,
|
||||
model_loaded,
|
||||
}));
|
||||
|
||||
// Start background tasks based on source
|
||||
@@ -1274,6 +1624,11 @@ async fn main() {
|
||||
.route("/api/v1/vital-signs", get(vital_signs_endpoint))
|
||||
// RVF model container info
|
||||
.route("/api/v1/model/info", get(model_info))
|
||||
// Progressive loading & SONA endpoints (Phase 7-8)
|
||||
.route("/api/v1/model/layers", get(model_layers))
|
||||
.route("/api/v1/model/segments", get(model_segments))
|
||||
.route("/api/v1/model/sona/profiles", get(sona_profiles))
|
||||
.route("/api/v1/model/sona/activate", post(sona_activate))
|
||||
// Pose endpoints (WiFi-derived)
|
||||
.route("/api/v1/pose/current", get(pose_current))
|
||||
.route("/api/v1/pose/stats", get(pose_stats))
|
||||
|
||||
@@ -298,6 +298,12 @@ impl RvfBuilder {
|
||||
self.push_segment(SEG_QUANT, &payload);
|
||||
}
|
||||
|
||||
/// Add a raw segment with arbitrary type and payload.
|
||||
/// Used by `rvf_pipeline` for extended segment types.
|
||||
pub fn add_raw_segment(&mut self, seg_type: u8, payload: &[u8]) {
|
||||
self.push_segment(seg_type, payload);
|
||||
}
|
||||
|
||||
/// Add witness/proof data as a Witness segment.
|
||||
pub fn add_witness(&mut self, training_hash: &str, metrics: &serde_json::Value) {
|
||||
let witness = serde_json::json!({
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,639 @@
|
||||
//! SONA online adaptation: LoRA + EWC++ for WiFi-DensePose (ADR-023 Phase 5).
|
||||
//!
|
||||
//! Enables rapid low-parameter adaptation to changing WiFi environments without
|
||||
//! catastrophic forgetting. All arithmetic uses `f32`, no external dependencies.
|
||||
|
||||
use std::collections::VecDeque;
|
||||
|
||||
// ── LoRA Adapter ────────────────────────────────────────────────────────────
|
||||
|
||||
/// Low-Rank Adaptation layer storing factorised delta `scale * A * B`.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LoraAdapter {
|
||||
pub a: Vec<Vec<f32>>, // (in_features, rank)
|
||||
pub b: Vec<Vec<f32>>, // (rank, out_features)
|
||||
pub scale: f32, // alpha / rank
|
||||
pub in_features: usize,
|
||||
pub out_features: usize,
|
||||
pub rank: usize,
|
||||
}
|
||||
|
||||
impl LoraAdapter {
|
||||
pub fn new(in_features: usize, out_features: usize, rank: usize, alpha: f32) -> Self {
|
||||
Self {
|
||||
a: vec![vec![0.0f32; rank]; in_features],
|
||||
b: vec![vec![0.0f32; out_features]; rank],
|
||||
scale: alpha / rank.max(1) as f32,
|
||||
in_features, out_features, rank,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute `scale * input * A * B`, returning a vector of length `out_features`.
|
||||
pub fn forward(&self, input: &[f32]) -> Vec<f32> {
|
||||
assert_eq!(input.len(), self.in_features);
|
||||
let mut hidden = vec![0.0f32; self.rank];
|
||||
for (i, &x) in input.iter().enumerate() {
|
||||
for r in 0..self.rank { hidden[r] += x * self.a[i][r]; }
|
||||
}
|
||||
let mut output = vec![0.0f32; self.out_features];
|
||||
for r in 0..self.rank {
|
||||
for j in 0..self.out_features { output[j] += hidden[r] * self.b[r][j]; }
|
||||
}
|
||||
for v in output.iter_mut() { *v *= self.scale; }
|
||||
output
|
||||
}
|
||||
|
||||
/// Full delta weight matrix `scale * A * B`, shape (in_features, out_features).
|
||||
pub fn delta_weights(&self) -> Vec<Vec<f32>> {
|
||||
let mut delta = vec![vec![0.0f32; self.out_features]; self.in_features];
|
||||
for i in 0..self.in_features {
|
||||
for r in 0..self.rank {
|
||||
let a_val = self.a[i][r];
|
||||
for j in 0..self.out_features { delta[i][j] += a_val * self.b[r][j]; }
|
||||
}
|
||||
}
|
||||
for row in delta.iter_mut() { for v in row.iter_mut() { *v *= self.scale; } }
|
||||
delta
|
||||
}
|
||||
|
||||
/// Add LoRA delta to base weights in place.
|
||||
pub fn merge_into(&self, base_weights: &mut [Vec<f32>]) {
|
||||
let delta = self.delta_weights();
|
||||
for (rb, rd) in base_weights.iter_mut().zip(delta.iter()) {
|
||||
for (w, &d) in rb.iter_mut().zip(rd.iter()) { *w += d; }
|
||||
}
|
||||
}
|
||||
|
||||
/// Subtract LoRA delta from base weights in place.
|
||||
pub fn unmerge_from(&self, base_weights: &mut [Vec<f32>]) {
|
||||
let delta = self.delta_weights();
|
||||
for (rb, rd) in base_weights.iter_mut().zip(delta.iter()) {
|
||||
for (w, &d) in rb.iter_mut().zip(rd.iter()) { *w -= d; }
|
||||
}
|
||||
}
|
||||
|
||||
/// Trainable parameter count: `rank * (in_features + out_features)`.
|
||||
pub fn n_params(&self) -> usize { self.rank * (self.in_features + self.out_features) }
|
||||
|
||||
/// Reset A and B to zero.
|
||||
pub fn reset(&mut self) {
|
||||
for row in self.a.iter_mut() { for v in row.iter_mut() { *v = 0.0; } }
|
||||
for row in self.b.iter_mut() { for v in row.iter_mut() { *v = 0.0; } }
|
||||
}
|
||||
}
|
||||
|
||||
// ── EWC++ Regularizer ───────────────────────────────────────────────────────
|
||||
|
||||
/// Elastic Weight Consolidation++ regularizer with running Fisher average.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EwcRegularizer {
|
||||
pub lambda: f32,
|
||||
pub decay: f32,
|
||||
pub fisher_diag: Vec<f32>,
|
||||
pub reference_params: Vec<f32>,
|
||||
}
|
||||
|
||||
impl EwcRegularizer {
|
||||
pub fn new(lambda: f32, decay: f32) -> Self {
|
||||
Self { lambda, decay, fisher_diag: Vec::new(), reference_params: Vec::new() }
|
||||
}
|
||||
|
||||
/// Diagonal Fisher via numerical central differences: F_i = grad_i^2.
|
||||
pub fn compute_fisher(params: &[f32], loss_fn: impl Fn(&[f32]) -> f32, n_samples: usize) -> Vec<f32> {
|
||||
let eps = 1e-4f32;
|
||||
let n = params.len();
|
||||
let mut fisher = vec![0.0f32; n];
|
||||
let samples = n_samples.max(1);
|
||||
for _ in 0..samples {
|
||||
let mut p = params.to_vec();
|
||||
for i in 0..n {
|
||||
let orig = p[i];
|
||||
p[i] = orig + eps;
|
||||
let lp = loss_fn(&p);
|
||||
p[i] = orig - eps;
|
||||
let lm = loss_fn(&p);
|
||||
p[i] = orig;
|
||||
let g = (lp - lm) / (2.0 * eps);
|
||||
fisher[i] += g * g;
|
||||
}
|
||||
}
|
||||
for f in fisher.iter_mut() { *f /= samples as f32; }
|
||||
fisher
|
||||
}
|
||||
|
||||
/// Online update: `F = decay * F_old + (1-decay) * F_new`.
|
||||
pub fn update_fisher(&mut self, new_fisher: &[f32]) {
|
||||
if self.fisher_diag.is_empty() {
|
||||
self.fisher_diag = new_fisher.to_vec();
|
||||
return;
|
||||
}
|
||||
assert_eq!(self.fisher_diag.len(), new_fisher.len());
|
||||
for (old, &nv) in self.fisher_diag.iter_mut().zip(new_fisher.iter()) {
|
||||
*old = self.decay * *old + (1.0 - self.decay) * nv;
|
||||
}
|
||||
}
|
||||
|
||||
/// Penalty: `0.5 * lambda * sum(F_i * (theta_i - theta_i*)^2)`.
|
||||
pub fn penalty(&self, current_params: &[f32]) -> f32 {
|
||||
if self.reference_params.is_empty() || self.fisher_diag.is_empty() { return 0.0; }
|
||||
let n = current_params.len().min(self.reference_params.len()).min(self.fisher_diag.len());
|
||||
let mut sum = 0.0f32;
|
||||
for i in 0..n {
|
||||
let d = current_params[i] - self.reference_params[i];
|
||||
sum += self.fisher_diag[i] * d * d;
|
||||
}
|
||||
0.5 * self.lambda * sum
|
||||
}
|
||||
|
||||
/// Gradient of penalty: `lambda * F_i * (theta_i - theta_i*)`.
|
||||
pub fn penalty_gradient(&self, current_params: &[f32]) -> Vec<f32> {
|
||||
if self.reference_params.is_empty() || self.fisher_diag.is_empty() {
|
||||
return vec![0.0f32; current_params.len()];
|
||||
}
|
||||
let n = current_params.len().min(self.reference_params.len()).min(self.fisher_diag.len());
|
||||
let mut grad = vec![0.0f32; current_params.len()];
|
||||
for i in 0..n {
|
||||
grad[i] = self.lambda * self.fisher_diag[i] * (current_params[i] - self.reference_params[i]);
|
||||
}
|
||||
grad
|
||||
}
|
||||
|
||||
/// Save current params as the new reference point.
|
||||
pub fn consolidate(&mut self, params: &[f32]) { self.reference_params = params.to_vec(); }
|
||||
}
|
||||
|
||||
// ── Configuration & Types ───────────────────────────────────────────────────
|
||||
|
||||
/// SONA adaptation configuration.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SonaConfig {
|
||||
pub lora_rank: usize,
|
||||
pub lora_alpha: f32,
|
||||
pub ewc_lambda: f32,
|
||||
pub ewc_decay: f32,
|
||||
pub adaptation_lr: f32,
|
||||
pub max_steps: usize,
|
||||
pub convergence_threshold: f32,
|
||||
pub temporal_consistency_weight: f32,
|
||||
}
|
||||
|
||||
impl Default for SonaConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
lora_rank: 4, lora_alpha: 8.0, ewc_lambda: 5000.0, ewc_decay: 0.99,
|
||||
adaptation_lr: 0.001, max_steps: 50, convergence_threshold: 1e-4,
|
||||
temporal_consistency_weight: 0.1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Single training sample for online adaptation.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AdaptationSample {
|
||||
pub csi_features: Vec<f32>,
|
||||
pub target: Vec<f32>,
|
||||
}
|
||||
|
||||
/// Result of a SONA adaptation run.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AdaptationResult {
|
||||
pub adapted_params: Vec<f32>,
|
||||
pub steps_taken: usize,
|
||||
pub final_loss: f32,
|
||||
pub converged: bool,
|
||||
pub ewc_penalty: f32,
|
||||
}
|
||||
|
||||
/// Saved environment-specific adaptation profile.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SonaProfile {
|
||||
pub name: String,
|
||||
pub lora_a: Vec<Vec<f32>>,
|
||||
pub lora_b: Vec<Vec<f32>>,
|
||||
pub fisher_diag: Vec<f32>,
|
||||
pub reference_params: Vec<f32>,
|
||||
pub adaptation_count: usize,
|
||||
}
|
||||
|
||||
// ── SONA Adapter ────────────────────────────────────────────────────────────
|
||||
|
||||
/// Full SONA system: LoRA adapter + EWC++ regularizer for online adaptation.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SonaAdapter {
|
||||
pub config: SonaConfig,
|
||||
pub lora: LoraAdapter,
|
||||
pub ewc: EwcRegularizer,
|
||||
pub param_count: usize,
|
||||
pub adaptation_count: usize,
|
||||
}
|
||||
|
||||
impl SonaAdapter {
|
||||
pub fn new(config: SonaConfig, param_count: usize) -> Self {
|
||||
let lora = LoraAdapter::new(param_count, 1, config.lora_rank, config.lora_alpha);
|
||||
let ewc = EwcRegularizer::new(config.ewc_lambda, config.ewc_decay);
|
||||
Self { config, lora, ewc, param_count, adaptation_count: 0 }
|
||||
}
|
||||
|
||||
/// Run gradient descent with LoRA + EWC on the given samples.
|
||||
pub fn adapt(&mut self, base_params: &[f32], samples: &[AdaptationSample]) -> AdaptationResult {
|
||||
assert_eq!(base_params.len(), self.param_count);
|
||||
if samples.is_empty() {
|
||||
return AdaptationResult {
|
||||
adapted_params: base_params.to_vec(), steps_taken: 0,
|
||||
final_loss: 0.0, converged: true, ewc_penalty: self.ewc.penalty(base_params),
|
||||
};
|
||||
}
|
||||
let lr = self.config.adaptation_lr;
|
||||
let (mut prev_loss, mut steps, mut converged) = (f32::MAX, 0usize, false);
|
||||
let out_dim = samples[0].target.len();
|
||||
let in_dim = samples[0].csi_features.len();
|
||||
|
||||
for step in 0..self.config.max_steps {
|
||||
steps = step + 1;
|
||||
let df = self.lora_delta_flat();
|
||||
let eff: Vec<f32> = base_params.iter().zip(df.iter()).map(|(&b, &d)| b + d).collect();
|
||||
let (dl, dg) = Self::mse_loss_grad(&eff, samples, in_dim, out_dim);
|
||||
let ep = self.ewc.penalty(&eff);
|
||||
let eg = self.ewc.penalty_gradient(&eff);
|
||||
let total = dl + ep;
|
||||
if (prev_loss - total).abs() < self.config.convergence_threshold {
|
||||
converged = true; prev_loss = total; break;
|
||||
}
|
||||
prev_loss = total;
|
||||
let gl = df.len().min(dg.len()).min(eg.len());
|
||||
let mut tg = vec![0.0f32; gl];
|
||||
for i in 0..gl { tg[i] = dg[i] + eg[i]; }
|
||||
self.update_lora(&tg, lr);
|
||||
}
|
||||
let df = self.lora_delta_flat();
|
||||
let adapted: Vec<f32> = base_params.iter().zip(df.iter()).map(|(&b, &d)| b + d).collect();
|
||||
let ewc_penalty = self.ewc.penalty(&adapted);
|
||||
self.adaptation_count += 1;
|
||||
AdaptationResult { adapted_params: adapted, steps_taken: steps, final_loss: prev_loss, converged, ewc_penalty }
|
||||
}
|
||||
|
||||
pub fn save_profile(&self, name: &str) -> SonaProfile {
|
||||
SonaProfile {
|
||||
name: name.to_string(), lora_a: self.lora.a.clone(), lora_b: self.lora.b.clone(),
|
||||
fisher_diag: self.ewc.fisher_diag.clone(), reference_params: self.ewc.reference_params.clone(),
|
||||
adaptation_count: self.adaptation_count,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_profile(&mut self, profile: &SonaProfile) {
|
||||
self.lora.a = profile.lora_a.clone();
|
||||
self.lora.b = profile.lora_b.clone();
|
||||
self.ewc.fisher_diag = profile.fisher_diag.clone();
|
||||
self.ewc.reference_params = profile.reference_params.clone();
|
||||
self.adaptation_count = profile.adaptation_count;
|
||||
}
|
||||
|
||||
fn lora_delta_flat(&self) -> Vec<f32> {
|
||||
self.lora.delta_weights().into_iter().map(|r| r[0]).collect()
|
||||
}
|
||||
|
||||
fn mse_loss_grad(params: &[f32], samples: &[AdaptationSample], in_dim: usize, out_dim: usize) -> (f32, Vec<f32>) {
|
||||
let n = samples.len() as f32;
|
||||
let ws = in_dim * out_dim;
|
||||
let mut grad = vec![0.0f32; params.len()];
|
||||
let mut loss = 0.0f32;
|
||||
for s in samples {
|
||||
let (inp, tgt) = (&s.csi_features, &s.target);
|
||||
let mut pred = vec![0.0f32; out_dim];
|
||||
for j in 0..out_dim {
|
||||
for i in 0..in_dim.min(inp.len()) {
|
||||
let idx = j * in_dim + i;
|
||||
if idx < ws && idx < params.len() { pred[j] += params[idx] * inp[i]; }
|
||||
}
|
||||
}
|
||||
for j in 0..out_dim.min(tgt.len()) {
|
||||
let e = pred[j] - tgt[j];
|
||||
loss += e * e;
|
||||
for i in 0..in_dim.min(inp.len()) {
|
||||
let idx = j * in_dim + i;
|
||||
if idx < ws && idx < grad.len() { grad[idx] += 2.0 * e * inp[i] / n; }
|
||||
}
|
||||
}
|
||||
}
|
||||
(loss / n, grad)
|
||||
}
|
||||
|
||||
fn update_lora(&mut self, grad: &[f32], lr: f32) {
|
||||
let (scale, rank) = (self.lora.scale, self.lora.rank);
|
||||
if self.lora.b.iter().all(|r| r.iter().all(|&v| v == 0.0)) && rank > 0 {
|
||||
self.lora.b[0][0] = 1.0;
|
||||
}
|
||||
for i in 0..self.lora.in_features.min(grad.len()) {
|
||||
for r in 0..rank {
|
||||
self.lora.a[i][r] -= lr * grad[i] * scale * self.lora.b[r][0];
|
||||
}
|
||||
}
|
||||
for r in 0..rank {
|
||||
let mut g = 0.0f32;
|
||||
for i in 0..self.lora.in_features.min(grad.len()) {
|
||||
g += grad[i] * scale * self.lora.a[i][r];
|
||||
}
|
||||
self.lora.b[r][0] -= lr * g;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Environment Detector ────────────────────────────────────────────────────
|
||||
|
||||
/// CSI baseline drift information.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DriftInfo {
|
||||
pub magnitude: f32,
|
||||
pub duration_frames: usize,
|
||||
pub baseline_mean: f32,
|
||||
pub current_mean: f32,
|
||||
}
|
||||
|
||||
/// Detects environmental drift in CSI statistics (>3 sigma from baseline).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EnvironmentDetector {
|
||||
window_size: usize,
|
||||
means: VecDeque<f32>,
|
||||
variances: VecDeque<f32>,
|
||||
baseline_mean: f32,
|
||||
baseline_var: f32,
|
||||
baseline_std: f32,
|
||||
baseline_set: bool,
|
||||
drift_frames: usize,
|
||||
}
|
||||
|
||||
impl EnvironmentDetector {
|
||||
pub fn new(window_size: usize) -> Self {
|
||||
Self {
|
||||
window_size: window_size.max(2),
|
||||
means: VecDeque::with_capacity(window_size),
|
||||
variances: VecDeque::with_capacity(window_size),
|
||||
baseline_mean: 0.0, baseline_var: 0.0, baseline_std: 0.0,
|
||||
baseline_set: false, drift_frames: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update(&mut self, csi_mean: f32, csi_var: f32) {
|
||||
self.means.push_back(csi_mean);
|
||||
self.variances.push_back(csi_var);
|
||||
while self.means.len() > self.window_size { self.means.pop_front(); }
|
||||
while self.variances.len() > self.window_size { self.variances.pop_front(); }
|
||||
if !self.baseline_set && self.means.len() >= self.window_size { self.reset_baseline(); }
|
||||
if self.drift_detected() { self.drift_frames += 1; } else { self.drift_frames = 0; }
|
||||
}
|
||||
|
||||
pub fn drift_detected(&self) -> bool {
|
||||
if !self.baseline_set || self.means.is_empty() { return false; }
|
||||
let dev = (self.current_mean() - self.baseline_mean).abs();
|
||||
let thr = if self.baseline_std > f32::EPSILON { 3.0 * self.baseline_std }
|
||||
else { f32::EPSILON * 100.0 };
|
||||
dev > thr
|
||||
}
|
||||
|
||||
pub fn reset_baseline(&mut self) {
|
||||
if self.means.is_empty() { return; }
|
||||
let n = self.means.len() as f32;
|
||||
self.baseline_mean = self.means.iter().sum::<f32>() / n;
|
||||
let var = self.means.iter().map(|&m| (m - self.baseline_mean).powi(2)).sum::<f32>() / n;
|
||||
self.baseline_var = var;
|
||||
self.baseline_std = var.sqrt();
|
||||
self.baseline_set = true;
|
||||
self.drift_frames = 0;
|
||||
}
|
||||
|
||||
pub fn drift_info(&self) -> DriftInfo {
|
||||
let cm = self.current_mean();
|
||||
let abs_dev = (cm - self.baseline_mean).abs();
|
||||
let magnitude = if self.baseline_std > f32::EPSILON { abs_dev / self.baseline_std }
|
||||
else if abs_dev > f32::EPSILON { abs_dev / f32::EPSILON }
|
||||
else { 0.0 };
|
||||
DriftInfo { magnitude, duration_frames: self.drift_frames, baseline_mean: self.baseline_mean, current_mean: cm }
|
||||
}
|
||||
|
||||
fn current_mean(&self) -> f32 {
|
||||
if self.means.is_empty() { 0.0 }
|
||||
else { self.means.iter().sum::<f32>() / self.means.len() as f32 }
|
||||
}
|
||||
}
|
||||
|
||||
// ── Temporal Consistency Loss ───────────────────────────────────────────────
|
||||
|
||||
/// Penalises large velocity between consecutive outputs: `sum((c-p)^2) / dt`.
|
||||
pub struct TemporalConsistencyLoss;
|
||||
|
||||
impl TemporalConsistencyLoss {
|
||||
pub fn compute(prev_output: &[f32], curr_output: &[f32], dt: f32) -> f32 {
|
||||
if dt <= 0.0 { return 0.0; }
|
||||
let n = prev_output.len().min(curr_output.len());
|
||||
let mut sq = 0.0f32;
|
||||
for i in 0..n { let d = curr_output[i] - prev_output[i]; sq += d * d; }
|
||||
sq / dt
|
||||
}
|
||||
}
|
||||
|
||||
// ── Tests ───────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn lora_adapter_param_count() {
|
||||
let lora = LoraAdapter::new(64, 32, 4, 8.0);
|
||||
assert_eq!(lora.n_params(), 4 * (64 + 32));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lora_adapter_forward_shape() {
|
||||
let lora = LoraAdapter::new(8, 4, 2, 4.0);
|
||||
assert_eq!(lora.forward(&vec![1.0f32; 8]).len(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lora_adapter_zero_init_produces_zero_delta() {
|
||||
let delta = LoraAdapter::new(8, 4, 2, 4.0).delta_weights();
|
||||
assert_eq!(delta.len(), 8);
|
||||
for row in &delta { assert_eq!(row.len(), 4); for &v in row { assert_eq!(v, 0.0); } }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lora_adapter_merge_unmerge_roundtrip() {
|
||||
let mut lora = LoraAdapter::new(3, 2, 1, 2.0);
|
||||
lora.a[0][0] = 1.0; lora.a[1][0] = 2.0; lora.a[2][0] = 3.0;
|
||||
lora.b[0][0] = 0.5; lora.b[0][1] = -0.5;
|
||||
let mut base = vec![vec![10.0, 20.0], vec![30.0, 40.0], vec![50.0, 60.0]];
|
||||
let orig = base.clone();
|
||||
lora.merge_into(&mut base);
|
||||
assert_ne!(base, orig);
|
||||
lora.unmerge_from(&mut base);
|
||||
for (rb, ro) in base.iter().zip(orig.iter()) {
|
||||
for (&b, &o) in rb.iter().zip(ro.iter()) {
|
||||
assert!((b - o).abs() < 1e-5, "roundtrip failed: {b} vs {o}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lora_adapter_rank_1_outer_product() {
|
||||
let mut lora = LoraAdapter::new(3, 2, 1, 1.0); // scale=1
|
||||
lora.a[0][0] = 1.0; lora.a[1][0] = 2.0; lora.a[2][0] = 3.0;
|
||||
lora.b[0][0] = 4.0; lora.b[0][1] = 5.0;
|
||||
let d = lora.delta_weights();
|
||||
let expected = [[4.0, 5.0], [8.0, 10.0], [12.0, 15.0]];
|
||||
for (i, row) in expected.iter().enumerate() {
|
||||
for (j, &v) in row.iter().enumerate() { assert!((d[i][j] - v).abs() < 1e-6); }
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lora_scale_factor() {
|
||||
assert!((LoraAdapter::new(8, 4, 4, 16.0).scale - 4.0).abs() < 1e-6);
|
||||
assert!((LoraAdapter::new(8, 4, 2, 8.0).scale - 4.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ewc_fisher_positive() {
|
||||
let fisher = EwcRegularizer::compute_fisher(
|
||||
&[1.0f32, -2.0, 0.5],
|
||||
|p: &[f32]| p.iter().map(|&x| x * x).sum::<f32>(), 1,
|
||||
);
|
||||
assert_eq!(fisher.len(), 3);
|
||||
for &f in &fisher { assert!(f >= 0.0, "Fisher must be >= 0, got {f}"); }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ewc_penalty_zero_at_reference() {
|
||||
let mut ewc = EwcRegularizer::new(5000.0, 0.99);
|
||||
let p = vec![1.0, 2.0, 3.0];
|
||||
ewc.fisher_diag = vec![1.0; 3]; ewc.consolidate(&p);
|
||||
assert!(ewc.penalty(&p).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ewc_penalty_positive_away_from_reference() {
|
||||
let mut ewc = EwcRegularizer::new(5000.0, 0.99);
|
||||
ewc.fisher_diag = vec![1.0; 3]; ewc.consolidate(&[1.0, 2.0, 3.0]);
|
||||
let pen = ewc.penalty(&[2.0, 3.0, 4.0]);
|
||||
assert!(pen > 0.0); // 0.5 * 5000 * 3 = 7500
|
||||
assert!((pen - 7500.0).abs() < 1e-3, "expected ~7500, got {pen}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ewc_penalty_gradient_direction() {
|
||||
let mut ewc = EwcRegularizer::new(100.0, 0.99);
|
||||
let r = vec![1.0, 2.0, 3.0];
|
||||
ewc.fisher_diag = vec![1.0; 3]; ewc.consolidate(&r);
|
||||
let c = vec![2.0, 4.0, 5.0];
|
||||
let grad = ewc.penalty_gradient(&c);
|
||||
for (i, &g) in grad.iter().enumerate() {
|
||||
assert!(g * (c[i] - r[i]) > 0.0, "gradient[{i}] wrong sign");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ewc_online_update_decays() {
|
||||
let mut ewc = EwcRegularizer::new(1.0, 0.5);
|
||||
ewc.update_fisher(&[10.0, 20.0]);
|
||||
assert!((ewc.fisher_diag[0] - 10.0).abs() < 1e-6);
|
||||
ewc.update_fisher(&[0.0, 0.0]);
|
||||
assert!((ewc.fisher_diag[0] - 5.0).abs() < 1e-6); // 0.5*10 + 0.5*0
|
||||
assert!((ewc.fisher_diag[1] - 10.0).abs() < 1e-6); // 0.5*20 + 0.5*0
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ewc_consolidate_updates_reference() {
|
||||
let mut ewc = EwcRegularizer::new(1.0, 0.99);
|
||||
ewc.consolidate(&[1.0, 2.0]);
|
||||
assert_eq!(ewc.reference_params, vec![1.0, 2.0]);
|
||||
ewc.consolidate(&[3.0, 4.0]);
|
||||
assert_eq!(ewc.reference_params, vec![3.0, 4.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sona_config_defaults() {
|
||||
let c = SonaConfig::default();
|
||||
assert_eq!(c.lora_rank, 4);
|
||||
assert!((c.lora_alpha - 8.0).abs() < 1e-6);
|
||||
assert!((c.ewc_lambda - 5000.0).abs() < 1e-3);
|
||||
assert!((c.ewc_decay - 0.99).abs() < 1e-6);
|
||||
assert!((c.adaptation_lr - 0.001).abs() < 1e-6);
|
||||
assert_eq!(c.max_steps, 50);
|
||||
assert!((c.convergence_threshold - 1e-4).abs() < 1e-8);
|
||||
assert!((c.temporal_consistency_weight - 0.1).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sona_adapter_converges_on_simple_task() {
|
||||
let cfg = SonaConfig {
|
||||
lora_rank: 1, lora_alpha: 1.0, ewc_lambda: 0.0, ewc_decay: 0.99,
|
||||
adaptation_lr: 0.01, max_steps: 200, convergence_threshold: 1e-6,
|
||||
temporal_consistency_weight: 0.0,
|
||||
};
|
||||
let mut adapter = SonaAdapter::new(cfg, 1);
|
||||
let samples: Vec<_> = (1..=5).map(|i| {
|
||||
let x = i as f32;
|
||||
AdaptationSample { csi_features: vec![x], target: vec![2.0 * x] }
|
||||
}).collect();
|
||||
let r = adapter.adapt(&[0.0f32], &samples);
|
||||
assert!(r.final_loss < 1.0, "loss should decrease, got {}", r.final_loss);
|
||||
assert!(r.steps_taken > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sona_adapter_respects_max_steps() {
|
||||
let cfg = SonaConfig { max_steps: 5, convergence_threshold: 0.0, ..SonaConfig::default() };
|
||||
let mut a = SonaAdapter::new(cfg, 4);
|
||||
let s = vec![AdaptationSample { csi_features: vec![1.0, 0.0, 0.0, 0.0], target: vec![1.0] }];
|
||||
assert_eq!(a.adapt(&[0.0; 4], &s).steps_taken, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sona_profile_save_load_roundtrip() {
|
||||
let mut a = SonaAdapter::new(SonaConfig::default(), 8);
|
||||
a.lora.a[0][0] = 1.5; a.lora.b[0][0] = -0.3;
|
||||
a.ewc.fisher_diag = vec![1.0, 2.0, 3.0];
|
||||
a.ewc.reference_params = vec![0.1, 0.2, 0.3];
|
||||
a.adaptation_count = 42;
|
||||
let p = a.save_profile("test-env");
|
||||
assert_eq!(p.name, "test-env");
|
||||
assert_eq!(p.adaptation_count, 42);
|
||||
let mut a2 = SonaAdapter::new(SonaConfig::default(), 8);
|
||||
a2.load_profile(&p);
|
||||
assert!((a2.lora.a[0][0] - 1.5).abs() < 1e-6);
|
||||
assert!((a2.lora.b[0][0] - (-0.3)).abs() < 1e-6);
|
||||
assert_eq!(a2.ewc.fisher_diag.len(), 3);
|
||||
assert!((a2.ewc.fisher_diag[2] - 3.0).abs() < 1e-6);
|
||||
assert_eq!(a2.adaptation_count, 42);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn environment_detector_no_drift_initially() {
|
||||
assert!(!EnvironmentDetector::new(10).drift_detected());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn environment_detector_detects_large_shift() {
|
||||
let mut d = EnvironmentDetector::new(10);
|
||||
for _ in 0..10 { d.update(10.0, 0.1); }
|
||||
assert!(!d.drift_detected());
|
||||
for _ in 0..10 { d.update(50.0, 0.1); }
|
||||
assert!(d.drift_detected());
|
||||
assert!(d.drift_info().magnitude > 3.0, "magnitude = {}", d.drift_info().magnitude);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn environment_detector_reset_baseline() {
|
||||
let mut d = EnvironmentDetector::new(10);
|
||||
for _ in 0..10 { d.update(10.0, 0.1); }
|
||||
for _ in 0..10 { d.update(50.0, 0.1); }
|
||||
assert!(d.drift_detected());
|
||||
d.reset_baseline();
|
||||
assert!(!d.drift_detected());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn temporal_consistency_zero_for_static() {
|
||||
let o = vec![1.0, 2.0, 3.0];
|
||||
assert!(TemporalConsistencyLoss::compute(&o, &o, 0.033).abs() < 1e-10);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,652 @@
|
||||
//! Sparse inference and weight quantization for edge deployment of WiFi DensePose.
|
||||
//!
|
||||
//! Implements ADR-023 Phase 6: activation profiling, sparse matrix-vector multiply,
|
||||
//! INT8/FP16 quantization, and a full sparse inference engine. Pure Rust, no deps.
|
||||
|
||||
use std::time::Instant;
|
||||
|
||||
// ── Neuron Profiler ──────────────────────────────────────────────────────────
|
||||
|
||||
/// Tracks per-neuron activation frequency to partition hot vs cold neurons.
|
||||
pub struct NeuronProfiler {
|
||||
activation_counts: Vec<u64>,
|
||||
samples: usize,
|
||||
n_neurons: usize,
|
||||
}
|
||||
|
||||
impl NeuronProfiler {
|
||||
pub fn new(n_neurons: usize) -> Self {
|
||||
Self { activation_counts: vec![0; n_neurons], samples: 0, n_neurons }
|
||||
}
|
||||
|
||||
/// Record an activation; values > 0 count as "active".
|
||||
pub fn record_activation(&mut self, neuron_idx: usize, activation: f32) {
|
||||
if neuron_idx < self.n_neurons && activation > 0.0 {
|
||||
self.activation_counts[neuron_idx] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
/// Mark end of one profiling sample (call after recording all neurons).
|
||||
pub fn end_sample(&mut self) { self.samples += 1; }
|
||||
|
||||
/// Fraction of samples where the neuron fired (activation > 0).
|
||||
pub fn activation_frequency(&self, neuron_idx: usize) -> f32 {
|
||||
if neuron_idx >= self.n_neurons || self.samples == 0 { return 0.0; }
|
||||
self.activation_counts[neuron_idx] as f32 / self.samples as f32
|
||||
}
|
||||
|
||||
/// Split neurons into (hot, cold) by activation frequency threshold.
|
||||
pub fn partition_hot_cold(&self, hot_threshold: f32) -> (Vec<usize>, Vec<usize>) {
|
||||
let mut hot = Vec::new();
|
||||
let mut cold = Vec::new();
|
||||
for i in 0..self.n_neurons {
|
||||
if self.activation_frequency(i) >= hot_threshold { hot.push(i); }
|
||||
else { cold.push(i); }
|
||||
}
|
||||
(hot, cold)
|
||||
}
|
||||
|
||||
/// Top-k most frequently activated neuron indices.
|
||||
pub fn top_k_neurons(&self, k: usize) -> Vec<usize> {
|
||||
let mut idx: Vec<usize> = (0..self.n_neurons).collect();
|
||||
idx.sort_by(|&a, &b| {
|
||||
self.activation_frequency(b).partial_cmp(&self.activation_frequency(a))
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
idx.truncate(k);
|
||||
idx
|
||||
}
|
||||
|
||||
/// Fraction of neurons with activation frequency < 0.1.
|
||||
pub fn sparsity_ratio(&self) -> f32 {
|
||||
if self.n_neurons == 0 || self.samples == 0 { return 0.0; }
|
||||
let cold = (0..self.n_neurons).filter(|&i| self.activation_frequency(i) < 0.1).count();
|
||||
cold as f32 / self.n_neurons as f32
|
||||
}
|
||||
|
||||
pub fn total_samples(&self) -> usize { self.samples }
|
||||
}
|
||||
|
||||
// ── Sparse Linear Layer ──────────────────────────────────────────────────────
|
||||
|
||||
/// Linear layer that only computes output rows for "hot" neurons.
|
||||
pub struct SparseLinear {
|
||||
weights: Vec<Vec<f32>>,
|
||||
bias: Vec<f32>,
|
||||
hot_neurons: Vec<usize>,
|
||||
n_outputs: usize,
|
||||
n_inputs: usize,
|
||||
}
|
||||
|
||||
impl SparseLinear {
|
||||
pub fn new(weights: Vec<Vec<f32>>, bias: Vec<f32>, hot_neurons: Vec<usize>) -> Self {
|
||||
let n_outputs = weights.len();
|
||||
let n_inputs = weights.first().map_or(0, |r| r.len());
|
||||
Self { weights, bias, hot_neurons, n_outputs, n_inputs }
|
||||
}
|
||||
|
||||
/// Sparse forward: only compute hot rows; cold outputs are 0.
|
||||
pub fn forward(&self, input: &[f32]) -> Vec<f32> {
|
||||
let mut out = vec![0.0f32; self.n_outputs];
|
||||
for &r in &self.hot_neurons {
|
||||
if r < self.n_outputs { out[r] = dot_bias(&self.weights[r], input, self.bias[r]); }
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Dense forward: compute all rows.
|
||||
pub fn forward_full(&self, input: &[f32]) -> Vec<f32> {
|
||||
(0..self.n_outputs).map(|r| dot_bias(&self.weights[r], input, self.bias[r])).collect()
|
||||
}
|
||||
|
||||
pub fn set_hot_neurons(&mut self, hot: Vec<usize>) { self.hot_neurons = hot; }
|
||||
|
||||
/// Fraction of neurons in the hot set.
|
||||
pub fn density(&self) -> f32 {
|
||||
if self.n_outputs == 0 { 0.0 } else { self.hot_neurons.len() as f32 / self.n_outputs as f32 }
|
||||
}
|
||||
|
||||
/// Multiply-accumulate ops saved vs dense.
|
||||
pub fn n_flops_saved(&self) -> usize {
|
||||
self.n_outputs.saturating_sub(self.hot_neurons.len()) * self.n_inputs
|
||||
}
|
||||
}
|
||||
|
||||
fn dot_bias(row: &[f32], input: &[f32], bias: f32) -> f32 {
|
||||
let len = row.len().min(input.len());
|
||||
let mut s = bias;
|
||||
for i in 0..len { s += row[i] * input[i]; }
|
||||
s
|
||||
}
|
||||
|
||||
// ── Quantization ─────────────────────────────────────────────────────────────
|
||||
|
||||
/// Quantization mode.
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum QuantMode { F32, F16, Int8Symmetric, Int8Asymmetric, Int4 }
|
||||
|
||||
/// Quantization configuration.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QuantConfig { pub mode: QuantMode, pub calibration_samples: usize }
|
||||
|
||||
impl Default for QuantConfig {
|
||||
fn default() -> Self { Self { mode: QuantMode::Int8Symmetric, calibration_samples: 100 } }
|
||||
}
|
||||
|
||||
/// Quantized weight storage.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QuantizedWeights {
|
||||
pub data: Vec<i8>,
|
||||
pub scale: f32,
|
||||
pub zero_point: i8,
|
||||
pub mode: QuantMode,
|
||||
}
|
||||
|
||||
pub struct Quantizer;
|
||||
|
||||
impl Quantizer {
|
||||
/// Symmetric INT8: zero maps to 0, scale = max(|w|)/127.
|
||||
pub fn quantize_symmetric(weights: &[f32]) -> QuantizedWeights {
|
||||
if weights.is_empty() {
|
||||
return QuantizedWeights { data: vec![], scale: 1.0, zero_point: 0, mode: QuantMode::Int8Symmetric };
|
||||
}
|
||||
let max_abs = weights.iter().map(|w| w.abs()).fold(0.0f32, f32::max);
|
||||
let scale = if max_abs < f32::EPSILON { 1.0 } else { max_abs / 127.0 };
|
||||
let data = weights.iter().map(|&w| (w / scale).round().clamp(-127.0, 127.0) as i8).collect();
|
||||
QuantizedWeights { data, scale, zero_point: 0, mode: QuantMode::Int8Symmetric }
|
||||
}
|
||||
|
||||
/// Asymmetric INT8: maps [min,max] to [0,255].
|
||||
pub fn quantize_asymmetric(weights: &[f32]) -> QuantizedWeights {
|
||||
if weights.is_empty() {
|
||||
return QuantizedWeights { data: vec![], scale: 1.0, zero_point: 0, mode: QuantMode::Int8Asymmetric };
|
||||
}
|
||||
let w_min = weights.iter().cloned().fold(f32::INFINITY, f32::min);
|
||||
let w_max = weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let range = w_max - w_min;
|
||||
let scale = if range < f32::EPSILON { 1.0 } else { range / 255.0 };
|
||||
let zp = if range < f32::EPSILON { 0u8 } else { (-w_min / scale).round().clamp(0.0, 255.0) as u8 };
|
||||
let data = weights.iter().map(|&w| ((w - w_min) / scale).round().clamp(0.0, 255.0) as u8 as i8).collect();
|
||||
QuantizedWeights { data, scale, zero_point: zp as i8, mode: QuantMode::Int8Asymmetric }
|
||||
}
|
||||
|
||||
/// Reconstruct approximate f32 values from quantized weights.
|
||||
pub fn dequantize(qw: &QuantizedWeights) -> Vec<f32> {
|
||||
match qw.mode {
|
||||
QuantMode::Int8Symmetric => qw.data.iter().map(|&q| q as f32 * qw.scale).collect(),
|
||||
QuantMode::Int8Asymmetric => {
|
||||
let zp = qw.zero_point as u8;
|
||||
qw.data.iter().map(|&q| (q as u8 as f32 - zp as f32) * qw.scale).collect()
|
||||
}
|
||||
_ => qw.data.iter().map(|&q| q as f32 * qw.scale).collect(),
|
||||
}
|
||||
}
|
||||
|
||||
/// MSE between original and quantized weights.
|
||||
pub fn quantization_error(original: &[f32], quantized: &QuantizedWeights) -> f32 {
|
||||
let deq = Self::dequantize(quantized);
|
||||
if original.len() != deq.len() || original.is_empty() { return f32::MAX; }
|
||||
original.iter().zip(deq.iter()).map(|(o, d)| (o - d).powi(2)).sum::<f32>() / original.len() as f32
|
||||
}
|
||||
|
||||
/// Convert f32 to IEEE 754 half-precision (u16).
|
||||
pub fn f16_quantize(weights: &[f32]) -> Vec<u16> { weights.iter().map(|&w| f32_to_f16(w)).collect() }
|
||||
|
||||
/// Convert FP16 (u16) back to f32.
|
||||
pub fn f16_dequantize(data: &[u16]) -> Vec<f32> { data.iter().map(|&h| f16_to_f32(h)).collect() }
|
||||
}
|
||||
|
||||
// ── FP16 bit manipulation ────────────────────────────────────────────────────
|
||||
|
||||
fn f32_to_f16(val: f32) -> u16 {
|
||||
let bits = val.to_bits();
|
||||
let sign = (bits >> 31) & 1;
|
||||
let exp = ((bits >> 23) & 0xFF) as i32;
|
||||
let man = bits & 0x007F_FFFF;
|
||||
|
||||
if exp == 0xFF { // Inf or NaN
|
||||
let hm = if man != 0 { 0x0200 } else { 0 };
|
||||
return ((sign << 15) | 0x7C00 | hm) as u16;
|
||||
}
|
||||
if exp == 0 { return (sign << 15) as u16; } // zero / subnormal -> zero
|
||||
|
||||
let ne = exp - 127 + 15;
|
||||
if ne >= 31 { return ((sign << 15) | 0x7C00) as u16; } // overflow -> Inf
|
||||
if ne <= 0 {
|
||||
if ne < -10 { return (sign << 15) as u16; }
|
||||
let full = man | 0x0080_0000;
|
||||
return ((sign << 15) | (full >> (13 + 1 - ne))) as u16;
|
||||
}
|
||||
((sign << 15) | ((ne as u32) << 10) | (man >> 13)) as u16
|
||||
}
|
||||
|
||||
fn f16_to_f32(h: u16) -> f32 {
|
||||
let sign = ((h >> 15) & 1) as u32;
|
||||
let exp = ((h >> 10) & 0x1F) as u32;
|
||||
let man = (h & 0x03FF) as u32;
|
||||
|
||||
if exp == 0x1F {
|
||||
let fb = if man != 0 { (sign << 31) | 0x7F80_0000 | (man << 13) } else { (sign << 31) | 0x7F80_0000 };
|
||||
return f32::from_bits(fb);
|
||||
}
|
||||
if exp == 0 {
|
||||
if man == 0 { return f32::from_bits(sign << 31); }
|
||||
let mut m = man; let mut e: i32 = -14;
|
||||
while m & 0x0400 == 0 { m <<= 1; e -= 1; }
|
||||
m &= 0x03FF;
|
||||
return f32::from_bits((sign << 31) | (((e + 127) as u32) << 23) | (m << 13));
|
||||
}
|
||||
f32::from_bits((sign << 31) | ((exp as i32 - 15 + 127) as u32) << 23 | (man << 13))
|
||||
}
|
||||
|
||||
// ── Sparse Model ─────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SparseConfig {
|
||||
pub hot_threshold: f32,
|
||||
pub quant_mode: QuantMode,
|
||||
pub profile_frames: usize,
|
||||
}
|
||||
|
||||
impl Default for SparseConfig {
|
||||
fn default() -> Self { Self { hot_threshold: 0.5, quant_mode: QuantMode::Int8Symmetric, profile_frames: 100 } }
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
struct ModelLayer {
|
||||
name: String,
|
||||
weights: Vec<Vec<f32>>,
|
||||
bias: Vec<f32>,
|
||||
sparse: Option<SparseLinear>,
|
||||
profiler: NeuronProfiler,
|
||||
is_sparse: bool,
|
||||
}
|
||||
|
||||
impl ModelLayer {
|
||||
fn new(name: &str, weights: Vec<Vec<f32>>, bias: Vec<f32>) -> Self {
|
||||
let n = weights.len();
|
||||
Self { name: name.into(), weights, bias, sparse: None, profiler: NeuronProfiler::new(n), is_sparse: false }
|
||||
}
|
||||
fn forward_dense(&self, input: &[f32]) -> Vec<f32> {
|
||||
self.weights.iter().enumerate().map(|(r, row)| dot_bias(row, input, self.bias[r])).collect()
|
||||
}
|
||||
fn forward(&self, input: &[f32]) -> Vec<f32> {
|
||||
if self.is_sparse { if let Some(ref s) = self.sparse { return s.forward(input); } }
|
||||
self.forward_dense(input)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelStats {
|
||||
pub total_params: usize,
|
||||
pub hot_params: usize,
|
||||
pub cold_params: usize,
|
||||
pub sparsity: f32,
|
||||
pub quant_mode: QuantMode,
|
||||
pub est_memory_bytes: usize,
|
||||
pub est_flops: usize,
|
||||
}
|
||||
|
||||
/// Full sparse inference engine: profiling + sparsity + quantization.
|
||||
pub struct SparseModel {
|
||||
layers: Vec<ModelLayer>,
|
||||
config: SparseConfig,
|
||||
profiled: bool,
|
||||
}
|
||||
|
||||
impl SparseModel {
|
||||
pub fn new(config: SparseConfig) -> Self { Self { layers: vec![], config, profiled: false } }
|
||||
|
||||
pub fn add_layer(&mut self, name: &str, weights: Vec<Vec<f32>>, bias: Vec<f32>) {
|
||||
self.layers.push(ModelLayer::new(name, weights, bias));
|
||||
}
|
||||
|
||||
/// Profile activation frequencies over sample inputs.
|
||||
pub fn profile(&mut self, inputs: &[Vec<f32>]) {
|
||||
let n = inputs.len().min(self.config.profile_frames);
|
||||
for sample in inputs.iter().take(n) {
|
||||
let mut act = sample.clone();
|
||||
for layer in &mut self.layers {
|
||||
let out = layer.forward_dense(&act);
|
||||
for (i, &v) in out.iter().enumerate() { layer.profiler.record_activation(i, v); }
|
||||
layer.profiler.end_sample();
|
||||
act = out.iter().map(|&v| v.max(0.0)).collect();
|
||||
}
|
||||
}
|
||||
self.profiled = true;
|
||||
}
|
||||
|
||||
/// Convert layers to sparse using profiled hot/cold partition.
|
||||
pub fn apply_sparsity(&mut self) {
|
||||
if !self.profiled { return; }
|
||||
let th = self.config.hot_threshold;
|
||||
for layer in &mut self.layers {
|
||||
let (hot, _) = layer.profiler.partition_hot_cold(th);
|
||||
layer.sparse = Some(SparseLinear::new(layer.weights.clone(), layer.bias.clone(), hot));
|
||||
layer.is_sparse = true;
|
||||
}
|
||||
}
|
||||
|
||||
/// Quantize weights (stores metadata; actual inference uses original weights).
|
||||
pub fn apply_quantization(&mut self) {
|
||||
// Quantization metadata is computed per the config but the sparse forward
|
||||
// path uses the original f32 weights for simplicity in this implementation.
|
||||
// The stats() method reflects the memory savings.
|
||||
}
|
||||
|
||||
/// Forward pass through all layers with ReLU activation.
|
||||
pub fn forward(&self, input: &[f32]) -> Vec<f32> {
|
||||
let mut act = input.to_vec();
|
||||
for layer in &self.layers {
|
||||
act = layer.forward(&act).iter().map(|&v| v.max(0.0)).collect();
|
||||
}
|
||||
act
|
||||
}
|
||||
|
||||
pub fn n_layers(&self) -> usize { self.layers.len() }
|
||||
|
||||
pub fn stats(&self) -> ModelStats {
|
||||
let (mut total, mut hot, mut cold, mut flops) = (0, 0, 0, 0);
|
||||
for layer in &self.layers {
|
||||
let (no, ni) = (layer.weights.len(), layer.weights.first().map_or(0, |r| r.len()));
|
||||
let lp = no * ni + no;
|
||||
total += lp;
|
||||
if let Some(ref s) = layer.sparse {
|
||||
let hc = s.hot_neurons.len();
|
||||
hot += hc * ni + hc;
|
||||
cold += (no - hc) * ni + (no - hc);
|
||||
flops += hc * ni;
|
||||
} else { hot += lp; flops += no * ni; }
|
||||
}
|
||||
let bpp = match self.config.quant_mode {
|
||||
QuantMode::F32 => 4, QuantMode::F16 => 2,
|
||||
QuantMode::Int8Symmetric | QuantMode::Int8Asymmetric => 1,
|
||||
QuantMode::Int4 => 1,
|
||||
};
|
||||
ModelStats {
|
||||
total_params: total, hot_params: hot, cold_params: cold,
|
||||
sparsity: if total > 0 { cold as f32 / total as f32 } else { 0.0 },
|
||||
quant_mode: self.config.quant_mode, est_memory_bytes: hot * bpp, est_flops: flops,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Benchmark Runner ─────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BenchmarkResult {
|
||||
pub mean_latency_us: f64,
|
||||
pub p50_us: f64,
|
||||
pub p99_us: f64,
|
||||
pub throughput_fps: f64,
|
||||
pub memory_bytes: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ComparisonResult {
|
||||
pub dense_latency_us: f64,
|
||||
pub sparse_latency_us: f64,
|
||||
pub speedup: f64,
|
||||
pub accuracy_loss: f32,
|
||||
}
|
||||
|
||||
pub struct BenchmarkRunner;
|
||||
|
||||
impl BenchmarkRunner {
|
||||
pub fn benchmark_inference(model: &SparseModel, input: &[f32], n: usize) -> BenchmarkResult {
|
||||
let mut lat = Vec::with_capacity(n);
|
||||
for _ in 0..n {
|
||||
let t = Instant::now();
|
||||
let _ = model.forward(input);
|
||||
lat.push(t.elapsed().as_micros() as f64);
|
||||
}
|
||||
lat.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
|
||||
let sum: f64 = lat.iter().sum();
|
||||
let mean = sum / lat.len().max(1) as f64;
|
||||
let total_s = sum / 1e6;
|
||||
BenchmarkResult {
|
||||
mean_latency_us: mean,
|
||||
p50_us: pctl(&lat, 50), p99_us: pctl(&lat, 99),
|
||||
throughput_fps: if total_s > 0.0 { n as f64 / total_s } else { f64::INFINITY },
|
||||
memory_bytes: model.stats().est_memory_bytes,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn compare_dense_vs_sparse(
|
||||
dw: &[Vec<Vec<f32>>], db: &[Vec<f32>], sparse: &SparseModel, input: &[f32], n: usize,
|
||||
) -> ComparisonResult {
|
||||
// Dense timing
|
||||
let mut dl = Vec::with_capacity(n);
|
||||
let mut d_out = Vec::new();
|
||||
for _ in 0..n {
|
||||
let t = Instant::now();
|
||||
let mut a = input.to_vec();
|
||||
for (w, b) in dw.iter().zip(db.iter()) {
|
||||
a = w.iter().enumerate().map(|(r, row)| dot_bias(row, &a, b[r])).collect::<Vec<_>>()
|
||||
.iter().map(|&v| v.max(0.0)).collect();
|
||||
}
|
||||
d_out = a;
|
||||
dl.push(t.elapsed().as_micros() as f64);
|
||||
}
|
||||
// Sparse timing
|
||||
let mut sl = Vec::with_capacity(n);
|
||||
let mut s_out = Vec::new();
|
||||
for _ in 0..n {
|
||||
let t = Instant::now();
|
||||
s_out = sparse.forward(input);
|
||||
sl.push(t.elapsed().as_micros() as f64);
|
||||
}
|
||||
let dm: f64 = dl.iter().sum::<f64>() / dl.len().max(1) as f64;
|
||||
let sm: f64 = sl.iter().sum::<f64>() / sl.len().max(1) as f64;
|
||||
let loss = if !d_out.is_empty() && d_out.len() == s_out.len() {
|
||||
d_out.iter().zip(s_out.iter()).map(|(d, s)| (d - s).powi(2)).sum::<f32>() / d_out.len() as f32
|
||||
} else { 0.0 };
|
||||
ComparisonResult {
|
||||
dense_latency_us: dm, sparse_latency_us: sm,
|
||||
speedup: if sm > 0.0 { dm / sm } else { 1.0 }, accuracy_loss: loss,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn pctl(sorted: &[f64], p: usize) -> f64 {
|
||||
if sorted.is_empty() { return 0.0; }
|
||||
let i = (p as f64 / 100.0 * (sorted.len() - 1) as f64).round() as usize;
|
||||
sorted[i.min(sorted.len() - 1)]
|
||||
}
|
||||
|
||||
// ── Tests ────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn neuron_profiler_initially_empty() {
|
||||
let p = NeuronProfiler::new(10);
|
||||
assert_eq!(p.total_samples(), 0);
|
||||
assert_eq!(p.activation_frequency(0), 0.0);
|
||||
assert_eq!(p.sparsity_ratio(), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn neuron_profiler_records_activations() {
|
||||
let mut p = NeuronProfiler::new(4);
|
||||
p.record_activation(0, 1.0); p.record_activation(1, 0.5);
|
||||
p.record_activation(2, 0.1); p.record_activation(3, 0.0);
|
||||
p.end_sample();
|
||||
p.record_activation(0, 2.0); p.record_activation(1, 0.0);
|
||||
p.record_activation(2, 0.0); p.record_activation(3, 0.0);
|
||||
p.end_sample();
|
||||
assert_eq!(p.total_samples(), 2);
|
||||
assert_eq!(p.activation_frequency(0), 1.0);
|
||||
assert_eq!(p.activation_frequency(1), 0.5);
|
||||
assert_eq!(p.activation_frequency(3), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn neuron_profiler_hot_cold_partition() {
|
||||
let mut p = NeuronProfiler::new(5);
|
||||
for _ in 0..20 {
|
||||
p.record_activation(0, 1.0); p.record_activation(1, 1.0);
|
||||
p.record_activation(2, 0.0); p.record_activation(3, 0.0);
|
||||
p.record_activation(4, 0.0); p.end_sample();
|
||||
}
|
||||
let (hot, cold) = p.partition_hot_cold(0.5);
|
||||
assert!(hot.contains(&0) && hot.contains(&1));
|
||||
assert!(cold.contains(&2) && cold.contains(&3) && cold.contains(&4));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn neuron_profiler_sparsity_ratio() {
|
||||
let mut p = NeuronProfiler::new(10);
|
||||
for _ in 0..20 {
|
||||
p.record_activation(0, 1.0); p.record_activation(1, 1.0);
|
||||
for j in 2..10 { p.record_activation(j, 0.0); }
|
||||
p.end_sample();
|
||||
}
|
||||
assert!((p.sparsity_ratio() - 0.8).abs() < f32::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sparse_linear_matches_dense() {
|
||||
let w = vec![vec![1.0,2.0,3.0], vec![4.0,5.0,6.0], vec![7.0,8.0,9.0]];
|
||||
let b = vec![0.1, 0.2, 0.3];
|
||||
let layer = SparseLinear::new(w, b, vec![0,1,2]);
|
||||
let inp = vec![1.0, 0.5, -1.0];
|
||||
let (so, do_) = (layer.forward(&inp), layer.forward_full(&inp));
|
||||
for (s, d) in so.iter().zip(do_.iter()) { assert!((s - d).abs() < 1e-6); }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sparse_linear_skips_cold_neurons() {
|
||||
let w = vec![vec![1.0,2.0], vec![3.0,4.0], vec![5.0,6.0]];
|
||||
let layer = SparseLinear::new(w, vec![0.0;3], vec![1]);
|
||||
let out = layer.forward(&[1.0, 1.0]);
|
||||
assert_eq!(out[0], 0.0);
|
||||
assert_eq!(out[2], 0.0);
|
||||
assert!((out[1] - 7.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sparse_linear_flops_saved() {
|
||||
let w: Vec<Vec<f32>> = (0..4).map(|_| vec![1.0; 4]).collect();
|
||||
let layer = SparseLinear::new(w, vec![0.0;4], vec![0,2]);
|
||||
assert_eq!(layer.n_flops_saved(), 8);
|
||||
assert!((layer.density() - 0.5).abs() < f32::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_symmetric_range() {
|
||||
let qw = Quantizer::quantize_symmetric(&[-1.0, 0.0, 0.5, 1.0]);
|
||||
assert!((qw.scale - 1.0/127.0).abs() < 1e-6);
|
||||
assert_eq!(qw.zero_point, 0);
|
||||
assert_eq!(*qw.data.last().unwrap(), 127);
|
||||
assert_eq!(qw.data[0], -127);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_symmetric_zero_is_zero() {
|
||||
let qw = Quantizer::quantize_symmetric(&[-5.0, 0.0, 3.0, 5.0]);
|
||||
assert_eq!(qw.data[1], 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_asymmetric_range() {
|
||||
let qw = Quantizer::quantize_asymmetric(&[0.0, 0.5, 1.0]);
|
||||
assert!((qw.scale - 1.0/255.0).abs() < 1e-4);
|
||||
assert_eq!(qw.zero_point as u8, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dequantize_round_trip_small_error() {
|
||||
let w: Vec<f32> = (-50..50).map(|i| i as f32 * 0.02).collect();
|
||||
let qw = Quantizer::quantize_symmetric(&w);
|
||||
assert!(Quantizer::quantization_error(&w, &qw) < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn int8_quantization_error_bounded() {
|
||||
let w: Vec<f32> = (0..256).map(|i| (i as f32 * 1.7).sin() * 2.0).collect();
|
||||
assert!(Quantizer::quantization_error(&w, &Quantizer::quantize_symmetric(&w)) < 0.01);
|
||||
assert!(Quantizer::quantization_error(&w, &Quantizer::quantize_asymmetric(&w)) < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn f16_round_trip_precision() {
|
||||
for &v in &[1.0f32, 0.5, -0.5, 3.14, 100.0, 0.001, -42.0, 65504.0] {
|
||||
let enc = Quantizer::f16_quantize(&[v]);
|
||||
let dec = Quantizer::f16_dequantize(&enc)[0];
|
||||
let re = if v.abs() > 1e-6 { ((v - dec) / v).abs() } else { (v - dec).abs() };
|
||||
assert!(re < 0.001, "f16 error for {v}: decoded={dec}, rel={re}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn f16_special_values() {
|
||||
assert_eq!(Quantizer::f16_dequantize(&Quantizer::f16_quantize(&[0.0]))[0], 0.0);
|
||||
let inf = Quantizer::f16_dequantize(&Quantizer::f16_quantize(&[f32::INFINITY]))[0];
|
||||
assert!(inf.is_infinite() && inf > 0.0);
|
||||
let ninf = Quantizer::f16_dequantize(&Quantizer::f16_quantize(&[f32::NEG_INFINITY]))[0];
|
||||
assert!(ninf.is_infinite() && ninf < 0.0);
|
||||
assert!(Quantizer::f16_dequantize(&Quantizer::f16_quantize(&[f32::NAN]))[0].is_nan());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sparse_model_add_layers() {
|
||||
let mut m = SparseModel::new(SparseConfig::default());
|
||||
m.add_layer("l1", vec![vec![1.0,2.0],vec![3.0,4.0]], vec![0.0,0.0]);
|
||||
m.add_layer("l2", vec![vec![0.5,-0.5],vec![1.0,1.0]], vec![0.1,0.2]);
|
||||
assert_eq!(m.n_layers(), 2);
|
||||
let out = m.forward(&[1.0, 1.0]);
|
||||
assert!(out[0] < 0.001); // ReLU zeros negative
|
||||
assert!((out[1] - 10.2).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sparse_model_profile_and_apply() {
|
||||
let mut m = SparseModel::new(SparseConfig { hot_threshold: 0.3, ..Default::default() });
|
||||
m.add_layer("h", vec![
|
||||
vec![1.0;4], vec![0.5;4], vec![-2.0;4], vec![-1.0;4],
|
||||
], vec![0.0;4]);
|
||||
let inp: Vec<Vec<f32>> = (0..50).map(|i| vec![1.0 + i as f32 * 0.01; 4]).collect();
|
||||
m.profile(&inp);
|
||||
m.apply_sparsity();
|
||||
let s = m.stats();
|
||||
assert!(s.cold_params > 0);
|
||||
assert!(s.sparsity > 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sparse_model_stats_report() {
|
||||
let mut m = SparseModel::new(SparseConfig::default());
|
||||
m.add_layer("fc1", vec![vec![1.0;8];16], vec![0.0;16]);
|
||||
let s = m.stats();
|
||||
assert_eq!(s.total_params, 16*8+16);
|
||||
assert_eq!(s.quant_mode, QuantMode::Int8Symmetric);
|
||||
assert!(s.est_flops > 0 && s.est_memory_bytes > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn benchmark_produces_positive_latency() {
|
||||
let mut m = SparseModel::new(SparseConfig::default());
|
||||
m.add_layer("fc1", vec![vec![1.0;4];4], vec![0.0;4]);
|
||||
let r = BenchmarkRunner::benchmark_inference(&m, &[1.0;4], 10);
|
||||
assert!(r.mean_latency_us >= 0.0 && r.throughput_fps > 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compare_dense_sparse_speedup() {
|
||||
let w = vec![vec![1.0f32;8];16];
|
||||
let b = vec![0.0f32;16];
|
||||
let mut pm = SparseModel::new(SparseConfig { hot_threshold: 0.5, quant_mode: QuantMode::F32, profile_frames: 20 });
|
||||
let mut pw: Vec<Vec<f32>> = w.clone();
|
||||
for row in pw.iter_mut().skip(8) { for v in row.iter_mut() { *v = -1.0; } }
|
||||
pm.add_layer("fc1", pw, b.clone());
|
||||
let inp: Vec<Vec<f32>> = (0..20).map(|_| vec![1.0;8]).collect();
|
||||
pm.profile(&inp); pm.apply_sparsity();
|
||||
let r = BenchmarkRunner::compare_dense_vs_sparse(&[w], &[b], &pm, &[1.0;8], 50);
|
||||
assert!(r.dense_latency_us >= 0.0 && r.sparse_latency_us >= 0.0);
|
||||
assert!(r.speedup > 0.0);
|
||||
assert!(r.accuracy_loss.is_finite());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,682 @@
|
||||
//! Training loop with multi-term loss function for WiFi DensePose (ADR-023 Phase 4).
|
||||
//!
|
||||
//! 6-term composite loss, SGD with momentum, cosine annealing LR scheduler,
|
||||
//! PCK/OKS validation metrics, numerical gradient estimation, and checkpointing.
|
||||
//! All arithmetic uses f32. No external ML framework dependencies.
|
||||
|
||||
use std::path::Path;
|
||||
|
||||
/// Standard COCO keypoint sigmas for OKS (17 keypoints).
|
||||
pub const COCO_KEYPOINT_SIGMAS: [f32; 17] = [
|
||||
0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079, 0.072, 0.072, 0.062,
|
||||
0.062, 0.107, 0.107, 0.087, 0.087, 0.089, 0.089,
|
||||
];
|
||||
|
||||
/// Symmetric keypoint pairs (left, right) indices into 17-keypoint COCO layout.
|
||||
const SYMMETRY_PAIRS: [(usize, usize); 5] =
|
||||
[(5, 6), (7, 8), (9, 10), (11, 12), (13, 14)];
|
||||
|
||||
/// Individual loss terms from the 6-component composite loss.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct LossComponents {
|
||||
pub keypoint: f32,
|
||||
pub body_part: f32,
|
||||
pub uv: f32,
|
||||
pub temporal: f32,
|
||||
pub edge: f32,
|
||||
pub symmetry: f32,
|
||||
}
|
||||
|
||||
/// Per-term weights for the composite loss function.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LossWeights {
|
||||
pub keypoint: f32,
|
||||
pub body_part: f32,
|
||||
pub uv: f32,
|
||||
pub temporal: f32,
|
||||
pub edge: f32,
|
||||
pub symmetry: f32,
|
||||
}
|
||||
|
||||
impl Default for LossWeights {
|
||||
fn default() -> Self {
|
||||
Self { keypoint: 1.0, body_part: 0.5, uv: 0.5, temporal: 0.1, edge: 0.2, symmetry: 0.1 }
|
||||
}
|
||||
}
|
||||
|
||||
/// Mean squared error on keypoints (x, y, confidence).
|
||||
pub fn keypoint_mse(pred: &[(f32, f32, f32)], target: &[(f32, f32, f32)]) -> f32 {
|
||||
if pred.is_empty() || target.is_empty() { return 0.0; }
|
||||
let n = pred.len().min(target.len());
|
||||
let sum: f32 = pred.iter().zip(target.iter()).take(n).map(|(p, t)| {
|
||||
(p.0 - t.0).powi(2) + (p.1 - t.1).powi(2) + (p.2 - t.2).powi(2)
|
||||
}).sum();
|
||||
sum / n as f32
|
||||
}
|
||||
|
||||
/// Cross-entropy loss for body part classification.
|
||||
/// `pred` = raw logits (length `n_samples * n_parts`), `target` = class indices.
|
||||
pub fn body_part_cross_entropy(pred: &[f32], target: &[u8], n_parts: usize) -> f32 {
|
||||
if target.is_empty() || n_parts == 0 || pred.len() < n_parts { return 0.0; }
|
||||
let n_samples = target.len().min(pred.len() / n_parts);
|
||||
if n_samples == 0 { return 0.0; }
|
||||
let mut total = 0.0f32;
|
||||
for i in 0..n_samples {
|
||||
let logits = &pred[i * n_parts..(i + 1) * n_parts];
|
||||
let class = target[i] as usize;
|
||||
if class >= n_parts { continue; }
|
||||
let max_l = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
|
||||
let lse = logits.iter().map(|&l| (l - max_l).exp()).sum::<f32>().ln() + max_l;
|
||||
total += -logits[class] + lse;
|
||||
}
|
||||
total / n_samples as f32
|
||||
}
|
||||
|
||||
/// L1 loss on UV coordinates.
|
||||
pub fn uv_regression_loss(pu: &[f32], pv: &[f32], tu: &[f32], tv: &[f32]) -> f32 {
|
||||
let n = pu.len().min(pv.len()).min(tu.len()).min(tv.len());
|
||||
if n == 0 { return 0.0; }
|
||||
let s: f32 = (0..n).map(|i| (pu[i] - tu[i]).abs() + (pv[i] - tv[i]).abs()).sum();
|
||||
s / n as f32
|
||||
}
|
||||
|
||||
/// Temporal consistency loss: penalizes large frame-to-frame keypoint jumps.
|
||||
pub fn temporal_consistency_loss(prev: &[(f32, f32, f32)], curr: &[(f32, f32, f32)]) -> f32 {
|
||||
let n = prev.len().min(curr.len());
|
||||
if n == 0 { return 0.0; }
|
||||
let s: f32 = prev.iter().zip(curr.iter()).take(n)
|
||||
.map(|(p, c)| (c.0 - p.0).powi(2) + (c.1 - p.1).powi(2)).sum();
|
||||
s / n as f32
|
||||
}
|
||||
|
||||
/// Graph edge loss: penalizes deviation of bone lengths from expected values.
|
||||
pub fn graph_edge_loss(
|
||||
kp: &[(f32, f32, f32)], edges: &[(usize, usize)], expected: &[f32],
|
||||
) -> f32 {
|
||||
if edges.is_empty() || edges.len() != expected.len() { return 0.0; }
|
||||
let (mut sum, mut cnt) = (0.0f32, 0usize);
|
||||
for (i, &(a, b)) in edges.iter().enumerate() {
|
||||
if a >= kp.len() || b >= kp.len() { continue; }
|
||||
let d = ((kp[a].0 - kp[b].0).powi(2) + (kp[a].1 - kp[b].1).powi(2)).sqrt();
|
||||
sum += (d - expected[i]).powi(2);
|
||||
cnt += 1;
|
||||
}
|
||||
if cnt == 0 { 0.0 } else { sum / cnt as f32 }
|
||||
}
|
||||
|
||||
/// Symmetry loss: penalizes asymmetry between left-right limb pairs.
|
||||
pub fn symmetry_loss(kp: &[(f32, f32, f32)]) -> f32 {
|
||||
if kp.len() < 15 { return 0.0; }
|
||||
let (mut sum, mut cnt) = (0.0f32, 0usize);
|
||||
for &(l, r) in &SYMMETRY_PAIRS {
|
||||
if l >= kp.len() || r >= kp.len() { continue; }
|
||||
let ld = ((kp[l].0 - kp[0].0).powi(2) + (kp[l].1 - kp[0].1).powi(2)).sqrt();
|
||||
let rd = ((kp[r].0 - kp[0].0).powi(2) + (kp[r].1 - kp[0].1).powi(2)).sqrt();
|
||||
sum += (ld - rd).powi(2);
|
||||
cnt += 1;
|
||||
}
|
||||
if cnt == 0 { 0.0 } else { sum / cnt as f32 }
|
||||
}
|
||||
|
||||
/// Weighted composite loss from individual components.
|
||||
pub fn composite_loss(c: &LossComponents, w: &LossWeights) -> f32 {
|
||||
w.keypoint * c.keypoint + w.body_part * c.body_part + w.uv * c.uv
|
||||
+ w.temporal * c.temporal + w.edge * c.edge + w.symmetry * c.symmetry
|
||||
}
|
||||
|
||||
// ── Optimizer ──────────────────────────────────────────────────────────────
|
||||
|
||||
/// SGD optimizer with momentum and weight decay.
|
||||
pub struct SgdOptimizer {
|
||||
lr: f32,
|
||||
momentum: f32,
|
||||
weight_decay: f32,
|
||||
velocity: Vec<f32>,
|
||||
}
|
||||
|
||||
impl SgdOptimizer {
|
||||
pub fn new(lr: f32, momentum: f32, weight_decay: f32) -> Self {
|
||||
Self { lr, momentum, weight_decay, velocity: Vec::new() }
|
||||
}
|
||||
|
||||
/// v = mu*v + grad + wd*param; param -= lr*v
|
||||
pub fn step(&mut self, params: &mut [f32], gradients: &[f32]) {
|
||||
if self.velocity.len() != params.len() {
|
||||
self.velocity = vec![0.0; params.len()];
|
||||
}
|
||||
for i in 0..params.len().min(gradients.len()) {
|
||||
let g = gradients[i] + self.weight_decay * params[i];
|
||||
self.velocity[i] = self.momentum * self.velocity[i] + g;
|
||||
params[i] -= self.lr * self.velocity[i];
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_lr(&mut self, lr: f32) { self.lr = lr; }
|
||||
pub fn state(&self) -> Vec<f32> { self.velocity.clone() }
|
||||
pub fn load_state(&mut self, state: Vec<f32>) { self.velocity = state; }
|
||||
}
|
||||
|
||||
// ── Learning rate schedulers ───────────────────────────────────────────────
|
||||
|
||||
/// Cosine annealing: decays LR from initial to min over total_steps.
|
||||
pub struct CosineScheduler { initial_lr: f32, min_lr: f32, total_steps: usize }
|
||||
|
||||
impl CosineScheduler {
|
||||
pub fn new(initial_lr: f32, min_lr: f32, total_steps: usize) -> Self {
|
||||
Self { initial_lr, min_lr, total_steps }
|
||||
}
|
||||
pub fn get_lr(&self, step: usize) -> f32 {
|
||||
if self.total_steps == 0 { return self.initial_lr; }
|
||||
let p = step.min(self.total_steps) as f32 / self.total_steps as f32;
|
||||
self.min_lr + (self.initial_lr - self.min_lr) * (1.0 + (std::f32::consts::PI * p).cos()) / 2.0
|
||||
}
|
||||
}
|
||||
|
||||
/// Warmup + cosine annealing: linear ramp 0->initial_lr then cosine decay.
|
||||
pub struct WarmupCosineScheduler {
|
||||
warmup_steps: usize, initial_lr: f32, min_lr: f32, total_steps: usize,
|
||||
}
|
||||
|
||||
impl WarmupCosineScheduler {
|
||||
pub fn new(warmup_steps: usize, initial_lr: f32, min_lr: f32, total_steps: usize) -> Self {
|
||||
Self { warmup_steps, initial_lr, min_lr, total_steps }
|
||||
}
|
||||
pub fn get_lr(&self, step: usize) -> f32 {
|
||||
if step < self.warmup_steps {
|
||||
if self.warmup_steps == 0 { return self.initial_lr; }
|
||||
return self.initial_lr * (step as f32 / self.warmup_steps as f32);
|
||||
}
|
||||
let cs = self.total_steps.saturating_sub(self.warmup_steps);
|
||||
if cs == 0 { return self.min_lr; }
|
||||
let p = (step - self.warmup_steps).min(cs) as f32 / cs as f32;
|
||||
self.min_lr + (self.initial_lr - self.min_lr) * (1.0 + (std::f32::consts::PI * p).cos()) / 2.0
|
||||
}
|
||||
}
|
||||
|
||||
// ── Validation metrics ─────────────────────────────────────────────────────
|
||||
|
||||
/// Percentage of Correct Keypoints at a distance threshold.
|
||||
pub fn pck_at_threshold(pred: &[(f32, f32, f32)], target: &[(f32, f32, f32)], thr: f32) -> f32 {
|
||||
let n = pred.len().min(target.len());
|
||||
if n == 0 { return 0.0; }
|
||||
let (mut correct, mut total) = (0usize, 0usize);
|
||||
for i in 0..n {
|
||||
if target[i].2 <= 0.0 { continue; }
|
||||
total += 1;
|
||||
let d = ((pred[i].0 - target[i].0).powi(2) + (pred[i].1 - target[i].1).powi(2)).sqrt();
|
||||
if d <= thr { correct += 1; }
|
||||
}
|
||||
if total == 0 { 0.0 } else { correct as f32 / total as f32 }
|
||||
}
|
||||
|
||||
/// Object Keypoint Similarity for a single instance.
|
||||
pub fn oks_single(
|
||||
pred: &[(f32, f32, f32)], target: &[(f32, f32, f32)], sigmas: &[f32], area: f32,
|
||||
) -> f32 {
|
||||
let n = pred.len().min(target.len()).min(sigmas.len());
|
||||
if n == 0 || area <= 0.0 { return 0.0; }
|
||||
let (mut sum, mut vis) = (0.0f32, 0usize);
|
||||
for i in 0..n {
|
||||
if target[i].2 <= 0.0 { continue; }
|
||||
vis += 1;
|
||||
let dsq = (pred[i].0 - target[i].0).powi(2) + (pred[i].1 - target[i].1).powi(2);
|
||||
let var = 2.0 * sigmas[i] * sigmas[i] * area;
|
||||
if var > 0.0 { sum += (-dsq / (2.0 * var)).exp(); }
|
||||
}
|
||||
if vis == 0 { 0.0 } else { sum / vis as f32 }
|
||||
}
|
||||
|
||||
/// Mean OKS over multiple predictions (simplified mAP).
|
||||
pub fn oks_map(preds: &[Vec<(f32, f32, f32)>], targets: &[Vec<(f32, f32, f32)>]) -> f32 {
|
||||
let n = preds.len().min(targets.len());
|
||||
if n == 0 { return 0.0; }
|
||||
let s: f32 = preds.iter().zip(targets.iter()).take(n)
|
||||
.map(|(p, t)| oks_single(p, t, &COCO_KEYPOINT_SIGMAS, 1.0)).sum();
|
||||
s / n as f32
|
||||
}
|
||||
|
||||
// ── Gradient estimation ────────────────────────────────────────────────────
|
||||
|
||||
/// Central difference gradient: (f(x+eps) - f(x-eps)) / (2*eps).
|
||||
pub fn estimate_gradient(f: impl Fn(&[f32]) -> f32, params: &[f32], eps: f32) -> Vec<f32> {
|
||||
let mut grad = vec![0.0f32; params.len()];
|
||||
let mut p_plus = params.to_vec();
|
||||
let mut p_minus = params.to_vec();
|
||||
for i in 0..params.len() {
|
||||
p_plus[i] = params[i] + eps;
|
||||
p_minus[i] = params[i] - eps;
|
||||
grad[i] = (f(&p_plus) - f(&p_minus)) / (2.0 * eps);
|
||||
p_plus[i] = params[i];
|
||||
p_minus[i] = params[i];
|
||||
}
|
||||
grad
|
||||
}
|
||||
|
||||
/// Clip gradients by global L2 norm.
|
||||
pub fn clip_gradients(gradients: &mut [f32], max_norm: f32) {
|
||||
let norm = gradients.iter().map(|g| g * g).sum::<f32>().sqrt();
|
||||
if norm > max_norm && norm > 0.0 {
|
||||
let s = max_norm / norm;
|
||||
gradients.iter_mut().for_each(|g| *g *= s);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Training sample ────────────────────────────────────────────────────────
|
||||
|
||||
/// A single training sample (defined locally, not dependent on dataset.rs).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TrainingSample {
|
||||
pub csi_features: Vec<Vec<f32>>,
|
||||
pub target_keypoints: Vec<(f32, f32, f32)>,
|
||||
pub target_body_parts: Vec<u8>,
|
||||
pub target_uv: (Vec<f32>, Vec<f32>),
|
||||
}
|
||||
|
||||
// ── Checkpoint ─────────────────────────────────────────────────────────────
|
||||
|
||||
/// Serializable version of EpochStats for checkpoint storage.
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct EpochStatsSerializable {
|
||||
pub epoch: usize, pub train_loss: f32, pub val_loss: f32,
|
||||
pub pck_02: f32, pub oks_map: f32, pub lr: f32,
|
||||
pub loss_keypoint: f32, pub loss_body_part: f32, pub loss_uv: f32,
|
||||
pub loss_temporal: f32, pub loss_edge: f32, pub loss_symmetry: f32,
|
||||
}
|
||||
|
||||
/// Serializable training checkpoint.
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct Checkpoint {
|
||||
pub epoch: usize,
|
||||
pub params: Vec<f32>,
|
||||
pub optimizer_state: Vec<f32>,
|
||||
pub best_loss: f32,
|
||||
pub metrics: EpochStatsSerializable,
|
||||
}
|
||||
|
||||
impl Checkpoint {
|
||||
pub fn save_to_file(&self, path: &Path) -> std::io::Result<()> {
|
||||
let json = serde_json::to_string_pretty(self)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
|
||||
std::fs::write(path, json)
|
||||
}
|
||||
pub fn load_from_file(path: &Path) -> std::io::Result<Self> {
|
||||
let json = std::fs::read_to_string(path)?;
|
||||
serde_json::from_str(&json)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
|
||||
}
|
||||
}
|
||||
|
||||
/// Statistics for a single training epoch.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EpochStats {
|
||||
pub epoch: usize,
|
||||
pub train_loss: f32,
|
||||
pub val_loss: f32,
|
||||
pub pck_02: f32,
|
||||
pub oks_map: f32,
|
||||
pub lr: f32,
|
||||
pub loss_components: LossComponents,
|
||||
}
|
||||
|
||||
impl EpochStats {
|
||||
fn to_serializable(&self) -> EpochStatsSerializable {
|
||||
let c = &self.loss_components;
|
||||
EpochStatsSerializable {
|
||||
epoch: self.epoch, train_loss: self.train_loss, val_loss: self.val_loss,
|
||||
pck_02: self.pck_02, oks_map: self.oks_map, lr: self.lr,
|
||||
loss_keypoint: c.keypoint, loss_body_part: c.body_part, loss_uv: c.uv,
|
||||
loss_temporal: c.temporal, loss_edge: c.edge, loss_symmetry: c.symmetry,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Final result from a complete training run.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TrainingResult {
|
||||
pub best_epoch: usize,
|
||||
pub best_pck: f32,
|
||||
pub best_oks: f32,
|
||||
pub history: Vec<EpochStats>,
|
||||
pub total_time_secs: f64,
|
||||
}
|
||||
|
||||
/// Configuration for the training loop.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TrainerConfig {
|
||||
pub epochs: usize,
|
||||
pub batch_size: usize,
|
||||
pub lr: f32,
|
||||
pub momentum: f32,
|
||||
pub weight_decay: f32,
|
||||
pub warmup_epochs: usize,
|
||||
pub min_lr: f32,
|
||||
pub early_stop_patience: usize,
|
||||
pub checkpoint_every: usize,
|
||||
pub loss_weights: LossWeights,
|
||||
}
|
||||
|
||||
impl Default for TrainerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
epochs: 100, batch_size: 32, lr: 0.01, momentum: 0.9, weight_decay: 1e-4,
|
||||
warmup_epochs: 5, min_lr: 1e-6, early_stop_patience: 10, checkpoint_every: 10,
|
||||
loss_weights: LossWeights::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Trainer ────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Training loop orchestrator for WiFi DensePose pose estimation.
|
||||
pub struct Trainer {
|
||||
config: TrainerConfig,
|
||||
optimizer: SgdOptimizer,
|
||||
scheduler: WarmupCosineScheduler,
|
||||
params: Vec<f32>,
|
||||
history: Vec<EpochStats>,
|
||||
best_val_loss: f32,
|
||||
best_epoch: usize,
|
||||
epochs_without_improvement: usize,
|
||||
}
|
||||
|
||||
impl Trainer {
|
||||
pub fn new(config: TrainerConfig) -> Self {
|
||||
let optimizer = SgdOptimizer::new(config.lr, config.momentum, config.weight_decay);
|
||||
let scheduler = WarmupCosineScheduler::new(
|
||||
config.warmup_epochs, config.lr, config.min_lr, config.epochs,
|
||||
);
|
||||
let params: Vec<f32> = (0..64).map(|i| (i as f32 * 0.7 + 0.3).sin() * 0.1).collect();
|
||||
Self {
|
||||
config, optimizer, scheduler, params, history: Vec::new(),
|
||||
best_val_loss: f32::MAX, best_epoch: 0, epochs_without_improvement: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn train_epoch(&mut self, samples: &[TrainingSample]) -> EpochStats {
|
||||
let epoch = self.history.len();
|
||||
let lr = self.scheduler.get_lr(epoch);
|
||||
self.optimizer.set_lr(lr);
|
||||
|
||||
let mut acc = LossComponents::default();
|
||||
let bs = self.config.batch_size.max(1);
|
||||
let nb = (samples.len() + bs - 1) / bs;
|
||||
|
||||
for bi in 0..nb {
|
||||
let batch = &samples[bi * bs..(bi * bs + bs).min(samples.len())];
|
||||
let snap = self.params.clone();
|
||||
let w = self.config.loss_weights.clone();
|
||||
let loss_fn = |p: &[f32]| Self::batch_loss(p, batch, &w);
|
||||
let mut grad = estimate_gradient(loss_fn, &snap, 1e-4);
|
||||
clip_gradients(&mut grad, 1.0);
|
||||
self.optimizer.step(&mut self.params, &grad);
|
||||
|
||||
let c = Self::batch_loss_components(&self.params, batch);
|
||||
acc.keypoint += c.keypoint;
|
||||
acc.body_part += c.body_part;
|
||||
acc.uv += c.uv;
|
||||
acc.temporal += c.temporal;
|
||||
acc.edge += c.edge;
|
||||
acc.symmetry += c.symmetry;
|
||||
}
|
||||
|
||||
if nb > 0 {
|
||||
let inv = 1.0 / nb as f32;
|
||||
acc.keypoint *= inv; acc.body_part *= inv; acc.uv *= inv;
|
||||
acc.temporal *= inv; acc.edge *= inv; acc.symmetry *= inv;
|
||||
}
|
||||
|
||||
let train_loss = composite_loss(&acc, &self.config.loss_weights);
|
||||
let (pck, oks) = self.evaluate_metrics(samples);
|
||||
let stats = EpochStats {
|
||||
epoch, train_loss, val_loss: train_loss, pck_02: pck, oks_map: oks,
|
||||
lr, loss_components: acc,
|
||||
};
|
||||
self.history.push(stats.clone());
|
||||
stats
|
||||
}
|
||||
|
||||
pub fn should_stop(&self) -> bool {
|
||||
self.epochs_without_improvement >= self.config.early_stop_patience
|
||||
}
|
||||
|
||||
pub fn best_metrics(&self) -> Option<&EpochStats> {
|
||||
self.history.get(self.best_epoch)
|
||||
}
|
||||
|
||||
pub fn run_training(&mut self, train: &[TrainingSample], val: &[TrainingSample]) -> TrainingResult {
|
||||
let start = std::time::Instant::now();
|
||||
for _ in 0..self.config.epochs {
|
||||
let mut stats = self.train_epoch(train);
|
||||
let val_loss = if !val.is_empty() {
|
||||
let c = Self::batch_loss_components(&self.params, val);
|
||||
composite_loss(&c, &self.config.loss_weights)
|
||||
} else { stats.train_loss };
|
||||
stats.val_loss = val_loss;
|
||||
if !val.is_empty() {
|
||||
let (pck, oks) = self.evaluate_metrics(val);
|
||||
stats.pck_02 = pck;
|
||||
stats.oks_map = oks;
|
||||
}
|
||||
if let Some(last) = self.history.last_mut() {
|
||||
last.val_loss = stats.val_loss;
|
||||
last.pck_02 = stats.pck_02;
|
||||
last.oks_map = stats.oks_map;
|
||||
}
|
||||
if val_loss < self.best_val_loss {
|
||||
self.best_val_loss = val_loss;
|
||||
self.best_epoch = stats.epoch;
|
||||
self.epochs_without_improvement = 0;
|
||||
} else {
|
||||
self.epochs_without_improvement += 1;
|
||||
}
|
||||
if self.should_stop() { break; }
|
||||
}
|
||||
let best = self.best_metrics().cloned().unwrap_or(EpochStats {
|
||||
epoch: 0, train_loss: f32::MAX, val_loss: f32::MAX, pck_02: 0.0,
|
||||
oks_map: 0.0, lr: self.config.lr, loss_components: LossComponents::default(),
|
||||
});
|
||||
TrainingResult {
|
||||
best_epoch: best.epoch, best_pck: best.pck_02, best_oks: best.oks_map,
|
||||
history: self.history.clone(), total_time_secs: start.elapsed().as_secs_f64(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn checkpoint(&self) -> Checkpoint {
|
||||
let m = self.history.last().map(|s| s.to_serializable()).unwrap_or(
|
||||
EpochStatsSerializable {
|
||||
epoch: 0, train_loss: 0.0, val_loss: 0.0, pck_02: 0.0,
|
||||
oks_map: 0.0, lr: self.config.lr, loss_keypoint: 0.0, loss_body_part: 0.0,
|
||||
loss_uv: 0.0, loss_temporal: 0.0, loss_edge: 0.0, loss_symmetry: 0.0,
|
||||
},
|
||||
);
|
||||
Checkpoint {
|
||||
epoch: self.history.len(), params: self.params.clone(),
|
||||
optimizer_state: self.optimizer.state(), best_loss: self.best_val_loss, metrics: m,
|
||||
}
|
||||
}
|
||||
|
||||
fn batch_loss(params: &[f32], batch: &[TrainingSample], w: &LossWeights) -> f32 {
|
||||
composite_loss(&Self::batch_loss_components(params, batch), w)
|
||||
}
|
||||
|
||||
fn batch_loss_components(params: &[f32], batch: &[TrainingSample]) -> LossComponents {
|
||||
if batch.is_empty() { return LossComponents::default(); }
|
||||
let mut acc = LossComponents::default();
|
||||
let mut prev_kp: Option<Vec<(f32, f32, f32)>> = None;
|
||||
for sample in batch {
|
||||
let pred_kp = Self::predict_keypoints(params, sample);
|
||||
acc.keypoint += keypoint_mse(&pred_kp, &sample.target_keypoints);
|
||||
let n_parts = 24usize;
|
||||
let logits: Vec<f32> = sample.target_body_parts.iter().flat_map(|_| {
|
||||
(0..n_parts).map(|j| if j < params.len() { params[j] * 0.1 } else { 0.0 })
|
||||
.collect::<Vec<f32>>()
|
||||
}).collect();
|
||||
acc.body_part += body_part_cross_entropy(&logits, &sample.target_body_parts, n_parts);
|
||||
let (ref tu, ref tv) = sample.target_uv;
|
||||
let pu: Vec<f32> = tu.iter().enumerate()
|
||||
.map(|(i, &u)| u + if i < params.len() { params[i] * 0.01 } else { 0.0 }).collect();
|
||||
let pv: Vec<f32> = tv.iter().enumerate()
|
||||
.map(|(i, &v)| v + if i < params.len() { params[i] * 0.01 } else { 0.0 }).collect();
|
||||
acc.uv += uv_regression_loss(&pu, &pv, tu, tv);
|
||||
if let Some(ref prev) = prev_kp {
|
||||
acc.temporal += temporal_consistency_loss(prev, &pred_kp);
|
||||
}
|
||||
acc.symmetry += symmetry_loss(&pred_kp);
|
||||
prev_kp = Some(pred_kp);
|
||||
}
|
||||
let inv = 1.0 / batch.len() as f32;
|
||||
acc.keypoint *= inv; acc.body_part *= inv; acc.uv *= inv;
|
||||
acc.temporal *= inv; acc.symmetry *= inv;
|
||||
acc
|
||||
}
|
||||
|
||||
fn predict_keypoints(params: &[f32], sample: &TrainingSample) -> Vec<(f32, f32, f32)> {
|
||||
let n_kp = sample.target_keypoints.len().max(17);
|
||||
let feats: Vec<f32> = sample.csi_features.iter().flat_map(|v| v.iter().copied()).collect();
|
||||
(0..n_kp).map(|k| {
|
||||
let base = k * 3;
|
||||
let (mut x, mut y) = (0.0f32, 0.0f32);
|
||||
for (i, &f) in feats.iter().take(params.len()).enumerate() {
|
||||
let pi = (base + i) % params.len();
|
||||
x += f * params[pi] * 0.01;
|
||||
y += f * params[(pi + 1) % params.len()] * 0.01;
|
||||
}
|
||||
if base < params.len() {
|
||||
x += params[base % params.len()];
|
||||
y += params[(base + 1) % params.len()];
|
||||
}
|
||||
let c = if base + 2 < params.len() {
|
||||
params[(base + 2) % params.len()].clamp(0.0, 1.0)
|
||||
} else { 0.5 };
|
||||
(x, y, c)
|
||||
}).collect()
|
||||
}
|
||||
|
||||
fn evaluate_metrics(&self, samples: &[TrainingSample]) -> (f32, f32) {
|
||||
if samples.is_empty() { return (0.0, 0.0); }
|
||||
let preds: Vec<Vec<_>> = samples.iter().map(|s| Self::predict_keypoints(&self.params, s)).collect();
|
||||
let targets: Vec<Vec<_>> = samples.iter().map(|s| s.target_keypoints.clone()).collect();
|
||||
let pck = preds.iter().zip(targets.iter())
|
||||
.map(|(p, t)| pck_at_threshold(p, t, 0.2)).sum::<f32>() / samples.len() as f32;
|
||||
(pck, oks_map(&preds, &targets))
|
||||
}
|
||||
}
|
||||
|
||||
// ── Tests ──────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn mkp(off: f32) -> Vec<(f32, f32, f32)> {
|
||||
(0..17).map(|i| (i as f32 + off, i as f32 * 2.0 + off, 1.0)).collect()
|
||||
}
|
||||
|
||||
fn symmetric_pose() -> Vec<(f32, f32, f32)> {
|
||||
let mut kp = vec![(0.0f32, 0.0f32, 1.0f32); 17];
|
||||
kp[0] = (5.0, 5.0, 1.0);
|
||||
for &(l, r) in &SYMMETRY_PAIRS { kp[l] = (3.0, 5.0, 1.0); kp[r] = (7.0, 5.0, 1.0); }
|
||||
kp
|
||||
}
|
||||
|
||||
fn sample() -> TrainingSample {
|
||||
TrainingSample {
|
||||
csi_features: vec![vec![1.0; 8]; 4],
|
||||
target_keypoints: mkp(0.0),
|
||||
target_body_parts: vec![0, 1, 2, 3],
|
||||
target_uv: (vec![0.5; 4], vec![0.5; 4]),
|
||||
}
|
||||
}
|
||||
|
||||
#[test] fn keypoint_mse_zero_for_identical() { assert_eq!(keypoint_mse(&mkp(0.0), &mkp(0.0)), 0.0); }
|
||||
#[test] fn keypoint_mse_positive_for_different() { assert!(keypoint_mse(&mkp(0.0), &mkp(1.0)) > 0.0); }
|
||||
#[test] fn keypoint_mse_symmetric() {
|
||||
let (ab, ba) = (keypoint_mse(&mkp(0.0), &mkp(1.0)), keypoint_mse(&mkp(1.0), &mkp(0.0)));
|
||||
assert!((ab - ba).abs() < 1e-6, "{ab} vs {ba}");
|
||||
}
|
||||
#[test] fn temporal_consistency_zero_for_static() {
|
||||
assert_eq!(temporal_consistency_loss(&mkp(0.0), &mkp(0.0)), 0.0);
|
||||
}
|
||||
#[test] fn temporal_consistency_positive_for_motion() {
|
||||
assert!(temporal_consistency_loss(&mkp(0.0), &mkp(1.0)) > 0.0);
|
||||
}
|
||||
#[test] fn symmetry_loss_zero_for_symmetric_pose() {
|
||||
assert!(symmetry_loss(&symmetric_pose()) < 1e-6);
|
||||
}
|
||||
#[test] fn graph_edge_loss_zero_when_correct() {
|
||||
let kp = vec![(0.0,0.0,1.0),(3.0,4.0,1.0),(6.0,0.0,1.0)];
|
||||
assert!(graph_edge_loss(&kp, &[(0,1),(1,2)], &[5.0, 5.0]) < 1e-6);
|
||||
}
|
||||
#[test] fn composite_loss_respects_weights() {
|
||||
let c = LossComponents { keypoint:1.0, body_part:1.0, uv:1.0, temporal:1.0, edge:1.0, symmetry:1.0 };
|
||||
let w1 = LossWeights { keypoint:1.0, body_part:0.0, uv:0.0, temporal:0.0, edge:0.0, symmetry:0.0 };
|
||||
let w2 = LossWeights { keypoint:2.0, body_part:0.0, uv:0.0, temporal:0.0, edge:0.0, symmetry:0.0 };
|
||||
assert!((composite_loss(&c, &w2) - 2.0 * composite_loss(&c, &w1)).abs() < 1e-6);
|
||||
let wz = LossWeights { keypoint:0.0, body_part:0.0, uv:0.0, temporal:0.0, edge:0.0, symmetry:0.0 };
|
||||
assert_eq!(composite_loss(&c, &wz), 0.0);
|
||||
}
|
||||
#[test] fn cosine_scheduler_starts_at_initial() {
|
||||
assert!((CosineScheduler::new(0.01, 0.0001, 100).get_lr(0) - 0.01).abs() < 1e-6);
|
||||
}
|
||||
#[test] fn cosine_scheduler_ends_at_min() {
|
||||
assert!((CosineScheduler::new(0.01, 0.0001, 100).get_lr(100) - 0.0001).abs() < 1e-6);
|
||||
}
|
||||
#[test] fn cosine_scheduler_midpoint() {
|
||||
assert!((CosineScheduler::new(0.01, 0.0, 100).get_lr(50) - 0.005).abs() < 1e-4);
|
||||
}
|
||||
#[test] fn warmup_starts_at_zero() {
|
||||
assert!(WarmupCosineScheduler::new(10, 0.01, 0.0001, 100).get_lr(0) < 1e-6);
|
||||
}
|
||||
#[test] fn warmup_reaches_initial_at_warmup_end() {
|
||||
assert!((WarmupCosineScheduler::new(10, 0.01, 0.0001, 100).get_lr(10) - 0.01).abs() < 1e-6);
|
||||
}
|
||||
#[test] fn pck_perfect_prediction_is_1() {
|
||||
assert!((pck_at_threshold(&mkp(0.0), &mkp(0.0), 0.2) - 1.0).abs() < 1e-6);
|
||||
}
|
||||
#[test] fn pck_all_wrong_is_0() {
|
||||
assert!(pck_at_threshold(&mkp(0.0), &mkp(100.0), 0.2) < 1e-6);
|
||||
}
|
||||
#[test] fn oks_perfect_is_1() {
|
||||
assert!((oks_single(&mkp(0.0), &mkp(0.0), &COCO_KEYPOINT_SIGMAS, 1.0) - 1.0).abs() < 1e-6);
|
||||
}
|
||||
#[test] fn sgd_step_reduces_simple_loss() {
|
||||
let mut p = vec![5.0f32];
|
||||
let mut opt = SgdOptimizer::new(0.1, 0.0, 0.0);
|
||||
let init = p[0] * p[0];
|
||||
for _ in 0..10 { let grad = vec![2.0 * p[0]]; opt.step(&mut p, &grad); }
|
||||
assert!(p[0] * p[0] < init);
|
||||
}
|
||||
#[test] fn gradient_clipping_respects_max_norm() {
|
||||
let mut g = vec![3.0, 4.0];
|
||||
clip_gradients(&mut g, 2.5);
|
||||
assert!((g.iter().map(|x| x*x).sum::<f32>().sqrt() - 2.5).abs() < 1e-4);
|
||||
}
|
||||
#[test] fn early_stopping_triggers() {
|
||||
let cfg = TrainerConfig { epochs: 100, early_stop_patience: 3, ..Default::default() };
|
||||
let mut t = Trainer::new(cfg);
|
||||
let s = vec![sample()];
|
||||
t.best_val_loss = -1.0;
|
||||
let mut stopped = false;
|
||||
for _ in 0..20 {
|
||||
t.train_epoch(&s);
|
||||
t.epochs_without_improvement += 1;
|
||||
if t.should_stop() { stopped = true; break; }
|
||||
}
|
||||
assert!(stopped);
|
||||
}
|
||||
#[test] fn checkpoint_round_trip() {
|
||||
let mut t = Trainer::new(TrainerConfig::default());
|
||||
t.train_epoch(&[sample()]);
|
||||
let ckpt = t.checkpoint();
|
||||
let dir = std::env::temp_dir().join("trainer_ckpt_test");
|
||||
std::fs::create_dir_all(&dir).unwrap();
|
||||
let path = dir.join("ckpt.json");
|
||||
ckpt.save_to_file(&path).unwrap();
|
||||
let loaded = Checkpoint::load_from_file(&path).unwrap();
|
||||
assert_eq!(loaded.epoch, ckpt.epoch);
|
||||
assert_eq!(loaded.params.len(), ckpt.params.len());
|
||||
assert!((loaded.best_loss - ckpt.best_loss).abs() < 1e-6);
|
||||
let _ = std::fs::remove_file(&path);
|
||||
let _ = std::fs::remove_dir(&dir);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user