feat: Complete Rust port of WiFi-DensePose with modular crates

Major changes:
- Organized Python v1 implementation into v1/ subdirectory
- Created Rust workspace with 9 modular crates:
  - wifi-densepose-core: Core types, traits, errors
  - wifi-densepose-signal: CSI processing, phase sanitization, FFT
  - wifi-densepose-nn: Neural network inference (ONNX/Candle/tch)
  - wifi-densepose-api: Axum-based REST/WebSocket API
  - wifi-densepose-db: SQLx database layer
  - wifi-densepose-config: Configuration management
  - wifi-densepose-hardware: Hardware abstraction
  - wifi-densepose-wasm: WebAssembly bindings
  - wifi-densepose-cli: Command-line interface

Documentation:
- ADR-001: Workspace structure
- ADR-002: Signal processing library selection
- ADR-003: Neural network inference strategy
- DDD domain model with bounded contexts

Testing:
- 69 tests passing across all crates
- Signal processing: 45 tests
- Neural networks: 21 tests
- Core: 3 doc tests

Performance targets:
- 10x faster CSI processing (~0.5ms vs ~5ms)
- 5x lower memory usage (~100MB vs ~500MB)
- WASM support for browser deployment
This commit is contained in:
Claude
2026-01-13 03:11:16 +00:00
parent 5101504b72
commit 6ed69a3d48
427 changed files with 90993 additions and 0 deletions

13
v1/src/core/__init__.py Normal file
View File

@@ -0,0 +1,13 @@
"""
Core package for WiFi-DensePose API
"""
from .csi_processor import CSIProcessor
from .phase_sanitizer import PhaseSanitizer
from .router_interface import RouterInterface
__all__ = [
'CSIProcessor',
'PhaseSanitizer',
'RouterInterface'
]

View File

