feat: Complete Rust port of WiFi-DensePose with modular crates
Major changes: - Organized Python v1 implementation into v1/ subdirectory - Created Rust workspace with 9 modular crates: - wifi-densepose-core: Core types, traits, errors - wifi-densepose-signal: CSI processing, phase sanitization, FFT - wifi-densepose-nn: Neural network inference (ONNX/Candle/tch) - wifi-densepose-api: Axum-based REST/WebSocket API - wifi-densepose-db: SQLx database layer - wifi-densepose-config: Configuration management - wifi-densepose-hardware: Hardware abstraction - wifi-densepose-wasm: WebAssembly bindings - wifi-densepose-cli: Command-line interface Documentation: - ADR-001: Workspace structure - ADR-002: Signal processing library selection - ADR-003: Neural network inference strategy - DDD domain model with bounded contexts Testing: - 69 tests passing across all crates - Signal processing: 45 tests - Neural networks: 21 tests - Core: 3 doc tests Performance targets: - 10x faster CSI processing (~0.5ms vs ~5ms) - 5x lower memory usage (~100MB vs ~500MB) - WASM support for browser deployment
This commit is contained in:
13
v1/src/core/__init__.py
Normal file
13
v1/src/core/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
Core package for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
from .csi_processor import CSIProcessor
|
||||
from .phase_sanitizer import PhaseSanitizer
|
||||
from .router_interface import RouterInterface
|
||||
|
||||
__all__ = [
|
||||
'CSIProcessor',
|
||||
'PhaseSanitizer',
|
||||
'RouterInterface'
|
||||
]
|
||||
425
v1/src/core/csi_processor.py
Normal file
425
v1/src/core/csi_processor.py
Normal file
@@ -0,0 +1,425 @@
|
||||
"""CSI data processor for WiFi-DensePose system using TDD approach."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import numpy as np
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Any, Optional, List
|
||||
from dataclasses import dataclass
|
||||
from collections import deque
|
||||
import scipy.signal
|
||||
import scipy.fft
|
||||
|
||||
try:
|
||||
from ..hardware.csi_extractor import CSIData
|
||||
except ImportError:
|
||||
# Handle import for testing
|
||||
from src.hardware.csi_extractor import CSIData
|
||||
|
||||
|
||||
class CSIProcessingError(Exception):
|
||||
"""Exception raised for CSI processing errors."""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class CSIFeatures:
|
||||
"""Data structure for extracted CSI features."""
|
||||
amplitude_mean: np.ndarray
|
||||
amplitude_variance: np.ndarray
|
||||
phase_difference: np.ndarray
|
||||
correlation_matrix: np.ndarray
|
||||
doppler_shift: np.ndarray
|
||||
power_spectral_density: np.ndarray
|
||||
timestamp: datetime
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class HumanDetectionResult:
|
||||
"""Data structure for human detection results."""
|
||||
human_detected: bool
|
||||
confidence: float
|
||||
motion_score: float
|
||||
timestamp: datetime
|
||||
features: CSIFeatures
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class CSIProcessor:
|
||||
"""Processes CSI data for human detection and pose estimation."""
|
||||
|
||||
def __init__(self, config: Dict[str, Any], logger: Optional[logging.Logger] = None):
|
||||
"""Initialize CSI processor.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary
|
||||
logger: Optional logger instance
|
||||
|
||||
Raises:
|
||||
ValueError: If configuration is invalid
|
||||
"""
|
||||
self._validate_config(config)
|
||||
|
||||
self.config = config
|
||||
self.logger = logger or logging.getLogger(__name__)
|
||||
|
||||
# Processing parameters
|
||||
self.sampling_rate = config['sampling_rate']
|
||||
self.window_size = config['window_size']
|
||||
self.overlap = config['overlap']
|
||||
self.noise_threshold = config['noise_threshold']
|
||||
self.human_detection_threshold = config.get('human_detection_threshold', 0.8)
|
||||
self.smoothing_factor = config.get('smoothing_factor', 0.9)
|
||||
self.max_history_size = config.get('max_history_size', 500)
|
||||
|
||||
# Feature extraction flags
|
||||
self.enable_preprocessing = config.get('enable_preprocessing', True)
|
||||
self.enable_feature_extraction = config.get('enable_feature_extraction', True)
|
||||
self.enable_human_detection = config.get('enable_human_detection', True)
|
||||
|
||||
# Processing state
|
||||
self.csi_history = deque(maxlen=self.max_history_size)
|
||||
self.previous_detection_confidence = 0.0
|
||||
|
||||
# Statistics tracking
|
||||
self._total_processed = 0
|
||||
self._processing_errors = 0
|
||||
self._human_detections = 0
|
||||
|
||||
def _validate_config(self, config: Dict[str, Any]) -> None:
|
||||
"""Validate configuration parameters.
|
||||
|
||||
Args:
|
||||
config: Configuration to validate
|
||||
|
||||
Raises:
|
||||
ValueError: If configuration is invalid
|
||||
"""
|
||||
required_fields = ['sampling_rate', 'window_size', 'overlap', 'noise_threshold']
|
||||
missing_fields = [field for field in required_fields if field not in config]
|
||||
|
||||
if missing_fields:
|
||||
raise ValueError(f"Missing required configuration: {missing_fields}")
|
||||
|
||||
if config['sampling_rate'] <= 0:
|
||||
raise ValueError("sampling_rate must be positive")
|
||||
|
||||
if config['window_size'] <= 0:
|
||||
raise ValueError("window_size must be positive")
|
||||
|
||||
if not 0 <= config['overlap'] < 1:
|
||||
raise ValueError("overlap must be between 0 and 1")
|
||||
|
||||
def preprocess_csi_data(self, csi_data: CSIData) -> CSIData:
|
||||
"""Preprocess CSI data for feature extraction.
|
||||
|
||||
Args:
|
||||
csi_data: Raw CSI data
|
||||
|
||||
Returns:
|
||||
Preprocessed CSI data
|
||||
|
||||
Raises:
|
||||
CSIProcessingError: If preprocessing fails
|
||||
"""
|
||||
if not self.enable_preprocessing:
|
||||
return csi_data
|
||||
|
||||
try:
|
||||
# Remove noise from the signal
|
||||
cleaned_data = self._remove_noise(csi_data)
|
||||
|
||||
# Apply windowing function
|
||||
windowed_data = self._apply_windowing(cleaned_data)
|
||||
|
||||
# Normalize amplitude values
|
||||
normalized_data = self._normalize_amplitude(windowed_data)
|
||||
|
||||
return normalized_data
|
||||
|
||||
except Exception as e:
|
||||
raise CSIProcessingError(f"Failed to preprocess CSI data: {e}")
|
||||
|
||||
def extract_features(self, csi_data: CSIData) -> Optional[CSIFeatures]:
|
||||
"""Extract features from CSI data.
|
||||
|
||||
Args:
|
||||
csi_data: Preprocessed CSI data
|
||||
|
||||
Returns:
|
||||
Extracted features or None if disabled
|
||||
|
||||
Raises:
|
||||
CSIProcessingError: If feature extraction fails
|
||||
"""
|
||||
if not self.enable_feature_extraction:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Extract amplitude-based features
|
||||
amplitude_mean, amplitude_variance = self._extract_amplitude_features(csi_data)
|
||||
|
||||
# Extract phase-based features
|
||||
phase_difference = self._extract_phase_features(csi_data)
|
||||
|
||||
# Extract correlation features
|
||||
correlation_matrix = self._extract_correlation_features(csi_data)
|
||||
|
||||
# Extract Doppler and frequency features
|
||||
doppler_shift, power_spectral_density = self._extract_doppler_features(csi_data)
|
||||
|
||||
return CSIFeatures(
|
||||
amplitude_mean=amplitude_mean,
|
||||
amplitude_variance=amplitude_variance,
|
||||
phase_difference=phase_difference,
|
||||
correlation_matrix=correlation_matrix,
|
||||
doppler_shift=doppler_shift,
|
||||
power_spectral_density=power_spectral_density,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
metadata={'processing_params': self.config}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise CSIProcessingError(f"Failed to extract features: {e}")
|
||||
|
||||
def detect_human_presence(self, features: CSIFeatures) -> Optional[HumanDetectionResult]:
|
||||
"""Detect human presence from CSI features.
|
||||
|
||||
Args:
|
||||
features: Extracted CSI features
|
||||
|
||||
Returns:
|
||||
Detection result or None if disabled
|
||||
|
||||
Raises:
|
||||
CSIProcessingError: If detection fails
|
||||
"""
|
||||
if not self.enable_human_detection:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Analyze motion patterns
|
||||
motion_score = self._analyze_motion_patterns(features)
|
||||
|
||||
# Calculate detection confidence
|
||||
raw_confidence = self._calculate_detection_confidence(features, motion_score)
|
||||
|
||||
# Apply temporal smoothing
|
||||
smoothed_confidence = self._apply_temporal_smoothing(raw_confidence)
|
||||
|
||||
# Determine if human is detected
|
||||
human_detected = smoothed_confidence >= self.human_detection_threshold
|
||||
|
||||
if human_detected:
|
||||
self._human_detections += 1
|
||||
|
||||
return HumanDetectionResult(
|
||||
human_detected=human_detected,
|
||||
confidence=smoothed_confidence,
|
||||
motion_score=motion_score,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
features=features,
|
||||
metadata={'threshold': self.human_detection_threshold}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise CSIProcessingError(f"Failed to detect human presence: {e}")
|
||||
|
||||
async def process_csi_data(self, csi_data: CSIData) -> HumanDetectionResult:
|
||||
"""Process CSI data through the complete pipeline.
|
||||
|
||||
Args:
|
||||
csi_data: Raw CSI data
|
||||
|
||||
Returns:
|
||||
Human detection result
|
||||
|
||||
Raises:
|
||||
CSIProcessingError: If processing fails
|
||||
"""
|
||||
try:
|
||||
self._total_processed += 1
|
||||
|
||||
# Preprocess the data
|
||||
preprocessed_data = self.preprocess_csi_data(csi_data)
|
||||
|
||||
# Extract features
|
||||
features = self.extract_features(preprocessed_data)
|
||||
|
||||
# Detect human presence
|
||||
detection_result = self.detect_human_presence(features)
|
||||
|
||||
# Add to history
|
||||
self.add_to_history(csi_data)
|
||||
|
||||
return detection_result
|
||||
|
||||
except Exception as e:
|
||||
self._processing_errors += 1
|
||||
raise CSIProcessingError(f"Pipeline processing failed: {e}")
|
||||
|
||||
def add_to_history(self, csi_data: CSIData) -> None:
|
||||
"""Add CSI data to processing history.
|
||||
|
||||
Args:
|
||||
csi_data: CSI data to add to history
|
||||
"""
|
||||
self.csi_history.append(csi_data)
|
||||
|
||||
def clear_history(self) -> None:
|
||||
"""Clear the CSI data history."""
|
||||
self.csi_history.clear()
|
||||
|
||||
def get_recent_history(self, count: int) -> List[CSIData]:
|
||||
"""Get recent CSI data from history.
|
||||
|
||||
Args:
|
||||
count: Number of recent entries to return
|
||||
|
||||
Returns:
|
||||
List of recent CSI data entries
|
||||
"""
|
||||
if count >= len(self.csi_history):
|
||||
return list(self.csi_history)
|
||||
else:
|
||||
return list(self.csi_history)[-count:]
|
||||
|
||||
def get_processing_statistics(self) -> Dict[str, Any]:
|
||||
"""Get processing statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary containing processing statistics
|
||||
"""
|
||||
error_rate = self._processing_errors / self._total_processed if self._total_processed > 0 else 0
|
||||
detection_rate = self._human_detections / self._total_processed if self._total_processed > 0 else 0
|
||||
|
||||
return {
|
||||
'total_processed': self._total_processed,
|
||||
'processing_errors': self._processing_errors,
|
||||
'human_detections': self._human_detections,
|
||||
'error_rate': error_rate,
|
||||
'detection_rate': detection_rate,
|
||||
'history_size': len(self.csi_history)
|
||||
}
|
||||
|
||||
def reset_statistics(self) -> None:
|
||||
"""Reset processing statistics."""
|
||||
self._total_processed = 0
|
||||
self._processing_errors = 0
|
||||
self._human_detections = 0
|
||||
|
||||
# Private processing methods
|
||||
def _remove_noise(self, csi_data: CSIData) -> CSIData:
|
||||
"""Remove noise from CSI data."""
|
||||
# Apply noise filtering based on threshold
|
||||
amplitude_db = 20 * np.log10(np.abs(csi_data.amplitude) + 1e-12)
|
||||
noise_mask = amplitude_db > self.noise_threshold
|
||||
|
||||
filtered_amplitude = csi_data.amplitude.copy()
|
||||
filtered_amplitude[~noise_mask] = 0
|
||||
|
||||
return CSIData(
|
||||
timestamp=csi_data.timestamp,
|
||||
amplitude=filtered_amplitude,
|
||||
phase=csi_data.phase,
|
||||
frequency=csi_data.frequency,
|
||||
bandwidth=csi_data.bandwidth,
|
||||
num_subcarriers=csi_data.num_subcarriers,
|
||||
num_antennas=csi_data.num_antennas,
|
||||
snr=csi_data.snr,
|
||||
metadata={**csi_data.metadata, 'noise_filtered': True}
|
||||
)
|
||||
|
||||
def _apply_windowing(self, csi_data: CSIData) -> CSIData:
|
||||
"""Apply windowing function to CSI data."""
|
||||
# Apply Hamming window to reduce spectral leakage
|
||||
window = scipy.signal.windows.hamming(csi_data.num_subcarriers)
|
||||
windowed_amplitude = csi_data.amplitude * window[np.newaxis, :]
|
||||
|
||||
return CSIData(
|
||||
timestamp=csi_data.timestamp,
|
||||
amplitude=windowed_amplitude,
|
||||
phase=csi_data.phase,
|
||||
frequency=csi_data.frequency,
|
||||
bandwidth=csi_data.bandwidth,
|
||||
num_subcarriers=csi_data.num_subcarriers,
|
||||
num_antennas=csi_data.num_antennas,
|
||||
snr=csi_data.snr,
|
||||
metadata={**csi_data.metadata, 'windowed': True}
|
||||
)
|
||||
|
||||
def _normalize_amplitude(self, csi_data: CSIData) -> CSIData:
|
||||
"""Normalize amplitude values."""
|
||||
# Normalize to unit variance
|
||||
normalized_amplitude = csi_data.amplitude / (np.std(csi_data.amplitude) + 1e-12)
|
||||
|
||||
return CSIData(
|
||||
timestamp=csi_data.timestamp,
|
||||
amplitude=normalized_amplitude,
|
||||
phase=csi_data.phase,
|
||||
frequency=csi_data.frequency,
|
||||
bandwidth=csi_data.bandwidth,
|
||||
num_subcarriers=csi_data.num_subcarriers,
|
||||
num_antennas=csi_data.num_antennas,
|
||||
snr=csi_data.snr,
|
||||
metadata={**csi_data.metadata, 'normalized': True}
|
||||
)
|
||||
|
||||
def _extract_amplitude_features(self, csi_data: CSIData) -> tuple:
|
||||
"""Extract amplitude-based features."""
|
||||
amplitude_mean = np.mean(csi_data.amplitude, axis=0)
|
||||
amplitude_variance = np.var(csi_data.amplitude, axis=0)
|
||||
return amplitude_mean, amplitude_variance
|
||||
|
||||
def _extract_phase_features(self, csi_data: CSIData) -> np.ndarray:
|
||||
"""Extract phase-based features."""
|
||||
# Calculate phase differences between adjacent subcarriers
|
||||
phase_diff = np.diff(csi_data.phase, axis=1)
|
||||
return np.mean(phase_diff, axis=0)
|
||||
|
||||
def _extract_correlation_features(self, csi_data: CSIData) -> np.ndarray:
|
||||
"""Extract correlation features between antennas."""
|
||||
# Calculate correlation matrix between antennas
|
||||
correlation_matrix = np.corrcoef(csi_data.amplitude)
|
||||
return correlation_matrix
|
||||
|
||||
def _extract_doppler_features(self, csi_data: CSIData) -> tuple:
|
||||
"""Extract Doppler and frequency domain features."""
|
||||
# Simple Doppler estimation (would use history in real implementation)
|
||||
doppler_shift = np.random.rand(10) # Placeholder
|
||||
|
||||
# Power spectral density
|
||||
psd = np.abs(scipy.fft.fft(csi_data.amplitude.flatten(), n=128))**2
|
||||
|
||||
return doppler_shift, psd
|
||||
|
||||
def _analyze_motion_patterns(self, features: CSIFeatures) -> float:
|
||||
"""Analyze motion patterns from features."""
|
||||
# Analyze variance and correlation patterns to detect motion
|
||||
variance_score = np.mean(features.amplitude_variance)
|
||||
correlation_score = np.mean(np.abs(features.correlation_matrix - np.eye(features.correlation_matrix.shape[0])))
|
||||
|
||||
# Combine scores (simplified approach)
|
||||
motion_score = 0.6 * variance_score + 0.4 * correlation_score
|
||||
return np.clip(motion_score, 0.0, 1.0)
|
||||
|
||||
def _calculate_detection_confidence(self, features: CSIFeatures, motion_score: float) -> float:
|
||||
"""Calculate detection confidence based on features."""
|
||||
# Combine multiple feature indicators
|
||||
amplitude_indicator = np.mean(features.amplitude_mean) > 0.1
|
||||
phase_indicator = np.std(features.phase_difference) > 0.05
|
||||
motion_indicator = motion_score > 0.3
|
||||
|
||||
# Weight the indicators
|
||||
confidence = (0.4 * amplitude_indicator + 0.3 * phase_indicator + 0.3 * motion_indicator)
|
||||
return np.clip(confidence, 0.0, 1.0)
|
||||
|
||||
def _apply_temporal_smoothing(self, raw_confidence: float) -> float:
|
||||
"""Apply temporal smoothing to detection confidence."""
|
||||
# Exponential moving average
|
||||
smoothed_confidence = (self.smoothing_factor * self.previous_detection_confidence +
|
||||
(1 - self.smoothing_factor) * raw_confidence)
|
||||
|
||||
self.previous_detection_confidence = smoothed_confidence
|
||||
return smoothed_confidence
|
||||
347
v1/src/core/phase_sanitizer.py
Normal file
347
v1/src/core/phase_sanitizer.py
Normal file
@@ -0,0 +1,347 @@
|
||||
"""Phase sanitization module for WiFi-DensePose system using TDD approach."""
|
||||
|
||||
import numpy as np
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
from datetime import datetime, timezone
|
||||
from scipy import signal
|
||||
|
||||
|
||||
class PhaseSanitizationError(Exception):
|
||||
"""Exception raised for phase sanitization errors."""
|
||||
pass
|
||||
|
||||
|
||||
class PhaseSanitizer:
|
||||
"""Sanitizes phase data from CSI signals for reliable processing."""
|
||||
|
||||
def __init__(self, config: Dict[str, Any], logger: Optional[logging.Logger] = None):
|
||||
"""Initialize phase sanitizer.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary
|
||||
logger: Optional logger instance
|
||||
|
||||
Raises:
|
||||
ValueError: If configuration is invalid
|
||||
"""
|
||||
self._validate_config(config)
|
||||
|
||||
self.config = config
|
||||
self.logger = logger or logging.getLogger(__name__)
|
||||
|
||||
# Processing parameters
|
||||
self.unwrapping_method = config['unwrapping_method']
|
||||
self.outlier_threshold = config['outlier_threshold']
|
||||
self.smoothing_window = config['smoothing_window']
|
||||
|
||||
# Optional parameters with defaults
|
||||
self.enable_outlier_removal = config.get('enable_outlier_removal', True)
|
||||
self.enable_smoothing = config.get('enable_smoothing', True)
|
||||
self.enable_noise_filtering = config.get('enable_noise_filtering', False)
|
||||
self.noise_threshold = config.get('noise_threshold', 0.05)
|
||||
self.phase_range = config.get('phase_range', (-np.pi, np.pi))
|
||||
|
||||
# Statistics tracking
|
||||
self._total_processed = 0
|
||||
self._outliers_removed = 0
|
||||
self._sanitization_errors = 0
|
||||
|
||||
def _validate_config(self, config: Dict[str, Any]) -> None:
|
||||
"""Validate configuration parameters.
|
||||
|
||||
Args:
|
||||
config: Configuration to validate
|
||||
|
||||
Raises:
|
||||
ValueError: If configuration is invalid
|
||||
"""
|
||||
required_fields = ['unwrapping_method', 'outlier_threshold', 'smoothing_window']
|
||||
missing_fields = [field for field in required_fields if field not in config]
|
||||
|
||||
if missing_fields:
|
||||
raise ValueError(f"Missing required configuration: {missing_fields}")
|
||||
|
||||
# Validate unwrapping method
|
||||
valid_methods = ['numpy', 'scipy', 'custom']
|
||||
if config['unwrapping_method'] not in valid_methods:
|
||||
raise ValueError(f"Invalid unwrapping method: {config['unwrapping_method']}. Must be one of {valid_methods}")
|
||||
|
||||
# Validate thresholds
|
||||
if config['outlier_threshold'] <= 0:
|
||||
raise ValueError("outlier_threshold must be positive")
|
||||
|
||||
if config['smoothing_window'] <= 0:
|
||||
raise ValueError("smoothing_window must be positive")
|
||||
|
||||
def unwrap_phase(self, phase_data: np.ndarray) -> np.ndarray:
|
||||
"""Unwrap phase data to remove discontinuities.
|
||||
|
||||
Args:
|
||||
phase_data: Wrapped phase data (2D array)
|
||||
|
||||
Returns:
|
||||
Unwrapped phase data
|
||||
|
||||
Raises:
|
||||
PhaseSanitizationError: If unwrapping fails
|
||||
"""
|
||||
try:
|
||||
if self.unwrapping_method == 'numpy':
|
||||
return self._unwrap_numpy(phase_data)
|
||||
elif self.unwrapping_method == 'scipy':
|
||||
return self._unwrap_scipy(phase_data)
|
||||
elif self.unwrapping_method == 'custom':
|
||||
return self._unwrap_custom(phase_data)
|
||||
else:
|
||||
raise ValueError(f"Unknown unwrapping method: {self.unwrapping_method}")
|
||||
|
||||
except Exception as e:
|
||||
raise PhaseSanitizationError(f"Failed to unwrap phase: {e}")
|
||||
|
||||
def _unwrap_numpy(self, phase_data: np.ndarray) -> np.ndarray:
|
||||
"""Unwrap phase using numpy's unwrap function."""
|
||||
if phase_data.size == 0:
|
||||
raise ValueError("Cannot unwrap empty phase data")
|
||||
return np.unwrap(phase_data, axis=1)
|
||||
|
||||
def _unwrap_scipy(self, phase_data: np.ndarray) -> np.ndarray:
|
||||
"""Unwrap phase using scipy's unwrap function."""
|
||||
if phase_data.size == 0:
|
||||
raise ValueError("Cannot unwrap empty phase data")
|
||||
return np.unwrap(phase_data, axis=1)
|
||||
|
||||
def _unwrap_custom(self, phase_data: np.ndarray) -> np.ndarray:
|
||||
"""Unwrap phase using custom algorithm."""
|
||||
if phase_data.size == 0:
|
||||
raise ValueError("Cannot unwrap empty phase data")
|
||||
# Simple custom unwrapping algorithm
|
||||
unwrapped = phase_data.copy()
|
||||
for i in range(phase_data.shape[0]):
|
||||
unwrapped[i, :] = np.unwrap(phase_data[i, :])
|
||||
return unwrapped
|
||||
|
||||
def remove_outliers(self, phase_data: np.ndarray) -> np.ndarray:
|
||||
"""Remove outliers from phase data.
|
||||
|
||||
Args:
|
||||
phase_data: Phase data (2D array)
|
||||
|
||||
Returns:
|
||||
Phase data with outliers removed
|
||||
|
||||
Raises:
|
||||
PhaseSanitizationError: If outlier removal fails
|
||||
"""
|
||||
if not self.enable_outlier_removal:
|
||||
return phase_data
|
||||
|
||||
try:
|
||||
# Detect outliers
|
||||
outlier_mask = self._detect_outliers(phase_data)
|
||||
|
||||
# Interpolate outliers
|
||||
clean_data = self._interpolate_outliers(phase_data, outlier_mask)
|
||||
|
||||
return clean_data
|
||||
|
||||
except Exception as e:
|
||||
raise PhaseSanitizationError(f"Failed to remove outliers: {e}")
|
||||
|
||||
def _detect_outliers(self, phase_data: np.ndarray) -> np.ndarray:
|
||||
"""Detect outliers using statistical methods."""
|
||||
# Use Z-score method to detect outliers
|
||||
z_scores = np.abs((phase_data - np.mean(phase_data, axis=1, keepdims=True)) /
|
||||
(np.std(phase_data, axis=1, keepdims=True) + 1e-8))
|
||||
outlier_mask = z_scores > self.outlier_threshold
|
||||
|
||||
# Update statistics
|
||||
self._outliers_removed += np.sum(outlier_mask)
|
||||
|
||||
return outlier_mask
|
||||
|
||||
def _interpolate_outliers(self, phase_data: np.ndarray, outlier_mask: np.ndarray) -> np.ndarray:
|
||||
"""Interpolate outlier values."""
|
||||
clean_data = phase_data.copy()
|
||||
|
||||
for i in range(phase_data.shape[0]):
|
||||
outliers = outlier_mask[i, :]
|
||||
if np.any(outliers):
|
||||
# Linear interpolation for outliers
|
||||
valid_indices = np.where(~outliers)[0]
|
||||
outlier_indices = np.where(outliers)[0]
|
||||
|
||||
if len(valid_indices) > 1:
|
||||
clean_data[i, outlier_indices] = np.interp(
|
||||
outlier_indices, valid_indices, phase_data[i, valid_indices]
|
||||
)
|
||||
|
||||
return clean_data
|
||||
|
||||
def smooth_phase(self, phase_data: np.ndarray) -> np.ndarray:
|
||||
"""Smooth phase data to reduce noise.
|
||||
|
||||
Args:
|
||||
phase_data: Phase data (2D array)
|
||||
|
||||
Returns:
|
||||
Smoothed phase data
|
||||
|
||||
Raises:
|
||||
PhaseSanitizationError: If smoothing fails
|
||||
"""
|
||||
if not self.enable_smoothing:
|
||||
return phase_data
|
||||
|
||||
try:
|
||||
smoothed_data = self._apply_moving_average(phase_data, self.smoothing_window)
|
||||
return smoothed_data
|
||||
|
||||
except Exception as e:
|
||||
raise PhaseSanitizationError(f"Failed to smooth phase: {e}")
|
||||
|
||||
def _apply_moving_average(self, phase_data: np.ndarray, window_size: int) -> np.ndarray:
|
||||
"""Apply moving average smoothing."""
|
||||
smoothed_data = phase_data.copy()
|
||||
|
||||
# Ensure window size is odd
|
||||
if window_size % 2 == 0:
|
||||
window_size += 1
|
||||
|
||||
half_window = window_size // 2
|
||||
|
||||
for i in range(phase_data.shape[0]):
|
||||
for j in range(half_window, phase_data.shape[1] - half_window):
|
||||
start_idx = j - half_window
|
||||
end_idx = j + half_window + 1
|
||||
smoothed_data[i, j] = np.mean(phase_data[i, start_idx:end_idx])
|
||||
|
||||
return smoothed_data
|
||||
|
||||
def filter_noise(self, phase_data: np.ndarray) -> np.ndarray:
|
||||
"""Filter noise from phase data.
|
||||
|
||||
Args:
|
||||
phase_data: Phase data (2D array)
|
||||
|
||||
Returns:
|
||||
Filtered phase data
|
||||
|
||||
Raises:
|
||||
PhaseSanitizationError: If noise filtering fails
|
||||
"""
|
||||
if not self.enable_noise_filtering:
|
||||
return phase_data
|
||||
|
||||
try:
|
||||
filtered_data = self._apply_low_pass_filter(phase_data, self.noise_threshold)
|
||||
return filtered_data
|
||||
|
||||
except Exception as e:
|
||||
raise PhaseSanitizationError(f"Failed to filter noise: {e}")
|
||||
|
||||
def _apply_low_pass_filter(self, phase_data: np.ndarray, threshold: float) -> np.ndarray:
|
||||
"""Apply low-pass filter to remove high-frequency noise."""
|
||||
filtered_data = phase_data.copy()
|
||||
|
||||
# Check if data is large enough for filtering
|
||||
min_filter_length = 18 # Minimum length required for filtfilt with order 4
|
||||
if phase_data.shape[1] < min_filter_length:
|
||||
# Skip filtering for small arrays
|
||||
return filtered_data
|
||||
|
||||
# Apply Butterworth low-pass filter
|
||||
nyquist = 0.5
|
||||
cutoff = threshold * nyquist
|
||||
|
||||
# Design filter
|
||||
b, a = signal.butter(4, cutoff, btype='low')
|
||||
|
||||
# Apply filter to each antenna
|
||||
for i in range(phase_data.shape[0]):
|
||||
filtered_data[i, :] = signal.filtfilt(b, a, phase_data[i, :])
|
||||
|
||||
return filtered_data
|
||||
|
||||
def sanitize_phase(self, phase_data: np.ndarray) -> np.ndarray:
|
||||
"""Sanitize phase data through complete pipeline.
|
||||
|
||||
Args:
|
||||
phase_data: Raw phase data (2D array)
|
||||
|
||||
Returns:
|
||||
Sanitized phase data
|
||||
|
||||
Raises:
|
||||
PhaseSanitizationError: If sanitization fails
|
||||
"""
|
||||
try:
|
||||
self._total_processed += 1
|
||||
|
||||
# Validate input data
|
||||
self.validate_phase_data(phase_data)
|
||||
|
||||
# Apply complete sanitization pipeline
|
||||
sanitized_data = self.unwrap_phase(phase_data)
|
||||
sanitized_data = self.remove_outliers(sanitized_data)
|
||||
sanitized_data = self.smooth_phase(sanitized_data)
|
||||
sanitized_data = self.filter_noise(sanitized_data)
|
||||
|
||||
return sanitized_data
|
||||
|
||||
except PhaseSanitizationError:
|
||||
self._sanitization_errors += 1
|
||||
raise
|
||||
except Exception as e:
|
||||
self._sanitization_errors += 1
|
||||
raise PhaseSanitizationError(f"Sanitization pipeline failed: {e}")
|
||||
|
||||
def validate_phase_data(self, phase_data: np.ndarray) -> bool:
|
||||
"""Validate phase data format and values.
|
||||
|
||||
Args:
|
||||
phase_data: Phase data to validate
|
||||
|
||||
Returns:
|
||||
True if valid
|
||||
|
||||
Raises:
|
||||
PhaseSanitizationError: If validation fails
|
||||
"""
|
||||
# Check if data is 2D
|
||||
if phase_data.ndim != 2:
|
||||
raise PhaseSanitizationError("Phase data must be 2D array")
|
||||
|
||||
# Check if data is not empty
|
||||
if phase_data.size == 0:
|
||||
raise PhaseSanitizationError("Phase data cannot be empty")
|
||||
|
||||
# Check if values are within valid range
|
||||
min_val, max_val = self.phase_range
|
||||
if np.any(phase_data < min_val) or np.any(phase_data > max_val):
|
||||
raise PhaseSanitizationError(f"Phase values outside valid range [{min_val}, {max_val}]")
|
||||
|
||||
return True
|
||||
|
||||
def get_sanitization_statistics(self) -> Dict[str, Any]:
|
||||
"""Get sanitization statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary containing sanitization statistics
|
||||
"""
|
||||
outlier_rate = self._outliers_removed / self._total_processed if self._total_processed > 0 else 0
|
||||
error_rate = self._sanitization_errors / self._total_processed if self._total_processed > 0 else 0
|
||||
|
||||
return {
|
||||
'total_processed': self._total_processed,
|
||||
'outliers_removed': self._outliers_removed,
|
||||
'sanitization_errors': self._sanitization_errors,
|
||||
'outlier_rate': outlier_rate,
|
||||
'error_rate': error_rate
|
||||
}
|
||||
|
||||
def reset_statistics(self) -> None:
|
||||
"""Reset sanitization statistics."""
|
||||
self._total_processed = 0
|
||||
self._outliers_removed = 0
|
||||
self._sanitization_errors = 0
|
||||
340
v1/src/core/router_interface.py
Normal file
340
v1/src/core/router_interface.py
Normal file
@@ -0,0 +1,340 @@
|
||||
"""
|
||||
Router interface for WiFi CSI data collection
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RouterInterface:
|
||||
"""Interface for connecting to WiFi routers and collecting CSI data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
router_id: str,
|
||||
host: str,
|
||||
port: int = 22,
|
||||
username: str = "admin",
|
||||
password: str = "",
|
||||
interface: str = "wlan0",
|
||||
mock_mode: bool = False
|
||||
):
|
||||
"""Initialize router interface.
|
||||
|
||||
Args:
|
||||
router_id: Unique identifier for the router
|
||||
host: Router IP address or hostname
|
||||
port: SSH port for connection
|
||||
username: SSH username
|
||||
password: SSH password
|
||||
interface: WiFi interface name
|
||||
mock_mode: Whether to use mock data instead of real connection
|
||||
"""
|
||||
self.router_id = router_id
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.interface = interface
|
||||
self.mock_mode = mock_mode
|
||||
|
||||
self.logger = logging.getLogger(f"{__name__}.{router_id}")
|
||||
|
||||
# Connection state
|
||||
self.is_connected = False
|
||||
self.connection = None
|
||||
self.last_error = None
|
||||
|
||||
# Data collection state
|
||||
self.last_data_time = None
|
||||
self.error_count = 0
|
||||
self.sample_count = 0
|
||||
|
||||
# Mock data generation
|
||||
self.mock_data_generator = None
|
||||
if mock_mode:
|
||||
self._initialize_mock_generator()
|
||||
|
||||
def _initialize_mock_generator(self):
|
||||
"""Initialize mock data generator."""
|
||||
self.mock_data_generator = {
|
||||
'phase': 0,
|
||||
'amplitude_base': 1.0,
|
||||
'frequency': 0.1,
|
||||
'noise_level': 0.1
|
||||
}
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to the router."""
|
||||
if self.mock_mode:
|
||||
self.is_connected = True
|
||||
self.logger.info(f"Mock connection established to router {self.router_id}")
|
||||
return
|
||||
|
||||
try:
|
||||
self.logger.info(f"Connecting to router {self.router_id} at {self.host}:{self.port}")
|
||||
|
||||
# In a real implementation, this would establish SSH connection
|
||||
# For now, we'll simulate the connection
|
||||
await asyncio.sleep(0.1) # Simulate connection delay
|
||||
|
||||
self.is_connected = True
|
||||
self.error_count = 0
|
||||
self.logger.info(f"Connected to router {self.router_id}")
|
||||
|
||||
except Exception as e:
|
||||
self.last_error = str(e)
|
||||
self.error_count += 1
|
||||
self.logger.error(f"Failed to connect to router {self.router_id}: {e}")
|
||||
raise
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnect from the router."""
|
||||
try:
|
||||
if self.connection:
|
||||
# Close SSH connection
|
||||
self.connection = None
|
||||
|
||||
self.is_connected = False
|
||||
self.logger.info(f"Disconnected from router {self.router_id}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error disconnecting from router {self.router_id}: {e}")
|
||||
|
||||
async def reconnect(self):
|
||||
"""Reconnect to the router."""
|
||||
await self.disconnect()
|
||||
await asyncio.sleep(1) # Wait before reconnecting
|
||||
await self.connect()
|
||||
|
||||
async def get_csi_data(self) -> Optional[np.ndarray]:
|
||||
"""Get CSI data from the router.
|
||||
|
||||
Returns:
|
||||
CSI data as numpy array, or None if no data available
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise RuntimeError(f"Router {self.router_id} is not connected")
|
||||
|
||||
try:
|
||||
if self.mock_mode:
|
||||
csi_data = self._generate_mock_csi_data()
|
||||
else:
|
||||
csi_data = await self._collect_real_csi_data()
|
||||
|
||||
if csi_data is not None:
|
||||
self.last_data_time = datetime.now()
|
||||
self.sample_count += 1
|
||||
self.error_count = 0
|
||||
|
||||
return csi_data
|
||||
|
||||
except Exception as e:
|
||||
self.last_error = str(e)
|
||||
self.error_count += 1
|
||||
self.logger.error(f"Error getting CSI data from router {self.router_id}: {e}")
|
||||
return None
|
||||
|
||||
def _generate_mock_csi_data(self) -> np.ndarray:
|
||||
"""Generate mock CSI data for testing."""
|
||||
# Simulate CSI data with realistic characteristics
|
||||
num_subcarriers = 64
|
||||
num_antennas = 4
|
||||
num_samples = 100
|
||||
|
||||
# Update mock generator state
|
||||
self.mock_data_generator['phase'] += self.mock_data_generator['frequency']
|
||||
|
||||
# Generate amplitude and phase data
|
||||
time_axis = np.linspace(0, 1, num_samples)
|
||||
|
||||
# Create realistic CSI patterns
|
||||
csi_data = np.zeros((num_antennas, num_subcarriers, num_samples), dtype=complex)
|
||||
|
||||
for antenna in range(num_antennas):
|
||||
for subcarrier in range(num_subcarriers):
|
||||
# Base signal with some variation per antenna/subcarrier
|
||||
amplitude = (
|
||||
self.mock_data_generator['amplitude_base'] *
|
||||
(1 + 0.2 * np.sin(2 * np.pi * subcarrier / num_subcarriers)) *
|
||||
(1 + 0.1 * antenna)
|
||||
)
|
||||
|
||||
# Phase with spatial and frequency variation
|
||||
phase_offset = (
|
||||
self.mock_data_generator['phase'] +
|
||||
2 * np.pi * subcarrier / num_subcarriers +
|
||||
np.pi * antenna / num_antennas
|
||||
)
|
||||
|
||||
# Add some movement simulation
|
||||
movement_freq = 0.5 # Hz
|
||||
movement_amplitude = 0.3
|
||||
movement = movement_amplitude * np.sin(2 * np.pi * movement_freq * time_axis)
|
||||
|
||||
# Generate complex signal
|
||||
signal_amplitude = amplitude * (1 + movement)
|
||||
signal_phase = phase_offset + movement * 0.5
|
||||
|
||||
# Add noise
|
||||
noise_real = np.random.normal(0, self.mock_data_generator['noise_level'], num_samples)
|
||||
noise_imag = np.random.normal(0, self.mock_data_generator['noise_level'], num_samples)
|
||||
noise = noise_real + 1j * noise_imag
|
||||
|
||||
# Create complex signal
|
||||
signal = signal_amplitude * np.exp(1j * signal_phase) + noise
|
||||
csi_data[antenna, subcarrier, :] = signal
|
||||
|
||||
return csi_data
|
||||
|
||||
async def _collect_real_csi_data(self) -> Optional[np.ndarray]:
|
||||
"""Collect real CSI data from router (placeholder implementation)."""
|
||||
# This would implement the actual CSI data collection
|
||||
# For now, return None to indicate no real implementation
|
||||
self.logger.warning("Real CSI data collection not implemented")
|
||||
return None
|
||||
|
||||
async def check_health(self) -> bool:
|
||||
"""Check if the router connection is healthy.
|
||||
|
||||
Returns:
|
||||
True if healthy, False otherwise
|
||||
"""
|
||||
if not self.is_connected:
|
||||
return False
|
||||
|
||||
try:
|
||||
# In mock mode, always healthy
|
||||
if self.mock_mode:
|
||||
return True
|
||||
|
||||
# For real connections, we could ping the router or check SSH connection
|
||||
# For now, consider healthy if error count is low
|
||||
return self.error_count < 5
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error checking health of router {self.router_id}: {e}")
|
||||
return False
|
||||
|
||||
async def get_status(self) -> Dict[str, Any]:
|
||||
"""Get router status information.
|
||||
|
||||
Returns:
|
||||
Dictionary containing router status
|
||||
"""
|
||||
return {
|
||||
"router_id": self.router_id,
|
||||
"connected": self.is_connected,
|
||||
"mock_mode": self.mock_mode,
|
||||
"last_data_time": self.last_data_time.isoformat() if self.last_data_time else None,
|
||||
"error_count": self.error_count,
|
||||
"sample_count": self.sample_count,
|
||||
"last_error": self.last_error,
|
||||
"configuration": {
|
||||
"host": self.host,
|
||||
"port": self.port,
|
||||
"username": self.username,
|
||||
"interface": self.interface
|
||||
}
|
||||
}
|
||||
|
||||
async def get_router_info(self) -> Dict[str, Any]:
|
||||
"""Get router hardware information.
|
||||
|
||||
Returns:
|
||||
Dictionary containing router information
|
||||
"""
|
||||
if self.mock_mode:
|
||||
return {
|
||||
"model": "Mock Router",
|
||||
"firmware": "1.0.0-mock",
|
||||
"wifi_standard": "802.11ac",
|
||||
"antennas": 4,
|
||||
"supported_bands": ["2.4GHz", "5GHz"],
|
||||
"csi_capabilities": {
|
||||
"max_subcarriers": 64,
|
||||
"max_antennas": 4,
|
||||
"sampling_rate": 1000
|
||||
}
|
||||
}
|
||||
|
||||
# For real routers, this would query the actual hardware
|
||||
return {
|
||||
"model": "Unknown",
|
||||
"firmware": "Unknown",
|
||||
"wifi_standard": "Unknown",
|
||||
"antennas": 1,
|
||||
"supported_bands": ["Unknown"],
|
||||
"csi_capabilities": {
|
||||
"max_subcarriers": 64,
|
||||
"max_antennas": 1,
|
||||
"sampling_rate": 100
|
||||
}
|
||||
}
|
||||
|
||||
async def configure_csi_collection(self, config: Dict[str, Any]) -> bool:
|
||||
"""Configure CSI data collection parameters.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary
|
||||
|
||||
Returns:
|
||||
True if configuration successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
if self.mock_mode:
|
||||
# Update mock generator parameters
|
||||
if 'sampling_rate' in config:
|
||||
self.mock_data_generator['frequency'] = config['sampling_rate'] / 1000.0
|
||||
|
||||
if 'noise_level' in config:
|
||||
self.mock_data_generator['noise_level'] = config['noise_level']
|
||||
|
||||
self.logger.info(f"Mock CSI collection configured for router {self.router_id}")
|
||||
return True
|
||||
|
||||
# For real routers, this would send configuration commands
|
||||
self.logger.warning("Real CSI configuration not implemented")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error configuring CSI collection for router {self.router_id}: {e}")
|
||||
return False
|
||||
|
||||
def get_metrics(self) -> Dict[str, Any]:
|
||||
"""Get router interface metrics.
|
||||
|
||||
Returns:
|
||||
Dictionary containing metrics
|
||||
"""
|
||||
uptime = 0
|
||||
if self.last_data_time:
|
||||
uptime = (datetime.now() - self.last_data_time).total_seconds()
|
||||
|
||||
success_rate = 0
|
||||
if self.sample_count > 0:
|
||||
success_rate = (self.sample_count - self.error_count) / self.sample_count
|
||||
|
||||
return {
|
||||
"router_id": self.router_id,
|
||||
"sample_count": self.sample_count,
|
||||
"error_count": self.error_count,
|
||||
"success_rate": success_rate,
|
||||
"uptime_seconds": uptime,
|
||||
"is_connected": self.is_connected,
|
||||
"mock_mode": self.mock_mode
|
||||
}
|
||||
|
||||
def reset_stats(self):
|
||||
"""Reset statistics counters."""
|
||||
self.error_count = 0
|
||||
self.sample_count = 0
|
||||
self.last_error = None
|
||||
self.logger.info(f"Statistics reset for router {self.router_id}")
|
||||
Reference in New Issue
Block a user