feat: Add wifi-Mat disaster detection enhancements

Implement 6 optional enhancements for the wifi-Mat module:

1. Hardware Integration (csi_receiver.rs + hardware_adapter.rs)
   - ESP32 CSI support via serial/UDP
   - Intel 5300 BFEE file parsing
   - Atheros CSI Tool integration
   - Live UDP packet streaming
   - PCAP replay capability

2. CLI Commands (wifi-densepose-cli/src/mat.rs)
   - `wifi-mat scan` - Run disaster detection scan
   - `wifi-mat status` - Check event status
   - `wifi-mat zones` - Manage scan zones
   - `wifi-mat survivors` - List detected survivors
   - `wifi-mat alerts` - View and acknowledge alerts
   - `wifi-mat export` - Export data in various formats

3. REST API (wifi-densepose-mat/src/api/)
   - Full CRUD for disaster events
   - Zone management endpoints
   - Survivor and alert queries
   - WebSocket streaming for real-time updates
   - Comprehensive DTOs and error handling

4. WASM Build (wifi-densepose-wasm/src/mat.rs)
   - Browser-based disaster dashboard
   - Real-time survivor tracking
   - Zone visualization
   - Alert management
   - JavaScript API bindings

5. Detection Benchmarks (benches/detection_bench.rs)
   - Single survivor detection
   - Multi-survivor detection
   - Full pipeline benchmarks
   - Signal processing benchmarks
   - Hardware adapter benchmarks

6. ML Models for Debris Penetration (ml/)
   - DebrisModel for material analysis
   - VitalSignsClassifier for triage
   - FFT-based feature extraction
   - Bandpass filtering
   - Monte Carlo dropout for uncertainty

All 134 unit tests pass. Compilation verified for:
- wifi-densepose-mat
- wifi-densepose-cli
- wifi-densepose-wasm (with mat feature)
This commit is contained in:
Claude
2026-01-13 18:23:03 +00:00
parent 8a43e8f355
commit 6b20ff0c14
25 changed files with 14452 additions and 60 deletions

View File