@@ -0,0 +1,425 @@
"""CSI data processor for WiFi-DensePose system using TDD approach."""
import asyncio
import logging
import numpy as np
from datetime import datetime, timezone
from typing import Dict, Any, Optional, List
from dataclasses import dataclass
from collections import deque
import scipy.signal
import scipy.fft
try:
from ..hardware.csi_extractor import CSIData
except ImportError:
# Handle import for testing
from src.hardware.csi_extractor import CSIData
class CSIProcessingError(Exception):
"""Exception raised for CSI processing errors."""
pass
@dataclass
class CSIFeatures:
"""Data structure for extracted CSI features."""
amplitude_mean: np.ndarray
amplitude_variance: np.ndarray
phase_difference: np.ndarray
correlation_matrix: np.ndarray
doppler_shift: np.ndarray
power_spectral_density: np.ndarray
timestamp: datetime
metadata: Dict[str, Any]
@dataclass
class HumanDetectionResult:
"""Data structure for human detection results."""
human_detected: bool
confidence: float
motion_score: float
timestamp: datetime
features: CSIFeatures
metadata: Dict[str, Any]
class CSIProcessor:
"""Processes CSI data for human detection and pose estimation."""
def __init__(self, config: Dict[str, Any], logger: Optional[logging.Logger] = None):
"""Initialize CSI processor.
Args:
config: Configuration dictionary
logger: Optional logger instance
Raises:
ValueError: If configuration is invalid
"""
self._validate_config(config)
self.config = config
self.logger = logger or logging.getLogger(__name__)
# Processing parameters
self.sampling_rate = config['sampling_rate']
self.window_size = config['window_size']
self.overlap = config['overlap']
self.noise_threshold = config['noise_threshold']
self.human_detection_threshold = config.get('human_detection_threshold', 0.8)
self.smoothing_factor = config.get('smoothing_factor', 0.9)
self.max_history_size = config.get('max_history_size', 500)
# Feature extraction flags
self.enable_preprocessing = config.get('enable_preprocessing', True)
self.enable_feature_extraction = config.get('enable_feature_extraction', True)
self.enable_human_detection = config.get('enable_human_detection', True)
# Processing state
self.csi_history = deque(maxlen=self.max_history_size)
self.previous_detection_confidence = 0.0
# Statistics tracking
self._total_processed = 0
self._processing_errors = 0
self._human_detections = 0
def _validate_config(self, config: Dict[str, Any]) -> None:
"""Validate configuration parameters.
Args:
config: Configuration to validate
Raises:
ValueError: If configuration is invalid
"""
required_fields = ['sampling_rate', 'window_size', 'overlap', 'noise_threshold']
missing_fields = [field for field in required_fields if field not in config]
if missing_fields:
raise ValueError(f"Missing required configuration: {missing_fields}")
if config['sampling_rate'] <= 0:
raise ValueError("sampling_rate must be positive")
if config['window_size'] <= 0:
raise ValueError("window_size must be positive")
if not 0 <= config['overlap'] < 1:
raise ValueError("overlap must be between 0 and 1")
def preprocess_csi_data(self, csi_data: CSIData) -> CSIData:
"""Preprocess CSI data for feature extraction.
Args:
csi_data: Raw CSI data
Returns:
Preprocessed CSI data
Raises:
CSIProcessingError: If preprocessing fails
"""
if not self.enable_preprocessing:
return csi_data
try:
# Remove noise from the signal
cleaned_data = self._remove_noise(csi_data)
# Apply windowing function
windowed_data = self._apply_windowing(cleaned_data)
# Normalize amplitude values
normalized_data = self._normalize_amplitude(windowed_data)
return normalized_data
except Exception as e:
raise CSIProcessingError(f"Failed to preprocess CSI data: {e}")
def extract_features(self, csi_data: CSIData) -> Optional[CSIFeatures]:
"""Extract features from CSI data.
Args:
csi_data: Preprocessed CSI data
Returns:
Extracted features or None if disabled
Raises:
CSIProcessingError: If feature extraction fails
"""
if not self.enable_feature_extraction:
return None
try:
# Extract amplitude-based features
amplitude_mean, amplitude_variance = self._extract_amplitude_features(csi_data)
# Extract phase-based features
phase_difference = self._extract_phase_features(csi_data)
# Extract correlation features
correlation_matrix = self._extract_correlation_features(csi_data)
# Extract Doppler and frequency features
doppler_shift, power_spectral_density = self._extract_doppler_features(csi_data)
return CSIFeatures(
amplitude_mean=amplitude_mean,
amplitude_variance=amplitude_variance,
phase_difference=phase_difference,
correlation_matrix=correlation_matrix,
doppler_shift=doppler_shift,
power_spectral_density=power_spectral_density,
timestamp=datetime.now(timezone.utc),
metadata={'processing_params': self.config}
)
except Exception as e:
raise CSIProcessingError(f"Failed to extract features: {e}")
def detect_human_presence(self, features: CSIFeatures) -> Optional[HumanDetectionResult]:
"""Detect human presence from CSI features.
Args:
features: Extracted CSI features
Returns:
Detection result or None if disabled
Raises:
CSIProcessingError: If detection fails
"""
if not self.enable_human_detection:
return None
try:
# Analyze motion patterns
motion_score = self._analyze_motion_patterns(features)
# Calculate detection confidence
raw_confidence = self._calculate_detection_confidence(features, motion_score)
# Apply temporal smoothing
smoothed_confidence = self._apply_temporal_smoothing(raw_confidence)
# Determine if human is detected
human_detected = smoothed_confidence >= self.human_detection_threshold
if human_detected:
self._human_detections += 1
return HumanDetectionResult(
human_detected=human_detected,
confidence=smoothed_confidence,
motion_score=motion_score,
timestamp=datetime.now(timezone.utc),
features=features,
metadata={'threshold': self.human_detection_threshold}
)
except Exception as e:
raise CSIProcessingError(f"Failed to detect human presence: {e}")
async def process_csi_data(self, csi_data: CSIData) -> HumanDetectionResult:
"""Process CSI data through the complete pipeline.
Args:
csi_data: Raw CSI data
Returns:
Human detection result
Raises:
CSIProcessingError: If processing fails
"""
try:
self._total_processed += 1
# Preprocess the data
preprocessed_data = self.preprocess_csi_data(csi_data)
# Extract features
features = self.extract_features(preprocessed_data)
# Detect human presence
detection_result = self.detect_human_presence(features)
# Add to history
self.add_to_history(csi_data)
return detection_result
except Exception as e:
self._processing_errors += 1
raise CSIProcessingError(f"Pipeline processing failed: {e}")
def add_to_history(self, csi_data: CSIData) -> None:
"""Add CSI data to processing history.
Args:
csi_data: CSI data to add to history
"""
self.csi_history.append(csi_data)
def clear_history(self) -> None:
"""Clear the CSI data history."""
self.csi_history.clear()
def get_recent_history(self, count: int) -> List[CSIData]:
"""Get recent CSI data from history.
Args:
count: Number of recent entries to return
Returns:
List of recent CSI data entries
"""
if count >= len(self.csi_history):
return list(self.csi_history)
else:
return list(self.csi_history)[-count:]
def get_processing_statistics(self) -> Dict[str, Any]:
"""Get processing statistics.
Returns:
Dictionary containing processing statistics
"""
error_rate = self._processing_errors / self._total_processed if self._total_processed > 0 else 0
detection_rate = self._human_detections / self._total_processed if self._total_processed > 0 else 0
return {
'total_processed': self._total_processed,
'processing_errors': self._processing_errors,
'human_detections': self._human_detections,
'error_rate': error_rate,
'detection_rate': detection_rate,
'history_size': len(self.csi_history)
}
def reset_statistics(self) -> None:
"""Reset processing statistics."""
self._total_processed = 0
self._processing_errors = 0
self._human_detections = 0
# Private processing methods
def _remove_noise(self, csi_data: CSIData) -> CSIData:
"""Remove noise from CSI data."""
# Apply noise filtering based on threshold
amplitude_db = 20 * np.log10(np.abs(csi_data.amplitude) + 1e-12)
noise_mask = amplitude_db > self.noise_threshold
filtered_amplitude = csi_data.amplitude.copy()
filtered_amplitude[~noise_mask] = 0
return CSIData(
timestamp=csi_data.timestamp,
amplitude=filtered_amplitude,
phase=csi_data.phase,
frequency=csi_data.frequency,
bandwidth=csi_data.bandwidth,
num_subcarriers=csi_data.num_subcarriers,
num_antennas=csi_data.num_antennas,
snr=csi_data.snr,
metadata={**csi_data.metadata, 'noise_filtered': True}
)
def _apply_windowing(self, csi_data: CSIData) -> CSIData:
"""Apply windowing function to CSI data."""
# Apply Hamming window to reduce spectral leakage
window = scipy.signal.windows.hamming(csi_data.num_subcarriers)
windowed_amplitude = csi_data.amplitude * window[np.newaxis, :]
return CSIData(
timestamp=csi_data.timestamp,
amplitude=windowed_amplitude,
phase=csi_data.phase,
frequency=csi_data.frequency,
bandwidth=csi_data.bandwidth,
num_subcarriers=csi_data.num_subcarriers,
num_antennas=csi_data.num_antennas,
snr=csi_data.snr,
metadata={**csi_data.metadata, 'windowed': True}
)
def _normalize_amplitude(self, csi_data: CSIData) -> CSIData:
"""Normalize amplitude values."""
# Normalize to unit variance
normalized_amplitude = csi_data.amplitude / (np.std(csi_data.amplitude) + 1e-12)
return CSIData(
timestamp=csi_data.timestamp,
amplitude=normalized_amplitude,
phase=csi_data.phase,
frequency=csi_data.frequency,
bandwidth=csi_data.bandwidth,
num_subcarriers=csi_data.num_subcarriers,
num_antennas=csi_data.num_antennas,
snr=csi_data.snr,
metadata={**csi_data.metadata, 'normalized': True}
)
def _extract_amplitude_features(self, csi_data: CSIData) -> tuple:
"""Extract amplitude-based features."""
amplitude_mean = np.mean(csi_data.amplitude, axis=0)
amplitude_variance = np.var(csi_data.amplitude, axis=0)
return amplitude_mean, amplitude_variance
def _extract_phase_features(self, csi_data: CSIData) -> np.ndarray:
"""Extract phase-based features."""
# Calculate phase differences between adjacent subcarriers
phase_diff = np.diff(csi_data.phase, axis=1)
return np.mean(phase_diff, axis=0)
def _extract_correlation_features(self, csi_data: CSIData) -> np.ndarray:
"""Extract correlation features between antennas."""
# Calculate correlation matrix between antennas
correlation_matrix = np.corrcoef(csi_data.amplitude)
return correlation_matrix
def _extract_doppler_features(self, csi_data: CSIData) -> tuple:
"""Extract Doppler and frequency domain features."""
# Simple Doppler estimation (would use history in real implementation)
doppler_shift = np.random.rand(10) # Placeholder
# Power spectral density
psd = np.abs(scipy.fft.fft(csi_data.amplitude.flatten(), n=128))**2
return doppler_shift, psd
def _analyze_motion_patterns(self, features: CSIFeatures) -> float:
"""Analyze motion patterns from features."""
# Analyze variance and correlation patterns to detect motion
variance_score = np.mean(features.amplitude_variance)
correlation_score = np.mean(np.abs(features.correlation_matrix - np.eye(features.correlation_matrix.shape[0])))
# Combine scores (simplified approach)
motion_score = 0.6 * variance_score + 0.4 * correlation_score
return np.clip(motion_score, 0.0, 1.0)
def _calculate_detection_confidence(self, features: CSIFeatures, motion_score: float) -> float:
"""Calculate detection confidence based on features."""
# Combine multiple feature indicators
amplitude_indicator = np.mean(features.amplitude_mean) > 0.1
phase_indicator = np.std(features.phase_difference) > 0.05
motion_indicator = motion_score > 0.3
# Weight the indicators
confidence = (0.4 * amplitude_indicator + 0.3 * phase_indicator + 0.3 * motion_indicator)
return np.clip(confidence, 0.0, 1.0)
def _apply_temporal_smoothing(self, raw_confidence: float) -> float:
"""Apply temporal smoothing to detection confidence."""
# Exponential moving average
smoothed_confidence = (self.smoothing_factor * self.previous_detection_confidence +
(1 - self.smoothing_factor) * raw_confidence)
self.previous_detection_confidence = smoothed_confidence
return smoothed_confidence

