feat: Add wifi-Mat disaster detection enhancements
Implement 6 optional enhancements for the wifi-Mat module: 1. Hardware Integration (csi_receiver.rs + hardware_adapter.rs) - ESP32 CSI support via serial/UDP - Intel 5300 BFEE file parsing - Atheros CSI Tool integration - Live UDP packet streaming - PCAP replay capability 2. CLI Commands (wifi-densepose-cli/src/mat.rs) - `wifi-mat scan` - Run disaster detection scan - `wifi-mat status` - Check event status - `wifi-mat zones` - Manage scan zones - `wifi-mat survivors` - List detected survivors - `wifi-mat alerts` - View and acknowledge alerts - `wifi-mat export` - Export data in various formats 3. REST API (wifi-densepose-mat/src/api/) - Full CRUD for disaster events - Zone management endpoints - Survivor and alert queries - WebSocket streaming for real-time updates - Comprehensive DTOs and error handling 4. WASM Build (wifi-densepose-wasm/src/mat.rs) - Browser-based disaster dashboard - Real-time survivor tracking - Zone visualization - Alert management - JavaScript API bindings 5. Detection Benchmarks (benches/detection_bench.rs) - Single survivor detection - Multi-survivor detection - Full pipeline benchmarks - Signal processing benchmarks - Hardware adapter benchmarks 6. ML Models for Debris Penetration (ml/) - DebrisModel for material analysis - VitalSignsClassifier for triage - FFT-based feature extraction - Bandpass filtering - Monte Carlo dropout for uncertainty All 134 unit tests pass. Compilation verified for: - wifi-densepose-mat - wifi-densepose-cli - wifi-densepose-wasm (with mat feature)
This commit is contained in:
1240
rust-port/wifi-densepose-rs/Cargo.lock
generated
1240
rust-port/wifi-densepose-rs/Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -3,5 +3,54 @@ name = "wifi-densepose-cli"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
description = "CLI for WiFi-DensePose"
|
||||
authors.workspace = true
|
||||
license.workspace = true
|
||||
repository.workspace = true
|
||||
|
||||
[[bin]]
|
||||
name = "wifi-densepose"
|
||||
path = "src/main.rs"
|
||||
|
||||
[features]
|
||||
default = ["mat"]
|
||||
mat = []
|
||||
|
||||
[dependencies]
|
||||
# Internal crates
|
||||
wifi-densepose-mat = { path = "../wifi-densepose-mat" }
|
||||
|
||||
# CLI framework
|
||||
clap = { version = "4.4", features = ["derive", "env", "cargo"] }
|
||||
|
||||
# Output formatting
|
||||
colored = "2.1"
|
||||
tabled = { version = "0.15", features = ["ansi"] }
|
||||
indicatif = "0.17"
|
||||
console = "0.15"
|
||||
|
||||
# Async runtime
|
||||
tokio = { version = "1.35", features = ["full"] }
|
||||
|
||||
# Serialization
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
csv = "1.3"
|
||||
|
||||
# Error handling
|
||||
anyhow = "1.0"
|
||||
thiserror = "1.0"
|
||||
|
||||
# Time
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
|
||||
# UUID
|
||||
uuid = { version = "1.6", features = ["v4", "serde"] }
|
||||
|
||||
# Logging
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
|
||||
|
||||
[dev-dependencies]
|
||||
assert_cmd = "2.0"
|
||||
predicates = "3.0"
|
||||
tempfile = "3.9"
|
||||
|
||||
@@ -1 +1,51 @@
|
||||
//! WiFi-DensePose CLI (stub)
|
||||
//! WiFi-DensePose CLI
|
||||
//!
|
||||
//! Command-line interface for WiFi-DensePose system, including the
|
||||
//! Mass Casualty Assessment Tool (MAT) for disaster response.
|
||||
//!
|
||||
//! # Features
|
||||
//!
|
||||
//! - **mat**: Disaster survivor detection and triage management
|
||||
//! - **version**: Display version information
|
||||
//!
|
||||
//! # Usage
|
||||
//!
|
||||
//! ```bash
|
||||
//! # Start scanning for survivors
|
||||
//! wifi-densepose mat scan --zone "Building A"
|
||||
//!
|
||||
//! # View current scan status
|
||||
//! wifi-densepose mat status
|
||||
//!
|
||||
//! # List detected survivors
|
||||
//! wifi-densepose mat survivors --sort-by triage
|
||||
//!
|
||||
//! # View and manage alerts
|
||||
//! wifi-densepose mat alerts
|
||||
//! ```
|
||||
|
||||
use clap::{Parser, Subcommand};
|
||||
|
||||
pub mod mat;
|
||||
|
||||
/// WiFi-DensePose Command Line Interface
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "wifi-densepose")]
|
||||
#[command(author, version, about = "WiFi-based pose estimation and disaster response")]
|
||||
#[command(propagate_version = true)]
|
||||
pub struct Cli {
|
||||
/// Command to execute
|
||||
#[command(subcommand)]
|
||||
pub command: Commands,
|
||||
}
|
||||
|
||||
/// Top-level commands
|
||||
#[derive(Subcommand, Debug)]
|
||||
pub enum Commands {
|
||||
/// Mass Casualty Assessment Tool commands
|
||||
#[command(subcommand)]
|
||||
Mat(mat::MatCommand),
|
||||
|
||||
/// Display version information
|
||||
Version,
|
||||
}
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
//! WiFi-DensePose CLI Entry Point
|
||||
//!
|
||||
//! This is the main entry point for the wifi-densepose command-line tool.
|
||||
|
||||
use clap::Parser;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
|
||||
|
||||
use wifi_densepose_cli::{Cli, Commands};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
// Initialize logging
|
||||
tracing_subscriber::registry()
|
||||
.with(EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")))
|
||||
.with(tracing_subscriber::fmt::layer().with_target(false))
|
||||
.init();
|
||||
|
||||
let cli = Cli::parse();
|
||||
|
||||
match cli.command {
|
||||
Commands::Mat(mat_cmd) => {
|
||||
wifi_densepose_cli::mat::execute(mat_cmd).await?;
|
||||
}
|
||||
Commands::Version => {
|
||||
println!("wifi-densepose {}", env!("CARGO_PKG_VERSION"));
|
||||
println!("MAT module version: {}", wifi_densepose_mat::VERSION);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
1235
rust-port/wifi-densepose-rs/crates/wifi-densepose-cli/src/mat.rs
Normal file
1235
rust-port/wifi-densepose-rs/crates/wifi-densepose-cli/src/mat.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -10,13 +10,14 @@ keywords = ["wifi", "disaster", "rescue", "detection", "vital-signs"]
|
||||
categories = ["science", "algorithms"]
|
||||
|
||||
[features]
|
||||
default = ["std"]
|
||||
default = ["std", "api"]
|
||||
std = []
|
||||
api = ["dep:serde", "chrono/serde", "geo/use-serde"]
|
||||
portable = ["low-power"]
|
||||
low-power = []
|
||||
distributed = ["tokio/sync"]
|
||||
drone = ["distributed"]
|
||||
serde = ["dep:serde", "chrono/serde"]
|
||||
serde = ["dep:serde", "chrono/serde", "geo/use-serde"]
|
||||
|
||||
[dependencies]
|
||||
# Workspace dependencies
|
||||
@@ -28,6 +29,10 @@ wifi-densepose-nn = { path = "../wifi-densepose-nn" }
|
||||
tokio = { version = "1.35", features = ["rt", "sync", "time"] }
|
||||
async-trait = "0.1"
|
||||
|
||||
# Web framework (REST API)
|
||||
axum = { version = "0.7", features = ["ws"] }
|
||||
futures-util = "0.3"
|
||||
|
||||
# Error handling
|
||||
thiserror = "1.0"
|
||||
anyhow = "1.0"
|
||||
@@ -58,6 +63,10 @@ criterion = { version = "0.5", features = ["html_reports"] }
|
||||
proptest = "1.4"
|
||||
approx = "0.5"
|
||||
|
||||
[[bench]]
|
||||
name = "detection_bench"
|
||||
harness = false
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
all-features = true
|
||||
rustdoc-args = ["--cfg", "docsrs"]
|
||||
|
||||
@@ -0,0 +1,906 @@
|
||||
//! Performance benchmarks for wifi-densepose-mat detection algorithms.
|
||||
//!
|
||||
//! Run with: cargo bench --package wifi-densepose-mat
|
||||
//!
|
||||
//! Benchmarks cover:
|
||||
//! - Breathing detection at various signal lengths
|
||||
//! - Heartbeat detection performance
|
||||
//! - Movement classification
|
||||
//! - Full detection pipeline
|
||||
//! - Localization algorithms (triangulation, depth estimation)
|
||||
//! - Alert generation
|
||||
|
||||
use criterion::{
|
||||
black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput,
|
||||
};
|
||||
use std::f64::consts::PI;
|
||||
|
||||
use wifi_densepose_mat::{
|
||||
// Detection types
|
||||
BreathingDetector, BreathingDetectorConfig,
|
||||
HeartbeatDetector, HeartbeatDetectorConfig,
|
||||
MovementClassifier, MovementClassifierConfig,
|
||||
DetectionConfig, DetectionPipeline, VitalSignsDetector,
|
||||
// Localization types
|
||||
Triangulator, DepthEstimator,
|
||||
// Alerting types
|
||||
AlertGenerator,
|
||||
// Domain types exported at crate root
|
||||
BreathingPattern, BreathingType, VitalSignsReading,
|
||||
MovementProfile, ScanZoneId, Survivor,
|
||||
};
|
||||
|
||||
// Types that need to be accessed from submodules
|
||||
use wifi_densepose_mat::detection::CsiDataBuffer;
|
||||
use wifi_densepose_mat::domain::{
|
||||
ConfidenceScore, SensorPosition, SensorType,
|
||||
DebrisProfile, DebrisMaterial, MoistureLevel, MetalContent,
|
||||
};
|
||||
|
||||
use chrono::Utc;
|
||||
|
||||
// =============================================================================
|
||||
// Test Data Generators
|
||||
// =============================================================================
|
||||
|
||||
/// Generate a clean breathing signal at specified rate
|
||||
fn generate_breathing_signal(rate_bpm: f64, sample_rate: f64, duration_secs: f64) -> Vec<f64> {
|
||||
let num_samples = (sample_rate * duration_secs) as usize;
|
||||
let freq = rate_bpm / 60.0;
|
||||
|
||||
(0..num_samples)
|
||||
.map(|i| {
|
||||
let t = i as f64 / sample_rate;
|
||||
(2.0 * PI * freq * t).sin()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Generate a breathing signal with noise
|
||||
fn generate_noisy_breathing_signal(
|
||||
rate_bpm: f64,
|
||||
sample_rate: f64,
|
||||
duration_secs: f64,
|
||||
noise_level: f64,
|
||||
) -> Vec<f64> {
|
||||
let num_samples = (sample_rate * duration_secs) as usize;
|
||||
let freq = rate_bpm / 60.0;
|
||||
|
||||
(0..num_samples)
|
||||
.map(|i| {
|
||||
let t = i as f64 / sample_rate;
|
||||
let signal = (2.0 * PI * freq * t).sin();
|
||||
// Simple pseudo-random noise based on sample index
|
||||
let noise = ((i as f64 * 12345.6789).sin() * 2.0 - 1.0) * noise_level;
|
||||
signal + noise
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Generate heartbeat signal with micro-Doppler characteristics
|
||||
fn generate_heartbeat_signal(rate_bpm: f64, sample_rate: f64, duration_secs: f64) -> Vec<f64> {
|
||||
let num_samples = (sample_rate * duration_secs) as usize;
|
||||
let freq = rate_bpm / 60.0;
|
||||
|
||||
(0..num_samples)
|
||||
.map(|i| {
|
||||
let t = i as f64 / sample_rate;
|
||||
let phase = 2.0 * PI * freq * t;
|
||||
// Heartbeat is more pulse-like than sinusoidal
|
||||
0.3 * phase.sin() + 0.1 * (2.0 * phase).sin() + 0.05 * (3.0 * phase).sin()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Generate combined breathing + heartbeat signal
|
||||
fn generate_combined_vital_signal(
|
||||
breathing_rate: f64,
|
||||
heart_rate: f64,
|
||||
sample_rate: f64,
|
||||
duration_secs: f64,
|
||||
) -> (Vec<f64>, Vec<f64>) {
|
||||
let num_samples = (sample_rate * duration_secs) as usize;
|
||||
let br_freq = breathing_rate / 60.0;
|
||||
let hr_freq = heart_rate / 60.0;
|
||||
|
||||
let amplitudes: Vec<f64> = (0..num_samples)
|
||||
.map(|i| {
|
||||
let t = i as f64 / sample_rate;
|
||||
// Breathing dominates amplitude
|
||||
(2.0 * PI * br_freq * t).sin()
|
||||
})
|
||||
.collect();
|
||||
|
||||
let phases: Vec<f64> = (0..num_samples)
|
||||
.map(|i| {
|
||||
let t = i as f64 / sample_rate;
|
||||
// Phase captures both but heartbeat is more prominent
|
||||
let breathing = 0.3 * (2.0 * PI * br_freq * t).sin();
|
||||
let heartbeat = 0.5 * (2.0 * PI * hr_freq * t).sin();
|
||||
breathing + heartbeat
|
||||
})
|
||||
.collect();
|
||||
|
||||
(amplitudes, phases)
|
||||
}
|
||||
|
||||
/// Generate multi-person scenario with overlapping signals
|
||||
fn generate_multi_person_signal(
|
||||
person_count: usize,
|
||||
sample_rate: f64,
|
||||
duration_secs: f64,
|
||||
) -> Vec<f64> {
|
||||
let num_samples = (sample_rate * duration_secs) as usize;
|
||||
|
||||
// Different breathing rates for each person
|
||||
let base_rates: Vec<f64> = (0..person_count)
|
||||
.map(|i| 12.0 + (i as f64 * 3.5)) // 12, 15.5, 19, 22.5... BPM
|
||||
.collect();
|
||||
|
||||
(0..num_samples)
|
||||
.map(|i| {
|
||||
let t = i as f64 / sample_rate;
|
||||
base_rates.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, &rate)| {
|
||||
let freq = rate / 60.0;
|
||||
let amplitude = 1.0 / (idx + 1) as f64; // Distance-based attenuation
|
||||
let phase_offset = idx as f64 * PI / 4.0; // Different phases
|
||||
amplitude * (2.0 * PI * freq * t + phase_offset).sin()
|
||||
})
|
||||
.sum::<f64>()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Generate movement signal with specified characteristics
|
||||
fn generate_movement_signal(
|
||||
movement_type: &str,
|
||||
sample_rate: f64,
|
||||
duration_secs: f64,
|
||||
) -> Vec<f64> {
|
||||
let num_samples = (sample_rate * duration_secs) as usize;
|
||||
|
||||
match movement_type {
|
||||
"gross" => {
|
||||
// Large, irregular movements
|
||||
let mut signal = vec![0.0; num_samples];
|
||||
for i in (num_samples / 4)..(num_samples / 2) {
|
||||
signal[i] = 2.0;
|
||||
}
|
||||
for i in (3 * num_samples / 4)..(4 * num_samples / 5) {
|
||||
signal[i] = -1.5;
|
||||
}
|
||||
signal
|
||||
}
|
||||
"tremor" => {
|
||||
// High-frequency tremor (8-12 Hz)
|
||||
(0..num_samples)
|
||||
.map(|i| {
|
||||
let t = i as f64 / sample_rate;
|
||||
0.3 * (2.0 * PI * 10.0 * t).sin()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
"periodic" => {
|
||||
// Low-frequency periodic (breathing-like)
|
||||
(0..num_samples)
|
||||
.map(|i| {
|
||||
let t = i as f64 / sample_rate;
|
||||
0.5 * (2.0 * PI * 0.25 * t).sin()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
_ => vec![0.0; num_samples], // No movement
|
||||
}
|
||||
}
|
||||
|
||||
/// Create test sensor positions in a triangular configuration
|
||||
fn create_test_sensors(count: usize) -> Vec<SensorPosition> {
|
||||
(0..count)
|
||||
.map(|i| {
|
||||
let angle = 2.0 * PI * i as f64 / count as f64;
|
||||
SensorPosition {
|
||||
id: format!("sensor_{}", i),
|
||||
x: 10.0 * angle.cos(),
|
||||
y: 10.0 * angle.sin(),
|
||||
z: 1.5,
|
||||
sensor_type: SensorType::Transceiver,
|
||||
is_operational: true,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Create test debris profile
|
||||
fn create_test_debris() -> DebrisProfile {
|
||||
DebrisProfile {
|
||||
primary_material: DebrisMaterial::Mixed,
|
||||
void_fraction: 0.25,
|
||||
moisture_content: MoistureLevel::Dry,
|
||||
metal_content: MetalContent::Low,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create test survivor for alert generation
|
||||
fn create_test_survivor() -> Survivor {
|
||||
let vitals = VitalSignsReading {
|
||||
breathing: Some(BreathingPattern {
|
||||
rate_bpm: 18.0,
|
||||
amplitude: 0.8,
|
||||
regularity: 0.9,
|
||||
pattern_type: BreathingType::Normal,
|
||||
}),
|
||||
heartbeat: None,
|
||||
movement: MovementProfile::default(),
|
||||
timestamp: Utc::now(),
|
||||
confidence: ConfidenceScore::new(0.85),
|
||||
};
|
||||
|
||||
Survivor::new(ScanZoneId::new(), vitals, None)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Breathing Detection Benchmarks
|
||||
// =============================================================================
|
||||
|
||||
fn bench_breathing_detection(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("breathing_detection");
|
||||
|
||||
let detector = BreathingDetector::with_defaults();
|
||||
let sample_rate = 100.0; // 100 Hz
|
||||
|
||||
// Benchmark different signal lengths
|
||||
for duration in [5.0, 10.0, 30.0, 60.0] {
|
||||
let signal = generate_breathing_signal(16.0, sample_rate, duration);
|
||||
let num_samples = signal.len();
|
||||
|
||||
group.throughput(Throughput::Elements(num_samples as u64));
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("clean_signal", format!("{}s", duration as u32)),
|
||||
&signal,
|
||||
|b, signal| {
|
||||
b.iter(|| detector.detect(black_box(signal), black_box(sample_rate)))
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// Benchmark different noise levels
|
||||
for noise_level in [0.0, 0.1, 0.3, 0.5] {
|
||||
let signal = generate_noisy_breathing_signal(16.0, sample_rate, 30.0, noise_level);
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("noisy_signal", format!("noise_{}", (noise_level * 10.0) as u32)),
|
||||
&signal,
|
||||
|b, signal| {
|
||||
b.iter(|| detector.detect(black_box(signal), black_box(sample_rate)))
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// Benchmark different breathing rates
|
||||
for rate in [8.0, 16.0, 25.0, 35.0] {
|
||||
let signal = generate_breathing_signal(rate, sample_rate, 30.0);
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("rate_variation", format!("{}bpm", rate as u32)),
|
||||
&signal,
|
||||
|b, signal| {
|
||||
b.iter(|| detector.detect(black_box(signal), black_box(sample_rate)))
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// Benchmark with custom config (high sensitivity)
|
||||
let high_sensitivity_config = BreathingDetectorConfig {
|
||||
min_rate_bpm: 2.0,
|
||||
max_rate_bpm: 50.0,
|
||||
min_amplitude: 0.05,
|
||||
window_size: 1024,
|
||||
window_overlap: 0.75,
|
||||
confidence_threshold: 0.2,
|
||||
};
|
||||
let sensitive_detector = BreathingDetector::new(high_sensitivity_config);
|
||||
let signal = generate_noisy_breathing_signal(16.0, sample_rate, 30.0, 0.3);
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("high_sensitivity", "30s_noisy"),
|
||||
&signal,
|
||||
|b, signal| {
|
||||
b.iter(|| sensitive_detector.detect(black_box(signal), black_box(sample_rate)))
|
||||
},
|
||||
);
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Heartbeat Detection Benchmarks
|
||||
// =============================================================================
|
||||
|
||||
fn bench_heartbeat_detection(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("heartbeat_detection");
|
||||
|
||||
let detector = HeartbeatDetector::with_defaults();
|
||||
let sample_rate = 1000.0; // 1 kHz for micro-Doppler
|
||||
|
||||
// Benchmark different signal lengths
|
||||
for duration in [5.0, 10.0, 30.0] {
|
||||
let signal = generate_heartbeat_signal(72.0, sample_rate, duration);
|
||||
let num_samples = signal.len();
|
||||
|
||||
group.throughput(Throughput::Elements(num_samples as u64));
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("clean_signal", format!("{}s", duration as u32)),
|
||||
&signal,
|
||||
|b, signal| {
|
||||
b.iter(|| detector.detect(black_box(signal), black_box(sample_rate), None))
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// Benchmark with known breathing rate (improves filtering)
|
||||
let signal = generate_heartbeat_signal(72.0, sample_rate, 30.0);
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("with_breathing_rate", "72bpm_known_br"),
|
||||
&signal,
|
||||
|b, signal| {
|
||||
b.iter(|| {
|
||||
detector.detect(
|
||||
black_box(signal),
|
||||
black_box(sample_rate),
|
||||
black_box(Some(16.0)), // Known breathing rate
|
||||
)
|
||||
})
|
||||
},
|
||||
);
|
||||
|
||||
// Benchmark different heart rates
|
||||
for rate in [50.0, 72.0, 100.0, 150.0] {
|
||||
let signal = generate_heartbeat_signal(rate, sample_rate, 10.0);
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("rate_variation", format!("{}bpm", rate as u32)),
|
||||
&signal,
|
||||
|b, signal| {
|
||||
b.iter(|| detector.detect(black_box(signal), black_box(sample_rate), None))
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// Benchmark enhanced processing config
|
||||
let enhanced_config = HeartbeatDetectorConfig {
|
||||
min_rate_bpm: 30.0,
|
||||
max_rate_bpm: 200.0,
|
||||
min_signal_strength: 0.02,
|
||||
window_size: 2048,
|
||||
enhanced_processing: true,
|
||||
confidence_threshold: 0.3,
|
||||
};
|
||||
let enhanced_detector = HeartbeatDetector::new(enhanced_config);
|
||||
let signal = generate_heartbeat_signal(72.0, sample_rate, 10.0);
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("enhanced_processing", "2048_window"),
|
||||
&signal,
|
||||
|b, signal| {
|
||||
b.iter(|| enhanced_detector.detect(black_box(signal), black_box(sample_rate), None))
|
||||
},
|
||||
);
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Movement Classification Benchmarks
|
||||
// =============================================================================
|
||||
|
||||
fn bench_movement_classification(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("movement_classification");
|
||||
|
||||
let classifier = MovementClassifier::with_defaults();
|
||||
let sample_rate = 100.0;
|
||||
|
||||
// Benchmark different movement types
|
||||
for movement_type in ["none", "gross", "tremor", "periodic"] {
|
||||
let signal = generate_movement_signal(movement_type, sample_rate, 10.0);
|
||||
let num_samples = signal.len();
|
||||
|
||||
group.throughput(Throughput::Elements(num_samples as u64));
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("movement_type", movement_type),
|
||||
&signal,
|
||||
|b, signal| {
|
||||
b.iter(|| classifier.classify(black_box(signal), black_box(sample_rate)))
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// Benchmark different signal lengths
|
||||
for duration in [2.0, 5.0, 10.0, 30.0] {
|
||||
let signal = generate_movement_signal("gross", sample_rate, duration);
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("signal_length", format!("{}s", duration as u32)),
|
||||
&signal,
|
||||
|b, signal| {
|
||||
b.iter(|| classifier.classify(black_box(signal), black_box(sample_rate)))
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// Benchmark with custom sensitivity
|
||||
let sensitive_config = MovementClassifierConfig {
|
||||
movement_threshold: 0.05,
|
||||
gross_movement_threshold: 0.3,
|
||||
window_size: 200,
|
||||
periodicity_threshold: 0.2,
|
||||
};
|
||||
let sensitive_classifier = MovementClassifier::new(sensitive_config);
|
||||
let signal = generate_movement_signal("tremor", sample_rate, 10.0);
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("high_sensitivity", "tremor_detection"),
|
||||
&signal,
|
||||
|b, signal| {
|
||||
b.iter(|| sensitive_classifier.classify(black_box(signal), black_box(sample_rate)))
|
||||
},
|
||||
);
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Full Detection Pipeline Benchmarks
|
||||
// =============================================================================
|
||||
|
||||
fn bench_detection_pipeline(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("detection_pipeline");
|
||||
group.sample_size(50); // Reduce sample size for slower benchmarks
|
||||
|
||||
let sample_rate = 100.0;
|
||||
|
||||
// Standard pipeline (breathing + movement)
|
||||
let standard_config = DetectionConfig {
|
||||
sample_rate,
|
||||
enable_heartbeat: false,
|
||||
min_confidence: 0.3,
|
||||
..Default::default()
|
||||
};
|
||||
let standard_pipeline = DetectionPipeline::new(standard_config);
|
||||
|
||||
// Full pipeline (breathing + heartbeat + movement)
|
||||
let full_config = DetectionConfig {
|
||||
sample_rate: 1000.0,
|
||||
enable_heartbeat: true,
|
||||
min_confidence: 0.3,
|
||||
..Default::default()
|
||||
};
|
||||
let full_pipeline = DetectionPipeline::new(full_config);
|
||||
|
||||
// Benchmark standard pipeline at different data sizes
|
||||
for duration in [5.0, 10.0, 30.0] {
|
||||
let (amplitudes, phases) = generate_combined_vital_signal(16.0, 72.0, sample_rate, duration);
|
||||
let mut buffer = CsiDataBuffer::new(sample_rate);
|
||||
buffer.add_samples(&litudes, &phases);
|
||||
|
||||
group.throughput(Throughput::Elements(amplitudes.len() as u64));
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("standard_pipeline", format!("{}s", duration as u32)),
|
||||
&buffer,
|
||||
|b, buffer| {
|
||||
b.iter(|| standard_pipeline.detect(black_box(buffer)))
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// Benchmark full pipeline
|
||||
for duration in [5.0, 10.0] {
|
||||
let (amplitudes, phases) = generate_combined_vital_signal(16.0, 72.0, 1000.0, duration);
|
||||
let mut buffer = CsiDataBuffer::new(1000.0);
|
||||
buffer.add_samples(&litudes, &phases);
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("full_pipeline", format!("{}s", duration as u32)),
|
||||
&buffer,
|
||||
|b, buffer| {
|
||||
b.iter(|| full_pipeline.detect(black_box(buffer)))
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// Benchmark multi-person scenarios
|
||||
for person_count in [1, 2, 3, 5] {
|
||||
let signal = generate_multi_person_signal(person_count, sample_rate, 30.0);
|
||||
let mut buffer = CsiDataBuffer::new(sample_rate);
|
||||
buffer.add_samples(&signal, &signal);
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("multi_person", format!("{}_people", person_count)),
|
||||
&buffer,
|
||||
|b, buffer| {
|
||||
b.iter(|| standard_pipeline.detect(black_box(buffer)))
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Triangulation Benchmarks
|
||||
// =============================================================================
|
||||
|
||||
fn bench_triangulation(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("triangulation");
|
||||
|
||||
let triangulator = Triangulator::with_defaults();
|
||||
|
||||
// Benchmark with different sensor counts
|
||||
for sensor_count in [3, 4, 5, 8, 12] {
|
||||
let sensors = create_test_sensors(sensor_count);
|
||||
|
||||
// Generate RSSI values (simulate target at center)
|
||||
let rssi_values: Vec<(String, f64)> = sensors.iter()
|
||||
.map(|s| {
|
||||
let distance = (s.x * s.x + s.y * s.y).sqrt();
|
||||
let rssi = -30.0 - 20.0 * distance.log10(); // Path loss model
|
||||
(s.id.clone(), rssi)
|
||||
})
|
||||
.collect();
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("rssi_position", format!("{}_sensors", sensor_count)),
|
||||
&(sensors.clone(), rssi_values.clone()),
|
||||
|b, (sensors, rssi)| {
|
||||
b.iter(|| {
|
||||
triangulator.estimate_position(black_box(sensors), black_box(rssi))
|
||||
})
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// Benchmark ToA-based positioning
|
||||
for sensor_count in [3, 4, 5, 8] {
|
||||
let sensors = create_test_sensors(sensor_count);
|
||||
|
||||
// Generate ToA values (time in nanoseconds)
|
||||
let toa_values: Vec<(String, f64)> = sensors.iter()
|
||||
.map(|s| {
|
||||
let distance = (s.x * s.x + s.y * s.y).sqrt();
|
||||
// Round trip time: 2 * distance / speed_of_light
|
||||
let toa_ns = 2.0 * distance / 299_792_458.0 * 1e9;
|
||||
(s.id.clone(), toa_ns)
|
||||
})
|
||||
.collect();
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("toa_position", format!("{}_sensors", sensor_count)),
|
||||
&(sensors.clone(), toa_values.clone()),
|
||||
|b, (sensors, toa)| {
|
||||
b.iter(|| {
|
||||
triangulator.estimate_from_toa(black_box(sensors), black_box(toa))
|
||||
})
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// Benchmark with noisy measurements
|
||||
let sensors = create_test_sensors(5);
|
||||
for noise_pct in [0, 5, 10, 20] {
|
||||
let rssi_values: Vec<(String, f64)> = sensors.iter()
|
||||
.enumerate()
|
||||
.map(|(i, s)| {
|
||||
let distance = (s.x * s.x + s.y * s.y).sqrt();
|
||||
let rssi = -30.0 - 20.0 * distance.log10();
|
||||
// Add noise based on index for determinism
|
||||
let noise = (i as f64 / 10.0) * noise_pct as f64 / 100.0 * 10.0;
|
||||
(s.id.clone(), rssi + noise)
|
||||
})
|
||||
.collect();
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("noisy_rssi", format!("{}pct_noise", noise_pct)),
|
||||
&(sensors.clone(), rssi_values.clone()),
|
||||
|b, (sensors, rssi)| {
|
||||
b.iter(|| {
|
||||
triangulator.estimate_position(black_box(sensors), black_box(rssi))
|
||||
})
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Depth Estimation Benchmarks
|
||||
// =============================================================================
|
||||
|
||||
fn bench_depth_estimation(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("depth_estimation");
|
||||
|
||||
let estimator = DepthEstimator::with_defaults();
|
||||
let debris = create_test_debris();
|
||||
|
||||
// Benchmark single-path depth estimation
|
||||
for attenuation in [10.0, 20.0, 40.0, 60.0] {
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("single_path", format!("{}dB", attenuation as u32)),
|
||||
&attenuation,
|
||||
|b, &attenuation| {
|
||||
b.iter(|| {
|
||||
estimator.estimate_depth(
|
||||
black_box(attenuation),
|
||||
black_box(5.0), // 5m horizontal distance
|
||||
black_box(&debris),
|
||||
)
|
||||
})
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// Benchmark different debris types
|
||||
let debris_types = [
|
||||
("snow", DebrisMaterial::Snow),
|
||||
("wood", DebrisMaterial::Wood),
|
||||
("light_concrete", DebrisMaterial::LightConcrete),
|
||||
("heavy_concrete", DebrisMaterial::HeavyConcrete),
|
||||
("mixed", DebrisMaterial::Mixed),
|
||||
];
|
||||
|
||||
for (name, material) in debris_types {
|
||||
let debris = DebrisProfile {
|
||||
primary_material: material,
|
||||
void_fraction: 0.25,
|
||||
moisture_content: MoistureLevel::Dry,
|
||||
metal_content: MetalContent::Low,
|
||||
};
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("debris_type", name),
|
||||
&debris,
|
||||
|b, debris| {
|
||||
b.iter(|| {
|
||||
estimator.estimate_depth(
|
||||
black_box(30.0),
|
||||
black_box(5.0),
|
||||
black_box(debris),
|
||||
)
|
||||
})
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// Benchmark multipath depth estimation
|
||||
for path_count in [1, 2, 4, 8] {
|
||||
let reflected_paths: Vec<(f64, f64)> = (0..path_count)
|
||||
.map(|i| {
|
||||
(
|
||||
30.0 + i as f64 * 5.0, // attenuation
|
||||
1e-9 * (i + 1) as f64, // delay in seconds
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("multipath", format!("{}_paths", path_count)),
|
||||
&reflected_paths,
|
||||
|b, paths| {
|
||||
b.iter(|| {
|
||||
estimator.estimate_from_multipath(
|
||||
black_box(25.0),
|
||||
black_box(paths),
|
||||
black_box(&debris),
|
||||
)
|
||||
})
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// Benchmark debris profile estimation
|
||||
for (variance, multipath, moisture) in [
|
||||
(0.2, 0.3, 0.2),
|
||||
(0.5, 0.5, 0.5),
|
||||
(0.7, 0.8, 0.8),
|
||||
] {
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("profile_estimation", format!("v{}_m{}", (variance * 10.0) as u32, (multipath * 10.0) as u32)),
|
||||
&(variance, multipath, moisture),
|
||||
|b, &(v, m, mo)| {
|
||||
b.iter(|| {
|
||||
estimator.estimate_debris_profile(
|
||||
black_box(v),
|
||||
black_box(m),
|
||||
black_box(mo),
|
||||
)
|
||||
})
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Alert Generation Benchmarks
|
||||
// =============================================================================
|
||||
|
||||
fn bench_alert_generation(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("alert_generation");
|
||||
|
||||
// Benchmark basic alert generation
|
||||
let generator = AlertGenerator::new();
|
||||
let survivor = create_test_survivor();
|
||||
|
||||
group.bench_function("generate_basic_alert", |b| {
|
||||
b.iter(|| generator.generate(black_box(&survivor)))
|
||||
});
|
||||
|
||||
// Benchmark escalation alert
|
||||
group.bench_function("generate_escalation_alert", |b| {
|
||||
b.iter(|| {
|
||||
generator.generate_escalation(
|
||||
black_box(&survivor),
|
||||
black_box("Vital signs deteriorating"),
|
||||
)
|
||||
})
|
||||
});
|
||||
|
||||
// Benchmark status change alert
|
||||
use wifi_densepose_mat::domain::TriageStatus;
|
||||
group.bench_function("generate_status_change_alert", |b| {
|
||||
b.iter(|| {
|
||||
generator.generate_status_change(
|
||||
black_box(&survivor),
|
||||
black_box(&TriageStatus::Minor),
|
||||
)
|
||||
})
|
||||
});
|
||||
|
||||
// Benchmark with zone registration
|
||||
let mut generator_with_zones = AlertGenerator::new();
|
||||
for i in 0..100 {
|
||||
generator_with_zones.register_zone(ScanZoneId::new(), format!("Zone {}", i));
|
||||
}
|
||||
|
||||
group.bench_function("generate_with_zones_lookup", |b| {
|
||||
b.iter(|| generator_with_zones.generate(black_box(&survivor)))
|
||||
});
|
||||
|
||||
// Benchmark batch alert generation
|
||||
let survivors: Vec<Survivor> = (0..10).map(|_| create_test_survivor()).collect();
|
||||
|
||||
group.bench_function("batch_generate_10_alerts", |b| {
|
||||
b.iter(|| {
|
||||
survivors.iter()
|
||||
.map(|s| generator.generate(black_box(s)))
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
});
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// CSI Buffer Operations Benchmarks
|
||||
// =============================================================================
|
||||
|
||||
fn bench_csi_buffer(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("csi_buffer");
|
||||
|
||||
let sample_rate = 100.0;
|
||||
|
||||
// Benchmark buffer creation and addition
|
||||
for sample_count in [1000, 5000, 10000, 30000] {
|
||||
let amplitudes: Vec<f64> = (0..sample_count)
|
||||
.map(|i| (i as f64 / 100.0).sin())
|
||||
.collect();
|
||||
let phases: Vec<f64> = (0..sample_count)
|
||||
.map(|i| (i as f64 / 50.0).cos())
|
||||
.collect();
|
||||
|
||||
group.throughput(Throughput::Elements(sample_count as u64));
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("add_samples", format!("{}_samples", sample_count)),
|
||||
&(amplitudes.clone(), phases.clone()),
|
||||
|b, (amp, phase)| {
|
||||
b.iter(|| {
|
||||
let mut buffer = CsiDataBuffer::new(sample_rate);
|
||||
buffer.add_samples(black_box(amp), black_box(phase));
|
||||
buffer
|
||||
})
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// Benchmark incremental addition (simulating real-time data)
|
||||
let chunk_size = 100;
|
||||
let total_samples = 10000;
|
||||
let amplitudes: Vec<f64> = (0..chunk_size).map(|i| (i as f64 / 100.0).sin()).collect();
|
||||
let phases: Vec<f64> = (0..chunk_size).map(|i| (i as f64 / 50.0).cos()).collect();
|
||||
|
||||
group.bench_function("incremental_add_100_chunks", |b| {
|
||||
b.iter(|| {
|
||||
let mut buffer = CsiDataBuffer::new(sample_rate);
|
||||
for _ in 0..(total_samples / chunk_size) {
|
||||
buffer.add_samples(black_box(&litudes), black_box(&phases));
|
||||
}
|
||||
buffer
|
||||
})
|
||||
});
|
||||
|
||||
// Benchmark has_sufficient_data check
|
||||
let mut buffer = CsiDataBuffer::new(sample_rate);
|
||||
let amplitudes: Vec<f64> = (0..3000).map(|i| (i as f64 / 100.0).sin()).collect();
|
||||
let phases: Vec<f64> = (0..3000).map(|i| (i as f64 / 50.0).cos()).collect();
|
||||
buffer.add_samples(&litudes, &phases);
|
||||
|
||||
group.bench_function("check_sufficient_data", |b| {
|
||||
b.iter(|| buffer.has_sufficient_data(black_box(10.0)))
|
||||
});
|
||||
|
||||
group.bench_function("calculate_duration", |b| {
|
||||
b.iter(|| black_box(&buffer).duration())
|
||||
});
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Criterion Groups and Main
|
||||
// =============================================================================
|
||||
|
||||
criterion_group!(
|
||||
name = detection_benches;
|
||||
config = Criterion::default()
|
||||
.warm_up_time(std::time::Duration::from_millis(500))
|
||||
.measurement_time(std::time::Duration::from_secs(2));
|
||||
targets =
|
||||
bench_breathing_detection,
|
||||
bench_heartbeat_detection,
|
||||
bench_movement_classification
|
||||
);
|
||||
|
||||
criterion_group!(
|
||||
name = pipeline_benches;
|
||||
config = Criterion::default()
|
||||
.warm_up_time(std::time::Duration::from_millis(500))
|
||||
.measurement_time(std::time::Duration::from_secs(3))
|
||||
.sample_size(50);
|
||||
targets = bench_detection_pipeline
|
||||
);
|
||||
|
||||
criterion_group!(
|
||||
name = localization_benches;
|
||||
config = Criterion::default()
|
||||
.warm_up_time(std::time::Duration::from_millis(500))
|
||||
.measurement_time(std::time::Duration::from_secs(2));
|
||||
targets =
|
||||
bench_triangulation,
|
||||
bench_depth_estimation
|
||||
);
|
||||
|
||||
criterion_group!(
|
||||
name = alerting_benches;
|
||||
config = Criterion::default()
|
||||
.warm_up_time(std::time::Duration::from_millis(300))
|
||||
.measurement_time(std::time::Duration::from_secs(1));
|
||||
targets = bench_alert_generation
|
||||
);
|
||||
|
||||
criterion_group!(
|
||||
name = buffer_benches;
|
||||
config = Criterion::default()
|
||||
.warm_up_time(std::time::Duration::from_millis(300))
|
||||
.measurement_time(std::time::Duration::from_secs(1));
|
||||
targets = bench_csi_buffer
|
||||
);
|
||||
|
||||
criterion_main!(
|
||||
detection_benches,
|
||||
pipeline_benches,
|
||||
localization_benches,
|
||||
alerting_benches,
|
||||
buffer_benches
|
||||
);
|
||||
@@ -0,0 +1,892 @@
|
||||
//! Data Transfer Objects (DTOs) for the MAT REST API.
|
||||
//!
|
||||
//! These types are used for serializing/deserializing API requests and responses.
|
||||
//! They provide a clean separation between domain models and API contracts.
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::domain::{
|
||||
DisasterType, EventStatus, ZoneStatus, TriageStatus, Priority,
|
||||
AlertStatus, SurvivorStatus,
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Event DTOs
|
||||
// ============================================================================
|
||||
|
||||
/// Request body for creating a new disaster event.
|
||||
///
|
||||
/// ## Example
|
||||
///
|
||||
/// ```json
|
||||
/// {
|
||||
/// "event_type": "Earthquake",
|
||||
/// "latitude": 37.7749,
|
||||
/// "longitude": -122.4194,
|
||||
/// "description": "Magnitude 6.8 earthquake in San Francisco",
|
||||
/// "estimated_occupancy": 500,
|
||||
/// "lead_agency": "SF Fire Department"
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct CreateEventRequest {
|
||||
/// Type of disaster event
|
||||
pub event_type: DisasterTypeDto,
|
||||
/// Latitude of disaster epicenter
|
||||
pub latitude: f64,
|
||||
/// Longitude of disaster epicenter
|
||||
pub longitude: f64,
|
||||
/// Human-readable description of the event
|
||||
pub description: String,
|
||||
/// Estimated number of people in the affected area
|
||||
#[serde(default)]
|
||||
pub estimated_occupancy: Option<u32>,
|
||||
/// Lead responding agency
|
||||
#[serde(default)]
|
||||
pub lead_agency: Option<String>,
|
||||
}
|
||||
|
||||
/// Response body for disaster event details.
|
||||
///
|
||||
/// ## Example Response
|
||||
///
|
||||
/// ```json
|
||||
/// {
|
||||
/// "id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
/// "event_type": "Earthquake",
|
||||
/// "status": "Active",
|
||||
/// "start_time": "2024-01-15T14:30:00Z",
|
||||
/// "latitude": 37.7749,
|
||||
/// "longitude": -122.4194,
|
||||
/// "description": "Magnitude 6.8 earthquake",
|
||||
/// "zone_count": 5,
|
||||
/// "survivor_count": 12,
|
||||
/// "triage_summary": {
|
||||
/// "immediate": 3,
|
||||
/// "delayed": 5,
|
||||
/// "minor": 4,
|
||||
/// "deceased": 0
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct EventResponse {
|
||||
/// Unique event identifier
|
||||
pub id: Uuid,
|
||||
/// Type of disaster
|
||||
pub event_type: DisasterTypeDto,
|
||||
/// Current event status
|
||||
pub status: EventStatusDto,
|
||||
/// When the event was created/started
|
||||
pub start_time: DateTime<Utc>,
|
||||
/// Latitude of epicenter
|
||||
pub latitude: f64,
|
||||
/// Longitude of epicenter
|
||||
pub longitude: f64,
|
||||
/// Event description
|
||||
pub description: String,
|
||||
/// Number of scan zones
|
||||
pub zone_count: usize,
|
||||
/// Number of detected survivors
|
||||
pub survivor_count: usize,
|
||||
/// Summary of triage classifications
|
||||
pub triage_summary: TriageSummary,
|
||||
/// Metadata about the event
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metadata: Option<EventMetadataDto>,
|
||||
}
|
||||
|
||||
/// Summary of triage counts across all survivors.
|
||||
#[derive(Debug, Clone, Serialize, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct TriageSummary {
|
||||
/// Immediate (Red) - life-threatening
|
||||
pub immediate: u32,
|
||||
/// Delayed (Yellow) - serious but stable
|
||||
pub delayed: u32,
|
||||
/// Minor (Green) - walking wounded
|
||||
pub minor: u32,
|
||||
/// Deceased (Black)
|
||||
pub deceased: u32,
|
||||
/// Unknown status
|
||||
pub unknown: u32,
|
||||
}
|
||||
|
||||
/// Event metadata DTO
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct EventMetadataDto {
|
||||
/// Estimated number of people in area at time of disaster
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub estimated_occupancy: Option<u32>,
|
||||
/// Known survivors (already rescued)
|
||||
#[serde(default)]
|
||||
pub confirmed_rescued: u32,
|
||||
/// Known fatalities
|
||||
#[serde(default)]
|
||||
pub confirmed_deceased: u32,
|
||||
/// Weather conditions
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub weather: Option<String>,
|
||||
/// Lead agency
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub lead_agency: Option<String>,
|
||||
}
|
||||
|
||||
/// Paginated list of events.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct EventListResponse {
|
||||
/// List of events
|
||||
pub events: Vec<EventResponse>,
|
||||
/// Total count of events
|
||||
pub total: usize,
|
||||
/// Current page number (0-indexed)
|
||||
pub page: usize,
|
||||
/// Number of items per page
|
||||
pub page_size: usize,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Zone DTOs
|
||||
// ============================================================================
|
||||
|
||||
/// Request body for adding a scan zone to an event.
|
||||
///
|
||||
/// ## Example
|
||||
///
|
||||
/// ```json
|
||||
/// {
|
||||
/// "name": "Building A - North Wing",
|
||||
/// "bounds": {
|
||||
/// "type": "rectangle",
|
||||
/// "min_x": 0.0,
|
||||
/// "min_y": 0.0,
|
||||
/// "max_x": 50.0,
|
||||
/// "max_y": 30.0
|
||||
/// },
|
||||
/// "parameters": {
|
||||
/// "sensitivity": 0.85,
|
||||
/// "max_depth": 5.0,
|
||||
/// "heartbeat_detection": true
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct CreateZoneRequest {
|
||||
/// Human-readable zone name
|
||||
pub name: String,
|
||||
/// Geographic bounds of the zone
|
||||
pub bounds: ZoneBoundsDto,
|
||||
/// Optional scan parameters
|
||||
#[serde(default)]
|
||||
pub parameters: Option<ScanParametersDto>,
|
||||
}
|
||||
|
||||
/// Zone boundary definition.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ZoneBoundsDto {
|
||||
/// Rectangular boundary
|
||||
Rectangle {
|
||||
min_x: f64,
|
||||
min_y: f64,
|
||||
max_x: f64,
|
||||
max_y: f64,
|
||||
},
|
||||
/// Circular boundary
|
||||
Circle {
|
||||
center_x: f64,
|
||||
center_y: f64,
|
||||
radius: f64,
|
||||
},
|
||||
/// Polygon boundary (list of vertices)
|
||||
Polygon {
|
||||
vertices: Vec<(f64, f64)>,
|
||||
},
|
||||
}
|
||||
|
||||
/// Scan parameters for a zone.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct ScanParametersDto {
|
||||
/// Detection sensitivity (0.0-1.0)
|
||||
#[serde(default = "default_sensitivity")]
|
||||
pub sensitivity: f64,
|
||||
/// Maximum depth to scan in meters
|
||||
#[serde(default = "default_max_depth")]
|
||||
pub max_depth: f64,
|
||||
/// Scan resolution level
|
||||
#[serde(default)]
|
||||
pub resolution: ScanResolutionDto,
|
||||
/// Enable enhanced breathing detection
|
||||
#[serde(default = "default_true")]
|
||||
pub enhanced_breathing: bool,
|
||||
/// Enable heartbeat detection (slower but more accurate)
|
||||
#[serde(default)]
|
||||
pub heartbeat_detection: bool,
|
||||
}
|
||||
|
||||
fn default_sensitivity() -> f64 { 0.8 }
|
||||
fn default_max_depth() -> f64 { 5.0 }
|
||||
fn default_true() -> bool { true }
|
||||
|
||||
impl Default for ScanParametersDto {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
sensitivity: default_sensitivity(),
|
||||
max_depth: default_max_depth(),
|
||||
resolution: ScanResolutionDto::default(),
|
||||
enhanced_breathing: default_true(),
|
||||
heartbeat_detection: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Scan resolution levels.
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ScanResolutionDto {
|
||||
Quick,
|
||||
#[default]
|
||||
Standard,
|
||||
High,
|
||||
Maximum,
|
||||
}
|
||||
|
||||
/// Response for zone details.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct ZoneResponse {
|
||||
/// Zone identifier
|
||||
pub id: Uuid,
|
||||
/// Zone name
|
||||
pub name: String,
|
||||
/// Zone status
|
||||
pub status: ZoneStatusDto,
|
||||
/// Zone boundaries
|
||||
pub bounds: ZoneBoundsDto,
|
||||
/// Zone area in square meters
|
||||
pub area: f64,
|
||||
/// Scan parameters
|
||||
pub parameters: ScanParametersDto,
|
||||
/// Last scan time
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub last_scan: Option<DateTime<Utc>>,
|
||||
/// Total scan count
|
||||
pub scan_count: u32,
|
||||
/// Number of detections in this zone
|
||||
pub detections_count: u32,
|
||||
}
|
||||
|
||||
/// List of zones response.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct ZoneListResponse {
|
||||
/// List of zones
|
||||
pub zones: Vec<ZoneResponse>,
|
||||
/// Total count
|
||||
pub total: usize,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Survivor DTOs
|
||||
// ============================================================================
|
||||
|
||||
/// Response for survivor details.
|
||||
///
|
||||
/// ## Example Response
|
||||
///
|
||||
/// ```json
|
||||
/// {
|
||||
/// "id": "550e8400-e29b-41d4-a716-446655440001",
|
||||
/// "zone_id": "550e8400-e29b-41d4-a716-446655440002",
|
||||
/// "status": "Active",
|
||||
/// "triage_status": "Immediate",
|
||||
/// "location": {
|
||||
/// "x": 25.5,
|
||||
/// "y": 12.3,
|
||||
/// "z": -2.1,
|
||||
/// "uncertainty_radius": 1.5
|
||||
/// },
|
||||
/// "vital_signs": {
|
||||
/// "breathing_rate": 22.5,
|
||||
/// "has_heartbeat": true,
|
||||
/// "has_movement": false
|
||||
/// },
|
||||
/// "confidence": 0.87,
|
||||
/// "first_detected": "2024-01-15T14:32:00Z",
|
||||
/// "last_updated": "2024-01-15T14:45:00Z"
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct SurvivorResponse {
|
||||
/// Survivor identifier
|
||||
pub id: Uuid,
|
||||
/// Zone where survivor was detected
|
||||
pub zone_id: Uuid,
|
||||
/// Current survivor status
|
||||
pub status: SurvivorStatusDto,
|
||||
/// Triage classification
|
||||
pub triage_status: TriageStatusDto,
|
||||
/// Location information
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub location: Option<LocationDto>,
|
||||
/// Latest vital signs summary
|
||||
pub vital_signs: VitalSignsSummaryDto,
|
||||
/// Detection confidence (0.0-1.0)
|
||||
pub confidence: f64,
|
||||
/// When survivor was first detected
|
||||
pub first_detected: DateTime<Utc>,
|
||||
/// Last update time
|
||||
pub last_updated: DateTime<Utc>,
|
||||
/// Whether survivor is deteriorating
|
||||
pub is_deteriorating: bool,
|
||||
/// Metadata
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metadata: Option<SurvivorMetadataDto>,
|
||||
}
|
||||
|
||||
/// Location information DTO.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct LocationDto {
|
||||
/// X coordinate (east-west, meters)
|
||||
pub x: f64,
|
||||
/// Y coordinate (north-south, meters)
|
||||
pub y: f64,
|
||||
/// Z coordinate (depth, negative is below surface)
|
||||
pub z: f64,
|
||||
/// Estimated depth below surface (positive meters)
|
||||
pub depth: f64,
|
||||
/// Horizontal uncertainty radius in meters
|
||||
pub uncertainty_radius: f64,
|
||||
/// Location confidence score
|
||||
pub confidence: f64,
|
||||
}
|
||||
|
||||
/// Summary of vital signs for API response.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct VitalSignsSummaryDto {
|
||||
/// Breathing rate (breaths per minute)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub breathing_rate: Option<f32>,
|
||||
/// Breathing pattern type
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub breathing_type: Option<String>,
|
||||
/// Heart rate if detected (bpm)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub heart_rate: Option<f32>,
|
||||
/// Whether heartbeat is detected
|
||||
pub has_heartbeat: bool,
|
||||
/// Whether movement is detected
|
||||
pub has_movement: bool,
|
||||
/// Movement type if present
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub movement_type: Option<String>,
|
||||
/// Timestamp of reading
|
||||
pub timestamp: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Survivor metadata DTO.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct SurvivorMetadataDto {
|
||||
/// Estimated age category
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub estimated_age_category: Option<String>,
|
||||
/// Assigned rescue team
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub assigned_team: Option<String>,
|
||||
/// Notes
|
||||
pub notes: Vec<String>,
|
||||
/// Tags
|
||||
pub tags: Vec<String>,
|
||||
}
|
||||
|
||||
/// List of survivors response.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct SurvivorListResponse {
|
||||
/// List of survivors
|
||||
pub survivors: Vec<SurvivorResponse>,
|
||||
/// Total count
|
||||
pub total: usize,
|
||||
/// Triage summary
|
||||
pub triage_summary: TriageSummary,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Alert DTOs
|
||||
// ============================================================================
|
||||
|
||||
/// Response for alert details.
|
||||
///
|
||||
/// ## Example Response
|
||||
///
|
||||
/// ```json
|
||||
/// {
|
||||
/// "id": "550e8400-e29b-41d4-a716-446655440003",
|
||||
/// "survivor_id": "550e8400-e29b-41d4-a716-446655440001",
|
||||
/// "priority": "Critical",
|
||||
/// "status": "Pending",
|
||||
/// "title": "Immediate: Survivor detected with abnormal breathing",
|
||||
/// "message": "Survivor in Zone A showing signs of respiratory distress",
|
||||
/// "triage_status": "Immediate",
|
||||
/// "location": { ... },
|
||||
/// "created_at": "2024-01-15T14:35:00Z"
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct AlertResponse {
|
||||
/// Alert identifier
|
||||
pub id: Uuid,
|
||||
/// Related survivor ID
|
||||
pub survivor_id: Uuid,
|
||||
/// Alert priority
|
||||
pub priority: PriorityDto,
|
||||
/// Alert status
|
||||
pub status: AlertStatusDto,
|
||||
/// Alert title
|
||||
pub title: String,
|
||||
/// Detailed message
|
||||
pub message: String,
|
||||
/// Associated triage status
|
||||
pub triage_status: TriageStatusDto,
|
||||
/// Location if available
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub location: Option<LocationDto>,
|
||||
/// Recommended action
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub recommended_action: Option<String>,
|
||||
/// When alert was created
|
||||
pub created_at: DateTime<Utc>,
|
||||
/// When alert was acknowledged
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub acknowledged_at: Option<DateTime<Utc>>,
|
||||
/// Who acknowledged the alert
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub acknowledged_by: Option<String>,
|
||||
/// Escalation count
|
||||
pub escalation_count: u32,
|
||||
}
|
||||
|
||||
/// Request to acknowledge an alert.
|
||||
///
|
||||
/// ## Example
|
||||
///
|
||||
/// ```json
|
||||
/// {
|
||||
/// "acknowledged_by": "Team Alpha",
|
||||
/// "notes": "En route to location"
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct AcknowledgeAlertRequest {
|
||||
/// Who is acknowledging the alert
|
||||
pub acknowledged_by: String,
|
||||
/// Optional notes
|
||||
#[serde(default)]
|
||||
pub notes: Option<String>,
|
||||
}
|
||||
|
||||
/// Response after acknowledging an alert.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct AcknowledgeAlertResponse {
|
||||
/// Whether acknowledgement was successful
|
||||
pub success: bool,
|
||||
/// Updated alert
|
||||
pub alert: AlertResponse,
|
||||
}
|
||||
|
||||
/// List of alerts response.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct AlertListResponse {
|
||||
/// List of alerts
|
||||
pub alerts: Vec<AlertResponse>,
|
||||
/// Total count
|
||||
pub total: usize,
|
||||
/// Count by priority
|
||||
pub priority_counts: PriorityCounts,
|
||||
}
|
||||
|
||||
/// Count of alerts by priority.
|
||||
#[derive(Debug, Clone, Serialize, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct PriorityCounts {
|
||||
pub critical: usize,
|
||||
pub high: usize,
|
||||
pub medium: usize,
|
||||
pub low: usize,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// WebSocket DTOs
|
||||
// ============================================================================
|
||||
|
||||
/// WebSocket message types for real-time streaming.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum WebSocketMessage {
|
||||
/// New survivor detected
|
||||
SurvivorDetected {
|
||||
event_id: Uuid,
|
||||
survivor: SurvivorResponse,
|
||||
},
|
||||
/// Survivor status updated
|
||||
SurvivorUpdated {
|
||||
event_id: Uuid,
|
||||
survivor: SurvivorResponse,
|
||||
},
|
||||
/// Survivor lost (signal lost)
|
||||
SurvivorLost {
|
||||
event_id: Uuid,
|
||||
survivor_id: Uuid,
|
||||
},
|
||||
/// New alert generated
|
||||
AlertCreated {
|
||||
event_id: Uuid,
|
||||
alert: AlertResponse,
|
||||
},
|
||||
/// Alert status changed
|
||||
AlertUpdated {
|
||||
event_id: Uuid,
|
||||
alert: AlertResponse,
|
||||
},
|
||||
/// Zone scan completed
|
||||
ZoneScanComplete {
|
||||
event_id: Uuid,
|
||||
zone_id: Uuid,
|
||||
detections: u32,
|
||||
},
|
||||
/// Event status changed
|
||||
EventStatusChanged {
|
||||
event_id: Uuid,
|
||||
old_status: EventStatusDto,
|
||||
new_status: EventStatusDto,
|
||||
},
|
||||
/// Heartbeat/keep-alive
|
||||
Heartbeat {
|
||||
timestamp: DateTime<Utc>,
|
||||
},
|
||||
/// Error message
|
||||
Error {
|
||||
code: String,
|
||||
message: String,
|
||||
},
|
||||
}
|
||||
|
||||
/// WebSocket subscription request.
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(tag = "action", rename_all = "snake_case")]
|
||||
pub enum WebSocketRequest {
|
||||
/// Subscribe to events for a disaster event
|
||||
Subscribe {
|
||||
event_id: Uuid,
|
||||
},
|
||||
/// Unsubscribe from events
|
||||
Unsubscribe {
|
||||
event_id: Uuid,
|
||||
},
|
||||
/// Subscribe to all events
|
||||
SubscribeAll,
|
||||
/// Request current state
|
||||
GetState {
|
||||
event_id: Uuid,
|
||||
},
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Enum DTOs (mirroring domain enums with serde)
|
||||
// ============================================================================
|
||||
|
||||
/// Disaster type DTO.
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "PascalCase")]
|
||||
pub enum DisasterTypeDto {
|
||||
BuildingCollapse,
|
||||
Earthquake,
|
||||
Landslide,
|
||||
Avalanche,
|
||||
Flood,
|
||||
MineCollapse,
|
||||
Industrial,
|
||||
TunnelCollapse,
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl From<DisasterType> for DisasterTypeDto {
|
||||
fn from(dt: DisasterType) -> Self {
|
||||
match dt {
|
||||
DisasterType::BuildingCollapse => DisasterTypeDto::BuildingCollapse,
|
||||
DisasterType::Earthquake => DisasterTypeDto::Earthquake,
|
||||
DisasterType::Landslide => DisasterTypeDto::Landslide,
|
||||
DisasterType::Avalanche => DisasterTypeDto::Avalanche,
|
||||
DisasterType::Flood => DisasterTypeDto::Flood,
|
||||
DisasterType::MineCollapse => DisasterTypeDto::MineCollapse,
|
||||
DisasterType::Industrial => DisasterTypeDto::Industrial,
|
||||
DisasterType::TunnelCollapse => DisasterTypeDto::TunnelCollapse,
|
||||
DisasterType::Unknown => DisasterTypeDto::Unknown,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DisasterTypeDto> for DisasterType {
|
||||
fn from(dt: DisasterTypeDto) -> Self {
|
||||
match dt {
|
||||
DisasterTypeDto::BuildingCollapse => DisasterType::BuildingCollapse,
|
||||
DisasterTypeDto::Earthquake => DisasterType::Earthquake,
|
||||
DisasterTypeDto::Landslide => DisasterType::Landslide,
|
||||
DisasterTypeDto::Avalanche => DisasterType::Avalanche,
|
||||
DisasterTypeDto::Flood => DisasterType::Flood,
|
||||
DisasterTypeDto::MineCollapse => DisasterType::MineCollapse,
|
||||
DisasterTypeDto::Industrial => DisasterType::Industrial,
|
||||
DisasterTypeDto::TunnelCollapse => DisasterType::TunnelCollapse,
|
||||
DisasterTypeDto::Unknown => DisasterType::Unknown,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Event status DTO.
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "PascalCase")]
|
||||
pub enum EventStatusDto {
|
||||
Initializing,
|
||||
Active,
|
||||
Suspended,
|
||||
SecondarySearch,
|
||||
Closed,
|
||||
}
|
||||
|
||||
impl From<EventStatus> for EventStatusDto {
|
||||
fn from(es: EventStatus) -> Self {
|
||||
match es {
|
||||
EventStatus::Initializing => EventStatusDto::Initializing,
|
||||
EventStatus::Active => EventStatusDto::Active,
|
||||
EventStatus::Suspended => EventStatusDto::Suspended,
|
||||
EventStatus::SecondarySearch => EventStatusDto::SecondarySearch,
|
||||
EventStatus::Closed => EventStatusDto::Closed,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Zone status DTO.
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "PascalCase")]
|
||||
pub enum ZoneStatusDto {
|
||||
Active,
|
||||
Paused,
|
||||
Complete,
|
||||
Inaccessible,
|
||||
Deactivated,
|
||||
}
|
||||
|
||||
impl From<ZoneStatus> for ZoneStatusDto {
|
||||
fn from(zs: ZoneStatus) -> Self {
|
||||
match zs {
|
||||
ZoneStatus::Active => ZoneStatusDto::Active,
|
||||
ZoneStatus::Paused => ZoneStatusDto::Paused,
|
||||
ZoneStatus::Complete => ZoneStatusDto::Complete,
|
||||
ZoneStatus::Inaccessible => ZoneStatusDto::Inaccessible,
|
||||
ZoneStatus::Deactivated => ZoneStatusDto::Deactivated,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Triage status DTO.
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "PascalCase")]
|
||||
pub enum TriageStatusDto {
|
||||
Immediate,
|
||||
Delayed,
|
||||
Minor,
|
||||
Deceased,
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl From<TriageStatus> for TriageStatusDto {
|
||||
fn from(ts: TriageStatus) -> Self {
|
||||
match ts {
|
||||
TriageStatus::Immediate => TriageStatusDto::Immediate,
|
||||
TriageStatus::Delayed => TriageStatusDto::Delayed,
|
||||
TriageStatus::Minor => TriageStatusDto::Minor,
|
||||
TriageStatus::Deceased => TriageStatusDto::Deceased,
|
||||
TriageStatus::Unknown => TriageStatusDto::Unknown,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Priority DTO.
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "PascalCase")]
|
||||
pub enum PriorityDto {
|
||||
Critical,
|
||||
High,
|
||||
Medium,
|
||||
Low,
|
||||
}
|
||||
|
||||
impl From<Priority> for PriorityDto {
|
||||
fn from(p: Priority) -> Self {
|
||||
match p {
|
||||
Priority::Critical => PriorityDto::Critical,
|
||||
Priority::High => PriorityDto::High,
|
||||
Priority::Medium => PriorityDto::Medium,
|
||||
Priority::Low => PriorityDto::Low,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Alert status DTO.
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "PascalCase")]
|
||||
pub enum AlertStatusDto {
|
||||
Pending,
|
||||
Acknowledged,
|
||||
InProgress,
|
||||
Resolved,
|
||||
Cancelled,
|
||||
Expired,
|
||||
}
|
||||
|
||||
impl From<AlertStatus> for AlertStatusDto {
|
||||
fn from(as_: AlertStatus) -> Self {
|
||||
match as_ {
|
||||
AlertStatus::Pending => AlertStatusDto::Pending,
|
||||
AlertStatus::Acknowledged => AlertStatusDto::Acknowledged,
|
||||
AlertStatus::InProgress => AlertStatusDto::InProgress,
|
||||
AlertStatus::Resolved => AlertStatusDto::Resolved,
|
||||
AlertStatus::Cancelled => AlertStatusDto::Cancelled,
|
||||
AlertStatus::Expired => AlertStatusDto::Expired,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Survivor status DTO.
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "PascalCase")]
|
||||
pub enum SurvivorStatusDto {
|
||||
Active,
|
||||
Rescued,
|
||||
Lost,
|
||||
Deceased,
|
||||
FalsePositive,
|
||||
}
|
||||
|
||||
impl From<SurvivorStatus> for SurvivorStatusDto {
|
||||
fn from(ss: SurvivorStatus) -> Self {
|
||||
match ss {
|
||||
SurvivorStatus::Active => SurvivorStatusDto::Active,
|
||||
SurvivorStatus::Rescued => SurvivorStatusDto::Rescued,
|
||||
SurvivorStatus::Lost => SurvivorStatusDto::Lost,
|
||||
SurvivorStatus::Deceased => SurvivorStatusDto::Deceased,
|
||||
SurvivorStatus::FalsePositive => SurvivorStatusDto::FalsePositive,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Query Parameters
|
||||
// ============================================================================
|
||||
|
||||
/// Query parameters for listing events.
|
||||
#[derive(Debug, Clone, Deserialize, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct ListEventsQuery {
|
||||
/// Filter by status
|
||||
pub status: Option<EventStatusDto>,
|
||||
/// Filter by disaster type
|
||||
pub event_type: Option<DisasterTypeDto>,
|
||||
/// Page number (0-indexed)
|
||||
#[serde(default)]
|
||||
pub page: usize,
|
||||
/// Page size (default 20, max 100)
|
||||
#[serde(default = "default_page_size")]
|
||||
pub page_size: usize,
|
||||
}
|
||||
|
||||
fn default_page_size() -> usize { 20 }
|
||||
|
||||
/// Query parameters for listing survivors.
|
||||
#[derive(Debug, Clone, Deserialize, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct ListSurvivorsQuery {
|
||||
/// Filter by triage status
|
||||
pub triage_status: Option<TriageStatusDto>,
|
||||
/// Filter by zone ID
|
||||
pub zone_id: Option<Uuid>,
|
||||
/// Filter by minimum confidence
|
||||
pub min_confidence: Option<f64>,
|
||||
/// Include only deteriorating
|
||||
#[serde(default)]
|
||||
pub deteriorating_only: bool,
|
||||
}
|
||||
|
||||
/// Query parameters for listing alerts.
|
||||
#[derive(Debug, Clone, Deserialize, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct ListAlertsQuery {
|
||||
/// Filter by priority
|
||||
pub priority: Option<PriorityDto>,
|
||||
/// Filter by status
|
||||
pub status: Option<AlertStatusDto>,
|
||||
/// Only pending alerts
|
||||
#[serde(default)]
|
||||
pub pending_only: bool,
|
||||
/// Only active alerts
|
||||
#[serde(default)]
|
||||
pub active_only: bool,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_create_event_request_deserialize() {
|
||||
let json = r#"{
|
||||
"event_type": "Earthquake",
|
||||
"latitude": 37.7749,
|
||||
"longitude": -122.4194,
|
||||
"description": "Test earthquake"
|
||||
}"#;
|
||||
|
||||
let req: CreateEventRequest = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(req.event_type, DisasterTypeDto::Earthquake);
|
||||
assert!((req.latitude - 37.7749).abs() < 0.0001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zone_bounds_dto_deserialize() {
|
||||
let rect_json = r#"{
|
||||
"type": "rectangle",
|
||||
"min_x": 0.0,
|
||||
"min_y": 0.0,
|
||||
"max_x": 10.0,
|
||||
"max_y": 10.0
|
||||
}"#;
|
||||
|
||||
let bounds: ZoneBoundsDto = serde_json::from_str(rect_json).unwrap();
|
||||
assert!(matches!(bounds, ZoneBoundsDto::Rectangle { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_websocket_message_serialize() {
|
||||
let msg = WebSocketMessage::Heartbeat {
|
||||
timestamp: Utc::now(),
|
||||
};
|
||||
let json = serde_json::to_string(&msg).unwrap();
|
||||
assert!(json.contains("\"type\":\"heartbeat\""));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,276 @@
|
||||
//! API error types and handling for the MAT REST API.
|
||||
//!
|
||||
//! This module provides a unified error type that maps to appropriate HTTP status codes
|
||||
//! and JSON error responses for the API.
|
||||
|
||||
use axum::{
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use serde::Serialize;
|
||||
use thiserror::Error;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// API error type that converts to HTTP responses.
|
||||
///
|
||||
/// All errors include:
|
||||
/// - An HTTP status code
|
||||
/// - A machine-readable error code
|
||||
/// - A human-readable message
|
||||
/// - Optional additional details
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ApiError {
|
||||
/// Resource not found (404)
|
||||
#[error("Resource not found: {resource_type} with id {id}")]
|
||||
NotFound {
|
||||
resource_type: String,
|
||||
id: String,
|
||||
},
|
||||
|
||||
/// Invalid request data (400)
|
||||
#[error("Bad request: {message}")]
|
||||
BadRequest {
|
||||
message: String,
|
||||
#[source]
|
||||
source: Option<Box<dyn std::error::Error + Send + Sync>>,
|
||||
},
|
||||
|
||||
/// Validation error (422)
|
||||
#[error("Validation failed: {message}")]
|
||||
ValidationError {
|
||||
message: String,
|
||||
field: Option<String>,
|
||||
},
|
||||
|
||||
/// Conflict with existing resource (409)
|
||||
#[error("Conflict: {message}")]
|
||||
Conflict {
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// Resource is in invalid state for operation (409)
|
||||
#[error("Invalid state: {message}")]
|
||||
InvalidState {
|
||||
message: String,
|
||||
current_state: String,
|
||||
},
|
||||
|
||||
/// Internal server error (500)
|
||||
#[error("Internal error: {message}")]
|
||||
Internal {
|
||||
message: String,
|
||||
#[source]
|
||||
source: Option<Box<dyn std::error::Error + Send + Sync>>,
|
||||
},
|
||||
|
||||
/// Service unavailable (503)
|
||||
#[error("Service unavailable: {message}")]
|
||||
ServiceUnavailable {
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// Domain error from business logic
|
||||
#[error("Domain error: {0}")]
|
||||
Domain(#[from] crate::MatError),
|
||||
}
|
||||
|
||||
impl ApiError {
|
||||
/// Create a not found error for an event.
|
||||
pub fn event_not_found(id: Uuid) -> Self {
|
||||
Self::NotFound {
|
||||
resource_type: "DisasterEvent".to_string(),
|
||||
id: id.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a not found error for a zone.
|
||||
pub fn zone_not_found(id: Uuid) -> Self {
|
||||
Self::NotFound {
|
||||
resource_type: "ScanZone".to_string(),
|
||||
id: id.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a not found error for a survivor.
|
||||
pub fn survivor_not_found(id: Uuid) -> Self {
|
||||
Self::NotFound {
|
||||
resource_type: "Survivor".to_string(),
|
||||
id: id.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a not found error for an alert.
|
||||
pub fn alert_not_found(id: Uuid) -> Self {
|
||||
Self::NotFound {
|
||||
resource_type: "Alert".to_string(),
|
||||
id: id.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a bad request error.
|
||||
pub fn bad_request(message: impl Into<String>) -> Self {
|
||||
Self::BadRequest {
|
||||
message: message.into(),
|
||||
source: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a validation error.
|
||||
pub fn validation(message: impl Into<String>, field: Option<String>) -> Self {
|
||||
Self::ValidationError {
|
||||
message: message.into(),
|
||||
field,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an internal error.
|
||||
pub fn internal(message: impl Into<String>) -> Self {
|
||||
Self::Internal {
|
||||
message: message.into(),
|
||||
source: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the HTTP status code for this error.
|
||||
pub fn status_code(&self) -> StatusCode {
|
||||
match self {
|
||||
Self::NotFound { .. } => StatusCode::NOT_FOUND,
|
||||
Self::BadRequest { .. } => StatusCode::BAD_REQUEST,
|
||||
Self::ValidationError { .. } => StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Self::Conflict { .. } => StatusCode::CONFLICT,
|
||||
Self::InvalidState { .. } => StatusCode::CONFLICT,
|
||||
Self::Internal { .. } => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::ServiceUnavailable { .. } => StatusCode::SERVICE_UNAVAILABLE,
|
||||
Self::Domain(_) => StatusCode::BAD_REQUEST,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the error code for this error.
|
||||
pub fn error_code(&self) -> &'static str {
|
||||
match self {
|
||||
Self::NotFound { .. } => "NOT_FOUND",
|
||||
Self::BadRequest { .. } => "BAD_REQUEST",
|
||||
Self::ValidationError { .. } => "VALIDATION_ERROR",
|
||||
Self::Conflict { .. } => "CONFLICT",
|
||||
Self::InvalidState { .. } => "INVALID_STATE",
|
||||
Self::Internal { .. } => "INTERNAL_ERROR",
|
||||
Self::ServiceUnavailable { .. } => "SERVICE_UNAVAILABLE",
|
||||
Self::Domain(_) => "DOMAIN_ERROR",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// JSON error response body.
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ErrorResponse {
|
||||
/// Machine-readable error code
|
||||
pub code: String,
|
||||
/// Human-readable error message
|
||||
pub message: String,
|
||||
/// Additional error details
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub details: Option<ErrorDetails>,
|
||||
/// Request ID for tracing (if available)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub request_id: Option<String>,
|
||||
}
|
||||
|
||||
/// Additional error details.
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ErrorDetails {
|
||||
/// Resource type involved
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub resource_type: Option<String>,
|
||||
/// Resource ID involved
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub resource_id: Option<String>,
|
||||
/// Field that caused the error
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub field: Option<String>,
|
||||
/// Current state (for state errors)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub current_state: Option<String>,
|
||||
}
|
||||
|
||||
impl IntoResponse for ApiError {
|
||||
fn into_response(self) -> Response {
|
||||
let status = self.status_code();
|
||||
let code = self.error_code().to_string();
|
||||
let message = self.to_string();
|
||||
|
||||
let details = match &self {
|
||||
ApiError::NotFound { resource_type, id } => Some(ErrorDetails {
|
||||
resource_type: Some(resource_type.clone()),
|
||||
resource_id: Some(id.clone()),
|
||||
field: None,
|
||||
current_state: None,
|
||||
}),
|
||||
ApiError::ValidationError { field, .. } => Some(ErrorDetails {
|
||||
resource_type: None,
|
||||
resource_id: None,
|
||||
field: field.clone(),
|
||||
current_state: None,
|
||||
}),
|
||||
ApiError::InvalidState { current_state, .. } => Some(ErrorDetails {
|
||||
resource_type: None,
|
||||
resource_id: None,
|
||||
field: None,
|
||||
current_state: Some(current_state.clone()),
|
||||
}),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
// Log errors
|
||||
match &self {
|
||||
ApiError::Internal { source, .. } | ApiError::BadRequest { source, .. } => {
|
||||
if let Some(src) = source {
|
||||
tracing::error!(error = %self, source = %src, "API error");
|
||||
} else {
|
||||
tracing::error!(error = %self, "API error");
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
tracing::warn!(error = %self, "API error");
|
||||
}
|
||||
}
|
||||
|
||||
let body = ErrorResponse {
|
||||
code,
|
||||
message,
|
||||
details,
|
||||
request_id: None, // Would be populated from request extension
|
||||
};
|
||||
|
||||
(status, Json(body)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
/// Result type alias for API handlers.
|
||||
pub type ApiResult<T> = Result<T, ApiError>;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_error_status_codes() {
|
||||
let not_found = ApiError::event_not_found(Uuid::new_v4());
|
||||
assert_eq!(not_found.status_code(), StatusCode::NOT_FOUND);
|
||||
|
||||
let bad_request = ApiError::bad_request("test");
|
||||
assert_eq!(bad_request.status_code(), StatusCode::BAD_REQUEST);
|
||||
|
||||
let internal = ApiError::internal("test");
|
||||
assert_eq!(internal.status_code(), StatusCode::INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_codes() {
|
||||
let not_found = ApiError::event_not_found(Uuid::new_v4());
|
||||
assert_eq!(not_found.error_code(), "NOT_FOUND");
|
||||
|
||||
let validation = ApiError::validation("test", Some("field".to_string()));
|
||||
assert_eq!(validation.error_code(), "VALIDATION_ERROR");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,886 @@
|
||||
//! Axum request handlers for the MAT REST API.
|
||||
//!
|
||||
//! This module contains all the HTTP endpoint handlers for disaster response operations.
|
||||
//! Each handler is documented with OpenAPI-style documentation comments.
|
||||
|
||||
use axum::{
|
||||
extract::{Path, Query, State},
|
||||
http::StatusCode,
|
||||
Json,
|
||||
};
|
||||
use geo::Point;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::dto::*;
|
||||
use super::error::{ApiError, ApiResult};
|
||||
use super::state::AppState;
|
||||
use crate::domain::{
|
||||
DisasterEvent, DisasterType, ScanZone, ZoneBounds,
|
||||
ScanParameters, ScanResolution, MovementType,
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Event Handlers
|
||||
// ============================================================================
|
||||
|
||||
/// List all disaster events.
|
||||
///
|
||||
/// # OpenAPI Specification
|
||||
///
|
||||
/// ```yaml
|
||||
/// /api/v1/mat/events:
|
||||
/// get:
|
||||
/// summary: List disaster events
|
||||
/// description: Returns a paginated list of disaster events with optional filtering
|
||||
/// tags: [Events]
|
||||
/// parameters:
|
||||
/// - name: status
|
||||
/// in: query
|
||||
/// description: Filter by event status
|
||||
/// schema:
|
||||
/// type: string
|
||||
/// enum: [Initializing, Active, Suspended, SecondarySearch, Closed]
|
||||
/// - name: event_type
|
||||
/// in: query
|
||||
/// description: Filter by disaster type
|
||||
/// schema:
|
||||
/// type: string
|
||||
/// - name: page
|
||||
/// in: query
|
||||
/// description: Page number (0-indexed)
|
||||
/// schema:
|
||||
/// type: integer
|
||||
/// default: 0
|
||||
/// - name: page_size
|
||||
/// in: query
|
||||
/// description: Items per page (max 100)
|
||||
/// schema:
|
||||
/// type: integer
|
||||
/// default: 20
|
||||
/// responses:
|
||||
/// 200:
|
||||
/// description: List of events
|
||||
/// content:
|
||||
/// application/json:
|
||||
/// schema:
|
||||
/// $ref: '#/components/schemas/EventListResponse'
|
||||
/// ```
|
||||
#[tracing::instrument(skip(state))]
|
||||
pub async fn list_events(
|
||||
State(state): State<AppState>,
|
||||
Query(query): Query<ListEventsQuery>,
|
||||
) -> ApiResult<Json<EventListResponse>> {
|
||||
let all_events = state.list_events();
|
||||
|
||||
// Apply filters
|
||||
let filtered: Vec<_> = all_events
|
||||
.into_iter()
|
||||
.filter(|e| {
|
||||
if let Some(ref status) = query.status {
|
||||
let event_status: EventStatusDto = e.status().clone().into();
|
||||
if !matches_status(&event_status, status) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if let Some(ref event_type) = query.event_type {
|
||||
let et: DisasterTypeDto = e.event_type().clone().into();
|
||||
if et != *event_type {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
})
|
||||
.collect();
|
||||
|
||||
let total = filtered.len();
|
||||
|
||||
// Apply pagination
|
||||
let page_size = query.page_size.min(100).max(1);
|
||||
let start = query.page * page_size;
|
||||
let events: Vec<_> = filtered
|
||||
.into_iter()
|
||||
.skip(start)
|
||||
.take(page_size)
|
||||
.map(event_to_response)
|
||||
.collect();
|
||||
|
||||
Ok(Json(EventListResponse {
|
||||
events,
|
||||
total,
|
||||
page: query.page,
|
||||
page_size,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Create a new disaster event.
|
||||
///
|
||||
/// # OpenAPI Specification
|
||||
///
|
||||
/// ```yaml
|
||||
/// /api/v1/mat/events:
|
||||
/// post:
|
||||
/// summary: Create a new disaster event
|
||||
/// description: Creates a new disaster event for search and rescue operations
|
||||
/// tags: [Events]
|
||||
/// requestBody:
|
||||
/// required: true
|
||||
/// content:
|
||||
/// application/json:
|
||||
/// schema:
|
||||
/// $ref: '#/components/schemas/CreateEventRequest'
|
||||
/// responses:
|
||||
/// 201:
|
||||
/// description: Event created successfully
|
||||
/// content:
|
||||
/// application/json:
|
||||
/// schema:
|
||||
/// $ref: '#/components/schemas/EventResponse'
|
||||
/// 400:
|
||||
/// description: Invalid request data
|
||||
/// content:
|
||||
/// application/json:
|
||||
/// schema:
|
||||
/// $ref: '#/components/schemas/ErrorResponse'
|
||||
/// ```
|
||||
#[tracing::instrument(skip(state))]
|
||||
pub async fn create_event(
|
||||
State(state): State<AppState>,
|
||||
Json(request): Json<CreateEventRequest>,
|
||||
) -> ApiResult<(StatusCode, Json<EventResponse>)> {
|
||||
// Validate coordinates
|
||||
if request.latitude < -90.0 || request.latitude > 90.0 {
|
||||
return Err(ApiError::validation(
|
||||
"Latitude must be between -90 and 90",
|
||||
Some("latitude".to_string()),
|
||||
));
|
||||
}
|
||||
if request.longitude < -180.0 || request.longitude > 180.0 {
|
||||
return Err(ApiError::validation(
|
||||
"Longitude must be between -180 and 180",
|
||||
Some("longitude".to_string()),
|
||||
));
|
||||
}
|
||||
|
||||
let disaster_type: DisasterType = request.event_type.into();
|
||||
let location = Point::new(request.longitude, request.latitude);
|
||||
let mut event = DisasterEvent::new(disaster_type, location, &request.description);
|
||||
|
||||
// Set metadata if provided
|
||||
if let Some(occupancy) = request.estimated_occupancy {
|
||||
event.metadata_mut().estimated_occupancy = Some(occupancy);
|
||||
}
|
||||
if let Some(agency) = request.lead_agency {
|
||||
event.metadata_mut().lead_agency = Some(agency);
|
||||
}
|
||||
|
||||
let response = event_to_response(event.clone());
|
||||
let event_id = *event.id().as_uuid();
|
||||
state.store_event(event);
|
||||
|
||||
// Broadcast event creation
|
||||
state.broadcast(WebSocketMessage::EventStatusChanged {
|
||||
event_id,
|
||||
old_status: EventStatusDto::Initializing,
|
||||
new_status: response.status,
|
||||
});
|
||||
|
||||
tracing::info!(event_id = %event_id, "Created new disaster event");
|
||||
|
||||
Ok((StatusCode::CREATED, Json(response)))
|
||||
}
|
||||
|
||||
/// Get a specific disaster event by ID.
|
||||
///
|
||||
/// # OpenAPI Specification
|
||||
///
|
||||
/// ```yaml
|
||||
/// /api/v1/mat/events/{event_id}:
|
||||
/// get:
|
||||
/// summary: Get event details
|
||||
/// description: Returns detailed information about a specific disaster event
|
||||
/// tags: [Events]
|
||||
/// parameters:
|
||||
/// - name: event_id
|
||||
/// in: path
|
||||
/// required: true
|
||||
/// description: Event UUID
|
||||
/// schema:
|
||||
/// type: string
|
||||
/// format: uuid
|
||||
/// responses:
|
||||
/// 200:
|
||||
/// description: Event details
|
||||
/// content:
|
||||
/// application/json:
|
||||
/// schema:
|
||||
/// $ref: '#/components/schemas/EventResponse'
|
||||
/// 404:
|
||||
/// description: Event not found
|
||||
/// ```
|
||||
#[tracing::instrument(skip(state))]
|
||||
pub async fn get_event(
|
||||
State(state): State<AppState>,
|
||||
Path(event_id): Path<Uuid>,
|
||||
) -> ApiResult<Json<EventResponse>> {
|
||||
let event = state
|
||||
.get_event(event_id)
|
||||
.ok_or_else(|| ApiError::event_not_found(event_id))?;
|
||||
|
||||
Ok(Json(event_to_response(event)))
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Zone Handlers
|
||||
// ============================================================================
|
||||
|
||||
/// List all zones for a disaster event.
|
||||
///
|
||||
/// # OpenAPI Specification
|
||||
///
|
||||
/// ```yaml
|
||||
/// /api/v1/mat/events/{event_id}/zones:
|
||||
/// get:
|
||||
/// summary: List zones for an event
|
||||
/// description: Returns all scan zones configured for a disaster event
|
||||
/// tags: [Zones]
|
||||
/// parameters:
|
||||
/// - name: event_id
|
||||
/// in: path
|
||||
/// required: true
|
||||
/// schema:
|
||||
/// type: string
|
||||
/// format: uuid
|
||||
/// responses:
|
||||
/// 200:
|
||||
/// description: List of zones
|
||||
/// content:
|
||||
/// application/json:
|
||||
/// schema:
|
||||
/// $ref: '#/components/schemas/ZoneListResponse'
|
||||
/// 404:
|
||||
/// description: Event not found
|
||||
/// ```
|
||||
#[tracing::instrument(skip(state))]
|
||||
pub async fn list_zones(
|
||||
State(state): State<AppState>,
|
||||
Path(event_id): Path<Uuid>,
|
||||
) -> ApiResult<Json<ZoneListResponse>> {
|
||||
let event = state
|
||||
.get_event(event_id)
|
||||
.ok_or_else(|| ApiError::event_not_found(event_id))?;
|
||||
|
||||
let zones: Vec<_> = event.zones().iter().map(zone_to_response).collect();
|
||||
let total = zones.len();
|
||||
|
||||
Ok(Json(ZoneListResponse { zones, total }))
|
||||
}
|
||||
|
||||
/// Add a scan zone to a disaster event.
|
||||
///
|
||||
/// # OpenAPI Specification
|
||||
///
|
||||
/// ```yaml
|
||||
/// /api/v1/mat/events/{event_id}/zones:
|
||||
/// post:
|
||||
/// summary: Add a scan zone
|
||||
/// description: Creates a new scan zone within a disaster event area
|
||||
/// tags: [Zones]
|
||||
/// parameters:
|
||||
/// - name: event_id
|
||||
/// in: path
|
||||
/// required: true
|
||||
/// schema:
|
||||
/// type: string
|
||||
/// format: uuid
|
||||
/// requestBody:
|
||||
/// required: true
|
||||
/// content:
|
||||
/// application/json:
|
||||
/// schema:
|
||||
/// $ref: '#/components/schemas/CreateZoneRequest'
|
||||
/// responses:
|
||||
/// 201:
|
||||
/// description: Zone created successfully
|
||||
/// content:
|
||||
/// application/json:
|
||||
/// schema:
|
||||
/// $ref: '#/components/schemas/ZoneResponse'
|
||||
/// 404:
|
||||
/// description: Event not found
|
||||
/// 400:
|
||||
/// description: Invalid zone configuration
|
||||
/// ```
|
||||
#[tracing::instrument(skip(state))]
|
||||
pub async fn add_zone(
|
||||
State(state): State<AppState>,
|
||||
Path(event_id): Path<Uuid>,
|
||||
Json(request): Json<CreateZoneRequest>,
|
||||
) -> ApiResult<(StatusCode, Json<ZoneResponse>)> {
|
||||
// Convert DTO to domain
|
||||
let bounds = match request.bounds {
|
||||
ZoneBoundsDto::Rectangle { min_x, min_y, max_x, max_y } => {
|
||||
if max_x <= min_x || max_y <= min_y {
|
||||
return Err(ApiError::validation(
|
||||
"max coordinates must be greater than min coordinates",
|
||||
Some("bounds".to_string()),
|
||||
));
|
||||
}
|
||||
ZoneBounds::rectangle(min_x, min_y, max_x, max_y)
|
||||
}
|
||||
ZoneBoundsDto::Circle { center_x, center_y, radius } => {
|
||||
if radius <= 0.0 {
|
||||
return Err(ApiError::validation(
|
||||
"radius must be positive",
|
||||
Some("bounds.radius".to_string()),
|
||||
));
|
||||
}
|
||||
ZoneBounds::circle(center_x, center_y, radius)
|
||||
}
|
||||
ZoneBoundsDto::Polygon { vertices } => {
|
||||
if vertices.len() < 3 {
|
||||
return Err(ApiError::validation(
|
||||
"polygon must have at least 3 vertices",
|
||||
Some("bounds.vertices".to_string()),
|
||||
));
|
||||
}
|
||||
ZoneBounds::polygon(vertices)
|
||||
}
|
||||
};
|
||||
|
||||
let params = if let Some(p) = request.parameters {
|
||||
ScanParameters {
|
||||
sensitivity: p.sensitivity.clamp(0.0, 1.0),
|
||||
max_depth: p.max_depth.max(0.0),
|
||||
resolution: match p.resolution {
|
||||
ScanResolutionDto::Quick => ScanResolution::Quick,
|
||||
ScanResolutionDto::Standard => ScanResolution::Standard,
|
||||
ScanResolutionDto::High => ScanResolution::High,
|
||||
ScanResolutionDto::Maximum => ScanResolution::Maximum,
|
||||
},
|
||||
enhanced_breathing: p.enhanced_breathing,
|
||||
heartbeat_detection: p.heartbeat_detection,
|
||||
}
|
||||
} else {
|
||||
ScanParameters::default()
|
||||
};
|
||||
|
||||
let zone = ScanZone::with_parameters(&request.name, bounds, params);
|
||||
let zone_response = zone_to_response(&zone);
|
||||
let zone_id = *zone.id().as_uuid();
|
||||
|
||||
// Add zone to event
|
||||
let added = state.update_event(event_id, move |e| {
|
||||
e.add_zone(zone);
|
||||
true
|
||||
});
|
||||
|
||||
if added.is_none() {
|
||||
return Err(ApiError::event_not_found(event_id));
|
||||
}
|
||||
|
||||
tracing::info!(event_id = %event_id, zone_id = %zone_id, "Added scan zone");
|
||||
|
||||
Ok((StatusCode::CREATED, Json(zone_response)))
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Survivor Handlers
|
||||
// ============================================================================
|
||||
|
||||
/// List survivors detected in a disaster event.
|
||||
///
|
||||
/// # OpenAPI Specification
|
||||
///
|
||||
/// ```yaml
|
||||
/// /api/v1/mat/events/{event_id}/survivors:
|
||||
/// get:
|
||||
/// summary: List survivors
|
||||
/// description: Returns all detected survivors in a disaster event
|
||||
/// tags: [Survivors]
|
||||
/// parameters:
|
||||
/// - name: event_id
|
||||
/// in: path
|
||||
/// required: true
|
||||
/// schema:
|
||||
/// type: string
|
||||
/// format: uuid
|
||||
/// - name: triage_status
|
||||
/// in: query
|
||||
/// description: Filter by triage status
|
||||
/// schema:
|
||||
/// type: string
|
||||
/// enum: [Immediate, Delayed, Minor, Deceased, Unknown]
|
||||
/// - name: zone_id
|
||||
/// in: query
|
||||
/// description: Filter by zone
|
||||
/// schema:
|
||||
/// type: string
|
||||
/// format: uuid
|
||||
/// - name: min_confidence
|
||||
/// in: query
|
||||
/// description: Minimum confidence threshold
|
||||
/// schema:
|
||||
/// type: number
|
||||
/// - name: deteriorating_only
|
||||
/// in: query
|
||||
/// description: Only return deteriorating survivors
|
||||
/// schema:
|
||||
/// type: boolean
|
||||
/// responses:
|
||||
/// 200:
|
||||
/// description: List of survivors
|
||||
/// content:
|
||||
/// application/json:
|
||||
/// schema:
|
||||
/// $ref: '#/components/schemas/SurvivorListResponse'
|
||||
/// 404:
|
||||
/// description: Event not found
|
||||
/// ```
|
||||
#[tracing::instrument(skip(state))]
|
||||
pub async fn list_survivors(
|
||||
State(state): State<AppState>,
|
||||
Path(event_id): Path<Uuid>,
|
||||
Query(query): Query<ListSurvivorsQuery>,
|
||||
) -> ApiResult<Json<SurvivorListResponse>> {
|
||||
let event = state
|
||||
.get_event(event_id)
|
||||
.ok_or_else(|| ApiError::event_not_found(event_id))?;
|
||||
|
||||
let mut triage_summary = TriageSummary::default();
|
||||
let survivors: Vec<_> = event
|
||||
.survivors()
|
||||
.into_iter()
|
||||
.filter(|s| {
|
||||
// Update triage counts for all survivors
|
||||
update_triage_summary(&mut triage_summary, s.triage_status());
|
||||
|
||||
// Apply filters
|
||||
if let Some(ref ts) = query.triage_status {
|
||||
let survivor_triage: TriageStatusDto = s.triage_status().clone().into();
|
||||
if !matches_triage_status(&survivor_triage, ts) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if let Some(zone_id) = query.zone_id {
|
||||
if s.zone_id().as_uuid() != &zone_id {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if let Some(min_conf) = query.min_confidence {
|
||||
if s.confidence() < min_conf {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if query.deteriorating_only && !s.is_deteriorating() {
|
||||
return false;
|
||||
}
|
||||
true
|
||||
})
|
||||
.map(survivor_to_response)
|
||||
.collect();
|
||||
|
||||
let total = survivors.len();
|
||||
|
||||
Ok(Json(SurvivorListResponse {
|
||||
survivors,
|
||||
total,
|
||||
triage_summary,
|
||||
}))
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Alert Handlers
|
||||
// ============================================================================
|
||||
|
||||
/// List alerts for a disaster event.
|
||||
///
|
||||
/// # OpenAPI Specification
|
||||
///
|
||||
/// ```yaml
|
||||
/// /api/v1/mat/events/{event_id}/alerts:
|
||||
/// get:
|
||||
/// summary: List alerts
|
||||
/// description: Returns all alerts generated for a disaster event
|
||||
/// tags: [Alerts]
|
||||
/// parameters:
|
||||
/// - name: event_id
|
||||
/// in: path
|
||||
/// required: true
|
||||
/// schema:
|
||||
/// type: string
|
||||
/// format: uuid
|
||||
/// - name: priority
|
||||
/// in: query
|
||||
/// description: Filter by priority
|
||||
/// schema:
|
||||
/// type: string
|
||||
/// enum: [Critical, High, Medium, Low]
|
||||
/// - name: status
|
||||
/// in: query
|
||||
/// description: Filter by status
|
||||
/// schema:
|
||||
/// type: string
|
||||
/// - name: pending_only
|
||||
/// in: query
|
||||
/// description: Only return pending alerts
|
||||
/// schema:
|
||||
/// type: boolean
|
||||
/// - name: active_only
|
||||
/// in: query
|
||||
/// description: Only return active alerts
|
||||
/// schema:
|
||||
/// type: boolean
|
||||
/// responses:
|
||||
/// 200:
|
||||
/// description: List of alerts
|
||||
/// content:
|
||||
/// application/json:
|
||||
/// schema:
|
||||
/// $ref: '#/components/schemas/AlertListResponse'
|
||||
/// 404:
|
||||
/// description: Event not found
|
||||
/// ```
|
||||
#[tracing::instrument(skip(state))]
|
||||
pub async fn list_alerts(
|
||||
State(state): State<AppState>,
|
||||
Path(event_id): Path<Uuid>,
|
||||
Query(query): Query<ListAlertsQuery>,
|
||||
) -> ApiResult<Json<AlertListResponse>> {
|
||||
// Verify event exists
|
||||
if state.get_event(event_id).is_none() {
|
||||
return Err(ApiError::event_not_found(event_id));
|
||||
}
|
||||
|
||||
let all_alerts = state.list_alerts_for_event(event_id);
|
||||
let mut priority_counts = PriorityCounts::default();
|
||||
|
||||
let alerts: Vec<_> = all_alerts
|
||||
.into_iter()
|
||||
.filter(|a| {
|
||||
// Update priority counts
|
||||
update_priority_counts(&mut priority_counts, a.priority());
|
||||
|
||||
// Apply filters
|
||||
if let Some(ref priority) = query.priority {
|
||||
let alert_priority: PriorityDto = a.priority().into();
|
||||
if !matches_priority(&alert_priority, priority) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if let Some(ref status) = query.status {
|
||||
let alert_status: AlertStatusDto = a.status().clone().into();
|
||||
if !matches_alert_status(&alert_status, status) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if query.pending_only && !a.is_pending() {
|
||||
return false;
|
||||
}
|
||||
if query.active_only && !a.is_active() {
|
||||
return false;
|
||||
}
|
||||
true
|
||||
})
|
||||
.map(|a| alert_to_response(&a))
|
||||
.collect();
|
||||
|
||||
let total = alerts.len();
|
||||
|
||||
Ok(Json(AlertListResponse {
|
||||
alerts,
|
||||
total,
|
||||
priority_counts,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Acknowledge an alert.
|
||||
///
|
||||
/// # OpenAPI Specification
|
||||
///
|
||||
/// ```yaml
|
||||
/// /api/v1/mat/alerts/{alert_id}/acknowledge:
|
||||
/// post:
|
||||
/// summary: Acknowledge an alert
|
||||
/// description: Marks an alert as acknowledged by a rescue team
|
||||
/// tags: [Alerts]
|
||||
/// parameters:
|
||||
/// - name: alert_id
|
||||
/// in: path
|
||||
/// required: true
|
||||
/// schema:
|
||||
/// type: string
|
||||
/// format: uuid
|
||||
/// requestBody:
|
||||
/// required: true
|
||||
/// content:
|
||||
/// application/json:
|
||||
/// schema:
|
||||
/// $ref: '#/components/schemas/AcknowledgeAlertRequest'
|
||||
/// responses:
|
||||
/// 200:
|
||||
/// description: Alert acknowledged
|
||||
/// content:
|
||||
/// application/json:
|
||||
/// schema:
|
||||
/// $ref: '#/components/schemas/AcknowledgeAlertResponse'
|
||||
/// 404:
|
||||
/// description: Alert not found
|
||||
/// 409:
|
||||
/// description: Alert already acknowledged
|
||||
/// ```
|
||||
#[tracing::instrument(skip(state))]
|
||||
pub async fn acknowledge_alert(
|
||||
State(state): State<AppState>,
|
||||
Path(alert_id): Path<Uuid>,
|
||||
Json(request): Json<AcknowledgeAlertRequest>,
|
||||
) -> ApiResult<Json<AcknowledgeAlertResponse>> {
|
||||
let alert_data = state
|
||||
.get_alert(alert_id)
|
||||
.ok_or_else(|| ApiError::alert_not_found(alert_id))?;
|
||||
|
||||
if !alert_data.alert.is_pending() {
|
||||
return Err(ApiError::InvalidState {
|
||||
message: "Alert is not in pending state".to_string(),
|
||||
current_state: format!("{:?}", alert_data.alert.status()),
|
||||
});
|
||||
}
|
||||
|
||||
let event_id = alert_data.event_id;
|
||||
|
||||
// Acknowledge the alert
|
||||
state.update_alert(alert_id, |a| {
|
||||
a.acknowledge(&request.acknowledged_by);
|
||||
});
|
||||
|
||||
// Get updated alert
|
||||
let updated = state
|
||||
.get_alert(alert_id)
|
||||
.ok_or_else(|| ApiError::alert_not_found(alert_id))?;
|
||||
|
||||
let response = alert_to_response(&updated.alert);
|
||||
|
||||
// Broadcast update
|
||||
state.broadcast(WebSocketMessage::AlertUpdated {
|
||||
event_id,
|
||||
alert: response.clone(),
|
||||
});
|
||||
|
||||
tracing::info!(
|
||||
alert_id = %alert_id,
|
||||
acknowledged_by = %request.acknowledged_by,
|
||||
"Alert acknowledged"
|
||||
);
|
||||
|
||||
Ok(Json(AcknowledgeAlertResponse {
|
||||
success: true,
|
||||
alert: response,
|
||||
}))
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Helper Functions
|
||||
// ============================================================================
|
||||
|
||||
fn event_to_response(event: DisasterEvent) -> EventResponse {
|
||||
let triage_counts = event.triage_counts();
|
||||
|
||||
EventResponse {
|
||||
id: *event.id().as_uuid(),
|
||||
event_type: event.event_type().clone().into(),
|
||||
status: event.status().clone().into(),
|
||||
start_time: *event.start_time(),
|
||||
latitude: event.location().y(),
|
||||
longitude: event.location().x(),
|
||||
description: event.description().to_string(),
|
||||
zone_count: event.zones().len(),
|
||||
survivor_count: event.survivors().len(),
|
||||
triage_summary: TriageSummary {
|
||||
immediate: triage_counts.immediate,
|
||||
delayed: triage_counts.delayed,
|
||||
minor: triage_counts.minor,
|
||||
deceased: triage_counts.deceased,
|
||||
unknown: triage_counts.unknown,
|
||||
},
|
||||
metadata: Some(EventMetadataDto {
|
||||
estimated_occupancy: event.metadata().estimated_occupancy,
|
||||
confirmed_rescued: event.metadata().confirmed_rescued,
|
||||
confirmed_deceased: event.metadata().confirmed_deceased,
|
||||
weather: event.metadata().weather.clone(),
|
||||
lead_agency: event.metadata().lead_agency.clone(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn zone_to_response(zone: &ScanZone) -> ZoneResponse {
|
||||
let bounds = match zone.bounds() {
|
||||
ZoneBounds::Rectangle { min_x, min_y, max_x, max_y } => {
|
||||
ZoneBoundsDto::Rectangle {
|
||||
min_x: *min_x,
|
||||
min_y: *min_y,
|
||||
max_x: *max_x,
|
||||
max_y: *max_y,
|
||||
}
|
||||
}
|
||||
ZoneBounds::Circle { center_x, center_y, radius } => {
|
||||
ZoneBoundsDto::Circle {
|
||||
center_x: *center_x,
|
||||
center_y: *center_y,
|
||||
radius: *radius,
|
||||
}
|
||||
}
|
||||
ZoneBounds::Polygon { vertices } => {
|
||||
ZoneBoundsDto::Polygon {
|
||||
vertices: vertices.clone(),
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let params = zone.parameters();
|
||||
let parameters = ScanParametersDto {
|
||||
sensitivity: params.sensitivity,
|
||||
max_depth: params.max_depth,
|
||||
resolution: match params.resolution {
|
||||
ScanResolution::Quick => ScanResolutionDto::Quick,
|
||||
ScanResolution::Standard => ScanResolutionDto::Standard,
|
||||
ScanResolution::High => ScanResolutionDto::High,
|
||||
ScanResolution::Maximum => ScanResolutionDto::Maximum,
|
||||
},
|
||||
enhanced_breathing: params.enhanced_breathing,
|
||||
heartbeat_detection: params.heartbeat_detection,
|
||||
};
|
||||
|
||||
ZoneResponse {
|
||||
id: *zone.id().as_uuid(),
|
||||
name: zone.name().to_string(),
|
||||
status: zone.status().clone().into(),
|
||||
bounds,
|
||||
area: zone.area(),
|
||||
parameters,
|
||||
last_scan: zone.last_scan().cloned(),
|
||||
scan_count: zone.scan_count(),
|
||||
detections_count: zone.detections_count(),
|
||||
}
|
||||
}
|
||||
|
||||
fn survivor_to_response(survivor: &crate::Survivor) -> SurvivorResponse {
|
||||
let location = survivor.location().map(|loc| LocationDto {
|
||||
x: loc.x,
|
||||
y: loc.y,
|
||||
z: loc.z,
|
||||
depth: loc.depth(),
|
||||
uncertainty_radius: loc.uncertainty.horizontal_error,
|
||||
confidence: loc.uncertainty.confidence,
|
||||
});
|
||||
|
||||
let latest_vitals = survivor.vital_signs().latest();
|
||||
let vital_signs = VitalSignsSummaryDto {
|
||||
breathing_rate: latest_vitals.and_then(|v| v.breathing.as_ref().map(|b| b.rate_bpm)),
|
||||
breathing_type: latest_vitals.and_then(|v| v.breathing.as_ref().map(|b| format!("{:?}", b.pattern_type))),
|
||||
heart_rate: latest_vitals.and_then(|v| v.heartbeat.as_ref().map(|h| h.rate_bpm)),
|
||||
has_heartbeat: latest_vitals.map(|v| v.has_heartbeat()).unwrap_or(false),
|
||||
has_movement: latest_vitals.map(|v| v.has_movement()).unwrap_or(false),
|
||||
movement_type: latest_vitals.and_then(|v| {
|
||||
if v.movement.movement_type != MovementType::None {
|
||||
Some(format!("{:?}", v.movement.movement_type))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}),
|
||||
timestamp: latest_vitals.map(|v| v.timestamp).unwrap_or_else(chrono::Utc::now),
|
||||
};
|
||||
|
||||
let metadata = {
|
||||
let m = survivor.metadata();
|
||||
if m.notes.is_empty() && m.tags.is_empty() && m.assigned_team.is_none() {
|
||||
None
|
||||
} else {
|
||||
Some(SurvivorMetadataDto {
|
||||
estimated_age_category: m.estimated_age_category.as_ref().map(|a| format!("{:?}", a)),
|
||||
assigned_team: m.assigned_team.clone(),
|
||||
notes: m.notes.clone(),
|
||||
tags: m.tags.clone(),
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
SurvivorResponse {
|
||||
id: *survivor.id().as_uuid(),
|
||||
zone_id: *survivor.zone_id().as_uuid(),
|
||||
status: survivor.status().clone().into(),
|
||||
triage_status: survivor.triage_status().clone().into(),
|
||||
location,
|
||||
vital_signs,
|
||||
confidence: survivor.confidence(),
|
||||
first_detected: *survivor.first_detected(),
|
||||
last_updated: *survivor.last_updated(),
|
||||
is_deteriorating: survivor.is_deteriorating(),
|
||||
metadata,
|
||||
}
|
||||
}
|
||||
|
||||
fn alert_to_response(alert: &crate::Alert) -> AlertResponse {
|
||||
let location = alert.payload().location.as_ref().map(|loc| LocationDto {
|
||||
x: loc.x,
|
||||
y: loc.y,
|
||||
z: loc.z,
|
||||
depth: loc.depth(),
|
||||
uncertainty_radius: loc.uncertainty.horizontal_error,
|
||||
confidence: loc.uncertainty.confidence,
|
||||
});
|
||||
|
||||
AlertResponse {
|
||||
id: *alert.id().as_uuid(),
|
||||
survivor_id: *alert.survivor_id().as_uuid(),
|
||||
priority: alert.priority().into(),
|
||||
status: alert.status().clone().into(),
|
||||
title: alert.payload().title.clone(),
|
||||
message: alert.payload().message.clone(),
|
||||
triage_status: alert.payload().triage_status.clone().into(),
|
||||
location,
|
||||
recommended_action: if alert.payload().recommended_action.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(alert.payload().recommended_action.clone())
|
||||
},
|
||||
created_at: *alert.created_at(),
|
||||
acknowledged_at: alert.acknowledged_at().cloned(),
|
||||
acknowledged_by: alert.acknowledged_by().map(String::from),
|
||||
escalation_count: alert.escalation_count(),
|
||||
}
|
||||
}
|
||||
|
||||
fn update_triage_summary(summary: &mut TriageSummary, status: &crate::TriageStatus) {
|
||||
match status {
|
||||
crate::TriageStatus::Immediate => summary.immediate += 1,
|
||||
crate::TriageStatus::Delayed => summary.delayed += 1,
|
||||
crate::TriageStatus::Minor => summary.minor += 1,
|
||||
crate::TriageStatus::Deceased => summary.deceased += 1,
|
||||
crate::TriageStatus::Unknown => summary.unknown += 1,
|
||||
}
|
||||
}
|
||||
|
||||
fn update_priority_counts(counts: &mut PriorityCounts, priority: crate::Priority) {
|
||||
match priority {
|
||||
crate::Priority::Critical => counts.critical += 1,
|
||||
crate::Priority::High => counts.high += 1,
|
||||
crate::Priority::Medium => counts.medium += 1,
|
||||
crate::Priority::Low => counts.low += 1,
|
||||
}
|
||||
}
|
||||
|
||||
// Match helper functions (avoiding PartialEq on DTOs for flexibility)
|
||||
fn matches_status(a: &EventStatusDto, b: &EventStatusDto) -> bool {
|
||||
std::mem::discriminant(a) == std::mem::discriminant(b)
|
||||
}
|
||||
|
||||
fn matches_triage_status(a: &TriageStatusDto, b: &TriageStatusDto) -> bool {
|
||||
std::mem::discriminant(a) == std::mem::discriminant(b)
|
||||
}
|
||||
|
||||
fn matches_priority(a: &PriorityDto, b: &PriorityDto) -> bool {
|
||||
std::mem::discriminant(a) == std::mem::discriminant(b)
|
||||
}
|
||||
|
||||
fn matches_alert_status(a: &AlertStatusDto, b: &AlertStatusDto) -> bool {
|
||||
std::mem::discriminant(a) == std::mem::discriminant(b)
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
//! REST API endpoints for WiFi-DensePose MAT disaster response monitoring.
|
||||
//!
|
||||
//! This module provides a complete REST API and WebSocket interface for
|
||||
//! managing disaster events, zones, survivors, and alerts in real-time.
|
||||
//!
|
||||
//! ## Endpoints
|
||||
//!
|
||||
//! ### Disaster Events
|
||||
//! - `GET /api/v1/mat/events` - List all disaster events
|
||||
//! - `POST /api/v1/mat/events` - Create new disaster event
|
||||
//! - `GET /api/v1/mat/events/{id}` - Get event details
|
||||
//!
|
||||
//! ### Zones
|
||||
//! - `GET /api/v1/mat/events/{id}/zones` - List zones for event
|
||||
//! - `POST /api/v1/mat/events/{id}/zones` - Add zone to event
|
||||
//!
|
||||
//! ### Survivors
|
||||
//! - `GET /api/v1/mat/events/{id}/survivors` - List survivors in event
|
||||
//!
|
||||
//! ### Alerts
|
||||
//! - `GET /api/v1/mat/events/{id}/alerts` - List alerts for event
|
||||
//! - `POST /api/v1/mat/alerts/{id}/acknowledge` - Acknowledge alert
|
||||
//!
|
||||
//! ### WebSocket
|
||||
//! - `WS /ws/mat/stream` - Real-time survivor and alert stream
|
||||
|
||||
pub mod dto;
|
||||
pub mod handlers;
|
||||
pub mod error;
|
||||
pub mod state;
|
||||
pub mod websocket;
|
||||
|
||||
use axum::{
|
||||
Router,
|
||||
routing::{get, post},
|
||||
};
|
||||
|
||||
pub use dto::*;
|
||||
pub use error::ApiError;
|
||||
pub use state::AppState;
|
||||
|
||||
/// Create the MAT API router with all endpoints.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// use wifi_densepose_mat::api::{create_router, AppState};
|
||||
///
|
||||
/// #[tokio::main]
|
||||
/// async fn main() {
|
||||
/// let state = AppState::new();
|
||||
/// let app = create_router(state);
|
||||
/// // ... serve with axum
|
||||
/// }
|
||||
/// ```
|
||||
pub fn create_router(state: AppState) -> Router {
|
||||
Router::new()
|
||||
// Event endpoints
|
||||
.route("/api/v1/mat/events", get(handlers::list_events).post(handlers::create_event))
|
||||
.route("/api/v1/mat/events/:event_id", get(handlers::get_event))
|
||||
// Zone endpoints
|
||||
.route("/api/v1/mat/events/:event_id/zones", get(handlers::list_zones).post(handlers::add_zone))
|
||||
// Survivor endpoints
|
||||
.route("/api/v1/mat/events/:event_id/survivors", get(handlers::list_survivors))
|
||||
// Alert endpoints
|
||||
.route("/api/v1/mat/events/:event_id/alerts", get(handlers::list_alerts))
|
||||
.route("/api/v1/mat/alerts/:alert_id/acknowledge", post(handlers::acknowledge_alert))
|
||||
// WebSocket endpoint
|
||||
.route("/ws/mat/stream", get(websocket::ws_handler))
|
||||
.with_state(state)
|
||||
}
|
||||
@@ -0,0 +1,258 @@
|
||||
//! Application state for the MAT REST API.
|
||||
//!
|
||||
//! This module provides the shared state that is passed to all API handlers.
|
||||
//! It contains repositories, services, and real-time event broadcasting.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use parking_lot::RwLock;
|
||||
use tokio::sync::broadcast;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::domain::{
|
||||
DisasterEvent, Alert,
|
||||
};
|
||||
use super::dto::WebSocketMessage;
|
||||
|
||||
/// Shared application state for the API.
|
||||
///
|
||||
/// This is cloned for each request handler and provides thread-safe
|
||||
/// access to shared resources.
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
inner: Arc<AppStateInner>,
|
||||
}
|
||||
|
||||
/// Inner state (not cloned, shared via Arc).
|
||||
struct AppStateInner {
|
||||
/// In-memory event repository
|
||||
events: RwLock<HashMap<Uuid, DisasterEvent>>,
|
||||
/// In-memory alert repository
|
||||
alerts: RwLock<HashMap<Uuid, AlertWithEventId>>,
|
||||
/// Broadcast channel for real-time updates
|
||||
broadcast_tx: broadcast::Sender<WebSocketMessage>,
|
||||
/// Configuration
|
||||
config: ApiConfig,
|
||||
}
|
||||
|
||||
/// Alert with its associated event ID for lookup.
|
||||
#[derive(Clone)]
|
||||
pub struct AlertWithEventId {
|
||||
pub alert: Alert,
|
||||
pub event_id: Uuid,
|
||||
}
|
||||
|
||||
/// API configuration.
|
||||
#[derive(Clone)]
|
||||
pub struct ApiConfig {
|
||||
/// Maximum number of events to store
|
||||
pub max_events: usize,
|
||||
/// Maximum survivors per event
|
||||
pub max_survivors_per_event: usize,
|
||||
/// Broadcast channel capacity
|
||||
pub broadcast_capacity: usize,
|
||||
}
|
||||
|
||||
impl Default for ApiConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_events: 1000,
|
||||
max_survivors_per_event: 10000,
|
||||
broadcast_capacity: 1024,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
/// Create a new application state with default configuration.
|
||||
pub fn new() -> Self {
|
||||
Self::with_config(ApiConfig::default())
|
||||
}
|
||||
|
||||
/// Create a new application state with custom configuration.
|
||||
pub fn with_config(config: ApiConfig) -> Self {
|
||||
let (broadcast_tx, _) = broadcast::channel(config.broadcast_capacity);
|
||||
|
||||
Self {
|
||||
inner: Arc::new(AppStateInner {
|
||||
events: RwLock::new(HashMap::new()),
|
||||
alerts: RwLock::new(HashMap::new()),
|
||||
broadcast_tx,
|
||||
config,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Event Operations
|
||||
// ========================================================================
|
||||
|
||||
/// Store a disaster event.
|
||||
pub fn store_event(&self, event: DisasterEvent) -> Uuid {
|
||||
let id = *event.id().as_uuid();
|
||||
let mut events = self.inner.events.write();
|
||||
|
||||
// Check capacity
|
||||
if events.len() >= self.inner.config.max_events {
|
||||
// Remove oldest closed event
|
||||
let oldest_closed = events
|
||||
.iter()
|
||||
.filter(|(_, e)| matches!(e.status(), crate::EventStatus::Closed))
|
||||
.min_by_key(|(_, e)| e.start_time())
|
||||
.map(|(id, _)| *id);
|
||||
|
||||
if let Some(old_id) = oldest_closed {
|
||||
events.remove(&old_id);
|
||||
}
|
||||
}
|
||||
|
||||
events.insert(id, event);
|
||||
id
|
||||
}
|
||||
|
||||
/// Get an event by ID.
|
||||
pub fn get_event(&self, id: Uuid) -> Option<DisasterEvent> {
|
||||
self.inner.events.read().get(&id).cloned()
|
||||
}
|
||||
|
||||
/// Get mutable access to an event (for updates).
|
||||
pub fn update_event<F, R>(&self, id: Uuid, f: F) -> Option<R>
|
||||
where
|
||||
F: FnOnce(&mut DisasterEvent) -> R,
|
||||
{
|
||||
let mut events = self.inner.events.write();
|
||||
events.get_mut(&id).map(f)
|
||||
}
|
||||
|
||||
/// List all events.
|
||||
pub fn list_events(&self) -> Vec<DisasterEvent> {
|
||||
self.inner.events.read().values().cloned().collect()
|
||||
}
|
||||
|
||||
/// Get event count.
|
||||
pub fn event_count(&self) -> usize {
|
||||
self.inner.events.read().len()
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Alert Operations
|
||||
// ========================================================================
|
||||
|
||||
/// Store an alert.
|
||||
pub fn store_alert(&self, alert: Alert, event_id: Uuid) -> Uuid {
|
||||
let id = *alert.id().as_uuid();
|
||||
let mut alerts = self.inner.alerts.write();
|
||||
alerts.insert(id, AlertWithEventId { alert, event_id });
|
||||
id
|
||||
}
|
||||
|
||||
/// Get an alert by ID.
|
||||
pub fn get_alert(&self, id: Uuid) -> Option<AlertWithEventId> {
|
||||
self.inner.alerts.read().get(&id).cloned()
|
||||
}
|
||||
|
||||
/// Update an alert.
|
||||
pub fn update_alert<F, R>(&self, id: Uuid, f: F) -> Option<R>
|
||||
where
|
||||
F: FnOnce(&mut Alert) -> R,
|
||||
{
|
||||
let mut alerts = self.inner.alerts.write();
|
||||
alerts.get_mut(&id).map(|a| f(&mut a.alert))
|
||||
}
|
||||
|
||||
/// List alerts for an event.
|
||||
pub fn list_alerts_for_event(&self, event_id: Uuid) -> Vec<Alert> {
|
||||
self.inner
|
||||
.alerts
|
||||
.read()
|
||||
.values()
|
||||
.filter(|a| a.event_id == event_id)
|
||||
.map(|a| a.alert.clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Broadcasting
|
||||
// ========================================================================
|
||||
|
||||
/// Get a receiver for real-time updates.
|
||||
pub fn subscribe(&self) -> broadcast::Receiver<WebSocketMessage> {
|
||||
self.inner.broadcast_tx.subscribe()
|
||||
}
|
||||
|
||||
/// Broadcast a message to all subscribers.
|
||||
pub fn broadcast(&self, message: WebSocketMessage) {
|
||||
// Ignore send errors (no subscribers)
|
||||
let _ = self.inner.broadcast_tx.send(message);
|
||||
}
|
||||
|
||||
/// Get the number of active subscribers.
|
||||
pub fn subscriber_count(&self) -> usize {
|
||||
self.inner.broadcast_tx.receiver_count()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AppState {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::domain::{DisasterType, DisasterEvent};
|
||||
use geo::Point;
|
||||
|
||||
#[test]
|
||||
fn test_store_and_get_event() {
|
||||
let state = AppState::new();
|
||||
let event = DisasterEvent::new(
|
||||
DisasterType::Earthquake,
|
||||
Point::new(-122.4194, 37.7749),
|
||||
"Test earthquake",
|
||||
);
|
||||
let id = *event.id().as_uuid();
|
||||
|
||||
state.store_event(event);
|
||||
|
||||
let retrieved = state.get_event(id);
|
||||
assert!(retrieved.is_some());
|
||||
assert_eq!(retrieved.unwrap().id().as_uuid(), &id);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_update_event() {
|
||||
let state = AppState::new();
|
||||
let event = DisasterEvent::new(
|
||||
DisasterType::Earthquake,
|
||||
Point::new(0.0, 0.0),
|
||||
"Test",
|
||||
);
|
||||
let id = *event.id().as_uuid();
|
||||
state.store_event(event);
|
||||
|
||||
let result = state.update_event(id, |e| {
|
||||
e.set_status(crate::EventStatus::Suspended);
|
||||
true
|
||||
});
|
||||
|
||||
assert!(result.unwrap());
|
||||
let updated = state.get_event(id).unwrap();
|
||||
assert!(matches!(updated.status(), crate::EventStatus::Suspended));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_broadcast_subscribe() {
|
||||
let state = AppState::new();
|
||||
let mut rx = state.subscribe();
|
||||
|
||||
state.broadcast(WebSocketMessage::Heartbeat {
|
||||
timestamp: chrono::Utc::now(),
|
||||
});
|
||||
|
||||
// Try to receive (in async context this would work)
|
||||
assert_eq!(state.subscriber_count(), 1);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,330 @@
|
||||
//! WebSocket handler for real-time survivor and alert streaming.
|
||||
//!
|
||||
//! This module provides a WebSocket endpoint that streams real-time updates
|
||||
//! for survivor detections, status changes, and alerts.
|
||||
//!
|
||||
//! ## Protocol
|
||||
//!
|
||||
//! Clients connect to `/ws/mat/stream` and receive JSON-formatted messages.
|
||||
//!
|
||||
//! ### Message Types
|
||||
//!
|
||||
//! - `survivor_detected` - New survivor found
|
||||
//! - `survivor_updated` - Survivor status/vitals changed
|
||||
//! - `survivor_lost` - Survivor signal lost
|
||||
//! - `alert_created` - New alert generated
|
||||
//! - `alert_updated` - Alert status changed
|
||||
//! - `zone_scan_complete` - Zone scan finished
|
||||
//! - `event_status_changed` - Event status changed
|
||||
//! - `heartbeat` - Keep-alive ping
|
||||
//! - `error` - Error message
|
||||
//!
|
||||
//! ### Client Commands
|
||||
//!
|
||||
//! Clients can send JSON commands:
|
||||
//! - `{"action": "subscribe", "event_id": "..."}`
|
||||
//! - `{"action": "unsubscribe", "event_id": "..."}`
|
||||
//! - `{"action": "subscribe_all"}`
|
||||
//! - `{"action": "get_state", "event_id": "..."}`
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use axum::{
|
||||
extract::{
|
||||
ws::{Message, WebSocket, WebSocketUpgrade},
|
||||
State,
|
||||
},
|
||||
response::Response,
|
||||
};
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use parking_lot::Mutex;
|
||||
use tokio::sync::broadcast;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::dto::{WebSocketMessage, WebSocketRequest};
|
||||
use super::state::AppState;
|
||||
|
||||
/// WebSocket connection handler.
|
||||
///
|
||||
/// # OpenAPI Specification
|
||||
///
|
||||
/// ```yaml
|
||||
/// /ws/mat/stream:
|
||||
/// get:
|
||||
/// summary: Real-time event stream
|
||||
/// description: |
|
||||
/// WebSocket endpoint for real-time updates on survivors and alerts.
|
||||
///
|
||||
/// ## Connection
|
||||
///
|
||||
/// Connect using a WebSocket client to receive real-time updates.
|
||||
///
|
||||
/// ## Messages
|
||||
///
|
||||
/// All messages are JSON-formatted with a "type" field indicating
|
||||
/// the message type.
|
||||
///
|
||||
/// ## Subscriptions
|
||||
///
|
||||
/// By default, clients receive updates for all events. Send a
|
||||
/// subscribe/unsubscribe command to filter to specific events.
|
||||
/// tags: [WebSocket]
|
||||
/// responses:
|
||||
/// 101:
|
||||
/// description: WebSocket connection established
|
||||
/// ```
|
||||
#[tracing::instrument(skip(state, ws))]
|
||||
pub async fn ws_handler(
|
||||
State(state): State<AppState>,
|
||||
ws: WebSocketUpgrade,
|
||||
) -> Response {
|
||||
ws.on_upgrade(move |socket| handle_socket(socket, state))
|
||||
}
|
||||
|
||||
/// Handle an established WebSocket connection.
|
||||
async fn handle_socket(socket: WebSocket, state: AppState) {
|
||||
let (mut sender, mut receiver) = socket.split();
|
||||
|
||||
// Subscription state for this connection
|
||||
let subscriptions: Arc<Mutex<SubscriptionState>> = Arc::new(Mutex::new(SubscriptionState::new()));
|
||||
|
||||
// Subscribe to broadcast channel
|
||||
let mut broadcast_rx = state.subscribe();
|
||||
|
||||
// Spawn task to forward broadcast messages to client
|
||||
let subs_clone = subscriptions.clone();
|
||||
let forward_task = tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Receive from broadcast channel
|
||||
result = broadcast_rx.recv() => {
|
||||
match result {
|
||||
Ok(msg) => {
|
||||
// Check if this message matches subscription filter
|
||||
if subs_clone.lock().should_receive(&msg) {
|
||||
if let Ok(json) = serde_json::to_string(&msg) {
|
||||
if sender.send(Message::Text(json)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(broadcast::error::RecvError::Lagged(n)) => {
|
||||
tracing::warn!(lagged = n, "WebSocket client lagged, messages dropped");
|
||||
// Send error notification
|
||||
let error = WebSocketMessage::Error {
|
||||
code: "MESSAGES_DROPPED".to_string(),
|
||||
message: format!("{} messages were dropped due to slow client", n),
|
||||
};
|
||||
if let Ok(json) = serde_json::to_string(&error) {
|
||||
if sender.send(Message::Text(json)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(broadcast::error::RecvError::Closed) => {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Periodic heartbeat
|
||||
_ = tokio::time::sleep(Duration::from_secs(30)) => {
|
||||
let heartbeat = WebSocketMessage::Heartbeat {
|
||||
timestamp: chrono::Utc::now(),
|
||||
};
|
||||
if let Ok(json) = serde_json::to_string(&heartbeat) {
|
||||
if sender.send(Message::Ping(json.into_bytes())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Handle incoming messages from client
|
||||
let subs_clone = subscriptions.clone();
|
||||
let state_clone = state.clone();
|
||||
while let Some(Ok(msg)) = receiver.next().await {
|
||||
match msg {
|
||||
Message::Text(text) => {
|
||||
// Parse and handle client command
|
||||
if let Err(e) = handle_client_message(&text, &subs_clone, &state_clone).await {
|
||||
tracing::warn!(error = %e, "Failed to handle WebSocket message");
|
||||
}
|
||||
}
|
||||
Message::Binary(_) => {
|
||||
// Binary messages not supported
|
||||
tracing::debug!("Ignoring binary WebSocket message");
|
||||
}
|
||||
Message::Ping(data) => {
|
||||
// Pong handled automatically by axum
|
||||
tracing::trace!(len = data.len(), "Received ping");
|
||||
}
|
||||
Message::Pong(_) => {
|
||||
// Heartbeat response
|
||||
tracing::trace!("Received pong");
|
||||
}
|
||||
Message::Close(_) => {
|
||||
tracing::debug!("Client closed WebSocket connection");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up
|
||||
forward_task.abort();
|
||||
tracing::debug!("WebSocket connection closed");
|
||||
}
|
||||
|
||||
/// Handle a client message (subscription commands).
|
||||
async fn handle_client_message(
|
||||
text: &str,
|
||||
subscriptions: &Arc<Mutex<SubscriptionState>>,
|
||||
state: &AppState,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let request: WebSocketRequest = serde_json::from_str(text)?;
|
||||
|
||||
match request {
|
||||
WebSocketRequest::Subscribe { event_id } => {
|
||||
// Verify event exists
|
||||
if state.get_event(event_id).is_some() {
|
||||
subscriptions.lock().subscribe(event_id);
|
||||
tracing::debug!(event_id = %event_id, "Client subscribed to event");
|
||||
}
|
||||
}
|
||||
WebSocketRequest::Unsubscribe { event_id } => {
|
||||
subscriptions.lock().unsubscribe(&event_id);
|
||||
tracing::debug!(event_id = %event_id, "Client unsubscribed from event");
|
||||
}
|
||||
WebSocketRequest::SubscribeAll => {
|
||||
subscriptions.lock().subscribe_all();
|
||||
tracing::debug!("Client subscribed to all events");
|
||||
}
|
||||
WebSocketRequest::GetState { event_id } => {
|
||||
// This would send current state - simplified for now
|
||||
tracing::debug!(event_id = %event_id, "Client requested state");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Tracks subscription state for a WebSocket connection.
|
||||
struct SubscriptionState {
|
||||
/// Subscribed event IDs (empty = all events)
|
||||
event_ids: HashSet<Uuid>,
|
||||
/// Whether subscribed to all events
|
||||
all_events: bool,
|
||||
}
|
||||
|
||||
impl SubscriptionState {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
event_ids: HashSet::new(),
|
||||
all_events: true, // Default to receiving all events
|
||||
}
|
||||
}
|
||||
|
||||
fn subscribe(&mut self, event_id: Uuid) {
|
||||
self.all_events = false;
|
||||
self.event_ids.insert(event_id);
|
||||
}
|
||||
|
||||
fn unsubscribe(&mut self, event_id: &Uuid) {
|
||||
self.event_ids.remove(event_id);
|
||||
if self.event_ids.is_empty() {
|
||||
self.all_events = true;
|
||||
}
|
||||
}
|
||||
|
||||
fn subscribe_all(&mut self) {
|
||||
self.all_events = true;
|
||||
self.event_ids.clear();
|
||||
}
|
||||
|
||||
fn should_receive(&self, msg: &WebSocketMessage) -> bool {
|
||||
if self.all_events {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Extract event_id from message and check subscription
|
||||
let event_id = match msg {
|
||||
WebSocketMessage::SurvivorDetected { event_id, .. } => Some(*event_id),
|
||||
WebSocketMessage::SurvivorUpdated { event_id, .. } => Some(*event_id),
|
||||
WebSocketMessage::SurvivorLost { event_id, .. } => Some(*event_id),
|
||||
WebSocketMessage::AlertCreated { event_id, .. } => Some(*event_id),
|
||||
WebSocketMessage::AlertUpdated { event_id, .. } => Some(*event_id),
|
||||
WebSocketMessage::ZoneScanComplete { event_id, .. } => Some(*event_id),
|
||||
WebSocketMessage::EventStatusChanged { event_id, .. } => Some(*event_id),
|
||||
WebSocketMessage::Heartbeat { .. } => None, // Always receive
|
||||
WebSocketMessage::Error { .. } => None, // Always receive
|
||||
};
|
||||
|
||||
match event_id {
|
||||
Some(id) => self.event_ids.contains(&id),
|
||||
None => true, // Non-event-specific messages always sent
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_subscription_state() {
|
||||
let mut state = SubscriptionState::new();
|
||||
|
||||
// Default is all events
|
||||
assert!(state.all_events);
|
||||
|
||||
// Subscribe to specific event
|
||||
let event_id = Uuid::new_v4();
|
||||
state.subscribe(event_id);
|
||||
assert!(!state.all_events);
|
||||
assert!(state.event_ids.contains(&event_id));
|
||||
|
||||
// Unsubscribe returns to all events
|
||||
state.unsubscribe(&event_id);
|
||||
assert!(state.all_events);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_should_receive() {
|
||||
let mut state = SubscriptionState::new();
|
||||
let event_id = Uuid::new_v4();
|
||||
let other_id = Uuid::new_v4();
|
||||
|
||||
// All events mode - receive everything
|
||||
let msg = WebSocketMessage::Heartbeat {
|
||||
timestamp: chrono::Utc::now(),
|
||||
};
|
||||
assert!(state.should_receive(&msg));
|
||||
|
||||
// Subscribe to specific event
|
||||
state.subscribe(event_id);
|
||||
|
||||
// Should receive messages for subscribed event
|
||||
let msg = WebSocketMessage::SurvivorLost {
|
||||
event_id,
|
||||
survivor_id: Uuid::new_v4(),
|
||||
};
|
||||
assert!(state.should_receive(&msg));
|
||||
|
||||
// Should not receive messages for other events
|
||||
let msg = WebSocketMessage::SurvivorLost {
|
||||
event_id: other_id,
|
||||
survivor_id: Uuid::new_v4(),
|
||||
};
|
||||
assert!(!state.should_receive(&msg));
|
||||
|
||||
// Heartbeats always received
|
||||
let msg = WebSocketMessage::Heartbeat {
|
||||
timestamp: chrono::Utc::now(),
|
||||
};
|
||||
assert!(state.should_receive(&msg));
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,10 @@
|
||||
//! Detection pipeline combining all vital signs detectors.
|
||||
//!
|
||||
//! This module provides both traditional signal-processing-based detection
|
||||
//! and optional ML-enhanced detection for improved accuracy.
|
||||
|
||||
use crate::domain::{ScanZone, VitalSignsReading, ConfidenceScore};
|
||||
use crate::ml::{MlDetectionConfig, MlDetectionPipeline, MlDetectionResult};
|
||||
use crate::{DisasterConfig, MatError};
|
||||
use super::{
|
||||
BreathingDetector, BreathingDetectorConfig,
|
||||
@@ -23,6 +27,10 @@ pub struct DetectionConfig {
|
||||
pub enable_heartbeat: bool,
|
||||
/// Minimum overall confidence to report detection
|
||||
pub min_confidence: f64,
|
||||
/// Enable ML-enhanced detection
|
||||
pub enable_ml: bool,
|
||||
/// ML detection configuration (if enabled)
|
||||
pub ml_config: Option<MlDetectionConfig>,
|
||||
}
|
||||
|
||||
impl Default for DetectionConfig {
|
||||
@@ -34,6 +42,8 @@ impl Default for DetectionConfig {
|
||||
sample_rate: 1000.0,
|
||||
enable_heartbeat: false,
|
||||
min_confidence: 0.3,
|
||||
enable_ml: false,
|
||||
ml_config: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -53,6 +63,20 @@ impl DetectionConfig {
|
||||
|
||||
detection_config
|
||||
}
|
||||
|
||||
/// Enable ML-enhanced detection with the given configuration
|
||||
pub fn with_ml(mut self, ml_config: MlDetectionConfig) -> Self {
|
||||
self.enable_ml = true;
|
||||
self.ml_config = Some(ml_config);
|
||||
self
|
||||
}
|
||||
|
||||
/// Enable ML-enhanced detection with default configuration
|
||||
pub fn with_default_ml(mut self) -> Self {
|
||||
self.enable_ml = true;
|
||||
self.ml_config = Some(MlDetectionConfig::default());
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait for vital signs detection
|
||||
@@ -123,20 +147,42 @@ pub struct DetectionPipeline {
|
||||
heartbeat_detector: HeartbeatDetector,
|
||||
movement_classifier: MovementClassifier,
|
||||
data_buffer: parking_lot::RwLock<CsiDataBuffer>,
|
||||
/// Optional ML detection pipeline
|
||||
ml_pipeline: Option<MlDetectionPipeline>,
|
||||
}
|
||||
|
||||
impl DetectionPipeline {
|
||||
/// Create a new detection pipeline
|
||||
pub fn new(config: DetectionConfig) -> Self {
|
||||
let ml_pipeline = if config.enable_ml {
|
||||
config.ml_config.clone().map(MlDetectionPipeline::new)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Self {
|
||||
breathing_detector: BreathingDetector::new(config.breathing.clone()),
|
||||
heartbeat_detector: HeartbeatDetector::new(config.heartbeat.clone()),
|
||||
movement_classifier: MovementClassifier::new(config.movement.clone()),
|
||||
data_buffer: parking_lot::RwLock::new(CsiDataBuffer::new(config.sample_rate)),
|
||||
ml_pipeline,
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize ML models asynchronously (if enabled)
|
||||
pub async fn initialize_ml(&mut self) -> Result<(), MatError> {
|
||||
if let Some(ref mut ml) = self.ml_pipeline {
|
||||
ml.initialize().await.map_err(MatError::from)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if ML pipeline is ready
|
||||
pub fn ml_ready(&self) -> bool {
|
||||
self.ml_pipeline.as_ref().map_or(true, |ml| ml.is_ready())
|
||||
}
|
||||
|
||||
/// Process a scan zone and return detected vital signs
|
||||
pub async fn process_zone(&self, zone: &ScanZone) -> Result<Option<VitalSignsReading>, MatError> {
|
||||
// In a real implementation, this would:
|
||||
@@ -152,17 +198,66 @@ impl DetectionPipeline {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Detect vital signs
|
||||
// Detect vital signs using traditional pipeline
|
||||
let reading = self.detect_from_buffer(&buffer, zone)?;
|
||||
|
||||
// If ML is enabled and ready, enhance with ML predictions
|
||||
let enhanced_reading = if self.config.enable_ml && self.ml_ready() {
|
||||
self.enhance_with_ml(reading, &buffer).await?
|
||||
} else {
|
||||
reading
|
||||
};
|
||||
|
||||
// Check minimum confidence
|
||||
if let Some(ref r) = reading {
|
||||
if let Some(ref r) = enhanced_reading {
|
||||
if r.confidence.value() < self.config.min_confidence {
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(reading)
|
||||
Ok(enhanced_reading)
|
||||
}
|
||||
|
||||
/// Enhance detection results with ML predictions
|
||||
async fn enhance_with_ml(
|
||||
&self,
|
||||
traditional_reading: Option<VitalSignsReading>,
|
||||
buffer: &CsiDataBuffer,
|
||||
) -> Result<Option<VitalSignsReading>, MatError> {
|
||||
let ml_pipeline = match &self.ml_pipeline {
|
||||
Some(ml) => ml,
|
||||
None => return Ok(traditional_reading),
|
||||
};
|
||||
|
||||
// Get ML predictions
|
||||
let ml_result = ml_pipeline.process(buffer).await.map_err(MatError::from)?;
|
||||
|
||||
// If we have ML vital classification, use it to enhance or replace traditional
|
||||
if let Some(ref ml_vital) = ml_result.vital_classification {
|
||||
if let Some(vital_reading) = ml_vital.to_vital_signs_reading() {
|
||||
// If ML result has higher confidence, prefer it
|
||||
if let Some(ref traditional) = traditional_reading {
|
||||
if ml_result.overall_confidence() > traditional.confidence.value() as f32 {
|
||||
return Ok(Some(vital_reading));
|
||||
}
|
||||
} else {
|
||||
// No traditional reading, use ML result
|
||||
return Ok(Some(vital_reading));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(traditional_reading)
|
||||
}
|
||||
|
||||
/// Get the latest ML detection results (if ML is enabled)
|
||||
pub async fn get_ml_results(&self) -> Option<MlDetectionResult> {
|
||||
let buffer = self.data_buffer.read();
|
||||
if let Some(ref ml) = self.ml_pipeline {
|
||||
ml.process(&buffer).await.ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Add CSI data to the processing buffer
|
||||
@@ -236,8 +331,23 @@ impl DetectionPipeline {
|
||||
self.breathing_detector = BreathingDetector::new(config.breathing.clone());
|
||||
self.heartbeat_detector = HeartbeatDetector::new(config.heartbeat.clone());
|
||||
self.movement_classifier = MovementClassifier::new(config.movement.clone());
|
||||
|
||||
// Update ML pipeline if configuration changed
|
||||
if config.enable_ml != self.config.enable_ml || config.ml_config != self.config.ml_config {
|
||||
self.ml_pipeline = if config.enable_ml {
|
||||
config.ml_config.clone().map(MlDetectionPipeline::new)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
}
|
||||
|
||||
self.config = config;
|
||||
}
|
||||
|
||||
/// Get the ML pipeline (if enabled)
|
||||
pub fn ml_pipeline(&self) -> Option<&MlDetectionPipeline> {
|
||||
self.ml_pipeline.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl VitalSignsDetector for DetectionPipeline {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -4,14 +4,102 @@
|
||||
//! - wifi-densepose-signal types and wifi-Mat domain types
|
||||
//! - wifi-densepose-nn inference results and detection results
|
||||
//! - wifi-densepose-hardware interfaces and sensor abstractions
|
||||
//!
|
||||
//! # Hardware Support
|
||||
//!
|
||||
//! The integration layer supports multiple WiFi CSI hardware platforms:
|
||||
//!
|
||||
//! - **ESP32**: Via serial communication using ESP-CSI firmware
|
||||
//! - **Intel 5300 NIC**: Using Linux CSI Tool (iwlwifi driver)
|
||||
//! - **Atheros NICs**: Using ath9k/ath10k/ath11k CSI patches
|
||||
//! - **Nexmon**: For Broadcom chips with CSI firmware
|
||||
//!
|
||||
//! # Example Usage
|
||||
//!
|
||||
//! ```ignore
|
||||
//! use wifi_densepose_mat::integration::{
|
||||
//! HardwareAdapter, HardwareConfig, AtherosDriver,
|
||||
//! csi_receiver::{UdpCsiReceiver, ReceiverConfig},
|
||||
//! };
|
||||
//!
|
||||
//! // Configure for ESP32
|
||||
//! let config = HardwareConfig::esp32("/dev/ttyUSB0", 921600);
|
||||
//! let mut adapter = HardwareAdapter::with_config(config);
|
||||
//! adapter.initialize().await?;
|
||||
//!
|
||||
//! // Or configure for Intel 5300
|
||||
//! let config = HardwareConfig::intel_5300("wlan0");
|
||||
//! let mut adapter = HardwareAdapter::with_config(config);
|
||||
//!
|
||||
//! // Or use UDP receiver for network streaming
|
||||
//! let config = ReceiverConfig::udp("0.0.0.0", 5500);
|
||||
//! let mut receiver = UdpCsiReceiver::new(config).await?;
|
||||
//! ```
|
||||
|
||||
mod signal_adapter;
|
||||
mod neural_adapter;
|
||||
mod hardware_adapter;
|
||||
pub mod csi_receiver;
|
||||
|
||||
pub use signal_adapter::SignalAdapter;
|
||||
pub use neural_adapter::NeuralAdapter;
|
||||
pub use hardware_adapter::HardwareAdapter;
|
||||
pub use hardware_adapter::{
|
||||
// Main adapter
|
||||
HardwareAdapter,
|
||||
// Configuration types
|
||||
HardwareConfig,
|
||||
DeviceType,
|
||||
DeviceSettings,
|
||||
AtherosDriver,
|
||||
ChannelConfig,
|
||||
Bandwidth,
|
||||
// Serial settings
|
||||
SerialSettings,
|
||||
Parity,
|
||||
FlowControl,
|
||||
// Network interface settings
|
||||
NetworkInterfaceSettings,
|
||||
AntennaConfig,
|
||||
// UDP settings
|
||||
UdpSettings,
|
||||
// PCAP settings
|
||||
PcapSettings,
|
||||
// Sensor types
|
||||
SensorInfo,
|
||||
SensorStatus,
|
||||
// CSI data types
|
||||
CsiReadings,
|
||||
CsiMetadata,
|
||||
SensorCsiReading,
|
||||
FrameControlType,
|
||||
CsiStream,
|
||||
// Health and stats
|
||||
HardwareHealth,
|
||||
HealthStatus,
|
||||
StreamingStats,
|
||||
};
|
||||
|
||||
pub use csi_receiver::{
|
||||
// Receiver types
|
||||
UdpCsiReceiver,
|
||||
SerialCsiReceiver,
|
||||
PcapCsiReader,
|
||||
// Configuration
|
||||
ReceiverConfig,
|
||||
CsiSource,
|
||||
UdpSourceConfig,
|
||||
SerialSourceConfig,
|
||||
PcapSourceConfig,
|
||||
SerialParity,
|
||||
// Packet types
|
||||
CsiPacket,
|
||||
CsiPacketMetadata,
|
||||
CsiPacketFormat,
|
||||
// Parser
|
||||
CsiParser,
|
||||
// Stats
|
||||
ReceiverStats,
|
||||
};
|
||||
|
||||
/// Configuration for integration layer
|
||||
#[derive(Debug, Clone, Default)]
|
||||
@@ -22,6 +110,40 @@ pub struct IntegrationConfig {
|
||||
pub batch_size: usize,
|
||||
/// Enable signal preprocessing optimizations
|
||||
pub optimize_signal: bool,
|
||||
/// Hardware configuration
|
||||
pub hardware: Option<HardwareConfig>,
|
||||
}
|
||||
|
||||
impl IntegrationConfig {
|
||||
/// Create configuration for real-time processing
|
||||
pub fn realtime() -> Self {
|
||||
Self {
|
||||
use_gpu: true,
|
||||
batch_size: 1,
|
||||
optimize_signal: true,
|
||||
hardware: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create configuration for batch processing
|
||||
pub fn batch(batch_size: usize) -> Self {
|
||||
Self {
|
||||
use_gpu: true,
|
||||
batch_size,
|
||||
optimize_signal: true,
|
||||
hardware: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create configuration with specific hardware
|
||||
pub fn with_hardware(hardware: HardwareConfig) -> Self {
|
||||
Self {
|
||||
use_gpu: true,
|
||||
batch_size: 1,
|
||||
optimize_signal: true,
|
||||
hardware: Some(hardware),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Error type for integration layer
|
||||
@@ -46,4 +168,68 @@ pub enum AdapterError {
|
||||
/// Data format error
|
||||
#[error("Data format error: {0}")]
|
||||
DataFormat(String),
|
||||
|
||||
/// I/O error
|
||||
#[error("I/O error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
/// Timeout error
|
||||
#[error("Timeout error: {0}")]
|
||||
Timeout(String),
|
||||
}
|
||||
|
||||
/// Prelude module for convenient imports
|
||||
pub mod prelude {
|
||||
pub use super::{
|
||||
AdapterError,
|
||||
HardwareAdapter,
|
||||
HardwareConfig,
|
||||
DeviceType,
|
||||
AtherosDriver,
|
||||
Bandwidth,
|
||||
CsiReadings,
|
||||
CsiPacket,
|
||||
CsiPacketFormat,
|
||||
IntegrationConfig,
|
||||
};
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_integration_config_defaults() {
|
||||
let config = IntegrationConfig::default();
|
||||
assert!(!config.use_gpu);
|
||||
assert_eq!(config.batch_size, 0);
|
||||
assert!(!config.optimize_signal);
|
||||
assert!(config.hardware.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_integration_config_realtime() {
|
||||
let config = IntegrationConfig::realtime();
|
||||
assert!(config.use_gpu);
|
||||
assert_eq!(config.batch_size, 1);
|
||||
assert!(config.optimize_signal);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_integration_config_batch() {
|
||||
let config = IntegrationConfig::batch(32);
|
||||
assert!(config.use_gpu);
|
||||
assert_eq!(config.batch_size, 32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_integration_config_with_hardware() {
|
||||
let hw_config = HardwareConfig::esp32("/dev/ttyUSB0", 921600);
|
||||
let config = IntegrationConfig::with_hardware(hw_config);
|
||||
assert!(config.hardware.is_some());
|
||||
assert!(matches!(
|
||||
config.hardware.as_ref().unwrap().device_type,
|
||||
DeviceType::Esp32
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,10 +78,12 @@
|
||||
#![warn(rustdoc::missing_crate_level_docs)]
|
||||
|
||||
pub mod alerting;
|
||||
pub mod api;
|
||||
pub mod detection;
|
||||
pub mod domain;
|
||||
pub mod integration;
|
||||
pub mod localization;
|
||||
pub mod ml;
|
||||
|
||||
// Re-export main types
|
||||
pub use domain::{
|
||||
@@ -121,6 +123,23 @@ pub use integration::{
|
||||
AdapterError, IntegrationConfig,
|
||||
};
|
||||
|
||||
pub use api::{
|
||||
create_router, AppState,
|
||||
};
|
||||
|
||||
pub use ml::{
|
||||
// Core ML types
|
||||
MlError, MlResult, MlDetectionConfig, MlDetectionPipeline, MlDetectionResult,
|
||||
// Debris penetration model
|
||||
DebrisPenetrationModel, DebrisFeatures, DepthEstimate as MlDepthEstimate,
|
||||
DebrisModel, DebrisModelConfig, DebrisFeatureExtractor,
|
||||
MaterialType, DebrisClassification, AttenuationPrediction,
|
||||
// Vital signs classifier
|
||||
VitalSignsClassifier, VitalSignsClassifierConfig,
|
||||
BreathingClassification, HeartbeatClassification,
|
||||
UncertaintyEstimate, ClassifierOutput,
|
||||
};
|
||||
|
||||
/// Library version
|
||||
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
|
||||
@@ -165,6 +184,10 @@ pub enum MatError {
|
||||
/// I/O error
|
||||
#[error("I/O error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
/// Machine learning error
|
||||
#[error("ML error: {0}")]
|
||||
Ml(#[from] ml::MlError),
|
||||
}
|
||||
|
||||
/// Configuration for the disaster response system
|
||||
@@ -417,6 +440,10 @@ pub mod prelude {
|
||||
LocalizationService,
|
||||
// Alerting
|
||||
AlertDispatcher,
|
||||
// ML types
|
||||
MlDetectionConfig, MlDetectionPipeline, MlDetectionResult,
|
||||
DebrisModel, MaterialType, DebrisClassification,
|
||||
VitalSignsClassifier, UncertaintyEstimate,
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,765 @@
|
||||
//! ONNX-based debris penetration model for material classification and depth prediction.
|
||||
//!
|
||||
//! This module provides neural network models for analyzing debris characteristics
|
||||
//! from WiFi CSI signals. Key capabilities include:
|
||||
//!
|
||||
//! - Material type classification (concrete, wood, metal, etc.)
|
||||
//! - Signal attenuation prediction based on material properties
|
||||
//! - Penetration depth estimation with uncertainty quantification
|
||||
//!
|
||||
//! ## Model Architecture
|
||||
//!
|
||||
//! The debris model uses a multi-head architecture:
|
||||
//! - Shared feature encoder (CNN-based)
|
||||
//! - Material classification head (softmax output)
|
||||
//! - Attenuation regression head (linear output)
|
||||
//! - Depth estimation head with uncertainty (mean + variance output)
|
||||
|
||||
use super::{DebrisFeatures, DepthEstimate, MlError, MlResult};
|
||||
use ndarray::{Array1, Array2, Array4, s};
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use parking_lot::RwLock;
|
||||
use thiserror::Error;
|
||||
use tracing::{debug, info, instrument, warn};
|
||||
|
||||
#[cfg(feature = "onnx")]
|
||||
use wifi_densepose_nn::{OnnxBackend, OnnxSession, InferenceOptions, Tensor, TensorShape};
|
||||
|
||||
/// Errors specific to debris model operations
|
||||
#[derive(Debug, Error)]
|
||||
pub enum DebrisModelError {
|
||||
/// Model file not found
|
||||
#[error("Model file not found: {0}")]
|
||||
FileNotFound(String),
|
||||
|
||||
/// Invalid model format
|
||||
#[error("Invalid model format: {0}")]
|
||||
InvalidFormat(String),
|
||||
|
||||
/// Inference error
|
||||
#[error("Inference failed: {0}")]
|
||||
InferenceFailed(String),
|
||||
|
||||
/// Feature extraction error
|
||||
#[error("Feature extraction failed: {0}")]
|
||||
FeatureExtractionFailed(String),
|
||||
}
|
||||
|
||||
/// Types of materials that can be detected in debris
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum MaterialType {
|
||||
/// Reinforced concrete (high attenuation)
|
||||
Concrete,
|
||||
/// Wood/timber (moderate attenuation)
|
||||
Wood,
|
||||
/// Metal/steel (very high attenuation, reflective)
|
||||
Metal,
|
||||
/// Glass (low attenuation)
|
||||
Glass,
|
||||
/// Brick/masonry (high attenuation)
|
||||
Brick,
|
||||
/// Drywall/plasterboard (low attenuation)
|
||||
Drywall,
|
||||
/// Mixed/composite materials
|
||||
Mixed,
|
||||
/// Unknown material type
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl MaterialType {
|
||||
/// Get typical attenuation coefficient (dB/m)
|
||||
pub fn typical_attenuation(&self) -> f32 {
|
||||
match self {
|
||||
MaterialType::Concrete => 25.0,
|
||||
MaterialType::Wood => 8.0,
|
||||
MaterialType::Metal => 50.0,
|
||||
MaterialType::Glass => 3.0,
|
||||
MaterialType::Brick => 18.0,
|
||||
MaterialType::Drywall => 4.0,
|
||||
MaterialType::Mixed => 15.0,
|
||||
MaterialType::Unknown => 12.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get typical delay spread (nanoseconds)
|
||||
pub fn typical_delay_spread(&self) -> f32 {
|
||||
match self {
|
||||
MaterialType::Concrete => 150.0,
|
||||
MaterialType::Wood => 50.0,
|
||||
MaterialType::Metal => 200.0,
|
||||
MaterialType::Glass => 20.0,
|
||||
MaterialType::Brick => 100.0,
|
||||
MaterialType::Drywall => 30.0,
|
||||
MaterialType::Mixed => 80.0,
|
||||
MaterialType::Unknown => 60.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// From class index
|
||||
pub fn from_index(index: usize) -> Self {
|
||||
match index {
|
||||
0 => MaterialType::Concrete,
|
||||
1 => MaterialType::Wood,
|
||||
2 => MaterialType::Metal,
|
||||
3 => MaterialType::Glass,
|
||||
4 => MaterialType::Brick,
|
||||
5 => MaterialType::Drywall,
|
||||
6 => MaterialType::Mixed,
|
||||
_ => MaterialType::Unknown,
|
||||
}
|
||||
}
|
||||
|
||||
/// To class index
|
||||
pub fn to_index(&self) -> usize {
|
||||
match self {
|
||||
MaterialType::Concrete => 0,
|
||||
MaterialType::Wood => 1,
|
||||
MaterialType::Metal => 2,
|
||||
MaterialType::Glass => 3,
|
||||
MaterialType::Brick => 4,
|
||||
MaterialType::Drywall => 5,
|
||||
MaterialType::Mixed => 6,
|
||||
MaterialType::Unknown => 7,
|
||||
}
|
||||
}
|
||||
|
||||
/// Number of material classes
|
||||
pub const NUM_CLASSES: usize = 8;
|
||||
}
|
||||
|
||||
impl std::fmt::Display for MaterialType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
MaterialType::Concrete => write!(f, "Concrete"),
|
||||
MaterialType::Wood => write!(f, "Wood"),
|
||||
MaterialType::Metal => write!(f, "Metal"),
|
||||
MaterialType::Glass => write!(f, "Glass"),
|
||||
MaterialType::Brick => write!(f, "Brick"),
|
||||
MaterialType::Drywall => write!(f, "Drywall"),
|
||||
MaterialType::Mixed => write!(f, "Mixed"),
|
||||
MaterialType::Unknown => write!(f, "Unknown"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of debris material classification
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DebrisClassification {
|
||||
/// Primary material type detected
|
||||
pub material_type: MaterialType,
|
||||
/// Confidence score for the classification (0.0-1.0)
|
||||
pub confidence: f32,
|
||||
/// Per-class probabilities
|
||||
pub class_probabilities: Vec<f32>,
|
||||
/// Estimated layer count
|
||||
pub estimated_layers: u8,
|
||||
/// Whether multiple materials detected
|
||||
pub is_composite: bool,
|
||||
}
|
||||
|
||||
impl DebrisClassification {
|
||||
/// Create a new debris classification
|
||||
pub fn new(probabilities: Vec<f32>) -> Self {
|
||||
let (max_idx, &max_prob) = probabilities.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
|
||||
.unwrap_or((7, &0.0));
|
||||
|
||||
// Check for composite materials (multiple high probabilities)
|
||||
let high_prob_count = probabilities.iter()
|
||||
.filter(|&&p| p > 0.2)
|
||||
.count();
|
||||
|
||||
let is_composite = high_prob_count > 1 && max_prob < 0.7;
|
||||
let material_type = if is_composite {
|
||||
MaterialType::Mixed
|
||||
} else {
|
||||
MaterialType::from_index(max_idx)
|
||||
};
|
||||
|
||||
// Estimate layer count from delay spread characteristics
|
||||
let estimated_layers = Self::estimate_layers(&probabilities);
|
||||
|
||||
Self {
|
||||
material_type,
|
||||
confidence: max_prob,
|
||||
class_probabilities: probabilities,
|
||||
estimated_layers,
|
||||
is_composite,
|
||||
}
|
||||
}
|
||||
|
||||
/// Estimate number of debris layers from probability distribution
|
||||
fn estimate_layers(probabilities: &[f32]) -> u8 {
|
||||
// More uniform distribution suggests more layers
|
||||
let entropy: f32 = probabilities.iter()
|
||||
.filter(|&&p| p > 0.01)
|
||||
.map(|&p| -p * p.ln())
|
||||
.sum();
|
||||
|
||||
let max_entropy = (probabilities.len() as f32).ln();
|
||||
let normalized_entropy = entropy / max_entropy;
|
||||
|
||||
// Map entropy to layer count (1-5)
|
||||
(1.0 + normalized_entropy * 4.0).round() as u8
|
||||
}
|
||||
|
||||
/// Get secondary material if composite
|
||||
pub fn secondary_material(&self) -> Option<MaterialType> {
|
||||
if !self.is_composite {
|
||||
return None;
|
||||
}
|
||||
|
||||
let primary_idx = self.material_type.to_index();
|
||||
self.class_probabilities.iter()
|
||||
.enumerate()
|
||||
.filter(|(i, _)| *i != primary_idx)
|
||||
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
|
||||
.map(|(i, _)| MaterialType::from_index(i))
|
||||
}
|
||||
}
|
||||
|
||||
/// Signal attenuation prediction result
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AttenuationPrediction {
|
||||
/// Predicted attenuation in dB
|
||||
pub attenuation_db: f32,
|
||||
/// Attenuation per meter (dB/m)
|
||||
pub attenuation_per_meter: f32,
|
||||
/// Uncertainty in the prediction
|
||||
pub uncertainty_db: f32,
|
||||
/// Frequency-dependent attenuation profile
|
||||
pub frequency_profile: Vec<f32>,
|
||||
/// Confidence in the prediction
|
||||
pub confidence: f32,
|
||||
}
|
||||
|
||||
impl AttenuationPrediction {
|
||||
/// Create new attenuation prediction
|
||||
pub fn new(attenuation: f32, depth: f32, uncertainty: f32) -> Self {
|
||||
let attenuation_per_meter = if depth > 0.0 {
|
||||
attenuation / depth
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
Self {
|
||||
attenuation_db: attenuation,
|
||||
attenuation_per_meter,
|
||||
uncertainty_db: uncertainty,
|
||||
frequency_profile: vec![],
|
||||
confidence: (1.0 - uncertainty / attenuation.abs().max(1.0)).max(0.0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Predict signal at given depth
|
||||
pub fn predict_signal_at_depth(&self, depth_m: f32) -> f32 {
|
||||
-self.attenuation_per_meter * depth_m
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for debris model
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DebrisModelConfig {
|
||||
/// Use GPU for inference
|
||||
pub use_gpu: bool,
|
||||
/// Number of inference threads
|
||||
pub num_threads: usize,
|
||||
/// Minimum confidence threshold
|
||||
pub confidence_threshold: f32,
|
||||
}
|
||||
|
||||
impl Default for DebrisModelConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
use_gpu: false,
|
||||
num_threads: 4,
|
||||
confidence_threshold: 0.5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Feature extractor for debris classification
|
||||
pub struct DebrisFeatureExtractor {
|
||||
/// Number of subcarriers to analyze
|
||||
num_subcarriers: usize,
|
||||
/// Window size for temporal analysis
|
||||
window_size: usize,
|
||||
/// Whether to use advanced features
|
||||
use_advanced_features: bool,
|
||||
}
|
||||
|
||||
impl Default for DebrisFeatureExtractor {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
num_subcarriers: 64,
|
||||
window_size: 100,
|
||||
use_advanced_features: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DebrisFeatureExtractor {
|
||||
/// Create new feature extractor
|
||||
pub fn new(num_subcarriers: usize, window_size: usize) -> Self {
|
||||
Self {
|
||||
num_subcarriers,
|
||||
window_size,
|
||||
use_advanced_features: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract features from debris features for model input
|
||||
pub fn extract(&self, features: &DebrisFeatures) -> MlResult<Array2<f32>> {
|
||||
let feature_vector = features.to_feature_vector();
|
||||
|
||||
// Reshape to 2D for model input (batch_size=1, features)
|
||||
let arr = Array2::from_shape_vec(
|
||||
(1, feature_vector.len()),
|
||||
feature_vector,
|
||||
).map_err(|e| MlError::FeatureExtraction(e.to_string()))?;
|
||||
|
||||
Ok(arr)
|
||||
}
|
||||
|
||||
/// Extract spatial-temporal features for CNN input
|
||||
pub fn extract_spatial_temporal(&self, features: &DebrisFeatures) -> MlResult<Array4<f32>> {
|
||||
let amp_len = features.amplitude_attenuation.len().min(self.num_subcarriers);
|
||||
let phase_len = features.phase_shifts.len().min(self.num_subcarriers);
|
||||
|
||||
// Create 4D tensor: [batch, channels, height, width]
|
||||
// channels: amplitude, phase
|
||||
// height: subcarriers
|
||||
// width: 1 (or temporal windows if available)
|
||||
let mut tensor = Array4::<f32>::zeros((1, 2, self.num_subcarriers, 1));
|
||||
|
||||
// Fill amplitude channel
|
||||
for (i, &v) in features.amplitude_attenuation.iter().take(amp_len).enumerate() {
|
||||
tensor[[0, 0, i, 0]] = v;
|
||||
}
|
||||
|
||||
// Fill phase channel
|
||||
for (i, &v) in features.phase_shifts.iter().take(phase_len).enumerate() {
|
||||
tensor[[0, 1, i, 0]] = v;
|
||||
}
|
||||
|
||||
Ok(tensor)
|
||||
}
|
||||
}
|
||||
|
||||
/// ONNX-based debris penetration model
|
||||
pub struct DebrisModel {
|
||||
config: DebrisModelConfig,
|
||||
feature_extractor: DebrisFeatureExtractor,
|
||||
/// Material classification model weights (for rule-based fallback)
|
||||
material_weights: MaterialClassificationWeights,
|
||||
/// Whether ONNX model is loaded
|
||||
model_loaded: bool,
|
||||
/// Cached model session
|
||||
#[cfg(feature = "onnx")]
|
||||
session: Option<Arc<RwLock<OnnxSession>>>,
|
||||
}
|
||||
|
||||
/// Pre-computed weights for rule-based material classification
|
||||
struct MaterialClassificationWeights {
|
||||
/// Weights for attenuation features
|
||||
attenuation_weights: [f32; MaterialType::NUM_CLASSES],
|
||||
/// Weights for delay spread features
|
||||
delay_weights: [f32; MaterialType::NUM_CLASSES],
|
||||
/// Weights for coherence bandwidth
|
||||
coherence_weights: [f32; MaterialType::NUM_CLASSES],
|
||||
/// Bias terms
|
||||
biases: [f32; MaterialType::NUM_CLASSES],
|
||||
}
|
||||
|
||||
impl Default for MaterialClassificationWeights {
|
||||
fn default() -> Self {
|
||||
// Pre-computed weights based on material RF properties
|
||||
Self {
|
||||
attenuation_weights: [0.8, 0.3, 0.95, 0.1, 0.6, 0.15, 0.5, 0.4],
|
||||
delay_weights: [0.7, 0.2, 0.9, 0.1, 0.5, 0.1, 0.4, 0.3],
|
||||
coherence_weights: [0.3, 0.7, 0.1, 0.9, 0.4, 0.8, 0.5, 0.5],
|
||||
biases: [-0.5, 0.2, -0.8, 0.5, -0.3, 0.3, 0.0, 0.0],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DebrisModel {
|
||||
/// Create a new debris model from ONNX file
|
||||
#[instrument(skip(path))]
|
||||
pub fn from_onnx<P: AsRef<Path>>(path: P, config: DebrisModelConfig) -> MlResult<Self> {
|
||||
let path_ref = path.as_ref();
|
||||
info!(?path_ref, "Loading debris model");
|
||||
|
||||
#[cfg(feature = "onnx")]
|
||||
let session = if path_ref.exists() {
|
||||
let options = InferenceOptions {
|
||||
use_gpu: config.use_gpu,
|
||||
num_threads: config.num_threads,
|
||||
..Default::default()
|
||||
};
|
||||
match OnnxSession::from_file(path_ref, &options) {
|
||||
Ok(s) => {
|
||||
info!("ONNX debris model loaded successfully");
|
||||
Some(Arc::new(RwLock::new(s)))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(?e, "Failed to load ONNX model, using rule-based fallback");
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
warn!(?path_ref, "Model file not found, using rule-based fallback");
|
||||
None
|
||||
};
|
||||
|
||||
#[cfg(feature = "onnx")]
|
||||
let model_loaded = session.is_some();
|
||||
|
||||
#[cfg(not(feature = "onnx"))]
|
||||
let model_loaded = false;
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
feature_extractor: DebrisFeatureExtractor::default(),
|
||||
material_weights: MaterialClassificationWeights::default(),
|
||||
model_loaded,
|
||||
#[cfg(feature = "onnx")]
|
||||
session,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create with in-memory model bytes
|
||||
#[cfg(feature = "onnx")]
|
||||
pub fn from_bytes(bytes: &[u8], config: DebrisModelConfig) -> MlResult<Self> {
|
||||
let options = InferenceOptions {
|
||||
use_gpu: config.use_gpu,
|
||||
num_threads: config.num_threads,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let session = OnnxSession::from_bytes(bytes, &options)
|
||||
.map_err(|e| MlError::ModelLoad(e.to_string()))?;
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
feature_extractor: DebrisFeatureExtractor::default(),
|
||||
material_weights: MaterialClassificationWeights::default(),
|
||||
model_loaded: true,
|
||||
session: Some(Arc::new(RwLock::new(session))),
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a rule-based model (no ONNX required)
|
||||
pub fn rule_based(config: DebrisModelConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
feature_extractor: DebrisFeatureExtractor::default(),
|
||||
material_weights: MaterialClassificationWeights::default(),
|
||||
model_loaded: false,
|
||||
#[cfg(feature = "onnx")]
|
||||
session: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if ONNX model is loaded
|
||||
pub fn is_loaded(&self) -> bool {
|
||||
self.model_loaded
|
||||
}
|
||||
|
||||
/// Classify material type from debris features
|
||||
#[instrument(skip(self, features))]
|
||||
pub async fn classify(&self, features: &DebrisFeatures) -> MlResult<DebrisClassification> {
|
||||
#[cfg(feature = "onnx")]
|
||||
if let Some(ref session) = self.session {
|
||||
return self.classify_onnx(features, session).await;
|
||||
}
|
||||
|
||||
// Fall back to rule-based classification
|
||||
self.classify_rules(features)
|
||||
}
|
||||
|
||||
/// ONNX-based classification
|
||||
#[cfg(feature = "onnx")]
|
||||
async fn classify_onnx(
|
||||
&self,
|
||||
features: &DebrisFeatures,
|
||||
session: &Arc<RwLock<OnnxSession>>,
|
||||
) -> MlResult<DebrisClassification> {
|
||||
let input_features = self.feature_extractor.extract(features)?;
|
||||
|
||||
// Prepare input tensor
|
||||
let input_array = Array4::from_shape_vec(
|
||||
(1, 1, 1, input_features.len()),
|
||||
input_features.iter().cloned().collect(),
|
||||
).map_err(|e| MlError::Inference(e.to_string()))?;
|
||||
|
||||
let input_tensor = Tensor::Float4D(input_array);
|
||||
|
||||
let mut inputs = HashMap::new();
|
||||
inputs.insert("input".to_string(), input_tensor);
|
||||
|
||||
// Run inference
|
||||
let outputs = session.write().run(inputs)
|
||||
.map_err(|e| MlError::NeuralNetwork(e))?;
|
||||
|
||||
// Extract classification probabilities
|
||||
let probabilities = if let Some(output) = outputs.get("material_probs") {
|
||||
output.to_vec()
|
||||
.map_err(|e| MlError::Inference(e.to_string()))?
|
||||
} else {
|
||||
// Fallback to rule-based
|
||||
return self.classify_rules(features);
|
||||
};
|
||||
|
||||
// Ensure we have enough classes
|
||||
let mut probs = vec![0.0f32; MaterialType::NUM_CLASSES];
|
||||
for (i, &p) in probabilities.iter().take(MaterialType::NUM_CLASSES).enumerate() {
|
||||
probs[i] = p;
|
||||
}
|
||||
|
||||
// Apply softmax normalization
|
||||
let max_val = probs.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let exp_sum: f32 = probs.iter().map(|&x| (x - max_val).exp()).sum();
|
||||
for p in &mut probs {
|
||||
*p = (*p - max_val).exp() / exp_sum;
|
||||
}
|
||||
|
||||
Ok(DebrisClassification::new(probs))
|
||||
}
|
||||
|
||||
/// Rule-based material classification (fallback)
|
||||
fn classify_rules(&self, features: &DebrisFeatures) -> MlResult<DebrisClassification> {
|
||||
let mut scores = [0.0f32; MaterialType::NUM_CLASSES];
|
||||
|
||||
// Normalize input features
|
||||
let attenuation_score = (features.snr_db.abs() / 30.0).min(1.0);
|
||||
let delay_score = (features.delay_spread / 200.0).min(1.0);
|
||||
let coherence_score = (features.coherence_bandwidth / 20.0).min(1.0);
|
||||
let stability_score = features.temporal_stability;
|
||||
|
||||
// Compute weighted scores for each material
|
||||
for i in 0..MaterialType::NUM_CLASSES {
|
||||
scores[i] = self.material_weights.attenuation_weights[i] * attenuation_score
|
||||
+ self.material_weights.delay_weights[i] * delay_score
|
||||
+ self.material_weights.coherence_weights[i] * (1.0 - coherence_score)
|
||||
+ self.material_weights.biases[i]
|
||||
+ 0.1 * stability_score;
|
||||
}
|
||||
|
||||
// Apply softmax
|
||||
let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let exp_sum: f32 = scores.iter().map(|&s| (s - max_score).exp()).sum();
|
||||
let probabilities: Vec<f32> = scores.iter()
|
||||
.map(|&s| (s - max_score).exp() / exp_sum)
|
||||
.collect();
|
||||
|
||||
Ok(DebrisClassification::new(probabilities))
|
||||
}
|
||||
|
||||
/// Predict signal attenuation through debris
|
||||
#[instrument(skip(self, features))]
|
||||
pub async fn predict_attenuation(&self, features: &DebrisFeatures) -> MlResult<AttenuationPrediction> {
|
||||
// Get material classification first
|
||||
let classification = self.classify(features).await?;
|
||||
|
||||
// Base attenuation from material type
|
||||
let base_attenuation = classification.material_type.typical_attenuation();
|
||||
|
||||
// Adjust based on measured features
|
||||
let measured_factor = if features.snr_db < 0.0 {
|
||||
1.0 + (features.snr_db.abs() / 30.0).min(1.0)
|
||||
} else {
|
||||
1.0 - (features.snr_db / 30.0).min(0.5)
|
||||
};
|
||||
|
||||
// Layer factor
|
||||
let layer_factor = 1.0 + 0.2 * (classification.estimated_layers as f32 - 1.0);
|
||||
|
||||
// Composite factor
|
||||
let composite_factor = if classification.is_composite { 1.2 } else { 1.0 };
|
||||
|
||||
let total_attenuation = base_attenuation * measured_factor * layer_factor * composite_factor;
|
||||
|
||||
// Uncertainty estimation
|
||||
let uncertainty = if classification.is_composite {
|
||||
total_attenuation * 0.3 // Higher uncertainty for composite
|
||||
} else {
|
||||
total_attenuation * (1.0 - classification.confidence) * 0.5
|
||||
};
|
||||
|
||||
// Estimate depth (will be refined by depth estimation)
|
||||
let estimated_depth = self.estimate_depth_internal(features, total_attenuation);
|
||||
|
||||
Ok(AttenuationPrediction::new(total_attenuation, estimated_depth, uncertainty))
|
||||
}
|
||||
|
||||
/// Estimate penetration depth
|
||||
#[instrument(skip(self, features))]
|
||||
pub async fn estimate_depth(&self, features: &DebrisFeatures) -> MlResult<DepthEstimate> {
|
||||
// Get attenuation prediction
|
||||
let attenuation = self.predict_attenuation(features).await?;
|
||||
|
||||
// Estimate depth from attenuation and material properties
|
||||
let depth = self.estimate_depth_internal(features, attenuation.attenuation_db);
|
||||
|
||||
// Calculate uncertainty
|
||||
let uncertainty = self.calculate_depth_uncertainty(
|
||||
features,
|
||||
depth,
|
||||
attenuation.confidence,
|
||||
);
|
||||
|
||||
let confidence = (attenuation.confidence * features.temporal_stability).min(1.0);
|
||||
|
||||
Ok(DepthEstimate::new(depth, uncertainty, confidence))
|
||||
}
|
||||
|
||||
/// Internal depth estimation logic
|
||||
fn estimate_depth_internal(&self, features: &DebrisFeatures, attenuation_db: f32) -> f32 {
|
||||
// Use coherence bandwidth for depth estimation
|
||||
// Smaller coherence bandwidth suggests more multipath = deeper penetration
|
||||
let cb_depth = (20.0 - features.coherence_bandwidth) / 5.0;
|
||||
|
||||
// Use delay spread
|
||||
let ds_depth = features.delay_spread / 100.0;
|
||||
|
||||
// Use attenuation (assuming typical material)
|
||||
let att_depth = attenuation_db / 15.0;
|
||||
|
||||
// Combine estimates with weights
|
||||
let depth = 0.3 * cb_depth + 0.3 * ds_depth + 0.4 * att_depth;
|
||||
|
||||
// Clamp to reasonable range (0.1 - 10 meters)
|
||||
depth.clamp(0.1, 10.0)
|
||||
}
|
||||
|
||||
/// Calculate uncertainty in depth estimate
|
||||
fn calculate_depth_uncertainty(
|
||||
&self,
|
||||
features: &DebrisFeatures,
|
||||
depth: f32,
|
||||
confidence: f32,
|
||||
) -> f32 {
|
||||
// Base uncertainty proportional to depth
|
||||
let base_uncertainty = depth * 0.2;
|
||||
|
||||
// Adjust by temporal stability (less stable = more uncertain)
|
||||
let stability_factor = 1.0 + (1.0 - features.temporal_stability) * 0.5;
|
||||
|
||||
// Adjust by confidence (lower confidence = more uncertain)
|
||||
let confidence_factor = 1.0 + (1.0 - confidence) * 0.5;
|
||||
|
||||
// Adjust by multipath richness (more multipath = harder to estimate)
|
||||
let multipath_factor = 1.0 + features.multipath_richness * 0.3;
|
||||
|
||||
base_uncertainty * stability_factor * confidence_factor * multipath_factor
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::detection::CsiDataBuffer;
|
||||
|
||||
fn create_test_debris_features() -> DebrisFeatures {
|
||||
DebrisFeatures {
|
||||
amplitude_attenuation: vec![0.5; 64],
|
||||
phase_shifts: vec![0.1; 64],
|
||||
fading_profile: vec![0.8, 0.6, 0.4, 0.2, 0.1, 0.05, 0.02, 0.01],
|
||||
coherence_bandwidth: 5.0,
|
||||
delay_spread: 100.0,
|
||||
snr_db: 15.0,
|
||||
multipath_richness: 0.6,
|
||||
temporal_stability: 0.8,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_material_type() {
|
||||
assert_eq!(MaterialType::from_index(0), MaterialType::Concrete);
|
||||
assert_eq!(MaterialType::Concrete.to_index(), 0);
|
||||
assert!(MaterialType::Concrete.typical_attenuation() > MaterialType::Glass.typical_attenuation());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_debris_classification() {
|
||||
let probs = vec![0.7, 0.1, 0.05, 0.05, 0.05, 0.02, 0.02, 0.01];
|
||||
let classification = DebrisClassification::new(probs);
|
||||
|
||||
assert_eq!(classification.material_type, MaterialType::Concrete);
|
||||
assert!(classification.confidence > 0.6);
|
||||
assert!(!classification.is_composite);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_composite_detection() {
|
||||
let probs = vec![0.4, 0.35, 0.1, 0.05, 0.05, 0.02, 0.02, 0.01];
|
||||
let classification = DebrisClassification::new(probs);
|
||||
|
||||
assert!(classification.is_composite);
|
||||
assert_eq!(classification.material_type, MaterialType::Mixed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attenuation_prediction() {
|
||||
let pred = AttenuationPrediction::new(25.0, 2.0, 3.0);
|
||||
assert_eq!(pred.attenuation_per_meter, 12.5);
|
||||
assert!(pred.confidence > 0.0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rule_based_classification() {
|
||||
let config = DebrisModelConfig::default();
|
||||
let model = DebrisModel::rule_based(config);
|
||||
|
||||
let features = create_test_debris_features();
|
||||
let result = model.classify(&features).await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
let classification = result.unwrap();
|
||||
assert!(classification.confidence > 0.0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_depth_estimation() {
|
||||
let config = DebrisModelConfig::default();
|
||||
let model = DebrisModel::rule_based(config);
|
||||
|
||||
let features = create_test_debris_features();
|
||||
let result = model.estimate_depth(&features).await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
let estimate = result.unwrap();
|
||||
assert!(estimate.depth_meters > 0.0);
|
||||
assert!(estimate.depth_meters < 10.0);
|
||||
assert!(estimate.uncertainty_meters > 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_feature_extractor() {
|
||||
let extractor = DebrisFeatureExtractor::default();
|
||||
let features = create_test_debris_features();
|
||||
|
||||
let result = extractor.extract(&features);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let arr = result.unwrap();
|
||||
assert_eq!(arr.shape()[0], 1);
|
||||
assert_eq!(arr.shape()[1], 256);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_spatial_temporal_extraction() {
|
||||
let extractor = DebrisFeatureExtractor::new(64, 100);
|
||||
let features = create_test_debris_features();
|
||||
|
||||
let result = extractor.extract_spatial_temporal(&features);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let arr = result.unwrap();
|
||||
assert_eq!(arr.shape(), &[1, 2, 64, 1]);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,692 @@
|
||||
//! Machine Learning module for debris penetration pattern recognition.
|
||||
//!
|
||||
//! This module provides ML-based models for:
|
||||
//! - Debris material classification
|
||||
//! - Penetration depth prediction
|
||||
//! - Signal attenuation analysis
|
||||
//! - Vital signs classification with uncertainty estimation
|
||||
//!
|
||||
//! ## Architecture
|
||||
//!
|
||||
//! The ML subsystem integrates with the `wifi-densepose-nn` crate for ONNX inference
|
||||
//! and provides specialized models for disaster response scenarios.
|
||||
//!
|
||||
//! ```text
|
||||
//! CSI Data -> Feature Extraction -> Model Inference -> Predictions
|
||||
//! | | |
|
||||
//! v v v
|
||||
//! [Debris Features] [ONNX Models] [Classifications]
|
||||
//! [Signal Features] [Neural Nets] [Confidences]
|
||||
//! ```
|
||||
|
||||
mod debris_model;
|
||||
mod vital_signs_classifier;
|
||||
|
||||
pub use debris_model::{
|
||||
DebrisModel, DebrisModelConfig, DebrisFeatureExtractor,
|
||||
MaterialType, DebrisClassification, AttenuationPrediction,
|
||||
DebrisModelError,
|
||||
};
|
||||
|
||||
pub use vital_signs_classifier::{
|
||||
VitalSignsClassifier, VitalSignsClassifierConfig,
|
||||
BreathingClassification, HeartbeatClassification,
|
||||
UncertaintyEstimate, ClassifierOutput,
|
||||
};
|
||||
|
||||
use crate::detection::CsiDataBuffer;
|
||||
use crate::domain::{VitalSignsReading, BreathingPattern, HeartbeatSignature};
|
||||
use async_trait::async_trait;
|
||||
use std::path::Path;
|
||||
use thiserror::Error;
|
||||
|
||||
/// Errors that can occur in ML operations
|
||||
#[derive(Debug, Error)]
|
||||
pub enum MlError {
|
||||
/// Model loading error
|
||||
#[error("Failed to load model: {0}")]
|
||||
ModelLoad(String),
|
||||
|
||||
/// Inference error
|
||||
#[error("Inference failed: {0}")]
|
||||
Inference(String),
|
||||
|
||||
/// Feature extraction error
|
||||
#[error("Feature extraction failed: {0}")]
|
||||
FeatureExtraction(String),
|
||||
|
||||
/// Invalid input error
|
||||
#[error("Invalid input: {0}")]
|
||||
InvalidInput(String),
|
||||
|
||||
/// Model not initialized
|
||||
#[error("Model not initialized: {0}")]
|
||||
NotInitialized(String),
|
||||
|
||||
/// Configuration error
|
||||
#[error("Configuration error: {0}")]
|
||||
Config(String),
|
||||
|
||||
/// Integration error with wifi-densepose-nn
|
||||
#[error("Neural network error: {0}")]
|
||||
NeuralNetwork(#[from] wifi_densepose_nn::NnError),
|
||||
}
|
||||
|
||||
/// Result type for ML operations
|
||||
pub type MlResult<T> = Result<T, MlError>;
|
||||
|
||||
/// Trait for debris penetration models
|
||||
///
|
||||
/// This trait defines the interface for models that can predict
|
||||
/// material type and signal attenuation through debris layers.
|
||||
#[async_trait]
|
||||
pub trait DebrisPenetrationModel: Send + Sync {
|
||||
/// Classify the material type from CSI features
|
||||
async fn classify_material(&self, features: &DebrisFeatures) -> MlResult<MaterialType>;
|
||||
|
||||
/// Predict signal attenuation through debris
|
||||
async fn predict_attenuation(&self, features: &DebrisFeatures) -> MlResult<AttenuationPrediction>;
|
||||
|
||||
/// Estimate penetration depth in meters
|
||||
async fn estimate_depth(&self, features: &DebrisFeatures) -> MlResult<DepthEstimate>;
|
||||
|
||||
/// Get model confidence for the predictions
|
||||
fn model_confidence(&self) -> f32;
|
||||
|
||||
/// Check if the model is loaded and ready
|
||||
fn is_ready(&self) -> bool;
|
||||
}
|
||||
|
||||
/// Features extracted from CSI data for debris analysis
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DebrisFeatures {
|
||||
/// Amplitude attenuation across subcarriers
|
||||
pub amplitude_attenuation: Vec<f32>,
|
||||
/// Phase shift patterns
|
||||
pub phase_shifts: Vec<f32>,
|
||||
/// Frequency-selective fading characteristics
|
||||
pub fading_profile: Vec<f32>,
|
||||
/// Coherence bandwidth estimate
|
||||
pub coherence_bandwidth: f32,
|
||||
/// RMS delay spread
|
||||
pub delay_spread: f32,
|
||||
/// Signal-to-noise ratio estimate
|
||||
pub snr_db: f32,
|
||||
/// Multipath richness indicator
|
||||
pub multipath_richness: f32,
|
||||
/// Temporal stability metric
|
||||
pub temporal_stability: f32,
|
||||
}
|
||||
|
||||
impl DebrisFeatures {
|
||||
/// Create new debris features from raw CSI data
|
||||
pub fn from_csi(buffer: &CsiDataBuffer) -> MlResult<Self> {
|
||||
if buffer.amplitudes.is_empty() {
|
||||
return Err(MlError::FeatureExtraction("Empty CSI buffer".into()));
|
||||
}
|
||||
|
||||
// Calculate amplitude attenuation
|
||||
let amplitude_attenuation = Self::compute_amplitude_features(&buffer.amplitudes);
|
||||
|
||||
// Calculate phase shifts
|
||||
let phase_shifts = Self::compute_phase_features(&buffer.phases);
|
||||
|
||||
// Compute fading profile
|
||||
let fading_profile = Self::compute_fading_profile(&buffer.amplitudes);
|
||||
|
||||
// Estimate coherence bandwidth from frequency correlation
|
||||
let coherence_bandwidth = Self::estimate_coherence_bandwidth(&buffer.amplitudes);
|
||||
|
||||
// Estimate delay spread
|
||||
let delay_spread = Self::estimate_delay_spread(&buffer.amplitudes);
|
||||
|
||||
// Estimate SNR
|
||||
let snr_db = Self::estimate_snr(&buffer.amplitudes);
|
||||
|
||||
// Multipath richness
|
||||
let multipath_richness = Self::compute_multipath_richness(&buffer.amplitudes);
|
||||
|
||||
// Temporal stability
|
||||
let temporal_stability = Self::compute_temporal_stability(&buffer.amplitudes);
|
||||
|
||||
Ok(Self {
|
||||
amplitude_attenuation,
|
||||
phase_shifts,
|
||||
fading_profile,
|
||||
coherence_bandwidth,
|
||||
delay_spread,
|
||||
snr_db,
|
||||
multipath_richness,
|
||||
temporal_stability,
|
||||
})
|
||||
}
|
||||
|
||||
/// Compute amplitude features
|
||||
fn compute_amplitude_features(amplitudes: &[f64]) -> Vec<f32> {
|
||||
if amplitudes.is_empty() {
|
||||
return vec![];
|
||||
}
|
||||
|
||||
let mean = amplitudes.iter().sum::<f64>() / amplitudes.len() as f64;
|
||||
let variance = amplitudes.iter()
|
||||
.map(|a| (a - mean).powi(2))
|
||||
.sum::<f64>() / amplitudes.len() as f64;
|
||||
let std_dev = variance.sqrt();
|
||||
|
||||
// Normalize amplitudes
|
||||
amplitudes.iter()
|
||||
.map(|a| ((a - mean) / (std_dev + 1e-8)) as f32)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Compute phase features
|
||||
fn compute_phase_features(phases: &[f64]) -> Vec<f32> {
|
||||
if phases.len() < 2 {
|
||||
return vec![];
|
||||
}
|
||||
|
||||
// Compute phase differences (unwrapped)
|
||||
phases.windows(2)
|
||||
.map(|w| {
|
||||
let diff = w[1] - w[0];
|
||||
// Unwrap phase
|
||||
let unwrapped = if diff > std::f64::consts::PI {
|
||||
diff - 2.0 * std::f64::consts::PI
|
||||
} else if diff < -std::f64::consts::PI {
|
||||
diff + 2.0 * std::f64::consts::PI
|
||||
} else {
|
||||
diff
|
||||
};
|
||||
unwrapped as f32
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Compute fading profile (power spectral characteristics)
|
||||
fn compute_fading_profile(amplitudes: &[f64]) -> Vec<f32> {
|
||||
use rustfft::{FftPlanner, num_complex::Complex};
|
||||
|
||||
if amplitudes.len() < 16 {
|
||||
return vec![0.0; 8];
|
||||
}
|
||||
|
||||
// Take a subset for FFT
|
||||
let n = 64.min(amplitudes.len());
|
||||
let mut buffer: Vec<Complex<f64>> = amplitudes.iter()
|
||||
.take(n)
|
||||
.map(|&a| Complex::new(a, 0.0))
|
||||
.collect();
|
||||
|
||||
// Pad to power of 2
|
||||
while buffer.len() < 64 {
|
||||
buffer.push(Complex::new(0.0, 0.0));
|
||||
}
|
||||
|
||||
// Compute FFT
|
||||
let mut planner = FftPlanner::new();
|
||||
let fft = planner.plan_fft_forward(64);
|
||||
fft.process(&mut buffer);
|
||||
|
||||
// Extract power spectrum (first half)
|
||||
buffer.iter()
|
||||
.take(8)
|
||||
.map(|c| (c.norm() / n as f64) as f32)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Estimate coherence bandwidth from frequency correlation
|
||||
fn estimate_coherence_bandwidth(amplitudes: &[f64]) -> f32 {
|
||||
if amplitudes.len() < 10 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Compute autocorrelation
|
||||
let n = amplitudes.len();
|
||||
let mean = amplitudes.iter().sum::<f64>() / n as f64;
|
||||
let variance: f64 = amplitudes.iter()
|
||||
.map(|a| (a - mean).powi(2))
|
||||
.sum::<f64>() / n as f64;
|
||||
|
||||
if variance < 1e-10 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Find lag where correlation drops below 0.5
|
||||
let mut coherence_lag = n;
|
||||
for lag in 1..n / 2 {
|
||||
let correlation: f64 = amplitudes.iter()
|
||||
.take(n - lag)
|
||||
.zip(amplitudes.iter().skip(lag))
|
||||
.map(|(a, b)| (a - mean) * (b - mean))
|
||||
.sum::<f64>() / ((n - lag) as f64 * variance);
|
||||
|
||||
if correlation < 0.5 {
|
||||
coherence_lag = lag;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to bandwidth estimate (assuming 20 MHz channel)
|
||||
(20.0 / coherence_lag as f32).min(20.0)
|
||||
}
|
||||
|
||||
/// Estimate RMS delay spread
|
||||
fn estimate_delay_spread(amplitudes: &[f64]) -> f32 {
|
||||
if amplitudes.len() < 10 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Use power delay profile approximation
|
||||
let power: Vec<f64> = amplitudes.iter().map(|a| a.powi(2)).collect();
|
||||
let total_power: f64 = power.iter().sum();
|
||||
|
||||
if total_power < 1e-10 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Calculate mean delay
|
||||
let mean_delay: f64 = power.iter()
|
||||
.enumerate()
|
||||
.map(|(i, p)| i as f64 * p)
|
||||
.sum::<f64>() / total_power;
|
||||
|
||||
// Calculate RMS delay spread
|
||||
let variance: f64 = power.iter()
|
||||
.enumerate()
|
||||
.map(|(i, p)| (i as f64 - mean_delay).powi(2) * p)
|
||||
.sum::<f64>() / total_power;
|
||||
|
||||
// Convert to nanoseconds (assuming sample period)
|
||||
(variance.sqrt() * 50.0) as f32 // 50 ns per sample assumed
|
||||
}
|
||||
|
||||
/// Estimate SNR from amplitude variance
|
||||
fn estimate_snr(amplitudes: &[f64]) -> f32 {
|
||||
if amplitudes.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let mean = amplitudes.iter().sum::<f64>() / amplitudes.len() as f64;
|
||||
let variance = amplitudes.iter()
|
||||
.map(|a| (a - mean).powi(2))
|
||||
.sum::<f64>() / amplitudes.len() as f64;
|
||||
|
||||
if variance < 1e-10 {
|
||||
return 30.0; // High SNR assumed
|
||||
}
|
||||
|
||||
// SNR estimate based on signal power to noise power ratio
|
||||
let signal_power = mean.powi(2);
|
||||
let snr_linear = signal_power / variance;
|
||||
|
||||
(10.0 * snr_linear.log10()) as f32
|
||||
}
|
||||
|
||||
/// Compute multipath richness indicator
|
||||
fn compute_multipath_richness(amplitudes: &[f64]) -> f32 {
|
||||
if amplitudes.len() < 10 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Calculate amplitude variance as multipath indicator
|
||||
let mean = amplitudes.iter().sum::<f64>() / amplitudes.len() as f64;
|
||||
let variance = amplitudes.iter()
|
||||
.map(|a| (a - mean).powi(2))
|
||||
.sum::<f64>() / amplitudes.len() as f64;
|
||||
|
||||
// Normalize to 0-1 range
|
||||
let std_dev = variance.sqrt();
|
||||
let normalized = std_dev / (mean.abs() + 1e-8);
|
||||
|
||||
(normalized.min(1.0)) as f32
|
||||
}
|
||||
|
||||
/// Compute temporal stability metric
|
||||
fn compute_temporal_stability(amplitudes: &[f64]) -> f32 {
|
||||
if amplitudes.len() < 2 {
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
// Calculate coefficient of variation over time
|
||||
let differences: Vec<f64> = amplitudes.windows(2)
|
||||
.map(|w| (w[1] - w[0]).abs())
|
||||
.collect();
|
||||
|
||||
let mean_diff = differences.iter().sum::<f64>() / differences.len() as f64;
|
||||
let mean_amp = amplitudes.iter().sum::<f64>() / amplitudes.len() as f64;
|
||||
|
||||
// Stability is inverse of relative variation
|
||||
let variation = mean_diff / (mean_amp.abs() + 1e-8);
|
||||
|
||||
(1.0 - variation.min(1.0)) as f32
|
||||
}
|
||||
|
||||
/// Convert to feature vector for model input
|
||||
pub fn to_feature_vector(&self) -> Vec<f32> {
|
||||
let mut features = Vec::with_capacity(256);
|
||||
|
||||
// Add amplitude attenuation features (padded/truncated to 64)
|
||||
let amp_len = self.amplitude_attenuation.len().min(64);
|
||||
features.extend_from_slice(&self.amplitude_attenuation[..amp_len]);
|
||||
features.resize(64, 0.0);
|
||||
|
||||
// Add phase shift features (padded/truncated to 64)
|
||||
let phase_len = self.phase_shifts.len().min(64);
|
||||
features.extend_from_slice(&self.phase_shifts[..phase_len]);
|
||||
features.resize(128, 0.0);
|
||||
|
||||
// Add fading profile (padded to 16)
|
||||
let fading_len = self.fading_profile.len().min(16);
|
||||
features.extend_from_slice(&self.fading_profile[..fading_len]);
|
||||
features.resize(144, 0.0);
|
||||
|
||||
// Add scalar features
|
||||
features.push(self.coherence_bandwidth);
|
||||
features.push(self.delay_spread);
|
||||
features.push(self.snr_db);
|
||||
features.push(self.multipath_richness);
|
||||
features.push(self.temporal_stability);
|
||||
|
||||
// Pad to 256 for model input
|
||||
features.resize(256, 0.0);
|
||||
|
||||
features
|
||||
}
|
||||
}
|
||||
|
||||
/// Depth estimate with uncertainty
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DepthEstimate {
|
||||
/// Estimated depth in meters
|
||||
pub depth_meters: f32,
|
||||
/// Uncertainty (standard deviation) in meters
|
||||
pub uncertainty_meters: f32,
|
||||
/// Confidence in the estimate (0.0-1.0)
|
||||
pub confidence: f32,
|
||||
/// Lower bound of 95% confidence interval
|
||||
pub lower_bound: f32,
|
||||
/// Upper bound of 95% confidence interval
|
||||
pub upper_bound: f32,
|
||||
}
|
||||
|
||||
impl DepthEstimate {
|
||||
/// Create a new depth estimate with uncertainty
|
||||
pub fn new(depth: f32, uncertainty: f32, confidence: f32) -> Self {
|
||||
Self {
|
||||
depth_meters: depth,
|
||||
uncertainty_meters: uncertainty,
|
||||
confidence,
|
||||
lower_bound: (depth - 1.96 * uncertainty).max(0.0),
|
||||
upper_bound: depth + 1.96 * uncertainty,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the estimate is reliable (high confidence, low uncertainty)
|
||||
pub fn is_reliable(&self) -> bool {
|
||||
self.confidence > 0.7 && self.uncertainty_meters < self.depth_meters * 0.3
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for the ML-enhanced detection pipeline
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct MlDetectionConfig {
|
||||
/// Enable ML-based debris classification
|
||||
pub enable_debris_classification: bool,
|
||||
/// Enable ML-based vital signs classification
|
||||
pub enable_vital_classification: bool,
|
||||
/// Path to debris model file
|
||||
pub debris_model_path: Option<String>,
|
||||
/// Path to vital signs model file
|
||||
pub vital_model_path: Option<String>,
|
||||
/// Minimum confidence threshold for ML predictions
|
||||
pub min_confidence: f32,
|
||||
/// Use GPU for inference
|
||||
pub use_gpu: bool,
|
||||
/// Number of inference threads
|
||||
pub num_threads: usize,
|
||||
}
|
||||
|
||||
impl Default for MlDetectionConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enable_debris_classification: false,
|
||||
enable_vital_classification: false,
|
||||
debris_model_path: None,
|
||||
vital_model_path: None,
|
||||
min_confidence: 0.5,
|
||||
use_gpu: false,
|
||||
num_threads: 4,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl MlDetectionConfig {
|
||||
/// Create configuration for CPU inference
|
||||
pub fn cpu() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Create configuration for GPU inference
|
||||
pub fn gpu() -> Self {
|
||||
Self {
|
||||
use_gpu: true,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Enable debris classification with model path
|
||||
pub fn with_debris_model<P: Into<String>>(mut self, path: P) -> Self {
|
||||
self.debris_model_path = Some(path.into());
|
||||
self.enable_debris_classification = true;
|
||||
self
|
||||
}
|
||||
|
||||
/// Enable vital signs classification with model path
|
||||
pub fn with_vital_model<P: Into<String>>(mut self, path: P) -> Self {
|
||||
self.vital_model_path = Some(path.into());
|
||||
self.enable_vital_classification = true;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set minimum confidence threshold
|
||||
pub fn with_min_confidence(mut self, confidence: f32) -> Self {
|
||||
self.min_confidence = confidence.clamp(0.0, 1.0);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// ML-enhanced detection pipeline that combines traditional and ML-based detection
|
||||
pub struct MlDetectionPipeline {
|
||||
config: MlDetectionConfig,
|
||||
debris_model: Option<DebrisModel>,
|
||||
vital_classifier: Option<VitalSignsClassifier>,
|
||||
}
|
||||
|
||||
impl MlDetectionPipeline {
|
||||
/// Create a new ML detection pipeline
|
||||
pub fn new(config: MlDetectionConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
debris_model: None,
|
||||
vital_classifier: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize models asynchronously
|
||||
pub async fn initialize(&mut self) -> MlResult<()> {
|
||||
if self.config.enable_debris_classification {
|
||||
if let Some(ref path) = self.config.debris_model_path {
|
||||
let debris_config = DebrisModelConfig {
|
||||
use_gpu: self.config.use_gpu,
|
||||
num_threads: self.config.num_threads,
|
||||
confidence_threshold: self.config.min_confidence,
|
||||
};
|
||||
self.debris_model = Some(DebrisModel::from_onnx(path, debris_config)?);
|
||||
}
|
||||
}
|
||||
|
||||
if self.config.enable_vital_classification {
|
||||
if let Some(ref path) = self.config.vital_model_path {
|
||||
let vital_config = VitalSignsClassifierConfig {
|
||||
use_gpu: self.config.use_gpu,
|
||||
num_threads: self.config.num_threads,
|
||||
min_confidence: self.config.min_confidence,
|
||||
enable_uncertainty: true,
|
||||
mc_samples: 10,
|
||||
dropout_rate: 0.1,
|
||||
};
|
||||
self.vital_classifier = Some(VitalSignsClassifier::from_onnx(path, vital_config)?);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Process CSI data and return enhanced detection results
|
||||
pub async fn process(&self, buffer: &CsiDataBuffer) -> MlResult<MlDetectionResult> {
|
||||
let mut result = MlDetectionResult::default();
|
||||
|
||||
// Extract debris features and classify if enabled
|
||||
if let Some(ref model) = self.debris_model {
|
||||
let features = DebrisFeatures::from_csi(buffer)?;
|
||||
result.debris_classification = Some(model.classify(&features).await?);
|
||||
result.depth_estimate = Some(model.estimate_depth(&features).await?);
|
||||
}
|
||||
|
||||
// Classify vital signs if enabled
|
||||
if let Some(ref classifier) = self.vital_classifier {
|
||||
let features = classifier.extract_features(buffer)?;
|
||||
result.vital_classification = Some(classifier.classify(&features).await?);
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Check if the pipeline is ready for inference
|
||||
pub fn is_ready(&self) -> bool {
|
||||
let debris_ready = !self.config.enable_debris_classification
|
||||
|| self.debris_model.as_ref().map_or(false, |m| m.is_loaded());
|
||||
let vital_ready = !self.config.enable_vital_classification
|
||||
|| self.vital_classifier.as_ref().map_or(false, |c| c.is_loaded());
|
||||
|
||||
debris_ready && vital_ready
|
||||
}
|
||||
|
||||
/// Get configuration
|
||||
pub fn config(&self) -> &MlDetectionConfig {
|
||||
&self.config
|
||||
}
|
||||
}
|
||||
|
||||
/// Combined ML detection results
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct MlDetectionResult {
|
||||
/// Debris classification result
|
||||
pub debris_classification: Option<DebrisClassification>,
|
||||
/// Depth estimate
|
||||
pub depth_estimate: Option<DepthEstimate>,
|
||||
/// Vital signs classification
|
||||
pub vital_classification: Option<ClassifierOutput>,
|
||||
}
|
||||
|
||||
impl MlDetectionResult {
|
||||
/// Check if any ML detection was performed
|
||||
pub fn has_results(&self) -> bool {
|
||||
self.debris_classification.is_some()
|
||||
|| self.depth_estimate.is_some()
|
||||
|| self.vital_classification.is_some()
|
||||
}
|
||||
|
||||
/// Get overall confidence
|
||||
pub fn overall_confidence(&self) -> f32 {
|
||||
let mut total = 0.0;
|
||||
let mut count = 0;
|
||||
|
||||
if let Some(ref debris) = self.debris_classification {
|
||||
total += debris.confidence;
|
||||
count += 1;
|
||||
}
|
||||
|
||||
if let Some(ref depth) = self.depth_estimate {
|
||||
total += depth.confidence;
|
||||
count += 1;
|
||||
}
|
||||
|
||||
if let Some(ref vital) = self.vital_classification {
|
||||
total += vital.overall_confidence;
|
||||
count += 1;
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
total / count as f32
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn create_test_buffer() -> CsiDataBuffer {
|
||||
let mut buffer = CsiDataBuffer::new(1000.0);
|
||||
let amplitudes: Vec<f64> = (0..1000)
|
||||
.map(|i| {
|
||||
let t = i as f64 / 1000.0;
|
||||
0.5 + 0.1 * (2.0 * std::f64::consts::PI * 0.25 * t).sin()
|
||||
})
|
||||
.collect();
|
||||
let phases: Vec<f64> = (0..1000)
|
||||
.map(|i| {
|
||||
let t = i as f64 / 1000.0;
|
||||
(2.0 * std::f64::consts::PI * 0.25 * t).sin() * 0.3
|
||||
})
|
||||
.collect();
|
||||
buffer.add_samples(&litudes, &phases);
|
||||
buffer
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_debris_features_extraction() {
|
||||
let buffer = create_test_buffer();
|
||||
let features = DebrisFeatures::from_csi(&buffer);
|
||||
assert!(features.is_ok());
|
||||
|
||||
let features = features.unwrap();
|
||||
assert!(!features.amplitude_attenuation.is_empty());
|
||||
assert!(!features.phase_shifts.is_empty());
|
||||
assert!(features.coherence_bandwidth >= 0.0);
|
||||
assert!(features.delay_spread >= 0.0);
|
||||
assert!(features.temporal_stability >= 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_feature_vector_size() {
|
||||
let buffer = create_test_buffer();
|
||||
let features = DebrisFeatures::from_csi(&buffer).unwrap();
|
||||
let vector = features.to_feature_vector();
|
||||
assert_eq!(vector.len(), 256);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_depth_estimate() {
|
||||
let estimate = DepthEstimate::new(2.5, 0.3, 0.85);
|
||||
assert!(estimate.is_reliable());
|
||||
assert!(estimate.lower_bound < estimate.depth_meters);
|
||||
assert!(estimate.upper_bound > estimate.depth_meters);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ml_config_builder() {
|
||||
let config = MlDetectionConfig::cpu()
|
||||
.with_debris_model("models/debris.onnx")
|
||||
.with_vital_model("models/vitals.onnx")
|
||||
.with_min_confidence(0.7);
|
||||
|
||||
assert!(config.enable_debris_classification);
|
||||
assert!(config.enable_vital_classification);
|
||||
assert_eq!(config.min_confidence, 0.7);
|
||||
assert!(!config.use_gpu);
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -3,5 +3,61 @@ name = "wifi-densepose-wasm"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
description = "WebAssembly bindings for WiFi-DensePose"
|
||||
license = "MIT OR Apache-2.0"
|
||||
repository = "https://github.com/ruvnet/wifi-densepose"
|
||||
|
||||
[lib]
|
||||
crate-type = ["cdylib", "rlib"]
|
||||
|
||||
[features]
|
||||
default = ["console_error_panic_hook"]
|
||||
mat = ["wifi-densepose-mat"]
|
||||
|
||||
[dependencies]
|
||||
# WASM bindings
|
||||
wasm-bindgen = "0.2"
|
||||
wasm-bindgen-futures = "0.4"
|
||||
js-sys = "0.3"
|
||||
web-sys = { version = "0.3", features = [
|
||||
"console",
|
||||
"Window",
|
||||
"Document",
|
||||
"Element",
|
||||
"HtmlCanvasElement",
|
||||
"CanvasRenderingContext2d",
|
||||
"WebSocket",
|
||||
"MessageEvent",
|
||||
"ErrorEvent",
|
||||
"CloseEvent",
|
||||
"BinaryType",
|
||||
"Performance",
|
||||
] }
|
||||
|
||||
# Error handling and logging
|
||||
console_error_panic_hook = { version = "0.1", optional = true }
|
||||
wasm-logger = "0.2"
|
||||
log = "0.4"
|
||||
|
||||
# Serialization for JS interop
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
serde-wasm-bindgen = "0.6"
|
||||
|
||||
# Async runtime for WASM
|
||||
futures = "0.3"
|
||||
|
||||
# Time handling
|
||||
chrono = { version = "0.4", features = ["serde", "wasmbind"] }
|
||||
|
||||
# UUID generation (with JS random support)
|
||||
uuid = { version = "1.6", features = ["v4", "serde", "js"] }
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
|
||||
# Optional: wifi-densepose-mat integration
|
||||
wifi-densepose-mat = { path = "../wifi-densepose-mat", optional = true, features = ["serde"] }
|
||||
|
||||
[dev-dependencies]
|
||||
wasm-bindgen-test = "0.3"
|
||||
|
||||
[package.metadata.wasm-pack.profile.release]
|
||||
wasm-opt = ["-O4", "--enable-mutable-globals"]
|
||||
|
||||
@@ -1 +1,132 @@
|
||||
//! WiFi-DensePose WebAssembly bindings (stub)
|
||||
//! WiFi-DensePose WebAssembly bindings
|
||||
//!
|
||||
//! This crate provides WebAssembly bindings for browser-based applications using
|
||||
//! WiFi-DensePose technology. It includes:
|
||||
//!
|
||||
//! - **mat**: WiFi-Mat disaster response dashboard module for browser integration
|
||||
//!
|
||||
//! # Features
|
||||
//!
|
||||
//! - `mat` - Enable WiFi-Mat disaster detection WASM bindings
|
||||
//! - `console_error_panic_hook` - Better panic messages in browser console
|
||||
//!
|
||||
//! # Building for WASM
|
||||
//!
|
||||
//! ```bash
|
||||
//! # Build with wasm-pack
|
||||
//! wasm-pack build --target web --features mat
|
||||
//!
|
||||
//! # Or with cargo
|
||||
//! cargo build --target wasm32-unknown-unknown --features mat
|
||||
//! ```
|
||||
//!
|
||||
//! # Example Usage (JavaScript)
|
||||
//!
|
||||
//! ```javascript
|
||||
//! import init, { MatDashboard, initLogging } from './wifi_densepose_wasm.js';
|
||||
//!
|
||||
//! async function main() {
|
||||
//! await init();
|
||||
//! initLogging('info');
|
||||
//!
|
||||
//! const dashboard = new MatDashboard();
|
||||
//!
|
||||
//! // Create a disaster event
|
||||
//! const eventId = dashboard.createEvent('earthquake', 37.7749, -122.4194, 'Bay Area Earthquake');
|
||||
//!
|
||||
//! // Add scan zones
|
||||
//! dashboard.addRectangleZone('Building A', 50, 50, 200, 150);
|
||||
//! dashboard.addCircleZone('Search Area B', 400, 200, 80);
|
||||
//!
|
||||
//! // Subscribe to events
|
||||
//! dashboard.onSurvivorDetected((survivor) => {
|
||||
//! console.log('Survivor detected:', survivor);
|
||||
//! updateUI(survivor);
|
||||
//! });
|
||||
//!
|
||||
//! dashboard.onAlertGenerated((alert) => {
|
||||
//! showNotification(alert);
|
||||
//! });
|
||||
//!
|
||||
//! // Render to canvas
|
||||
//! const canvas = document.getElementById('map');
|
||||
//! const ctx = canvas.getContext('2d');
|
||||
//!
|
||||
//! function render() {
|
||||
//! ctx.clearRect(0, 0, canvas.width, canvas.height);
|
||||
//! dashboard.renderZones(ctx);
|
||||
//! dashboard.renderSurvivors(ctx);
|
||||
//! requestAnimationFrame(render);
|
||||
//! }
|
||||
//! render();
|
||||
//! }
|
||||
//!
|
||||
//! main();
|
||||
//! ```
|
||||
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
// WiFi-Mat module for disaster response dashboard
|
||||
pub mod mat;
|
||||
pub use mat::*;
|
||||
|
||||
/// Initialize the WASM module.
|
||||
/// Call this once at startup before using any other functions.
|
||||
#[wasm_bindgen(start)]
|
||||
pub fn init() {
|
||||
// Set panic hook for better error messages in browser console
|
||||
#[cfg(feature = "console_error_panic_hook")]
|
||||
console_error_panic_hook::set_once();
|
||||
}
|
||||
|
||||
/// Initialize logging with specified level.
|
||||
///
|
||||
/// @param {string} level - Log level: "trace", "debug", "info", "warn", "error"
|
||||
#[wasm_bindgen(js_name = initLogging)]
|
||||
pub fn init_logging(level: &str) {
|
||||
let log_level = match level.to_lowercase().as_str() {
|
||||
"trace" => log::Level::Trace,
|
||||
"debug" => log::Level::Debug,
|
||||
"info" => log::Level::Info,
|
||||
"warn" => log::Level::Warn,
|
||||
"error" => log::Level::Error,
|
||||
_ => log::Level::Info,
|
||||
};
|
||||
|
||||
let _ = wasm_logger::init(wasm_logger::Config::new(log_level));
|
||||
log::info!("WiFi-DensePose WASM initialized with log level: {}", level);
|
||||
}
|
||||
|
||||
/// Get the library version.
|
||||
///
|
||||
/// @returns {string} Version string
|
||||
#[wasm_bindgen(js_name = getVersion)]
|
||||
pub fn get_version() -> String {
|
||||
env!("CARGO_PKG_VERSION").to_string()
|
||||
}
|
||||
|
||||
/// Check if the MAT feature is enabled.
|
||||
///
|
||||
/// @returns {boolean} True if MAT module is available
|
||||
#[wasm_bindgen(js_name = isMatEnabled)]
|
||||
pub fn is_mat_enabled() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// Get current timestamp in milliseconds (for performance measurements).
|
||||
///
|
||||
/// @returns {number} Timestamp in milliseconds
|
||||
#[wasm_bindgen(js_name = getTimestamp)]
|
||||
pub fn get_timestamp() -> f64 {
|
||||
let window = web_sys::window().expect("no global window");
|
||||
let performance = window.performance().expect("no performance object");
|
||||
performance.now()
|
||||
}
|
||||
|
||||
// Re-export all public types from mat module for easy access
|
||||
pub mod types {
|
||||
pub use super::mat::{
|
||||
JsAlert, JsAlertPriority, JsDashboardStats, JsDisasterType, JsScanZone, JsSurvivor,
|
||||
JsTriageStatus, JsZoneStatus,
|
||||
};
|
||||
}
|
||||
|
||||
1553
rust-port/wifi-densepose-rs/crates/wifi-densepose-wasm/src/mat.rs
Normal file
1553
rust-port/wifi-densepose-rs/crates/wifi-densepose-wasm/src/mat.rs
Normal file
File diff suppressed because it is too large
Load Diff
1082
rust-port/wifi-densepose-rs/examples/mat-dashboard.html
Normal file
1082
rust-port/wifi-densepose-rs/examples/mat-dashboard.html
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user