feat: Add wifi-Mat disaster detection enhancements
Implement 6 optional enhancements for the wifi-Mat module: 1. Hardware Integration (csi_receiver.rs + hardware_adapter.rs) - ESP32 CSI support via serial/UDP - Intel 5300 BFEE file parsing - Atheros CSI Tool integration - Live UDP packet streaming - PCAP replay capability 2. CLI Commands (wifi-densepose-cli/src/mat.rs) - `wifi-mat scan` - Run disaster detection scan - `wifi-mat status` - Check event status - `wifi-mat zones` - Manage scan zones - `wifi-mat survivors` - List detected survivors - `wifi-mat alerts` - View and acknowledge alerts - `wifi-mat export` - Export data in various formats 3. REST API (wifi-densepose-mat/src/api/) - Full CRUD for disaster events - Zone management endpoints - Survivor and alert queries - WebSocket streaming for real-time updates - Comprehensive DTOs and error handling 4. WASM Build (wifi-densepose-wasm/src/mat.rs) - Browser-based disaster dashboard - Real-time survivor tracking - Zone visualization - Alert management - JavaScript API bindings 5. Detection Benchmarks (benches/detection_bench.rs) - Single survivor detection - Multi-survivor detection - Full pipeline benchmarks - Signal processing benchmarks - Hardware adapter benchmarks 6. ML Models for Debris Penetration (ml/) - DebrisModel for material analysis - VitalSignsClassifier for triage - FFT-based feature extraction - Bandpass filtering - Monte Carlo dropout for uncertainty All 134 unit tests pass. Compilation verified for: - wifi-densepose-mat - wifi-densepose-cli - wifi-densepose-wasm (with mat feature)
This commit is contained in:
@@ -0,0 +1,765 @@
|
||||
//! ONNX-based debris penetration model for material classification and depth prediction.
|
||||
//!
|
||||
//! This module provides neural network models for analyzing debris characteristics
|
||||
//! from WiFi CSI signals. Key capabilities include:
|
||||
//!
|
||||
//! - Material type classification (concrete, wood, metal, etc.)
|
||||
//! - Signal attenuation prediction based on material properties
|
||||
//! - Penetration depth estimation with uncertainty quantification
|
||||
//!
|
||||
//! ## Model Architecture
|
||||
//!
|
||||
//! The debris model uses a multi-head architecture:
|
||||
//! - Shared feature encoder (CNN-based)
|
||||
//! - Material classification head (softmax output)
|
||||
//! - Attenuation regression head (linear output)
|
||||
//! - Depth estimation head with uncertainty (mean + variance output)
|
||||
|
||||
use super::{DebrisFeatures, DepthEstimate, MlError, MlResult};
|
||||
use ndarray::{Array1, Array2, Array4, s};
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use parking_lot::RwLock;
|
||||
use thiserror::Error;
|
||||
use tracing::{debug, info, instrument, warn};
|
||||
|
||||
#[cfg(feature = "onnx")]
|
||||
use wifi_densepose_nn::{OnnxBackend, OnnxSession, InferenceOptions, Tensor, TensorShape};
|
||||
|
||||
/// Errors specific to debris model operations
|
||||
#[derive(Debug, Error)]
|
||||
pub enum DebrisModelError {
|
||||
/// Model file not found
|
||||
#[error("Model file not found: {0}")]
|
||||
FileNotFound(String),
|
||||
|
||||
/// Invalid model format
|
||||
#[error("Invalid model format: {0}")]
|
||||
InvalidFormat(String),
|
||||
|
||||
/// Inference error
|
||||
#[error("Inference failed: {0}")]
|
||||
InferenceFailed(String),
|
||||
|
||||
/// Feature extraction error
|
||||
#[error("Feature extraction failed: {0}")]
|
||||
FeatureExtractionFailed(String),
|
||||
}
|
||||
|
||||
/// Types of materials that can be detected in debris
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum MaterialType {
|
||||
/// Reinforced concrete (high attenuation)
|
||||
Concrete,
|
||||
/// Wood/timber (moderate attenuation)
|
||||
Wood,
|
||||
/// Metal/steel (very high attenuation, reflective)
|
||||
Metal,
|
||||
/// Glass (low attenuation)
|
||||
Glass,
|
||||
/// Brick/masonry (high attenuation)
|
||||
Brick,
|
||||
/// Drywall/plasterboard (low attenuation)
|
||||
Drywall,
|
||||
/// Mixed/composite materials
|
||||
Mixed,
|
||||
/// Unknown material type
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl MaterialType {
|
||||
/// Get typical attenuation coefficient (dB/m)
|
||||
pub fn typical_attenuation(&self) -> f32 {
|
||||
match self {
|
||||
MaterialType::Concrete => 25.0,
|
||||
MaterialType::Wood => 8.0,
|
||||
MaterialType::Metal => 50.0,
|
||||
MaterialType::Glass => 3.0,
|
||||
MaterialType::Brick => 18.0,
|
||||
MaterialType::Drywall => 4.0,
|
||||
MaterialType::Mixed => 15.0,
|
||||
MaterialType::Unknown => 12.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get typical delay spread (nanoseconds)
|
||||
pub fn typical_delay_spread(&self) -> f32 {
|
||||
match self {
|
||||
MaterialType::Concrete => 150.0,
|
||||
MaterialType::Wood => 50.0,
|
||||
MaterialType::Metal => 200.0,
|
||||
MaterialType::Glass => 20.0,
|
||||
MaterialType::Brick => 100.0,
|
||||
MaterialType::Drywall => 30.0,
|
||||
MaterialType::Mixed => 80.0,
|
||||
MaterialType::Unknown => 60.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// From class index
|
||||
pub fn from_index(index: usize) -> Self {
|
||||
match index {
|
||||
0 => MaterialType::Concrete,
|
||||
1 => MaterialType::Wood,
|
||||
2 => MaterialType::Metal,
|
||||
3 => MaterialType::Glass,
|
||||
4 => MaterialType::Brick,
|
||||
5 => MaterialType::Drywall,
|
||||
6 => MaterialType::Mixed,
|
||||
_ => MaterialType::Unknown,
|
||||
}
|
||||
}
|
||||
|
||||
/// To class index
|
||||
pub fn to_index(&self) -> usize {
|
||||
match self {
|
||||
MaterialType::Concrete => 0,
|
||||
MaterialType::Wood => 1,
|
||||
MaterialType::Metal => 2,
|
||||
MaterialType::Glass => 3,
|
||||
MaterialType::Brick => 4,
|
||||
MaterialType::Drywall => 5,
|
||||
MaterialType::Mixed => 6,
|
||||
MaterialType::Unknown => 7,
|
||||
}
|
||||
}
|
||||
|
||||
/// Number of material classes
|
||||
pub const NUM_CLASSES: usize = 8;
|
||||
}
|
||||
|
||||
impl std::fmt::Display for MaterialType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
MaterialType::Concrete => write!(f, "Concrete"),
|
||||
MaterialType::Wood => write!(f, "Wood"),
|
||||
MaterialType::Metal => write!(f, "Metal"),
|
||||
MaterialType::Glass => write!(f, "Glass"),
|
||||
MaterialType::Brick => write!(f, "Brick"),
|
||||
MaterialType::Drywall => write!(f, "Drywall"),
|
||||
MaterialType::Mixed => write!(f, "Mixed"),
|
||||
MaterialType::Unknown => write!(f, "Unknown"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of debris material classification
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DebrisClassification {
|
||||
/// Primary material type detected
|
||||
pub material_type: MaterialType,
|
||||
/// Confidence score for the classification (0.0-1.0)
|
||||
pub confidence: f32,
|
||||
/// Per-class probabilities
|
||||
pub class_probabilities: Vec<f32>,
|
||||
/// Estimated layer count
|
||||
pub estimated_layers: u8,
|
||||
/// Whether multiple materials detected
|
||||
pub is_composite: bool,
|
||||
}
|
||||
|
||||
impl DebrisClassification {
|
||||
/// Create a new debris classification
|
||||
pub fn new(probabilities: Vec<f32>) -> Self {
|
||||
let (max_idx, &max_prob) = probabilities.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
|
||||
.unwrap_or((7, &0.0));
|
||||
|
||||
// Check for composite materials (multiple high probabilities)
|
||||
let high_prob_count = probabilities.iter()
|
||||
.filter(|&&p| p > 0.2)
|
||||
.count();
|
||||
|
||||
let is_composite = high_prob_count > 1 && max_prob < 0.7;
|
||||
let material_type = if is_composite {
|
||||
MaterialType::Mixed
|
||||
} else {
|
||||
MaterialType::from_index(max_idx)
|
||||
};
|
||||
|
||||
// Estimate layer count from delay spread characteristics
|
||||
let estimated_layers = Self::estimate_layers(&probabilities);
|
||||
|
||||
Self {
|
||||
material_type,
|
||||
confidence: max_prob,
|
||||
class_probabilities: probabilities,
|
||||
estimated_layers,
|
||||
is_composite,
|
||||
}
|
||||
}
|
||||
|
||||
/// Estimate number of debris layers from probability distribution
|
||||
fn estimate_layers(probabilities: &[f32]) -> u8 {
|
||||
// More uniform distribution suggests more layers
|
||||
let entropy: f32 = probabilities.iter()
|
||||
.filter(|&&p| p > 0.01)
|
||||
.map(|&p| -p * p.ln())
|
||||
.sum();
|
||||
|
||||
let max_entropy = (probabilities.len() as f32).ln();
|
||||
let normalized_entropy = entropy / max_entropy;
|
||||
|
||||
// Map entropy to layer count (1-5)
|
||||
(1.0 + normalized_entropy * 4.0).round() as u8
|
||||
}
|
||||
|
||||
/// Get secondary material if composite
|
||||
pub fn secondary_material(&self) -> Option<MaterialType> {
|
||||
if !self.is_composite {
|
||||
return None;
|
||||
}
|
||||
|
||||
let primary_idx = self.material_type.to_index();
|
||||
self.class_probabilities.iter()
|
||||
.enumerate()
|
||||
.filter(|(i, _)| *i != primary_idx)
|
||||
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
|
||||
.map(|(i, _)| MaterialType::from_index(i))
|
||||
}
|
||||
}
|
||||
|
||||
/// Signal attenuation prediction result
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AttenuationPrediction {
|
||||
/// Predicted attenuation in dB
|
||||
pub attenuation_db: f32,
|
||||
/// Attenuation per meter (dB/m)
|
||||
pub attenuation_per_meter: f32,
|
||||
/// Uncertainty in the prediction
|
||||
pub uncertainty_db: f32,
|
||||
/// Frequency-dependent attenuation profile
|
||||
pub frequency_profile: Vec<f32>,
|
||||
/// Confidence in the prediction
|
||||
pub confidence: f32,
|
||||
}
|
||||
|
||||
impl AttenuationPrediction {
|
||||
/// Create new attenuation prediction
|
||||
pub fn new(attenuation: f32, depth: f32, uncertainty: f32) -> Self {
|
||||
let attenuation_per_meter = if depth > 0.0 {
|
||||
attenuation / depth
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
Self {
|
||||
attenuation_db: attenuation,
|
||||
attenuation_per_meter,
|
||||
uncertainty_db: uncertainty,
|
||||
frequency_profile: vec![],
|
||||
confidence: (1.0 - uncertainty / attenuation.abs().max(1.0)).max(0.0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Predict signal at given depth
|
||||
pub fn predict_signal_at_depth(&self, depth_m: f32) -> f32 {
|
||||
-self.attenuation_per_meter * depth_m
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for debris model
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DebrisModelConfig {
|
||||
/// Use GPU for inference
|
||||
pub use_gpu: bool,
|
||||
/// Number of inference threads
|
||||
pub num_threads: usize,
|
||||
/// Minimum confidence threshold
|
||||
pub confidence_threshold: f32,
|
||||
}
|
||||
|
||||
impl Default for DebrisModelConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
use_gpu: false,
|
||||
num_threads: 4,
|
||||
confidence_threshold: 0.5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Feature extractor for debris classification
|
||||
pub struct DebrisFeatureExtractor {
|
||||
/// Number of subcarriers to analyze
|
||||
num_subcarriers: usize,
|
||||
/// Window size for temporal analysis
|
||||
window_size: usize,
|
||||
/// Whether to use advanced features
|
||||
use_advanced_features: bool,
|
||||
}
|
||||
|
||||
impl Default for DebrisFeatureExtractor {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
num_subcarriers: 64,
|
||||
window_size: 100,
|
||||
use_advanced_features: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DebrisFeatureExtractor {
|
||||
/// Create new feature extractor
|
||||
pub fn new(num_subcarriers: usize, window_size: usize) -> Self {
|
||||
Self {
|
||||
num_subcarriers,
|
||||
window_size,
|
||||
use_advanced_features: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract features from debris features for model input
|
||||
pub fn extract(&self, features: &DebrisFeatures) -> MlResult<Array2<f32>> {
|
||||
let feature_vector = features.to_feature_vector();
|
||||
|
||||
// Reshape to 2D for model input (batch_size=1, features)
|
||||
let arr = Array2::from_shape_vec(
|
||||
(1, feature_vector.len()),
|
||||
feature_vector,
|
||||
).map_err(|e| MlError::FeatureExtraction(e.to_string()))?;
|
||||
|
||||
Ok(arr)
|
||||
}
|
||||
|
||||
/// Extract spatial-temporal features for CNN input
|
||||
pub fn extract_spatial_temporal(&self, features: &DebrisFeatures) -> MlResult<Array4<f32>> {
|
||||
let amp_len = features.amplitude_attenuation.len().min(self.num_subcarriers);
|
||||
let phase_len = features.phase_shifts.len().min(self.num_subcarriers);
|
||||
|
||||
// Create 4D tensor: [batch, channels, height, width]
|
||||
// channels: amplitude, phase
|
||||
// height: subcarriers
|
||||
// width: 1 (or temporal windows if available)
|
||||
let mut tensor = Array4::<f32>::zeros((1, 2, self.num_subcarriers, 1));
|
||||
|
||||
// Fill amplitude channel
|
||||
for (i, &v) in features.amplitude_attenuation.iter().take(amp_len).enumerate() {
|
||||
tensor[[0, 0, i, 0]] = v;
|
||||
}
|
||||
|
||||
// Fill phase channel
|
||||
for (i, &v) in features.phase_shifts.iter().take(phase_len).enumerate() {
|
||||
tensor[[0, 1, i, 0]] = v;
|
||||
}
|
||||
|
||||
Ok(tensor)
|
||||
}
|
||||
}
|
||||
|
||||
/// ONNX-based debris penetration model
|
||||
pub struct DebrisModel {
|
||||
config: DebrisModelConfig,
|
||||
feature_extractor: DebrisFeatureExtractor,
|
||||
/// Material classification model weights (for rule-based fallback)
|
||||
material_weights: MaterialClassificationWeights,
|
||||
/// Whether ONNX model is loaded
|
||||
model_loaded: bool,
|
||||
/// Cached model session
|
||||
#[cfg(feature = "onnx")]
|
||||
session: Option<Arc<RwLock<OnnxSession>>>,
|
||||
}
|
||||
|
||||
/// Pre-computed weights for rule-based material classification
|
||||
struct MaterialClassificationWeights {
|
||||
/// Weights for attenuation features
|
||||
attenuation_weights: [f32; MaterialType::NUM_CLASSES],
|
||||
/// Weights for delay spread features
|
||||
delay_weights: [f32; MaterialType::NUM_CLASSES],
|
||||
/// Weights for coherence bandwidth
|
||||
coherence_weights: [f32; MaterialType::NUM_CLASSES],
|
||||
/// Bias terms
|
||||
biases: [f32; MaterialType::NUM_CLASSES],
|
||||
}
|
||||
|
||||
impl Default for MaterialClassificationWeights {
|
||||
fn default() -> Self {
|
||||
// Pre-computed weights based on material RF properties
|
||||
Self {
|
||||
attenuation_weights: [0.8, 0.3, 0.95, 0.1, 0.6, 0.15, 0.5, 0.4],
|
||||
delay_weights: [0.7, 0.2, 0.9, 0.1, 0.5, 0.1, 0.4, 0.3],
|
||||
coherence_weights: [0.3, 0.7, 0.1, 0.9, 0.4, 0.8, 0.5, 0.5],
|
||||
biases: [-0.5, 0.2, -0.8, 0.5, -0.3, 0.3, 0.0, 0.0],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DebrisModel {
|
||||
/// Create a new debris model from ONNX file
|
||||
#[instrument(skip(path))]
|
||||
pub fn from_onnx<P: AsRef<Path>>(path: P, config: DebrisModelConfig) -> MlResult<Self> {
|
||||
let path_ref = path.as_ref();
|
||||
info!(?path_ref, "Loading debris model");
|
||||
|
||||
#[cfg(feature = "onnx")]
|
||||
let session = if path_ref.exists() {
|
||||
let options = InferenceOptions {
|
||||
use_gpu: config.use_gpu,
|
||||
num_threads: config.num_threads,
|
||||
..Default::default()
|
||||
};
|
||||
match OnnxSession::from_file(path_ref, &options) {
|
||||
Ok(s) => {
|
||||
info!("ONNX debris model loaded successfully");
|
||||
Some(Arc::new(RwLock::new(s)))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(?e, "Failed to load ONNX model, using rule-based fallback");
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
warn!(?path_ref, "Model file not found, using rule-based fallback");
|
||||
None
|
||||
};
|
||||
|
||||
#[cfg(feature = "onnx")]
|
||||
let model_loaded = session.is_some();
|
||||
|
||||
#[cfg(not(feature = "onnx"))]
|
||||
let model_loaded = false;
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
feature_extractor: DebrisFeatureExtractor::default(),
|
||||
material_weights: MaterialClassificationWeights::default(),
|
||||
model_loaded,
|
||||
#[cfg(feature = "onnx")]
|
||||
session,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create with in-memory model bytes
|
||||
#[cfg(feature = "onnx")]
|
||||
pub fn from_bytes(bytes: &[u8], config: DebrisModelConfig) -> MlResult<Self> {
|
||||
let options = InferenceOptions {
|
||||
use_gpu: config.use_gpu,
|
||||
num_threads: config.num_threads,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let session = OnnxSession::from_bytes(bytes, &options)
|
||||
.map_err(|e| MlError::ModelLoad(e.to_string()))?;
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
feature_extractor: DebrisFeatureExtractor::default(),
|
||||
material_weights: MaterialClassificationWeights::default(),
|
||||
model_loaded: true,
|
||||
session: Some(Arc::new(RwLock::new(session))),
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a rule-based model (no ONNX required)
|
||||
pub fn rule_based(config: DebrisModelConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
feature_extractor: DebrisFeatureExtractor::default(),
|
||||
material_weights: MaterialClassificationWeights::default(),
|
||||
model_loaded: false,
|
||||
#[cfg(feature = "onnx")]
|
||||
session: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if ONNX model is loaded
|
||||
pub fn is_loaded(&self) -> bool {
|
||||
self.model_loaded
|
||||
}
|
||||
|
||||
/// Classify material type from debris features
|
||||
#[instrument(skip(self, features))]
|
||||
pub async fn classify(&self, features: &DebrisFeatures) -> MlResult<DebrisClassification> {
|
||||
#[cfg(feature = "onnx")]
|
||||
if let Some(ref session) = self.session {
|
||||
return self.classify_onnx(features, session).await;
|
||||
}
|
||||
|
||||
// Fall back to rule-based classification
|
||||
self.classify_rules(features)
|
||||
}
|
||||
|
||||
/// ONNX-based classification
|
||||
#[cfg(feature = "onnx")]
|
||||
async fn classify_onnx(
|
||||
&self,
|
||||
features: &DebrisFeatures,
|
||||
session: &Arc<RwLock<OnnxSession>>,
|
||||
) -> MlResult<DebrisClassification> {
|
||||
let input_features = self.feature_extractor.extract(features)?;
|
||||
|
||||
// Prepare input tensor
|
||||
let input_array = Array4::from_shape_vec(
|
||||
(1, 1, 1, input_features.len()),
|
||||
input_features.iter().cloned().collect(),
|
||||
).map_err(|e| MlError::Inference(e.to_string()))?;
|
||||
|
||||
let input_tensor = Tensor::Float4D(input_array);
|
||||
|
||||
let mut inputs = HashMap::new();
|
||||
inputs.insert("input".to_string(), input_tensor);
|
||||
|
||||
// Run inference
|
||||
let outputs = session.write().run(inputs)
|
||||
.map_err(|e| MlError::NeuralNetwork(e))?;
|
||||
|
||||
// Extract classification probabilities
|
||||
let probabilities = if let Some(output) = outputs.get("material_probs") {
|
||||
output.to_vec()
|
||||
.map_err(|e| MlError::Inference(e.to_string()))?
|
||||
} else {
|
||||
// Fallback to rule-based
|
||||
return self.classify_rules(features);
|
||||
};
|
||||
|
||||
// Ensure we have enough classes
|
||||
let mut probs = vec![0.0f32; MaterialType::NUM_CLASSES];
|
||||
for (i, &p) in probabilities.iter().take(MaterialType::NUM_CLASSES).enumerate() {
|
||||
probs[i] = p;
|
||||
}
|
||||
|
||||
// Apply softmax normalization
|
||||
let max_val = probs.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let exp_sum: f32 = probs.iter().map(|&x| (x - max_val).exp()).sum();
|
||||
for p in &mut probs {
|
||||
*p = (*p - max_val).exp() / exp_sum;
|
||||
}
|
||||
|
||||
Ok(DebrisClassification::new(probs))
|
||||
}
|
||||
|
||||
/// Rule-based material classification (fallback)
|
||||
fn classify_rules(&self, features: &DebrisFeatures) -> MlResult<DebrisClassification> {
|
||||
let mut scores = [0.0f32; MaterialType::NUM_CLASSES];
|
||||
|
||||
// Normalize input features
|
||||
let attenuation_score = (features.snr_db.abs() / 30.0).min(1.0);
|
||||
let delay_score = (features.delay_spread / 200.0).min(1.0);
|
||||
let coherence_score = (features.coherence_bandwidth / 20.0).min(1.0);
|
||||
let stability_score = features.temporal_stability;
|
||||
|
||||
// Compute weighted scores for each material
|
||||
for i in 0..MaterialType::NUM_CLASSES {
|
||||
scores[i] = self.material_weights.attenuation_weights[i] * attenuation_score
|
||||
+ self.material_weights.delay_weights[i] * delay_score
|
||||
+ self.material_weights.coherence_weights[i] * (1.0 - coherence_score)
|
||||
+ self.material_weights.biases[i]
|
||||
+ 0.1 * stability_score;
|
||||
}
|
||||
|
||||
// Apply softmax
|
||||
let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let exp_sum: f32 = scores.iter().map(|&s| (s - max_score).exp()).sum();
|
||||
let probabilities: Vec<f32> = scores.iter()
|
||||
.map(|&s| (s - max_score).exp() / exp_sum)
|
||||
.collect();
|
||||
|
||||
Ok(DebrisClassification::new(probabilities))
|
||||
}
|
||||
|
||||
/// Predict signal attenuation through debris
|
||||
#[instrument(skip(self, features))]
|
||||
pub async fn predict_attenuation(&self, features: &DebrisFeatures) -> MlResult<AttenuationPrediction> {
|
||||
// Get material classification first
|
||||
let classification = self.classify(features).await?;
|
||||
|
||||
// Base attenuation from material type
|
||||
let base_attenuation = classification.material_type.typical_attenuation();
|
||||
|
||||
// Adjust based on measured features
|
||||
let measured_factor = if features.snr_db < 0.0 {
|
||||
1.0 + (features.snr_db.abs() / 30.0).min(1.0)
|
||||
} else {
|
||||
1.0 - (features.snr_db / 30.0).min(0.5)
|
||||
};
|
||||
|
||||
// Layer factor
|
||||
let layer_factor = 1.0 + 0.2 * (classification.estimated_layers as f32 - 1.0);
|
||||
|
||||
// Composite factor
|
||||
let composite_factor = if classification.is_composite { 1.2 } else { 1.0 };
|
||||
|
||||
let total_attenuation = base_attenuation * measured_factor * layer_factor * composite_factor;
|
||||
|
||||
// Uncertainty estimation
|
||||
let uncertainty = if classification.is_composite {
|
||||
total_attenuation * 0.3 // Higher uncertainty for composite
|
||||
} else {
|
||||
total_attenuation * (1.0 - classification.confidence) * 0.5
|
||||
};
|
||||
|
||||
// Estimate depth (will be refined by depth estimation)
|
||||
let estimated_depth = self.estimate_depth_internal(features, total_attenuation);
|
||||
|
||||
Ok(AttenuationPrediction::new(total_attenuation, estimated_depth, uncertainty))
|
||||
}
|
||||
|
||||
/// Estimate penetration depth
|
||||
#[instrument(skip(self, features))]
|
||||
pub async fn estimate_depth(&self, features: &DebrisFeatures) -> MlResult<DepthEstimate> {
|
||||
// Get attenuation prediction
|
||||
let attenuation = self.predict_attenuation(features).await?;
|
||||
|
||||
// Estimate depth from attenuation and material properties
|
||||
let depth = self.estimate_depth_internal(features, attenuation.attenuation_db);
|
||||
|
||||
// Calculate uncertainty
|
||||
let uncertainty = self.calculate_depth_uncertainty(
|
||||
features,
|
||||
depth,
|
||||
attenuation.confidence,
|
||||
);
|
||||
|
||||
let confidence = (attenuation.confidence * features.temporal_stability).min(1.0);
|
||||
|
||||
Ok(DepthEstimate::new(depth, uncertainty, confidence))
|
||||
}
|
||||
|
||||
/// Internal depth estimation logic
|
||||
fn estimate_depth_internal(&self, features: &DebrisFeatures, attenuation_db: f32) -> f32 {
|
||||
// Use coherence bandwidth for depth estimation
|
||||
// Smaller coherence bandwidth suggests more multipath = deeper penetration
|
||||
let cb_depth = (20.0 - features.coherence_bandwidth) / 5.0;
|
||||
|
||||
// Use delay spread
|
||||
let ds_depth = features.delay_spread / 100.0;
|
||||
|
||||
// Use attenuation (assuming typical material)
|
||||
let att_depth = attenuation_db / 15.0;
|
||||
|
||||
// Combine estimates with weights
|
||||
let depth = 0.3 * cb_depth + 0.3 * ds_depth + 0.4 * att_depth;
|
||||
|
||||
// Clamp to reasonable range (0.1 - 10 meters)
|
||||
depth.clamp(0.1, 10.0)
|
||||
}
|
||||
|
||||
/// Calculate uncertainty in depth estimate
|
||||
fn calculate_depth_uncertainty(
|
||||
&self,
|
||||
features: &DebrisFeatures,
|
||||
depth: f32,
|
||||
confidence: f32,
|
||||
) -> f32 {
|
||||
// Base uncertainty proportional to depth
|
||||
let base_uncertainty = depth * 0.2;
|
||||
|
||||
// Adjust by temporal stability (less stable = more uncertain)
|
||||
let stability_factor = 1.0 + (1.0 - features.temporal_stability) * 0.5;
|
||||
|
||||
// Adjust by confidence (lower confidence = more uncertain)
|
||||
let confidence_factor = 1.0 + (1.0 - confidence) * 0.5;
|
||||
|
||||
// Adjust by multipath richness (more multipath = harder to estimate)
|
||||
let multipath_factor = 1.0 + features.multipath_richness * 0.3;
|
||||
|
||||
base_uncertainty * stability_factor * confidence_factor * multipath_factor
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::detection::CsiDataBuffer;
|
||||
|
||||
fn create_test_debris_features() -> DebrisFeatures {
|
||||
DebrisFeatures {
|
||||
amplitude_attenuation: vec![0.5; 64],
|
||||
phase_shifts: vec![0.1; 64],
|
||||
fading_profile: vec![0.8, 0.6, 0.4, 0.2, 0.1, 0.05, 0.02, 0.01],
|
||||
coherence_bandwidth: 5.0,
|
||||
delay_spread: 100.0,
|
||||
snr_db: 15.0,
|
||||
multipath_richness: 0.6,
|
||||
temporal_stability: 0.8,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_material_type() {
|
||||
assert_eq!(MaterialType::from_index(0), MaterialType::Concrete);
|
||||
assert_eq!(MaterialType::Concrete.to_index(), 0);
|
||||
assert!(MaterialType::Concrete.typical_attenuation() > MaterialType::Glass.typical_attenuation());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_debris_classification() {
|
||||
let probs = vec![0.7, 0.1, 0.05, 0.05, 0.05, 0.02, 0.02, 0.01];
|
||||
let classification = DebrisClassification::new(probs);
|
||||
|
||||
assert_eq!(classification.material_type, MaterialType::Concrete);
|
||||
assert!(classification.confidence > 0.6);
|
||||
assert!(!classification.is_composite);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_composite_detection() {
|
||||
let probs = vec![0.4, 0.35, 0.1, 0.05, 0.05, 0.02, 0.02, 0.01];
|
||||
let classification = DebrisClassification::new(probs);
|
||||
|
||||
assert!(classification.is_composite);
|
||||
assert_eq!(classification.material_type, MaterialType::Mixed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attenuation_prediction() {
|
||||
let pred = AttenuationPrediction::new(25.0, 2.0, 3.0);
|
||||
assert_eq!(pred.attenuation_per_meter, 12.5);
|
||||
assert!(pred.confidence > 0.0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rule_based_classification() {
|
||||
let config = DebrisModelConfig::default();
|
||||
let model = DebrisModel::rule_based(config);
|
||||
|
||||
let features = create_test_debris_features();
|
||||
let result = model.classify(&features).await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
let classification = result.unwrap();
|
||||
assert!(classification.confidence > 0.0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_depth_estimation() {
|
||||
let config = DebrisModelConfig::default();
|
||||
let model = DebrisModel::rule_based(config);
|
||||
|
||||
let features = create_test_debris_features();
|
||||
let result = model.estimate_depth(&features).await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
let estimate = result.unwrap();
|
||||
assert!(estimate.depth_meters > 0.0);
|
||||
assert!(estimate.depth_meters < 10.0);
|
||||
assert!(estimate.uncertainty_meters > 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_feature_extractor() {
|
||||
let extractor = DebrisFeatureExtractor::default();
|
||||
let features = create_test_debris_features();
|
||||
|
||||
let result = extractor.extract(&features);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let arr = result.unwrap();
|
||||
assert_eq!(arr.shape()[0], 1);
|
||||
assert_eq!(arr.shape()[1], 256);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_spatial_temporal_extraction() {
|
||||
let extractor = DebrisFeatureExtractor::new(64, 100);
|
||||
let features = create_test_debris_features();
|
||||
|
||||
let result = extractor.extract_spatial_temporal(&features);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let arr = result.unwrap();
|
||||
assert_eq!(arr.shape(), &[1, 2, 64, 1]);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,692 @@
|
||||
//! Machine Learning module for debris penetration pattern recognition.
|
||||
//!
|
||||
//! This module provides ML-based models for:
|
||||
//! - Debris material classification
|
||||
//! - Penetration depth prediction
|
||||
//! - Signal attenuation analysis
|
||||
//! - Vital signs classification with uncertainty estimation
|
||||
//!
|
||||
//! ## Architecture
|
||||
//!
|
||||
//! The ML subsystem integrates with the `wifi-densepose-nn` crate for ONNX inference
|
||||
//! and provides specialized models for disaster response scenarios.
|
||||
//!
|
||||
//! ```text
|
||||
//! CSI Data -> Feature Extraction -> Model Inference -> Predictions
|
||||
//! | | |
|
||||
//! v v v
|
||||
//! [Debris Features] [ONNX Models] [Classifications]
|
||||
//! [Signal Features] [Neural Nets] [Confidences]
|
||||
//! ```
|
||||
|
||||
mod debris_model;
|
||||
mod vital_signs_classifier;
|
||||
|
||||
pub use debris_model::{
|
||||
DebrisModel, DebrisModelConfig, DebrisFeatureExtractor,
|
||||
MaterialType, DebrisClassification, AttenuationPrediction,
|
||||
DebrisModelError,
|
||||
};
|
||||
|
||||
pub use vital_signs_classifier::{
|
||||
VitalSignsClassifier, VitalSignsClassifierConfig,
|
||||
BreathingClassification, HeartbeatClassification,
|
||||
UncertaintyEstimate, ClassifierOutput,
|
||||
};
|
||||
|
||||
use crate::detection::CsiDataBuffer;
|
||||
use crate::domain::{VitalSignsReading, BreathingPattern, HeartbeatSignature};
|
||||
use async_trait::async_trait;
|
||||
use std::path::Path;
|
||||
use thiserror::Error;
|
||||
|
||||
/// Errors that can occur in ML operations
|
||||
#[derive(Debug, Error)]
|
||||
pub enum MlError {
|
||||
/// Model loading error
|
||||
#[error("Failed to load model: {0}")]
|
||||
ModelLoad(String),
|
||||
|
||||
/// Inference error
|
||||
#[error("Inference failed: {0}")]
|
||||
Inference(String),
|
||||
|
||||
/// Feature extraction error
|
||||
#[error("Feature extraction failed: {0}")]
|
||||
FeatureExtraction(String),
|
||||
|
||||
/// Invalid input error
|
||||
#[error("Invalid input: {0}")]
|
||||
InvalidInput(String),
|
||||
|
||||
/// Model not initialized
|
||||
#[error("Model not initialized: {0}")]
|
||||
NotInitialized(String),
|
||||
|
||||
/// Configuration error
|
||||
#[error("Configuration error: {0}")]
|
||||
Config(String),
|
||||
|
||||
/// Integration error with wifi-densepose-nn
|
||||
#[error("Neural network error: {0}")]
|
||||
NeuralNetwork(#[from] wifi_densepose_nn::NnError),
|
||||
}
|
||||
|
||||
/// Result type for ML operations
|
||||
pub type MlResult<T> = Result<T, MlError>;
|
||||
|
||||
/// Trait for debris penetration models
|
||||
///
|
||||
/// This trait defines the interface for models that can predict
|
||||
/// material type and signal attenuation through debris layers.
|
||||
#[async_trait]
|
||||
pub trait DebrisPenetrationModel: Send + Sync {
|
||||
/// Classify the material type from CSI features
|
||||
async fn classify_material(&self, features: &DebrisFeatures) -> MlResult<MaterialType>;
|
||||
|
||||
/// Predict signal attenuation through debris
|
||||
async fn predict_attenuation(&self, features: &DebrisFeatures) -> MlResult<AttenuationPrediction>;
|
||||
|
||||
/// Estimate penetration depth in meters
|
||||
async fn estimate_depth(&self, features: &DebrisFeatures) -> MlResult<DepthEstimate>;
|
||||
|
||||
/// Get model confidence for the predictions
|
||||
fn model_confidence(&self) -> f32;
|
||||
|
||||
/// Check if the model is loaded and ready
|
||||
fn is_ready(&self) -> bool;
|
||||
}
|
||||
|
||||
/// Features extracted from CSI data for debris analysis
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DebrisFeatures {
|
||||
/// Amplitude attenuation across subcarriers
|
||||
pub amplitude_attenuation: Vec<f32>,
|
||||
/// Phase shift patterns
|
||||
pub phase_shifts: Vec<f32>,
|
||||
/// Frequency-selective fading characteristics
|
||||
pub fading_profile: Vec<f32>,
|
||||
/// Coherence bandwidth estimate
|
||||
pub coherence_bandwidth: f32,
|
||||
/// RMS delay spread
|
||||
pub delay_spread: f32,
|
||||
/// Signal-to-noise ratio estimate
|
||||
pub snr_db: f32,
|
||||
/// Multipath richness indicator
|
||||
pub multipath_richness: f32,
|
||||
/// Temporal stability metric
|
||||
pub temporal_stability: f32,
|
||||
}
|
||||
|
||||
impl DebrisFeatures {
|
||||
/// Create new debris features from raw CSI data
|
||||
pub fn from_csi(buffer: &CsiDataBuffer) -> MlResult<Self> {
|
||||
if buffer.amplitudes.is_empty() {
|
||||
return Err(MlError::FeatureExtraction("Empty CSI buffer".into()));
|
||||
}
|
||||
|
||||
// Calculate amplitude attenuation
|
||||
let amplitude_attenuation = Self::compute_amplitude_features(&buffer.amplitudes);
|
||||
|
||||
// Calculate phase shifts
|
||||
let phase_shifts = Self::compute_phase_features(&buffer.phases);
|
||||
|
||||
// Compute fading profile
|
||||
let fading_profile = Self::compute_fading_profile(&buffer.amplitudes);
|
||||
|
||||
// Estimate coherence bandwidth from frequency correlation
|
||||
let coherence_bandwidth = Self::estimate_coherence_bandwidth(&buffer.amplitudes);
|
||||
|
||||
// Estimate delay spread
|
||||
let delay_spread = Self::estimate_delay_spread(&buffer.amplitudes);
|
||||
|
||||
// Estimate SNR
|
||||
let snr_db = Self::estimate_snr(&buffer.amplitudes);
|
||||
|
||||
// Multipath richness
|
||||
let multipath_richness = Self::compute_multipath_richness(&buffer.amplitudes);
|
||||
|
||||
// Temporal stability
|
||||
let temporal_stability = Self::compute_temporal_stability(&buffer.amplitudes);
|
||||
|
||||
Ok(Self {
|
||||
amplitude_attenuation,
|
||||
phase_shifts,
|
||||
fading_profile,
|
||||
coherence_bandwidth,
|
||||
delay_spread,
|
||||
snr_db,
|
||||
multipath_richness,
|
||||
temporal_stability,
|
||||
})
|
||||
}
|
||||
|
||||
/// Compute amplitude features
|
||||
fn compute_amplitude_features(amplitudes: &[f64]) -> Vec<f32> {
|
||||
if amplitudes.is_empty() {
|
||||
return vec![];
|
||||
}
|
||||
|
||||
let mean = amplitudes.iter().sum::<f64>() / amplitudes.len() as f64;
|
||||
let variance = amplitudes.iter()
|
||||
.map(|a| (a - mean).powi(2))
|
||||
.sum::<f64>() / amplitudes.len() as f64;
|
||||
let std_dev = variance.sqrt();
|
||||
|
||||
// Normalize amplitudes
|
||||
amplitudes.iter()
|
||||
.map(|a| ((a - mean) / (std_dev + 1e-8)) as f32)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Compute phase features
|
||||
fn compute_phase_features(phases: &[f64]) -> Vec<f32> {
|
||||
if phases.len() < 2 {
|
||||
return vec![];
|
||||
}
|
||||
|
||||
// Compute phase differences (unwrapped)
|
||||
phases.windows(2)
|
||||
.map(|w| {
|
||||
let diff = w[1] - w[0];
|
||||
// Unwrap phase
|
||||
let unwrapped = if diff > std::f64::consts::PI {
|
||||
diff - 2.0 * std::f64::consts::PI
|
||||
} else if diff < -std::f64::consts::PI {
|
||||
diff + 2.0 * std::f64::consts::PI
|
||||
} else {
|
||||
diff
|
||||
};
|
||||
unwrapped as f32
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Compute fading profile (power spectral characteristics)
|
||||
fn compute_fading_profile(amplitudes: &[f64]) -> Vec<f32> {
|
||||
use rustfft::{FftPlanner, num_complex::Complex};
|
||||
|
||||
if amplitudes.len() < 16 {
|
||||
return vec![0.0; 8];
|
||||
}
|
||||
|
||||
// Take a subset for FFT
|
||||
let n = 64.min(amplitudes.len());
|
||||
let mut buffer: Vec<Complex<f64>> = amplitudes.iter()
|
||||
.take(n)
|
||||
.map(|&a| Complex::new(a, 0.0))
|
||||
.collect();
|
||||
|
||||
// Pad to power of 2
|
||||
while buffer.len() < 64 {
|
||||
buffer.push(Complex::new(0.0, 0.0));
|
||||
}
|
||||
|
||||
// Compute FFT
|
||||
let mut planner = FftPlanner::new();
|
||||
let fft = planner.plan_fft_forward(64);
|
||||
fft.process(&mut buffer);
|
||||
|
||||
// Extract power spectrum (first half)
|
||||
buffer.iter()
|
||||
.take(8)
|
||||
.map(|c| (c.norm() / n as f64) as f32)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Estimate coherence bandwidth from frequency correlation
|
||||
fn estimate_coherence_bandwidth(amplitudes: &[f64]) -> f32 {
|
||||
if amplitudes.len() < 10 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Compute autocorrelation
|
||||
let n = amplitudes.len();
|
||||
let mean = amplitudes.iter().sum::<f64>() / n as f64;
|
||||
let variance: f64 = amplitudes.iter()
|
||||
.map(|a| (a - mean).powi(2))
|
||||
.sum::<f64>() / n as f64;
|
||||
|
||||
if variance < 1e-10 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Find lag where correlation drops below 0.5
|
||||
let mut coherence_lag = n;
|
||||
for lag in 1..n / 2 {
|
||||
let correlation: f64 = amplitudes.iter()
|
||||
.take(n - lag)
|
||||
.zip(amplitudes.iter().skip(lag))
|
||||
.map(|(a, b)| (a - mean) * (b - mean))
|
||||
.sum::<f64>() / ((n - lag) as f64 * variance);
|
||||
|
||||
if correlation < 0.5 {
|
||||
coherence_lag = lag;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to bandwidth estimate (assuming 20 MHz channel)
|
||||
(20.0 / coherence_lag as f32).min(20.0)
|
||||
}
|
||||
|
||||
/// Estimate RMS delay spread
|
||||
fn estimate_delay_spread(amplitudes: &[f64]) -> f32 {
|
||||
if amplitudes.len() < 10 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Use power delay profile approximation
|
||||
let power: Vec<f64> = amplitudes.iter().map(|a| a.powi(2)).collect();
|
||||
let total_power: f64 = power.iter().sum();
|
||||
|
||||
if total_power < 1e-10 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Calculate mean delay
|
||||
let mean_delay: f64 = power.iter()
|
||||
.enumerate()
|
||||
.map(|(i, p)| i as f64 * p)
|
||||
.sum::<f64>() / total_power;
|
||||
|
||||
// Calculate RMS delay spread
|
||||
let variance: f64 = power.iter()
|
||||
.enumerate()
|
||||
.map(|(i, p)| (i as f64 - mean_delay).powi(2) * p)
|
||||
.sum::<f64>() / total_power;
|
||||
|
||||
// Convert to nanoseconds (assuming sample period)
|
||||
(variance.sqrt() * 50.0) as f32 // 50 ns per sample assumed
|
||||
}
|
||||
|
||||
/// Estimate SNR from amplitude variance
|
||||
fn estimate_snr(amplitudes: &[f64]) -> f32 {
|
||||
if amplitudes.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let mean = amplitudes.iter().sum::<f64>() / amplitudes.len() as f64;
|
||||
let variance = amplitudes.iter()
|
||||
.map(|a| (a - mean).powi(2))
|
||||
.sum::<f64>() / amplitudes.len() as f64;
|
||||
|
||||
if variance < 1e-10 {
|
||||
return 30.0; // High SNR assumed
|
||||
}
|
||||
|
||||
// SNR estimate based on signal power to noise power ratio
|
||||
let signal_power = mean.powi(2);
|
||||
let snr_linear = signal_power / variance;
|
||||
|
||||
(10.0 * snr_linear.log10()) as f32
|
||||
}
|
||||
|
||||
/// Compute multipath richness indicator
|
||||
fn compute_multipath_richness(amplitudes: &[f64]) -> f32 {
|
||||
if amplitudes.len() < 10 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Calculate amplitude variance as multipath indicator
|
||||
let mean = amplitudes.iter().sum::<f64>() / amplitudes.len() as f64;
|
||||
let variance = amplitudes.iter()
|
||||
.map(|a| (a - mean).powi(2))
|
||||
.sum::<f64>() / amplitudes.len() as f64;
|
||||
|
||||
// Normalize to 0-1 range
|
||||
let std_dev = variance.sqrt();
|
||||
let normalized = std_dev / (mean.abs() + 1e-8);
|
||||
|
||||
(normalized.min(1.0)) as f32
|
||||
}
|
||||
|
||||
/// Compute temporal stability metric
|
||||
fn compute_temporal_stability(amplitudes: &[f64]) -> f32 {
|
||||
if amplitudes.len() < 2 {
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
// Calculate coefficient of variation over time
|
||||
let differences: Vec<f64> = amplitudes.windows(2)
|
||||
.map(|w| (w[1] - w[0]).abs())
|
||||
.collect();
|
||||
|
||||
let mean_diff = differences.iter().sum::<f64>() / differences.len() as f64;
|
||||
let mean_amp = amplitudes.iter().sum::<f64>() / amplitudes.len() as f64;
|
||||
|
||||
// Stability is inverse of relative variation
|
||||
let variation = mean_diff / (mean_amp.abs() + 1e-8);
|
||||
|
||||
(1.0 - variation.min(1.0)) as f32
|
||||
}
|
||||
|
||||
/// Convert to feature vector for model input
|
||||
pub fn to_feature_vector(&self) -> Vec<f32> {
|
||||
let mut features = Vec::with_capacity(256);
|
||||
|
||||
// Add amplitude attenuation features (padded/truncated to 64)
|
||||
let amp_len = self.amplitude_attenuation.len().min(64);
|
||||
features.extend_from_slice(&self.amplitude_attenuation[..amp_len]);
|
||||
features.resize(64, 0.0);
|
||||
|
||||
// Add phase shift features (padded/truncated to 64)
|
||||
let phase_len = self.phase_shifts.len().min(64);
|
||||
features.extend_from_slice(&self.phase_shifts[..phase_len]);
|
||||
features.resize(128, 0.0);
|
||||
|
||||
// Add fading profile (padded to 16)
|
||||
let fading_len = self.fading_profile.len().min(16);
|
||||
features.extend_from_slice(&self.fading_profile[..fading_len]);
|
||||
features.resize(144, 0.0);
|
||||
|
||||
// Add scalar features
|
||||
features.push(self.coherence_bandwidth);
|
||||
features.push(self.delay_spread);
|
||||
features.push(self.snr_db);
|
||||
features.push(self.multipath_richness);
|
||||
features.push(self.temporal_stability);
|
||||
|
||||
// Pad to 256 for model input
|
||||
features.resize(256, 0.0);
|
||||
|
||||
features
|
||||
}
|
||||
}
|
||||
|
||||
/// Depth estimate with uncertainty
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DepthEstimate {
|
||||
/// Estimated depth in meters
|
||||
pub depth_meters: f32,
|
||||
/// Uncertainty (standard deviation) in meters
|
||||
pub uncertainty_meters: f32,
|
||||
/// Confidence in the estimate (0.0-1.0)
|
||||
pub confidence: f32,
|
||||
/// Lower bound of 95% confidence interval
|
||||
pub lower_bound: f32,
|
||||
/// Upper bound of 95% confidence interval
|
||||
pub upper_bound: f32,
|
||||
}
|
||||
|
||||
impl DepthEstimate {
|
||||
/// Create a new depth estimate with uncertainty
|
||||
pub fn new(depth: f32, uncertainty: f32, confidence: f32) -> Self {
|
||||
Self {
|
||||
depth_meters: depth,
|
||||
uncertainty_meters: uncertainty,
|
||||
confidence,
|
||||
lower_bound: (depth - 1.96 * uncertainty).max(0.0),
|
||||
upper_bound: depth + 1.96 * uncertainty,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the estimate is reliable (high confidence, low uncertainty)
|
||||
pub fn is_reliable(&self) -> bool {
|
||||
self.confidence > 0.7 && self.uncertainty_meters < self.depth_meters * 0.3
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for the ML-enhanced detection pipeline
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct MlDetectionConfig {
|
||||
/// Enable ML-based debris classification
|
||||
pub enable_debris_classification: bool,
|
||||
/// Enable ML-based vital signs classification
|
||||
pub enable_vital_classification: bool,
|
||||
/// Path to debris model file
|
||||
pub debris_model_path: Option<String>,
|
||||
/// Path to vital signs model file
|
||||
pub vital_model_path: Option<String>,
|
||||
/// Minimum confidence threshold for ML predictions
|
||||
pub min_confidence: f32,
|
||||
/// Use GPU for inference
|
||||
pub use_gpu: bool,
|
||||
/// Number of inference threads
|
||||
pub num_threads: usize,
|
||||
}
|
||||
|
||||
impl Default for MlDetectionConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enable_debris_classification: false,
|
||||
enable_vital_classification: false,
|
||||
debris_model_path: None,
|
||||
vital_model_path: None,
|
||||
min_confidence: 0.5,
|
||||
use_gpu: false,
|
||||
num_threads: 4,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl MlDetectionConfig {
|
||||
/// Create configuration for CPU inference
|
||||
pub fn cpu() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Create configuration for GPU inference
|
||||
pub fn gpu() -> Self {
|
||||
Self {
|
||||
use_gpu: true,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Enable debris classification with model path
|
||||
pub fn with_debris_model<P: Into<String>>(mut self, path: P) -> Self {
|
||||
self.debris_model_path = Some(path.into());
|
||||
self.enable_debris_classification = true;
|
||||
self
|
||||
}
|
||||
|
||||
/// Enable vital signs classification with model path
|
||||
pub fn with_vital_model<P: Into<String>>(mut self, path: P) -> Self {
|
||||
self.vital_model_path = Some(path.into());
|
||||
self.enable_vital_classification = true;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set minimum confidence threshold
|
||||
pub fn with_min_confidence(mut self, confidence: f32) -> Self {
|
||||
self.min_confidence = confidence.clamp(0.0, 1.0);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// ML-enhanced detection pipeline that combines traditional and ML-based detection
|
||||
pub struct MlDetectionPipeline {
|
||||
config: MlDetectionConfig,
|
||||
debris_model: Option<DebrisModel>,
|
||||
vital_classifier: Option<VitalSignsClassifier>,
|
||||
}
|
||||
|
||||
impl MlDetectionPipeline {
|
||||
/// Create a new ML detection pipeline
|
||||
pub fn new(config: MlDetectionConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
debris_model: None,
|
||||
vital_classifier: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize models asynchronously
|
||||
pub async fn initialize(&mut self) -> MlResult<()> {
|
||||
if self.config.enable_debris_classification {
|
||||
if let Some(ref path) = self.config.debris_model_path {
|
||||
let debris_config = DebrisModelConfig {
|
||||
use_gpu: self.config.use_gpu,
|
||||
num_threads: self.config.num_threads,
|
||||
confidence_threshold: self.config.min_confidence,
|
||||
};
|
||||
self.debris_model = Some(DebrisModel::from_onnx(path, debris_config)?);
|
||||
}
|
||||
}
|
||||
|
||||
if self.config.enable_vital_classification {
|
||||
if let Some(ref path) = self.config.vital_model_path {
|
||||
let vital_config = VitalSignsClassifierConfig {
|
||||
use_gpu: self.config.use_gpu,
|
||||
num_threads: self.config.num_threads,
|
||||
min_confidence: self.config.min_confidence,
|
||||
enable_uncertainty: true,
|
||||
mc_samples: 10,
|
||||
dropout_rate: 0.1,
|
||||
};
|
||||
self.vital_classifier = Some(VitalSignsClassifier::from_onnx(path, vital_config)?);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Process CSI data and return enhanced detection results
|
||||
pub async fn process(&self, buffer: &CsiDataBuffer) -> MlResult<MlDetectionResult> {
|
||||
let mut result = MlDetectionResult::default();
|
||||
|
||||
// Extract debris features and classify if enabled
|
||||
if let Some(ref model) = self.debris_model {
|
||||
let features = DebrisFeatures::from_csi(buffer)?;
|
||||
result.debris_classification = Some(model.classify(&features).await?);
|
||||
result.depth_estimate = Some(model.estimate_depth(&features).await?);
|
||||
}
|
||||
|
||||
// Classify vital signs if enabled
|
||||
if let Some(ref classifier) = self.vital_classifier {
|
||||
let features = classifier.extract_features(buffer)?;
|
||||
result.vital_classification = Some(classifier.classify(&features).await?);
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Check if the pipeline is ready for inference
|
||||
pub fn is_ready(&self) -> bool {
|
||||
let debris_ready = !self.config.enable_debris_classification
|
||||
|| self.debris_model.as_ref().map_or(false, |m| m.is_loaded());
|
||||
let vital_ready = !self.config.enable_vital_classification
|
||||
|| self.vital_classifier.as_ref().map_or(false, |c| c.is_loaded());
|
||||
|
||||
debris_ready && vital_ready
|
||||
}
|
||||
|
||||
/// Get configuration
|
||||
pub fn config(&self) -> &MlDetectionConfig {
|
||||
&self.config
|
||||
}
|
||||
}
|
||||
|
||||
/// Combined ML detection results
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct MlDetectionResult {
|
||||
/// Debris classification result
|
||||
pub debris_classification: Option<DebrisClassification>,
|
||||
/// Depth estimate
|
||||
pub depth_estimate: Option<DepthEstimate>,
|
||||
/// Vital signs classification
|
||||
pub vital_classification: Option<ClassifierOutput>,
|
||||
}
|
||||
|
||||
impl MlDetectionResult {
|
||||
/// Check if any ML detection was performed
|
||||
pub fn has_results(&self) -> bool {
|
||||
self.debris_classification.is_some()
|
||||
|| self.depth_estimate.is_some()
|
||||
|| self.vital_classification.is_some()
|
||||
}
|
||||
|
||||
/// Get overall confidence
|
||||
pub fn overall_confidence(&self) -> f32 {
|
||||
let mut total = 0.0;
|
||||
let mut count = 0;
|
||||
|
||||
if let Some(ref debris) = self.debris_classification {
|
||||
total += debris.confidence;
|
||||
count += 1;
|
||||
}
|
||||
|
||||
if let Some(ref depth) = self.depth_estimate {
|
||||
total += depth.confidence;
|
||||
count += 1;
|
||||
}
|
||||
|
||||
if let Some(ref vital) = self.vital_classification {
|
||||
total += vital.overall_confidence;
|
||||
count += 1;
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
total / count as f32
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn create_test_buffer() -> CsiDataBuffer {
|
||||
let mut buffer = CsiDataBuffer::new(1000.0);
|
||||
let amplitudes: Vec<f64> = (0..1000)
|
||||
.map(|i| {
|
||||
let t = i as f64 / 1000.0;
|
||||
0.5 + 0.1 * (2.0 * std::f64::consts::PI * 0.25 * t).sin()
|
||||
})
|
||||
.collect();
|
||||
let phases: Vec<f64> = (0..1000)
|
||||
.map(|i| {
|
||||
let t = i as f64 / 1000.0;
|
||||
(2.0 * std::f64::consts::PI * 0.25 * t).sin() * 0.3
|
||||
})
|
||||
.collect();
|
||||
buffer.add_samples(&litudes, &phases);
|
||||
buffer
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_debris_features_extraction() {
|
||||
let buffer = create_test_buffer();
|
||||
let features = DebrisFeatures::from_csi(&buffer);
|
||||
assert!(features.is_ok());
|
||||
|
||||
let features = features.unwrap();
|
||||
assert!(!features.amplitude_attenuation.is_empty());
|
||||
assert!(!features.phase_shifts.is_empty());
|
||||
assert!(features.coherence_bandwidth >= 0.0);
|
||||
assert!(features.delay_spread >= 0.0);
|
||||
assert!(features.temporal_stability >= 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_feature_vector_size() {
|
||||
let buffer = create_test_buffer();
|
||||
let features = DebrisFeatures::from_csi(&buffer).unwrap();
|
||||
let vector = features.to_feature_vector();
|
||||
assert_eq!(vector.len(), 256);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_depth_estimate() {
|
||||
let estimate = DepthEstimate::new(2.5, 0.3, 0.85);
|
||||
assert!(estimate.is_reliable());
|
||||
assert!(estimate.lower_bound < estimate.depth_meters);
|
||||
assert!(estimate.upper_bound > estimate.depth_meters);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ml_config_builder() {
|
||||
let config = MlDetectionConfig::cpu()
|
||||
.with_debris_model("models/debris.onnx")
|
||||
.with_vital_model("models/vitals.onnx")
|
||||
.with_min_confidence(0.7);
|
||||
|
||||
assert!(config.enable_debris_classification);
|
||||
assert!(config.enable_vital_classification);
|
||||
assert_eq!(config.min_confidence, 0.7);
|
||||
assert!(!config.use_gpu);
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user