View File

@@ -0,0 +1,347 @@
"""Phase sanitization module for WiFi-DensePose system using TDD approach."""
import numpy as np
import logging
from typing import Dict, Any, Optional, Tuple
from datetime import datetime, timezone
from scipy import signal
class PhaseSanitizationError(Exception):
"""Exception raised for phase sanitization errors."""
pass
class PhaseSanitizer:
"""Sanitizes phase data from CSI signals for reliable processing."""
def __init__(self, config: Dict[str, Any], logger: Optional[logging.Logger] = None):
"""Initialize phase sanitizer.
Args:
config: Configuration dictionary
logger: Optional logger instance
Raises:
ValueError: If configuration is invalid
"""
self._validate_config(config)
self.config = config
self.logger = logger or logging.getLogger(__name__)
# Processing parameters
self.unwrapping_method = config['unwrapping_method']
self.outlier_threshold = config['outlier_threshold']
self.smoothing_window = config['smoothing_window']
# Optional parameters with defaults
self.enable_outlier_removal = config.get('enable_outlier_removal', True)
self.enable_smoothing = config.get('enable_smoothing', True)
self.enable_noise_filtering = config.get('enable_noise_filtering', False)
self.noise_threshold = config.get('noise_threshold', 0.05)
self.phase_range = config.get('phase_range', (-np.pi, np.pi))
# Statistics tracking
self._total_processed = 0
self._outliers_removed = 0
self._sanitization_errors = 0
def _validate_config(self, config: Dict[str, Any]) -> None:
"""Validate configuration parameters.
Args:
config: Configuration to validate
Raises:
ValueError: If configuration is invalid
"""
required_fields = ['unwrapping_method', 'outlier_threshold', 'smoothing_window']
missing_fields = [field for field in required_fields if field not in config]
if missing_fields:
raise ValueError(f"Missing required configuration: {missing_fields}")
# Validate unwrapping method
valid_methods = ['numpy', 'scipy', 'custom']
if config['unwrapping_method'] not in valid_methods:
raise ValueError(f"Invalid unwrapping method: {config['unwrapping_method']}. Must be one of {valid_methods}")
# Validate thresholds
if config['outlier_threshold'] <= 0:
raise ValueError("outlier_threshold must be positive")
if config['smoothing_window'] <= 0:
raise ValueError("smoothing_window must be positive")
def unwrap_phase(self, phase_data: np.ndarray) -> np.ndarray:
"""Unwrap phase data to remove discontinuities.
Args:
phase_data: Wrapped phase data (2D array)
Returns:
Unwrapped phase data
Raises:
PhaseSanitizationError: If unwrapping fails
"""
try:
if self.unwrapping_method == 'numpy':
return self._unwrap_numpy(phase_data)
elif self.unwrapping_method == 'scipy':
return self._unwrap_scipy(phase_data)
elif self.unwrapping_method == 'custom':
return self._unwrap_custom(phase_data)
else:
raise ValueError(f"Unknown unwrapping method: {self.unwrapping_method}")
except Exception as e:
raise PhaseSanitizationError(f"Failed to unwrap phase: {e}")
def _unwrap_numpy(self, phase_data: np.ndarray) -> np.ndarray:
"""Unwrap phase using numpy's unwrap function."""
if phase_data.size == 0:
raise ValueError("Cannot unwrap empty phase data")
return np.unwrap(phase_data, axis=1)
def _unwrap_scipy(self, phase_data: np.ndarray) -> np.ndarray:
"""Unwrap phase using scipy's unwrap function."""
if phase_data.size == 0:
raise ValueError("Cannot unwrap empty phase data")
return np.unwrap(phase_data, axis=1)
def _unwrap_custom(self, phase_data: np.ndarray) -> np.ndarray:
"""Unwrap phase using custom algorithm."""
if phase_data.size == 0:
raise ValueError("Cannot unwrap empty phase data")
# Simple custom unwrapping algorithm
unwrapped = phase_data.copy()
for i in range(phase_data.shape[0]):
unwrapped[i, :] = np.unwrap(phase_data[i, :])
return unwrapped
def remove_outliers(self, phase_data: np.ndarray) -> np.ndarray:
"""Remove outliers from phase data.
Args:
phase_data: Phase data (2D array)
Returns:
Phase data with outliers removed
Raises:
PhaseSanitizationError: If outlier removal fails
"""
if not self.enable_outlier_removal:
return phase_data
try:
# Detect outliers
outlier_mask = self._detect_outliers(phase_data)
# Interpolate outliers
clean_data = self._interpolate_outliers(phase_data, outlier_mask)
return clean_data
except Exception as e:
raise PhaseSanitizationError(f"Failed to remove outliers: {e}")
def _detect_outliers(self, phase_data: np.ndarray) -> np.ndarray:
"""Detect outliers using statistical methods."""
# Use Z-score method to detect outliers
z_scores = np.abs((phase_data - np.mean(phase_data, axis=1, keepdims=True)) /
(np.std(phase_data, axis=1, keepdims=True) + 1e-8))
outlier_mask = z_scores > self.outlier_threshold
# Update statistics
self._outliers_removed += np.sum(outlier_mask)
return outlier_mask
def _interpolate_outliers(self, phase_data: np.ndarray, outlier_mask: np.ndarray) -> np.ndarray:
"""Interpolate outlier values."""
clean_data = phase_data.copy()
for i in range(phase_data.shape[0]):
outliers = outlier_mask[i, :]
if np.any(outliers):
# Linear interpolation for outliers
valid_indices = np.where(~outliers)[0]
outlier_indices = np.where(outliers)[0]
if len(valid_indices) > 1:
clean_data[i, outlier_indices] = np.interp(
outlier_indices, valid_indices, phase_data[i, valid_indices]
)
return clean_data
def smooth_phase(self, phase_data: np.ndarray) -> np.ndarray:
"""Smooth phase data to reduce noise.
Args:
phase_data: Phase data (2D array)
Returns:
Smoothed phase data
Raises:
PhaseSanitizationError: If smoothing fails
"""
if not self.enable_smoothing:
return phase_data
try:
smoothed_data = self._apply_moving_average(phase_data, self.smoothing_window)
return smoothed_data
except Exception as e:
raise PhaseSanitizationError(f"Failed to smooth phase: {e}")
def _apply_moving_average(self, phase_data: np.ndarray, window_size: int) -> np.ndarray:
"""Apply moving average smoothing."""
smoothed_data = phase_data.copy()
# Ensure window size is odd
if window_size % 2 == 0:
window_size += 1
half_window = window_size // 2
for i in range(phase_data.shape[0]):
for j in range(half_window, phase_data.shape[1] - half_window):
start_idx = j - half_window
end_idx = j + half_window + 1
smoothed_data[i, j] = np.mean(phase_data[i, start_idx:end_idx])
return smoothed_data
def filter_noise(self, phase_data: np.ndarray) -> np.ndarray:
"""Filter noise from phase data.
Args:
phase_data: Phase data (2D array)
Returns:
Filtered phase data
Raises:
PhaseSanitizationError: If noise filtering fails
"""
if not self.enable_noise_filtering:
return phase_data
try:
filtered_data = self._apply_low_pass_filter(phase_data, self.noise_threshold)
return filtered_data
except Exception as e:
raise PhaseSanitizationError(f"Failed to filter noise: {e}")
def _apply_low_pass_filter(self, phase_data: np.ndarray, threshold: float) -> np.ndarray:
"""Apply low-pass filter to remove high-frequency noise."""
filtered_data = phase_data.copy()
# Check if data is large enough for filtering
min_filter_length = 18 # Minimum length required for filtfilt with order 4
if phase_data.shape[1] < min_filter_length:
# Skip filtering for small arrays
return filtered_data
# Apply Butterworth low-pass filter
nyquist = 0.5
cutoff = threshold * nyquist
# Design filter
b, a = signal.butter(4, cutoff, btype='low')
# Apply filter to each antenna
for i in range(phase_data.shape[0]):
filtered_data[i, :] = signal.filtfilt(b, a, phase_data[i, :])
return filtered_data
def sanitize_phase(self, phase_data: np.ndarray) -> np.ndarray:
"""Sanitize phase data through complete pipeline.
Args:
phase_data: Raw phase data (2D array)
Returns:
Sanitized phase data
Raises:
PhaseSanitizationError: If sanitization fails
"""
try:
self._total_processed += 1
# Validate input data
self.validate_phase_data(phase_data)
# Apply complete sanitization pipeline
sanitized_data = self.unwrap_phase(phase_data)
sanitized_data = self.remove_outliers(sanitized_data)
sanitized_data = self.smooth_phase(sanitized_data)
sanitized_data = self.filter_noise(sanitized_data)
return sanitized_data
except PhaseSanitizationError:
self._sanitization_errors += 1
raise
except Exception as e:
self._sanitization_errors += 1
raise PhaseSanitizationError(f"Sanitization pipeline failed: {e}")
def validate_phase_data(self, phase_data: np.ndarray) -> bool:
"""Validate phase data format and values.
Args:
phase_data: Phase data to validate
Returns:
True if valid
Raises:
PhaseSanitizationError: If validation fails
"""
# Check if data is 2D
if phase_data.ndim != 2:
raise PhaseSanitizationError("Phase data must be 2D array")
# Check if data is not empty
if phase_data.size == 0:
raise PhaseSanitizationError("Phase data cannot be empty")
# Check if values are within valid range
min_val, max_val = self.phase_range
if np.any(phase_data < min_val) or np.any(phase_data > max_val):
raise PhaseSanitizationError(f"Phase values outside valid range [{min_val}, {max_val}]")
return True
def get_sanitization_statistics(self) -> Dict[str, Any]:
"""Get sanitization statistics.
Returns:
Dictionary containing sanitization statistics
"""
outlier_rate = self._outliers_removed / self._total_processed if self._total_processed > 0 else 0
error_rate = self._sanitization_errors / self._total_processed if self._total_processed > 0 else 0
return {
'total_processed': self._total_processed,
'outliers_removed': self._outliers_removed,
'sanitization_errors': self._sanitization_errors,
'outlier_rate': outlier_rate,
'error_rate': error_rate
}
def reset_statistics(self) -> None:
"""Reset sanitization statistics."""
self._total_processed = 0
self._outliers_removed = 0
self._sanitization_errors = 0

