Compare commits
3 Commits
feat/windo
...
claude/tes
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b4a739402b | ||
|
|
2ca107c10c | ||
|
|
7c00482314 |
@@ -1,6 +1,6 @@
|
||||
//! Breathing pattern detection from CSI signals.
|
||||
|
||||
use crate::domain::{BreathingPattern, BreathingType, ConfidenceScore};
|
||||
use crate::domain::{BreathingPattern, BreathingType};
|
||||
|
||||
/// Configuration for breathing detection
|
||||
#[derive(Debug, Clone)]
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
//! This module provides both traditional signal-processing-based detection
|
||||
//! and optional ML-enhanced detection for improved accuracy.
|
||||
|
||||
use crate::domain::{ScanZone, VitalSignsReading, ConfidenceScore};
|
||||
use crate::domain::{ScanZone, VitalSignsReading};
|
||||
use crate::ml::{MlDetectionConfig, MlDetectionPipeline, MlDetectionResult};
|
||||
use crate::{DisasterConfig, MatError};
|
||||
use super::{
|
||||
|
||||
@@ -28,8 +28,6 @@ use chrono::{DateTime, Utc};
|
||||
use std::collections::VecDeque;
|
||||
use std::io::{BufReader, Read};
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
|
||||
/// Configuration for CSI receivers
|
||||
#[derive(Debug, Clone)]
|
||||
|
||||
@@ -16,13 +16,10 @@
|
||||
//! - Depth estimation head with uncertainty (mean + variance output)
|
||||
|
||||
use super::{DebrisFeatures, DepthEstimate, MlError, MlResult};
|
||||
use ndarray::{Array1, Array2, Array4, s};
|
||||
use std::collections::HashMap;
|
||||
use ndarray::{Array2, Array4};
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use parking_lot::RwLock;
|
||||
use thiserror::Error;
|
||||
use tracing::{debug, info, instrument, warn};
|
||||
use tracing::{info, instrument, warn};
|
||||
|
||||
#[cfg(feature = "onnx")]
|
||||
use wifi_densepose_nn::{OnnxBackend, OnnxSession, InferenceOptions, Tensor, TensorShape};
|
||||
|
||||
@@ -35,7 +35,6 @@ pub use vital_signs_classifier::{
|
||||
};
|
||||
|
||||
use crate::detection::CsiDataBuffer;
|
||||
use crate::domain::{VitalSignsReading, BreathingPattern, HeartbeatSignature};
|
||||
use async_trait::async_trait;
|
||||
use std::path::Path;
|
||||
use thiserror::Error;
|
||||
|
||||
@@ -27,12 +27,8 @@ use crate::domain::{
|
||||
BreathingPattern, BreathingType, HeartbeatSignature, MovementProfile,
|
||||
MovementType, SignalStrength, VitalSignsReading,
|
||||
};
|
||||
use ndarray::{Array1, Array2, Array4, s};
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use parking_lot::RwLock;
|
||||
use tracing::{debug, info, instrument, warn};
|
||||
use tracing::{info, instrument, warn};
|
||||
|
||||
#[cfg(feature = "onnx")]
|
||||
use wifi_densepose_nn::{OnnxBackend, OnnxSession, InferenceOptions, Tensor, TensorShape};
|
||||
|
||||
@@ -252,7 +252,7 @@ impl DensePoseHead {
|
||||
})?;
|
||||
|
||||
let input_arr = input.as_array4()?;
|
||||
let (batch, _channels, height, width) = input_arr.dim();
|
||||
let (_batch, _channels, _height, _width) = input_arr.dim();
|
||||
|
||||
// Apply shared convolutions
|
||||
let mut current = input_arr.clone();
|
||||
|
||||
@@ -206,7 +206,7 @@ impl Backend for MockBackend {
|
||||
self.output_shapes.get(name).cloned()
|
||||
}
|
||||
|
||||
fn run(&self, inputs: HashMap<String, Tensor>) -> NnResult<HashMap<String, Tensor>> {
|
||||
fn run(&self, _inputs: HashMap<String, Tensor>) -> NnResult<HashMap<String, Tensor>> {
|
||||
let mut outputs = HashMap::new();
|
||||
|
||||
for (name, shape) in &self.output_shapes {
|
||||
|
||||
@@ -266,7 +266,7 @@ impl Tensor {
|
||||
}
|
||||
|
||||
/// Apply softmax along axis
|
||||
pub fn softmax(&self, axis: usize) -> NnResult<Tensor> {
|
||||
pub fn softmax(&self, _axis: usize) -> NnResult<Tensor> {
|
||||
match self {
|
||||
Tensor::Float4D(a) => {
|
||||
let max = a.fold(f32::NEG_INFINITY, |acc, &x| acc.max(x));
|
||||
|
||||
@@ -342,7 +342,7 @@ impl ModalityTranslator {
|
||||
})?;
|
||||
|
||||
let input_arr = input.as_array4()?;
|
||||
let (batch, _channels, height, width) = input_arr.dim();
|
||||
let (_batch, _channels, _height, _width) = input_arr.dim();
|
||||
|
||||
// Encode
|
||||
let mut encoder_outputs = Vec::new();
|
||||
@@ -461,7 +461,7 @@ impl ModalityTranslator {
|
||||
weights: &ConvBlockWeights,
|
||||
) -> NnResult<Array4<f32>> {
|
||||
let (batch, in_channels, in_height, in_width) = input.dim();
|
||||
let (out_channels, _, kernel_h, kernel_w) = weights.conv_weight.dim();
|
||||
let (out_channels, _, _kernel_h, _kernel_w) = weights.conv_weight.dim();
|
||||
|
||||
// Upsample 2x
|
||||
let out_height = in_height * 2;
|
||||
@@ -536,7 +536,7 @@ impl ModalityTranslator {
|
||||
fn apply_attention(
|
||||
&self,
|
||||
input: &Array4<f32>,
|
||||
weights: &AttentionWeights,
|
||||
_weights: &AttentionWeights,
|
||||
) -> NnResult<(Array4<f32>, Array4<f32>)> {
|
||||
let (batch, channels, height, width) = input.dim();
|
||||
let seq_len = height * width;
|
||||
|
||||
@@ -29,7 +29,7 @@ Author: WiFi-DensePose Team
|
||||
License: MIT
|
||||
"""
|
||||
|
||||
__version__ = "1.1.0"
|
||||
__version__ = "1.2.0"
|
||||
__author__ = "WiFi-DensePose Team"
|
||||
__email__ = "team@wifi-densepose.com"
|
||||
__license__ = "MIT"
|
||||
|
||||
@@ -5,9 +5,27 @@ Core package for WiFi-DensePose API
|
||||
from .csi_processor import CSIProcessor
|
||||
from .phase_sanitizer import PhaseSanitizer
|
||||
from .router_interface import RouterInterface
|
||||
from .vital_signs import (
|
||||
VitalSignsDetector,
|
||||
BreathingDetector,
|
||||
HeartbeatDetector,
|
||||
BreathingPattern,
|
||||
HeartbeatSignature,
|
||||
VitalSignsReading,
|
||||
BreathingType,
|
||||
SignalStrength,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'CSIProcessor',
|
||||
'PhaseSanitizer',
|
||||
'RouterInterface'
|
||||
'RouterInterface',
|
||||
'VitalSignsDetector',
|
||||
'BreathingDetector',
|
||||
'HeartbeatDetector',
|
||||
'BreathingPattern',
|
||||
'HeartbeatSignature',
|
||||
'VitalSignsReading',
|
||||
'BreathingType',
|
||||
'SignalStrength',
|
||||
]
|
||||
@@ -385,13 +385,69 @@ class CSIProcessor:
|
||||
return correlation_matrix
|
||||
|
||||
def _extract_doppler_features(self, csi_data: CSIData) -> tuple:
|
||||
"""Extract Doppler and frequency domain features."""
|
||||
# Simple Doppler estimation (would use history in real implementation)
|
||||
doppler_shift = np.random.rand(10) # Placeholder
|
||||
"""Extract Doppler and frequency domain features.
|
||||
|
||||
# Power spectral density
|
||||
Doppler shift estimation from CSI phase changes:
|
||||
- Phase change rate indicates velocity of moving objects
|
||||
- Frequency analysis reveals movement speed and direction
|
||||
|
||||
The Doppler frequency shift is: f_d = (2 * v * f_c) / c
|
||||
Where v = velocity, f_c = carrier frequency, c = speed of light
|
||||
"""
|
||||
# Power spectral density of amplitude
|
||||
psd = np.abs(scipy.fft.fft(csi_data.amplitude.flatten(), n=128))**2
|
||||
|
||||
# Doppler estimation from phase history
|
||||
if len(self.csi_history) < 2:
|
||||
# Not enough history, return zeros
|
||||
doppler_shift = np.zeros(min(csi_data.num_subcarriers, 10))
|
||||
return doppler_shift, psd
|
||||
|
||||
# Get phase from current and previous samples
|
||||
current_phase = csi_data.phase.flatten()
|
||||
prev_data = self.csi_history[-1]
|
||||
|
||||
# Handle if prev_data is tuple (CSIData, features) or just CSIData
|
||||
if isinstance(prev_data, tuple):
|
||||
prev_phase = prev_data[0].phase.flatten()
|
||||
time_delta = (csi_data.timestamp - prev_data[0].timestamp).total_seconds()
|
||||
else:
|
||||
prev_phase = prev_data.phase.flatten()
|
||||
time_delta = 1.0 / self.sampling_rate # Default to sampling interval
|
||||
|
||||
if time_delta <= 0:
|
||||
time_delta = 1.0 / self.sampling_rate
|
||||
|
||||
# Ensure same length
|
||||
min_len = min(len(current_phase), len(prev_phase))
|
||||
current_phase = current_phase[:min_len]
|
||||
prev_phase = prev_phase[:min_len]
|
||||
|
||||
# Calculate phase difference (unwrap to handle wrapping)
|
||||
phase_diff = np.unwrap(current_phase) - np.unwrap(prev_phase)
|
||||
|
||||
# Phase rate of change (rad/s)
|
||||
phase_rate = phase_diff / time_delta
|
||||
|
||||
# Convert to Doppler frequency (Hz)
|
||||
# f_d = (d_phi/dt) / (2 * pi)
|
||||
doppler_freq = phase_rate / (2 * np.pi)
|
||||
|
||||
# Aggregate Doppler per subcarrier group (reduce to ~10 values)
|
||||
num_groups = min(10, len(doppler_freq))
|
||||
group_size = max(1, len(doppler_freq) // num_groups)
|
||||
|
||||
doppler_shift = np.array([
|
||||
np.mean(doppler_freq[i*group_size:(i+1)*group_size])
|
||||
for i in range(num_groups)
|
||||
])
|
||||
|
||||
# Apply smoothing to reduce noise
|
||||
if len(doppler_shift) > 3:
|
||||
# Simple moving average
|
||||
kernel = np.ones(3) / 3
|
||||
doppler_shift = np.convolve(doppler_shift, kernel, mode='same')
|
||||
|
||||
return doppler_shift, psd
|
||||
|
||||
def _analyze_motion_patterns(self, features: CSIFeatures) -> float:
|
||||
|
||||
@@ -1,15 +1,27 @@
|
||||
"""
|
||||
Router interface for WiFi CSI data collection
|
||||
Router interface for WiFi CSI data collection.
|
||||
|
||||
Supports multiple router types:
|
||||
- OpenWRT routers with Atheros CSI Tool
|
||||
- DD-WRT routers with custom CSI extraction
|
||||
- Custom firmware routers with raw CSI access
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import struct
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import asyncssh
|
||||
HAS_ASYNCSSH = True
|
||||
except ImportError:
|
||||
HAS_ASYNCSSH = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -72,18 +84,34 @@ class RouterInterface:
|
||||
}
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to the router."""
|
||||
"""Connect to the router via SSH."""
|
||||
if self.mock_mode:
|
||||
self.is_connected = True
|
||||
self.logger.info(f"Mock connection established to router {self.router_id}")
|
||||
return
|
||||
|
||||
if not HAS_ASYNCSSH:
|
||||
self.logger.warning("asyncssh not available, falling back to mock mode")
|
||||
self.mock_mode = True
|
||||
self._initialize_mock_generator()
|
||||
self.is_connected = True
|
||||
return
|
||||
|
||||
try:
|
||||
self.logger.info(f"Connecting to router {self.router_id} at {self.host}:{self.port}")
|
||||
|
||||
# In a real implementation, this would establish SSH connection
|
||||
# For now, we'll simulate the connection
|
||||
await asyncio.sleep(0.1) # Simulate connection delay
|
||||
# Establish SSH connection
|
||||
self.connection = await asyncssh.connect(
|
||||
self.host,
|
||||
port=self.port,
|
||||
username=self.username,
|
||||
password=self.password if self.password else None,
|
||||
known_hosts=None, # Disable host key checking for embedded devices
|
||||
connect_timeout=10
|
||||
)
|
||||
|
||||
# Verify connection by checking router type
|
||||
await self._detect_router_type()
|
||||
|
||||
self.is_connected = True
|
||||
self.error_count = 0
|
||||
@@ -95,6 +123,42 @@ class RouterInterface:
|
||||
self.logger.error(f"Failed to connect to router {self.router_id}: {e}")
|
||||
raise
|
||||
|
||||
async def _detect_router_type(self):
|
||||
"""Detect router firmware type and CSI capabilities."""
|
||||
if not self.connection:
|
||||
return
|
||||
|
||||
try:
|
||||
# Check for OpenWRT
|
||||
result = await self.connection.run('cat /etc/openwrt_release 2>/dev/null || echo ""', check=False)
|
||||
if 'OpenWrt' in result.stdout:
|
||||
self.router_type = 'openwrt'
|
||||
self.logger.info(f"Detected OpenWRT router: {self.router_id}")
|
||||
return
|
||||
|
||||
# Check for DD-WRT
|
||||
result = await self.connection.run('nvram get DD_BOARD 2>/dev/null || echo ""', check=False)
|
||||
if result.stdout.strip():
|
||||
self.router_type = 'ddwrt'
|
||||
self.logger.info(f"Detected DD-WRT router: {self.router_id}")
|
||||
return
|
||||
|
||||
# Check for Atheros CSI Tool
|
||||
result = await self.connection.run('which csi_tool 2>/dev/null || echo ""', check=False)
|
||||
if result.stdout.strip():
|
||||
self.csi_tool_path = result.stdout.strip()
|
||||
self.router_type = 'atheros_csi'
|
||||
self.logger.info(f"Detected Atheros CSI Tool on router: {self.router_id}")
|
||||
return
|
||||
|
||||
# Default to generic Linux
|
||||
self.router_type = 'generic'
|
||||
self.logger.info(f"Generic Linux router: {self.router_id}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Could not detect router type: {e}")
|
||||
self.router_type = 'unknown'
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnect from the router."""
|
||||
try:
|
||||
@@ -195,10 +259,243 @@ class RouterInterface:
|
||||
return csi_data
|
||||
|
||||
async def _collect_real_csi_data(self) -> Optional[np.ndarray]:
|
||||
"""Collect real CSI data from router (placeholder implementation)."""
|
||||
# This would implement the actual CSI data collection
|
||||
# For now, return None to indicate no real implementation
|
||||
self.logger.warning("Real CSI data collection not implemented")
|
||||
"""Collect real CSI data from router via SSH.
|
||||
|
||||
Supports multiple CSI extraction methods:
|
||||
- Atheros CSI Tool (ath9k/ath10k)
|
||||
- Custom kernel module reading
|
||||
- Proc filesystem access
|
||||
- Raw device file reading
|
||||
|
||||
Returns:
|
||||
Numpy array of complex CSI values or None on failure
|
||||
"""
|
||||
if not self.connection:
|
||||
self.logger.error("No SSH connection available")
|
||||
return None
|
||||
|
||||
try:
|
||||
router_type = getattr(self, 'router_type', 'unknown')
|
||||
|
||||
if router_type == 'atheros_csi':
|
||||
return await self._collect_atheros_csi()
|
||||
elif router_type == 'openwrt':
|
||||
return await self._collect_openwrt_csi()
|
||||
else:
|
||||
return await self._collect_generic_csi()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error collecting CSI data: {e}")
|
||||
self.error_count += 1
|
||||
return None
|
||||
|
||||
async def _collect_atheros_csi(self) -> Optional[np.ndarray]:
|
||||
"""Collect CSI using Atheros CSI Tool."""
|
||||
csi_tool = getattr(self, 'csi_tool_path', '/usr/bin/csi_tool')
|
||||
|
||||
try:
|
||||
# Read single CSI sample
|
||||
result = await self.connection.run(
|
||||
f'{csi_tool} -i {self.interface} -c 1 -f /tmp/csi_sample.dat && '
|
||||
f'cat /tmp/csi_sample.dat | base64',
|
||||
check=True,
|
||||
timeout=5
|
||||
)
|
||||
|
||||
# Decode base64 CSI data
|
||||
import base64
|
||||
csi_bytes = base64.b64decode(result.stdout.strip())
|
||||
|
||||
return self._parse_atheros_csi_bytes(csi_bytes)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Atheros CSI collection failed: {e}")
|
||||
return None
|
||||
|
||||
async def _collect_openwrt_csi(self) -> Optional[np.ndarray]:
|
||||
"""Collect CSI from OpenWRT with CSI support."""
|
||||
try:
|
||||
# Try reading from debugfs (common CSI location)
|
||||
result = await self.connection.run(
|
||||
f'cat /sys/kernel/debug/ieee80211/phy0/ath9k/csi 2>/dev/null | head -c 4096 | base64',
|
||||
check=False,
|
||||
timeout=5
|
||||
)
|
||||
|
||||
if result.returncode == 0 and result.stdout.strip():
|
||||
import base64
|
||||
csi_bytes = base64.b64decode(result.stdout.strip())
|
||||
return self._parse_atheros_csi_bytes(csi_bytes)
|
||||
|
||||
# Try alternate location
|
||||
result = await self.connection.run(
|
||||
f'cat /proc/csi 2>/dev/null | head -c 4096 | base64',
|
||||
check=False,
|
||||
timeout=5
|
||||
)
|
||||
|
||||
if result.returncode == 0 and result.stdout.strip():
|
||||
import base64
|
||||
csi_bytes = base64.b64decode(result.stdout.strip())
|
||||
return self._parse_generic_csi_bytes(csi_bytes)
|
||||
|
||||
self.logger.warning("No CSI data available from OpenWRT paths")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"OpenWRT CSI collection failed: {e}")
|
||||
return None
|
||||
|
||||
async def _collect_generic_csi(self) -> Optional[np.ndarray]:
|
||||
"""Collect CSI using generic Linux methods."""
|
||||
try:
|
||||
# Try iw command for station info (not real CSI but channel info)
|
||||
result = await self.connection.run(
|
||||
f'iw dev {self.interface} survey dump 2>/dev/null || echo ""',
|
||||
check=False,
|
||||
timeout=5
|
||||
)
|
||||
|
||||
if result.stdout.strip():
|
||||
# Parse survey data for channel metrics
|
||||
return self._parse_survey_data(result.stdout)
|
||||
|
||||
self.logger.warning("No CSI data available via generic methods")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Generic CSI collection failed: {e}")
|
||||
return None
|
||||
|
||||
def _parse_atheros_csi_bytes(self, data: bytes) -> Optional[np.ndarray]:
|
||||
"""Parse Atheros CSI Tool binary format.
|
||||
|
||||
Format:
|
||||
- 4 bytes: magic (0x11111111)
|
||||
- 8 bytes: timestamp
|
||||
- 2 bytes: channel
|
||||
- 1 byte: bandwidth
|
||||
- 1 byte: num_rx_antennas
|
||||
- 1 byte: num_tx_antennas
|
||||
- 1 byte: num_tones
|
||||
- 2 bytes: RSSI
|
||||
- Remaining: CSI matrix as int16 I/Q pairs
|
||||
"""
|
||||
if len(data) < 20:
|
||||
return None
|
||||
|
||||
try:
|
||||
magic = struct.unpack('<I', data[0:4])[0]
|
||||
if magic != 0x11111111:
|
||||
# Try different offset or format
|
||||
return self._parse_generic_csi_bytes(data)
|
||||
|
||||
# Parse header
|
||||
timestamp = struct.unpack('<Q', data[4:12])[0]
|
||||
channel = struct.unpack('<H', data[12:14])[0]
|
||||
bw = struct.unpack('<B', data[14:15])[0]
|
||||
nr = struct.unpack('<B', data[15:16])[0]
|
||||
nc = struct.unpack('<B', data[16:17])[0]
|
||||
num_tones = struct.unpack('<B', data[17:18])[0]
|
||||
|
||||
if nr == 0 or num_tones == 0:
|
||||
return None
|
||||
|
||||
# Parse CSI matrix
|
||||
csi_data = data[20:]
|
||||
csi_matrix = np.zeros((nr, num_tones), dtype=complex)
|
||||
|
||||
for ant in range(nr):
|
||||
for tone in range(num_tones):
|
||||
offset = (ant * num_tones + tone) * 4
|
||||
if offset + 4 <= len(csi_data):
|
||||
real, imag = struct.unpack('<hh', csi_data[offset:offset+4])
|
||||
csi_matrix[ant, tone] = complex(real, imag)
|
||||
|
||||
return csi_matrix
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error parsing Atheros CSI: {e}")
|
||||
return None
|
||||
|
||||
def _parse_generic_csi_bytes(self, data: bytes) -> Optional[np.ndarray]:
|
||||
"""Parse generic binary CSI format."""
|
||||
if len(data) < 8:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Assume simple format: int16 I/Q pairs
|
||||
num_samples = len(data) // 4
|
||||
if num_samples == 0:
|
||||
return None
|
||||
|
||||
# Default to 56 subcarriers (20MHz), adjust antennas
|
||||
num_tones = min(56, num_samples)
|
||||
num_antennas = max(1, num_samples // num_tones)
|
||||
|
||||
csi_matrix = np.zeros((num_antennas, num_tones), dtype=complex)
|
||||
|
||||
for i in range(min(num_samples, num_antennas * num_tones)):
|
||||
offset = i * 4
|
||||
if offset + 4 <= len(data):
|
||||
real, imag = struct.unpack('<hh', data[offset:offset+4])
|
||||
ant = i // num_tones
|
||||
tone = i % num_tones
|
||||
if ant < num_antennas and tone < num_tones:
|
||||
csi_matrix[ant, tone] = complex(real, imag)
|
||||
|
||||
return csi_matrix
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error parsing generic CSI: {e}")
|
||||
return None
|
||||
|
||||
def _parse_survey_data(self, survey_output: str) -> Optional[np.ndarray]:
|
||||
"""Parse iw survey dump output to extract channel metrics.
|
||||
|
||||
This isn't true CSI but provides per-channel noise and activity data
|
||||
that can be used as a fallback.
|
||||
"""
|
||||
try:
|
||||
lines = survey_output.strip().split('\n')
|
||||
noise_values = []
|
||||
busy_values = []
|
||||
|
||||
for line in lines:
|
||||
if 'noise:' in line.lower():
|
||||
parts = line.split()
|
||||
for i, p in enumerate(parts):
|
||||
if p == 'dBm' and i > 0:
|
||||
try:
|
||||
noise_values.append(float(parts[i-1]))
|
||||
except ValueError:
|
||||
pass
|
||||
elif 'channel busy time:' in line.lower():
|
||||
parts = line.split()
|
||||
for i, p in enumerate(parts):
|
||||
if p == 'ms' and i > 0:
|
||||
try:
|
||||
busy_values.append(float(parts[i-1]))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if noise_values:
|
||||
# Create pseudo-CSI from noise measurements
|
||||
num_channels = len(noise_values)
|
||||
csi_matrix = np.zeros((1, max(56, num_channels)), dtype=complex)
|
||||
|
||||
for i, noise in enumerate(noise_values):
|
||||
# Convert noise dBm to amplitude (simplified)
|
||||
amplitude = 10 ** (noise / 20)
|
||||
phase = 0 if i >= len(busy_values) else busy_values[i] / 1000 * np.pi
|
||||
csi_matrix[0, i] = amplitude * np.exp(1j * phase)
|
||||
|
||||
return csi_matrix
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error parsing survey data: {e}")
|
||||
return None
|
||||
|
||||
async def check_health(self) -> bool:
|
||||
|
||||
566
v1/src/core/vital_signs.py
Normal file
566
v1/src/core/vital_signs.py
Normal file
@@ -0,0 +1,566 @@
|
||||
"""Vital signs detection from CSI signals.
|
||||
|
||||
This module provides breathing and heartbeat detection capabilities
|
||||
mirroring the Rust wifi-densepose-mat crate functionality.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional, Tuple
|
||||
from datetime import datetime, timezone
|
||||
import scipy.signal
|
||||
import scipy.fft
|
||||
|
||||
|
||||
class BreathingType(Enum):
|
||||
"""Types of breathing patterns."""
|
||||
NORMAL = "normal"
|
||||
SHALLOW = "shallow"
|
||||
DEEP = "deep"
|
||||
RAPID = "rapid"
|
||||
IRREGULAR = "irregular"
|
||||
APNEA = "apnea"
|
||||
|
||||
|
||||
class SignalStrength(Enum):
|
||||
"""Signal strength classification."""
|
||||
STRONG = "strong"
|
||||
MODERATE = "moderate"
|
||||
WEAK = "weak"
|
||||
VERY_WEAK = "very_weak"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BreathingPattern:
|
||||
"""Detected breathing pattern."""
|
||||
rate_bpm: float
|
||||
amplitude: float
|
||||
regularity: float
|
||||
pattern_type: BreathingType
|
||||
confidence: float
|
||||
timestamp: datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class HeartbeatSignature:
|
||||
"""Detected heartbeat signature."""
|
||||
rate_bpm: float
|
||||
signal_strength: SignalStrength
|
||||
hrv_estimate: Optional[float]
|
||||
confidence: float
|
||||
timestamp: datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class VitalSignsReading:
|
||||
"""Combined vital signs reading."""
|
||||
breathing: Optional[BreathingPattern]
|
||||
heartbeat: Optional[HeartbeatSignature]
|
||||
motion_detected: bool
|
||||
overall_confidence: float
|
||||
timestamp: datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class BreathingDetectorConfig:
|
||||
"""Configuration for breathing detection."""
|
||||
min_rate_bpm: float = 4.0 # Very slow breathing
|
||||
max_rate_bpm: float = 40.0 # Fast breathing (distressed)
|
||||
min_amplitude: float = 0.1
|
||||
window_size: int = 512
|
||||
window_overlap: float = 0.5
|
||||
confidence_threshold: float = 0.3
|
||||
|
||||
|
||||
@dataclass
|
||||
class HeartbeatDetectorConfig:
|
||||
"""Configuration for heartbeat detection."""
|
||||
min_rate_bpm: float = 30.0 # Bradycardia
|
||||
max_rate_bpm: float = 200.0 # Extreme tachycardia
|
||||
min_signal_strength: float = 0.05
|
||||
window_size: int = 1024
|
||||
enhanced_processing: bool = True
|
||||
confidence_threshold: float = 0.4
|
||||
|
||||
|
||||
class BreathingDetector:
|
||||
"""Detector for breathing patterns in CSI signals.
|
||||
|
||||
Breathing causes periodic chest movement that modulates the WiFi signal.
|
||||
We detect this by looking for periodic variations in the 0.1-0.67 Hz range
|
||||
(corresponding to 6-40 breaths per minute).
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[BreathingDetectorConfig] = None):
|
||||
"""Initialize breathing detector.
|
||||
|
||||
Args:
|
||||
config: Detector configuration. Uses defaults if None.
|
||||
"""
|
||||
self.config = config or BreathingDetectorConfig()
|
||||
|
||||
def detect(self, csi_amplitudes: np.ndarray, sample_rate: float) -> Optional[BreathingPattern]:
|
||||
"""Detect breathing pattern from CSI amplitude variations.
|
||||
|
||||
Args:
|
||||
csi_amplitudes: Array of CSI amplitude values.
|
||||
sample_rate: Sampling rate in Hz.
|
||||
|
||||
Returns:
|
||||
Detected BreathingPattern or None if not detected.
|
||||
"""
|
||||
if len(csi_amplitudes) < self.config.window_size:
|
||||
return None
|
||||
|
||||
# Calculate the frequency spectrum
|
||||
spectrum = self._compute_spectrum(csi_amplitudes)
|
||||
|
||||
# Find the dominant frequency in the breathing range
|
||||
min_freq = self.config.min_rate_bpm / 60.0
|
||||
max_freq = self.config.max_rate_bpm / 60.0
|
||||
|
||||
result = self._find_dominant_frequency(
|
||||
spectrum, sample_rate, min_freq, max_freq
|
||||
)
|
||||
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
dominant_freq, amplitude = result
|
||||
|
||||
# Convert to BPM
|
||||
rate_bpm = dominant_freq * 60.0
|
||||
|
||||
# Check amplitude threshold
|
||||
if amplitude < self.config.min_amplitude:
|
||||
return None
|
||||
|
||||
# Calculate regularity
|
||||
regularity = self._calculate_regularity(spectrum, dominant_freq, sample_rate)
|
||||
|
||||
# Determine breathing type
|
||||
pattern_type = self._classify_pattern(rate_bpm, regularity)
|
||||
|
||||
# Calculate confidence
|
||||
confidence = self._calculate_confidence(amplitude, regularity)
|
||||
|
||||
if confidence < self.config.confidence_threshold:
|
||||
return None
|
||||
|
||||
return BreathingPattern(
|
||||
rate_bpm=rate_bpm,
|
||||
amplitude=amplitude,
|
||||
regularity=regularity,
|
||||
pattern_type=pattern_type,
|
||||
confidence=confidence,
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
def _compute_spectrum(self, signal: np.ndarray) -> np.ndarray:
|
||||
"""Compute frequency spectrum using FFT."""
|
||||
# Apply window
|
||||
window = scipy.signal.windows.hamming(len(signal))
|
||||
windowed = signal * window
|
||||
|
||||
# Compute FFT
|
||||
spectrum = np.abs(scipy.fft.rfft(windowed))
|
||||
return spectrum
|
||||
|
||||
def _find_dominant_frequency(
|
||||
self,
|
||||
spectrum: np.ndarray,
|
||||
sample_rate: float,
|
||||
min_freq: float,
|
||||
max_freq: float
|
||||
) -> Optional[Tuple[float, float]]:
|
||||
"""Find the dominant frequency in a given range."""
|
||||
# rfft output length is n//2 + 1 for input of length n
|
||||
# So original length n = (len(spectrum) - 1) * 2
|
||||
n = (len(spectrum) - 1) * 2
|
||||
freqs = scipy.fft.rfftfreq(n, 1.0 / sample_rate)
|
||||
|
||||
# Ensure freqs and spectrum have same length
|
||||
min_len = min(len(freqs), len(spectrum))
|
||||
freqs = freqs[:min_len]
|
||||
spectrum_trimmed = spectrum[:min_len]
|
||||
|
||||
# Find indices in the frequency range
|
||||
mask = (freqs >= min_freq) & (freqs <= max_freq)
|
||||
if not np.any(mask):
|
||||
return None
|
||||
|
||||
masked_spectrum = spectrum_trimmed.copy()
|
||||
masked_spectrum[~mask] = 0
|
||||
|
||||
# Find peak
|
||||
peak_idx = np.argmax(masked_spectrum)
|
||||
if masked_spectrum[peak_idx] == 0:
|
||||
return None
|
||||
|
||||
return freqs[peak_idx], spectrum_trimmed[peak_idx]
|
||||
|
||||
def _calculate_regularity(
|
||||
self,
|
||||
spectrum: np.ndarray,
|
||||
dominant_freq: float,
|
||||
sample_rate: float
|
||||
) -> float:
|
||||
"""Calculate how regular the breathing pattern is."""
|
||||
n = (len(spectrum) - 1) * 2
|
||||
freqs = scipy.fft.rfftfreq(n, 1.0 / sample_rate)
|
||||
|
||||
# Look at energy concentration around dominant frequency
|
||||
freq_resolution = freqs[1] - freqs[0] if len(freqs) > 1 else 1.0
|
||||
peak_idx = int(dominant_freq / freq_resolution) if freq_resolution > 0 else 0
|
||||
|
||||
# Calculate energy in narrow band around peak
|
||||
half_bandwidth = 3 # bins on each side
|
||||
start_idx = max(0, peak_idx - half_bandwidth)
|
||||
end_idx = min(len(spectrum), peak_idx + half_bandwidth + 1)
|
||||
|
||||
peak_energy = np.sum(spectrum[start_idx:end_idx] ** 2)
|
||||
total_energy = np.sum(spectrum ** 2) + 1e-10
|
||||
|
||||
regularity = float(peak_energy / total_energy)
|
||||
return min(1.0, regularity * 2.0) # Scale to 0-1
|
||||
|
||||
def _classify_pattern(self, rate_bpm: float, regularity: float) -> BreathingType:
|
||||
"""Classify breathing pattern based on rate and regularity."""
|
||||
if regularity < 0.3:
|
||||
return BreathingType.IRREGULAR
|
||||
|
||||
if rate_bpm < 6:
|
||||
return BreathingType.APNEA
|
||||
elif rate_bpm < 12:
|
||||
return BreathingType.SHALLOW
|
||||
elif rate_bpm <= 20:
|
||||
return BreathingType.NORMAL
|
||||
elif rate_bpm <= 25:
|
||||
return BreathingType.DEEP
|
||||
else:
|
||||
return BreathingType.RAPID
|
||||
|
||||
def _calculate_confidence(self, amplitude: float, regularity: float) -> float:
|
||||
"""Calculate detection confidence."""
|
||||
# Combine amplitude and regularity factors
|
||||
amp_factor = min(1.0, amplitude / 0.5)
|
||||
confidence = 0.6 * amp_factor + 0.4 * regularity
|
||||
return float(np.clip(confidence, 0.0, 1.0))
|
||||
|
||||
|
||||
class HeartbeatDetector:
|
||||
"""Detector for heartbeat signatures using micro-Doppler analysis.
|
||||
|
||||
Heartbeats cause very small chest wall movements (~0.5mm) that can be
|
||||
detected through careful analysis of CSI phase variations at higher
|
||||
frequencies than breathing (0.8-3.3 Hz for 48-200 BPM).
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[HeartbeatDetectorConfig] = None):
|
||||
"""Initialize heartbeat detector.
|
||||
|
||||
Args:
|
||||
config: Detector configuration. Uses defaults if None.
|
||||
"""
|
||||
self.config = config or HeartbeatDetectorConfig()
|
||||
|
||||
def detect(
|
||||
self,
|
||||
csi_phase: np.ndarray,
|
||||
sample_rate: float,
|
||||
breathing_rate: Optional[float] = None
|
||||
) -> Optional[HeartbeatSignature]:
|
||||
"""Detect heartbeat from CSI phase data.
|
||||
|
||||
Args:
|
||||
csi_phase: Array of CSI phase values in radians.
|
||||
sample_rate: Sampling rate in Hz.
|
||||
breathing_rate: Known breathing rate in Hz (optional).
|
||||
|
||||
Returns:
|
||||
Detected HeartbeatSignature or None if not detected.
|
||||
"""
|
||||
if len(csi_phase) < self.config.window_size:
|
||||
return None
|
||||
|
||||
# Remove breathing component if known
|
||||
if breathing_rate is not None:
|
||||
filtered = self._remove_breathing_component(csi_phase, sample_rate, breathing_rate)
|
||||
else:
|
||||
filtered = self._highpass_filter(csi_phase, sample_rate, 0.8)
|
||||
|
||||
# Compute micro-Doppler spectrum
|
||||
spectrum = self._compute_micro_doppler_spectrum(filtered, sample_rate)
|
||||
|
||||
# Find heartbeat frequency
|
||||
min_freq = self.config.min_rate_bpm / 60.0
|
||||
max_freq = self.config.max_rate_bpm / 60.0
|
||||
|
||||
result = self._find_heartbeat_frequency(
|
||||
spectrum, sample_rate, min_freq, max_freq
|
||||
)
|
||||
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
heart_freq, strength = result
|
||||
|
||||
if strength < self.config.min_signal_strength:
|
||||
return None
|
||||
|
||||
rate_bpm = heart_freq * 60.0
|
||||
|
||||
# Classify signal strength
|
||||
signal_strength = self._classify_signal_strength(strength)
|
||||
|
||||
# Estimate HRV if we have enough data
|
||||
hrv_estimate = self._estimate_hrv(csi_phase, sample_rate, heart_freq)
|
||||
|
||||
# Calculate confidence
|
||||
confidence = self._calculate_confidence(strength, signal_strength)
|
||||
|
||||
if confidence < self.config.confidence_threshold:
|
||||
return None
|
||||
|
||||
return HeartbeatSignature(
|
||||
rate_bpm=rate_bpm,
|
||||
signal_strength=signal_strength,
|
||||
hrv_estimate=hrv_estimate,
|
||||
confidence=confidence,
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
def _remove_breathing_component(
|
||||
self,
|
||||
phase: np.ndarray,
|
||||
sample_rate: float,
|
||||
breathing_rate: float
|
||||
) -> np.ndarray:
|
||||
"""Remove breathing frequency component from phase signal."""
|
||||
# Design notch filter at breathing frequency
|
||||
quality_factor = 30.0
|
||||
b, a = scipy.signal.iirnotch(breathing_rate, quality_factor, sample_rate)
|
||||
|
||||
# Also remove harmonics (2x, 3x)
|
||||
filtered = scipy.signal.filtfilt(b, a, phase)
|
||||
|
||||
for harmonic in [2, 3]:
|
||||
notch_freq = breathing_rate * harmonic
|
||||
if notch_freq < sample_rate / 2:
|
||||
b, a = scipy.signal.iirnotch(notch_freq, quality_factor, sample_rate)
|
||||
filtered = scipy.signal.filtfilt(b, a, filtered)
|
||||
|
||||
return filtered
|
||||
|
||||
def _highpass_filter(
|
||||
self,
|
||||
signal: np.ndarray,
|
||||
sample_rate: float,
|
||||
cutoff: float
|
||||
) -> np.ndarray:
|
||||
"""Apply highpass filter to remove low-frequency components."""
|
||||
nyquist = sample_rate / 2
|
||||
if cutoff >= nyquist:
|
||||
return signal
|
||||
|
||||
b, a = scipy.signal.butter(4, cutoff / nyquist, btype='high')
|
||||
return scipy.signal.filtfilt(b, a, signal)
|
||||
|
||||
def _compute_micro_doppler_spectrum(
|
||||
self,
|
||||
signal: np.ndarray,
|
||||
sample_rate: float
|
||||
) -> np.ndarray:
|
||||
"""Compute micro-Doppler spectrum for heartbeat detection."""
|
||||
# Use shorter window for better time resolution
|
||||
window_size = min(len(signal), self.config.window_size)
|
||||
|
||||
if self.config.enhanced_processing:
|
||||
# Use STFT for better frequency resolution
|
||||
f, t, Zxx = scipy.signal.stft(
|
||||
signal,
|
||||
sample_rate,
|
||||
nperseg=window_size,
|
||||
noverlap=window_size // 2
|
||||
)
|
||||
# Average over time
|
||||
spectrum = np.mean(np.abs(Zxx), axis=1)
|
||||
else:
|
||||
# Simple FFT
|
||||
window = scipy.signal.windows.hamming(window_size)
|
||||
windowed = signal[:window_size] * window
|
||||
spectrum = np.abs(scipy.fft.rfft(windowed))
|
||||
|
||||
return spectrum
|
||||
|
||||
def _find_heartbeat_frequency(
|
||||
self,
|
||||
spectrum: np.ndarray,
|
||||
sample_rate: float,
|
||||
min_freq: float,
|
||||
max_freq: float
|
||||
) -> Optional[Tuple[float, float]]:
|
||||
"""Find heartbeat frequency in the spectrum."""
|
||||
# rfft output length is n//2 + 1 for input of length n
|
||||
# So original length n = (len(spectrum) - 1) * 2
|
||||
n = (len(spectrum) - 1) * 2
|
||||
freqs = scipy.fft.rfftfreq(n, 1.0 / sample_rate)
|
||||
|
||||
# Ensure freqs and spectrum have same length
|
||||
min_len = min(len(freqs), len(spectrum))
|
||||
freqs = freqs[:min_len]
|
||||
spectrum_trimmed = spectrum[:min_len]
|
||||
|
||||
# Find indices in the frequency range
|
||||
mask = (freqs >= min_freq) & (freqs <= max_freq)
|
||||
if not np.any(mask):
|
||||
return None
|
||||
|
||||
masked_spectrum = spectrum_trimmed.copy()
|
||||
masked_spectrum[~mask] = 0
|
||||
|
||||
# Find peak
|
||||
peak_idx = np.argmax(masked_spectrum)
|
||||
if masked_spectrum[peak_idx] == 0:
|
||||
return None
|
||||
|
||||
return freqs[peak_idx], spectrum_trimmed[peak_idx]
|
||||
|
||||
def _classify_signal_strength(self, strength: float) -> SignalStrength:
|
||||
"""Classify signal strength level."""
|
||||
if strength > 0.3:
|
||||
return SignalStrength.STRONG
|
||||
elif strength > 0.15:
|
||||
return SignalStrength.MODERATE
|
||||
elif strength > 0.08:
|
||||
return SignalStrength.WEAK
|
||||
else:
|
||||
return SignalStrength.VERY_WEAK
|
||||
|
||||
def _estimate_hrv(
|
||||
self,
|
||||
phase: np.ndarray,
|
||||
sample_rate: float,
|
||||
heart_freq: float
|
||||
) -> Optional[float]:
|
||||
"""Estimate heart rate variability."""
|
||||
# Simple HRV estimation based on spectral width
|
||||
# In practice, would use peak detection and RR interval analysis
|
||||
n = len(phase)
|
||||
if n < self.config.window_size * 2:
|
||||
return None
|
||||
|
||||
# Placeholder - would require more sophisticated analysis
|
||||
return None
|
||||
|
||||
def _calculate_confidence(
|
||||
self,
|
||||
strength: float,
|
||||
signal_class: SignalStrength
|
||||
) -> float:
|
||||
"""Calculate detection confidence."""
|
||||
strength_factor = min(1.0, strength / 0.2)
|
||||
|
||||
class_weights = {
|
||||
SignalStrength.STRONG: 1.0,
|
||||
SignalStrength.MODERATE: 0.7,
|
||||
SignalStrength.WEAK: 0.4,
|
||||
SignalStrength.VERY_WEAK: 0.2,
|
||||
}
|
||||
class_factor = class_weights[signal_class]
|
||||
|
||||
confidence = 0.5 * strength_factor + 0.5 * class_factor
|
||||
return float(np.clip(confidence, 0.0, 1.0))
|
||||
|
||||
|
||||
class VitalSignsDetector:
|
||||
"""Combined vital signs detector for breathing and heartbeat."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
breathing_config: Optional[BreathingDetectorConfig] = None,
|
||||
heartbeat_config: Optional[HeartbeatDetectorConfig] = None
|
||||
):
|
||||
"""Initialize combined detector.
|
||||
|
||||
Args:
|
||||
breathing_config: Breathing detector configuration.
|
||||
heartbeat_config: Heartbeat detector configuration.
|
||||
"""
|
||||
self.breathing_detector = BreathingDetector(breathing_config)
|
||||
self.heartbeat_detector = HeartbeatDetector(heartbeat_config)
|
||||
self._motion_threshold = 0.5
|
||||
|
||||
def detect(
|
||||
self,
|
||||
csi_amplitude: np.ndarray,
|
||||
csi_phase: np.ndarray,
|
||||
sample_rate: float
|
||||
) -> VitalSignsReading:
|
||||
"""Detect vital signs from CSI data.
|
||||
|
||||
Args:
|
||||
csi_amplitude: CSI amplitude values.
|
||||
csi_phase: CSI phase values in radians.
|
||||
sample_rate: Sampling rate in Hz.
|
||||
|
||||
Returns:
|
||||
Combined VitalSignsReading.
|
||||
"""
|
||||
# Detect breathing
|
||||
breathing = self.breathing_detector.detect(csi_amplitude, sample_rate)
|
||||
|
||||
# Detect heartbeat (using breathing rate if available)
|
||||
breathing_rate = (breathing.rate_bpm / 60.0) if breathing else None
|
||||
heartbeat = self.heartbeat_detector.detect(csi_phase, sample_rate, breathing_rate)
|
||||
|
||||
# Detect motion
|
||||
motion_detected = self._detect_motion(csi_amplitude)
|
||||
|
||||
# Calculate overall confidence
|
||||
overall_confidence = self._calculate_overall_confidence(
|
||||
breathing, heartbeat, motion_detected
|
||||
)
|
||||
|
||||
return VitalSignsReading(
|
||||
breathing=breathing,
|
||||
heartbeat=heartbeat,
|
||||
motion_detected=motion_detected,
|
||||
overall_confidence=overall_confidence,
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
def _detect_motion(self, amplitude: np.ndarray) -> bool:
|
||||
"""Detect significant motion from amplitude variance."""
|
||||
if len(amplitude) < 10:
|
||||
return False
|
||||
variance = np.var(amplitude)
|
||||
return variance > self._motion_threshold
|
||||
|
||||
def _calculate_overall_confidence(
|
||||
self,
|
||||
breathing: Optional[BreathingPattern],
|
||||
heartbeat: Optional[HeartbeatSignature],
|
||||
motion_detected: bool
|
||||
) -> float:
|
||||
"""Calculate overall detection confidence."""
|
||||
confidences = []
|
||||
|
||||
if breathing:
|
||||
confidences.append(breathing.confidence)
|
||||
if heartbeat:
|
||||
confidences.append(heartbeat.confidence)
|
||||
|
||||
if not confidences:
|
||||
return 0.0
|
||||
|
||||
base_confidence = np.mean(confidences)
|
||||
|
||||
# Motion can either help (confirms presence) or hurt (noise)
|
||||
if motion_detected:
|
||||
# Strong motion reduces confidence in subtle vital sign detection
|
||||
if base_confidence > 0.7:
|
||||
base_confidence *= 0.9
|
||||
|
||||
return float(np.clip(base_confidence, 0.0, 1.0))
|
||||
@@ -1,9 +1,10 @@
|
||||
"""CSI data extraction from WiFi hardware using Test-Driven Development approach."""
|
||||
|
||||
import asyncio
|
||||
import struct
|
||||
import numpy as np
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Any, Optional, Callable, Protocol
|
||||
from typing import Dict, Any, Optional, Callable, Protocol, List, Tuple
|
||||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
import logging
|
||||
@@ -42,13 +43,28 @@ class CSIParser(Protocol):
|
||||
|
||||
|
||||
class ESP32CSIParser:
|
||||
"""Parser for ESP32 CSI data format."""
|
||||
"""Parser for ESP32 CSI data format.
|
||||
|
||||
ESP32 CSI data format (from esp-csi library):
|
||||
- Header: 'CSI_DATA:' prefix
|
||||
- Fields: timestamp,rssi,rate,sig_mode,mcs,bandwidth,smoothing,
|
||||
not_sounding,aggregation,stbc,fec_coding,sgi,noise_floor,
|
||||
ampdu_cnt,channel,secondary_channel,local_timestamp,
|
||||
ant,sig_len,rx_state,len,first_word,data[...]
|
||||
|
||||
The actual CSI data is in the 'data' field as complex I/Q values.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize ESP32 CSI parser with default configuration."""
|
||||
self.htltf_subcarriers = 56 # HT-LTF subcarriers for 20MHz
|
||||
self.antenna_count = 1 # Most ESP32 have 1 antenna
|
||||
|
||||
def parse(self, raw_data: bytes) -> CSIData:
|
||||
"""Parse ESP32 CSI data format.
|
||||
|
||||
Args:
|
||||
raw_data: Raw bytes from ESP32
|
||||
raw_data: Raw bytes from ESP32 serial/network
|
||||
|
||||
Returns:
|
||||
Parsed CSI data
|
||||
@@ -60,12 +76,103 @@ class ESP32CSIParser:
|
||||
raise CSIParseError("Empty data received")
|
||||
|
||||
try:
|
||||
data_str = raw_data.decode('utf-8')
|
||||
if not data_str.startswith('CSI_DATA:'):
|
||||
data_str = raw_data.decode('utf-8').strip()
|
||||
|
||||
# Handle ESP-CSI library format
|
||||
if data_str.startswith('CSI_DATA,'):
|
||||
return self._parse_esp_csi_format(data_str)
|
||||
# Handle simplified format for testing
|
||||
elif data_str.startswith('CSI_DATA:'):
|
||||
return self._parse_simple_format(data_str)
|
||||
else:
|
||||
raise CSIParseError("Invalid ESP32 CSI data format")
|
||||
|
||||
# Parse ESP32 format: CSI_DATA:timestamp,antennas,subcarriers,freq,bw,snr,[amp],[phase]
|
||||
parts = data_str[9:].split(',') # Remove 'CSI_DATA:' prefix
|
||||
except UnicodeDecodeError:
|
||||
# Binary format - parse as raw bytes
|
||||
return self._parse_binary_format(raw_data)
|
||||
except (ValueError, IndexError) as e:
|
||||
raise CSIParseError(f"Failed to parse ESP32 data: {e}")
|
||||
|
||||
def _parse_esp_csi_format(self, data_str: str) -> CSIData:
|
||||
"""Parse ESP-CSI library CSV format.
|
||||
|
||||
Format: CSI_DATA,<mac>,<rssi>,<rate>,<sig_mode>,<mcs>,<bw>,<smoothing>,
|
||||
<not_sounding>,<aggregation>,<stbc>,<fec>,<sgi>,<noise>,
|
||||
<ampdu_cnt>,<channel>,<sec_chan>,<timestamp>,<ant>,<sig_len>,
|
||||
<rx_state>,<len>,[csi_data...]
|
||||
"""
|
||||
parts = data_str.split(',')
|
||||
|
||||
if len(parts) < 22:
|
||||
raise CSIParseError(f"Incomplete ESP-CSI data: expected >= 22 fields, got {len(parts)}")
|
||||
|
||||
# Extract metadata
|
||||
mac_addr = parts[1]
|
||||
rssi = int(parts[2])
|
||||
rate = int(parts[3])
|
||||
sig_mode = int(parts[4])
|
||||
mcs = int(parts[5])
|
||||
bandwidth = int(parts[6]) # 0=20MHz, 1=40MHz
|
||||
channel = int(parts[15])
|
||||
timestamp_us = int(parts[17])
|
||||
csi_len = int(parts[21])
|
||||
|
||||
# Parse CSI I/Q data (remaining fields are the CSI values)
|
||||
csi_raw = [int(x) for x in parts[22:22 + csi_len]]
|
||||
|
||||
# Convert I/Q pairs to complex numbers
|
||||
# ESP32 CSI format: [I0, Q0, I1, Q1, ...] as signed 8-bit integers
|
||||
amplitude, phase = self._iq_to_amplitude_phase(csi_raw)
|
||||
|
||||
# Determine frequency from channel
|
||||
if channel <= 14:
|
||||
frequency = 2.412e9 + (channel - 1) * 5e6 # 2.4 GHz band
|
||||
else:
|
||||
frequency = 5.0e9 + (channel - 36) * 5e6 # 5 GHz band
|
||||
|
||||
bw_hz = 20e6 if bandwidth == 0 else 40e6
|
||||
num_subcarriers = len(amplitude) // self.antenna_count
|
||||
|
||||
return CSIData(
|
||||
timestamp=datetime.fromtimestamp(timestamp_us / 1e6, tz=timezone.utc),
|
||||
amplitude=amplitude.reshape(self.antenna_count, -1),
|
||||
phase=phase.reshape(self.antenna_count, -1),
|
||||
frequency=frequency,
|
||||
bandwidth=bw_hz,
|
||||
num_subcarriers=num_subcarriers,
|
||||
num_antennas=self.antenna_count,
|
||||
snr=float(rssi + 100), # Approximate SNR from RSSI
|
||||
metadata={
|
||||
'source': 'esp32',
|
||||
'mac': mac_addr,
|
||||
'rssi': rssi,
|
||||
'mcs': mcs,
|
||||
'channel': channel,
|
||||
'sig_mode': sig_mode,
|
||||
}
|
||||
)
|
||||
|
||||
def _parse_simple_format(self, data_str: str) -> CSIData:
|
||||
"""Parse simplified CSI format for testing/development.
|
||||
|
||||
Format: CSI_DATA:timestamp,antennas,subcarriers,freq,bw,snr,[amp_values],[phase_values]
|
||||
"""
|
||||
content = data_str[9:] # Remove 'CSI_DATA:' prefix
|
||||
|
||||
# Split the main fields and array data
|
||||
if '[' in content:
|
||||
main_part, arrays_part = content.split('[', 1)
|
||||
parts = main_part.rstrip(',').split(',')
|
||||
|
||||
# Parse amplitude and phase arrays
|
||||
arrays_str = '[' + arrays_part
|
||||
amp_str, phase_str = self._split_arrays(arrays_str)
|
||||
amplitude = np.array([float(x) for x in amp_str.strip('[]').split(',')])
|
||||
phase = np.array([float(x) for x in phase_str.strip('[]').split(',')])
|
||||
else:
|
||||
parts = content.split(',')
|
||||
# No array data provided, need to return error or minimal data
|
||||
raise CSIParseError("No CSI array data in simple format")
|
||||
|
||||
timestamp_ms = int(parts[0])
|
||||
num_antennas = int(parts[1])
|
||||
@@ -74,33 +181,141 @@ class ESP32CSIParser:
|
||||
bandwidth_mhz = float(parts[4])
|
||||
snr = float(parts[5])
|
||||
|
||||
# Convert to proper units
|
||||
frequency = frequency_mhz * 1e6 # MHz to Hz
|
||||
bandwidth = bandwidth_mhz * 1e6 # MHz to Hz
|
||||
|
||||
# Parse amplitude and phase arrays (simplified for now)
|
||||
# In real implementation, this would parse actual CSI matrix data
|
||||
amplitude = np.random.rand(num_antennas, num_subcarriers)
|
||||
phase = np.random.rand(num_antennas, num_subcarriers)
|
||||
# Reshape arrays
|
||||
expected_size = num_antennas * num_subcarriers
|
||||
if len(amplitude) != expected_size:
|
||||
# Interpolate or pad
|
||||
amplitude = np.interp(
|
||||
np.linspace(0, 1, expected_size),
|
||||
np.linspace(0, 1, len(amplitude)),
|
||||
amplitude
|
||||
)
|
||||
phase = np.interp(
|
||||
np.linspace(0, 1, expected_size),
|
||||
np.linspace(0, 1, len(phase)),
|
||||
phase
|
||||
)
|
||||
|
||||
return CSIData(
|
||||
timestamp=datetime.fromtimestamp(timestamp_ms / 1000, tz=timezone.utc),
|
||||
amplitude=amplitude,
|
||||
phase=phase,
|
||||
frequency=frequency,
|
||||
bandwidth=bandwidth,
|
||||
amplitude=amplitude.reshape(num_antennas, num_subcarriers),
|
||||
phase=phase.reshape(num_antennas, num_subcarriers),
|
||||
frequency=frequency_mhz * 1e6,
|
||||
bandwidth=bandwidth_mhz * 1e6,
|
||||
num_subcarriers=num_subcarriers,
|
||||
num_antennas=num_antennas,
|
||||
snr=snr,
|
||||
metadata={'source': 'esp32', 'raw_length': len(raw_data)}
|
||||
metadata={'source': 'esp32', 'format': 'simple'}
|
||||
)
|
||||
|
||||
except (ValueError, IndexError) as e:
|
||||
raise CSIParseError(f"Failed to parse ESP32 data: {e}")
|
||||
def _parse_binary_format(self, raw_data: bytes) -> CSIData:
|
||||
"""Parse binary CSI format from ESP32.
|
||||
|
||||
Binary format (struct packed):
|
||||
- 4 bytes: timestamp (uint32)
|
||||
- 1 byte: num_antennas (uint8)
|
||||
- 1 byte: num_subcarriers (uint8)
|
||||
- 2 bytes: channel (uint16)
|
||||
- 4 bytes: frequency (float32)
|
||||
- 4 bytes: bandwidth (float32)
|
||||
- 4 bytes: snr (float32)
|
||||
- Remaining: CSI I/Q data as int8 pairs
|
||||
"""
|
||||
if len(raw_data) < 20:
|
||||
raise CSIParseError("Binary data too short")
|
||||
|
||||
header_fmt = '<IBBHfff'
|
||||
header_size = struct.calcsize(header_fmt)
|
||||
|
||||
timestamp, num_antennas, num_subcarriers, channel, freq, bw, snr = \
|
||||
struct.unpack(header_fmt, raw_data[:header_size])
|
||||
|
||||
# Parse I/Q data
|
||||
iq_data = raw_data[header_size:]
|
||||
csi_raw = list(struct.unpack(f'{len(iq_data)}b', iq_data))
|
||||
|
||||
amplitude, phase = self._iq_to_amplitude_phase(csi_raw)
|
||||
|
||||
# Adjust dimensions
|
||||
expected_size = num_antennas * num_subcarriers
|
||||
if len(amplitude) < expected_size:
|
||||
amplitude = np.pad(amplitude, (0, expected_size - len(amplitude)))
|
||||
phase = np.pad(phase, (0, expected_size - len(phase)))
|
||||
elif len(amplitude) > expected_size:
|
||||
amplitude = amplitude[:expected_size]
|
||||
phase = phase[:expected_size]
|
||||
|
||||
return CSIData(
|
||||
timestamp=datetime.fromtimestamp(timestamp / 1000, tz=timezone.utc),
|
||||
amplitude=amplitude.reshape(num_antennas, num_subcarriers),
|
||||
phase=phase.reshape(num_antennas, num_subcarriers),
|
||||
frequency=float(freq),
|
||||
bandwidth=float(bw),
|
||||
num_subcarriers=num_subcarriers,
|
||||
num_antennas=num_antennas,
|
||||
snr=float(snr),
|
||||
metadata={'source': 'esp32', 'format': 'binary', 'channel': channel}
|
||||
)
|
||||
|
||||
def _iq_to_amplitude_phase(self, iq_data: List[int]) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Convert I/Q pairs to amplitude and phase.
|
||||
|
||||
Args:
|
||||
iq_data: List of interleaved I, Q values (signed 8-bit)
|
||||
|
||||
Returns:
|
||||
Tuple of (amplitude, phase) arrays
|
||||
"""
|
||||
if len(iq_data) % 2 != 0:
|
||||
iq_data = iq_data[:-1] # Trim odd value
|
||||
|
||||
i_vals = np.array(iq_data[0::2], dtype=np.float64)
|
||||
q_vals = np.array(iq_data[1::2], dtype=np.float64)
|
||||
|
||||
# Calculate amplitude (magnitude) and phase
|
||||
complex_vals = i_vals + 1j * q_vals
|
||||
amplitude = np.abs(complex_vals)
|
||||
phase = np.angle(complex_vals)
|
||||
|
||||
# Normalize amplitude to [0, 1] range
|
||||
max_amp = np.max(amplitude)
|
||||
if max_amp > 0:
|
||||
amplitude = amplitude / max_amp
|
||||
|
||||
return amplitude, phase
|
||||
|
||||
def _split_arrays(self, arrays_str: str) -> Tuple[str, str]:
|
||||
"""Split concatenated array strings."""
|
||||
# Find the boundary between two arrays
|
||||
depth = 0
|
||||
split_idx = 0
|
||||
for i, c in enumerate(arrays_str):
|
||||
if c == '[':
|
||||
depth += 1
|
||||
elif c == ']':
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
split_idx = i + 1
|
||||
break
|
||||
|
||||
amp_str = arrays_str[:split_idx]
|
||||
phase_str = arrays_str[split_idx:].lstrip(',')
|
||||
return amp_str, phase_str
|
||||
|
||||
|
||||
class RouterCSIParser:
|
||||
"""Parser for router CSI data format."""
|
||||
"""Parser for router CSI data formats (Atheros, Intel, etc.).
|
||||
|
||||
Supports:
|
||||
- Atheros CSI Tool format (ath9k/ath10k)
|
||||
- Intel 5300 CSI Tool format
|
||||
- Nexmon CSI format (Broadcom)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize router CSI parser."""
|
||||
self.default_subcarriers = 56 # 20MHz HT
|
||||
self.default_antennas = 3
|
||||
|
||||
def parse(self, raw_data: bytes) -> CSIData:
|
||||
"""Parse router CSI data format.
|
||||
@@ -117,30 +332,289 @@ class RouterCSIParser:
|
||||
if not raw_data:
|
||||
raise CSIParseError("Empty data received")
|
||||
|
||||
# Handle different router formats
|
||||
# Try to decode as text first
|
||||
try:
|
||||
data_str = raw_data.decode('utf-8')
|
||||
|
||||
if data_str.startswith('ATHEROS_CSI:'):
|
||||
return self._parse_atheros_format(raw_data)
|
||||
else:
|
||||
return self._parse_atheros_text_format(data_str)
|
||||
elif data_str.startswith('INTEL_CSI:'):
|
||||
return self._parse_intel_text_format(data_str)
|
||||
except UnicodeDecodeError:
|
||||
pass
|
||||
|
||||
# Binary format detection based on header
|
||||
if len(raw_data) >= 4:
|
||||
magic = struct.unpack('<I', raw_data[:4])[0]
|
||||
if magic == 0x11111111: # Atheros CSI Tool magic
|
||||
return self._parse_atheros_binary_format(raw_data)
|
||||
elif magic == 0xBB: # Intel 5300 magic byte pattern
|
||||
return self._parse_intel_binary_format(raw_data)
|
||||
|
||||
raise CSIParseError("Unknown router CSI format")
|
||||
|
||||
def _parse_atheros_format(self, raw_data: bytes) -> CSIData:
|
||||
"""Parse Atheros CSI format (placeholder implementation)."""
|
||||
# This would implement actual Atheros CSI parsing
|
||||
# For now, return mock data for testing
|
||||
def _parse_atheros_text_format(self, data_str: str) -> CSIData:
|
||||
"""Parse Atheros CSI text format.
|
||||
|
||||
Format: ATHEROS_CSI:timestamp,rssi,rate,channel,bw,nr,nc,num_tones,[csi_data...]
|
||||
"""
|
||||
content = data_str[12:] # Remove 'ATHEROS_CSI:' prefix
|
||||
parts = content.split(',')
|
||||
|
||||
if len(parts) < 8:
|
||||
raise CSIParseError("Incomplete Atheros CSI data")
|
||||
|
||||
timestamp = int(parts[0])
|
||||
rssi = int(parts[1])
|
||||
rate = int(parts[2])
|
||||
channel = int(parts[3])
|
||||
bandwidth = int(parts[4]) # MHz
|
||||
nr = int(parts[5]) # Rx antennas
|
||||
nc = int(parts[6]) # Tx antennas (usually 1 for probe)
|
||||
num_tones = int(parts[7]) # Subcarriers
|
||||
|
||||
# Parse CSI matrix data
|
||||
csi_values = [float(x) for x in parts[8:] if x.strip()]
|
||||
|
||||
# CSI data is complex: [real, imag, real, imag, ...]
|
||||
amplitude, phase = self._parse_complex_csi(csi_values, nr, num_tones)
|
||||
|
||||
# Calculate frequency from channel
|
||||
if channel <= 14:
|
||||
frequency = 2.412e9 + (channel - 1) * 5e6
|
||||
else:
|
||||
frequency = 5.18e9 + (channel - 36) * 5e6
|
||||
|
||||
return CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=12.0,
|
||||
metadata={'source': 'atheros_router'}
|
||||
timestamp=datetime.fromtimestamp(timestamp / 1000, tz=timezone.utc),
|
||||
amplitude=amplitude,
|
||||
phase=phase,
|
||||
frequency=frequency,
|
||||
bandwidth=bandwidth * 1e6,
|
||||
num_subcarriers=num_tones,
|
||||
num_antennas=nr,
|
||||
snr=float(rssi + 95),
|
||||
metadata={
|
||||
'source': 'atheros_router',
|
||||
'rssi': rssi,
|
||||
'rate': rate,
|
||||
'channel': channel,
|
||||
'tx_antennas': nc,
|
||||
}
|
||||
)
|
||||
|
||||
def _parse_atheros_binary_format(self, raw_data: bytes) -> CSIData:
|
||||
"""Parse Atheros CSI Tool binary format.
|
||||
|
||||
Based on ath9k/ath10k CSI Tool structure:
|
||||
- 4 bytes: magic (0x11111111)
|
||||
- 8 bytes: timestamp
|
||||
- 2 bytes: channel
|
||||
- 1 byte: bandwidth (0=20MHz, 1=40MHz, 2=80MHz)
|
||||
- 1 byte: nr (rx antennas)
|
||||
- 1 byte: nc (tx antennas)
|
||||
- 1 byte: num_tones
|
||||
- 2 bytes: rssi
|
||||
- Remaining: CSI payload (complex int16 per subcarrier per antenna pair)
|
||||
"""
|
||||
if len(raw_data) < 20:
|
||||
raise CSIParseError("Atheros binary data too short")
|
||||
|
||||
header_fmt = '<IQHBBBBB' # Q is 8-byte timestamp
|
||||
header_size = struct.calcsize(header_fmt)
|
||||
|
||||
magic, timestamp, channel, bw, nr, nc, num_tones, rssi = \
|
||||
struct.unpack(header_fmt, raw_data[:header_size])
|
||||
|
||||
if magic != 0x11111111:
|
||||
raise CSIParseError("Invalid Atheros magic number")
|
||||
|
||||
# Parse CSI payload
|
||||
csi_data = raw_data[header_size:]
|
||||
|
||||
# Each subcarrier has complex value per antenna pair: int16 real + int16 imag
|
||||
expected_bytes = nr * nc * num_tones * 4
|
||||
if len(csi_data) < expected_bytes:
|
||||
# Adjust num_tones based on available data
|
||||
num_tones = len(csi_data) // (nr * nc * 4)
|
||||
|
||||
csi_complex = np.zeros((nr, num_tones), dtype=np.complex128)
|
||||
|
||||
for ant in range(nr):
|
||||
for tone in range(num_tones):
|
||||
offset = (ant * nc * num_tones + tone) * 4
|
||||
if offset + 4 <= len(csi_data):
|
||||
real, imag = struct.unpack('<hh', csi_data[offset:offset+4])
|
||||
csi_complex[ant, tone] = complex(real, imag)
|
||||
|
||||
amplitude = np.abs(csi_complex)
|
||||
phase = np.angle(csi_complex)
|
||||
|
||||
# Normalize amplitude
|
||||
max_amp = np.max(amplitude)
|
||||
if max_amp > 0:
|
||||
amplitude = amplitude / max_amp
|
||||
|
||||
# Calculate frequency
|
||||
if channel <= 14:
|
||||
frequency = 2.412e9 + (channel - 1) * 5e6
|
||||
else:
|
||||
frequency = 5.18e9 + (channel - 36) * 5e6
|
||||
|
||||
bandwidth_hz = [20e6, 40e6, 80e6][bw] if bw < 3 else 20e6
|
||||
|
||||
return CSIData(
|
||||
timestamp=datetime.fromtimestamp(timestamp / 1e9, tz=timezone.utc),
|
||||
amplitude=amplitude,
|
||||
phase=phase,
|
||||
frequency=frequency,
|
||||
bandwidth=bandwidth_hz,
|
||||
num_subcarriers=num_tones,
|
||||
num_antennas=nr,
|
||||
snr=float(rssi),
|
||||
metadata={
|
||||
'source': 'atheros_router',
|
||||
'format': 'binary',
|
||||
'channel': channel,
|
||||
'tx_antennas': nc,
|
||||
}
|
||||
)
|
||||
|
||||
def _parse_intel_text_format(self, data_str: str) -> CSIData:
|
||||
"""Parse Intel 5300 CSI text format."""
|
||||
content = data_str[10:] # Remove 'INTEL_CSI:' prefix
|
||||
parts = content.split(',')
|
||||
|
||||
if len(parts) < 6:
|
||||
raise CSIParseError("Incomplete Intel CSI data")
|
||||
|
||||
timestamp = int(parts[0])
|
||||
rssi = int(parts[1])
|
||||
channel = int(parts[2])
|
||||
bandwidth = int(parts[3])
|
||||
num_antennas = int(parts[4])
|
||||
num_tones = int(parts[5])
|
||||
|
||||
csi_values = [float(x) for x in parts[6:] if x.strip()]
|
||||
amplitude, phase = self._parse_complex_csi(csi_values, num_antennas, num_tones)
|
||||
|
||||
frequency = 5.18e9 + (channel - 36) * 5e6 if channel > 14 else 2.412e9 + (channel - 1) * 5e6
|
||||
|
||||
return CSIData(
|
||||
timestamp=datetime.fromtimestamp(timestamp / 1000, tz=timezone.utc),
|
||||
amplitude=amplitude,
|
||||
phase=phase,
|
||||
frequency=frequency,
|
||||
bandwidth=bandwidth * 1e6,
|
||||
num_subcarriers=num_tones,
|
||||
num_antennas=num_antennas,
|
||||
snr=float(rssi + 95),
|
||||
metadata={'source': 'intel_5300', 'channel': channel}
|
||||
)
|
||||
|
||||
def _parse_intel_binary_format(self, raw_data: bytes) -> CSIData:
|
||||
"""Parse Intel 5300 CSI Tool binary format."""
|
||||
# Intel format is more complex with BFEE (beamforming feedback) structure
|
||||
if len(raw_data) < 25:
|
||||
raise CSIParseError("Intel binary data too short")
|
||||
|
||||
# BFEE header structure
|
||||
timestamp = struct.unpack('<Q', raw_data[0:8])[0]
|
||||
rssi_a, rssi_b, rssi_c = struct.unpack('<bbb', raw_data[8:11])
|
||||
noise = struct.unpack('<b', raw_data[11:12])[0]
|
||||
agc = struct.unpack('<B', raw_data[12:13])[0]
|
||||
antenna_sel = struct.unpack('<B', raw_data[13:14])[0]
|
||||
perm = struct.unpack('<BBB', raw_data[14:17])
|
||||
num_tones = struct.unpack('<B', raw_data[17:18])[0]
|
||||
nc = struct.unpack('<B', raw_data[18:19])[0]
|
||||
nr = struct.unpack('<B', raw_data[19:20])[0]
|
||||
|
||||
# Parse CSI matrix
|
||||
csi_data = raw_data[20:]
|
||||
|
||||
# Intel stores CSI in a packed format with variable bit width
|
||||
csi_complex = self._unpack_intel_csi(csi_data, nr, nc, num_tones)
|
||||
|
||||
# Use first TX stream
|
||||
amplitude = np.abs(csi_complex[:, 0, :])
|
||||
phase = np.angle(csi_complex[:, 0, :])
|
||||
|
||||
# Normalize
|
||||
max_amp = np.max(amplitude)
|
||||
if max_amp > 0:
|
||||
amplitude = amplitude / max_amp
|
||||
|
||||
rssi_avg = (rssi_a + rssi_b + rssi_c) / 3
|
||||
|
||||
return CSIData(
|
||||
timestamp=datetime.fromtimestamp(timestamp / 1e6, tz=timezone.utc),
|
||||
amplitude=amplitude,
|
||||
phase=phase,
|
||||
frequency=5.32e9, # Default Intel channel
|
||||
bandwidth=40e6,
|
||||
num_subcarriers=num_tones,
|
||||
num_antennas=nr,
|
||||
snr=float(rssi_avg - noise),
|
||||
metadata={
|
||||
'source': 'intel_5300',
|
||||
'format': 'binary',
|
||||
'noise_floor': noise,
|
||||
'agc': agc,
|
||||
}
|
||||
)
|
||||
|
||||
def _unpack_intel_csi(self, data: bytes, nr: int, nc: int, num_tones: int) -> np.ndarray:
|
||||
"""Unpack Intel CSI data with bit manipulation."""
|
||||
csi = np.zeros((nr, nc, num_tones), dtype=np.complex128)
|
||||
|
||||
# Intel uses packed 10-bit values
|
||||
bits_per_sample = 10
|
||||
samples_needed = nr * nc * num_tones * 2 # real + imag
|
||||
|
||||
# Simple unpacking (actual Intel format is more complex)
|
||||
idx = 0
|
||||
for tone in range(num_tones):
|
||||
for nc_idx in range(nc):
|
||||
for nr_idx in range(nr):
|
||||
if idx + 2 <= len(data):
|
||||
# Approximate unpacking
|
||||
real = int.from_bytes(data[idx:idx+1], 'little', signed=True)
|
||||
imag = int.from_bytes(data[idx+1:idx+2], 'little', signed=True)
|
||||
csi[nr_idx, nc_idx, tone] = complex(real, imag)
|
||||
idx += 2
|
||||
|
||||
return csi
|
||||
|
||||
def _parse_complex_csi(
|
||||
self,
|
||||
values: List[float],
|
||||
num_antennas: int,
|
||||
num_tones: int
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Parse complex CSI values from real/imag pairs."""
|
||||
expected_len = num_antennas * num_tones * 2
|
||||
|
||||
if len(values) < expected_len:
|
||||
# Pad with zeros
|
||||
values = values + [0.0] * (expected_len - len(values))
|
||||
|
||||
csi_complex = np.zeros((num_antennas, num_tones), dtype=np.complex128)
|
||||
|
||||
for ant in range(num_antennas):
|
||||
for tone in range(num_tones):
|
||||
idx = (ant * num_tones + tone) * 2
|
||||
if idx + 1 < len(values):
|
||||
csi_complex[ant, tone] = complex(values[idx], values[idx + 1])
|
||||
|
||||
amplitude = np.abs(csi_complex)
|
||||
phase = np.angle(csi_complex)
|
||||
|
||||
# Normalize
|
||||
max_amp = np.max(amplitude)
|
||||
if max_amp > 0:
|
||||
amplitude = amplitude / max_amp
|
||||
|
||||
return amplitude, phase
|
||||
|
||||
|
||||
class CSIExtractor:
|
||||
"""Main CSI data extractor supporting multiple hardware types."""
|
||||
@@ -169,24 +643,18 @@ class CSIExtractor:
|
||||
# State management
|
||||
self.is_connected = False
|
||||
self.is_streaming = False
|
||||
self._connection = None
|
||||
|
||||
# Create appropriate parser
|
||||
if self.hardware_type == 'esp32':
|
||||
self.parser = ESP32CSIParser()
|
||||
elif self.hardware_type == 'router':
|
||||
elif self.hardware_type in ('router', 'atheros', 'intel'):
|
||||
self.parser = RouterCSIParser()
|
||||
else:
|
||||
raise ValueError(f"Unsupported hardware type: {self.hardware_type}")
|
||||
|
||||
def _validate_config(self, config: Dict[str, Any]) -> None:
|
||||
"""Validate configuration parameters.
|
||||
|
||||
Args:
|
||||
config: Configuration to validate
|
||||
|
||||
Raises:
|
||||
ValueError: If configuration is invalid
|
||||
"""
|
||||
"""Validate configuration parameters."""
|
||||
required_fields = ['hardware_type', 'sampling_rate', 'buffer_size', 'timeout']
|
||||
missing_fields = [field for field in required_fields if field not in config]
|
||||
|
||||
@@ -203,11 +671,7 @@ class CSIExtractor:
|
||||
raise ValueError("timeout must be positive")
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Establish connection to CSI hardware.
|
||||
|
||||
Returns:
|
||||
True if connection successful, False otherwise
|
||||
"""
|
||||
"""Establish connection to CSI hardware."""
|
||||
try:
|
||||
success = await self._establish_hardware_connection()
|
||||
self.is_connected = success
|
||||
@@ -224,18 +688,10 @@ class CSIExtractor:
|
||||
self.is_connected = False
|
||||
|
||||
async def extract_csi(self) -> CSIData:
|
||||
"""Extract CSI data from hardware.
|
||||
|
||||
Returns:
|
||||
Extracted CSI data
|
||||
|
||||
Raises:
|
||||
CSIParseError: If not connected or extraction fails
|
||||
"""
|
||||
"""Extract CSI data from hardware."""
|
||||
if not self.is_connected:
|
||||
raise CSIParseError("Not connected to hardware")
|
||||
|
||||
# Retry mechanism for temporary failures
|
||||
for attempt in range(self.retry_attempts):
|
||||
try:
|
||||
raw_data = await self._read_raw_data()
|
||||
@@ -249,22 +705,12 @@ class CSIExtractor:
|
||||
except ConnectionError as e:
|
||||
if attempt < self.retry_attempts - 1:
|
||||
self.logger.warning(f"Extraction attempt {attempt + 1} failed, retrying: {e}")
|
||||
await asyncio.sleep(0.1) # Brief delay before retry
|
||||
await asyncio.sleep(0.1)
|
||||
else:
|
||||
raise CSIParseError(f"Extraction failed after {self.retry_attempts} attempts: {e}")
|
||||
|
||||
def validate_csi_data(self, csi_data: CSIData) -> bool:
|
||||
"""Validate CSI data structure and values.
|
||||
|
||||
Args:
|
||||
csi_data: CSI data to validate
|
||||
|
||||
Returns:
|
||||
True if valid
|
||||
|
||||
Raises:
|
||||
CSIValidationError: If data is invalid
|
||||
"""
|
||||
"""Validate CSI data structure and values."""
|
||||
if csi_data.amplitude.size == 0:
|
||||
raise CSIValidationError("Empty amplitude data")
|
||||
|
||||
@@ -283,17 +729,13 @@ class CSIExtractor:
|
||||
if csi_data.num_antennas <= 0:
|
||||
raise CSIValidationError("Invalid number of antennas")
|
||||
|
||||
if csi_data.snr < -50 or csi_data.snr > 50: # Reasonable SNR range
|
||||
if csi_data.snr < -50 or csi_data.snr > 100:
|
||||
raise CSIValidationError("Invalid SNR value")
|
||||
|
||||
return True
|
||||
|
||||
async def start_streaming(self, callback: Callable[[CSIData], None]) -> None:
|
||||
"""Start streaming CSI data.
|
||||
|
||||
Args:
|
||||
callback: Function to call with each CSI sample
|
||||
"""
|
||||
"""Start streaming CSI data."""
|
||||
self.is_streaming = True
|
||||
|
||||
try:
|
||||
@@ -311,16 +753,68 @@ class CSIExtractor:
|
||||
self.is_streaming = False
|
||||
|
||||
async def _establish_hardware_connection(self) -> bool:
|
||||
"""Establish connection to hardware (to be implemented by subclasses)."""
|
||||
# Placeholder implementation for testing
|
||||
"""Establish connection to hardware."""
|
||||
connection_config = self.config.get('connection', {})
|
||||
|
||||
if self.hardware_type == 'esp32':
|
||||
# Serial or network connection for ESP32
|
||||
port = connection_config.get('port', '/dev/ttyUSB0')
|
||||
baudrate = connection_config.get('baudrate', 115200)
|
||||
|
||||
try:
|
||||
import serial_asyncio
|
||||
reader, writer = await serial_asyncio.open_serial_connection(
|
||||
url=port, baudrate=baudrate
|
||||
)
|
||||
self._connection = (reader, writer)
|
||||
return True
|
||||
except ImportError:
|
||||
self.logger.warning("serial_asyncio not available, using mock connection")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.error(f"Serial connection failed: {e}")
|
||||
return False
|
||||
|
||||
elif self.hardware_type in ('router', 'atheros', 'intel'):
|
||||
# Network connection for router
|
||||
host = connection_config.get('host', '192.168.1.1')
|
||||
port = connection_config.get('port', 5500)
|
||||
|
||||
try:
|
||||
reader, writer = await asyncio.open_connection(host, port)
|
||||
self._connection = (reader, writer)
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.error(f"Network connection failed: {e}")
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
async def _close_hardware_connection(self) -> None:
|
||||
"""Close hardware connection (to be implemented by subclasses)."""
|
||||
# Placeholder implementation for testing
|
||||
pass
|
||||
"""Close hardware connection."""
|
||||
if self._connection:
|
||||
try:
|
||||
reader, writer = self._connection
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error closing connection: {e}")
|
||||
finally:
|
||||
self._connection = None
|
||||
|
||||
async def _read_raw_data(self) -> bytes:
|
||||
"""Read raw data from hardware (to be implemented by subclasses)."""
|
||||
# Placeholder implementation for testing
|
||||
return b"CSI_DATA:1234567890,3,56,2400,20,15.5,[1.0,2.0,3.0],[0.5,1.5,2.5]"
|
||||
"""Read raw data from hardware."""
|
||||
if self._connection:
|
||||
reader, writer = self._connection
|
||||
try:
|
||||
# Read until newline or buffer size
|
||||
data = await asyncio.wait_for(
|
||||
reader.readline(),
|
||||
timeout=self.timeout
|
||||
)
|
||||
return data
|
||||
except asyncio.TimeoutError:
|
||||
raise ConnectionError("Read timeout")
|
||||
else:
|
||||
# Mock data for testing when no real connection
|
||||
raise ConnectionError("No active connection")
|
||||
@@ -265,24 +265,99 @@ class PoseService:
|
||||
self.logger.error(f"Error in pose estimation: {e}")
|
||||
return []
|
||||
|
||||
def _parse_pose_outputs(self, outputs: torch.Tensor) -> List[Dict[str, Any]]:
|
||||
"""Parse neural network outputs into pose detections."""
|
||||
def _parse_pose_outputs(self, outputs: Dict[str, torch.Tensor]) -> List[Dict[str, Any]]:
|
||||
"""Parse neural network outputs into pose detections.
|
||||
|
||||
The DensePose model outputs:
|
||||
- segmentation: (batch, num_parts+1, H, W) - body part segmentation
|
||||
- uv_coords: (batch, 2, H, W) - UV coordinates for surface mapping
|
||||
|
||||
Returns list of detected persons with keypoints and body parts.
|
||||
"""
|
||||
poses = []
|
||||
|
||||
# This is a simplified parsing - in reality, this would depend on the model architecture
|
||||
# For now, generate mock poses based on the output shape
|
||||
# Handle different output formats
|
||||
if isinstance(outputs, torch.Tensor):
|
||||
# Simple tensor output - use legacy parsing
|
||||
return self._parse_simple_outputs(outputs)
|
||||
|
||||
# DensePose structured output
|
||||
segmentation = outputs.get('segmentation')
|
||||
uv_coords = outputs.get('uv_coords')
|
||||
|
||||
if segmentation is None:
|
||||
return []
|
||||
|
||||
batch_size = segmentation.shape[0]
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
# Get segmentation for this sample
|
||||
seg = segmentation[batch_idx] # (num_parts+1, H, W)
|
||||
|
||||
# Find persons by analyzing body part segmentation
|
||||
# Background is class 0, body parts are 1-24
|
||||
body_mask = seg[1:].sum(dim=0) > seg[0] # Any body part vs background
|
||||
|
||||
if not body_mask.any():
|
||||
continue
|
||||
|
||||
# Find connected components (persons)
|
||||
person_regions = self._find_person_regions(body_mask)
|
||||
|
||||
for person_idx, region in enumerate(person_regions):
|
||||
# Extract keypoints from body part segmentation
|
||||
keypoints = self._extract_keypoints_from_segmentation(seg, region)
|
||||
|
||||
# Calculate bounding box from region
|
||||
bbox = self._calculate_bounding_box(region)
|
||||
|
||||
# Calculate confidence from segmentation probabilities
|
||||
seg_probs = torch.softmax(seg, dim=0)
|
||||
region_mask = region['mask']
|
||||
confidence = float(seg_probs[1:, region_mask].max().item())
|
||||
|
||||
# Classify activity from pose keypoints
|
||||
activity = self._classify_activity_from_keypoints(keypoints)
|
||||
|
||||
pose = {
|
||||
"person_id": person_idx,
|
||||
"confidence": confidence,
|
||||
"keypoints": keypoints,
|
||||
"bounding_box": bbox,
|
||||
"activity": activity,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"body_parts": self._extract_body_parts(seg, region) if uv_coords is not None else None
|
||||
}
|
||||
|
||||
poses.append(pose)
|
||||
|
||||
return poses
|
||||
|
||||
def _parse_simple_outputs(self, outputs: torch.Tensor) -> List[Dict[str, Any]]:
|
||||
"""Parse simple tensor outputs (fallback for non-DensePose models)."""
|
||||
poses = []
|
||||
batch_size = outputs.shape[0]
|
||||
|
||||
for i in range(batch_size):
|
||||
# Extract pose information (mock implementation)
|
||||
confidence = float(torch.sigmoid(outputs[i, 0]).item()) if outputs.shape[1] > 0 else 0.5
|
||||
output = outputs[i]
|
||||
|
||||
# Extract confidence from first channel
|
||||
confidence = float(torch.sigmoid(output[0]).mean().item()) if output.numel() > 0 else 0.0
|
||||
|
||||
if confidence < 0.1:
|
||||
continue
|
||||
|
||||
# Try to extract keypoints from output tensor
|
||||
keypoints = self._extract_keypoints_from_tensor(output)
|
||||
bbox = self._estimate_bbox_from_keypoints(keypoints)
|
||||
activity = self._classify_activity_from_keypoints(keypoints)
|
||||
|
||||
pose = {
|
||||
"person_id": i,
|
||||
"confidence": confidence,
|
||||
"keypoints": self._generate_keypoints(),
|
||||
"bounding_box": self._generate_bounding_box(),
|
||||
"activity": self._classify_activity(outputs[i] if len(outputs.shape) > 1 else outputs),
|
||||
"keypoints": keypoints,
|
||||
"bounding_box": bbox,
|
||||
"activity": activity,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
@@ -290,6 +365,272 @@ class PoseService:
|
||||
|
||||
return poses
|
||||
|
||||
def _find_person_regions(self, body_mask: torch.Tensor) -> List[Dict[str, Any]]:
|
||||
"""Find distinct person regions in body mask using connected components."""
|
||||
# Convert to numpy for connected component analysis
|
||||
mask_np = body_mask.cpu().numpy().astype(np.uint8)
|
||||
|
||||
# Simple connected component labeling
|
||||
from scipy import ndimage
|
||||
labeled, num_features = ndimage.label(mask_np)
|
||||
|
||||
regions = []
|
||||
for label_id in range(1, num_features + 1):
|
||||
region_mask = labeled == label_id
|
||||
if region_mask.sum() < 100: # Minimum region size
|
||||
continue
|
||||
|
||||
# Find bounding coordinates
|
||||
coords = np.where(region_mask)
|
||||
regions.append({
|
||||
'mask': torch.from_numpy(region_mask),
|
||||
'y_min': int(coords[0].min()),
|
||||
'y_max': int(coords[0].max()),
|
||||
'x_min': int(coords[1].min()),
|
||||
'x_max': int(coords[1].max()),
|
||||
'area': int(region_mask.sum())
|
||||
})
|
||||
|
||||
return regions
|
||||
|
||||
def _extract_keypoints_from_segmentation(
|
||||
self, segmentation: torch.Tensor, region: Dict[str, Any]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Extract keypoints from body part segmentation."""
|
||||
keypoint_names = [
|
||||
"nose", "left_eye", "right_eye", "left_ear", "right_ear",
|
||||
"left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
|
||||
"left_wrist", "right_wrist", "left_hip", "right_hip",
|
||||
"left_knee", "right_knee", "left_ankle", "right_ankle"
|
||||
]
|
||||
|
||||
# Mapping from body parts to keypoints
|
||||
# DensePose has 24 body parts, we map to COCO keypoints
|
||||
part_to_keypoint = {
|
||||
14: "nose", # Head -> nose
|
||||
10: "left_shoulder", 11: "right_shoulder",
|
||||
12: "left_elbow", 13: "right_elbow",
|
||||
2: "left_wrist", 3: "right_wrist", # Hands approximate wrists
|
||||
7: "left_hip", 6: "right_hip", # Upper legs
|
||||
9: "left_knee", 8: "right_knee", # Lower legs
|
||||
4: "left_ankle", 5: "right_ankle", # Feet approximate ankles
|
||||
}
|
||||
|
||||
h, w = segmentation.shape[1], segmentation.shape[2]
|
||||
keypoints = []
|
||||
|
||||
# Get softmax probabilities
|
||||
seg_probs = torch.softmax(segmentation, dim=0)
|
||||
|
||||
for kp_name in keypoint_names:
|
||||
# Find which body part corresponds to this keypoint
|
||||
part_idx = None
|
||||
for part, name in part_to_keypoint.items():
|
||||
if name == kp_name:
|
||||
part_idx = part
|
||||
break
|
||||
|
||||
if part_idx is not None and part_idx < seg_probs.shape[0]:
|
||||
# Get probability map for this part within the region
|
||||
part_prob = seg_probs[part_idx] * region['mask'].float()
|
||||
|
||||
if part_prob.max() > 0.1:
|
||||
# Find location of maximum probability
|
||||
max_idx = part_prob.argmax()
|
||||
y = int(max_idx // w)
|
||||
x = int(max_idx % w)
|
||||
|
||||
keypoints.append({
|
||||
"name": kp_name,
|
||||
"x": float(x) / w,
|
||||
"y": float(y) / h,
|
||||
"confidence": float(part_prob.max().item())
|
||||
})
|
||||
else:
|
||||
# Keypoint not visible
|
||||
keypoints.append({
|
||||
"name": kp_name,
|
||||
"x": 0.0,
|
||||
"y": 0.0,
|
||||
"confidence": 0.0
|
||||
})
|
||||
else:
|
||||
# Estimate position based on body region
|
||||
cx = (region['x_min'] + region['x_max']) / 2 / w
|
||||
cy = (region['y_min'] + region['y_max']) / 2 / h
|
||||
keypoints.append({
|
||||
"name": kp_name,
|
||||
"x": float(cx),
|
||||
"y": float(cy),
|
||||
"confidence": 0.1
|
||||
})
|
||||
|
||||
return keypoints
|
||||
|
||||
def _calculate_bounding_box(self, region: Dict[str, Any]) -> Dict[str, float]:
|
||||
"""Calculate normalized bounding box from region."""
|
||||
# Assume region contains mask shape info
|
||||
mask = region['mask']
|
||||
h, w = mask.shape
|
||||
|
||||
return {
|
||||
"x": float(region['x_min']) / w,
|
||||
"y": float(region['y_min']) / h,
|
||||
"width": float(region['x_max'] - region['x_min']) / w,
|
||||
"height": float(region['y_max'] - region['y_min']) / h
|
||||
}
|
||||
|
||||
def _extract_body_parts(
|
||||
self, segmentation: torch.Tensor, region: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Extract body part information from segmentation."""
|
||||
part_names = [
|
||||
"background", "torso", "right_hand", "left_hand", "left_foot", "right_foot",
|
||||
"upper_leg_right", "upper_leg_left", "lower_leg_right", "lower_leg_left",
|
||||
"upper_arm_left", "upper_arm_right", "lower_arm_left", "lower_arm_right", "head"
|
||||
]
|
||||
|
||||
seg_probs = torch.softmax(segmentation, dim=0)
|
||||
region_mask = region['mask']
|
||||
|
||||
parts = {}
|
||||
for i, name in enumerate(part_names):
|
||||
if i < seg_probs.shape[0]:
|
||||
part_prob = seg_probs[i] * region_mask.float()
|
||||
parts[name] = {
|
||||
"present": bool(part_prob.max() > 0.3),
|
||||
"confidence": float(part_prob.max().item()),
|
||||
"coverage": float((part_prob > 0.3).sum().item() / max(1, region_mask.sum().item()))
|
||||
}
|
||||
|
||||
return parts
|
||||
|
||||
def _extract_keypoints_from_tensor(self, output: torch.Tensor) -> List[Dict[str, Any]]:
|
||||
"""Extract keypoints from a generic output tensor."""
|
||||
keypoint_names = [
|
||||
"nose", "left_eye", "right_eye", "left_ear", "right_ear",
|
||||
"left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
|
||||
"left_wrist", "right_wrist", "left_hip", "right_hip",
|
||||
"left_knee", "right_knee", "left_ankle", "right_ankle"
|
||||
]
|
||||
|
||||
keypoints = []
|
||||
|
||||
# Try to interpret output as heatmaps
|
||||
if output.dim() >= 2:
|
||||
flat = output.flatten()
|
||||
num_kp = len(keypoint_names)
|
||||
|
||||
# Divide output evenly for each keypoint
|
||||
chunk_size = len(flat) // num_kp if num_kp > 0 else 1
|
||||
|
||||
for i, name in enumerate(keypoint_names):
|
||||
start = i * chunk_size
|
||||
end = min(start + chunk_size, len(flat))
|
||||
|
||||
if start < len(flat):
|
||||
chunk = flat[start:end]
|
||||
# Find max location in chunk
|
||||
max_val = chunk.max().item()
|
||||
max_idx = chunk.argmax().item()
|
||||
|
||||
# Convert to x, y (assume square spatial layout)
|
||||
side = int(np.sqrt(chunk_size))
|
||||
if side > 0:
|
||||
x = (max_idx % side) / side
|
||||
y = (max_idx // side) / side
|
||||
else:
|
||||
x, y = 0.5, 0.5
|
||||
|
||||
keypoints.append({
|
||||
"name": name,
|
||||
"x": float(x),
|
||||
"y": float(y),
|
||||
"confidence": float(torch.sigmoid(torch.tensor(max_val)).item())
|
||||
})
|
||||
else:
|
||||
keypoints.append({
|
||||
"name": name, "x": 0.5, "y": 0.5, "confidence": 0.0
|
||||
})
|
||||
else:
|
||||
# Fallback
|
||||
for name in keypoint_names:
|
||||
keypoints.append({"name": name, "x": 0.5, "y": 0.5, "confidence": 0.1})
|
||||
|
||||
return keypoints
|
||||
|
||||
def _estimate_bbox_from_keypoints(self, keypoints: List[Dict[str, Any]]) -> Dict[str, float]:
|
||||
"""Estimate bounding box from keypoint positions."""
|
||||
valid_kps = [kp for kp in keypoints if kp['confidence'] > 0.1]
|
||||
|
||||
if not valid_kps:
|
||||
return {"x": 0.3, "y": 0.2, "width": 0.4, "height": 0.6}
|
||||
|
||||
xs = [kp['x'] for kp in valid_kps]
|
||||
ys = [kp['y'] for kp in valid_kps]
|
||||
|
||||
x_min, x_max = min(xs), max(xs)
|
||||
y_min, y_max = min(ys), max(ys)
|
||||
|
||||
# Add padding
|
||||
padding = 0.05
|
||||
x_min = max(0, x_min - padding)
|
||||
y_min = max(0, y_min - padding)
|
||||
x_max = min(1, x_max + padding)
|
||||
y_max = min(1, y_max + padding)
|
||||
|
||||
return {
|
||||
"x": x_min,
|
||||
"y": y_min,
|
||||
"width": x_max - x_min,
|
||||
"height": y_max - y_min
|
||||
}
|
||||
|
||||
def _classify_activity_from_keypoints(self, keypoints: List[Dict[str, Any]]) -> str:
|
||||
"""Classify activity based on keypoint positions."""
|
||||
# Get key body parts
|
||||
kp_dict = {kp['name']: kp for kp in keypoints}
|
||||
|
||||
# Check if enough keypoints are detected
|
||||
valid_count = sum(1 for kp in keypoints if kp['confidence'] > 0.3)
|
||||
if valid_count < 5:
|
||||
return "unknown"
|
||||
|
||||
# Get relevant keypoints
|
||||
nose = kp_dict.get('nose', {})
|
||||
l_hip = kp_dict.get('left_hip', {})
|
||||
r_hip = kp_dict.get('right_hip', {})
|
||||
l_ankle = kp_dict.get('left_ankle', {})
|
||||
r_ankle = kp_dict.get('right_ankle', {})
|
||||
l_shoulder = kp_dict.get('left_shoulder', {})
|
||||
r_shoulder = kp_dict.get('right_shoulder', {})
|
||||
|
||||
# Calculate body metrics
|
||||
hip_y = (l_hip.get('y', 0.5) + r_hip.get('y', 0.5)) / 2
|
||||
ankle_y = (l_ankle.get('y', 0.8) + r_ankle.get('y', 0.8)) / 2
|
||||
shoulder_y = (l_shoulder.get('y', 0.3) + r_shoulder.get('y', 0.3)) / 2
|
||||
nose_y = nose.get('y', 0.2)
|
||||
|
||||
# Leg spread (horizontal distance between ankles)
|
||||
leg_spread = abs(l_ankle.get('x', 0.5) - r_ankle.get('x', 0.5))
|
||||
|
||||
# Vertical compression (how "tall" the pose is)
|
||||
vertical_span = ankle_y - nose_y if ankle_y > nose_y else 0.6
|
||||
|
||||
# Classification logic
|
||||
if vertical_span < 0.3:
|
||||
# Very compressed vertically - likely lying down
|
||||
return "lying"
|
||||
elif vertical_span < 0.45 and hip_y > 0.5:
|
||||
# Medium compression with low hips - sitting
|
||||
return "sitting"
|
||||
elif leg_spread > 0.15:
|
||||
# Legs apart - likely walking
|
||||
return "walking"
|
||||
else:
|
||||
# Default upright pose
|
||||
return "standing"
|
||||
|
||||
def _generate_mock_poses(self) -> List[Dict[str, Any]]:
|
||||
"""Generate mock pose data for development."""
|
||||
import random
|
||||
|
||||
Reference in New Issue
Block a user