Compare commits
1 Commits
claude/tes
...
salmanmkc/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e8f16722b9 |
12
.github/workflows/cd.yml
vendored
12
.github/workflows/cd.yml
vendored
@@ -41,7 +41,7 @@ jobs:
|
|||||||
image_tag: ${{ steps.determine-tag.outputs.tag }}
|
image_tag: ${{ steps.determine-tag.outputs.tag }}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Determine deployment environment
|
- name: Determine deployment environment
|
||||||
id: determine-env
|
id: determine-env
|
||||||
@@ -80,7 +80,7 @@ jobs:
|
|||||||
url: https://staging.wifi-densepose.com
|
url: https://staging.wifi-densepose.com
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Set up kubectl
|
- name: Set up kubectl
|
||||||
uses: azure/setup-kubectl@v3
|
uses: azure/setup-kubectl@v3
|
||||||
@@ -126,7 +126,7 @@ jobs:
|
|||||||
url: https://wifi-densepose.com
|
url: https://wifi-densepose.com
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Set up kubectl
|
- name: Set up kubectl
|
||||||
uses: azure/setup-kubectl@v3
|
uses: azure/setup-kubectl@v3
|
||||||
@@ -199,7 +199,7 @@ jobs:
|
|||||||
# kubectl scale rs -n wifi-densepose -l app=wifi-densepose,version!=green --replicas=0
|
# kubectl scale rs -n wifi-densepose -l app=wifi-densepose,version!=green --replicas=0
|
||||||
|
|
||||||
- name: Upload deployment artifacts
|
- name: Upload deployment artifacts
|
||||||
uses: actions/upload-artifact@v3
|
uses: actions/upload-artifact@v6
|
||||||
with:
|
with:
|
||||||
name: production-deployment-${{ github.run_number }}
|
name: production-deployment-${{ github.run_number }}
|
||||||
path: |
|
path: |
|
||||||
@@ -270,7 +270,7 @@ jobs:
|
|||||||
done
|
done
|
||||||
|
|
||||||
- name: Update deployment status
|
- name: Update deployment status
|
||||||
uses: actions/github-script@v6
|
uses: actions/github-script@v8
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
const deployEnv = '${{ needs.pre-deployment.outputs.deploy_env }}';
|
const deployEnv = '${{ needs.pre-deployment.outputs.deploy_env }}';
|
||||||
@@ -321,7 +321,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Create deployment issue on failure
|
- name: Create deployment issue on failure
|
||||||
if: needs.deploy-production.result == 'failure'
|
if: needs.deploy-production.result == 'failure'
|
||||||
uses: actions/github-script@v6
|
uses: actions/github-script@v8
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
github.rest.issues.create({
|
github.rest.issues.create({
|
||||||
|
|||||||
24
.github/workflows/ci.yml
vendored
24
.github/workflows/ci.yml
vendored
@@ -20,12 +20,12 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: ${{ env.PYTHON_VERSION }}
|
python-version: ${{ env.PYTHON_VERSION }}
|
||||||
cache: 'pip'
|
cache: 'pip'
|
||||||
@@ -54,7 +54,7 @@ jobs:
|
|||||||
continue-on-error: true
|
continue-on-error: true
|
||||||
|
|
||||||
- name: Upload security reports
|
- name: Upload security reports
|
||||||
uses: actions/upload-artifact@v3
|
uses: actions/upload-artifact@v6
|
||||||
if: always()
|
if: always()
|
||||||
with:
|
with:
|
||||||
name: security-reports
|
name: security-reports
|
||||||
@@ -95,10 +95,10 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
cache: 'pip'
|
cache: 'pip'
|
||||||
@@ -133,7 +133,7 @@ jobs:
|
|||||||
name: codecov-umbrella
|
name: codecov-umbrella
|
||||||
|
|
||||||
- name: Upload test results
|
- name: Upload test results
|
||||||
uses: actions/upload-artifact@v3
|
uses: actions/upload-artifact@v6
|
||||||
if: always()
|
if: always()
|
||||||
with:
|
with:
|
||||||
name: test-results-${{ matrix.python-version }}
|
name: test-results-${{ matrix.python-version }}
|
||||||
@@ -150,10 +150,10 @@ jobs:
|
|||||||
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
|
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: ${{ env.PYTHON_VERSION }}
|
python-version: ${{ env.PYTHON_VERSION }}
|
||||||
cache: 'pip'
|
cache: 'pip'
|
||||||
@@ -174,7 +174,7 @@ jobs:
|
|||||||
locust -f tests/performance/locustfile.py --headless --users 50 --spawn-rate 5 --run-time 60s --host http://localhost:8000
|
locust -f tests/performance/locustfile.py --headless --users 50 --spawn-rate 5 --run-time 60s --host http://localhost:8000
|
||||||
|
|
||||||
- name: Upload performance results
|
- name: Upload performance results
|
||||||
uses: actions/upload-artifact@v3
|
uses: actions/upload-artifact@v6
|
||||||
with:
|
with:
|
||||||
name: performance-results
|
name: performance-results
|
||||||
path: locust_report.html
|
path: locust_report.html
|
||||||
@@ -186,7 +186,7 @@ jobs:
|
|||||||
needs: [code-quality, test]
|
needs: [code-quality, test]
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
@@ -249,10 +249,10 @@ jobs:
|
|||||||
if: github.ref == 'refs/heads/main'
|
if: github.ref == 'refs/heads/main'
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: ${{ env.PYTHON_VERSION }}
|
python-version: ${{ env.PYTHON_VERSION }}
|
||||||
cache: 'pip'
|
cache: 'pip'
|
||||||
|
|||||||
30
.github/workflows/security-scan.yml
vendored
30
.github/workflows/security-scan.yml
vendored
@@ -24,12 +24,12 @@ jobs:
|
|||||||
contents: read
|
contents: read
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: ${{ env.PYTHON_VERSION }}
|
python-version: ${{ env.PYTHON_VERSION }}
|
||||||
cache: 'pip'
|
cache: 'pip'
|
||||||
@@ -86,10 +86,10 @@ jobs:
|
|||||||
contents: read
|
contents: read
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: ${{ env.PYTHON_VERSION }}
|
python-version: ${{ env.PYTHON_VERSION }}
|
||||||
cache: 'pip'
|
cache: 'pip'
|
||||||
@@ -126,7 +126,7 @@ jobs:
|
|||||||
category: snyk
|
category: snyk
|
||||||
|
|
||||||
- name: Upload vulnerability reports
|
- name: Upload vulnerability reports
|
||||||
uses: actions/upload-artifact@v3
|
uses: actions/upload-artifact@v6
|
||||||
if: always()
|
if: always()
|
||||||
with:
|
with:
|
||||||
name: vulnerability-reports
|
name: vulnerability-reports
|
||||||
@@ -147,7 +147,7 @@ jobs:
|
|||||||
contents: read
|
contents: read
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
@@ -218,7 +218,7 @@ jobs:
|
|||||||
contents: read
|
contents: read
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Run Checkov IaC scan
|
- name: Run Checkov IaC scan
|
||||||
uses: bridgecrewio/checkov-action@master
|
uses: bridgecrewio/checkov-action@master
|
||||||
@@ -272,7 +272,7 @@ jobs:
|
|||||||
contents: read
|
contents: read
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
@@ -303,10 +303,10 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: ${{ env.PYTHON_VERSION }}
|
python-version: ${{ env.PYTHON_VERSION }}
|
||||||
cache: 'pip'
|
cache: 'pip'
|
||||||
@@ -323,7 +323,7 @@ jobs:
|
|||||||
licensecheck --zero
|
licensecheck --zero
|
||||||
|
|
||||||
- name: Upload license report
|
- name: Upload license report
|
||||||
uses: actions/upload-artifact@v3
|
uses: actions/upload-artifact@v6
|
||||||
with:
|
with:
|
||||||
name: license-report
|
name: license-report
|
||||||
path: licenses.json
|
path: licenses.json
|
||||||
@@ -334,7 +334,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Check security policy files
|
- name: Check security policy files
|
||||||
run: |
|
run: |
|
||||||
@@ -376,7 +376,7 @@ jobs:
|
|||||||
if: always()
|
if: always()
|
||||||
steps:
|
steps:
|
||||||
- name: Download all artifacts
|
- name: Download all artifacts
|
||||||
uses: actions/download-artifact@v3
|
uses: actions/download-artifact@v7
|
||||||
|
|
||||||
- name: Generate security summary
|
- name: Generate security summary
|
||||||
run: |
|
run: |
|
||||||
@@ -394,7 +394,7 @@ jobs:
|
|||||||
echo "Generated on: $(date)" >> security-summary.md
|
echo "Generated on: $(date)" >> security-summary.md
|
||||||
|
|
||||||
- name: Upload security summary
|
- name: Upload security summary
|
||||||
uses: actions/upload-artifact@v3
|
uses: actions/upload-artifact@v6
|
||||||
with:
|
with:
|
||||||
name: security-summary
|
name: security-summary
|
||||||
path: security-summary.md
|
path: security-summary.md
|
||||||
@@ -416,7 +416,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Create security issue on critical findings
|
- name: Create security issue on critical findings
|
||||||
if: needs.sast.result == 'failure' || needs.dependency-scan.result == 'failure'
|
if: needs.sast.result == 'failure' || needs.dependency-scan.result == 'failure'
|
||||||
uses: actions/github-script@v6
|
uses: actions/github-script@v8
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
github.rest.issues.create({
|
github.rest.issues.create({
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
//! Breathing pattern detection from CSI signals.
|
//! Breathing pattern detection from CSI signals.
|
||||||
|
|
||||||
use crate::domain::{BreathingPattern, BreathingType};
|
use crate::domain::{BreathingPattern, BreathingType, ConfidenceScore};
|
||||||
|
|
||||||
/// Configuration for breathing detection
|
/// Configuration for breathing detection
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
//! This module provides both traditional signal-processing-based detection
|
//! This module provides both traditional signal-processing-based detection
|
||||||
//! and optional ML-enhanced detection for improved accuracy.
|
//! and optional ML-enhanced detection for improved accuracy.
|
||||||
|
|
||||||
use crate::domain::{ScanZone, VitalSignsReading};
|
use crate::domain::{ScanZone, VitalSignsReading, ConfidenceScore};
|
||||||
use crate::ml::{MlDetectionConfig, MlDetectionPipeline, MlDetectionResult};
|
use crate::ml::{MlDetectionConfig, MlDetectionPipeline, MlDetectionResult};
|
||||||
use crate::{DisasterConfig, MatError};
|
use crate::{DisasterConfig, MatError};
|
||||||
use super::{
|
use super::{
|
||||||
|
|||||||
@@ -28,6 +28,8 @@ use chrono::{DateTime, Utc};
|
|||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
use std::io::{BufReader, Read};
|
use std::io::{BufReader, Read};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::{mpsc, Mutex};
|
||||||
|
|
||||||
/// Configuration for CSI receivers
|
/// Configuration for CSI receivers
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
|
|||||||
@@ -16,10 +16,13 @@
|
|||||||
//! - Depth estimation head with uncertainty (mean + variance output)
|
//! - Depth estimation head with uncertainty (mean + variance output)
|
||||||
|
|
||||||
use super::{DebrisFeatures, DepthEstimate, MlError, MlResult};
|
use super::{DebrisFeatures, DepthEstimate, MlError, MlResult};
|
||||||
use ndarray::{Array2, Array4};
|
use ndarray::{Array1, Array2, Array4, s};
|
||||||
|
use std::collections::HashMap;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use parking_lot::RwLock;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tracing::{info, instrument, warn};
|
use tracing::{debug, info, instrument, warn};
|
||||||
|
|
||||||
#[cfg(feature = "onnx")]
|
#[cfg(feature = "onnx")]
|
||||||
use wifi_densepose_nn::{OnnxBackend, OnnxSession, InferenceOptions, Tensor, TensorShape};
|
use wifi_densepose_nn::{OnnxBackend, OnnxSession, InferenceOptions, Tensor, TensorShape};
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ pub use vital_signs_classifier::{
|
|||||||
};
|
};
|
||||||
|
|
||||||
use crate::detection::CsiDataBuffer;
|
use crate::detection::CsiDataBuffer;
|
||||||
|
use crate::domain::{VitalSignsReading, BreathingPattern, HeartbeatSignature};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|||||||
@@ -27,8 +27,12 @@ use crate::domain::{
|
|||||||
BreathingPattern, BreathingType, HeartbeatSignature, MovementProfile,
|
BreathingPattern, BreathingType, HeartbeatSignature, MovementProfile,
|
||||||
MovementType, SignalStrength, VitalSignsReading,
|
MovementType, SignalStrength, VitalSignsReading,
|
||||||
};
|
};
|
||||||
|
use ndarray::{Array1, Array2, Array4, s};
|
||||||
|
use std::collections::HashMap;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use tracing::{info, instrument, warn};
|
use std::sync::Arc;
|
||||||
|
use parking_lot::RwLock;
|
||||||
|
use tracing::{debug, info, instrument, warn};
|
||||||
|
|
||||||
#[cfg(feature = "onnx")]
|
#[cfg(feature = "onnx")]
|
||||||
use wifi_densepose_nn::{OnnxBackend, OnnxSession, InferenceOptions, Tensor, TensorShape};
|
use wifi_densepose_nn::{OnnxBackend, OnnxSession, InferenceOptions, Tensor, TensorShape};
|
||||||
|
|||||||
@@ -252,7 +252,7 @@ impl DensePoseHead {
|
|||||||
})?;
|
})?;
|
||||||
|
|
||||||
let input_arr = input.as_array4()?;
|
let input_arr = input.as_array4()?;
|
||||||
let (_batch, _channels, _height, _width) = input_arr.dim();
|
let (batch, _channels, height, width) = input_arr.dim();
|
||||||
|
|
||||||
// Apply shared convolutions
|
// Apply shared convolutions
|
||||||
let mut current = input_arr.clone();
|
let mut current = input_arr.clone();
|
||||||
|
|||||||
@@ -206,7 +206,7 @@ impl Backend for MockBackend {
|
|||||||
self.output_shapes.get(name).cloned()
|
self.output_shapes.get(name).cloned()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run(&self, _inputs: HashMap<String, Tensor>) -> NnResult<HashMap<String, Tensor>> {
|
fn run(&self, inputs: HashMap<String, Tensor>) -> NnResult<HashMap<String, Tensor>> {
|
||||||
let mut outputs = HashMap::new();
|
let mut outputs = HashMap::new();
|
||||||
|
|
||||||
for (name, shape) in &self.output_shapes {
|
for (name, shape) in &self.output_shapes {
|
||||||
|
|||||||
@@ -266,7 +266,7 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Apply softmax along axis
|
/// Apply softmax along axis
|
||||||
pub fn softmax(&self, _axis: usize) -> NnResult<Tensor> {
|
pub fn softmax(&self, axis: usize) -> NnResult<Tensor> {
|
||||||
match self {
|
match self {
|
||||||
Tensor::Float4D(a) => {
|
Tensor::Float4D(a) => {
|
||||||
let max = a.fold(f32::NEG_INFINITY, |acc, &x| acc.max(x));
|
let max = a.fold(f32::NEG_INFINITY, |acc, &x| acc.max(x));
|
||||||
|
|||||||
@@ -342,7 +342,7 @@ impl ModalityTranslator {
|
|||||||
})?;
|
})?;
|
||||||
|
|
||||||
let input_arr = input.as_array4()?;
|
let input_arr = input.as_array4()?;
|
||||||
let (_batch, _channels, _height, _width) = input_arr.dim();
|
let (batch, _channels, height, width) = input_arr.dim();
|
||||||
|
|
||||||
// Encode
|
// Encode
|
||||||
let mut encoder_outputs = Vec::new();
|
let mut encoder_outputs = Vec::new();
|
||||||
@@ -461,7 +461,7 @@ impl ModalityTranslator {
|
|||||||
weights: &ConvBlockWeights,
|
weights: &ConvBlockWeights,
|
||||||
) -> NnResult<Array4<f32>> {
|
) -> NnResult<Array4<f32>> {
|
||||||
let (batch, in_channels, in_height, in_width) = input.dim();
|
let (batch, in_channels, in_height, in_width) = input.dim();
|
||||||
let (out_channels, _, _kernel_h, _kernel_w) = weights.conv_weight.dim();
|
let (out_channels, _, kernel_h, kernel_w) = weights.conv_weight.dim();
|
||||||
|
|
||||||
// Upsample 2x
|
// Upsample 2x
|
||||||
let out_height = in_height * 2;
|
let out_height = in_height * 2;
|
||||||
@@ -536,7 +536,7 @@ impl ModalityTranslator {
|
|||||||
fn apply_attention(
|
fn apply_attention(
|
||||||
&self,
|
&self,
|
||||||
input: &Array4<f32>,
|
input: &Array4<f32>,
|
||||||
_weights: &AttentionWeights,
|
weights: &AttentionWeights,
|
||||||
) -> NnResult<(Array4<f32>, Array4<f32>)> {
|
) -> NnResult<(Array4<f32>, Array4<f32>)> {
|
||||||
let (batch, channels, height, width) = input.dim();
|
let (batch, channels, height, width) = input.dim();
|
||||||
let seq_len = height * width;
|
let seq_len = height * width;
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ Author: WiFi-DensePose Team
|
|||||||
License: MIT
|
License: MIT
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__version__ = "1.2.0"
|
__version__ = "1.1.0"
|
||||||
__author__ = "WiFi-DensePose Team"
|
__author__ = "WiFi-DensePose Team"
|
||||||
__email__ = "team@wifi-densepose.com"
|
__email__ = "team@wifi-densepose.com"
|
||||||
__license__ = "MIT"
|
__license__ = "MIT"
|
||||||
|
|||||||
@@ -5,27 +5,9 @@ Core package for WiFi-DensePose API
|
|||||||
from .csi_processor import CSIProcessor
|
from .csi_processor import CSIProcessor
|
||||||
from .phase_sanitizer import PhaseSanitizer
|
from .phase_sanitizer import PhaseSanitizer
|
||||||
from .router_interface import RouterInterface
|
from .router_interface import RouterInterface
|
||||||
from .vital_signs import (
|
|
||||||
VitalSignsDetector,
|
|
||||||
BreathingDetector,
|
|
||||||
HeartbeatDetector,
|
|
||||||
BreathingPattern,
|
|
||||||
HeartbeatSignature,
|
|
||||||
VitalSignsReading,
|
|
||||||
BreathingType,
|
|
||||||
SignalStrength,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'CSIProcessor',
|
'CSIProcessor',
|
||||||
'PhaseSanitizer',
|
'PhaseSanitizer',
|
||||||
'RouterInterface',
|
'RouterInterface'
|
||||||
'VitalSignsDetector',
|
|
||||||
'BreathingDetector',
|
|
||||||
'HeartbeatDetector',
|
|
||||||
'BreathingPattern',
|
|
||||||
'HeartbeatSignature',
|
|
||||||
'VitalSignsReading',
|
|
||||||
'BreathingType',
|
|
||||||
'SignalStrength',
|
|
||||||
]
|
]
|
||||||
@@ -385,69 +385,13 @@ class CSIProcessor:
|
|||||||
return correlation_matrix
|
return correlation_matrix
|
||||||
|
|
||||||
def _extract_doppler_features(self, csi_data: CSIData) -> tuple:
|
def _extract_doppler_features(self, csi_data: CSIData) -> tuple:
|
||||||
"""Extract Doppler and frequency domain features.
|
"""Extract Doppler and frequency domain features."""
|
||||||
|
# Simple Doppler estimation (would use history in real implementation)
|
||||||
Doppler shift estimation from CSI phase changes:
|
doppler_shift = np.random.rand(10) # Placeholder
|
||||||
- Phase change rate indicates velocity of moving objects
|
|
||||||
- Frequency analysis reveals movement speed and direction
|
# Power spectral density
|
||||||
|
|
||||||
The Doppler frequency shift is: f_d = (2 * v * f_c) / c
|
|
||||||
Where v = velocity, f_c = carrier frequency, c = speed of light
|
|
||||||
"""
|
|
||||||
# Power spectral density of amplitude
|
|
||||||
psd = np.abs(scipy.fft.fft(csi_data.amplitude.flatten(), n=128))**2
|
psd = np.abs(scipy.fft.fft(csi_data.amplitude.flatten(), n=128))**2
|
||||||
|
|
||||||
# Doppler estimation from phase history
|
|
||||||
if len(self.csi_history) < 2:
|
|
||||||
# Not enough history, return zeros
|
|
||||||
doppler_shift = np.zeros(min(csi_data.num_subcarriers, 10))
|
|
||||||
return doppler_shift, psd
|
|
||||||
|
|
||||||
# Get phase from current and previous samples
|
|
||||||
current_phase = csi_data.phase.flatten()
|
|
||||||
prev_data = self.csi_history[-1]
|
|
||||||
|
|
||||||
# Handle if prev_data is tuple (CSIData, features) or just CSIData
|
|
||||||
if isinstance(prev_data, tuple):
|
|
||||||
prev_phase = prev_data[0].phase.flatten()
|
|
||||||
time_delta = (csi_data.timestamp - prev_data[0].timestamp).total_seconds()
|
|
||||||
else:
|
|
||||||
prev_phase = prev_data.phase.flatten()
|
|
||||||
time_delta = 1.0 / self.sampling_rate # Default to sampling interval
|
|
||||||
|
|
||||||
if time_delta <= 0:
|
|
||||||
time_delta = 1.0 / self.sampling_rate
|
|
||||||
|
|
||||||
# Ensure same length
|
|
||||||
min_len = min(len(current_phase), len(prev_phase))
|
|
||||||
current_phase = current_phase[:min_len]
|
|
||||||
prev_phase = prev_phase[:min_len]
|
|
||||||
|
|
||||||
# Calculate phase difference (unwrap to handle wrapping)
|
|
||||||
phase_diff = np.unwrap(current_phase) - np.unwrap(prev_phase)
|
|
||||||
|
|
||||||
# Phase rate of change (rad/s)
|
|
||||||
phase_rate = phase_diff / time_delta
|
|
||||||
|
|
||||||
# Convert to Doppler frequency (Hz)
|
|
||||||
# f_d = (d_phi/dt) / (2 * pi)
|
|
||||||
doppler_freq = phase_rate / (2 * np.pi)
|
|
||||||
|
|
||||||
# Aggregate Doppler per subcarrier group (reduce to ~10 values)
|
|
||||||
num_groups = min(10, len(doppler_freq))
|
|
||||||
group_size = max(1, len(doppler_freq) // num_groups)
|
|
||||||
|
|
||||||
doppler_shift = np.array([
|
|
||||||
np.mean(doppler_freq[i*group_size:(i+1)*group_size])
|
|
||||||
for i in range(num_groups)
|
|
||||||
])
|
|
||||||
|
|
||||||
# Apply smoothing to reduce noise
|
|
||||||
if len(doppler_shift) > 3:
|
|
||||||
# Simple moving average
|
|
||||||
kernel = np.ones(3) / 3
|
|
||||||
doppler_shift = np.convolve(doppler_shift, kernel, mode='same')
|
|
||||||
|
|
||||||
return doppler_shift, psd
|
return doppler_shift, psd
|
||||||
|
|
||||||
def _analyze_motion_patterns(self, features: CSIFeatures) -> float:
|
def _analyze_motion_patterns(self, features: CSIFeatures) -> float:
|
||||||
|
|||||||
@@ -1,27 +1,15 @@
|
|||||||
"""
|
"""
|
||||||
Router interface for WiFi CSI data collection.
|
Router interface for WiFi CSI data collection
|
||||||
|
|
||||||
Supports multiple router types:
|
|
||||||
- OpenWRT routers with Atheros CSI Tool
|
|
||||||
- DD-WRT routers with custom CSI extraction
|
|
||||||
- Custom firmware routers with raw CSI access
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
import struct
|
|
||||||
import time
|
import time
|
||||||
from typing import Dict, List, Optional, Any, Tuple
|
from typing import Dict, List, Optional, Any
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
try:
|
|
||||||
import asyncssh
|
|
||||||
HAS_ASYNCSSH = True
|
|
||||||
except ImportError:
|
|
||||||
HAS_ASYNCSSH = False
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -84,80 +72,28 @@ class RouterInterface:
|
|||||||
}
|
}
|
||||||
|
|
||||||
async def connect(self):
|
async def connect(self):
|
||||||
"""Connect to the router via SSH."""
|
"""Connect to the router."""
|
||||||
if self.mock_mode:
|
if self.mock_mode:
|
||||||
self.is_connected = True
|
self.is_connected = True
|
||||||
self.logger.info(f"Mock connection established to router {self.router_id}")
|
self.logger.info(f"Mock connection established to router {self.router_id}")
|
||||||
return
|
return
|
||||||
|
|
||||||
if not HAS_ASYNCSSH:
|
|
||||||
self.logger.warning("asyncssh not available, falling back to mock mode")
|
|
||||||
self.mock_mode = True
|
|
||||||
self._initialize_mock_generator()
|
|
||||||
self.is_connected = True
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.logger.info(f"Connecting to router {self.router_id} at {self.host}:{self.port}")
|
self.logger.info(f"Connecting to router {self.router_id} at {self.host}:{self.port}")
|
||||||
|
|
||||||
# Establish SSH connection
|
# In a real implementation, this would establish SSH connection
|
||||||
self.connection = await asyncssh.connect(
|
# For now, we'll simulate the connection
|
||||||
self.host,
|
await asyncio.sleep(0.1) # Simulate connection delay
|
||||||
port=self.port,
|
|
||||||
username=self.username,
|
|
||||||
password=self.password if self.password else None,
|
|
||||||
known_hosts=None, # Disable host key checking for embedded devices
|
|
||||||
connect_timeout=10
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify connection by checking router type
|
|
||||||
await self._detect_router_type()
|
|
||||||
|
|
||||||
self.is_connected = True
|
self.is_connected = True
|
||||||
self.error_count = 0
|
self.error_count = 0
|
||||||
self.logger.info(f"Connected to router {self.router_id}")
|
self.logger.info(f"Connected to router {self.router_id}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.last_error = str(e)
|
self.last_error = str(e)
|
||||||
self.error_count += 1
|
self.error_count += 1
|
||||||
self.logger.error(f"Failed to connect to router {self.router_id}: {e}")
|
self.logger.error(f"Failed to connect to router {self.router_id}: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def _detect_router_type(self):
|
|
||||||
"""Detect router firmware type and CSI capabilities."""
|
|
||||||
if not self.connection:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Check for OpenWRT
|
|
||||||
result = await self.connection.run('cat /etc/openwrt_release 2>/dev/null || echo ""', check=False)
|
|
||||||
if 'OpenWrt' in result.stdout:
|
|
||||||
self.router_type = 'openwrt'
|
|
||||||
self.logger.info(f"Detected OpenWRT router: {self.router_id}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check for DD-WRT
|
|
||||||
result = await self.connection.run('nvram get DD_BOARD 2>/dev/null || echo ""', check=False)
|
|
||||||
if result.stdout.strip():
|
|
||||||
self.router_type = 'ddwrt'
|
|
||||||
self.logger.info(f"Detected DD-WRT router: {self.router_id}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check for Atheros CSI Tool
|
|
||||||
result = await self.connection.run('which csi_tool 2>/dev/null || echo ""', check=False)
|
|
||||||
if result.stdout.strip():
|
|
||||||
self.csi_tool_path = result.stdout.strip()
|
|
||||||
self.router_type = 'atheros_csi'
|
|
||||||
self.logger.info(f"Detected Atheros CSI Tool on router: {self.router_id}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Default to generic Linux
|
|
||||||
self.router_type = 'generic'
|
|
||||||
self.logger.info(f"Generic Linux router: {self.router_id}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.warning(f"Could not detect router type: {e}")
|
|
||||||
self.router_type = 'unknown'
|
|
||||||
|
|
||||||
async def disconnect(self):
|
async def disconnect(self):
|
||||||
"""Disconnect from the router."""
|
"""Disconnect from the router."""
|
||||||
@@ -259,244 +195,11 @@ class RouterInterface:
|
|||||||
return csi_data
|
return csi_data
|
||||||
|
|
||||||
async def _collect_real_csi_data(self) -> Optional[np.ndarray]:
|
async def _collect_real_csi_data(self) -> Optional[np.ndarray]:
|
||||||
"""Collect real CSI data from router via SSH.
|
"""Collect real CSI data from router (placeholder implementation)."""
|
||||||
|
# This would implement the actual CSI data collection
|
||||||
Supports multiple CSI extraction methods:
|
# For now, return None to indicate no real implementation
|
||||||
- Atheros CSI Tool (ath9k/ath10k)
|
self.logger.warning("Real CSI data collection not implemented")
|
||||||
- Custom kernel module reading
|
return None
|
||||||
- Proc filesystem access
|
|
||||||
- Raw device file reading
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Numpy array of complex CSI values or None on failure
|
|
||||||
"""
|
|
||||||
if not self.connection:
|
|
||||||
self.logger.error("No SSH connection available")
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
router_type = getattr(self, 'router_type', 'unknown')
|
|
||||||
|
|
||||||
if router_type == 'atheros_csi':
|
|
||||||
return await self._collect_atheros_csi()
|
|
||||||
elif router_type == 'openwrt':
|
|
||||||
return await self._collect_openwrt_csi()
|
|
||||||
else:
|
|
||||||
return await self._collect_generic_csi()
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Error collecting CSI data: {e}")
|
|
||||||
self.error_count += 1
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _collect_atheros_csi(self) -> Optional[np.ndarray]:
|
|
||||||
"""Collect CSI using Atheros CSI Tool."""
|
|
||||||
csi_tool = getattr(self, 'csi_tool_path', '/usr/bin/csi_tool')
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Read single CSI sample
|
|
||||||
result = await self.connection.run(
|
|
||||||
f'{csi_tool} -i {self.interface} -c 1 -f /tmp/csi_sample.dat && '
|
|
||||||
f'cat /tmp/csi_sample.dat | base64',
|
|
||||||
check=True,
|
|
||||||
timeout=5
|
|
||||||
)
|
|
||||||
|
|
||||||
# Decode base64 CSI data
|
|
||||||
import base64
|
|
||||||
csi_bytes = base64.b64decode(result.stdout.strip())
|
|
||||||
|
|
||||||
return self._parse_atheros_csi_bytes(csi_bytes)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Atheros CSI collection failed: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _collect_openwrt_csi(self) -> Optional[np.ndarray]:
|
|
||||||
"""Collect CSI from OpenWRT with CSI support."""
|
|
||||||
try:
|
|
||||||
# Try reading from debugfs (common CSI location)
|
|
||||||
result = await self.connection.run(
|
|
||||||
f'cat /sys/kernel/debug/ieee80211/phy0/ath9k/csi 2>/dev/null | head -c 4096 | base64',
|
|
||||||
check=False,
|
|
||||||
timeout=5
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.returncode == 0 and result.stdout.strip():
|
|
||||||
import base64
|
|
||||||
csi_bytes = base64.b64decode(result.stdout.strip())
|
|
||||||
return self._parse_atheros_csi_bytes(csi_bytes)
|
|
||||||
|
|
||||||
# Try alternate location
|
|
||||||
result = await self.connection.run(
|
|
||||||
f'cat /proc/csi 2>/dev/null | head -c 4096 | base64',
|
|
||||||
check=False,
|
|
||||||
timeout=5
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.returncode == 0 and result.stdout.strip():
|
|
||||||
import base64
|
|
||||||
csi_bytes = base64.b64decode(result.stdout.strip())
|
|
||||||
return self._parse_generic_csi_bytes(csi_bytes)
|
|
||||||
|
|
||||||
self.logger.warning("No CSI data available from OpenWRT paths")
|
|
||||||
return None
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"OpenWRT CSI collection failed: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _collect_generic_csi(self) -> Optional[np.ndarray]:
|
|
||||||
"""Collect CSI using generic Linux methods."""
|
|
||||||
try:
|
|
||||||
# Try iw command for station info (not real CSI but channel info)
|
|
||||||
result = await self.connection.run(
|
|
||||||
f'iw dev {self.interface} survey dump 2>/dev/null || echo ""',
|
|
||||||
check=False,
|
|
||||||
timeout=5
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.stdout.strip():
|
|
||||||
# Parse survey data for channel metrics
|
|
||||||
return self._parse_survey_data(result.stdout)
|
|
||||||
|
|
||||||
self.logger.warning("No CSI data available via generic methods")
|
|
||||||
return None
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Generic CSI collection failed: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _parse_atheros_csi_bytes(self, data: bytes) -> Optional[np.ndarray]:
|
|
||||||
"""Parse Atheros CSI Tool binary format.
|
|
||||||
|
|
||||||
Format:
|
|
||||||
- 4 bytes: magic (0x11111111)
|
|
||||||
- 8 bytes: timestamp
|
|
||||||
- 2 bytes: channel
|
|
||||||
- 1 byte: bandwidth
|
|
||||||
- 1 byte: num_rx_antennas
|
|
||||||
- 1 byte: num_tx_antennas
|
|
||||||
- 1 byte: num_tones
|
|
||||||
- 2 bytes: RSSI
|
|
||||||
- Remaining: CSI matrix as int16 I/Q pairs
|
|
||||||
"""
|
|
||||||
if len(data) < 20:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
magic = struct.unpack('<I', data[0:4])[0]
|
|
||||||
if magic != 0x11111111:
|
|
||||||
# Try different offset or format
|
|
||||||
return self._parse_generic_csi_bytes(data)
|
|
||||||
|
|
||||||
# Parse header
|
|
||||||
timestamp = struct.unpack('<Q', data[4:12])[0]
|
|
||||||
channel = struct.unpack('<H', data[12:14])[0]
|
|
||||||
bw = struct.unpack('<B', data[14:15])[0]
|
|
||||||
nr = struct.unpack('<B', data[15:16])[0]
|
|
||||||
nc = struct.unpack('<B', data[16:17])[0]
|
|
||||||
num_tones = struct.unpack('<B', data[17:18])[0]
|
|
||||||
|
|
||||||
if nr == 0 or num_tones == 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Parse CSI matrix
|
|
||||||
csi_data = data[20:]
|
|
||||||
csi_matrix = np.zeros((nr, num_tones), dtype=complex)
|
|
||||||
|
|
||||||
for ant in range(nr):
|
|
||||||
for tone in range(num_tones):
|
|
||||||
offset = (ant * num_tones + tone) * 4
|
|
||||||
if offset + 4 <= len(csi_data):
|
|
||||||
real, imag = struct.unpack('<hh', csi_data[offset:offset+4])
|
|
||||||
csi_matrix[ant, tone] = complex(real, imag)
|
|
||||||
|
|
||||||
return csi_matrix
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Error parsing Atheros CSI: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _parse_generic_csi_bytes(self, data: bytes) -> Optional[np.ndarray]:
|
|
||||||
"""Parse generic binary CSI format."""
|
|
||||||
if len(data) < 8:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Assume simple format: int16 I/Q pairs
|
|
||||||
num_samples = len(data) // 4
|
|
||||||
if num_samples == 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Default to 56 subcarriers (20MHz), adjust antennas
|
|
||||||
num_tones = min(56, num_samples)
|
|
||||||
num_antennas = max(1, num_samples // num_tones)
|
|
||||||
|
|
||||||
csi_matrix = np.zeros((num_antennas, num_tones), dtype=complex)
|
|
||||||
|
|
||||||
for i in range(min(num_samples, num_antennas * num_tones)):
|
|
||||||
offset = i * 4
|
|
||||||
if offset + 4 <= len(data):
|
|
||||||
real, imag = struct.unpack('<hh', data[offset:offset+4])
|
|
||||||
ant = i // num_tones
|
|
||||||
tone = i % num_tones
|
|
||||||
if ant < num_antennas and tone < num_tones:
|
|
||||||
csi_matrix[ant, tone] = complex(real, imag)
|
|
||||||
|
|
||||||
return csi_matrix
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Error parsing generic CSI: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _parse_survey_data(self, survey_output: str) -> Optional[np.ndarray]:
|
|
||||||
"""Parse iw survey dump output to extract channel metrics.
|
|
||||||
|
|
||||||
This isn't true CSI but provides per-channel noise and activity data
|
|
||||||
that can be used as a fallback.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
lines = survey_output.strip().split('\n')
|
|
||||||
noise_values = []
|
|
||||||
busy_values = []
|
|
||||||
|
|
||||||
for line in lines:
|
|
||||||
if 'noise:' in line.lower():
|
|
||||||
parts = line.split()
|
|
||||||
for i, p in enumerate(parts):
|
|
||||||
if p == 'dBm' and i > 0:
|
|
||||||
try:
|
|
||||||
noise_values.append(float(parts[i-1]))
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
elif 'channel busy time:' in line.lower():
|
|
||||||
parts = line.split()
|
|
||||||
for i, p in enumerate(parts):
|
|
||||||
if p == 'ms' and i > 0:
|
|
||||||
try:
|
|
||||||
busy_values.append(float(parts[i-1]))
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if noise_values:
|
|
||||||
# Create pseudo-CSI from noise measurements
|
|
||||||
num_channels = len(noise_values)
|
|
||||||
csi_matrix = np.zeros((1, max(56, num_channels)), dtype=complex)
|
|
||||||
|
|
||||||
for i, noise in enumerate(noise_values):
|
|
||||||
# Convert noise dBm to amplitude (simplified)
|
|
||||||
amplitude = 10 ** (noise / 20)
|
|
||||||
phase = 0 if i >= len(busy_values) else busy_values[i] / 1000 * np.pi
|
|
||||||
csi_matrix[0, i] = amplitude * np.exp(1j * phase)
|
|
||||||
|
|
||||||
return csi_matrix
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Error parsing survey data: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def check_health(self) -> bool:
|
async def check_health(self) -> bool:
|
||||||
"""Check if the router connection is healthy.
|
"""Check if the router connection is healthy.
|
||||||
|
|||||||
@@ -1,566 +0,0 @@
|
|||||||
"""Vital signs detection from CSI signals.
|
|
||||||
|
|
||||||
This module provides breathing and heartbeat detection capabilities
|
|
||||||
mirroring the Rust wifi-densepose-mat crate functionality.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
import scipy.signal
|
|
||||||
import scipy.fft
|
|
||||||
|
|
||||||
|
|
||||||
class BreathingType(Enum):
|
|
||||||
"""Types of breathing patterns."""
|
|
||||||
NORMAL = "normal"
|
|
||||||
SHALLOW = "shallow"
|
|
||||||
DEEP = "deep"
|
|
||||||
RAPID = "rapid"
|
|
||||||
IRREGULAR = "irregular"
|
|
||||||
APNEA = "apnea"
|
|
||||||
|
|
||||||
|
|
||||||
class SignalStrength(Enum):
|
|
||||||
"""Signal strength classification."""
|
|
||||||
STRONG = "strong"
|
|
||||||
MODERATE = "moderate"
|
|
||||||
WEAK = "weak"
|
|
||||||
VERY_WEAK = "very_weak"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BreathingPattern:
|
|
||||||
"""Detected breathing pattern."""
|
|
||||||
rate_bpm: float
|
|
||||||
amplitude: float
|
|
||||||
regularity: float
|
|
||||||
pattern_type: BreathingType
|
|
||||||
confidence: float
|
|
||||||
timestamp: datetime
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class HeartbeatSignature:
|
|
||||||
"""Detected heartbeat signature."""
|
|
||||||
rate_bpm: float
|
|
||||||
signal_strength: SignalStrength
|
|
||||||
hrv_estimate: Optional[float]
|
|
||||||
confidence: float
|
|
||||||
timestamp: datetime
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class VitalSignsReading:
|
|
||||||
"""Combined vital signs reading."""
|
|
||||||
breathing: Optional[BreathingPattern]
|
|
||||||
heartbeat: Optional[HeartbeatSignature]
|
|
||||||
motion_detected: bool
|
|
||||||
overall_confidence: float
|
|
||||||
timestamp: datetime
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BreathingDetectorConfig:
|
|
||||||
"""Configuration for breathing detection."""
|
|
||||||
min_rate_bpm: float = 4.0 # Very slow breathing
|
|
||||||
max_rate_bpm: float = 40.0 # Fast breathing (distressed)
|
|
||||||
min_amplitude: float = 0.1
|
|
||||||
window_size: int = 512
|
|
||||||
window_overlap: float = 0.5
|
|
||||||
confidence_threshold: float = 0.3
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class HeartbeatDetectorConfig:
|
|
||||||
"""Configuration for heartbeat detection."""
|
|
||||||
min_rate_bpm: float = 30.0 # Bradycardia
|
|
||||||
max_rate_bpm: float = 200.0 # Extreme tachycardia
|
|
||||||
min_signal_strength: float = 0.05
|
|
||||||
window_size: int = 1024
|
|
||||||
enhanced_processing: bool = True
|
|
||||||
confidence_threshold: float = 0.4
|
|
||||||
|
|
||||||
|
|
||||||
class BreathingDetector:
|
|
||||||
"""Detector for breathing patterns in CSI signals.
|
|
||||||
|
|
||||||
Breathing causes periodic chest movement that modulates the WiFi signal.
|
|
||||||
We detect this by looking for periodic variations in the 0.1-0.67 Hz range
|
|
||||||
(corresponding to 6-40 breaths per minute).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: Optional[BreathingDetectorConfig] = None):
|
|
||||||
"""Initialize breathing detector.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config: Detector configuration. Uses defaults if None.
|
|
||||||
"""
|
|
||||||
self.config = config or BreathingDetectorConfig()
|
|
||||||
|
|
||||||
def detect(self, csi_amplitudes: np.ndarray, sample_rate: float) -> Optional[BreathingPattern]:
|
|
||||||
"""Detect breathing pattern from CSI amplitude variations.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
csi_amplitudes: Array of CSI amplitude values.
|
|
||||||
sample_rate: Sampling rate in Hz.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Detected BreathingPattern or None if not detected.
|
|
||||||
"""
|
|
||||||
if len(csi_amplitudes) < self.config.window_size:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Calculate the frequency spectrum
|
|
||||||
spectrum = self._compute_spectrum(csi_amplitudes)
|
|
||||||
|
|
||||||
# Find the dominant frequency in the breathing range
|
|
||||||
min_freq = self.config.min_rate_bpm / 60.0
|
|
||||||
max_freq = self.config.max_rate_bpm / 60.0
|
|
||||||
|
|
||||||
result = self._find_dominant_frequency(
|
|
||||||
spectrum, sample_rate, min_freq, max_freq
|
|
||||||
)
|
|
||||||
|
|
||||||
if result is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
dominant_freq, amplitude = result
|
|
||||||
|
|
||||||
# Convert to BPM
|
|
||||||
rate_bpm = dominant_freq * 60.0
|
|
||||||
|
|
||||||
# Check amplitude threshold
|
|
||||||
if amplitude < self.config.min_amplitude:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Calculate regularity
|
|
||||||
regularity = self._calculate_regularity(spectrum, dominant_freq, sample_rate)
|
|
||||||
|
|
||||||
# Determine breathing type
|
|
||||||
pattern_type = self._classify_pattern(rate_bpm, regularity)
|
|
||||||
|
|
||||||
# Calculate confidence
|
|
||||||
confidence = self._calculate_confidence(amplitude, regularity)
|
|
||||||
|
|
||||||
if confidence < self.config.confidence_threshold:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return BreathingPattern(
|
|
||||||
rate_bpm=rate_bpm,
|
|
||||||
amplitude=amplitude,
|
|
||||||
regularity=regularity,
|
|
||||||
pattern_type=pattern_type,
|
|
||||||
confidence=confidence,
|
|
||||||
timestamp=datetime.now(timezone.utc)
|
|
||||||
)
|
|
||||||
|
|
||||||
def _compute_spectrum(self, signal: np.ndarray) -> np.ndarray:
|
|
||||||
"""Compute frequency spectrum using FFT."""
|
|
||||||
# Apply window
|
|
||||||
window = scipy.signal.windows.hamming(len(signal))
|
|
||||||
windowed = signal * window
|
|
||||||
|
|
||||||
# Compute FFT
|
|
||||||
spectrum = np.abs(scipy.fft.rfft(windowed))
|
|
||||||
return spectrum
|
|
||||||
|
|
||||||
def _find_dominant_frequency(
|
|
||||||
self,
|
|
||||||
spectrum: np.ndarray,
|
|
||||||
sample_rate: float,
|
|
||||||
min_freq: float,
|
|
||||||
max_freq: float
|
|
||||||
) -> Optional[Tuple[float, float]]:
|
|
||||||
"""Find the dominant frequency in a given range."""
|
|
||||||
# rfft output length is n//2 + 1 for input of length n
|
|
||||||
# So original length n = (len(spectrum) - 1) * 2
|
|
||||||
n = (len(spectrum) - 1) * 2
|
|
||||||
freqs = scipy.fft.rfftfreq(n, 1.0 / sample_rate)
|
|
||||||
|
|
||||||
# Ensure freqs and spectrum have same length
|
|
||||||
min_len = min(len(freqs), len(spectrum))
|
|
||||||
freqs = freqs[:min_len]
|
|
||||||
spectrum_trimmed = spectrum[:min_len]
|
|
||||||
|
|
||||||
# Find indices in the frequency range
|
|
||||||
mask = (freqs >= min_freq) & (freqs <= max_freq)
|
|
||||||
if not np.any(mask):
|
|
||||||
return None
|
|
||||||
|
|
||||||
masked_spectrum = spectrum_trimmed.copy()
|
|
||||||
masked_spectrum[~mask] = 0
|
|
||||||
|
|
||||||
# Find peak
|
|
||||||
peak_idx = np.argmax(masked_spectrum)
|
|
||||||
if masked_spectrum[peak_idx] == 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return freqs[peak_idx], spectrum_trimmed[peak_idx]
|
|
||||||
|
|
||||||
def _calculate_regularity(
|
|
||||||
self,
|
|
||||||
spectrum: np.ndarray,
|
|
||||||
dominant_freq: float,
|
|
||||||
sample_rate: float
|
|
||||||
) -> float:
|
|
||||||
"""Calculate how regular the breathing pattern is."""
|
|
||||||
n = (len(spectrum) - 1) * 2
|
|
||||||
freqs = scipy.fft.rfftfreq(n, 1.0 / sample_rate)
|
|
||||||
|
|
||||||
# Look at energy concentration around dominant frequency
|
|
||||||
freq_resolution = freqs[1] - freqs[0] if len(freqs) > 1 else 1.0
|
|
||||||
peak_idx = int(dominant_freq / freq_resolution) if freq_resolution > 0 else 0
|
|
||||||
|
|
||||||
# Calculate energy in narrow band around peak
|
|
||||||
half_bandwidth = 3 # bins on each side
|
|
||||||
start_idx = max(0, peak_idx - half_bandwidth)
|
|
||||||
end_idx = min(len(spectrum), peak_idx + half_bandwidth + 1)
|
|
||||||
|
|
||||||
peak_energy = np.sum(spectrum[start_idx:end_idx] ** 2)
|
|
||||||
total_energy = np.sum(spectrum ** 2) + 1e-10
|
|
||||||
|
|
||||||
regularity = float(peak_energy / total_energy)
|
|
||||||
return min(1.0, regularity * 2.0) # Scale to 0-1
|
|
||||||
|
|
||||||
def _classify_pattern(self, rate_bpm: float, regularity: float) -> BreathingType:
|
|
||||||
"""Classify breathing pattern based on rate and regularity."""
|
|
||||||
if regularity < 0.3:
|
|
||||||
return BreathingType.IRREGULAR
|
|
||||||
|
|
||||||
if rate_bpm < 6:
|
|
||||||
return BreathingType.APNEA
|
|
||||||
elif rate_bpm < 12:
|
|
||||||
return BreathingType.SHALLOW
|
|
||||||
elif rate_bpm <= 20:
|
|
||||||
return BreathingType.NORMAL
|
|
||||||
elif rate_bpm <= 25:
|
|
||||||
return BreathingType.DEEP
|
|
||||||
else:
|
|
||||||
return BreathingType.RAPID
|
|
||||||
|
|
||||||
def _calculate_confidence(self, amplitude: float, regularity: float) -> float:
|
|
||||||
"""Calculate detection confidence."""
|
|
||||||
# Combine amplitude and regularity factors
|
|
||||||
amp_factor = min(1.0, amplitude / 0.5)
|
|
||||||
confidence = 0.6 * amp_factor + 0.4 * regularity
|
|
||||||
return float(np.clip(confidence, 0.0, 1.0))
|
|
||||||
|
|
||||||
|
|
||||||
class HeartbeatDetector:
|
|
||||||
"""Detector for heartbeat signatures using micro-Doppler analysis.
|
|
||||||
|
|
||||||
Heartbeats cause very small chest wall movements (~0.5mm) that can be
|
|
||||||
detected through careful analysis of CSI phase variations at higher
|
|
||||||
frequencies than breathing (0.8-3.3 Hz for 48-200 BPM).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: Optional[HeartbeatDetectorConfig] = None):
|
|
||||||
"""Initialize heartbeat detector.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config: Detector configuration. Uses defaults if None.
|
|
||||||
"""
|
|
||||||
self.config = config or HeartbeatDetectorConfig()
|
|
||||||
|
|
||||||
def detect(
|
|
||||||
self,
|
|
||||||
csi_phase: np.ndarray,
|
|
||||||
sample_rate: float,
|
|
||||||
breathing_rate: Optional[float] = None
|
|
||||||
) -> Optional[HeartbeatSignature]:
|
|
||||||
"""Detect heartbeat from CSI phase data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
csi_phase: Array of CSI phase values in radians.
|
|
||||||
sample_rate: Sampling rate in Hz.
|
|
||||||
breathing_rate: Known breathing rate in Hz (optional).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Detected HeartbeatSignature or None if not detected.
|
|
||||||
"""
|
|
||||||
if len(csi_phase) < self.config.window_size:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Remove breathing component if known
|
|
||||||
if breathing_rate is not None:
|
|
||||||
filtered = self._remove_breathing_component(csi_phase, sample_rate, breathing_rate)
|
|
||||||
else:
|
|
||||||
filtered = self._highpass_filter(csi_phase, sample_rate, 0.8)
|
|
||||||
|
|
||||||
# Compute micro-Doppler spectrum
|
|
||||||
spectrum = self._compute_micro_doppler_spectrum(filtered, sample_rate)
|
|
||||||
|
|
||||||
# Find heartbeat frequency
|
|
||||||
min_freq = self.config.min_rate_bpm / 60.0
|
|
||||||
max_freq = self.config.max_rate_bpm / 60.0
|
|
||||||
|
|
||||||
result = self._find_heartbeat_frequency(
|
|
||||||
spectrum, sample_rate, min_freq, max_freq
|
|
||||||
)
|
|
||||||
|
|
||||||
if result is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
heart_freq, strength = result
|
|
||||||
|
|
||||||
if strength < self.config.min_signal_strength:
|
|
||||||
return None
|
|
||||||
|
|
||||||
rate_bpm = heart_freq * 60.0
|
|
||||||
|
|
||||||
# Classify signal strength
|
|
||||||
signal_strength = self._classify_signal_strength(strength)
|
|
||||||
|
|
||||||
# Estimate HRV if we have enough data
|
|
||||||
hrv_estimate = self._estimate_hrv(csi_phase, sample_rate, heart_freq)
|
|
||||||
|
|
||||||
# Calculate confidence
|
|
||||||
confidence = self._calculate_confidence(strength, signal_strength)
|
|
||||||
|
|
||||||
if confidence < self.config.confidence_threshold:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return HeartbeatSignature(
|
|
||||||
rate_bpm=rate_bpm,
|
|
||||||
signal_strength=signal_strength,
|
|
||||||
hrv_estimate=hrv_estimate,
|
|
||||||
confidence=confidence,
|
|
||||||
timestamp=datetime.now(timezone.utc)
|
|
||||||
)
|
|
||||||
|
|
||||||
def _remove_breathing_component(
|
|
||||||
self,
|
|
||||||
phase: np.ndarray,
|
|
||||||
sample_rate: float,
|
|
||||||
breathing_rate: float
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Remove breathing frequency component from phase signal."""
|
|
||||||
# Design notch filter at breathing frequency
|
|
||||||
quality_factor = 30.0
|
|
||||||
b, a = scipy.signal.iirnotch(breathing_rate, quality_factor, sample_rate)
|
|
||||||
|
|
||||||
# Also remove harmonics (2x, 3x)
|
|
||||||
filtered = scipy.signal.filtfilt(b, a, phase)
|
|
||||||
|
|
||||||
for harmonic in [2, 3]:
|
|
||||||
notch_freq = breathing_rate * harmonic
|
|
||||||
if notch_freq < sample_rate / 2:
|
|
||||||
b, a = scipy.signal.iirnotch(notch_freq, quality_factor, sample_rate)
|
|
||||||
filtered = scipy.signal.filtfilt(b, a, filtered)
|
|
||||||
|
|
||||||
return filtered
|
|
||||||
|
|
||||||
def _highpass_filter(
|
|
||||||
self,
|
|
||||||
signal: np.ndarray,
|
|
||||||
sample_rate: float,
|
|
||||||
cutoff: float
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Apply highpass filter to remove low-frequency components."""
|
|
||||||
nyquist = sample_rate / 2
|
|
||||||
if cutoff >= nyquist:
|
|
||||||
return signal
|
|
||||||
|
|
||||||
b, a = scipy.signal.butter(4, cutoff / nyquist, btype='high')
|
|
||||||
return scipy.signal.filtfilt(b, a, signal)
|
|
||||||
|
|
||||||
def _compute_micro_doppler_spectrum(
|
|
||||||
self,
|
|
||||||
signal: np.ndarray,
|
|
||||||
sample_rate: float
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Compute micro-Doppler spectrum for heartbeat detection."""
|
|
||||||
# Use shorter window for better time resolution
|
|
||||||
window_size = min(len(signal), self.config.window_size)
|
|
||||||
|
|
||||||
if self.config.enhanced_processing:
|
|
||||||
# Use STFT for better frequency resolution
|
|
||||||
f, t, Zxx = scipy.signal.stft(
|
|
||||||
signal,
|
|
||||||
sample_rate,
|
|
||||||
nperseg=window_size,
|
|
||||||
noverlap=window_size // 2
|
|
||||||
)
|
|
||||||
# Average over time
|
|
||||||
spectrum = np.mean(np.abs(Zxx), axis=1)
|
|
||||||
else:
|
|
||||||
# Simple FFT
|
|
||||||
window = scipy.signal.windows.hamming(window_size)
|
|
||||||
windowed = signal[:window_size] * window
|
|
||||||
spectrum = np.abs(scipy.fft.rfft(windowed))
|
|
||||||
|
|
||||||
return spectrum
|
|
||||||
|
|
||||||
def _find_heartbeat_frequency(
|
|
||||||
self,
|
|
||||||
spectrum: np.ndarray,
|
|
||||||
sample_rate: float,
|
|
||||||
min_freq: float,
|
|
||||||
max_freq: float
|
|
||||||
) -> Optional[Tuple[float, float]]:
|
|
||||||
"""Find heartbeat frequency in the spectrum."""
|
|
||||||
# rfft output length is n//2 + 1 for input of length n
|
|
||||||
# So original length n = (len(spectrum) - 1) * 2
|
|
||||||
n = (len(spectrum) - 1) * 2
|
|
||||||
freqs = scipy.fft.rfftfreq(n, 1.0 / sample_rate)
|
|
||||||
|
|
||||||
# Ensure freqs and spectrum have same length
|
|
||||||
min_len = min(len(freqs), len(spectrum))
|
|
||||||
freqs = freqs[:min_len]
|
|
||||||
spectrum_trimmed = spectrum[:min_len]
|
|
||||||
|
|
||||||
# Find indices in the frequency range
|
|
||||||
mask = (freqs >= min_freq) & (freqs <= max_freq)
|
|
||||||
if not np.any(mask):
|
|
||||||
return None
|
|
||||||
|
|
||||||
masked_spectrum = spectrum_trimmed.copy()
|
|
||||||
masked_spectrum[~mask] = 0
|
|
||||||
|
|
||||||
# Find peak
|
|
||||||
peak_idx = np.argmax(masked_spectrum)
|
|
||||||
if masked_spectrum[peak_idx] == 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return freqs[peak_idx], spectrum_trimmed[peak_idx]
|
|
||||||
|
|
||||||
def _classify_signal_strength(self, strength: float) -> SignalStrength:
|
|
||||||
"""Classify signal strength level."""
|
|
||||||
if strength > 0.3:
|
|
||||||
return SignalStrength.STRONG
|
|
||||||
elif strength > 0.15:
|
|
||||||
return SignalStrength.MODERATE
|
|
||||||
elif strength > 0.08:
|
|
||||||
return SignalStrength.WEAK
|
|
||||||
else:
|
|
||||||
return SignalStrength.VERY_WEAK
|
|
||||||
|
|
||||||
def _estimate_hrv(
|
|
||||||
self,
|
|
||||||
phase: np.ndarray,
|
|
||||||
sample_rate: float,
|
|
||||||
heart_freq: float
|
|
||||||
) -> Optional[float]:
|
|
||||||
"""Estimate heart rate variability."""
|
|
||||||
# Simple HRV estimation based on spectral width
|
|
||||||
# In practice, would use peak detection and RR interval analysis
|
|
||||||
n = len(phase)
|
|
||||||
if n < self.config.window_size * 2:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Placeholder - would require more sophisticated analysis
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _calculate_confidence(
|
|
||||||
self,
|
|
||||||
strength: float,
|
|
||||||
signal_class: SignalStrength
|
|
||||||
) -> float:
|
|
||||||
"""Calculate detection confidence."""
|
|
||||||
strength_factor = min(1.0, strength / 0.2)
|
|
||||||
|
|
||||||
class_weights = {
|
|
||||||
SignalStrength.STRONG: 1.0,
|
|
||||||
SignalStrength.MODERATE: 0.7,
|
|
||||||
SignalStrength.WEAK: 0.4,
|
|
||||||
SignalStrength.VERY_WEAK: 0.2,
|
|
||||||
}
|
|
||||||
class_factor = class_weights[signal_class]
|
|
||||||
|
|
||||||
confidence = 0.5 * strength_factor + 0.5 * class_factor
|
|
||||||
return float(np.clip(confidence, 0.0, 1.0))
|
|
||||||
|
|
||||||
|
|
||||||
class VitalSignsDetector:
|
|
||||||
"""Combined vital signs detector for breathing and heartbeat."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
breathing_config: Optional[BreathingDetectorConfig] = None,
|
|
||||||
heartbeat_config: Optional[HeartbeatDetectorConfig] = None
|
|
||||||
):
|
|
||||||
"""Initialize combined detector.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
breathing_config: Breathing detector configuration.
|
|
||||||
heartbeat_config: Heartbeat detector configuration.
|
|
||||||
"""
|
|
||||||
self.breathing_detector = BreathingDetector(breathing_config)
|
|
||||||
self.heartbeat_detector = HeartbeatDetector(heartbeat_config)
|
|
||||||
self._motion_threshold = 0.5
|
|
||||||
|
|
||||||
def detect(
|
|
||||||
self,
|
|
||||||
csi_amplitude: np.ndarray,
|
|
||||||
csi_phase: np.ndarray,
|
|
||||||
sample_rate: float
|
|
||||||
) -> VitalSignsReading:
|
|
||||||
"""Detect vital signs from CSI data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
csi_amplitude: CSI amplitude values.
|
|
||||||
csi_phase: CSI phase values in radians.
|
|
||||||
sample_rate: Sampling rate in Hz.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Combined VitalSignsReading.
|
|
||||||
"""
|
|
||||||
# Detect breathing
|
|
||||||
breathing = self.breathing_detector.detect(csi_amplitude, sample_rate)
|
|
||||||
|
|
||||||
# Detect heartbeat (using breathing rate if available)
|
|
||||||
breathing_rate = (breathing.rate_bpm / 60.0) if breathing else None
|
|
||||||
heartbeat = self.heartbeat_detector.detect(csi_phase, sample_rate, breathing_rate)
|
|
||||||
|
|
||||||
# Detect motion
|
|
||||||
motion_detected = self._detect_motion(csi_amplitude)
|
|
||||||
|
|
||||||
# Calculate overall confidence
|
|
||||||
overall_confidence = self._calculate_overall_confidence(
|
|
||||||
breathing, heartbeat, motion_detected
|
|
||||||
)
|
|
||||||
|
|
||||||
return VitalSignsReading(
|
|
||||||
breathing=breathing,
|
|
||||||
heartbeat=heartbeat,
|
|
||||||
motion_detected=motion_detected,
|
|
||||||
overall_confidence=overall_confidence,
|
|
||||||
timestamp=datetime.now(timezone.utc)
|
|
||||||
)
|
|
||||||
|
|
||||||
def _detect_motion(self, amplitude: np.ndarray) -> bool:
|
|
||||||
"""Detect significant motion from amplitude variance."""
|
|
||||||
if len(amplitude) < 10:
|
|
||||||
return False
|
|
||||||
variance = np.var(amplitude)
|
|
||||||
return variance > self._motion_threshold
|
|
||||||
|
|
||||||
def _calculate_overall_confidence(
|
|
||||||
self,
|
|
||||||
breathing: Optional[BreathingPattern],
|
|
||||||
heartbeat: Optional[HeartbeatSignature],
|
|
||||||
motion_detected: bool
|
|
||||||
) -> float:
|
|
||||||
"""Calculate overall detection confidence."""
|
|
||||||
confidences = []
|
|
||||||
|
|
||||||
if breathing:
|
|
||||||
confidences.append(breathing.confidence)
|
|
||||||
if heartbeat:
|
|
||||||
confidences.append(heartbeat.confidence)
|
|
||||||
|
|
||||||
if not confidences:
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
base_confidence = np.mean(confidences)
|
|
||||||
|
|
||||||
# Motion can either help (confirms presence) or hurt (noise)
|
|
||||||
if motion_detected:
|
|
||||||
# Strong motion reduces confidence in subtle vital sign detection
|
|
||||||
if base_confidence > 0.7:
|
|
||||||
base_confidence *= 0.9
|
|
||||||
|
|
||||||
return float(np.clip(base_confidence, 0.0, 1.0))
|
|
||||||
@@ -1,10 +1,9 @@
|
|||||||
"""CSI data extraction from WiFi hardware using Test-Driven Development approach."""
|
"""CSI data extraction from WiFi hardware using Test-Driven Development approach."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import struct
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Dict, Any, Optional, Callable, Protocol, List, Tuple
|
from typing import Dict, Any, Optional, Callable, Protocol
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import logging
|
import logging
|
||||||
@@ -36,601 +35,128 @@ class CSIData:
|
|||||||
|
|
||||||
class CSIParser(Protocol):
|
class CSIParser(Protocol):
|
||||||
"""Protocol for CSI data parsers."""
|
"""Protocol for CSI data parsers."""
|
||||||
|
|
||||||
def parse(self, raw_data: bytes) -> CSIData:
|
def parse(self, raw_data: bytes) -> CSIData:
|
||||||
"""Parse raw CSI data into structured format."""
|
"""Parse raw CSI data into structured format."""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
class ESP32CSIParser:
|
class ESP32CSIParser:
|
||||||
"""Parser for ESP32 CSI data format.
|
"""Parser for ESP32 CSI data format."""
|
||||||
|
|
||||||
ESP32 CSI data format (from esp-csi library):
|
|
||||||
- Header: 'CSI_DATA:' prefix
|
|
||||||
- Fields: timestamp,rssi,rate,sig_mode,mcs,bandwidth,smoothing,
|
|
||||||
not_sounding,aggregation,stbc,fec_coding,sgi,noise_floor,
|
|
||||||
ampdu_cnt,channel,secondary_channel,local_timestamp,
|
|
||||||
ant,sig_len,rx_state,len,first_word,data[...]
|
|
||||||
|
|
||||||
The actual CSI data is in the 'data' field as complex I/Q values.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""Initialize ESP32 CSI parser with default configuration."""
|
|
||||||
self.htltf_subcarriers = 56 # HT-LTF subcarriers for 20MHz
|
|
||||||
self.antenna_count = 1 # Most ESP32 have 1 antenna
|
|
||||||
|
|
||||||
def parse(self, raw_data: bytes) -> CSIData:
|
def parse(self, raw_data: bytes) -> CSIData:
|
||||||
"""Parse ESP32 CSI data format.
|
"""Parse ESP32 CSI data format.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
raw_data: Raw bytes from ESP32 serial/network
|
raw_data: Raw bytes from ESP32
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Parsed CSI data
|
Parsed CSI data
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
CSIParseError: If data format is invalid
|
CSIParseError: If data format is invalid
|
||||||
"""
|
"""
|
||||||
if not raw_data:
|
if not raw_data:
|
||||||
raise CSIParseError("Empty data received")
|
raise CSIParseError("Empty data received")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data_str = raw_data.decode('utf-8').strip()
|
data_str = raw_data.decode('utf-8')
|
||||||
|
if not data_str.startswith('CSI_DATA:'):
|
||||||
# Handle ESP-CSI library format
|
|
||||||
if data_str.startswith('CSI_DATA,'):
|
|
||||||
return self._parse_esp_csi_format(data_str)
|
|
||||||
# Handle simplified format for testing
|
|
||||||
elif data_str.startswith('CSI_DATA:'):
|
|
||||||
return self._parse_simple_format(data_str)
|
|
||||||
else:
|
|
||||||
raise CSIParseError("Invalid ESP32 CSI data format")
|
raise CSIParseError("Invalid ESP32 CSI data format")
|
||||||
|
|
||||||
except UnicodeDecodeError:
|
# Parse ESP32 format: CSI_DATA:timestamp,antennas,subcarriers,freq,bw,snr,[amp],[phase]
|
||||||
# Binary format - parse as raw bytes
|
parts = data_str[9:].split(',') # Remove 'CSI_DATA:' prefix
|
||||||
return self._parse_binary_format(raw_data)
|
|
||||||
|
timestamp_ms = int(parts[0])
|
||||||
|
num_antennas = int(parts[1])
|
||||||
|
num_subcarriers = int(parts[2])
|
||||||
|
frequency_mhz = float(parts[3])
|
||||||
|
bandwidth_mhz = float(parts[4])
|
||||||
|
snr = float(parts[5])
|
||||||
|
|
||||||
|
# Convert to proper units
|
||||||
|
frequency = frequency_mhz * 1e6 # MHz to Hz
|
||||||
|
bandwidth = bandwidth_mhz * 1e6 # MHz to Hz
|
||||||
|
|
||||||
|
# Parse amplitude and phase arrays (simplified for now)
|
||||||
|
# In real implementation, this would parse actual CSI matrix data
|
||||||
|
amplitude = np.random.rand(num_antennas, num_subcarriers)
|
||||||
|
phase = np.random.rand(num_antennas, num_subcarriers)
|
||||||
|
|
||||||
|
return CSIData(
|
||||||
|
timestamp=datetime.fromtimestamp(timestamp_ms / 1000, tz=timezone.utc),
|
||||||
|
amplitude=amplitude,
|
||||||
|
phase=phase,
|
||||||
|
frequency=frequency,
|
||||||
|
bandwidth=bandwidth,
|
||||||
|
num_subcarriers=num_subcarriers,
|
||||||
|
num_antennas=num_antennas,
|
||||||
|
snr=snr,
|
||||||
|
metadata={'source': 'esp32', 'raw_length': len(raw_data)}
|
||||||
|
)
|
||||||
|
|
||||||
except (ValueError, IndexError) as e:
|
except (ValueError, IndexError) as e:
|
||||||
raise CSIParseError(f"Failed to parse ESP32 data: {e}")
|
raise CSIParseError(f"Failed to parse ESP32 data: {e}")
|
||||||
|
|
||||||
def _parse_esp_csi_format(self, data_str: str) -> CSIData:
|
|
||||||
"""Parse ESP-CSI library CSV format.
|
|
||||||
|
|
||||||
Format: CSI_DATA,<mac>,<rssi>,<rate>,<sig_mode>,<mcs>,<bw>,<smoothing>,
|
|
||||||
<not_sounding>,<aggregation>,<stbc>,<fec>,<sgi>,<noise>,
|
|
||||||
<ampdu_cnt>,<channel>,<sec_chan>,<timestamp>,<ant>,<sig_len>,
|
|
||||||
<rx_state>,<len>,[csi_data...]
|
|
||||||
"""
|
|
||||||
parts = data_str.split(',')
|
|
||||||
|
|
||||||
if len(parts) < 22:
|
|
||||||
raise CSIParseError(f"Incomplete ESP-CSI data: expected >= 22 fields, got {len(parts)}")
|
|
||||||
|
|
||||||
# Extract metadata
|
|
||||||
mac_addr = parts[1]
|
|
||||||
rssi = int(parts[2])
|
|
||||||
rate = int(parts[3])
|
|
||||||
sig_mode = int(parts[4])
|
|
||||||
mcs = int(parts[5])
|
|
||||||
bandwidth = int(parts[6]) # 0=20MHz, 1=40MHz
|
|
||||||
channel = int(parts[15])
|
|
||||||
timestamp_us = int(parts[17])
|
|
||||||
csi_len = int(parts[21])
|
|
||||||
|
|
||||||
# Parse CSI I/Q data (remaining fields are the CSI values)
|
|
||||||
csi_raw = [int(x) for x in parts[22:22 + csi_len]]
|
|
||||||
|
|
||||||
# Convert I/Q pairs to complex numbers
|
|
||||||
# ESP32 CSI format: [I0, Q0, I1, Q1, ...] as signed 8-bit integers
|
|
||||||
amplitude, phase = self._iq_to_amplitude_phase(csi_raw)
|
|
||||||
|
|
||||||
# Determine frequency from channel
|
|
||||||
if channel <= 14:
|
|
||||||
frequency = 2.412e9 + (channel - 1) * 5e6 # 2.4 GHz band
|
|
||||||
else:
|
|
||||||
frequency = 5.0e9 + (channel - 36) * 5e6 # 5 GHz band
|
|
||||||
|
|
||||||
bw_hz = 20e6 if bandwidth == 0 else 40e6
|
|
||||||
num_subcarriers = len(amplitude) // self.antenna_count
|
|
||||||
|
|
||||||
return CSIData(
|
|
||||||
timestamp=datetime.fromtimestamp(timestamp_us / 1e6, tz=timezone.utc),
|
|
||||||
amplitude=amplitude.reshape(self.antenna_count, -1),
|
|
||||||
phase=phase.reshape(self.antenna_count, -1),
|
|
||||||
frequency=frequency,
|
|
||||||
bandwidth=bw_hz,
|
|
||||||
num_subcarriers=num_subcarriers,
|
|
||||||
num_antennas=self.antenna_count,
|
|
||||||
snr=float(rssi + 100), # Approximate SNR from RSSI
|
|
||||||
metadata={
|
|
||||||
'source': 'esp32',
|
|
||||||
'mac': mac_addr,
|
|
||||||
'rssi': rssi,
|
|
||||||
'mcs': mcs,
|
|
||||||
'channel': channel,
|
|
||||||
'sig_mode': sig_mode,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def _parse_simple_format(self, data_str: str) -> CSIData:
|
|
||||||
"""Parse simplified CSI format for testing/development.
|
|
||||||
|
|
||||||
Format: CSI_DATA:timestamp,antennas,subcarriers,freq,bw,snr,[amp_values],[phase_values]
|
|
||||||
"""
|
|
||||||
content = data_str[9:] # Remove 'CSI_DATA:' prefix
|
|
||||||
|
|
||||||
# Split the main fields and array data
|
|
||||||
if '[' in content:
|
|
||||||
main_part, arrays_part = content.split('[', 1)
|
|
||||||
parts = main_part.rstrip(',').split(',')
|
|
||||||
|
|
||||||
# Parse amplitude and phase arrays
|
|
||||||
arrays_str = '[' + arrays_part
|
|
||||||
amp_str, phase_str = self._split_arrays(arrays_str)
|
|
||||||
amplitude = np.array([float(x) for x in amp_str.strip('[]').split(',')])
|
|
||||||
phase = np.array([float(x) for x in phase_str.strip('[]').split(',')])
|
|
||||||
else:
|
|
||||||
parts = content.split(',')
|
|
||||||
# No array data provided, need to return error or minimal data
|
|
||||||
raise CSIParseError("No CSI array data in simple format")
|
|
||||||
|
|
||||||
timestamp_ms = int(parts[0])
|
|
||||||
num_antennas = int(parts[1])
|
|
||||||
num_subcarriers = int(parts[2])
|
|
||||||
frequency_mhz = float(parts[3])
|
|
||||||
bandwidth_mhz = float(parts[4])
|
|
||||||
snr = float(parts[5])
|
|
||||||
|
|
||||||
# Reshape arrays
|
|
||||||
expected_size = num_antennas * num_subcarriers
|
|
||||||
if len(amplitude) != expected_size:
|
|
||||||
# Interpolate or pad
|
|
||||||
amplitude = np.interp(
|
|
||||||
np.linspace(0, 1, expected_size),
|
|
||||||
np.linspace(0, 1, len(amplitude)),
|
|
||||||
amplitude
|
|
||||||
)
|
|
||||||
phase = np.interp(
|
|
||||||
np.linspace(0, 1, expected_size),
|
|
||||||
np.linspace(0, 1, len(phase)),
|
|
||||||
phase
|
|
||||||
)
|
|
||||||
|
|
||||||
return CSIData(
|
|
||||||
timestamp=datetime.fromtimestamp(timestamp_ms / 1000, tz=timezone.utc),
|
|
||||||
amplitude=amplitude.reshape(num_antennas, num_subcarriers),
|
|
||||||
phase=phase.reshape(num_antennas, num_subcarriers),
|
|
||||||
frequency=frequency_mhz * 1e6,
|
|
||||||
bandwidth=bandwidth_mhz * 1e6,
|
|
||||||
num_subcarriers=num_subcarriers,
|
|
||||||
num_antennas=num_antennas,
|
|
||||||
snr=snr,
|
|
||||||
metadata={'source': 'esp32', 'format': 'simple'}
|
|
||||||
)
|
|
||||||
|
|
||||||
def _parse_binary_format(self, raw_data: bytes) -> CSIData:
|
|
||||||
"""Parse binary CSI format from ESP32.
|
|
||||||
|
|
||||||
Binary format (struct packed):
|
|
||||||
- 4 bytes: timestamp (uint32)
|
|
||||||
- 1 byte: num_antennas (uint8)
|
|
||||||
- 1 byte: num_subcarriers (uint8)
|
|
||||||
- 2 bytes: channel (uint16)
|
|
||||||
- 4 bytes: frequency (float32)
|
|
||||||
- 4 bytes: bandwidth (float32)
|
|
||||||
- 4 bytes: snr (float32)
|
|
||||||
- Remaining: CSI I/Q data as int8 pairs
|
|
||||||
"""
|
|
||||||
if len(raw_data) < 20:
|
|
||||||
raise CSIParseError("Binary data too short")
|
|
||||||
|
|
||||||
header_fmt = '<IBBHfff'
|
|
||||||
header_size = struct.calcsize(header_fmt)
|
|
||||||
|
|
||||||
timestamp, num_antennas, num_subcarriers, channel, freq, bw, snr = \
|
|
||||||
struct.unpack(header_fmt, raw_data[:header_size])
|
|
||||||
|
|
||||||
# Parse I/Q data
|
|
||||||
iq_data = raw_data[header_size:]
|
|
||||||
csi_raw = list(struct.unpack(f'{len(iq_data)}b', iq_data))
|
|
||||||
|
|
||||||
amplitude, phase = self._iq_to_amplitude_phase(csi_raw)
|
|
||||||
|
|
||||||
# Adjust dimensions
|
|
||||||
expected_size = num_antennas * num_subcarriers
|
|
||||||
if len(amplitude) < expected_size:
|
|
||||||
amplitude = np.pad(amplitude, (0, expected_size - len(amplitude)))
|
|
||||||
phase = np.pad(phase, (0, expected_size - len(phase)))
|
|
||||||
elif len(amplitude) > expected_size:
|
|
||||||
amplitude = amplitude[:expected_size]
|
|
||||||
phase = phase[:expected_size]
|
|
||||||
|
|
||||||
return CSIData(
|
|
||||||
timestamp=datetime.fromtimestamp(timestamp / 1000, tz=timezone.utc),
|
|
||||||
amplitude=amplitude.reshape(num_antennas, num_subcarriers),
|
|
||||||
phase=phase.reshape(num_antennas, num_subcarriers),
|
|
||||||
frequency=float(freq),
|
|
||||||
bandwidth=float(bw),
|
|
||||||
num_subcarriers=num_subcarriers,
|
|
||||||
num_antennas=num_antennas,
|
|
||||||
snr=float(snr),
|
|
||||||
metadata={'source': 'esp32', 'format': 'binary', 'channel': channel}
|
|
||||||
)
|
|
||||||
|
|
||||||
def _iq_to_amplitude_phase(self, iq_data: List[int]) -> Tuple[np.ndarray, np.ndarray]:
|
|
||||||
"""Convert I/Q pairs to amplitude and phase.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
iq_data: List of interleaved I, Q values (signed 8-bit)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (amplitude, phase) arrays
|
|
||||||
"""
|
|
||||||
if len(iq_data) % 2 != 0:
|
|
||||||
iq_data = iq_data[:-1] # Trim odd value
|
|
||||||
|
|
||||||
i_vals = np.array(iq_data[0::2], dtype=np.float64)
|
|
||||||
q_vals = np.array(iq_data[1::2], dtype=np.float64)
|
|
||||||
|
|
||||||
# Calculate amplitude (magnitude) and phase
|
|
||||||
complex_vals = i_vals + 1j * q_vals
|
|
||||||
amplitude = np.abs(complex_vals)
|
|
||||||
phase = np.angle(complex_vals)
|
|
||||||
|
|
||||||
# Normalize amplitude to [0, 1] range
|
|
||||||
max_amp = np.max(amplitude)
|
|
||||||
if max_amp > 0:
|
|
||||||
amplitude = amplitude / max_amp
|
|
||||||
|
|
||||||
return amplitude, phase
|
|
||||||
|
|
||||||
def _split_arrays(self, arrays_str: str) -> Tuple[str, str]:
|
|
||||||
"""Split concatenated array strings."""
|
|
||||||
# Find the boundary between two arrays
|
|
||||||
depth = 0
|
|
||||||
split_idx = 0
|
|
||||||
for i, c in enumerate(arrays_str):
|
|
||||||
if c == '[':
|
|
||||||
depth += 1
|
|
||||||
elif c == ']':
|
|
||||||
depth -= 1
|
|
||||||
if depth == 0:
|
|
||||||
split_idx = i + 1
|
|
||||||
break
|
|
||||||
|
|
||||||
amp_str = arrays_str[:split_idx]
|
|
||||||
phase_str = arrays_str[split_idx:].lstrip(',')
|
|
||||||
return amp_str, phase_str
|
|
||||||
|
|
||||||
|
|
||||||
class RouterCSIParser:
|
class RouterCSIParser:
|
||||||
"""Parser for router CSI data formats (Atheros, Intel, etc.).
|
"""Parser for router CSI data format."""
|
||||||
|
|
||||||
Supports:
|
|
||||||
- Atheros CSI Tool format (ath9k/ath10k)
|
|
||||||
- Intel 5300 CSI Tool format
|
|
||||||
- Nexmon CSI format (Broadcom)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""Initialize router CSI parser."""
|
|
||||||
self.default_subcarriers = 56 # 20MHz HT
|
|
||||||
self.default_antennas = 3
|
|
||||||
|
|
||||||
def parse(self, raw_data: bytes) -> CSIData:
|
def parse(self, raw_data: bytes) -> CSIData:
|
||||||
"""Parse router CSI data format.
|
"""Parse router CSI data format.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
raw_data: Raw bytes from router
|
raw_data: Raw bytes from router
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Parsed CSI data
|
Parsed CSI data
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
CSIParseError: If data format is invalid
|
CSIParseError: If data format is invalid
|
||||||
"""
|
"""
|
||||||
if not raw_data:
|
if not raw_data:
|
||||||
raise CSIParseError("Empty data received")
|
raise CSIParseError("Empty data received")
|
||||||
|
|
||||||
# Try to decode as text first
|
# Handle different router formats
|
||||||
try:
|
data_str = raw_data.decode('utf-8')
|
||||||
data_str = raw_data.decode('utf-8')
|
|
||||||
if data_str.startswith('ATHEROS_CSI:'):
|
if data_str.startswith('ATHEROS_CSI:'):
|
||||||
return self._parse_atheros_text_format(data_str)
|
return self._parse_atheros_format(raw_data)
|
||||||
elif data_str.startswith('INTEL_CSI:'):
|
|
||||||
return self._parse_intel_text_format(data_str)
|
|
||||||
except UnicodeDecodeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Binary format detection based on header
|
|
||||||
if len(raw_data) >= 4:
|
|
||||||
magic = struct.unpack('<I', raw_data[:4])[0]
|
|
||||||
if magic == 0x11111111: # Atheros CSI Tool magic
|
|
||||||
return self._parse_atheros_binary_format(raw_data)
|
|
||||||
elif magic == 0xBB: # Intel 5300 magic byte pattern
|
|
||||||
return self._parse_intel_binary_format(raw_data)
|
|
||||||
|
|
||||||
raise CSIParseError("Unknown router CSI format")
|
|
||||||
|
|
||||||
def _parse_atheros_text_format(self, data_str: str) -> CSIData:
|
|
||||||
"""Parse Atheros CSI text format.
|
|
||||||
|
|
||||||
Format: ATHEROS_CSI:timestamp,rssi,rate,channel,bw,nr,nc,num_tones,[csi_data...]
|
|
||||||
"""
|
|
||||||
content = data_str[12:] # Remove 'ATHEROS_CSI:' prefix
|
|
||||||
parts = content.split(',')
|
|
||||||
|
|
||||||
if len(parts) < 8:
|
|
||||||
raise CSIParseError("Incomplete Atheros CSI data")
|
|
||||||
|
|
||||||
timestamp = int(parts[0])
|
|
||||||
rssi = int(parts[1])
|
|
||||||
rate = int(parts[2])
|
|
||||||
channel = int(parts[3])
|
|
||||||
bandwidth = int(parts[4]) # MHz
|
|
||||||
nr = int(parts[5]) # Rx antennas
|
|
||||||
nc = int(parts[6]) # Tx antennas (usually 1 for probe)
|
|
||||||
num_tones = int(parts[7]) # Subcarriers
|
|
||||||
|
|
||||||
# Parse CSI matrix data
|
|
||||||
csi_values = [float(x) for x in parts[8:] if x.strip()]
|
|
||||||
|
|
||||||
# CSI data is complex: [real, imag, real, imag, ...]
|
|
||||||
amplitude, phase = self._parse_complex_csi(csi_values, nr, num_tones)
|
|
||||||
|
|
||||||
# Calculate frequency from channel
|
|
||||||
if channel <= 14:
|
|
||||||
frequency = 2.412e9 + (channel - 1) * 5e6
|
|
||||||
else:
|
else:
|
||||||
frequency = 5.18e9 + (channel - 36) * 5e6
|
raise CSIParseError("Unknown router CSI format")
|
||||||
|
|
||||||
|
def _parse_atheros_format(self, raw_data: bytes) -> CSIData:
|
||||||
|
"""Parse Atheros CSI format (placeholder implementation)."""
|
||||||
|
# This would implement actual Atheros CSI parsing
|
||||||
|
# For now, return mock data for testing
|
||||||
return CSIData(
|
return CSIData(
|
||||||
timestamp=datetime.fromtimestamp(timestamp / 1000, tz=timezone.utc),
|
timestamp=datetime.now(timezone.utc),
|
||||||
amplitude=amplitude,
|
amplitude=np.random.rand(3, 56),
|
||||||
phase=phase,
|
phase=np.random.rand(3, 56),
|
||||||
frequency=frequency,
|
frequency=2.4e9,
|
||||||
bandwidth=bandwidth * 1e6,
|
bandwidth=20e6,
|
||||||
num_subcarriers=num_tones,
|
num_subcarriers=56,
|
||||||
num_antennas=nr,
|
num_antennas=3,
|
||||||
snr=float(rssi + 95),
|
snr=12.0,
|
||||||
metadata={
|
metadata={'source': 'atheros_router'}
|
||||||
'source': 'atheros_router',
|
|
||||||
'rssi': rssi,
|
|
||||||
'rate': rate,
|
|
||||||
'channel': channel,
|
|
||||||
'tx_antennas': nc,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_atheros_binary_format(self, raw_data: bytes) -> CSIData:
|
|
||||||
"""Parse Atheros CSI Tool binary format.
|
|
||||||
|
|
||||||
Based on ath9k/ath10k CSI Tool structure:
|
|
||||||
- 4 bytes: magic (0x11111111)
|
|
||||||
- 8 bytes: timestamp
|
|
||||||
- 2 bytes: channel
|
|
||||||
- 1 byte: bandwidth (0=20MHz, 1=40MHz, 2=80MHz)
|
|
||||||
- 1 byte: nr (rx antennas)
|
|
||||||
- 1 byte: nc (tx antennas)
|
|
||||||
- 1 byte: num_tones
|
|
||||||
- 2 bytes: rssi
|
|
||||||
- Remaining: CSI payload (complex int16 per subcarrier per antenna pair)
|
|
||||||
"""
|
|
||||||
if len(raw_data) < 20:
|
|
||||||
raise CSIParseError("Atheros binary data too short")
|
|
||||||
|
|
||||||
header_fmt = '<IQHBBBBB' # Q is 8-byte timestamp
|
|
||||||
header_size = struct.calcsize(header_fmt)
|
|
||||||
|
|
||||||
magic, timestamp, channel, bw, nr, nc, num_tones, rssi = \
|
|
||||||
struct.unpack(header_fmt, raw_data[:header_size])
|
|
||||||
|
|
||||||
if magic != 0x11111111:
|
|
||||||
raise CSIParseError("Invalid Atheros magic number")
|
|
||||||
|
|
||||||
# Parse CSI payload
|
|
||||||
csi_data = raw_data[header_size:]
|
|
||||||
|
|
||||||
# Each subcarrier has complex value per antenna pair: int16 real + int16 imag
|
|
||||||
expected_bytes = nr * nc * num_tones * 4
|
|
||||||
if len(csi_data) < expected_bytes:
|
|
||||||
# Adjust num_tones based on available data
|
|
||||||
num_tones = len(csi_data) // (nr * nc * 4)
|
|
||||||
|
|
||||||
csi_complex = np.zeros((nr, num_tones), dtype=np.complex128)
|
|
||||||
|
|
||||||
for ant in range(nr):
|
|
||||||
for tone in range(num_tones):
|
|
||||||
offset = (ant * nc * num_tones + tone) * 4
|
|
||||||
if offset + 4 <= len(csi_data):
|
|
||||||
real, imag = struct.unpack('<hh', csi_data[offset:offset+4])
|
|
||||||
csi_complex[ant, tone] = complex(real, imag)
|
|
||||||
|
|
||||||
amplitude = np.abs(csi_complex)
|
|
||||||
phase = np.angle(csi_complex)
|
|
||||||
|
|
||||||
# Normalize amplitude
|
|
||||||
max_amp = np.max(amplitude)
|
|
||||||
if max_amp > 0:
|
|
||||||
amplitude = amplitude / max_amp
|
|
||||||
|
|
||||||
# Calculate frequency
|
|
||||||
if channel <= 14:
|
|
||||||
frequency = 2.412e9 + (channel - 1) * 5e6
|
|
||||||
else:
|
|
||||||
frequency = 5.18e9 + (channel - 36) * 5e6
|
|
||||||
|
|
||||||
bandwidth_hz = [20e6, 40e6, 80e6][bw] if bw < 3 else 20e6
|
|
||||||
|
|
||||||
return CSIData(
|
|
||||||
timestamp=datetime.fromtimestamp(timestamp / 1e9, tz=timezone.utc),
|
|
||||||
amplitude=amplitude,
|
|
||||||
phase=phase,
|
|
||||||
frequency=frequency,
|
|
||||||
bandwidth=bandwidth_hz,
|
|
||||||
num_subcarriers=num_tones,
|
|
||||||
num_antennas=nr,
|
|
||||||
snr=float(rssi),
|
|
||||||
metadata={
|
|
||||||
'source': 'atheros_router',
|
|
||||||
'format': 'binary',
|
|
||||||
'channel': channel,
|
|
||||||
'tx_antennas': nc,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def _parse_intel_text_format(self, data_str: str) -> CSIData:
|
|
||||||
"""Parse Intel 5300 CSI text format."""
|
|
||||||
content = data_str[10:] # Remove 'INTEL_CSI:' prefix
|
|
||||||
parts = content.split(',')
|
|
||||||
|
|
||||||
if len(parts) < 6:
|
|
||||||
raise CSIParseError("Incomplete Intel CSI data")
|
|
||||||
|
|
||||||
timestamp = int(parts[0])
|
|
||||||
rssi = int(parts[1])
|
|
||||||
channel = int(parts[2])
|
|
||||||
bandwidth = int(parts[3])
|
|
||||||
num_antennas = int(parts[4])
|
|
||||||
num_tones = int(parts[5])
|
|
||||||
|
|
||||||
csi_values = [float(x) for x in parts[6:] if x.strip()]
|
|
||||||
amplitude, phase = self._parse_complex_csi(csi_values, num_antennas, num_tones)
|
|
||||||
|
|
||||||
frequency = 5.18e9 + (channel - 36) * 5e6 if channel > 14 else 2.412e9 + (channel - 1) * 5e6
|
|
||||||
|
|
||||||
return CSIData(
|
|
||||||
timestamp=datetime.fromtimestamp(timestamp / 1000, tz=timezone.utc),
|
|
||||||
amplitude=amplitude,
|
|
||||||
phase=phase,
|
|
||||||
frequency=frequency,
|
|
||||||
bandwidth=bandwidth * 1e6,
|
|
||||||
num_subcarriers=num_tones,
|
|
||||||
num_antennas=num_antennas,
|
|
||||||
snr=float(rssi + 95),
|
|
||||||
metadata={'source': 'intel_5300', 'channel': channel}
|
|
||||||
)
|
|
||||||
|
|
||||||
def _parse_intel_binary_format(self, raw_data: bytes) -> CSIData:
|
|
||||||
"""Parse Intel 5300 CSI Tool binary format."""
|
|
||||||
# Intel format is more complex with BFEE (beamforming feedback) structure
|
|
||||||
if len(raw_data) < 25:
|
|
||||||
raise CSIParseError("Intel binary data too short")
|
|
||||||
|
|
||||||
# BFEE header structure
|
|
||||||
timestamp = struct.unpack('<Q', raw_data[0:8])[0]
|
|
||||||
rssi_a, rssi_b, rssi_c = struct.unpack('<bbb', raw_data[8:11])
|
|
||||||
noise = struct.unpack('<b', raw_data[11:12])[0]
|
|
||||||
agc = struct.unpack('<B', raw_data[12:13])[0]
|
|
||||||
antenna_sel = struct.unpack('<B', raw_data[13:14])[0]
|
|
||||||
perm = struct.unpack('<BBB', raw_data[14:17])
|
|
||||||
num_tones = struct.unpack('<B', raw_data[17:18])[0]
|
|
||||||
nc = struct.unpack('<B', raw_data[18:19])[0]
|
|
||||||
nr = struct.unpack('<B', raw_data[19:20])[0]
|
|
||||||
|
|
||||||
# Parse CSI matrix
|
|
||||||
csi_data = raw_data[20:]
|
|
||||||
|
|
||||||
# Intel stores CSI in a packed format with variable bit width
|
|
||||||
csi_complex = self._unpack_intel_csi(csi_data, nr, nc, num_tones)
|
|
||||||
|
|
||||||
# Use first TX stream
|
|
||||||
amplitude = np.abs(csi_complex[:, 0, :])
|
|
||||||
phase = np.angle(csi_complex[:, 0, :])
|
|
||||||
|
|
||||||
# Normalize
|
|
||||||
max_amp = np.max(amplitude)
|
|
||||||
if max_amp > 0:
|
|
||||||
amplitude = amplitude / max_amp
|
|
||||||
|
|
||||||
rssi_avg = (rssi_a + rssi_b + rssi_c) / 3
|
|
||||||
|
|
||||||
return CSIData(
|
|
||||||
timestamp=datetime.fromtimestamp(timestamp / 1e6, tz=timezone.utc),
|
|
||||||
amplitude=amplitude,
|
|
||||||
phase=phase,
|
|
||||||
frequency=5.32e9, # Default Intel channel
|
|
||||||
bandwidth=40e6,
|
|
||||||
num_subcarriers=num_tones,
|
|
||||||
num_antennas=nr,
|
|
||||||
snr=float(rssi_avg - noise),
|
|
||||||
metadata={
|
|
||||||
'source': 'intel_5300',
|
|
||||||
'format': 'binary',
|
|
||||||
'noise_floor': noise,
|
|
||||||
'agc': agc,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def _unpack_intel_csi(self, data: bytes, nr: int, nc: int, num_tones: int) -> np.ndarray:
|
|
||||||
"""Unpack Intel CSI data with bit manipulation."""
|
|
||||||
csi = np.zeros((nr, nc, num_tones), dtype=np.complex128)
|
|
||||||
|
|
||||||
# Intel uses packed 10-bit values
|
|
||||||
bits_per_sample = 10
|
|
||||||
samples_needed = nr * nc * num_tones * 2 # real + imag
|
|
||||||
|
|
||||||
# Simple unpacking (actual Intel format is more complex)
|
|
||||||
idx = 0
|
|
||||||
for tone in range(num_tones):
|
|
||||||
for nc_idx in range(nc):
|
|
||||||
for nr_idx in range(nr):
|
|
||||||
if idx + 2 <= len(data):
|
|
||||||
# Approximate unpacking
|
|
||||||
real = int.from_bytes(data[idx:idx+1], 'little', signed=True)
|
|
||||||
imag = int.from_bytes(data[idx+1:idx+2], 'little', signed=True)
|
|
||||||
csi[nr_idx, nc_idx, tone] = complex(real, imag)
|
|
||||||
idx += 2
|
|
||||||
|
|
||||||
return csi
|
|
||||||
|
|
||||||
def _parse_complex_csi(
|
|
||||||
self,
|
|
||||||
values: List[float],
|
|
||||||
num_antennas: int,
|
|
||||||
num_tones: int
|
|
||||||
) -> Tuple[np.ndarray, np.ndarray]:
|
|
||||||
"""Parse complex CSI values from real/imag pairs."""
|
|
||||||
expected_len = num_antennas * num_tones * 2
|
|
||||||
|
|
||||||
if len(values) < expected_len:
|
|
||||||
# Pad with zeros
|
|
||||||
values = values + [0.0] * (expected_len - len(values))
|
|
||||||
|
|
||||||
csi_complex = np.zeros((num_antennas, num_tones), dtype=np.complex128)
|
|
||||||
|
|
||||||
for ant in range(num_antennas):
|
|
||||||
for tone in range(num_tones):
|
|
||||||
idx = (ant * num_tones + tone) * 2
|
|
||||||
if idx + 1 < len(values):
|
|
||||||
csi_complex[ant, tone] = complex(values[idx], values[idx + 1])
|
|
||||||
|
|
||||||
amplitude = np.abs(csi_complex)
|
|
||||||
phase = np.angle(csi_complex)
|
|
||||||
|
|
||||||
# Normalize
|
|
||||||
max_amp = np.max(amplitude)
|
|
||||||
if max_amp > 0:
|
|
||||||
amplitude = amplitude / max_amp
|
|
||||||
|
|
||||||
return amplitude, phase
|
|
||||||
|
|
||||||
|
|
||||||
class CSIExtractor:
|
class CSIExtractor:
|
||||||
"""Main CSI data extractor supporting multiple hardware types."""
|
"""Main CSI data extractor supporting multiple hardware types."""
|
||||||
|
|
||||||
def __init__(self, config: Dict[str, Any], logger: Optional[logging.Logger] = None):
|
def __init__(self, config: Dict[str, Any], logger: Optional[logging.Logger] = None):
|
||||||
"""Initialize CSI extractor.
|
"""Initialize CSI extractor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: Configuration dictionary
|
config: Configuration dictionary
|
||||||
logger: Optional logger instance
|
logger: Optional logger instance
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If configuration is invalid
|
ValueError: If configuration is invalid
|
||||||
"""
|
"""
|
||||||
self._validate_config(config)
|
self._validate_config(config)
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.logger = logger or logging.getLogger(__name__)
|
self.logger = logger or logging.getLogger(__name__)
|
||||||
self.hardware_type = config['hardware_type']
|
self.hardware_type = config['hardware_type']
|
||||||
@@ -639,39 +165,49 @@ class CSIExtractor:
|
|||||||
self.timeout = config['timeout']
|
self.timeout = config['timeout']
|
||||||
self.validation_enabled = config.get('validation_enabled', True)
|
self.validation_enabled = config.get('validation_enabled', True)
|
||||||
self.retry_attempts = config.get('retry_attempts', 3)
|
self.retry_attempts = config.get('retry_attempts', 3)
|
||||||
|
|
||||||
# State management
|
# State management
|
||||||
self.is_connected = False
|
self.is_connected = False
|
||||||
self.is_streaming = False
|
self.is_streaming = False
|
||||||
self._connection = None
|
|
||||||
|
|
||||||
# Create appropriate parser
|
# Create appropriate parser
|
||||||
if self.hardware_type == 'esp32':
|
if self.hardware_type == 'esp32':
|
||||||
self.parser = ESP32CSIParser()
|
self.parser = ESP32CSIParser()
|
||||||
elif self.hardware_type in ('router', 'atheros', 'intel'):
|
elif self.hardware_type == 'router':
|
||||||
self.parser = RouterCSIParser()
|
self.parser = RouterCSIParser()
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported hardware type: {self.hardware_type}")
|
raise ValueError(f"Unsupported hardware type: {self.hardware_type}")
|
||||||
|
|
||||||
def _validate_config(self, config: Dict[str, Any]) -> None:
|
def _validate_config(self, config: Dict[str, Any]) -> None:
|
||||||
"""Validate configuration parameters."""
|
"""Validate configuration parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration to validate
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If configuration is invalid
|
||||||
|
"""
|
||||||
required_fields = ['hardware_type', 'sampling_rate', 'buffer_size', 'timeout']
|
required_fields = ['hardware_type', 'sampling_rate', 'buffer_size', 'timeout']
|
||||||
missing_fields = [field for field in required_fields if field not in config]
|
missing_fields = [field for field in required_fields if field not in config]
|
||||||
|
|
||||||
if missing_fields:
|
if missing_fields:
|
||||||
raise ValueError(f"Missing required configuration: {missing_fields}")
|
raise ValueError(f"Missing required configuration: {missing_fields}")
|
||||||
|
|
||||||
if config['sampling_rate'] <= 0:
|
if config['sampling_rate'] <= 0:
|
||||||
raise ValueError("sampling_rate must be positive")
|
raise ValueError("sampling_rate must be positive")
|
||||||
|
|
||||||
if config['buffer_size'] <= 0:
|
if config['buffer_size'] <= 0:
|
||||||
raise ValueError("buffer_size must be positive")
|
raise ValueError("buffer_size must be positive")
|
||||||
|
|
||||||
if config['timeout'] <= 0:
|
if config['timeout'] <= 0:
|
||||||
raise ValueError("timeout must be positive")
|
raise ValueError("timeout must be positive")
|
||||||
|
|
||||||
async def connect(self) -> bool:
|
async def connect(self) -> bool:
|
||||||
"""Establish connection to CSI hardware."""
|
"""Establish connection to CSI hardware.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if connection successful, False otherwise
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
success = await self._establish_hardware_connection()
|
success = await self._establish_hardware_connection()
|
||||||
self.is_connected = success
|
self.is_connected = success
|
||||||
@@ -680,64 +216,86 @@ class CSIExtractor:
|
|||||||
self.logger.error(f"Failed to connect to hardware: {e}")
|
self.logger.error(f"Failed to connect to hardware: {e}")
|
||||||
self.is_connected = False
|
self.is_connected = False
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def disconnect(self) -> None:
|
async def disconnect(self) -> None:
|
||||||
"""Disconnect from CSI hardware."""
|
"""Disconnect from CSI hardware."""
|
||||||
if self.is_connected:
|
if self.is_connected:
|
||||||
await self._close_hardware_connection()
|
await self._close_hardware_connection()
|
||||||
self.is_connected = False
|
self.is_connected = False
|
||||||
|
|
||||||
async def extract_csi(self) -> CSIData:
|
async def extract_csi(self) -> CSIData:
|
||||||
"""Extract CSI data from hardware."""
|
"""Extract CSI data from hardware.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Extracted CSI data
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CSIParseError: If not connected or extraction fails
|
||||||
|
"""
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
raise CSIParseError("Not connected to hardware")
|
raise CSIParseError("Not connected to hardware")
|
||||||
|
|
||||||
|
# Retry mechanism for temporary failures
|
||||||
for attempt in range(self.retry_attempts):
|
for attempt in range(self.retry_attempts):
|
||||||
try:
|
try:
|
||||||
raw_data = await self._read_raw_data()
|
raw_data = await self._read_raw_data()
|
||||||
csi_data = self.parser.parse(raw_data)
|
csi_data = self.parser.parse(raw_data)
|
||||||
|
|
||||||
if self.validation_enabled:
|
if self.validation_enabled:
|
||||||
self.validate_csi_data(csi_data)
|
self.validate_csi_data(csi_data)
|
||||||
|
|
||||||
return csi_data
|
return csi_data
|
||||||
|
|
||||||
except ConnectionError as e:
|
except ConnectionError as e:
|
||||||
if attempt < self.retry_attempts - 1:
|
if attempt < self.retry_attempts - 1:
|
||||||
self.logger.warning(f"Extraction attempt {attempt + 1} failed, retrying: {e}")
|
self.logger.warning(f"Extraction attempt {attempt + 1} failed, retrying: {e}")
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1) # Brief delay before retry
|
||||||
else:
|
else:
|
||||||
raise CSIParseError(f"Extraction failed after {self.retry_attempts} attempts: {e}")
|
raise CSIParseError(f"Extraction failed after {self.retry_attempts} attempts: {e}")
|
||||||
|
|
||||||
def validate_csi_data(self, csi_data: CSIData) -> bool:
|
def validate_csi_data(self, csi_data: CSIData) -> bool:
|
||||||
"""Validate CSI data structure and values."""
|
"""Validate CSI data structure and values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
csi_data: CSI data to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if valid
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CSIValidationError: If data is invalid
|
||||||
|
"""
|
||||||
if csi_data.amplitude.size == 0:
|
if csi_data.amplitude.size == 0:
|
||||||
raise CSIValidationError("Empty amplitude data")
|
raise CSIValidationError("Empty amplitude data")
|
||||||
|
|
||||||
if csi_data.phase.size == 0:
|
if csi_data.phase.size == 0:
|
||||||
raise CSIValidationError("Empty phase data")
|
raise CSIValidationError("Empty phase data")
|
||||||
|
|
||||||
if csi_data.frequency <= 0:
|
if csi_data.frequency <= 0:
|
||||||
raise CSIValidationError("Invalid frequency")
|
raise CSIValidationError("Invalid frequency")
|
||||||
|
|
||||||
if csi_data.bandwidth <= 0:
|
if csi_data.bandwidth <= 0:
|
||||||
raise CSIValidationError("Invalid bandwidth")
|
raise CSIValidationError("Invalid bandwidth")
|
||||||
|
|
||||||
if csi_data.num_subcarriers <= 0:
|
if csi_data.num_subcarriers <= 0:
|
||||||
raise CSIValidationError("Invalid number of subcarriers")
|
raise CSIValidationError("Invalid number of subcarriers")
|
||||||
|
|
||||||
if csi_data.num_antennas <= 0:
|
if csi_data.num_antennas <= 0:
|
||||||
raise CSIValidationError("Invalid number of antennas")
|
raise CSIValidationError("Invalid number of antennas")
|
||||||
|
|
||||||
if csi_data.snr < -50 or csi_data.snr > 100:
|
if csi_data.snr < -50 or csi_data.snr > 50: # Reasonable SNR range
|
||||||
raise CSIValidationError("Invalid SNR value")
|
raise CSIValidationError("Invalid SNR value")
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def start_streaming(self, callback: Callable[[CSIData], None]) -> None:
|
async def start_streaming(self, callback: Callable[[CSIData], None]) -> None:
|
||||||
"""Start streaming CSI data."""
|
"""Start streaming CSI data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callback: Function to call with each CSI sample
|
||||||
|
"""
|
||||||
self.is_streaming = True
|
self.is_streaming = True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while self.is_streaming:
|
while self.is_streaming:
|
||||||
csi_data = await self.extract_csi()
|
csi_data = await self.extract_csi()
|
||||||
@@ -747,74 +305,22 @@ class CSIExtractor:
|
|||||||
self.logger.error(f"Streaming error: {e}")
|
self.logger.error(f"Streaming error: {e}")
|
||||||
finally:
|
finally:
|
||||||
self.is_streaming = False
|
self.is_streaming = False
|
||||||
|
|
||||||
def stop_streaming(self) -> None:
|
def stop_streaming(self) -> None:
|
||||||
"""Stop streaming CSI data."""
|
"""Stop streaming CSI data."""
|
||||||
self.is_streaming = False
|
self.is_streaming = False
|
||||||
|
|
||||||
async def _establish_hardware_connection(self) -> bool:
|
async def _establish_hardware_connection(self) -> bool:
|
||||||
"""Establish connection to hardware."""
|
"""Establish connection to hardware (to be implemented by subclasses)."""
|
||||||
connection_config = self.config.get('connection', {})
|
# Placeholder implementation for testing
|
||||||
|
return True
|
||||||
if self.hardware_type == 'esp32':
|
|
||||||
# Serial or network connection for ESP32
|
|
||||||
port = connection_config.get('port', '/dev/ttyUSB0')
|
|
||||||
baudrate = connection_config.get('baudrate', 115200)
|
|
||||||
|
|
||||||
try:
|
|
||||||
import serial_asyncio
|
|
||||||
reader, writer = await serial_asyncio.open_serial_connection(
|
|
||||||
url=port, baudrate=baudrate
|
|
||||||
)
|
|
||||||
self._connection = (reader, writer)
|
|
||||||
return True
|
|
||||||
except ImportError:
|
|
||||||
self.logger.warning("serial_asyncio not available, using mock connection")
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Serial connection failed: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
elif self.hardware_type in ('router', 'atheros', 'intel'):
|
|
||||||
# Network connection for router
|
|
||||||
host = connection_config.get('host', '192.168.1.1')
|
|
||||||
port = connection_config.get('port', 5500)
|
|
||||||
|
|
||||||
try:
|
|
||||||
reader, writer = await asyncio.open_connection(host, port)
|
|
||||||
self._connection = (reader, writer)
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Network connection failed: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def _close_hardware_connection(self) -> None:
|
async def _close_hardware_connection(self) -> None:
|
||||||
"""Close hardware connection."""
|
"""Close hardware connection (to be implemented by subclasses)."""
|
||||||
if self._connection:
|
# Placeholder implementation for testing
|
||||||
try:
|
pass
|
||||||
reader, writer = self._connection
|
|
||||||
writer.close()
|
|
||||||
await writer.wait_closed()
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Error closing connection: {e}")
|
|
||||||
finally:
|
|
||||||
self._connection = None
|
|
||||||
|
|
||||||
async def _read_raw_data(self) -> bytes:
|
async def _read_raw_data(self) -> bytes:
|
||||||
"""Read raw data from hardware."""
|
"""Read raw data from hardware (to be implemented by subclasses)."""
|
||||||
if self._connection:
|
# Placeholder implementation for testing
|
||||||
reader, writer = self._connection
|
return b"CSI_DATA:1234567890,3,56,2400,20,15.5,[1.0,2.0,3.0],[0.5,1.5,2.5]"
|
||||||
try:
|
|
||||||
# Read until newline or buffer size
|
|
||||||
data = await asyncio.wait_for(
|
|
||||||
reader.readline(),
|
|
||||||
timeout=self.timeout
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
raise ConnectionError("Read timeout")
|
|
||||||
else:
|
|
||||||
# Mock data for testing when no real connection
|
|
||||||
raise ConnectionError("No active connection")
|
|
||||||
@@ -265,371 +265,30 @@ class PoseService:
|
|||||||
self.logger.error(f"Error in pose estimation: {e}")
|
self.logger.error(f"Error in pose estimation: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def _parse_pose_outputs(self, outputs: Dict[str, torch.Tensor]) -> List[Dict[str, Any]]:
|
def _parse_pose_outputs(self, outputs: torch.Tensor) -> List[Dict[str, Any]]:
|
||||||
"""Parse neural network outputs into pose detections.
|
"""Parse neural network outputs into pose detections."""
|
||||||
|
|
||||||
The DensePose model outputs:
|
|
||||||
- segmentation: (batch, num_parts+1, H, W) - body part segmentation
|
|
||||||
- uv_coords: (batch, 2, H, W) - UV coordinates for surface mapping
|
|
||||||
|
|
||||||
Returns list of detected persons with keypoints and body parts.
|
|
||||||
"""
|
|
||||||
poses = []
|
|
||||||
|
|
||||||
# Handle different output formats
|
|
||||||
if isinstance(outputs, torch.Tensor):
|
|
||||||
# Simple tensor output - use legacy parsing
|
|
||||||
return self._parse_simple_outputs(outputs)
|
|
||||||
|
|
||||||
# DensePose structured output
|
|
||||||
segmentation = outputs.get('segmentation')
|
|
||||||
uv_coords = outputs.get('uv_coords')
|
|
||||||
|
|
||||||
if segmentation is None:
|
|
||||||
return []
|
|
||||||
|
|
||||||
batch_size = segmentation.shape[0]
|
|
||||||
|
|
||||||
for batch_idx in range(batch_size):
|
|
||||||
# Get segmentation for this sample
|
|
||||||
seg = segmentation[batch_idx] # (num_parts+1, H, W)
|
|
||||||
|
|
||||||
# Find persons by analyzing body part segmentation
|
|
||||||
# Background is class 0, body parts are 1-24
|
|
||||||
body_mask = seg[1:].sum(dim=0) > seg[0] # Any body part vs background
|
|
||||||
|
|
||||||
if not body_mask.any():
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Find connected components (persons)
|
|
||||||
person_regions = self._find_person_regions(body_mask)
|
|
||||||
|
|
||||||
for person_idx, region in enumerate(person_regions):
|
|
||||||
# Extract keypoints from body part segmentation
|
|
||||||
keypoints = self._extract_keypoints_from_segmentation(seg, region)
|
|
||||||
|
|
||||||
# Calculate bounding box from region
|
|
||||||
bbox = self._calculate_bounding_box(region)
|
|
||||||
|
|
||||||
# Calculate confidence from segmentation probabilities
|
|
||||||
seg_probs = torch.softmax(seg, dim=0)
|
|
||||||
region_mask = region['mask']
|
|
||||||
confidence = float(seg_probs[1:, region_mask].max().item())
|
|
||||||
|
|
||||||
# Classify activity from pose keypoints
|
|
||||||
activity = self._classify_activity_from_keypoints(keypoints)
|
|
||||||
|
|
||||||
pose = {
|
|
||||||
"person_id": person_idx,
|
|
||||||
"confidence": confidence,
|
|
||||||
"keypoints": keypoints,
|
|
||||||
"bounding_box": bbox,
|
|
||||||
"activity": activity,
|
|
||||||
"timestamp": datetime.now().isoformat(),
|
|
||||||
"body_parts": self._extract_body_parts(seg, region) if uv_coords is not None else None
|
|
||||||
}
|
|
||||||
|
|
||||||
poses.append(pose)
|
|
||||||
|
|
||||||
return poses
|
|
||||||
|
|
||||||
def _parse_simple_outputs(self, outputs: torch.Tensor) -> List[Dict[str, Any]]:
|
|
||||||
"""Parse simple tensor outputs (fallback for non-DensePose models)."""
|
|
||||||
poses = []
|
poses = []
|
||||||
|
|
||||||
|
# This is a simplified parsing - in reality, this would depend on the model architecture
|
||||||
|
# For now, generate mock poses based on the output shape
|
||||||
batch_size = outputs.shape[0]
|
batch_size = outputs.shape[0]
|
||||||
|
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
output = outputs[i]
|
# Extract pose information (mock implementation)
|
||||||
|
confidence = float(torch.sigmoid(outputs[i, 0]).item()) if outputs.shape[1] > 0 else 0.5
|
||||||
# Extract confidence from first channel
|
|
||||||
confidence = float(torch.sigmoid(output[0]).mean().item()) if output.numel() > 0 else 0.0
|
|
||||||
|
|
||||||
if confidence < 0.1:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Try to extract keypoints from output tensor
|
|
||||||
keypoints = self._extract_keypoints_from_tensor(output)
|
|
||||||
bbox = self._estimate_bbox_from_keypoints(keypoints)
|
|
||||||
activity = self._classify_activity_from_keypoints(keypoints)
|
|
||||||
|
|
||||||
pose = {
|
pose = {
|
||||||
"person_id": i,
|
"person_id": i,
|
||||||
"confidence": confidence,
|
"confidence": confidence,
|
||||||
"keypoints": keypoints,
|
"keypoints": self._generate_keypoints(),
|
||||||
"bounding_box": bbox,
|
"bounding_box": self._generate_bounding_box(),
|
||||||
"activity": activity,
|
"activity": self._classify_activity(outputs[i] if len(outputs.shape) > 1 else outputs),
|
||||||
"timestamp": datetime.now().isoformat()
|
"timestamp": datetime.now().isoformat()
|
||||||
}
|
}
|
||||||
|
|
||||||
poses.append(pose)
|
poses.append(pose)
|
||||||
|
|
||||||
return poses
|
return poses
|
||||||
|
|
||||||
def _find_person_regions(self, body_mask: torch.Tensor) -> List[Dict[str, Any]]:
|
|
||||||
"""Find distinct person regions in body mask using connected components."""
|
|
||||||
# Convert to numpy for connected component analysis
|
|
||||||
mask_np = body_mask.cpu().numpy().astype(np.uint8)
|
|
||||||
|
|
||||||
# Simple connected component labeling
|
|
||||||
from scipy import ndimage
|
|
||||||
labeled, num_features = ndimage.label(mask_np)
|
|
||||||
|
|
||||||
regions = []
|
|
||||||
for label_id in range(1, num_features + 1):
|
|
||||||
region_mask = labeled == label_id
|
|
||||||
if region_mask.sum() < 100: # Minimum region size
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Find bounding coordinates
|
|
||||||
coords = np.where(region_mask)
|
|
||||||
regions.append({
|
|
||||||
'mask': torch.from_numpy(region_mask),
|
|
||||||
'y_min': int(coords[0].min()),
|
|
||||||
'y_max': int(coords[0].max()),
|
|
||||||
'x_min': int(coords[1].min()),
|
|
||||||
'x_max': int(coords[1].max()),
|
|
||||||
'area': int(region_mask.sum())
|
|
||||||
})
|
|
||||||
|
|
||||||
return regions
|
|
||||||
|
|
||||||
def _extract_keypoints_from_segmentation(
|
|
||||||
self, segmentation: torch.Tensor, region: Dict[str, Any]
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""Extract keypoints from body part segmentation."""
|
|
||||||
keypoint_names = [
|
|
||||||
"nose", "left_eye", "right_eye", "left_ear", "right_ear",
|
|
||||||
"left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
|
|
||||||
"left_wrist", "right_wrist", "left_hip", "right_hip",
|
|
||||||
"left_knee", "right_knee", "left_ankle", "right_ankle"
|
|
||||||
]
|
|
||||||
|
|
||||||
# Mapping from body parts to keypoints
|
|
||||||
# DensePose has 24 body parts, we map to COCO keypoints
|
|
||||||
part_to_keypoint = {
|
|
||||||
14: "nose", # Head -> nose
|
|
||||||
10: "left_shoulder", 11: "right_shoulder",
|
|
||||||
12: "left_elbow", 13: "right_elbow",
|
|
||||||
2: "left_wrist", 3: "right_wrist", # Hands approximate wrists
|
|
||||||
7: "left_hip", 6: "right_hip", # Upper legs
|
|
||||||
9: "left_knee", 8: "right_knee", # Lower legs
|
|
||||||
4: "left_ankle", 5: "right_ankle", # Feet approximate ankles
|
|
||||||
}
|
|
||||||
|
|
||||||
h, w = segmentation.shape[1], segmentation.shape[2]
|
|
||||||
keypoints = []
|
|
||||||
|
|
||||||
# Get softmax probabilities
|
|
||||||
seg_probs = torch.softmax(segmentation, dim=0)
|
|
||||||
|
|
||||||
for kp_name in keypoint_names:
|
|
||||||
# Find which body part corresponds to this keypoint
|
|
||||||
part_idx = None
|
|
||||||
for part, name in part_to_keypoint.items():
|
|
||||||
if name == kp_name:
|
|
||||||
part_idx = part
|
|
||||||
break
|
|
||||||
|
|
||||||
if part_idx is not None and part_idx < seg_probs.shape[0]:
|
|
||||||
# Get probability map for this part within the region
|
|
||||||
part_prob = seg_probs[part_idx] * region['mask'].float()
|
|
||||||
|
|
||||||
if part_prob.max() > 0.1:
|
|
||||||
# Find location of maximum probability
|
|
||||||
max_idx = part_prob.argmax()
|
|
||||||
y = int(max_idx // w)
|
|
||||||
x = int(max_idx % w)
|
|
||||||
|
|
||||||
keypoints.append({
|
|
||||||
"name": kp_name,
|
|
||||||
"x": float(x) / w,
|
|
||||||
"y": float(y) / h,
|
|
||||||
"confidence": float(part_prob.max().item())
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
# Keypoint not visible
|
|
||||||
keypoints.append({
|
|
||||||
"name": kp_name,
|
|
||||||
"x": 0.0,
|
|
||||||
"y": 0.0,
|
|
||||||
"confidence": 0.0
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
# Estimate position based on body region
|
|
||||||
cx = (region['x_min'] + region['x_max']) / 2 / w
|
|
||||||
cy = (region['y_min'] + region['y_max']) / 2 / h
|
|
||||||
keypoints.append({
|
|
||||||
"name": kp_name,
|
|
||||||
"x": float(cx),
|
|
||||||
"y": float(cy),
|
|
||||||
"confidence": 0.1
|
|
||||||
})
|
|
||||||
|
|
||||||
return keypoints
|
|
||||||
|
|
||||||
def _calculate_bounding_box(self, region: Dict[str, Any]) -> Dict[str, float]:
|
|
||||||
"""Calculate normalized bounding box from region."""
|
|
||||||
# Assume region contains mask shape info
|
|
||||||
mask = region['mask']
|
|
||||||
h, w = mask.shape
|
|
||||||
|
|
||||||
return {
|
|
||||||
"x": float(region['x_min']) / w,
|
|
||||||
"y": float(region['y_min']) / h,
|
|
||||||
"width": float(region['x_max'] - region['x_min']) / w,
|
|
||||||
"height": float(region['y_max'] - region['y_min']) / h
|
|
||||||
}
|
|
||||||
|
|
||||||
def _extract_body_parts(
|
|
||||||
self, segmentation: torch.Tensor, region: Dict[str, Any]
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Extract body part information from segmentation."""
|
|
||||||
part_names = [
|
|
||||||
"background", "torso", "right_hand", "left_hand", "left_foot", "right_foot",
|
|
||||||
"upper_leg_right", "upper_leg_left", "lower_leg_right", "lower_leg_left",
|
|
||||||
"upper_arm_left", "upper_arm_right", "lower_arm_left", "lower_arm_right", "head"
|
|
||||||
]
|
|
||||||
|
|
||||||
seg_probs = torch.softmax(segmentation, dim=0)
|
|
||||||
region_mask = region['mask']
|
|
||||||
|
|
||||||
parts = {}
|
|
||||||
for i, name in enumerate(part_names):
|
|
||||||
if i < seg_probs.shape[0]:
|
|
||||||
part_prob = seg_probs[i] * region_mask.float()
|
|
||||||
parts[name] = {
|
|
||||||
"present": bool(part_prob.max() > 0.3),
|
|
||||||
"confidence": float(part_prob.max().item()),
|
|
||||||
"coverage": float((part_prob > 0.3).sum().item() / max(1, region_mask.sum().item()))
|
|
||||||
}
|
|
||||||
|
|
||||||
return parts
|
|
||||||
|
|
||||||
def _extract_keypoints_from_tensor(self, output: torch.Tensor) -> List[Dict[str, Any]]:
|
|
||||||
"""Extract keypoints from a generic output tensor."""
|
|
||||||
keypoint_names = [
|
|
||||||
"nose", "left_eye", "right_eye", "left_ear", "right_ear",
|
|
||||||
"left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
|
|
||||||
"left_wrist", "right_wrist", "left_hip", "right_hip",
|
|
||||||
"left_knee", "right_knee", "left_ankle", "right_ankle"
|
|
||||||
]
|
|
||||||
|
|
||||||
keypoints = []
|
|
||||||
|
|
||||||
# Try to interpret output as heatmaps
|
|
||||||
if output.dim() >= 2:
|
|
||||||
flat = output.flatten()
|
|
||||||
num_kp = len(keypoint_names)
|
|
||||||
|
|
||||||
# Divide output evenly for each keypoint
|
|
||||||
chunk_size = len(flat) // num_kp if num_kp > 0 else 1
|
|
||||||
|
|
||||||
for i, name in enumerate(keypoint_names):
|
|
||||||
start = i * chunk_size
|
|
||||||
end = min(start + chunk_size, len(flat))
|
|
||||||
|
|
||||||
if start < len(flat):
|
|
||||||
chunk = flat[start:end]
|
|
||||||
# Find max location in chunk
|
|
||||||
max_val = chunk.max().item()
|
|
||||||
max_idx = chunk.argmax().item()
|
|
||||||
|
|
||||||
# Convert to x, y (assume square spatial layout)
|
|
||||||
side = int(np.sqrt(chunk_size))
|
|
||||||
if side > 0:
|
|
||||||
x = (max_idx % side) / side
|
|
||||||
y = (max_idx // side) / side
|
|
||||||
else:
|
|
||||||
x, y = 0.5, 0.5
|
|
||||||
|
|
||||||
keypoints.append({
|
|
||||||
"name": name,
|
|
||||||
"x": float(x),
|
|
||||||
"y": float(y),
|
|
||||||
"confidence": float(torch.sigmoid(torch.tensor(max_val)).item())
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
keypoints.append({
|
|
||||||
"name": name, "x": 0.5, "y": 0.5, "confidence": 0.0
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
# Fallback
|
|
||||||
for name in keypoint_names:
|
|
||||||
keypoints.append({"name": name, "x": 0.5, "y": 0.5, "confidence": 0.1})
|
|
||||||
|
|
||||||
return keypoints
|
|
||||||
|
|
||||||
def _estimate_bbox_from_keypoints(self, keypoints: List[Dict[str, Any]]) -> Dict[str, float]:
|
|
||||||
"""Estimate bounding box from keypoint positions."""
|
|
||||||
valid_kps = [kp for kp in keypoints if kp['confidence'] > 0.1]
|
|
||||||
|
|
||||||
if not valid_kps:
|
|
||||||
return {"x": 0.3, "y": 0.2, "width": 0.4, "height": 0.6}
|
|
||||||
|
|
||||||
xs = [kp['x'] for kp in valid_kps]
|
|
||||||
ys = [kp['y'] for kp in valid_kps]
|
|
||||||
|
|
||||||
x_min, x_max = min(xs), max(xs)
|
|
||||||
y_min, y_max = min(ys), max(ys)
|
|
||||||
|
|
||||||
# Add padding
|
|
||||||
padding = 0.05
|
|
||||||
x_min = max(0, x_min - padding)
|
|
||||||
y_min = max(0, y_min - padding)
|
|
||||||
x_max = min(1, x_max + padding)
|
|
||||||
y_max = min(1, y_max + padding)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"x": x_min,
|
|
||||||
"y": y_min,
|
|
||||||
"width": x_max - x_min,
|
|
||||||
"height": y_max - y_min
|
|
||||||
}
|
|
||||||
|
|
||||||
def _classify_activity_from_keypoints(self, keypoints: List[Dict[str, Any]]) -> str:
|
|
||||||
"""Classify activity based on keypoint positions."""
|
|
||||||
# Get key body parts
|
|
||||||
kp_dict = {kp['name']: kp for kp in keypoints}
|
|
||||||
|
|
||||||
# Check if enough keypoints are detected
|
|
||||||
valid_count = sum(1 for kp in keypoints if kp['confidence'] > 0.3)
|
|
||||||
if valid_count < 5:
|
|
||||||
return "unknown"
|
|
||||||
|
|
||||||
# Get relevant keypoints
|
|
||||||
nose = kp_dict.get('nose', {})
|
|
||||||
l_hip = kp_dict.get('left_hip', {})
|
|
||||||
r_hip = kp_dict.get('right_hip', {})
|
|
||||||
l_ankle = kp_dict.get('left_ankle', {})
|
|
||||||
r_ankle = kp_dict.get('right_ankle', {})
|
|
||||||
l_shoulder = kp_dict.get('left_shoulder', {})
|
|
||||||
r_shoulder = kp_dict.get('right_shoulder', {})
|
|
||||||
|
|
||||||
# Calculate body metrics
|
|
||||||
hip_y = (l_hip.get('y', 0.5) + r_hip.get('y', 0.5)) / 2
|
|
||||||
ankle_y = (l_ankle.get('y', 0.8) + r_ankle.get('y', 0.8)) / 2
|
|
||||||
shoulder_y = (l_shoulder.get('y', 0.3) + r_shoulder.get('y', 0.3)) / 2
|
|
||||||
nose_y = nose.get('y', 0.2)
|
|
||||||
|
|
||||||
# Leg spread (horizontal distance between ankles)
|
|
||||||
leg_spread = abs(l_ankle.get('x', 0.5) - r_ankle.get('x', 0.5))
|
|
||||||
|
|
||||||
# Vertical compression (how "tall" the pose is)
|
|
||||||
vertical_span = ankle_y - nose_y if ankle_y > nose_y else 0.6
|
|
||||||
|
|
||||||
# Classification logic
|
|
||||||
if vertical_span < 0.3:
|
|
||||||
# Very compressed vertically - likely lying down
|
|
||||||
return "lying"
|
|
||||||
elif vertical_span < 0.45 and hip_y > 0.5:
|
|
||||||
# Medium compression with low hips - sitting
|
|
||||||
return "sitting"
|
|
||||||
elif leg_spread > 0.15:
|
|
||||||
# Legs apart - likely walking
|
|
||||||
return "walking"
|
|
||||||
else:
|
|
||||||
# Default upright pose
|
|
||||||
return "standing"
|
|
||||||
|
|
||||||
def _generate_mock_poses(self) -> List[Dict[str, Any]]:
|
def _generate_mock_poses(self) -> List[Dict[str, Any]]:
|
||||||
"""Generate mock pose data for development."""
|
"""Generate mock pose data for development."""
|
||||||
|
|||||||
Reference in New Issue
Block a user