View File

@@ -0,0 +1,340 @@
"""
Router interface for WiFi CSI data collection
"""
import logging
import asyncio
import time
from typing import Dict, List, Optional, Any
from datetime import datetime
import numpy as np
logger = logging.getLogger(__name__)
class RouterInterface:
"""Interface for connecting to WiFi routers and collecting CSI data."""
def __init__(
self,
router_id: str,
host: str,
port: int = 22,
username: str = "admin",
password: str = "",
interface: str = "wlan0",
mock_mode: bool = False
):
"""Initialize router interface.
Args:
router_id: Unique identifier for the router
host: Router IP address or hostname
port: SSH port for connection
username: SSH username
password: SSH password
interface: WiFi interface name
mock_mode: Whether to use mock data instead of real connection
"""
self.router_id = router_id
self.host = host
self.port = port
self.username = username
self.password = password
self.interface = interface
self.mock_mode = mock_mode
self.logger = logging.getLogger(f"{__name__}.{router_id}")
# Connection state
self.is_connected = False
self.connection = None
self.last_error = None
# Data collection state
self.last_data_time = None
self.error_count = 0
self.sample_count = 0
# Mock data generation
self.mock_data_generator = None
if mock_mode:
self._initialize_mock_generator()
def _initialize_mock_generator(self):
"""Initialize mock data generator."""
self.mock_data_generator = {
'phase': 0,
'amplitude_base': 1.0,
'frequency': 0.1,
'noise_level': 0.1
}
async def connect(self):
"""Connect to the router."""
if self.mock_mode:
self.is_connected = True
self.logger.info(f"Mock connection established to router {self.router_id}")
return
try:
self.logger.info(f"Connecting to router {self.router_id} at {self.host}:{self.port}")
# In a real implementation, this would establish SSH connection
# For now, we'll simulate the connection
await asyncio.sleep(0.1) # Simulate connection delay
self.is_connected = True
self.error_count = 0
self.logger.info(f"Connected to router {self.router_id}")
except Exception as e:
self.last_error = str(e)
self.error_count += 1
self.logger.error(f"Failed to connect to router {self.router_id}: {e}")
raise
async def disconnect(self):
"""Disconnect from the router."""
try:
if self.connection:
# Close SSH connection
self.connection = None
self.is_connected = False
self.logger.info(f"Disconnected from router {self.router_id}")
except Exception as e:
self.logger.error(f"Error disconnecting from router {self.router_id}: {e}")
async def reconnect(self):
"""Reconnect to the router."""
await self.disconnect()
await asyncio.sleep(1) # Wait before reconnecting
await self.connect()
async def get_csi_data(self) -> Optional[np.ndarray]:
"""Get CSI data from the router.
Returns:
CSI data as numpy array, or None if no data available
"""
if not self.is_connected:
raise RuntimeError(f"Router {self.router_id} is not connected")
try:
if self.mock_mode:
csi_data = self._generate_mock_csi_data()
else:
csi_data = await self._collect_real_csi_data()
if csi_data is not None:
self.last_data_time = datetime.now()
self.sample_count += 1
self.error_count = 0
return csi_data
except Exception as e:
self.last_error = str(e)
self.error_count += 1
self.logger.error(f"Error getting CSI data from router {self.router_id}: {e}")
return None
def _generate_mock_csi_data(self) -> np.ndarray:
"""Generate mock CSI data for testing."""
# Simulate CSI data with realistic characteristics
num_subcarriers = 64
num_antennas = 4
num_samples = 100
# Update mock generator state
self.mock_data_generator['phase'] += self.mock_data_generator['frequency']
# Generate amplitude and phase data
time_axis = np.linspace(0, 1, num_samples)
# Create realistic CSI patterns
csi_data = np.zeros((num_antennas, num_subcarriers, num_samples), dtype=complex)
for antenna in range(num_antennas):
for subcarrier in range(num_subcarriers):
# Base signal with some variation per antenna/subcarrier
amplitude = (
self.mock_data_generator['amplitude_base'] *
(1 + 0.2 * np.sin(2 * np.pi * subcarrier / num_subcarriers)) *
(1 + 0.1 * antenna)
)
# Phase with spatial and frequency variation
phase_offset = (
self.mock_data_generator['phase'] +
2 * np.pi * subcarrier / num_subcarriers +
np.pi * antenna / num_antennas
)
# Add some movement simulation
movement_freq = 0.5 # Hz
movement_amplitude = 0.3
movement = movement_amplitude * np.sin(2 * np.pi * movement_freq * time_axis)
# Generate complex signal
signal_amplitude = amplitude * (1 + movement)
signal_phase = phase_offset + movement * 0.5
# Add noise
noise_real = np.random.normal(0, self.mock_data_generator['noise_level'], num_samples)
noise_imag = np.random.normal(0, self.mock_data_generator['noise_level'], num_samples)
noise = noise_real + 1j * noise_imag
# Create complex signal
signal = signal_amplitude * np.exp(1j * signal_phase) + noise
csi_data[antenna, subcarrier, :] = signal
return csi_data
async def _collect_real_csi_data(self) -> Optional[np.ndarray]:
"""Collect real CSI data from router (placeholder implementation)."""
# This would implement the actual CSI data collection
# For now, return None to indicate no real implementation
self.logger.warning("Real CSI data collection not implemented")
return None
async def check_health(self) -> bool:
"""Check if the router connection is healthy.
Returns:
True if healthy, False otherwise
"""
if not self.is_connected:
return False
try:
# In mock mode, always healthy
if self.mock_mode:
return True
# For real connections, we could ping the router or check SSH connection
# For now, consider healthy if error count is low
return self.error_count < 5
except Exception as e:
self.logger.error(f"Error checking health of router {self.router_id}: {e}")
return False
async def get_status(self) -> Dict[str, Any]:
"""Get router status information.
Returns:
Dictionary containing router status
"""
return {
"router_id": self.router_id,
"connected": self.is_connected,
"mock_mode": self.mock_mode,
"last_data_time": self.last_data_time.isoformat() if self.last_data_time else None,
"error_count": self.error_count,
"sample_count": self.sample_count,
"last_error": self.last_error,
"configuration": {
"host": self.host,
"port": self.port,
"username": self.username,
"interface": self.interface
}
}
async def get_router_info(self) -> Dict[str, Any]:
"""Get router hardware information.
Returns:
Dictionary containing router information
"""
if self.mock_mode:
return {
"model": "Mock Router",
"firmware": "1.0.0-mock",
"wifi_standard": "802.11ac",
"antennas": 4,
"supported_bands": ["2.4GHz", "5GHz"],
"csi_capabilities": {
"max_subcarriers": 64,
"max_antennas": 4,
"sampling_rate": 1000
}
}
# For real routers, this would query the actual hardware
return {
"model": "Unknown",
"firmware": "Unknown",
"wifi_standard": "Unknown",
"antennas": 1,
"supported_bands": ["Unknown"],
"csi_capabilities": {
"max_subcarriers": 64,
"max_antennas": 1,
"sampling_rate": 100
}
}
async def configure_csi_collection(self, config: Dict[str, Any]) -> bool:
"""Configure CSI data collection parameters.
Args:
config: Configuration dictionary
Returns:
True if configuration successful, False otherwise
"""
try:
if self.mock_mode:
# Update mock generator parameters
if 'sampling_rate' in config:
self.mock_data_generator['frequency'] = config['sampling_rate'] / 1000.0
if 'noise_level' in config:
self.mock_data_generator['noise_level'] = config['noise_level']
self.logger.info(f"Mock CSI collection configured for router {self.router_id}")
return True
# For real routers, this would send configuration commands
self.logger.warning("Real CSI configuration not implemented")
return False
except Exception as e:
self.logger.error(f"Error configuring CSI collection for router {self.router_id}: {e}")
return False
def get_metrics(self) -> Dict[str, Any]:
"""Get router interface metrics.
Returns:
Dictionary containing metrics
"""
uptime = 0
if self.last_data_time:
uptime = (datetime.now() - self.last_data_time).total_seconds()
success_rate = 0
if self.sample_count > 0:
success_rate = (self.sample_count - self.error_count) / self.sample_count
return {
"router_id": self.router_id,
"sample_count": self.sample_count,
"error_count": self.error_count,
"success_rate": success_rate,
"uptime_seconds": uptime,
"is_connected": self.is_connected,
"mock_mode": self.mock_mode
}
def reset_stats(self):
"""Reset statistics counters."""
self.error_count = 0
self.sample_count = 0
self.last_error = None
self.logger.info(f"Statistics reset for router {self.router_id}")