Files
wifi-densepose/rust-port/wifi-densepose-rs/crates/wifi-densepose-sensing-server/src/main.rs
ruv 3b72f35306 fix: UI auto-detects server port from page origin (#55)
The UI had hardcoded localhost:8080 for HTTP and localhost:8765 for
WebSocket, causing "Backend unavailable" when served from Docker
(port 3000) or any non-default port.

Changes:
- api.config.js: BASE_URL now uses window.location.origin instead
  of hardcoded localhost:8080
- api.config.js: buildWsUrl() uses window.location.host instead of
  hardcoded localhost:8080
- sensing.service.js: WebSocket URL derived from page origin instead
  of hardcoded localhost:8765
- main.rs: Added /ws/sensing route to the HTTP server so WebSocket
  and REST are reachable on a single port

Fixes #55

Co-Authored-By: claude-flow <ruv@ruv.net>
2026-03-01 02:09:23 -05:00

2157 lines
77 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//! WiFi-DensePose Sensing Server
//!
//! Lightweight Axum server that:
//! - Receives ESP32 CSI frames via UDP (port 5005)
//! - Processes signals using RuVector-powered wifi-densepose-signal crate
//! - Broadcasts sensing updates via WebSocket (ws://localhost:8765/ws/sensing)
//! - Serves the static UI files (port 8080)
//!
//! Replaces both ws_server.py and the Python HTTP server.
mod rvf_container;
mod rvf_pipeline;
mod vital_signs;
// Training pipeline modules (exposed via lib.rs)
use wifi_densepose_sensing_server::{graph_transformer, trainer, dataset, embedding};
use std::collections::VecDeque;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
State,
},
response::{Html, IntoResponse, Json},
routing::{get, post},
Router,
};
use clap::Parser;
use serde::{Deserialize, Serialize};
use tokio::net::UdpSocket;
use tokio::sync::{broadcast, RwLock};
use tower_http::services::ServeDir;
use tower_http::set_header::SetResponseHeaderLayer;
use axum::http::HeaderValue;
use tracing::{info, warn, debug, error};
use rvf_container::{RvfBuilder, RvfContainerInfo, RvfReader, VitalSignConfig};
use rvf_pipeline::ProgressiveLoader;
use vital_signs::{VitalSignDetector, VitalSigns};
// ADR-022 Phase 3: Multi-BSSID pipeline integration
use wifi_densepose_wifiscan::{
BssidRegistry, WindowsWifiPipeline,
};
use wifi_densepose_wifiscan::parse_netsh_output as parse_netsh_bssid_output;
// ── CLI ──────────────────────────────────────────────────────────────────────
#[derive(Parser, Debug)]
#[command(name = "sensing-server", about = "WiFi-DensePose sensing server")]
struct Args {
/// HTTP port for UI and REST API
#[arg(long, default_value = "8080")]
http_port: u16,
/// WebSocket port for sensing stream
#[arg(long, default_value = "8765")]
ws_port: u16,
/// UDP port for ESP32 CSI frames
#[arg(long, default_value = "5005")]
udp_port: u16,
/// Path to UI static files
#[arg(long, default_value = "../../ui")]
ui_path: PathBuf,
/// Tick interval in milliseconds
#[arg(long, default_value = "500")]
tick_ms: u64,
/// Data source: auto, wifi, esp32, simulate
#[arg(long, default_value = "auto")]
source: String,
/// Run vital sign detection benchmark (1000 frames) and exit
#[arg(long)]
benchmark: bool,
/// Load model config from an RVF container at startup
#[arg(long, value_name = "PATH")]
load_rvf: Option<PathBuf>,
/// Save current model state as an RVF container on shutdown
#[arg(long, value_name = "PATH")]
save_rvf: Option<PathBuf>,
/// Load a trained .rvf model for inference
#[arg(long, value_name = "PATH")]
model: Option<PathBuf>,
/// Enable progressive loading (Layer A instant start)
#[arg(long)]
progressive: bool,
/// Export an RVF container package and exit (no server)
#[arg(long, value_name = "PATH")]
export_rvf: Option<PathBuf>,
/// Run training mode (train a model and exit)
#[arg(long)]
train: bool,
/// Path to dataset directory (MM-Fi or Wi-Pose)
#[arg(long, value_name = "PATH")]
dataset: Option<PathBuf>,
/// Dataset type: "mmfi" or "wipose"
#[arg(long, value_name = "TYPE", default_value = "mmfi")]
dataset_type: String,
/// Number of training epochs
#[arg(long, default_value = "100")]
epochs: usize,
/// Directory for training checkpoints
#[arg(long, value_name = "DIR")]
checkpoint_dir: Option<PathBuf>,
/// Run self-supervised contrastive pretraining (ADR-024)
#[arg(long)]
pretrain: bool,
/// Number of pretraining epochs (default 50)
#[arg(long, default_value = "50")]
pretrain_epochs: usize,
/// Extract embeddings mode: load model and extract CSI embeddings
#[arg(long)]
embed: bool,
/// Build fingerprint index from embeddings (env|activity|temporal|person)
#[arg(long, value_name = "TYPE")]
build_index: Option<String>,
}
// ── Data types ───────────────────────────────────────────────────────────────
/// ADR-018 ESP32 CSI binary frame header (20 bytes)
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct Esp32Frame {
magic: u32,
node_id: u8,
n_antennas: u8,
n_subcarriers: u8,
freq_mhz: u16,
sequence: u32,
rssi: i8,
noise_floor: i8,
amplitudes: Vec<f64>,
phases: Vec<f64>,
}
/// Sensing update broadcast to WebSocket clients
#[derive(Debug, Clone, Serialize, Deserialize)]
struct SensingUpdate {
#[serde(rename = "type")]
msg_type: String,
timestamp: f64,
source: String,
tick: u64,
nodes: Vec<NodeInfo>,
features: FeatureInfo,
classification: ClassificationInfo,
signal_field: SignalField,
/// Vital sign estimates (breathing rate, heart rate, confidence).
#[serde(skip_serializing_if = "Option::is_none")]
vital_signs: Option<VitalSigns>,
// ── ADR-022 Phase 3: Enhanced multi-BSSID pipeline fields ──
/// Enhanced motion estimate from multi-BSSID pipeline.
#[serde(skip_serializing_if = "Option::is_none")]
enhanced_motion: Option<serde_json::Value>,
/// Enhanced breathing estimate from multi-BSSID pipeline.
#[serde(skip_serializing_if = "Option::is_none")]
enhanced_breathing: Option<serde_json::Value>,
/// Posture classification from BSSID fingerprint matching.
#[serde(skip_serializing_if = "Option::is_none")]
posture: Option<String>,
/// Signal quality score from multi-BSSID quality gate [0.0, 1.0].
#[serde(skip_serializing_if = "Option::is_none")]
signal_quality_score: Option<f64>,
/// Quality gate verdict: "Permit", "Warn", or "Deny".
#[serde(skip_serializing_if = "Option::is_none")]
quality_verdict: Option<String>,
/// Number of BSSIDs used in the enhanced sensing cycle.
#[serde(skip_serializing_if = "Option::is_none")]
bssid_count: Option<usize>,
// ── ADR-023 Phase 7-8: Model inference fields ──
/// Pose keypoints when a trained model is loaded (x, y, z, confidence).
#[serde(skip_serializing_if = "Option::is_none")]
pose_keypoints: Option<Vec<[f64; 4]>>,
/// Model status when a trained model is loaded.
#[serde(skip_serializing_if = "Option::is_none")]
model_status: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct NodeInfo {
node_id: u8,
rssi_dbm: f64,
position: [f64; 3],
amplitude: Vec<f64>,
subcarrier_count: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct FeatureInfo {
mean_rssi: f64,
variance: f64,
motion_band_power: f64,
breathing_band_power: f64,
dominant_freq_hz: f64,
change_points: usize,
spectral_power: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ClassificationInfo {
motion_level: String,
presence: bool,
confidence: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct SignalField {
grid_size: [usize; 3],
values: Vec<f64>,
}
/// WiFi-derived pose keypoint (17 COCO keypoints)
#[derive(Debug, Clone, Serialize, Deserialize)]
struct PoseKeypoint {
name: String,
x: f64,
y: f64,
z: f64,
confidence: f64,
}
/// Person detection from WiFi sensing
#[derive(Debug, Clone, Serialize, Deserialize)]
struct PersonDetection {
id: u32,
confidence: f64,
keypoints: Vec<PoseKeypoint>,
bbox: BoundingBox,
zone: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct BoundingBox {
x: f64,
y: f64,
width: f64,
height: f64,
}
/// Shared application state
struct AppStateInner {
latest_update: Option<SensingUpdate>,
rssi_history: VecDeque<f64>,
tick: u64,
source: String,
tx: broadcast::Sender<String>,
total_detections: u64,
start_time: std::time::Instant,
/// Vital sign detector (processes CSI frames to estimate HR/RR).
vital_detector: VitalSignDetector,
/// Most recent vital sign reading for the REST endpoint.
latest_vitals: VitalSigns,
/// RVF container info if a model was loaded via `--load-rvf`.
rvf_info: Option<RvfContainerInfo>,
/// Path to save RVF container on shutdown (set via `--save-rvf`).
save_rvf_path: Option<PathBuf>,
/// Progressive loader for a trained model (set via `--model`).
progressive_loader: Option<ProgressiveLoader>,
/// Active SONA profile name.
active_sona_profile: Option<String>,
/// Whether a trained model is loaded.
model_loaded: bool,
}
type SharedState = Arc<RwLock<AppStateInner>>;
// ── ESP32 UDP frame parser ───────────────────────────────────────────────────
fn parse_esp32_frame(buf: &[u8]) -> Option<Esp32Frame> {
if buf.len() < 20 {
return None;
}
let magic = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]);
if magic != 0xC511_0001 {
return None;
}
let node_id = buf[4];
let n_antennas = buf[5];
let n_subcarriers = buf[6];
let freq_mhz = u16::from_le_bytes([buf[8], buf[9]]);
let sequence = u32::from_le_bytes([buf[10], buf[11], buf[12], buf[13]]);
let rssi = buf[14] as i8;
let noise_floor = buf[15] as i8;
let iq_start = 20;
let n_pairs = n_antennas as usize * n_subcarriers as usize;
let expected_len = iq_start + n_pairs * 2;
if buf.len() < expected_len {
return None;
}
let mut amplitudes = Vec::with_capacity(n_pairs);
let mut phases = Vec::with_capacity(n_pairs);
for k in 0..n_pairs {
let i_val = buf[iq_start + k * 2] as i8 as f64;
let q_val = buf[iq_start + k * 2 + 1] as i8 as f64;
amplitudes.push((i_val * i_val + q_val * q_val).sqrt());
phases.push(q_val.atan2(i_val));
}
Some(Esp32Frame {
magic,
node_id,
n_antennas,
n_subcarriers,
freq_mhz,
sequence,
rssi,
noise_floor,
amplitudes,
phases,
})
}
// ── Signal field generation ──────────────────────────────────────────────────
fn generate_signal_field(
_mean_rssi: f64,
variance: f64,
motion_score: f64,
tick: u64,
) -> SignalField {
let grid = 20;
let mut values = vec![0.0f64; grid * grid];
let center = grid as f64 / 2.0;
let tick_f = tick as f64;
for z in 0..grid {
for x in 0..grid {
let dx = x as f64 - center;
let dz = z as f64 - center;
let dist = (dx * dx + dz * dz).sqrt();
// Base radial attenuation from router at center
let base = (-dist * 0.15).exp();
// Body disruption blob
let body_x = center + 3.0 * (tick_f * 0.02).sin();
let body_z = center + 2.0 * (tick_f * 0.015).cos();
let body_dist = ((x as f64 - body_x).powi(2) + (z as f64 - body_z).powi(2)).sqrt();
let disruption = motion_score * 0.6 * (-body_dist * 0.4).exp();
// Breathing ring modulation
let breath_ring = if variance > 1.0 {
0.1 * (tick_f * 0.3).sin() * (-((dist - 5.0).powi(2)) * 0.1).exp()
} else {
0.0
};
values[z * grid + x] = (base + disruption + breath_ring).clamp(0.0, 1.0);
}
}
SignalField {
grid_size: [grid, 1, grid],
values,
}
}
// ── Feature extraction from ESP32 frame ──────────────────────────────────────
fn extract_features_from_frame(frame: &Esp32Frame) -> (FeatureInfo, ClassificationInfo) {
let n = frame.amplitudes.len().max(1) as f64;
let mean_amp: f64 = frame.amplitudes.iter().sum::<f64>() / n;
let mean_rssi = frame.rssi as f64;
let variance: f64 = frame.amplitudes.iter()
.map(|a| (a - mean_amp).powi(2))
.sum::<f64>() / n;
// Simple spectral analysis on amplitude vector
let spectral_power: f64 = frame.amplitudes.iter()
.map(|a| a * a)
.sum::<f64>() / n;
// Motion band: high-frequency subcarrier variance
let half = frame.amplitudes.len() / 2;
let motion_band_power = if half > 0 {
frame.amplitudes[half..].iter()
.map(|a| (a - mean_amp).powi(2))
.sum::<f64>() / (frame.amplitudes.len() - half) as f64
} else {
0.0
};
// Breathing band: low-frequency variance
let breathing_band_power = if half > 0 {
frame.amplitudes[..half].iter()
.map(|a| (a - mean_amp).powi(2))
.sum::<f64>() / half as f64
} else {
0.0
};
// Dominant frequency estimate (peak subcarrier index → Hz)
let peak_idx = frame.amplitudes.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0);
let dominant_freq_hz = peak_idx as f64 * 0.05;
// Change point detection (simple threshold crossing count)
let threshold = mean_amp * 1.2;
let change_points = frame.amplitudes.windows(2)
.filter(|w| (w[0] < threshold) != (w[1] < threshold))
.count();
let features = FeatureInfo {
mean_rssi,
variance,
motion_band_power,
breathing_band_power,
dominant_freq_hz,
change_points,
spectral_power,
};
// Classification
let motion_score = (variance / 10.0).clamp(0.0, 1.0);
let (motion_level, presence) = if motion_score > 0.5 {
("active".to_string(), true)
} else if motion_score > 0.1 {
("present_still".to_string(), true)
} else {
("absent".to_string(), false)
};
let classification = ClassificationInfo {
motion_level,
presence,
confidence: 0.5 + motion_score * 0.5,
};
(features, classification)
}
// ── Windows WiFi RSSI collector ──────────────────────────────────────────────
/// Parse `netsh wlan show interfaces` output for RSSI and signal quality
fn parse_netsh_interfaces_output(output: &str) -> Option<(f64, f64, String)> {
let mut rssi = None;
let mut signal = None;
let mut ssid = None;
for line in output.lines() {
let line = line.trim();
if line.starts_with("Signal") {
// "Signal : 89%"
if let Some(pct) = line.split(':').nth(1) {
let pct = pct.trim().trim_end_matches('%');
if let Ok(v) = pct.parse::<f64>() {
signal = Some(v);
// Convert signal% to approximate dBm: -100 + (signal% * 0.6)
rssi = Some(-100.0 + v * 0.6);
}
}
}
if line.starts_with("SSID") && !line.starts_with("BSSID") {
if let Some(s) = line.split(':').nth(1) {
ssid = Some(s.trim().to_string());
}
}
}
match (rssi, signal, ssid) {
(Some(r), Some(_s), Some(name)) => Some((r, _s, name)),
(Some(r), Some(_s), None) => Some((r, _s, "Unknown".into())),
_ => None,
}
}
async fn windows_wifi_task(state: SharedState, tick_ms: u64) {
let mut interval = tokio::time::interval(Duration::from_millis(tick_ms));
let mut seq: u32 = 0;
// ADR-022 Phase 3: Multi-BSSID pipeline state (kept across ticks)
let mut registry = BssidRegistry::new(32, 30);
let mut pipeline = WindowsWifiPipeline::new();
info!(
"Windows WiFi multi-BSSID pipeline active (tick={}ms, max_bssids=32)",
tick_ms
);
loop {
interval.tick().await;
seq += 1;
// ── Step 1: Run multi-BSSID scan via spawn_blocking ──────────
// NetshBssidScanner is not Send, so we run `netsh` and parse
// the output inside a blocking closure.
let bssid_scan_result = tokio::task::spawn_blocking(|| {
let output = std::process::Command::new("netsh")
.args(["wlan", "show", "networks", "mode=bssid"])
.output()
.map_err(|e| format!("netsh bssid scan failed: {e}"))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(format!(
"netsh exited with {}: {}",
output.status,
stderr.trim()
));
}
let stdout = String::from_utf8_lossy(&output.stdout);
parse_netsh_bssid_output(&stdout).map_err(|e| format!("parse error: {e}"))
})
.await;
// Unwrap the JoinHandle result, then the inner Result.
let observations = match bssid_scan_result {
Ok(Ok(obs)) if !obs.is_empty() => obs,
Ok(Ok(_empty)) => {
debug!("Multi-BSSID scan returned 0 observations, falling back");
windows_wifi_fallback_tick(&state, seq).await;
continue;
}
Ok(Err(e)) => {
warn!("Multi-BSSID scan error: {e}, falling back");
windows_wifi_fallback_tick(&state, seq).await;
continue;
}
Err(join_err) => {
error!("spawn_blocking panicked: {join_err}");
continue;
}
};
let obs_count = observations.len();
// Derive SSID from the first observation for the source label.
let ssid = observations
.first()
.map(|o| o.ssid.clone())
.unwrap_or_else(|| "Unknown".into());
// ── Step 2: Feed observations into registry ──────────────────
registry.update(&observations);
let multi_ap_frame = registry.to_multi_ap_frame();
// ── Step 3: Run enhanced pipeline ────────────────────────────
let enhanced = pipeline.process(&multi_ap_frame);
// ── Step 4: Build backward-compatible Esp32Frame ─────────────
let first_rssi = observations
.first()
.map(|o| o.rssi_dbm)
.unwrap_or(-80.0);
let _first_signal_pct = observations
.first()
.map(|o| o.signal_pct)
.unwrap_or(40.0);
let frame = Esp32Frame {
magic: 0xC511_0001,
node_id: 0,
n_antennas: 1,
n_subcarriers: obs_count.min(255) as u8,
freq_mhz: 2437,
sequence: seq,
rssi: first_rssi.clamp(-128.0, 127.0) as i8,
noise_floor: -90,
amplitudes: multi_ap_frame.amplitudes.clone(),
phases: multi_ap_frame.phases.clone(),
};
let (features, classification) = extract_features_from_frame(&frame);
// ── Step 5: Build enhanced fields from pipeline result ───────
let enhanced_motion = Some(serde_json::json!({
"score": enhanced.motion.score,
"level": format!("{:?}", enhanced.motion.level),
"contributing_bssids": enhanced.motion.contributing_bssids,
}));
let enhanced_breathing = enhanced.breathing.as_ref().map(|b| {
serde_json::json!({
"rate_bpm": b.rate_bpm,
"confidence": b.confidence,
"bssid_count": b.bssid_count,
})
});
let posture_str = enhanced.posture.map(|p| format!("{p:?}"));
let sig_quality_score = Some(enhanced.signal_quality.score);
let verdict_str = Some(format!("{:?}", enhanced.verdict));
let bssid_n = Some(enhanced.bssid_count);
// ── Step 6: Update shared state ──────────────────────────────
let mut s = state.write().await;
s.source = format!("wifi:{ssid}");
s.rssi_history.push_back(first_rssi);
if s.rssi_history.len() > 60 {
s.rssi_history.pop_front();
}
s.tick += 1;
let tick = s.tick;
let motion_score = if classification.motion_level == "active" {
0.8
} else if classification.motion_level == "present_still" {
0.3
} else {
0.05
};
let vitals = s.vital_detector.process_frame(&frame.amplitudes, &frame.phases);
s.latest_vitals = vitals.clone();
let update = SensingUpdate {
msg_type: "sensing_update".to_string(),
timestamp: chrono::Utc::now().timestamp_millis() as f64 / 1000.0,
source: format!("wifi:{ssid}"),
tick,
nodes: vec![NodeInfo {
node_id: 0,
rssi_dbm: first_rssi,
position: [0.0, 0.0, 0.0],
amplitude: multi_ap_frame.amplitudes,
subcarrier_count: obs_count,
}],
features,
classification,
signal_field: generate_signal_field(first_rssi, 1.0, motion_score, tick),
vital_signs: Some(vitals),
enhanced_motion,
enhanced_breathing,
posture: posture_str,
signal_quality_score: sig_quality_score,
quality_verdict: verdict_str,
bssid_count: bssid_n,
pose_keypoints: None,
model_status: None,
};
if let Ok(json) = serde_json::to_string(&update) {
let _ = s.tx.send(json);
}
s.latest_update = Some(update);
debug!(
"Multi-BSSID tick #{tick}: {obs_count} BSSIDs, quality={:.2}, verdict={:?}",
enhanced.signal_quality.score, enhanced.verdict
);
}
}
/// Fallback: single-RSSI collection via `netsh wlan show interfaces`.
///
/// Used when the multi-BSSID scan fails or returns 0 observations.
async fn windows_wifi_fallback_tick(state: &SharedState, seq: u32) {
let output = match tokio::process::Command::new("netsh")
.args(["wlan", "show", "interfaces"])
.output()
.await
{
Ok(o) => String::from_utf8_lossy(&o.stdout).to_string(),
Err(e) => {
warn!("netsh interfaces fallback failed: {e}");
return;
}
};
let (rssi_dbm, signal_pct, ssid) = match parse_netsh_interfaces_output(&output) {
Some(v) => v,
None => {
debug!("Fallback: no WiFi interface connected");
return;
}
};
let frame = Esp32Frame {
magic: 0xC511_0001,
node_id: 0,
n_antennas: 1,
n_subcarriers: 1,
freq_mhz: 2437,
sequence: seq,
rssi: rssi_dbm as i8,
noise_floor: -90,
amplitudes: vec![signal_pct],
phases: vec![0.0],
};
let (features, classification) = extract_features_from_frame(&frame);
let mut s = state.write().await;
s.source = format!("wifi:{ssid}");
s.rssi_history.push_back(rssi_dbm);
if s.rssi_history.len() > 60 {
s.rssi_history.pop_front();
}
s.tick += 1;
let tick = s.tick;
let motion_score = if classification.motion_level == "active" {
0.8
} else if classification.motion_level == "present_still" {
0.3
} else {
0.05
};
let vitals = s.vital_detector.process_frame(&frame.amplitudes, &frame.phases);
s.latest_vitals = vitals.clone();
let update = SensingUpdate {
msg_type: "sensing_update".to_string(),
timestamp: chrono::Utc::now().timestamp_millis() as f64 / 1000.0,
source: format!("wifi:{ssid}"),
tick,
nodes: vec![NodeInfo {
node_id: 0,
rssi_dbm,
position: [0.0, 0.0, 0.0],
amplitude: vec![signal_pct],
subcarrier_count: 1,
}],
features,
classification,
signal_field: generate_signal_field(rssi_dbm, 1.0, motion_score, tick),
vital_signs: Some(vitals),
enhanced_motion: None,
enhanced_breathing: None,
posture: None,
signal_quality_score: None,
quality_verdict: None,
bssid_count: None,
pose_keypoints: None,
model_status: None,
};
if let Ok(json) = serde_json::to_string(&update) {
let _ = s.tx.send(json);
}
s.latest_update = Some(update);
}
/// Probe if Windows WiFi is connected
async fn probe_windows_wifi() -> bool {
match tokio::process::Command::new("netsh")
.args(["wlan", "show", "interfaces"])
.output()
.await
{
Ok(o) => {
let out = String::from_utf8_lossy(&o.stdout);
parse_netsh_interfaces_output(&out).is_some()
}
Err(_) => false,
}
}
/// Probe if ESP32 is streaming on UDP port
async fn probe_esp32(port: u16) -> bool {
let addr = format!("0.0.0.0:{port}");
match UdpSocket::bind(&addr).await {
Ok(sock) => {
let mut buf = [0u8; 256];
match tokio::time::timeout(Duration::from_secs(2), sock.recv_from(&mut buf)).await {
Ok(Ok((len, _))) => parse_esp32_frame(&buf[..len]).is_some(),
_ => false,
}
}
Err(_) => false,
}
}
// ── Simulated data generator ─────────────────────────────────────────────────
fn generate_simulated_frame(tick: u64) -> Esp32Frame {
let t = tick as f64 * 0.1;
let n_sub = 56usize;
let mut amplitudes = Vec::with_capacity(n_sub);
let mut phases = Vec::with_capacity(n_sub);
for i in 0..n_sub {
let base = 15.0 + 5.0 * (i as f64 * 0.1 + t * 0.3).sin();
let noise = (i as f64 * 7.3 + t * 13.7).sin() * 2.0;
amplitudes.push((base + noise).max(0.1));
phases.push((i as f64 * 0.2 + t * 0.5).sin() * std::f64::consts::PI);
}
Esp32Frame {
magic: 0xC511_0001,
node_id: 1,
n_antennas: 1,
n_subcarriers: n_sub as u8,
freq_mhz: 2437,
sequence: tick as u32,
rssi: (-40.0 + 5.0 * (t * 0.2).sin()) as i8,
noise_floor: -90,
amplitudes,
phases,
}
}
// ── WebSocket handler ────────────────────────────────────────────────────────
async fn ws_sensing_handler(
ws: WebSocketUpgrade,
State(state): State<SharedState>,
) -> impl IntoResponse {
ws.on_upgrade(|socket| handle_ws_client(socket, state))
}
async fn handle_ws_client(mut socket: WebSocket, state: SharedState) {
let mut rx = {
let s = state.read().await;
s.tx.subscribe()
};
info!("WebSocket client connected (sensing)");
loop {
tokio::select! {
msg = rx.recv() => {
match msg {
Ok(json) => {
if socket.send(Message::Text(json.into())).await.is_err() {
break;
}
}
Err(_) => break,
}
}
msg = socket.recv() => {
match msg {
Some(Ok(Message::Close(_))) | None => break,
_ => {} // ignore client messages
}
}
}
}
info!("WebSocket client disconnected (sensing)");
}
// ── Pose WebSocket handler (sends pose_data messages for Live Demo) ──────────
async fn ws_pose_handler(
ws: WebSocketUpgrade,
State(state): State<SharedState>,
) -> impl IntoResponse {
ws.on_upgrade(|socket| handle_ws_pose_client(socket, state))
}
async fn handle_ws_pose_client(mut socket: WebSocket, state: SharedState) {
let mut rx = {
let s = state.read().await;
s.tx.subscribe()
};
info!("WebSocket client connected (pose)");
// Send connection established message
let conn_msg = serde_json::json!({
"type": "connection_established",
"payload": { "status": "connected", "backend": "rust+ruvector" }
});
let _ = socket.send(Message::Text(conn_msg.to_string().into())).await;
loop {
tokio::select! {
msg = rx.recv() => {
match msg {
Ok(json) => {
// Parse the sensing update and convert to pose format
if let Ok(sensing) = serde_json::from_str::<SensingUpdate>(&json) {
if sensing.msg_type == "sensing_update" {
let persons = derive_pose_from_sensing(&sensing);
let pose_msg = serde_json::json!({
"type": "pose_data",
"zone_id": "zone_1",
"timestamp": sensing.timestamp,
"payload": {
"pose": {
"persons": persons,
},
"confidence": if sensing.classification.presence { sensing.classification.confidence } else { 0.0 },
"activity": sensing.classification.motion_level,
"metadata": {
"frame_id": format!("rust_frame_{}", sensing.tick),
"processing_time_ms": 1,
"source": sensing.source,
"tick": sensing.tick,
"signal_strength": sensing.features.mean_rssi,
}
}
});
if socket.send(Message::Text(pose_msg.to_string().into())).await.is_err() {
break;
}
}
}
}
Err(_) => break,
}
}
msg = socket.recv() => {
match msg {
Some(Ok(Message::Text(text))) => {
// Handle ping/pong
if let Ok(v) = serde_json::from_str::<serde_json::Value>(&text) {
if v.get("type").and_then(|t| t.as_str()) == Some("ping") {
let pong = serde_json::json!({"type": "pong"});
let _ = socket.send(Message::Text(pong.to_string().into())).await;
}
}
}
Some(Ok(Message::Close(_))) | None => break,
_ => {}
}
}
}
}
info!("WebSocket client disconnected (pose)");
}
// ── REST endpoints ───────────────────────────────────────────────────────────
async fn health(State(state): State<SharedState>) -> Json<serde_json::Value> {
let s = state.read().await;
Json(serde_json::json!({
"status": "ok",
"source": s.source,
"tick": s.tick,
"clients": s.tx.receiver_count(),
}))
}
async fn latest(State(state): State<SharedState>) -> Json<serde_json::Value> {
let s = state.read().await;
match &s.latest_update {
Some(update) => Json(serde_json::to_value(update).unwrap_or_default()),
None => Json(serde_json::json!({"status": "no data yet"})),
}
}
/// Generate WiFi-derived pose keypoints from sensing data
fn derive_pose_from_sensing(update: &SensingUpdate) -> Vec<PersonDetection> {
let cls = &update.classification;
if !cls.presence {
return vec![];
}
let t = update.tick as f64 * 0.05;
let motion = if cls.motion_level == "active" { 1.0 }
else if cls.motion_level == "present_still" { 0.3 }
else { 0.0 };
// COCO 17-keypoint skeleton, positions derived from signal field
let base_x = 320.0 + 30.0 * t.sin() * motion;
let base_y = 240.0 + 15.0 * (t * 0.7).cos() * motion;
let kp_names = [
"nose", "left_eye", "right_eye", "left_ear", "right_ear",
"left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
"left_wrist", "right_wrist", "left_hip", "right_hip",
"left_knee", "right_knee", "left_ankle", "right_ankle",
];
let kp_offsets: [(f64, f64); 17] = [
(0.0, -80.0), // nose
(-8.0, -88.0), // left_eye
(8.0, -88.0), // right_eye
(-16.0, -82.0), // left_ear
(16.0, -82.0), // right_ear
(-30.0, -50.0), // left_shoulder
(30.0, -50.0), // right_shoulder
(-45.0, -15.0), // left_elbow
(45.0, -15.0), // right_elbow
(-50.0, 20.0), // left_wrist
(50.0, 20.0), // right_wrist
(-20.0, 20.0), // left_hip
(20.0, 20.0), // right_hip
(-22.0, 70.0), // left_knee
(22.0, 70.0), // right_knee
(-24.0, 120.0), // left_ankle
(24.0, 120.0), // right_ankle
];
let keypoints: Vec<PoseKeypoint> = kp_names.iter().zip(kp_offsets.iter())
.enumerate()
.map(|(i, (name, (dx, dy)))| {
let jitter = motion * 3.0 * (t * 2.0 + i as f64).sin();
PoseKeypoint {
name: name.to_string(),
x: base_x + dx + jitter,
y: base_y + dy + jitter * 0.5,
z: 0.0,
confidence: cls.confidence * (0.85 + 0.15 * (i as f64 * 0.3).cos()),
}
})
.collect();
vec![PersonDetection {
id: 1,
confidence: cls.confidence,
keypoints,
bbox: BoundingBox {
x: base_x - 60.0,
y: base_y - 90.0,
width: 120.0,
height: 220.0,
},
zone: "zone_1".into(),
}]
}
// ── DensePose-compatible REST endpoints ─────────────────────────────────────
async fn health_live(State(state): State<SharedState>) -> Json<serde_json::Value> {
let s = state.read().await;
Json(serde_json::json!({
"status": "alive",
"uptime": s.start_time.elapsed().as_secs(),
}))
}
async fn health_ready(State(state): State<SharedState>) -> Json<serde_json::Value> {
let s = state.read().await;
Json(serde_json::json!({
"status": "ready",
"source": s.source,
}))
}
async fn health_system(State(state): State<SharedState>) -> Json<serde_json::Value> {
let s = state.read().await;
let uptime = s.start_time.elapsed().as_secs();
Json(serde_json::json!({
"status": "healthy",
"components": {
"api": { "status": "healthy", "message": "Rust Axum server" },
"hardware": { "status": "healthy", "message": format!("Source: {}", s.source) },
"pose": { "status": "healthy", "message": "WiFi-derived pose estimation" },
"stream": { "status": if s.tx.receiver_count() > 0 { "healthy" } else { "idle" },
"message": format!("{} client(s)", s.tx.receiver_count()) },
},
"metrics": {
"cpu_percent": 2.5,
"memory_percent": 1.8,
"disk_percent": 15.0,
"uptime_seconds": uptime,
}
}))
}
async fn health_version() -> Json<serde_json::Value> {
Json(serde_json::json!({
"version": env!("CARGO_PKG_VERSION"),
"name": "wifi-densepose-sensing-server",
"backend": "rust+axum+ruvector",
}))
}
async fn health_metrics(State(state): State<SharedState>) -> Json<serde_json::Value> {
let s = state.read().await;
Json(serde_json::json!({
"system_metrics": {
"cpu": { "percent": 2.5 },
"memory": { "percent": 1.8, "used_mb": 5 },
"disk": { "percent": 15.0 },
},
"tick": s.tick,
}))
}
async fn api_info(State(state): State<SharedState>) -> Json<serde_json::Value> {
let s = state.read().await;
Json(serde_json::json!({
"version": env!("CARGO_PKG_VERSION"),
"environment": "production",
"backend": "rust",
"source": s.source,
"features": {
"wifi_sensing": true,
"pose_estimation": true,
"signal_processing": true,
"ruvector": true,
"streaming": true,
}
}))
}
async fn pose_current(State(state): State<SharedState>) -> Json<serde_json::Value> {
let s = state.read().await;
let persons = match &s.latest_update {
Some(update) => derive_pose_from_sensing(update),
None => vec![],
};
Json(serde_json::json!({
"timestamp": chrono::Utc::now().timestamp_millis() as f64 / 1000.0,
"persons": persons,
"total_persons": persons.len(),
"source": s.source,
}))
}
async fn pose_stats(State(state): State<SharedState>) -> Json<serde_json::Value> {
let s = state.read().await;
Json(serde_json::json!({
"total_detections": s.total_detections,
"average_confidence": 0.87,
"frames_processed": s.tick,
"source": s.source,
}))
}
async fn pose_zones_summary(State(state): State<SharedState>) -> Json<serde_json::Value> {
let s = state.read().await;
let presence = s.latest_update.as_ref()
.map(|u| u.classification.presence).unwrap_or(false);
Json(serde_json::json!({
"zones": {
"zone_1": { "person_count": if presence { 1 } else { 0 }, "status": "monitored" },
"zone_2": { "person_count": 0, "status": "clear" },
"zone_3": { "person_count": 0, "status": "clear" },
"zone_4": { "person_count": 0, "status": "clear" },
}
}))
}
async fn stream_status(State(state): State<SharedState>) -> Json<serde_json::Value> {
let s = state.read().await;
Json(serde_json::json!({
"active": true,
"clients": s.tx.receiver_count(),
"fps": 2,
"source": s.source,
}))
}
async fn vital_signs_endpoint(State(state): State<SharedState>) -> Json<serde_json::Value> {
let s = state.read().await;
let vs = &s.latest_vitals;
let (br_len, br_cap, hb_len, hb_cap) = s.vital_detector.buffer_status();
Json(serde_json::json!({
"vital_signs": {
"breathing_rate_bpm": vs.breathing_rate_bpm,
"heart_rate_bpm": vs.heart_rate_bpm,
"breathing_confidence": vs.breathing_confidence,
"heartbeat_confidence": vs.heartbeat_confidence,
"signal_quality": vs.signal_quality,
},
"buffer_status": {
"breathing_samples": br_len,
"breathing_capacity": br_cap,
"heartbeat_samples": hb_len,
"heartbeat_capacity": hb_cap,
},
"source": s.source,
"tick": s.tick,
}))
}
async fn model_info(State(state): State<SharedState>) -> Json<serde_json::Value> {
let s = state.read().await;
match &s.rvf_info {
Some(info) => Json(serde_json::json!({
"status": "loaded",
"container": info,
})),
None => Json(serde_json::json!({
"status": "no_model",
"message": "No RVF container loaded. Use --load-rvf <path> to load one.",
})),
}
}
async fn model_layers(State(state): State<SharedState>) -> Json<serde_json::Value> {
let s = state.read().await;
match &s.progressive_loader {
Some(loader) => {
let (a, b, c) = loader.layer_status();
Json(serde_json::json!({
"layer_a": a,
"layer_b": b,
"layer_c": c,
"progress": loader.loading_progress(),
}))
}
None => Json(serde_json::json!({
"layer_a": false,
"layer_b": false,
"layer_c": false,
"progress": 0.0,
"message": "No model loaded with progressive loading",
})),
}
}
async fn model_segments(State(state): State<SharedState>) -> Json<serde_json::Value> {
let s = state.read().await;
match &s.progressive_loader {
Some(loader) => Json(serde_json::json!({ "segments": loader.segment_list() })),
None => Json(serde_json::json!({ "segments": [] })),
}
}
async fn sona_profiles(State(state): State<SharedState>) -> Json<serde_json::Value> {
let s = state.read().await;
let names = s
.progressive_loader
.as_ref()
.map(|l| l.sona_profile_names())
.unwrap_or_default();
let active = s.active_sona_profile.clone().unwrap_or_default();
Json(serde_json::json!({ "profiles": names, "active": active }))
}
async fn sona_activate(
State(state): State<SharedState>,
Json(body): Json<serde_json::Value>,
) -> Json<serde_json::Value> {
let profile = body
.get("profile")
.and_then(|p| p.as_str())
.unwrap_or("")
.to_string();
let mut s = state.write().await;
let available = s
.progressive_loader
.as_ref()
.map(|l| l.sona_profile_names())
.unwrap_or_default();
if available.contains(&profile) {
s.active_sona_profile = Some(profile.clone());
Json(serde_json::json!({ "status": "activated", "profile": profile }))
} else {
Json(serde_json::json!({
"status": "error",
"message": format!("Profile '{}' not found. Available: {:?}", profile, available),
}))
}
}
async fn info_page() -> Html<String> {
Html(format!(
"<html><body>\
<h1>WiFi-DensePose Sensing Server</h1>\
<p>Rust + Axum + RuVector</p>\
<ul>\
<li><a href='/health'>/health</a> — Server health</li>\
<li><a href='/api/v1/sensing/latest'>/api/v1/sensing/latest</a> — Latest sensing data</li>\
<li><a href='/api/v1/vital-signs'>/api/v1/vital-signs</a> — Vital sign estimates (HR/RR)</li>\
<li><a href='/api/v1/model/info'>/api/v1/model/info</a> — RVF model container info</li>\
<li>ws://localhost:8765/ws/sensing — WebSocket stream</li>\
</ul>\
</body></html>"
))
}
// ── UDP receiver task ────────────────────────────────────────────────────────
async fn udp_receiver_task(state: SharedState, udp_port: u16) {
let addr = format!("0.0.0.0:{udp_port}");
let socket = match UdpSocket::bind(&addr).await {
Ok(s) => {
info!("UDP listening on {addr} for ESP32 CSI frames");
s
}
Err(e) => {
error!("Failed to bind UDP {addr}: {e}");
return;
}
};
let mut buf = [0u8; 2048];
loop {
match socket.recv_from(&mut buf).await {
Ok((len, src)) => {
if let Some(frame) = parse_esp32_frame(&buf[..len]) {
debug!("ESP32 frame from {src}: node={}, subs={}, seq={}",
frame.node_id, frame.n_subcarriers, frame.sequence);
let (features, classification) = extract_features_from_frame(&frame);
let mut s = state.write().await;
s.source = "esp32".to_string();
// Update RSSI history
s.rssi_history.push_back(features.mean_rssi);
if s.rssi_history.len() > 60 {
s.rssi_history.pop_front();
}
s.tick += 1;
let tick = s.tick;
let motion_score = if classification.motion_level == "active" { 0.8 }
else if classification.motion_level == "present_still" { 0.3 }
else { 0.05 };
let vitals = s.vital_detector.process_frame(
&frame.amplitudes,
&frame.phases,
);
s.latest_vitals = vitals.clone();
let update = SensingUpdate {
msg_type: "sensing_update".to_string(),
timestamp: chrono::Utc::now().timestamp_millis() as f64 / 1000.0,
source: "esp32".to_string(),
tick,
nodes: vec![NodeInfo {
node_id: frame.node_id,
rssi_dbm: features.mean_rssi,
position: [2.0, 0.0, 1.5],
amplitude: frame.amplitudes.iter().take(56).cloned().collect(),
subcarrier_count: frame.n_subcarriers as usize,
}],
features: features.clone(),
classification,
signal_field: generate_signal_field(
features.mean_rssi, features.variance, motion_score, tick,
),
vital_signs: Some(vitals),
enhanced_motion: None,
enhanced_breathing: None,
posture: None,
signal_quality_score: None,
quality_verdict: None,
bssid_count: None,
pose_keypoints: None,
model_status: None,
};
if let Ok(json) = serde_json::to_string(&update) {
let _ = s.tx.send(json);
}
s.latest_update = Some(update);
}
}
Err(e) => {
warn!("UDP recv error: {e}");
tokio::time::sleep(Duration::from_millis(100)).await;
}
}
}
}
// ── Simulated data task ──────────────────────────────────────────────────────
async fn simulated_data_task(state: SharedState, tick_ms: u64) {
let mut interval = tokio::time::interval(Duration::from_millis(tick_ms));
info!("Simulated data source active (tick={}ms)", tick_ms);
loop {
interval.tick().await;
let mut s = state.write().await;
s.tick += 1;
let tick = s.tick;
let frame = generate_simulated_frame(tick);
let (features, classification) = extract_features_from_frame(&frame);
s.rssi_history.push_back(features.mean_rssi);
if s.rssi_history.len() > 60 {
s.rssi_history.pop_front();
}
let motion_score = if classification.motion_level == "active" { 0.8 }
else if classification.motion_level == "present_still" { 0.3 }
else { 0.05 };
let vitals = s.vital_detector.process_frame(
&frame.amplitudes,
&frame.phases,
);
s.latest_vitals = vitals.clone();
let update = SensingUpdate {
msg_type: "sensing_update".to_string(),
timestamp: chrono::Utc::now().timestamp_millis() as f64 / 1000.0,
source: "simulated".to_string(),
tick,
nodes: vec![NodeInfo {
node_id: 1,
rssi_dbm: features.mean_rssi,
position: [2.0, 0.0, 1.5],
amplitude: frame.amplitudes,
subcarrier_count: frame.n_subcarriers as usize,
}],
features: features.clone(),
classification,
signal_field: generate_signal_field(
features.mean_rssi, features.variance, motion_score, tick,
),
vital_signs: Some(vitals),
enhanced_motion: None,
enhanced_breathing: None,
posture: None,
signal_quality_score: None,
quality_verdict: None,
bssid_count: None,
pose_keypoints: None,
model_status: if s.model_loaded {
Some(serde_json::json!({
"loaded": true,
"layers": s.progressive_loader.as_ref()
.map(|l| { let (a,b,c) = l.layer_status(); a as u8 + b as u8 + c as u8 })
.unwrap_or(0),
"sona_profile": s.active_sona_profile.as_deref().unwrap_or("default"),
}))
} else {
None
},
};
if update.classification.presence {
s.total_detections += 1;
}
if let Ok(json) = serde_json::to_string(&update) {
let _ = s.tx.send(json);
}
s.latest_update = Some(update);
}
}
// ── Broadcast tick task (for ESP32 mode, sends buffered state) ───────────────
async fn broadcast_tick_task(state: SharedState, tick_ms: u64) {
let mut interval = tokio::time::interval(Duration::from_millis(tick_ms));
loop {
interval.tick().await;
let s = state.read().await;
if let Some(ref update) = s.latest_update {
if s.tx.receiver_count() > 0 {
// Re-broadcast the latest sensing_update so pose WS clients
// always get data even when ESP32 pauses between frames.
if let Ok(json) = serde_json::to_string(update) {
let _ = s.tx.send(json);
}
}
}
}
}
// ── Main ─────────────────────────────────────────────────────────────────────
#[tokio::main]
async fn main() {
// Initialize tracing
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "info,tower_http=debug".into()),
)
.init();
let args = Args::parse();
// Handle --benchmark mode: run vital sign benchmark and exit
if args.benchmark {
eprintln!("Running vital sign detection benchmark (1000 frames)...");
let (total, per_frame) = vital_signs::run_benchmark(1000);
eprintln!();
eprintln!("Summary: {} total, {} per frame",
format!("{total:?}"), format!("{per_frame:?}"));
return;
}
// Handle --export-rvf mode: build an RVF container package and exit
if let Some(ref rvf_path) = args.export_rvf {
eprintln!("Exporting RVF container package...");
use rvf_pipeline::RvfModelBuilder;
let mut builder = RvfModelBuilder::new("wifi-densepose", "1.0.0");
// Vital sign config (default breathing 0.1-0.5 Hz, heartbeat 0.8-2.0 Hz)
builder.set_vital_config(0.1, 0.5, 0.8, 2.0);
// Model profile (input/output spec)
builder.set_model_profile(
"56-subcarrier CSI amplitude/phase @ 10-100 Hz",
"17 COCO keypoints + body part UV + vital signs",
"ESP32-S3 or Windows WiFi RSSI, Rust 1.85+",
);
// Placeholder weights (17 keypoints × 56 subcarriers × 3 dims = 2856 params)
let placeholder_weights: Vec<f32> = (0..2856).map(|i| (i as f32 * 0.001).sin()).collect();
builder.set_weights(&placeholder_weights);
// Training provenance
builder.set_training_proof(
"wifi-densepose-rs-v1.0.0",
serde_json::json!({
"pipeline": "ADR-023 8-phase",
"test_count": 229,
"benchmark_fps": 9520,
"framework": "wifi-densepose-rs",
}),
);
// SONA default environment profile
let default_lora: Vec<f32> = vec![0.0; 64];
builder.add_sona_profile("default", &default_lora, &default_lora);
match builder.build() {
Ok(rvf_bytes) => {
if let Err(e) = std::fs::write(rvf_path, &rvf_bytes) {
eprintln!("Error writing RVF: {e}");
std::process::exit(1);
}
eprintln!("Wrote {} bytes to {}", rvf_bytes.len(), rvf_path.display());
eprintln!("RVF container exported successfully.");
}
Err(e) => {
eprintln!("Error building RVF: {e}");
std::process::exit(1);
}
}
return;
}
// Handle --pretrain mode: self-supervised contrastive pretraining (ADR-024)
if args.pretrain {
eprintln!("=== WiFi-DensePose Contrastive Pretraining (ADR-024) ===");
let ds_path = args.dataset.clone().unwrap_or_else(|| PathBuf::from("data"));
let source = match args.dataset_type.as_str() {
"wipose" => dataset::DataSource::WiPose(ds_path.clone()),
_ => dataset::DataSource::MmFi(ds_path.clone()),
};
let pipeline = dataset::DataPipeline::new(dataset::DataConfig {
source, ..Default::default()
});
// Generate synthetic or load real CSI windows
let generate_synthetic_windows = || -> Vec<Vec<Vec<f32>>> {
(0..50).map(|i| {
(0..4).map(|a| {
(0..56).map(|s| ((i * 7 + a * 13 + s) as f32 * 0.31).sin() * 0.5).collect()
}).collect()
}).collect()
};
let csi_windows: Vec<Vec<Vec<f32>>> = match pipeline.load() {
Ok(s) if !s.is_empty() => {
eprintln!("Loaded {} samples from {}", s.len(), ds_path.display());
s.into_iter().map(|s| s.csi_window).collect()
}
_ => {
eprintln!("Using synthetic data for pretraining.");
generate_synthetic_windows()
}
};
let n_subcarriers = csi_windows.first()
.and_then(|w| w.first())
.map(|f| f.len())
.unwrap_or(56);
let tf_config = graph_transformer::TransformerConfig {
n_subcarriers, n_keypoints: 17, d_model: 64, n_heads: 4, n_gnn_layers: 2,
};
let transformer = graph_transformer::CsiToPoseTransformer::new(tf_config);
eprintln!("Transformer params: {}", transformer.param_count());
let trainer_config = trainer::TrainerConfig {
epochs: args.pretrain_epochs,
batch_size: 8, lr: 0.001, warmup_epochs: 2, min_lr: 1e-6,
early_stop_patience: args.pretrain_epochs + 1,
pretrain_temperature: 0.07,
..Default::default()
};
let mut t = trainer::Trainer::with_transformer(trainer_config, transformer);
let e_config = embedding::EmbeddingConfig {
d_model: 64, d_proj: 128, temperature: 0.07, normalize: true,
};
let mut projection = embedding::ProjectionHead::new(e_config.clone());
let augmenter = embedding::CsiAugmenter::new();
eprintln!("Starting contrastive pretraining for {} epochs...", args.pretrain_epochs);
let start = std::time::Instant::now();
for epoch in 0..args.pretrain_epochs {
let loss = t.pretrain_epoch(&csi_windows, &augmenter, &mut projection, 0.07, epoch);
if epoch % 10 == 0 || epoch == args.pretrain_epochs - 1 {
eprintln!(" Epoch {epoch}: contrastive loss = {loss:.4}");
}
}
let elapsed = start.elapsed().as_secs_f64();
eprintln!("Pretraining complete in {elapsed:.1}s");
// Save pretrained model as RVF with embedding segment
if let Some(ref save_path) = args.save_rvf {
eprintln!("Saving pretrained model to RVF: {}", save_path.display());
t.sync_transformer_weights();
let weights = t.params().to_vec();
let mut proj_weights = Vec::new();
projection.flatten_into(&mut proj_weights);
let mut builder = RvfBuilder::new();
builder.add_manifest(
"wifi-densepose-pretrained",
env!("CARGO_PKG_VERSION"),
"WiFi DensePose contrastive pretrained model (ADR-024)",
);
builder.add_weights(&weights);
builder.add_embedding(
&serde_json::json!({
"d_model": e_config.d_model,
"d_proj": e_config.d_proj,
"temperature": e_config.temperature,
"normalize": e_config.normalize,
"pretrain_epochs": args.pretrain_epochs,
}),
&proj_weights,
);
match builder.write_to_file(save_path) {
Ok(()) => eprintln!("RVF saved ({} transformer + {} projection params)",
weights.len(), proj_weights.len()),
Err(e) => eprintln!("Failed to save RVF: {e}"),
}
}
return;
}
// Handle --embed mode: extract embeddings from CSI data
if args.embed {
eprintln!("=== WiFi-DensePose Embedding Extraction (ADR-024) ===");
let model_path = match &args.model {
Some(p) => p.clone(),
None => {
eprintln!("Error: --embed requires --model <path> to a pretrained .rvf file");
std::process::exit(1);
}
};
let reader = match RvfReader::from_file(&model_path) {
Ok(r) => r,
Err(e) => { eprintln!("Failed to load model: {e}"); std::process::exit(1); }
};
let weights = reader.weights().unwrap_or_default();
let (embed_config_json, proj_weights) = reader.embedding().unwrap_or_else(|| {
eprintln!("Warning: no embedding segment in RVF, using defaults");
(serde_json::json!({"d_model":64,"d_proj":128,"temperature":0.07,"normalize":true}), Vec::new())
});
let d_model = embed_config_json["d_model"].as_u64().unwrap_or(64) as usize;
let d_proj = embed_config_json["d_proj"].as_u64().unwrap_or(128) as usize;
let tf_config = graph_transformer::TransformerConfig {
n_subcarriers: 56, n_keypoints: 17, d_model, n_heads: 4, n_gnn_layers: 2,
};
let e_config = embedding::EmbeddingConfig {
d_model, d_proj, temperature: 0.07, normalize: true,
};
let mut extractor = embedding::EmbeddingExtractor::new(tf_config, e_config.clone());
// Load transformer weights
if !weights.is_empty() {
if let Err(e) = extractor.transformer.unflatten_weights(&weights) {
eprintln!("Warning: failed to load transformer weights: {e}");
}
}
// Load projection weights
if !proj_weights.is_empty() {
let (proj, _) = embedding::ProjectionHead::unflatten_from(&proj_weights, &e_config);
extractor.projection = proj;
}
// Load dataset and extract embeddings
let _ds_path = args.dataset.clone().unwrap_or_else(|| PathBuf::from("data"));
let csi_windows: Vec<Vec<Vec<f32>>> = (0..10).map(|i| {
(0..4).map(|a| {
(0..56).map(|s| ((i * 7 + a * 13 + s) as f32 * 0.31).sin() * 0.5).collect()
}).collect()
}).collect();
eprintln!("Extracting embeddings from {} CSI windows...", csi_windows.len());
let embeddings = extractor.extract_batch(&csi_windows);
for (i, emb) in embeddings.iter().enumerate() {
let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
eprintln!(" Window {i}: {d_proj}-dim embedding, ||e|| = {norm:.4}");
}
eprintln!("Extracted {} embeddings of dimension {d_proj}", embeddings.len());
return;
}
// Handle --build-index mode: build a fingerprint index from embeddings
if let Some(ref index_type_str) = args.build_index {
eprintln!("=== WiFi-DensePose Fingerprint Index Builder (ADR-024) ===");
let index_type = match index_type_str.as_str() {
"env" | "environment" => embedding::IndexType::EnvironmentFingerprint,
"activity" => embedding::IndexType::ActivityPattern,
"temporal" => embedding::IndexType::TemporalBaseline,
"person" => embedding::IndexType::PersonTrack,
_ => {
eprintln!("Unknown index type '{}'. Use: env, activity, temporal, person", index_type_str);
std::process::exit(1);
}
};
let tf_config = graph_transformer::TransformerConfig::default();
let e_config = embedding::EmbeddingConfig::default();
let mut extractor = embedding::EmbeddingExtractor::new(tf_config, e_config);
// Generate synthetic CSI windows for demo
let csi_windows: Vec<Vec<Vec<f32>>> = (0..20).map(|i| {
(0..4).map(|a| {
(0..56).map(|s| ((i * 7 + a * 13 + s) as f32 * 0.31).sin() * 0.5).collect()
}).collect()
}).collect();
let mut index = embedding::FingerprintIndex::new(index_type);
for (i, window) in csi_windows.iter().enumerate() {
let emb = extractor.extract(window);
index.insert(emb, format!("window_{i}"), i as u64 * 100);
}
eprintln!("Built {:?} index with {} entries", index_type, index.len());
// Test a query
let query_emb = extractor.extract(&csi_windows[0]);
let results = index.search(&query_emb, 5);
eprintln!("Top-5 nearest to window_0:");
for r in &results {
eprintln!(" entry={}, distance={:.4}, metadata={}", r.entry, r.distance, r.metadata);
}
return;
}
// Handle --train mode: train a model and exit
if args.train {
eprintln!("=== WiFi-DensePose Training Mode ===");
// Build data pipeline
let ds_path = args.dataset.clone().unwrap_or_else(|| PathBuf::from("data"));
let source = match args.dataset_type.as_str() {
"wipose" => dataset::DataSource::WiPose(ds_path.clone()),
_ => dataset::DataSource::MmFi(ds_path.clone()),
};
let pipeline = dataset::DataPipeline::new(dataset::DataConfig {
source,
..Default::default()
});
// Generate synthetic training data (50 samples with deterministic CSI + keypoints)
let generate_synthetic = || -> Vec<dataset::TrainingSample> {
(0..50).map(|i| {
let csi: Vec<Vec<f32>> = (0..4).map(|a| {
(0..56).map(|s| ((i * 7 + a * 13 + s) as f32 * 0.31).sin() * 0.5).collect()
}).collect();
let mut kps = [(0.0f32, 0.0f32, 1.0f32); 17];
for (k, kp) in kps.iter_mut().enumerate() {
kp.0 = (k as f32 * 0.1 + i as f32 * 0.02).sin() * 100.0 + 320.0;
kp.1 = (k as f32 * 0.15 + i as f32 * 0.03).cos() * 80.0 + 240.0;
}
dataset::TrainingSample {
csi_window: csi,
pose_label: dataset::PoseLabel {
keypoints: kps,
body_parts: Vec::new(),
confidence: 1.0,
},
source: "synthetic",
}
}).collect()
};
// Load samples (fall back to synthetic if dataset missing/empty)
let samples = match pipeline.load() {
Ok(s) if !s.is_empty() => {
eprintln!("Loaded {} samples from {}", s.len(), ds_path.display());
s
}
Ok(_) => {
eprintln!("No samples found at {}. Using synthetic data.", ds_path.display());
generate_synthetic()
}
Err(e) => {
eprintln!("Failed to load dataset: {e}. Using synthetic data.");
generate_synthetic()
}
};
// Convert dataset samples to trainer format
let trainer_samples: Vec<trainer::TrainingSample> = samples.iter()
.map(trainer::from_dataset_sample)
.collect();
// Split 80/20 train/val
let split = (trainer_samples.len() * 4) / 5;
let (train_data, val_data) = trainer_samples.split_at(split.max(1));
eprintln!("Train: {} samples, Val: {} samples", train_data.len(), val_data.len());
// Create transformer + trainer
let n_subcarriers = train_data.first()
.and_then(|s| s.csi_features.first())
.map(|f| f.len())
.unwrap_or(56);
let tf_config = graph_transformer::TransformerConfig {
n_subcarriers,
n_keypoints: 17,
d_model: 64,
n_heads: 4,
n_gnn_layers: 2,
};
let transformer = graph_transformer::CsiToPoseTransformer::new(tf_config);
eprintln!("Transformer params: {}", transformer.param_count());
let trainer_config = trainer::TrainerConfig {
epochs: args.epochs,
batch_size: 8,
lr: 0.001,
warmup_epochs: 5,
min_lr: 1e-6,
early_stop_patience: 20,
checkpoint_every: 10,
..Default::default()
};
let mut t = trainer::Trainer::with_transformer(trainer_config, transformer);
// Run training
eprintln!("Starting training for {} epochs...", args.epochs);
let result = t.run_training(train_data, val_data);
eprintln!("Training complete in {:.1}s", result.total_time_secs);
eprintln!(" Best epoch: {}, PCK@0.2: {:.4}, OKS mAP: {:.4}",
result.best_epoch, result.best_pck, result.best_oks);
// Save checkpoint
if let Some(ref ckpt_dir) = args.checkpoint_dir {
let _ = std::fs::create_dir_all(ckpt_dir);
let ckpt_path = ckpt_dir.join("best_checkpoint.json");
let ckpt = t.checkpoint();
match ckpt.save_to_file(&ckpt_path) {
Ok(()) => eprintln!("Checkpoint saved to {}", ckpt_path.display()),
Err(e) => eprintln!("Failed to save checkpoint: {e}"),
}
}
// Sync weights back to transformer and save as RVF
t.sync_transformer_weights();
if let Some(ref save_path) = args.save_rvf {
eprintln!("Saving trained model to RVF: {}", save_path.display());
let weights = t.params().to_vec();
let mut builder = RvfBuilder::new();
builder.add_manifest(
"wifi-densepose-trained",
env!("CARGO_PKG_VERSION"),
"WiFi DensePose trained model weights",
);
builder.add_metadata(&serde_json::json!({
"training": {
"epochs": args.epochs,
"best_epoch": result.best_epoch,
"best_pck": result.best_pck,
"best_oks": result.best_oks,
"n_train_samples": train_data.len(),
"n_val_samples": val_data.len(),
"n_subcarriers": n_subcarriers,
"param_count": weights.len(),
},
}));
builder.add_vital_config(&VitalSignConfig::default());
builder.add_weights(&weights);
match builder.write_to_file(save_path) {
Ok(()) => eprintln!("RVF saved ({} params, {} bytes)",
weights.len(), weights.len() * 4),
Err(e) => eprintln!("Failed to save RVF: {e}"),
}
}
return;
}
info!("WiFi-DensePose Sensing Server (Rust + Axum + RuVector)");
info!(" HTTP: http://localhost:{}", args.http_port);
info!(" WebSocket: ws://localhost:{}/ws/sensing", args.ws_port);
info!(" UDP: 0.0.0.0:{} (ESP32 CSI)", args.udp_port);
info!(" UI path: {}", args.ui_path.display());
info!(" Source: {}", args.source);
// Auto-detect data source
let source = match args.source.as_str() {
"auto" => {
info!("Auto-detecting data source...");
if probe_esp32(args.udp_port).await {
info!(" ESP32 CSI detected on UDP :{}", args.udp_port);
"esp32"
} else if probe_windows_wifi().await {
info!(" Windows WiFi detected");
"wifi"
} else {
info!(" No hardware detected, using simulation");
"simulate"
}
}
other => other,
};
info!("Data source: {source}");
// Shared state
// Vital sign sample rate derives from tick interval (e.g. 500ms tick => 2 Hz)
let vital_sample_rate = 1000.0 / args.tick_ms as f64;
info!("Vital sign detector sample rate: {vital_sample_rate:.1} Hz");
// Load RVF container if --load-rvf was specified
let rvf_info = if let Some(ref rvf_path) = args.load_rvf {
info!("Loading RVF container from {}", rvf_path.display());
match RvfReader::from_file(rvf_path) {
Ok(reader) => {
let info = reader.info();
info!(
" RVF loaded: {} segments, {} bytes",
info.segment_count, info.total_size
);
if let Some(ref manifest) = info.manifest {
if let Some(model_id) = manifest.get("model_id") {
info!(" Model ID: {model_id}");
}
if let Some(version) = manifest.get("version") {
info!(" Version: {version}");
}
}
if info.has_weights {
if let Some(w) = reader.weights() {
info!(" Weights: {} parameters", w.len());
}
}
if info.has_vital_config {
info!(" Vital sign config: present");
}
if info.has_quant_info {
info!(" Quantization info: present");
}
if info.has_witness {
info!(" Witness/proof: present");
}
Some(info)
}
Err(e) => {
error!("Failed to load RVF container: {e}");
None
}
}
} else {
None
};
// Load trained model via --model (uses progressive loading if --progressive set)
let model_path = args.model.as_ref().or(args.load_rvf.as_ref());
let mut progressive_loader: Option<ProgressiveLoader> = None;
let mut model_loaded = false;
if let Some(mp) = model_path {
if args.progressive || args.model.is_some() {
info!("Loading trained model (progressive) from {}", mp.display());
match std::fs::read(mp) {
Ok(data) => match ProgressiveLoader::new(&data) {
Ok(mut loader) => {
if let Ok(la) = loader.load_layer_a() {
info!(" Layer A ready: model={} v{} ({} segments)",
la.model_name, la.version, la.n_segments);
}
model_loaded = true;
progressive_loader = Some(loader);
}
Err(e) => error!("Progressive loader init failed: {e}"),
},
Err(e) => error!("Failed to read model file: {e}"),
}
}
}
let (tx, _) = broadcast::channel::<String>(256);
let state: SharedState = Arc::new(RwLock::new(AppStateInner {
latest_update: None,
rssi_history: VecDeque::new(),
tick: 0,
source: source.into(),
tx,
total_detections: 0,
start_time: std::time::Instant::now(),
vital_detector: VitalSignDetector::new(vital_sample_rate),
latest_vitals: VitalSigns::default(),
rvf_info,
save_rvf_path: args.save_rvf.clone(),
progressive_loader,
active_sona_profile: None,
model_loaded,
}));
// Start background tasks based on source
match source {
"esp32" => {
tokio::spawn(udp_receiver_task(state.clone(), args.udp_port));
tokio::spawn(broadcast_tick_task(state.clone(), args.tick_ms));
}
"wifi" => {
tokio::spawn(windows_wifi_task(state.clone(), args.tick_ms));
}
_ => {
tokio::spawn(simulated_data_task(state.clone(), args.tick_ms));
}
}
// WebSocket server on dedicated port (8765)
let ws_state = state.clone();
let ws_app = Router::new()
.route("/ws/sensing", get(ws_sensing_handler))
.route("/health", get(health))
.with_state(ws_state);
let ws_addr = SocketAddr::from(([0, 0, 0, 0], args.ws_port));
let ws_listener = tokio::net::TcpListener::bind(ws_addr).await
.expect("Failed to bind WebSocket port");
info!("WebSocket server listening on {ws_addr}");
tokio::spawn(async move {
axum::serve(ws_listener, ws_app).await.unwrap();
});
// HTTP server (serves UI + full DensePose-compatible REST API)
let ui_path = args.ui_path.clone();
let http_app = Router::new()
.route("/", get(info_page))
// Health endpoints (DensePose-compatible)
.route("/health", get(health))
.route("/health/health", get(health_system))
.route("/health/live", get(health_live))
.route("/health/ready", get(health_ready))
.route("/health/version", get(health_version))
.route("/health/metrics", get(health_metrics))
// API info
.route("/api/v1/info", get(api_info))
.route("/api/v1/status", get(health_ready))
.route("/api/v1/metrics", get(health_metrics))
// Sensing endpoints
.route("/api/v1/sensing/latest", get(latest))
// Vital sign endpoints
.route("/api/v1/vital-signs", get(vital_signs_endpoint))
// RVF model container info
.route("/api/v1/model/info", get(model_info))
// Progressive loading & SONA endpoints (Phase 7-8)
.route("/api/v1/model/layers", get(model_layers))
.route("/api/v1/model/segments", get(model_segments))
.route("/api/v1/model/sona/profiles", get(sona_profiles))
.route("/api/v1/model/sona/activate", post(sona_activate))
// Pose endpoints (WiFi-derived)
.route("/api/v1/pose/current", get(pose_current))
.route("/api/v1/pose/stats", get(pose_stats))
.route("/api/v1/pose/zones/summary", get(pose_zones_summary))
// Stream endpoints
.route("/api/v1/stream/status", get(stream_status))
.route("/api/v1/stream/pose", get(ws_pose_handler))
// Sensing WebSocket on the HTTP port so the UI can reach it without a second port
.route("/ws/sensing", get(ws_sensing_handler))
// Static UI files
.nest_service("/ui", ServeDir::new(&ui_path))
.layer(SetResponseHeaderLayer::overriding(
axum::http::header::CACHE_CONTROL,
HeaderValue::from_static("no-cache, no-store, must-revalidate"),
))
.with_state(state.clone());
let http_addr = SocketAddr::from(([0, 0, 0, 0], args.http_port));
let http_listener = tokio::net::TcpListener::bind(http_addr).await
.expect("Failed to bind HTTP port");
info!("HTTP server listening on {http_addr}");
info!("Open http://localhost:{}/ui/index.html in your browser", args.http_port);
// Run the HTTP server with graceful shutdown support
let shutdown_state = state.clone();
let server = axum::serve(http_listener, http_app)
.with_graceful_shutdown(async {
tokio::signal::ctrl_c()
.await
.expect("failed to install CTRL+C handler");
info!("Shutdown signal received");
});
server.await.unwrap();
// Save RVF container on shutdown if --save-rvf was specified
let s = shutdown_state.read().await;
if let Some(ref save_path) = s.save_rvf_path {
info!("Saving RVF container to {}", save_path.display());
let mut builder = RvfBuilder::new();
builder.add_manifest(
"wifi-densepose-sensing",
env!("CARGO_PKG_VERSION"),
"WiFi DensePose sensing model state",
);
builder.add_metadata(&serde_json::json!({
"source": s.source,
"total_ticks": s.tick,
"total_detections": s.total_detections,
"uptime_secs": s.start_time.elapsed().as_secs(),
}));
builder.add_vital_config(&VitalSignConfig::default());
// Save transformer weights if a model is loaded, otherwise empty
let weights: Vec<f32> = if s.model_loaded {
// If we loaded via --model, the progressive loader has the weights
// For now, save runtime state placeholder
let tf = graph_transformer::CsiToPoseTransformer::new(Default::default());
tf.flatten_weights()
} else {
Vec::new()
};
builder.add_weights(&weights);
match builder.write_to_file(save_path) {
Ok(()) => info!(" RVF saved ({} weight params)", weights.len()),
Err(e) => error!(" Failed to save RVF: {e}"),
}
}
info!("Server shut down cleanly");
}