@@ -0,0 +1,765 @@
//! ONNX-based debris penetration model for material classification and depth prediction.
//!
//! This module provides neural network models for analyzing debris characteristics
//! from WiFi CSI signals. Key capabilities include:
//!
//! - Material type classification (concrete, wood, metal, etc.)
//! - Signal attenuation prediction based on material properties
//! - Penetration depth estimation with uncertainty quantification
//!
//! ## Model Architecture
//!
//! The debris model uses a multi-head architecture:
//! - Shared feature encoder (CNN-based)
//! - Material classification head (softmax output)
//! - Attenuation regression head (linear output)
//! - Depth estimation head with uncertainty (mean + variance output)
use super::{DebrisFeatures, DepthEstimate, MlError, MlResult};
use ndarray::{Array1, Array2, Array4, s};
use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
use parking_lot::RwLock;
use thiserror::Error;
use tracing::{debug, info, instrument, warn};
#[cfg(feature = "onnx")]
use wifi_densepose_nn::{OnnxBackend, OnnxSession, InferenceOptions, Tensor, TensorShape};
/// Errors specific to debris model operations
#[derive(Debug, Error)]
pub enum DebrisModelError {
/// Model file not found
#[error("Model file not found: {0}")]
FileNotFound(String),
/// Invalid model format
#[error("Invalid model format: {0}")]
InvalidFormat(String),
/// Inference error
#[error("Inference failed: {0}")]
InferenceFailed(String),
/// Feature extraction error
#[error("Feature extraction failed: {0}")]
FeatureExtractionFailed(String),
}
/// Types of materials that can be detected in debris
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MaterialType {
/// Reinforced concrete (high attenuation)
Concrete,
/// Wood/timber (moderate attenuation)
Wood,
/// Metal/steel (very high attenuation, reflective)
Metal,
/// Glass (low attenuation)
Glass,
/// Brick/masonry (high attenuation)
Brick,
/// Drywall/plasterboard (low attenuation)
Drywall,
/// Mixed/composite materials
Mixed,
/// Unknown material type
Unknown,
}
impl MaterialType {
/// Get typical attenuation coefficient (dB/m)
pub fn typical_attenuation(&self) -> f32 {
match self {
MaterialType::Concrete => 25.0,
MaterialType::Wood => 8.0,
MaterialType::Metal => 50.0,
MaterialType::Glass => 3.0,
MaterialType::Brick => 18.0,
MaterialType::Drywall => 4.0,
MaterialType::Mixed => 15.0,
MaterialType::Unknown => 12.0,
}
}
/// Get typical delay spread (nanoseconds)
pub fn typical_delay_spread(&self) -> f32 {
match self {
MaterialType::Concrete => 150.0,
MaterialType::Wood => 50.0,
MaterialType::Metal => 200.0,
MaterialType::Glass => 20.0,
MaterialType::Brick => 100.0,
MaterialType::Drywall => 30.0,
MaterialType::Mixed => 80.0,
MaterialType::Unknown => 60.0,
}
}
/// From class index
pub fn from_index(index: usize) -> Self {
match index {
0 => MaterialType::Concrete,
1 => MaterialType::Wood,
2 => MaterialType::Metal,
3 => MaterialType::Glass,
4 => MaterialType::Brick,
5 => MaterialType::Drywall,
6 => MaterialType::Mixed,
_ => MaterialType::Unknown,
}
}
/// To class index
pub fn to_index(&self) -> usize {
match self {
MaterialType::Concrete => 0,
MaterialType::Wood => 1,
MaterialType::Metal => 2,
MaterialType::Glass => 3,
MaterialType::Brick => 4,
MaterialType::Drywall => 5,
MaterialType::Mixed => 6,
MaterialType::Unknown => 7,
}
}
/// Number of material classes
pub const NUM_CLASSES: usize = 8;
}
impl std::fmt::Display for MaterialType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MaterialType::Concrete => write!(f, "Concrete"),
MaterialType::Wood => write!(f, "Wood"),
MaterialType::Metal => write!(f, "Metal"),
MaterialType::Glass => write!(f, "Glass"),
MaterialType::Brick => write!(f, "Brick"),
MaterialType::Drywall => write!(f, "Drywall"),
MaterialType::Mixed => write!(f, "Mixed"),
MaterialType::Unknown => write!(f, "Unknown"),
}
}
}
/// Result of debris material classification
#[derive(Debug, Clone)]
pub struct DebrisClassification {
/// Primary material type detected
pub material_type: MaterialType,
/// Confidence score for the classification (0.0-1.0)
pub confidence: f32,
/// Per-class probabilities
pub class_probabilities: Vec<f32>,
/// Estimated layer count
pub estimated_layers: u8,
/// Whether multiple materials detected
pub is_composite: bool,
}
impl DebrisClassification {
/// Create a new debris classification
pub fn new(probabilities: Vec<f32>) -> Self {
let (max_idx, &max_prob) = probabilities.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.unwrap_or((7, &0.0));
// Check for composite materials (multiple high probabilities)
let high_prob_count = probabilities.iter()
.filter(|&&p| p > 0.2)
.count();
let is_composite = high_prob_count > 1 && max_prob < 0.7;
let material_type = if is_composite {
MaterialType::Mixed
} else {
MaterialType::from_index(max_idx)
};
// Estimate layer count from delay spread characteristics
let estimated_layers = Self::estimate_layers(&probabilities);
Self {
material_type,
confidence: max_prob,
class_probabilities: probabilities,
estimated_layers,
is_composite,
}
}
/// Estimate number of debris layers from probability distribution
fn estimate_layers(probabilities: &[f32]) -> u8 {
// More uniform distribution suggests more layers
let entropy: f32 = probabilities.iter()
.filter(|&&p| p > 0.01)
.map(|&p| -p * p.ln())
.sum();
let max_entropy = (probabilities.len() as f32).ln();
let normalized_entropy = entropy / max_entropy;
// Map entropy to layer count (1-5)
(1.0 + normalized_entropy * 4.0).round() as u8
}
/// Get secondary material if composite
pub fn secondary_material(&self) -> Option<MaterialType> {
if !self.is_composite {
return None;
}
let primary_idx = self.material_type.to_index();
self.class_probabilities.iter()
.enumerate()
.filter(|(i, _)| *i != primary_idx)
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(i, _)| MaterialType::from_index(i))
}
}
/// Signal attenuation prediction result
#[derive(Debug, Clone)]
pub struct AttenuationPrediction {
/// Predicted attenuation in dB
pub attenuation_db: f32,
/// Attenuation per meter (dB/m)
pub attenuation_per_meter: f32,
/// Uncertainty in the prediction
pub uncertainty_db: f32,
/// Frequency-dependent attenuation profile
pub frequency_profile: Vec<f32>,
/// Confidence in the prediction
pub confidence: f32,
}
impl AttenuationPrediction {
/// Create new attenuation prediction
pub fn new(attenuation: f32, depth: f32, uncertainty: f32) -> Self {
let attenuation_per_meter = if depth > 0.0 {
attenuation / depth
} else {
0.0
};
Self {
attenuation_db: attenuation,
attenuation_per_meter,
uncertainty_db: uncertainty,
frequency_profile: vec![],
confidence: (1.0 - uncertainty / attenuation.abs().max(1.0)).max(0.0),
}
}
/// Predict signal at given depth
pub fn predict_signal_at_depth(&self, depth_m: f32) -> f32 {
-self.attenuation_per_meter * depth_m
}
}
/// Configuration for debris model
#[derive(Debug, Clone)]
pub struct DebrisModelConfig {
/// Use GPU for inference
pub use_gpu: bool,
/// Number of inference threads
pub num_threads: usize,
/// Minimum confidence threshold
pub confidence_threshold: f32,
}
impl Default for DebrisModelConfig {
fn default() -> Self {
Self {
use_gpu: false,
num_threads: 4,
confidence_threshold: 0.5,
}
}
}
/// Feature extractor for debris classification
pub struct DebrisFeatureExtractor {
/// Number of subcarriers to analyze
num_subcarriers: usize,
/// Window size for temporal analysis
window_size: usize,
/// Whether to use advanced features
use_advanced_features: bool,
}
impl Default for DebrisFeatureExtractor {
fn default() -> Self {
Self {
num_subcarriers: 64,
window_size: 100,
use_advanced_features: true,
}
}
}
impl DebrisFeatureExtractor {
/// Create new feature extractor
pub fn new(num_subcarriers: usize, window_size: usize) -> Self {
Self {
num_subcarriers,
window_size,
use_advanced_features: true,
}
}
/// Extract features from debris features for model input
pub fn extract(&self, features: &DebrisFeatures) -> MlResult<Array2<f32>> {
let feature_vector = features.to_feature_vector();
// Reshape to 2D for model input (batch_size=1, features)
let arr = Array2::from_shape_vec(
(1, feature_vector.len()),
feature_vector,
).map_err(|e| MlError::FeatureExtraction(e.to_string()))?;
Ok(arr)
}
/// Extract spatial-temporal features for CNN input
pub fn extract_spatial_temporal(&self, features: &DebrisFeatures) -> MlResult<Array4<f32>> {
let amp_len = features.amplitude_attenuation.len().min(self.num_subcarriers);
let phase_len = features.phase_shifts.len().min(self.num_subcarriers);
// Create 4D tensor: [batch, channels, height, width]
// channels: amplitude, phase
// height: subcarriers
// width: 1 (or temporal windows if available)
let mut tensor = Array4::<f32>::zeros((1, 2, self.num_subcarriers, 1));
// Fill amplitude channel
for (i, &v) in features.amplitude_attenuation.iter().take(amp_len).enumerate() {
tensor[[0, 0, i, 0]] = v;
}
// Fill phase channel
for (i, &v) in features.phase_shifts.iter().take(phase_len).enumerate() {
tensor[[0, 1, i, 0]] = v;
}
Ok(tensor)
}
}
/// ONNX-based debris penetration model
pub struct DebrisModel {
config: DebrisModelConfig,
feature_extractor: DebrisFeatureExtractor,
/// Material classification model weights (for rule-based fallback)
material_weights: MaterialClassificationWeights,
/// Whether ONNX model is loaded
model_loaded: bool,
/// Cached model session
#[cfg(feature = "onnx")]
session: Option<Arc<RwLock<OnnxSession>>>,
}
/// Pre-computed weights for rule-based material classification
struct MaterialClassificationWeights {
/// Weights for attenuation features
attenuation_weights: [f32; MaterialType::NUM_CLASSES],
/// Weights for delay spread features
delay_weights: [f32; MaterialType::NUM_CLASSES],
/// Weights for coherence bandwidth
coherence_weights: [f32; MaterialType::NUM_CLASSES],
/// Bias terms
biases: [f32; MaterialType::NUM_CLASSES],
}
impl Default for MaterialClassificationWeights {
fn default() -> Self {
// Pre-computed weights based on material RF properties
Self {
attenuation_weights: [0.8, 0.3, 0.95, 0.1, 0.6, 0.15, 0.5, 0.4],
delay_weights: [0.7, 0.2, 0.9, 0.1, 0.5, 0.1, 0.4, 0.3],
coherence_weights: [0.3, 0.7, 0.1, 0.9, 0.4, 0.8, 0.5, 0.5],
biases: [-0.5, 0.2, -0.8, 0.5, -0.3, 0.3, 0.0, 0.0],
}
}
}
impl DebrisModel {
/// Create a new debris model from ONNX file
#[instrument(skip(path))]
pub fn from_onnx<P: AsRef<Path>>(path: P, config: DebrisModelConfig) -> MlResult<Self> {
let path_ref = path.as_ref();
info!(?path_ref, "Loading debris model");
#[cfg(feature = "onnx")]
let session = if path_ref.exists() {
let options = InferenceOptions {
use_gpu: config.use_gpu,
num_threads: config.num_threads,
..Default::default()
};
match OnnxSession::from_file(path_ref, &options) {
Ok(s) => {
info!("ONNX debris model loaded successfully");
Some(Arc::new(RwLock::new(s)))
}
Err(e) => {
warn!(?e, "Failed to load ONNX model, using rule-based fallback");
None
}
}
} else {
warn!(?path_ref, "Model file not found, using rule-based fallback");
None
};
#[cfg(feature = "onnx")]
let model_loaded = session.is_some();
#[cfg(not(feature = "onnx"))]
let model_loaded = false;
Ok(Self {
config,
feature_extractor: DebrisFeatureExtractor::default(),
material_weights: MaterialClassificationWeights::default(),
model_loaded,
#[cfg(feature = "onnx")]
session,
})
}
/// Create with in-memory model bytes
#[cfg(feature = "onnx")]
pub fn from_bytes(bytes: &[u8], config: DebrisModelConfig) -> MlResult<Self> {
let options = InferenceOptions {
use_gpu: config.use_gpu,
num_threads: config.num_threads,
..Default::default()
};
let session = OnnxSession::from_bytes(bytes, &options)
.map_err(|e| MlError::ModelLoad(e.to_string()))?;
Ok(Self {
config,
feature_extractor: DebrisFeatureExtractor::default(),
material_weights: MaterialClassificationWeights::default(),
model_loaded: true,
session: Some(Arc::new(RwLock::new(session))),
})
}
/// Create a rule-based model (no ONNX required)
pub fn rule_based(config: DebrisModelConfig) -> Self {
Self {
config,
feature_extractor: DebrisFeatureExtractor::default(),
material_weights: MaterialClassificationWeights::default(),
model_loaded: false,
#[cfg(feature = "onnx")]
session: None,
}
}
/// Check if ONNX model is loaded
pub fn is_loaded(&self) -> bool {
self.model_loaded
}
/// Classify material type from debris features
#[instrument(skip(self, features))]
pub async fn classify(&self, features: &DebrisFeatures) -> MlResult<DebrisClassification> {
#[cfg(feature = "onnx")]
if let Some(ref session) = self.session {
return self.classify_onnx(features, session).await;
}
// Fall back to rule-based classification
self.classify_rules(features)
}
/// ONNX-based classification
#[cfg(feature = "onnx")]
async fn classify_onnx(
&self,
features: &DebrisFeatures,
session: &Arc<RwLock<OnnxSession>>,
) -> MlResult<DebrisClassification> {
let input_features = self.feature_extractor.extract(features)?;
// Prepare input tensor
let input_array = Array4::from_shape_vec(
(1, 1, 1, input_features.len()),
input_features.iter().cloned().collect(),
).map_err(|e| MlError::Inference(e.to_string()))?;
let input_tensor = Tensor::Float4D(input_array);
let mut inputs = HashMap::new();
inputs.insert("input".to_string(), input_tensor);
// Run inference
let outputs = session.write().run(inputs)
.map_err(|e| MlError::NeuralNetwork(e))?;
// Extract classification probabilities
let probabilities = if let Some(output) = outputs.get("material_probs") {
output.to_vec()
.map_err(|e| MlError::Inference(e.to_string()))?
} else {
// Fallback to rule-based
return self.classify_rules(features);
};
// Ensure we have enough classes
let mut probs = vec![0.0f32; MaterialType::NUM_CLASSES];
for (i, &p) in probabilities.iter().take(MaterialType::NUM_CLASSES).enumerate() {
probs[i] = p;
}
// Apply softmax normalization
let max_val = probs.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_sum: f32 = probs.iter().map(|&x| (x - max_val).exp()).sum();
for p in &mut probs {
*p = (*p - max_val).exp() / exp_sum;
}
Ok(DebrisClassification::new(probs))
}
/// Rule-based material classification (fallback)
fn classify_rules(&self, features: &DebrisFeatures) -> MlResult<DebrisClassification> {
let mut scores = [0.0f32; MaterialType::NUM_CLASSES];
// Normalize input features
let attenuation_score = (features.snr_db.abs() / 30.0).min(1.0);
let delay_score = (features.delay_spread / 200.0).min(1.0);
let coherence_score = (features.coherence_bandwidth / 20.0).min(1.0);
let stability_score = features.temporal_stability;
// Compute weighted scores for each material
for i in 0..MaterialType::NUM_CLASSES {
scores[i] = self.material_weights.attenuation_weights[i] * attenuation_score
+ self.material_weights.delay_weights[i] * delay_score
+ self.material_weights.coherence_weights[i] * (1.0 - coherence_score)
+ self.material_weights.biases[i]
+ 0.1 * stability_score;
}
// Apply softmax
let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_sum: f32 = scores.iter().map(|&s| (s - max_score).exp()).sum();
let probabilities: Vec<f32> = scores.iter()
.map(|&s| (s - max_score).exp() / exp_sum)
.collect();
Ok(DebrisClassification::new(probabilities))
}
/// Predict signal attenuation through debris
#[instrument(skip(self, features))]
pub async fn predict_attenuation(&self, features: &DebrisFeatures) -> MlResult<AttenuationPrediction> {
// Get material classification first
let classification = self.classify(features).await?;
// Base attenuation from material type
let base_attenuation = classification.material_type.typical_attenuation();
// Adjust based on measured features
let measured_factor = if features.snr_db < 0.0 {
1.0 + (features.snr_db.abs() / 30.0).min(1.0)
} else {
1.0 - (features.snr_db / 30.0).min(0.5)
};
// Layer factor
let layer_factor = 1.0 + 0.2 * (classification.estimated_layers as f32 - 1.0);
// Composite factor
let composite_factor = if classification.is_composite { 1.2 } else { 1.0 };
let total_attenuation = base_attenuation * measured_factor * layer_factor * composite_factor;
// Uncertainty estimation
let uncertainty = if classification.is_composite {
total_attenuation * 0.3 // Higher uncertainty for composite
} else {
total_attenuation * (1.0 - classification.confidence) * 0.5
};
// Estimate depth (will be refined by depth estimation)
let estimated_depth = self.estimate_depth_internal(features, total_attenuation);
Ok(AttenuationPrediction::new(total_attenuation, estimated_depth, uncertainty))
}
/// Estimate penetration depth
#[instrument(skip(self, features))]
pub async fn estimate_depth(&self, features: &DebrisFeatures) -> MlResult<DepthEstimate> {
// Get attenuation prediction
let attenuation = self.predict_attenuation(features).await?;
// Estimate depth from attenuation and material properties
let depth = self.estimate_depth_internal(features, attenuation.attenuation_db);
// Calculate uncertainty
let uncertainty = self.calculate_depth_uncertainty(
features,
depth,
attenuation.confidence,
);
let confidence = (attenuation.confidence * features.temporal_stability).min(1.0);
Ok(DepthEstimate::new(depth, uncertainty, confidence))
}
/// Internal depth estimation logic
fn estimate_depth_internal(&self, features: &DebrisFeatures, attenuation_db: f32) -> f32 {
// Use coherence bandwidth for depth estimation
// Smaller coherence bandwidth suggests more multipath = deeper penetration
let cb_depth = (20.0 - features.coherence_bandwidth) / 5.0;
// Use delay spread
let ds_depth = features.delay_spread / 100.0;
// Use attenuation (assuming typical material)
let att_depth = attenuation_db / 15.0;
// Combine estimates with weights
let depth = 0.3 * cb_depth + 0.3 * ds_depth + 0.4 * att_depth;
// Clamp to reasonable range (0.1 - 10 meters)
depth.clamp(0.1, 10.0)
}
/// Calculate uncertainty in depth estimate
fn calculate_depth_uncertainty(
&self,
features: &DebrisFeatures,
depth: f32,
confidence: f32,
) -> f32 {
// Base uncertainty proportional to depth
let base_uncertainty = depth * 0.2;
// Adjust by temporal stability (less stable = more uncertain)
let stability_factor = 1.0 + (1.0 - features.temporal_stability) * 0.5;
// Adjust by confidence (lower confidence = more uncertain)
let confidence_factor = 1.0 + (1.0 - confidence) * 0.5;
// Adjust by multipath richness (more multipath = harder to estimate)
let multipath_factor = 1.0 + features.multipath_richness * 0.3;
base_uncertainty * stability_factor * confidence_factor * multipath_factor
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::detection::CsiDataBuffer;
fn create_test_debris_features() -> DebrisFeatures {
DebrisFeatures {
amplitude_attenuation: vec![0.5; 64],
phase_shifts: vec![0.1; 64],
fading_profile: vec![0.8, 0.6, 0.4, 0.2, 0.1, 0.05, 0.02, 0.01],
coherence_bandwidth: 5.0,
delay_spread: 100.0,
snr_db: 15.0,
multipath_richness: 0.6,
temporal_stability: 0.8,
}
}
#[test]
fn test_material_type() {
assert_eq!(MaterialType::from_index(0), MaterialType::Concrete);
assert_eq!(MaterialType::Concrete.to_index(), 0);
assert!(MaterialType::Concrete.typical_attenuation() > MaterialType::Glass.typical_attenuation());
}
#[test]
fn test_debris_classification() {
let probs = vec![0.7, 0.1, 0.05, 0.05, 0.05, 0.02, 0.02, 0.01];
let classification = DebrisClassification::new(probs);
assert_eq!(classification.material_type, MaterialType::Concrete);
assert!(classification.confidence > 0.6);
assert!(!classification.is_composite);
}
#[test]
fn test_composite_detection() {
let probs = vec![0.4, 0.35, 0.1, 0.05, 0.05, 0.02, 0.02, 0.01];
let classification = DebrisClassification::new(probs);
assert!(classification.is_composite);
assert_eq!(classification.material_type, MaterialType::Mixed);
}
#[test]
fn test_attenuation_prediction() {
let pred = AttenuationPrediction::new(25.0, 2.0, 3.0);
assert_eq!(pred.attenuation_per_meter, 12.5);
assert!(pred.confidence > 0.0);
}
#[tokio::test]
async fn test_rule_based_classification() {
let config = DebrisModelConfig::default();
let model = DebrisModel::rule_based(config);
let features = create_test_debris_features();
let result = model.classify(&features).await;
assert!(result.is_ok());
let classification = result.unwrap();
assert!(classification.confidence > 0.0);
}
#[tokio::test]
async fn test_depth_estimation() {
let config = DebrisModelConfig::default();
let model = DebrisModel::rule_based(config);
let features = create_test_debris_features();
let result = model.estimate_depth(&features).await;
assert!(result.is_ok());
let estimate = result.unwrap();
assert!(estimate.depth_meters > 0.0);
assert!(estimate.depth_meters < 10.0);
assert!(estimate.uncertainty_meters > 0.0);
}
#[test]
fn test_feature_extractor() {
let extractor = DebrisFeatureExtractor::default();
let features = create_test_debris_features();
let result = extractor.extract(&features);
assert!(result.is_ok());
let arr = result.unwrap();
assert_eq!(arr.shape()[0], 1);
assert_eq!(arr.shape()[1], 256);
}
#[test]
fn test_spatial_temporal_extraction() {
let extractor = DebrisFeatureExtractor::new(64, 100);
let features = create_test_debris_features();
let result = extractor.extract_spatial_temporal(&features);
assert!(result.is_ok());
let arr = result.unwrap();
assert_eq!(arr.shape(), &[1, 2, 64, 1]);
}
}

View File

@@ -0,0 +1,692 @@
//! Machine Learning module for debris penetration pattern recognition.
//!
//! This module provides ML-based models for:
//! - Debris material classification
//! - Penetration depth prediction
//! - Signal attenuation analysis
//! - Vital signs classification with uncertainty estimation
//!
//! ## Architecture
//!
//! The ML subsystem integrates with the `wifi-densepose-nn` crate for ONNX inference
//! and provides specialized models for disaster response scenarios.
//!
//! ```text
//! CSI Data -> Feature Extraction -> Model Inference -> Predictions
//! | | |
//! v v v
//! [Debris Features] [ONNX Models] [Classifications]
//! [Signal Features] [Neural Nets] [Confidences]
//! ```
mod debris_model;
mod vital_signs_classifier;
pub use debris_model::{
DebrisModel, DebrisModelConfig, DebrisFeatureExtractor,
MaterialType, DebrisClassification, AttenuationPrediction,
DebrisModelError,
};
pub use vital_signs_classifier::{
VitalSignsClassifier, VitalSignsClassifierConfig,
BreathingClassification, HeartbeatClassification,
UncertaintyEstimate, ClassifierOutput,
};
use crate::detection::CsiDataBuffer;
use crate::domain::{VitalSignsReading, BreathingPattern, HeartbeatSignature};
use async_trait::async_trait;
use std::path::Path;
use thiserror::Error;
/// Errors that can occur in ML operations
#[derive(Debug, Error)]
pub enum MlError {
/// Model loading error
#[error("Failed to load model: {0}")]
ModelLoad(String),
/// Inference error
#[error("Inference failed: {0}")]
Inference(String),
/// Feature extraction error
#[error("Feature extraction failed: {0}")]
FeatureExtraction(String),
/// Invalid input error
#[error("Invalid input: {0}")]
InvalidInput(String),
/// Model not initialized
#[error("Model not initialized: {0}")]
NotInitialized(String),
/// Configuration error
#[error("Configuration error: {0}")]
Config(String),
/// Integration error with wifi-densepose-nn
#[error("Neural network error: {0}")]
NeuralNetwork(#[from] wifi_densepose_nn::NnError),
}
/// Result type for ML operations
pub type MlResult<T> = Result<T, MlError>;
/// Trait for debris penetration models
///
/// This trait defines the interface for models that can predict
/// material type and signal attenuation through debris layers.
#[async_trait]
pub trait DebrisPenetrationModel: Send + Sync {
/// Classify the material type from CSI features
async fn classify_material(&self, features: &DebrisFeatures) -> MlResult<MaterialType>;
/// Predict signal attenuation through debris
async fn predict_attenuation(&self, features: &DebrisFeatures) -> MlResult<AttenuationPrediction>;
/// Estimate penetration depth in meters
async fn estimate_depth(&self, features: &DebrisFeatures) -> MlResult<DepthEstimate>;
/// Get model confidence for the predictions
fn model_confidence(&self) -> f32;
/// Check if the model is loaded and ready
fn is_ready(&self) -> bool;
}
/// Features extracted from CSI data for debris analysis
#[derive(Debug, Clone)]
pub struct DebrisFeatures {
/// Amplitude attenuation across subcarriers
pub amplitude_attenuation: Vec<f32>,
/// Phase shift patterns
pub phase_shifts: Vec<f32>,
/// Frequency-selective fading characteristics
pub fading_profile: Vec<f32>,
/// Coherence bandwidth estimate
pub coherence_bandwidth: f32,
/// RMS delay spread
pub delay_spread: f32,
/// Signal-to-noise ratio estimate
pub snr_db: f32,
/// Multipath richness indicator
pub multipath_richness: f32,
/// Temporal stability metric
pub temporal_stability: f32,
}
impl DebrisFeatures {
/// Create new debris features from raw CSI data
pub fn from_csi(buffer: &CsiDataBuffer) -> MlResult<Self> {
if buffer.amplitudes.is_empty() {
return Err(MlError::FeatureExtraction("Empty CSI buffer".into()));
}
// Calculate amplitude attenuation
let amplitude_attenuation = Self::compute_amplitude_features(&buffer.amplitudes);
// Calculate phase shifts
let phase_shifts = Self::compute_phase_features(&buffer.phases);
// Compute fading profile
let fading_profile = Self::compute_fading_profile(&buffer.amplitudes);
// Estimate coherence bandwidth from frequency correlation
let coherence_bandwidth = Self::estimate_coherence_bandwidth(&buffer.amplitudes);
// Estimate delay spread
let delay_spread = Self::estimate_delay_spread(&buffer.amplitudes);
// Estimate SNR
let snr_db = Self::estimate_snr(&buffer.amplitudes);
// Multipath richness
let multipath_richness = Self::compute_multipath_richness(&buffer.amplitudes);
// Temporal stability
let temporal_stability = Self::compute_temporal_stability(&buffer.amplitudes);
Ok(Self {
amplitude_attenuation,
phase_shifts,
fading_profile,
coherence_bandwidth,
delay_spread,
snr_db,
multipath_richness,
temporal_stability,
})
}
/// Compute amplitude features
fn compute_amplitude_features(amplitudes: &[f64]) -> Vec<f32> {
if amplitudes.is_empty() {
return vec![];
}
let mean = amplitudes.iter().sum::<f64>() / amplitudes.len() as f64;
let variance = amplitudes.iter()
.map(|a| (a - mean).powi(2))
.sum::<f64>() / amplitudes.len() as f64;
let std_dev = variance.sqrt();
// Normalize amplitudes
amplitudes.iter()
.map(|a| ((a - mean) / (std_dev + 1e-8)) as f32)
.collect()
}
/// Compute phase features
fn compute_phase_features(phases: &[f64]) -> Vec<f32> {
if phases.len() < 2 {
return vec![];
}
// Compute phase differences (unwrapped)
phases.windows(2)
.map(|w| {
let diff = w[1] - w[0];
// Unwrap phase
let unwrapped = if diff > std::f64::consts::PI {
diff - 2.0 * std::f64::consts::PI
} else if diff < -std::f64::consts::PI {
diff + 2.0 * std::f64::consts::PI
} else {
diff
};
unwrapped as f32
})
.collect()
}
/// Compute fading profile (power spectral characteristics)
fn compute_fading_profile(amplitudes: &[f64]) -> Vec<f32> {
use rustfft::{FftPlanner, num_complex::Complex};
if amplitudes.len() < 16 {
return vec![0.0; 8];
}
// Take a subset for FFT
let n = 64.min(amplitudes.len());
let mut buffer: Vec<Complex<f64>> = amplitudes.iter()
.take(n)
.map(|&a| Complex::new(a, 0.0))
.collect();
// Pad to power of 2
while buffer.len() < 64 {
buffer.push(Complex::new(0.0, 0.0));
}
// Compute FFT
let mut planner = FftPlanner::new();
let fft = planner.plan_fft_forward(64);
fft.process(&mut buffer);
// Extract power spectrum (first half)
buffer.iter()
.take(8)
.map(|c| (c.norm() / n as f64) as f32)
.collect()
}
/// Estimate coherence bandwidth from frequency correlation
fn estimate_coherence_bandwidth(amplitudes: &[f64]) -> f32 {
if amplitudes.len() < 10 {
return 0.0;
}
// Compute autocorrelation
let n = amplitudes.len();
let mean = amplitudes.iter().sum::<f64>() / n as f64;
let variance: f64 = amplitudes.iter()
.map(|a| (a - mean).powi(2))
.sum::<f64>() / n as f64;
if variance < 1e-10 {
return 0.0;
}
// Find lag where correlation drops below 0.5
let mut coherence_lag = n;
for lag in 1..n / 2 {
let correlation: f64 = amplitudes.iter()
.take(n - lag)
.zip(amplitudes.iter().skip(lag))
.map(|(a, b)| (a - mean) * (b - mean))
.sum::<f64>() / ((n - lag) as f64 * variance);
if correlation < 0.5 {
coherence_lag = lag;
break;
}
}
// Convert to bandwidth estimate (assuming 20 MHz channel)
(20.0 / coherence_lag as f32).min(20.0)
}
/// Estimate RMS delay spread
fn estimate_delay_spread(amplitudes: &[f64]) -> f32 {
if amplitudes.len() < 10 {
return 0.0;
}
// Use power delay profile approximation
let power: Vec<f64> = amplitudes.iter().map(|a| a.powi(2)).collect();
let total_power: f64 = power.iter().sum();
if total_power < 1e-10 {
return 0.0;
}
// Calculate mean delay
let mean_delay: f64 = power.iter()
.enumerate()
.map(|(i, p)| i as f64 * p)
.sum::<f64>() / total_power;
// Calculate RMS delay spread
let variance: f64 = power.iter()
.enumerate()
.map(|(i, p)| (i as f64 - mean_delay).powi(2) * p)
.sum::<f64>() / total_power;
// Convert to nanoseconds (assuming sample period)
(variance.sqrt() * 50.0) as f32 // 50 ns per sample assumed
}
/// Estimate SNR from amplitude variance
fn estimate_snr(amplitudes: &[f64]) -> f32 {
if amplitudes.is_empty() {
return 0.0;
}
let mean = amplitudes.iter().sum::<f64>() / amplitudes.len() as f64;
let variance = amplitudes.iter()
.map(|a| (a - mean).powi(2))
.sum::<f64>() / amplitudes.len() as f64;
if variance < 1e-10 {
return 30.0; // High SNR assumed
}
// SNR estimate based on signal power to noise power ratio
let signal_power = mean.powi(2);
let snr_linear = signal_power / variance;
(10.0 * snr_linear.log10()) as f32
}
/// Compute multipath richness indicator
fn compute_multipath_richness(amplitudes: &[f64]) -> f32 {
if amplitudes.len() < 10 {
return 0.0;
}
// Calculate amplitude variance as multipath indicator
let mean = amplitudes.iter().sum::<f64>() / amplitudes.len() as f64;
let variance = amplitudes.iter()
.map(|a| (a - mean).powi(2))
.sum::<f64>() / amplitudes.len() as f64;
// Normalize to 0-1 range
let std_dev = variance.sqrt();
let normalized = std_dev / (mean.abs() + 1e-8);
(normalized.min(1.0)) as f32
}
/// Compute temporal stability metric
fn compute_temporal_stability(amplitudes: &[f64]) -> f32 {
if amplitudes.len() < 2 {
return 1.0;
}
// Calculate coefficient of variation over time
let differences: Vec<f64> = amplitudes.windows(2)
.map(|w| (w[1] - w[0]).abs())
.collect();
let mean_diff = differences.iter().sum::<f64>() / differences.len() as f64;
let mean_amp = amplitudes.iter().sum::<f64>() / amplitudes.len() as f64;
// Stability is inverse of relative variation
let variation = mean_diff / (mean_amp.abs() + 1e-8);
(1.0 - variation.min(1.0)) as f32
}
/// Convert to feature vector for model input
pub fn to_feature_vector(&self) -> Vec<f32> {
let mut features = Vec::with_capacity(256);
// Add amplitude attenuation features (padded/truncated to 64)
let amp_len = self.amplitude_attenuation.len().min(64);
features.extend_from_slice(&self.amplitude_attenuation[..amp_len]);
features.resize(64, 0.0);
// Add phase shift features (padded/truncated to 64)
let phase_len = self.phase_shifts.len().min(64);
features.extend_from_slice(&self.phase_shifts[..phase_len]);
features.resize(128, 0.0);
// Add fading profile (padded to 16)
let fading_len = self.fading_profile.len().min(16);
features.extend_from_slice(&self.fading_profile[..fading_len]);
features.resize(144, 0.0);
// Add scalar features
features.push(self.coherence_bandwidth);
features.push(self.delay_spread);
features.push(self.snr_db);
features.push(self.multipath_richness);
features.push(self.temporal_stability);
// Pad to 256 for model input
features.resize(256, 0.0);
features
}
}
/// Depth estimate with uncertainty
#[derive(Debug, Clone)]
pub struct DepthEstimate {
/// Estimated depth in meters
pub depth_meters: f32,
/// Uncertainty (standard deviation) in meters
pub uncertainty_meters: f32,
/// Confidence in the estimate (0.0-1.0)
pub confidence: f32,
/// Lower bound of 95% confidence interval
pub lower_bound: f32,
/// Upper bound of 95% confidence interval
pub upper_bound: f32,
}
impl DepthEstimate {
/// Create a new depth estimate with uncertainty
pub fn new(depth: f32, uncertainty: f32, confidence: f32) -> Self {
Self {
depth_meters: depth,
uncertainty_meters: uncertainty,
confidence,
lower_bound: (depth - 1.96 * uncertainty).max(0.0),
upper_bound: depth + 1.96 * uncertainty,
}
}
/// Check if the estimate is reliable (high confidence, low uncertainty)
pub fn is_reliable(&self) -> bool {
self.confidence > 0.7 && self.uncertainty_meters < self.depth_meters * 0.3
}
}
/// Configuration for the ML-enhanced detection pipeline
#[derive(Debug, Clone, PartialEq)]
pub struct MlDetectionConfig {
/// Enable ML-based debris classification
pub enable_debris_classification: bool,
/// Enable ML-based vital signs classification
pub enable_vital_classification: bool,
/// Path to debris model file
pub debris_model_path: Option<String>,
/// Path to vital signs model file
pub vital_model_path: Option<String>,
/// Minimum confidence threshold for ML predictions
pub min_confidence: f32,
/// Use GPU for inference
pub use_gpu: bool,
/// Number of inference threads
pub num_threads: usize,
}
impl Default for MlDetectionConfig {
fn default() -> Self {
Self {
enable_debris_classification: false,
enable_vital_classification: false,
debris_model_path: None,
vital_model_path: None,
min_confidence: 0.5,
use_gpu: false,
num_threads: 4,
}
}
}
impl MlDetectionConfig {
/// Create configuration for CPU inference
pub fn cpu() -> Self {
Self::default()
}
/// Create configuration for GPU inference
pub fn gpu() -> Self {
Self {
use_gpu: true,
..Default::default()
}
}
/// Enable debris classification with model path
pub fn with_debris_model<P: Into<String>>(mut self, path: P) -> Self {
self.debris_model_path = Some(path.into());
self.enable_debris_classification = true;
self
}
/// Enable vital signs classification with model path
pub fn with_vital_model<P: Into<String>>(mut self, path: P) -> Self {
self.vital_model_path = Some(path.into());
self.enable_vital_classification = true;
self
}
/// Set minimum confidence threshold
pub fn with_min_confidence(mut self, confidence: f32) -> Self {
self.min_confidence = confidence.clamp(0.0, 1.0);
self
}
}
/// ML-enhanced detection pipeline that combines traditional and ML-based detection
pub struct MlDetectionPipeline {
config: MlDetectionConfig,
debris_model: Option<DebrisModel>,
vital_classifier: Option<VitalSignsClassifier>,
}
impl MlDetectionPipeline {
/// Create a new ML detection pipeline
pub fn new(config: MlDetectionConfig) -> Self {
Self {
config,
debris_model: None,
vital_classifier: None,
}
}
/// Initialize models asynchronously
pub async fn initialize(&mut self) -> MlResult<()> {
if self.config.enable_debris_classification {
if let Some(ref path) = self.config.debris_model_path {
let debris_config = DebrisModelConfig {
use_gpu: self.config.use_gpu,
num_threads: self.config.num_threads,
confidence_threshold: self.config.min_confidence,
};
self.debris_model = Some(DebrisModel::from_onnx(path, debris_config)?);
}
}
if self.config.enable_vital_classification {
if let Some(ref path) = self.config.vital_model_path {
let vital_config = VitalSignsClassifierConfig {
use_gpu: self.config.use_gpu,
num_threads: self.config.num_threads,
min_confidence: self.config.min_confidence,
enable_uncertainty: true,
mc_samples: 10,
dropout_rate: 0.1,
};
self.vital_classifier = Some(VitalSignsClassifier::from_onnx(path, vital_config)?);
}
}
Ok(())
}
/// Process CSI data and return enhanced detection results
pub async fn process(&self, buffer: &CsiDataBuffer) -> MlResult<MlDetectionResult> {
let mut result = MlDetectionResult::default();
// Extract debris features and classify if enabled
if let Some(ref model) = self.debris_model {
let features = DebrisFeatures::from_csi(buffer)?;
result.debris_classification = Some(model.classify(&features).await?);
result.depth_estimate = Some(model.estimate_depth(&features).await?);
}
// Classify vital signs if enabled
if let Some(ref classifier) = self.vital_classifier {
let features = classifier.extract_features(buffer)?;
result.vital_classification = Some(classifier.classify(&features).await?);
}
Ok(result)
}
/// Check if the pipeline is ready for inference
pub fn is_ready(&self) -> bool {
let debris_ready = !self.config.enable_debris_classification
|| self.debris_model.as_ref().map_or(false, |m| m.is_loaded());
let vital_ready = !self.config.enable_vital_classification
|| self.vital_classifier.as_ref().map_or(false, |c| c.is_loaded());
debris_ready && vital_ready
}
/// Get configuration
pub fn config(&self) -> &MlDetectionConfig {
&self.config
}
}
/// Combined ML detection results
#[derive(Debug, Clone, Default)]
pub struct MlDetectionResult {
/// Debris classification result
pub debris_classification: Option<DebrisClassification>,
/// Depth estimate
pub depth_estimate: Option<DepthEstimate>,
/// Vital signs classification
pub vital_classification: Option<ClassifierOutput>,
}
impl MlDetectionResult {
/// Check if any ML detection was performed
pub fn has_results(&self) -> bool {
self.debris_classification.is_some()
|| self.depth_estimate.is_some()
|| self.vital_classification.is_some()
}
/// Get overall confidence
pub fn overall_confidence(&self) -> f32 {
let mut total = 0.0;
let mut count = 0;
if let Some(ref debris) = self.debris_classification {
total += debris.confidence;
count += 1;
}
if let Some(ref depth) = self.depth_estimate {
total += depth.confidence;
count += 1;
}
if let Some(ref vital) = self.vital_classification {
total += vital.overall_confidence;
count += 1;
}
if count > 0 {
total / count as f32
} else {
0.0
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_buffer() -> CsiDataBuffer {
let mut buffer = CsiDataBuffer::new(1000.0);
let amplitudes: Vec<f64> = (0..1000)
.map(|i| {
let t = i as f64 / 1000.0;
0.5 + 0.1 * (2.0 * std::f64::consts::PI * 0.25 * t).sin()
})
.collect();
let phases: Vec<f64> = (0..1000)
.map(|i| {
let t = i as f64 / 1000.0;
(2.0 * std::f64::consts::PI * 0.25 * t).sin() * 0.3
})
.collect();
buffer.add_samples(&amplitudes, &phases);
buffer
}
#[test]
fn test_debris_features_extraction() {
let buffer = create_test_buffer();
let features = DebrisFeatures::from_csi(&buffer);
assert!(features.is_ok());
let features = features.unwrap();
assert!(!features.amplitude_attenuation.is_empty());
assert!(!features.phase_shifts.is_empty());
assert!(features.coherence_bandwidth >= 0.0);
assert!(features.delay_spread >= 0.0);
assert!(features.temporal_stability >= 0.0);
}
#[test]
fn test_feature_vector_size() {
let buffer = create_test_buffer();
let features = DebrisFeatures::from_csi(&buffer).unwrap();
let vector = features.to_feature_vector();
assert_eq!(vector.len(), 256);
}
#[test]
fn test_depth_estimate() {
let estimate = DepthEstimate::new(2.5, 0.3, 0.85);
assert!(estimate.is_reliable());
assert!(estimate.lower_bound < estimate.depth_meters);
assert!(estimate.upper_bound > estimate.depth_meters);
}
#[test]
fn test_ml_config_builder() {
let config = MlDetectionConfig::cpu()
.with_debris_model("models/debris.onnx")
.with_vital_model("models/vitals.onnx")
.with_min_confidence(0.7);
assert!(config.enable_debris_classification);
assert!(config.enable_vital_classification);
assert_eq!(config.min_confidence, 0.7);
assert!(!config.use_gpu);
}
}