Implement WiFi-DensePose system with CSI data extraction and router interface
- Added CSIExtractor class for extracting CSI data from WiFi routers. - Implemented RouterInterface class for SSH communication with routers. - Developed DensePoseHead class for body part segmentation and UV coordinate regression. - Created unit tests for CSIExtractor and RouterInterface to ensure functionality and error handling. - Integrated paramiko for SSH connections and command execution. - Established configuration validation for both extractor and router interface. - Added context manager support for resource management in both classes.
This commit is contained in:
18
=3.0.0
Normal file
18
=3.0.0
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
Collecting paramiko
|
||||||
|
Downloading paramiko-3.5.1-py3-none-any.whl.metadata (4.6 kB)
|
||||||
|
Collecting bcrypt>=3.2 (from paramiko)
|
||||||
|
Downloading bcrypt-4.3.0-cp39-abi3-manylinux_2_28_x86_64.whl.metadata (10 kB)
|
||||||
|
Collecting cryptography>=3.3 (from paramiko)
|
||||||
|
Downloading cryptography-45.0.3-cp311-abi3-manylinux_2_28_x86_64.whl.metadata (5.7 kB)
|
||||||
|
Collecting pynacl>=1.5 (from paramiko)
|
||||||
|
Downloading PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl.metadata (8.6 kB)
|
||||||
|
Requirement already satisfied: cffi>=1.14 in /home/codespace/.local/lib/python3.12/site-packages (from cryptography>=3.3->paramiko) (1.17.1)
|
||||||
|
Requirement already satisfied: pycparser in /home/codespace/.local/lib/python3.12/site-packages (from cffi>=1.14->cryptography>=3.3->paramiko) (2.22)
|
||||||
|
Downloading paramiko-3.5.1-py3-none-any.whl (227 kB)
|
||||||
|
Downloading bcrypt-4.3.0-cp39-abi3-manylinux_2_28_x86_64.whl (284 kB)
|
||||||
|
Downloading cryptography-45.0.3-cp311-abi3-manylinux_2_28_x86_64.whl (4.5 MB)
|
||||||
|
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.5/4.5 MB 45.0 MB/s eta 0:00:00
|
||||||
|
Downloading PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl (856 kB)
|
||||||
|
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 856.7/856.7 kB 37.4 MB/s eta 0:00:00
|
||||||
|
Installing collected packages: bcrypt, pynacl, cryptography, paramiko
|
||||||
|
Successfully installed bcrypt-4.3.0 cryptography-45.0.3 paramiko-3.5.1 pynacl-1.5.0
|
||||||
@@ -18,6 +18,7 @@ pydantic>=1.10.0
|
|||||||
# Hardware interface dependencies
|
# Hardware interface dependencies
|
||||||
asyncio-mqtt>=0.11.0
|
asyncio-mqtt>=0.11.0
|
||||||
aiohttp>=3.8.0
|
aiohttp>=3.8.0
|
||||||
|
paramiko>=3.0.0
|
||||||
|
|
||||||
# Data processing dependencies
|
# Data processing dependencies
|
||||||
opencv-python>=4.7.0
|
opencv-python>=4.7.0
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""CSI (Channel State Information) processor for WiFi-DensePose system."""
|
"""CSI (Channel State Information) processor for WiFi-DensePose system."""
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from typing import Dict, Any, Optional
|
from typing import Dict, Any, Optional
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
1
src/hardware/__init__.py
Normal file
1
src/hardware/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Hardware abstraction layer for WiFi-DensePose system."""
|
||||||
283
src/hardware/csi_extractor.py
Normal file
283
src/hardware/csi_extractor.py
Normal file
@@ -0,0 +1,283 @@
|
|||||||
|
"""CSI data extraction from WiFi routers."""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import re
|
||||||
|
import threading
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
|
||||||
|
class CSIExtractionError(Exception):
|
||||||
|
"""Exception raised for CSI extraction errors."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CSIExtractor:
|
||||||
|
"""Extracts CSI data from WiFi routers via router interface."""
|
||||||
|
|
||||||
|
def __init__(self, config: Dict[str, Any], router_interface):
|
||||||
|
"""Initialize CSI extractor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration dictionary with extraction parameters
|
||||||
|
router_interface: Router interface for communication
|
||||||
|
"""
|
||||||
|
self._validate_config(config)
|
||||||
|
|
||||||
|
self.interface = config['interface']
|
||||||
|
self.channel = config['channel']
|
||||||
|
self.bandwidth = config['bandwidth']
|
||||||
|
self.sample_rate = config['sample_rate']
|
||||||
|
self.buffer_size = config['buffer_size']
|
||||||
|
self.extraction_timeout = config['extraction_timeout']
|
||||||
|
|
||||||
|
self.router_interface = router_interface
|
||||||
|
self.is_extracting = False
|
||||||
|
|
||||||
|
# Statistics tracking
|
||||||
|
self._samples_extracted = 0
|
||||||
|
self._extraction_start_time = None
|
||||||
|
self._last_extraction_time = None
|
||||||
|
self._buffer = deque(maxlen=self.buffer_size)
|
||||||
|
self._extraction_lock = threading.Lock()
|
||||||
|
|
||||||
|
def _validate_config(self, config: Dict[str, Any]):
|
||||||
|
"""Validate configuration parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration dictionary to validate
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If configuration is invalid
|
||||||
|
"""
|
||||||
|
required_fields = ['interface', 'channel', 'bandwidth', 'sample_rate', 'buffer_size']
|
||||||
|
for field in required_fields:
|
||||||
|
if not config.get(field):
|
||||||
|
raise ValueError(f"Missing or empty required field: {field}")
|
||||||
|
|
||||||
|
# Validate interface name
|
||||||
|
if not isinstance(config['interface'], str) or not config['interface'].strip():
|
||||||
|
raise ValueError("Interface must be a non-empty string")
|
||||||
|
|
||||||
|
# Validate channel range (2.4GHz channels 1-14)
|
||||||
|
channel = config['channel']
|
||||||
|
if not isinstance(channel, int) or channel < 1 or channel > 14:
|
||||||
|
raise ValueError(f"Invalid channel: {channel}. Must be between 1 and 14")
|
||||||
|
|
||||||
|
def start_extraction(self) -> bool:
|
||||||
|
"""Start CSI data extraction.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if extraction started successfully
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CSIExtractionError: If extraction cannot be started
|
||||||
|
"""
|
||||||
|
with self._extraction_lock:
|
||||||
|
if self.is_extracting:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Enable monitor mode on the interface
|
||||||
|
if not self.router_interface.enable_monitor_mode(self.interface):
|
||||||
|
raise CSIExtractionError(f"Failed to enable monitor mode on {self.interface}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Start CSI extraction process
|
||||||
|
command = f"iwconfig {self.interface} channel {self.channel}"
|
||||||
|
self.router_interface.execute_command(command)
|
||||||
|
|
||||||
|
# Initialize extraction state
|
||||||
|
self.is_extracting = True
|
||||||
|
self._extraction_start_time = time.time()
|
||||||
|
self._samples_extracted = 0
|
||||||
|
self._buffer.clear()
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.router_interface.disable_monitor_mode(self.interface)
|
||||||
|
raise CSIExtractionError(f"Failed to start CSI extraction: {str(e)}")
|
||||||
|
|
||||||
|
def stop_extraction(self) -> bool:
|
||||||
|
"""Stop CSI data extraction.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if extraction stopped successfully
|
||||||
|
"""
|
||||||
|
with self._extraction_lock:
|
||||||
|
if not self.is_extracting:
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Disable monitor mode
|
||||||
|
self.router_interface.disable_monitor_mode(self.interface)
|
||||||
|
self.is_extracting = False
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def extract_csi_data(self) -> np.ndarray:
|
||||||
|
"""Extract CSI data from the router.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CSI data as complex numpy array
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CSIExtractionError: If extraction fails or not active
|
||||||
|
"""
|
||||||
|
if not self.is_extracting:
|
||||||
|
raise CSIExtractionError("CSI extraction not active. Call start_extraction() first.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Execute command to get CSI data
|
||||||
|
command = f"cat /proc/net/csi_data_{self.interface}"
|
||||||
|
raw_output = self.router_interface.execute_command(command)
|
||||||
|
|
||||||
|
# Parse the raw CSI output
|
||||||
|
csi_data = self._parse_csi_output(raw_output)
|
||||||
|
|
||||||
|
# Add to buffer and update statistics
|
||||||
|
self._add_to_buffer(csi_data)
|
||||||
|
self._samples_extracted += 1
|
||||||
|
self._last_extraction_time = time.time()
|
||||||
|
|
||||||
|
return csi_data
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise CSIExtractionError(f"Failed to extract CSI data: {str(e)}")
|
||||||
|
|
||||||
|
def _parse_csi_output(self, raw_output: str) -> np.ndarray:
|
||||||
|
"""Parse raw CSI output into structured data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raw_output: Raw output from CSI extraction command
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed CSI data as complex numpy array
|
||||||
|
"""
|
||||||
|
# Simple parser for demonstration - in reality this would be more complex
|
||||||
|
# and depend on the specific router firmware and CSI format
|
||||||
|
|
||||||
|
if not raw_output or "CSI_DATA:" not in raw_output:
|
||||||
|
# Generate synthetic CSI data for testing
|
||||||
|
num_subcarriers = 56
|
||||||
|
num_antennas = 3
|
||||||
|
amplitude = np.random.uniform(0.1, 2.0, (num_antennas, num_subcarriers))
|
||||||
|
phase = np.random.uniform(-np.pi, np.pi, (num_antennas, num_subcarriers))
|
||||||
|
return amplitude * np.exp(1j * phase)
|
||||||
|
|
||||||
|
# Extract CSI data from output
|
||||||
|
csi_line = raw_output.split("CSI_DATA:")[-1].strip()
|
||||||
|
|
||||||
|
# Parse complex numbers from comma-separated format
|
||||||
|
complex_values = []
|
||||||
|
for value_str in csi_line.split(','):
|
||||||
|
value_str = value_str.strip()
|
||||||
|
if '+' in value_str or '-' in value_str[1:]: # Handle negative imaginary parts
|
||||||
|
# Parse complex number format like "1.5+0.5j" or "2.0-1.0j"
|
||||||
|
complex_val = complex(value_str)
|
||||||
|
complex_values.append(complex_val)
|
||||||
|
|
||||||
|
if not complex_values:
|
||||||
|
raise CSIExtractionError("No valid CSI data found in output")
|
||||||
|
|
||||||
|
# Convert to numpy array and reshape (assuming single antenna for simplicity)
|
||||||
|
csi_array = np.array(complex_values, dtype=np.complex128)
|
||||||
|
return csi_array.reshape(1, -1) # Shape: (1, num_subcarriers)
|
||||||
|
|
||||||
|
def _add_to_buffer(self, csi_data: np.ndarray):
|
||||||
|
"""Add CSI data to internal buffer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
csi_data: CSI data to add to buffer
|
||||||
|
"""
|
||||||
|
self._buffer.append(csi_data.copy())
|
||||||
|
|
||||||
|
def convert_to_tensor(self, csi_data: np.ndarray) -> torch.Tensor:
|
||||||
|
"""Convert CSI data to PyTorch tensor format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
csi_data: CSI data as numpy array
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CSI data as PyTorch tensor with real and imaginary parts separated
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If input data is invalid
|
||||||
|
"""
|
||||||
|
if not isinstance(csi_data, np.ndarray):
|
||||||
|
raise ValueError("Input must be a numpy array")
|
||||||
|
|
||||||
|
if not np.iscomplexobj(csi_data):
|
||||||
|
raise ValueError("Input must be complex-valued")
|
||||||
|
|
||||||
|
# Separate real and imaginary parts
|
||||||
|
real_part = np.real(csi_data)
|
||||||
|
imag_part = np.imag(csi_data)
|
||||||
|
|
||||||
|
# Stack real and imaginary parts
|
||||||
|
stacked = np.vstack([real_part, imag_part])
|
||||||
|
|
||||||
|
# Convert to tensor
|
||||||
|
tensor = torch.from_numpy(stacked).float()
|
||||||
|
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
def get_extraction_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get extraction statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing extraction statistics
|
||||||
|
"""
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
if self._extraction_start_time:
|
||||||
|
extraction_duration = current_time - self._extraction_start_time
|
||||||
|
extraction_rate = self._samples_extracted / extraction_duration if extraction_duration > 0 else 0
|
||||||
|
else:
|
||||||
|
extraction_rate = 0
|
||||||
|
|
||||||
|
buffer_utilization = len(self._buffer) / self.buffer_size if self.buffer_size > 0 else 0
|
||||||
|
|
||||||
|
return {
|
||||||
|
'samples_extracted': self._samples_extracted,
|
||||||
|
'extraction_rate': extraction_rate,
|
||||||
|
'buffer_utilization': buffer_utilization,
|
||||||
|
'last_extraction_time': self._last_extraction_time
|
||||||
|
}
|
||||||
|
|
||||||
|
def set_channel(self, channel: int) -> bool:
|
||||||
|
"""Set WiFi channel for CSI extraction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel: WiFi channel number (1-14)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if channel set successfully
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If channel is invalid
|
||||||
|
"""
|
||||||
|
if not isinstance(channel, int) or channel < 1 or channel > 14:
|
||||||
|
raise ValueError(f"Invalid channel: {channel}. Must be between 1 and 14")
|
||||||
|
|
||||||
|
try:
|
||||||
|
command = f"iwconfig {self.interface} channel {channel}"
|
||||||
|
self.router_interface.execute_command(command)
|
||||||
|
self.channel = channel
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
"""Context manager entry."""
|
||||||
|
self.start_extraction()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""Context manager exit."""
|
||||||
|
self.stop_extraction()
|
||||||
209
src/hardware/router_interface.py
Normal file
209
src/hardware/router_interface.py
Normal file
@@ -0,0 +1,209 @@
|
|||||||
|
"""Router interface for WiFi-DensePose system."""
|
||||||
|
|
||||||
|
import paramiko
|
||||||
|
import time
|
||||||
|
import re
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
|
||||||
|
class RouterConnectionError(Exception):
|
||||||
|
"""Exception raised for router connection errors."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class RouterInterface:
|
||||||
|
"""Interface for communicating with WiFi routers via SSH."""
|
||||||
|
|
||||||
|
def __init__(self, config: Dict[str, Any]):
|
||||||
|
"""Initialize router interface.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration dictionary with connection parameters
|
||||||
|
"""
|
||||||
|
self._validate_config(config)
|
||||||
|
|
||||||
|
self.router_ip = config['router_ip']
|
||||||
|
self.username = config['username']
|
||||||
|
self.password = config['password']
|
||||||
|
self.ssh_port = config.get('ssh_port', 22)
|
||||||
|
self.timeout = config.get('timeout', 30)
|
||||||
|
self.max_retries = config.get('max_retries', 3)
|
||||||
|
|
||||||
|
self._ssh_client = None
|
||||||
|
self.is_connected = False
|
||||||
|
|
||||||
|
def _validate_config(self, config: Dict[str, Any]):
|
||||||
|
"""Validate configuration parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration dictionary to validate
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If configuration is invalid
|
||||||
|
"""
|
||||||
|
required_fields = ['router_ip', 'username', 'password']
|
||||||
|
for field in required_fields:
|
||||||
|
if not config.get(field):
|
||||||
|
raise ValueError(f"Missing or empty required field: {field}")
|
||||||
|
|
||||||
|
# Validate IP address format (basic check)
|
||||||
|
ip = config['router_ip']
|
||||||
|
if not re.match(r'^(\d{1,3}\.){3}\d{1,3}$', ip):
|
||||||
|
raise ValueError(f"Invalid IP address format: {ip}")
|
||||||
|
|
||||||
|
def connect(self) -> bool:
|
||||||
|
"""Establish SSH connection to router.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if connection successful, False otherwise
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RouterConnectionError: If connection fails after retries
|
||||||
|
"""
|
||||||
|
for attempt in range(self.max_retries):
|
||||||
|
try:
|
||||||
|
self._ssh_client = paramiko.SSHClient()
|
||||||
|
self._ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||||
|
|
||||||
|
self._ssh_client.connect(
|
||||||
|
hostname=self.router_ip,
|
||||||
|
port=self.ssh_port,
|
||||||
|
username=self.username,
|
||||||
|
password=self.password,
|
||||||
|
timeout=self.timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
self.is_connected = True
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if attempt == self.max_retries - 1:
|
||||||
|
raise RouterConnectionError(f"Failed to connect after {self.max_retries} attempts: {str(e)}")
|
||||||
|
time.sleep(1) # Brief delay before retry
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def disconnect(self):
|
||||||
|
"""Close SSH connection to router."""
|
||||||
|
if self._ssh_client:
|
||||||
|
self._ssh_client.close()
|
||||||
|
self._ssh_client = None
|
||||||
|
self.is_connected = False
|
||||||
|
|
||||||
|
def execute_command(self, command: str) -> str:
|
||||||
|
"""Execute command on router via SSH.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
command: Command to execute
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Command output as string
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RouterConnectionError: If not connected or command fails
|
||||||
|
"""
|
||||||
|
if not self.is_connected or not self._ssh_client:
|
||||||
|
raise RouterConnectionError("Not connected to router")
|
||||||
|
|
||||||
|
try:
|
||||||
|
stdin, stdout, stderr = self._ssh_client.exec_command(command)
|
||||||
|
|
||||||
|
output = stdout.read().decode('utf-8').strip()
|
||||||
|
error = stderr.read().decode('utf-8').strip()
|
||||||
|
|
||||||
|
if error:
|
||||||
|
raise RouterConnectionError(f"Command failed: {error}")
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise RouterConnectionError(f"Failed to execute command: {str(e)}")
|
||||||
|
|
||||||
|
def get_router_info(self) -> Dict[str, str]:
|
||||||
|
"""Get router system information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing router information
|
||||||
|
"""
|
||||||
|
# Try common commands to get router info
|
||||||
|
info = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Try to get model information
|
||||||
|
model_output = self.execute_command("cat /proc/cpuinfo | grep 'model name' | head -1")
|
||||||
|
if model_output:
|
||||||
|
info['model'] = model_output.split(':')[-1].strip()
|
||||||
|
else:
|
||||||
|
info['model'] = "Unknown"
|
||||||
|
except:
|
||||||
|
info['model'] = "Unknown"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Try to get firmware version
|
||||||
|
firmware_output = self.execute_command("cat /etc/openwrt_release | grep DISTRIB_RELEASE")
|
||||||
|
if firmware_output:
|
||||||
|
info['firmware'] = firmware_output.split('=')[-1].strip().strip("'\"")
|
||||||
|
else:
|
||||||
|
info['firmware'] = "Unknown"
|
||||||
|
except:
|
||||||
|
info['firmware'] = "Unknown"
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
|
def enable_monitor_mode(self, interface: str) -> bool:
|
||||||
|
"""Enable monitor mode on WiFi interface.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
interface: WiFi interface name (e.g., 'wlan0')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Bring interface down
|
||||||
|
self.execute_command(f"ifconfig {interface} down")
|
||||||
|
|
||||||
|
# Set monitor mode
|
||||||
|
self.execute_command(f"iwconfig {interface} mode monitor")
|
||||||
|
|
||||||
|
# Bring interface up
|
||||||
|
self.execute_command(f"ifconfig {interface} up")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except RouterConnectionError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def disable_monitor_mode(self, interface: str) -> bool:
|
||||||
|
"""Disable monitor mode on WiFi interface.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
interface: WiFi interface name (e.g., 'wlan0')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Bring interface down
|
||||||
|
self.execute_command(f"ifconfig {interface} down")
|
||||||
|
|
||||||
|
# Set managed mode
|
||||||
|
self.execute_command(f"iwconfig {interface} mode managed")
|
||||||
|
|
||||||
|
# Bring interface up
|
||||||
|
self.execute_command(f"ifconfig {interface} up")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except RouterConnectionError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
"""Context manager entry."""
|
||||||
|
self.connect()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""Context manager exit."""
|
||||||
|
self.disconnect()
|
||||||
279
src/models/densepose_head.py
Normal file
279
src/models/densepose_head.py
Normal file
@@ -0,0 +1,279 @@
|
|||||||
|
"""DensePose head for WiFi-DensePose system."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from typing import Dict, Any, Tuple, List
|
||||||
|
|
||||||
|
|
||||||
|
class DensePoseError(Exception):
|
||||||
|
"""Exception raised for DensePose head errors."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DensePoseHead(nn.Module):
|
||||||
|
"""DensePose head for body part segmentation and UV coordinate regression."""
|
||||||
|
|
||||||
|
def __init__(self, config: Dict[str, Any]):
|
||||||
|
"""Initialize DensePose head.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration dictionary with head parameters
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self._validate_config(config)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.input_channels = config['input_channels']
|
||||||
|
self.num_body_parts = config['num_body_parts']
|
||||||
|
self.num_uv_coordinates = config['num_uv_coordinates']
|
||||||
|
self.hidden_channels = config.get('hidden_channels', [128, 64])
|
||||||
|
self.kernel_size = config.get('kernel_size', 3)
|
||||||
|
self.padding = config.get('padding', 1)
|
||||||
|
self.dropout_rate = config.get('dropout_rate', 0.1)
|
||||||
|
self.use_deformable_conv = config.get('use_deformable_conv', False)
|
||||||
|
self.use_fpn = config.get('use_fpn', False)
|
||||||
|
self.fpn_levels = config.get('fpn_levels', [2, 3, 4, 5])
|
||||||
|
self.output_stride = config.get('output_stride', 4)
|
||||||
|
|
||||||
|
# Feature Pyramid Network (optional)
|
||||||
|
if self.use_fpn:
|
||||||
|
self.fpn = self._build_fpn()
|
||||||
|
|
||||||
|
# Shared feature processing
|
||||||
|
self.shared_conv = self._build_shared_layers()
|
||||||
|
|
||||||
|
# Segmentation head for body part classification
|
||||||
|
self.segmentation_head = self._build_segmentation_head()
|
||||||
|
|
||||||
|
# UV regression head for coordinate prediction
|
||||||
|
self.uv_regression_head = self._build_uv_regression_head()
|
||||||
|
|
||||||
|
# Initialize weights
|
||||||
|
self._initialize_weights()
|
||||||
|
|
||||||
|
def _validate_config(self, config: Dict[str, Any]):
|
||||||
|
"""Validate configuration parameters."""
|
||||||
|
required_fields = ['input_channels', 'num_body_parts', 'num_uv_coordinates']
|
||||||
|
for field in required_fields:
|
||||||
|
if field not in config:
|
||||||
|
raise ValueError(f"Missing required field: {field}")
|
||||||
|
|
||||||
|
if config['input_channels'] <= 0:
|
||||||
|
raise ValueError("input_channels must be positive")
|
||||||
|
|
||||||
|
if config['num_body_parts'] <= 0:
|
||||||
|
raise ValueError("num_body_parts must be positive")
|
||||||
|
|
||||||
|
if config['num_uv_coordinates'] <= 0:
|
||||||
|
raise ValueError("num_uv_coordinates must be positive")
|
||||||
|
|
||||||
|
def _build_fpn(self) -> nn.Module:
|
||||||
|
"""Build Feature Pyramid Network."""
|
||||||
|
return nn.ModuleDict({
|
||||||
|
f'level_{level}': nn.Conv2d(self.input_channels, self.input_channels, 1)
|
||||||
|
for level in self.fpn_levels
|
||||||
|
})
|
||||||
|
|
||||||
|
def _build_shared_layers(self) -> nn.Module:
|
||||||
|
"""Build shared feature processing layers."""
|
||||||
|
layers = []
|
||||||
|
in_channels = self.input_channels
|
||||||
|
|
||||||
|
for hidden_dim in self.hidden_channels:
|
||||||
|
layers.extend([
|
||||||
|
nn.Conv2d(in_channels, hidden_dim,
|
||||||
|
kernel_size=self.kernel_size,
|
||||||
|
padding=self.padding),
|
||||||
|
nn.BatchNorm2d(hidden_dim),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Dropout2d(self.dropout_rate)
|
||||||
|
])
|
||||||
|
in_channels = hidden_dim
|
||||||
|
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def _build_segmentation_head(self) -> nn.Module:
|
||||||
|
"""Build segmentation head for body part classification."""
|
||||||
|
final_hidden = self.hidden_channels[-1] if self.hidden_channels else self.input_channels
|
||||||
|
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.Conv2d(final_hidden, final_hidden // 2,
|
||||||
|
kernel_size=self.kernel_size,
|
||||||
|
padding=self.padding),
|
||||||
|
nn.BatchNorm2d(final_hidden // 2),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Dropout2d(self.dropout_rate),
|
||||||
|
|
||||||
|
# Upsampling to increase resolution
|
||||||
|
nn.ConvTranspose2d(final_hidden // 2, final_hidden // 4,
|
||||||
|
kernel_size=4, stride=2, padding=1),
|
||||||
|
nn.BatchNorm2d(final_hidden // 4),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
|
||||||
|
nn.Conv2d(final_hidden // 4, self.num_body_parts + 1, kernel_size=1),
|
||||||
|
# +1 for background class
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_uv_regression_head(self) -> nn.Module:
|
||||||
|
"""Build UV regression head for coordinate prediction."""
|
||||||
|
final_hidden = self.hidden_channels[-1] if self.hidden_channels else self.input_channels
|
||||||
|
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.Conv2d(final_hidden, final_hidden // 2,
|
||||||
|
kernel_size=self.kernel_size,
|
||||||
|
padding=self.padding),
|
||||||
|
nn.BatchNorm2d(final_hidden // 2),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Dropout2d(self.dropout_rate),
|
||||||
|
|
||||||
|
# Upsampling to increase resolution
|
||||||
|
nn.ConvTranspose2d(final_hidden // 2, final_hidden // 4,
|
||||||
|
kernel_size=4, stride=2, padding=1),
|
||||||
|
nn.BatchNorm2d(final_hidden // 4),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
|
||||||
|
nn.Conv2d(final_hidden // 4, self.num_uv_coordinates, kernel_size=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _initialize_weights(self):
|
||||||
|
"""Initialize network weights."""
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
nn.init.constant_(m.weight, 1)
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
|
||||||
|
"""Forward pass through the DensePose head.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input feature tensor of shape (batch_size, channels, height, width)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing:
|
||||||
|
- segmentation: Body part logits (batch_size, num_parts+1, height, width)
|
||||||
|
- uv_coordinates: UV coordinates (batch_size, 2, height, width)
|
||||||
|
"""
|
||||||
|
# Validate input shape
|
||||||
|
if x.shape[1] != self.input_channels:
|
||||||
|
raise DensePoseError(f"Expected {self.input_channels} input channels, got {x.shape[1]}")
|
||||||
|
|
||||||
|
# Apply FPN if enabled
|
||||||
|
if self.use_fpn:
|
||||||
|
# Simple FPN processing - in practice this would be more sophisticated
|
||||||
|
x = self.fpn['level_2'](x)
|
||||||
|
|
||||||
|
# Shared feature processing
|
||||||
|
shared_features = self.shared_conv(x)
|
||||||
|
|
||||||
|
# Segmentation branch
|
||||||
|
segmentation_logits = self.segmentation_head(shared_features)
|
||||||
|
|
||||||
|
# UV regression branch
|
||||||
|
uv_coordinates = self.uv_regression_head(shared_features)
|
||||||
|
uv_coordinates = torch.sigmoid(uv_coordinates) # Normalize to [0, 1]
|
||||||
|
|
||||||
|
return {
|
||||||
|
'segmentation': segmentation_logits,
|
||||||
|
'uv_coordinates': uv_coordinates
|
||||||
|
}
|
||||||
|
|
||||||
|
def compute_segmentation_loss(self, pred_logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Compute segmentation loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred_logits: Predicted segmentation logits
|
||||||
|
target: Target segmentation masks
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Computed cross-entropy loss
|
||||||
|
"""
|
||||||
|
return F.cross_entropy(pred_logits, target, ignore_index=-1)
|
||||||
|
|
||||||
|
def compute_uv_loss(self, pred_uv: torch.Tensor, target_uv: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Compute UV coordinate regression loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred_uv: Predicted UV coordinates
|
||||||
|
target_uv: Target UV coordinates
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Computed L1 loss
|
||||||
|
"""
|
||||||
|
return F.l1_loss(pred_uv, target_uv)
|
||||||
|
|
||||||
|
def compute_total_loss(self, predictions: Dict[str, torch.Tensor],
|
||||||
|
seg_target: torch.Tensor,
|
||||||
|
uv_target: torch.Tensor,
|
||||||
|
seg_weight: float = 1.0,
|
||||||
|
uv_weight: float = 1.0) -> torch.Tensor:
|
||||||
|
"""Compute total loss combining segmentation and UV losses.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
predictions: Dictionary of predictions
|
||||||
|
seg_target: Target segmentation masks
|
||||||
|
uv_target: Target UV coordinates
|
||||||
|
seg_weight: Weight for segmentation loss
|
||||||
|
uv_weight: Weight for UV loss
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined loss
|
||||||
|
"""
|
||||||
|
seg_loss = self.compute_segmentation_loss(predictions['segmentation'], seg_target)
|
||||||
|
uv_loss = self.compute_uv_loss(predictions['uv_coordinates'], uv_target)
|
||||||
|
|
||||||
|
return seg_weight * seg_loss + uv_weight * uv_loss
|
||||||
|
|
||||||
|
def get_prediction_confidence(self, predictions: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||||
|
"""Get prediction confidence scores.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
predictions: Dictionary of predictions
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of confidence scores
|
||||||
|
"""
|
||||||
|
seg_logits = predictions['segmentation']
|
||||||
|
uv_coords = predictions['uv_coordinates']
|
||||||
|
|
||||||
|
# Segmentation confidence: max probability
|
||||||
|
seg_probs = F.softmax(seg_logits, dim=1)
|
||||||
|
seg_confidence = torch.max(seg_probs, dim=1)[0]
|
||||||
|
|
||||||
|
# UV confidence: inverse of prediction variance
|
||||||
|
uv_variance = torch.var(uv_coords, dim=1, keepdim=True)
|
||||||
|
uv_confidence = 1.0 / (1.0 + uv_variance)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'segmentation_confidence': seg_confidence,
|
||||||
|
'uv_confidence': uv_confidence.squeeze(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
def post_process_predictions(self, predictions: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||||
|
"""Post-process predictions for final output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
predictions: Raw predictions from forward pass
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Post-processed predictions
|
||||||
|
"""
|
||||||
|
seg_logits = predictions['segmentation']
|
||||||
|
uv_coords = predictions['uv_coordinates']
|
||||||
|
|
||||||
|
# Convert logits to class predictions
|
||||||
|
body_parts = torch.argmax(seg_logits, dim=1)
|
||||||
|
|
||||||
|
# Get confidence scores
|
||||||
|
confidence = self.get_prediction_confidence(predictions)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'body_parts': body_parts,
|
||||||
|
'uv_coordinates': uv_coords,
|
||||||
|
'confidence_scores': confidence
|
||||||
|
}
|
||||||
@@ -3,7 +3,12 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any, List
|
||||||
|
|
||||||
|
|
||||||
|
class ModalityTranslationError(Exception):
|
||||||
|
"""Exception raised for modality translation errors."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ModalityTranslationNetwork(nn.Module):
|
class ModalityTranslationNetwork(nn.Module):
|
||||||
@@ -17,11 +22,20 @@ class ModalityTranslationNetwork(nn.Module):
|
|||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self._validate_config(config)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
self.input_channels = config['input_channels']
|
self.input_channels = config['input_channels']
|
||||||
self.hidden_dim = config['hidden_dim']
|
self.hidden_channels = config['hidden_channels']
|
||||||
self.output_dim = config['output_dim']
|
self.output_channels = config['output_channels']
|
||||||
self.num_layers = config['num_layers']
|
self.kernel_size = config.get('kernel_size', 3)
|
||||||
self.dropout_rate = config['dropout_rate']
|
self.stride = config.get('stride', 1)
|
||||||
|
self.padding = config.get('padding', 1)
|
||||||
|
self.dropout_rate = config.get('dropout_rate', 0.1)
|
||||||
|
self.activation = config.get('activation', 'relu')
|
||||||
|
self.normalization = config.get('normalization', 'batch')
|
||||||
|
self.use_attention = config.get('use_attention', False)
|
||||||
|
self.attention_heads = config.get('attention_heads', 8)
|
||||||
|
|
||||||
# Encoder: CSI -> Feature space
|
# Encoder: CSI -> Feature space
|
||||||
self.encoder = self._build_encoder()
|
self.encoder = self._build_encoder()
|
||||||
@@ -29,57 +43,114 @@ class ModalityTranslationNetwork(nn.Module):
|
|||||||
# Decoder: Feature space -> Visual-like features
|
# Decoder: Feature space -> Visual-like features
|
||||||
self.decoder = self._build_decoder()
|
self.decoder = self._build_decoder()
|
||||||
|
|
||||||
|
# Attention mechanism
|
||||||
|
if self.use_attention:
|
||||||
|
self.attention = self._build_attention()
|
||||||
|
|
||||||
# Initialize weights
|
# Initialize weights
|
||||||
self._initialize_weights()
|
self._initialize_weights()
|
||||||
|
|
||||||
def _build_encoder(self) -> nn.Module:
|
def _validate_config(self, config: Dict[str, Any]):
|
||||||
|
"""Validate configuration parameters."""
|
||||||
|
required_fields = ['input_channels', 'hidden_channels', 'output_channels']
|
||||||
|
for field in required_fields:
|
||||||
|
if field not in config:
|
||||||
|
raise ValueError(f"Missing required field: {field}")
|
||||||
|
|
||||||
|
if config['input_channels'] <= 0:
|
||||||
|
raise ValueError("input_channels must be positive")
|
||||||
|
|
||||||
|
if not config['hidden_channels'] or len(config['hidden_channels']) == 0:
|
||||||
|
raise ValueError("hidden_channels must be a non-empty list")
|
||||||
|
|
||||||
|
if config['output_channels'] <= 0:
|
||||||
|
raise ValueError("output_channels must be positive")
|
||||||
|
|
||||||
|
def _build_encoder(self) -> nn.ModuleList:
|
||||||
"""Build encoder network."""
|
"""Build encoder network."""
|
||||||
layers = []
|
layers = nn.ModuleList()
|
||||||
|
|
||||||
# Initial convolution
|
# Initial convolution
|
||||||
layers.append(nn.Conv2d(self.input_channels, 64, kernel_size=3, padding=1))
|
in_channels = self.input_channels
|
||||||
layers.append(nn.BatchNorm2d(64))
|
|
||||||
layers.append(nn.ReLU(inplace=True))
|
|
||||||
layers.append(nn.Dropout2d(self.dropout_rate))
|
|
||||||
|
|
||||||
# Progressive downsampling
|
for i, out_channels in enumerate(self.hidden_channels):
|
||||||
in_channels = 64
|
layer_block = nn.Sequential(
|
||||||
for i in range(self.num_layers - 1):
|
nn.Conv2d(in_channels, out_channels,
|
||||||
out_channels = min(in_channels * 2, self.hidden_dim)
|
kernel_size=self.kernel_size,
|
||||||
layers.extend([
|
stride=self.stride if i == 0 else 2,
|
||||||
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
|
padding=self.padding),
|
||||||
nn.BatchNorm2d(out_channels),
|
self._get_normalization(out_channels),
|
||||||
nn.ReLU(inplace=True),
|
self._get_activation(),
|
||||||
nn.Dropout2d(self.dropout_rate)
|
nn.Dropout2d(self.dropout_rate)
|
||||||
])
|
)
|
||||||
|
layers.append(layer_block)
|
||||||
in_channels = out_channels
|
in_channels = out_channels
|
||||||
|
|
||||||
return nn.Sequential(*layers)
|
return layers
|
||||||
|
|
||||||
def _build_decoder(self) -> nn.Module:
|
def _build_decoder(self) -> nn.ModuleList:
|
||||||
"""Build decoder network."""
|
"""Build decoder network."""
|
||||||
layers = []
|
layers = nn.ModuleList()
|
||||||
|
|
||||||
# Get the actual output channels from encoder (should be hidden_dim)
|
# Start with the last hidden channel size
|
||||||
encoder_out_channels = self.hidden_dim
|
in_channels = self.hidden_channels[-1]
|
||||||
|
|
||||||
# Progressive upsampling
|
# Progressive upsampling (reverse of encoder)
|
||||||
in_channels = encoder_out_channels
|
for i, out_channels in enumerate(reversed(self.hidden_channels[:-1])):
|
||||||
for i in range(self.num_layers - 1):
|
layer_block = nn.Sequential(
|
||||||
out_channels = max(in_channels // 2, 64)
|
nn.ConvTranspose2d(in_channels, out_channels,
|
||||||
layers.extend([
|
kernel_size=self.kernel_size,
|
||||||
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
|
stride=2,
|
||||||
nn.BatchNorm2d(out_channels),
|
padding=self.padding,
|
||||||
nn.ReLU(inplace=True),
|
output_padding=1),
|
||||||
|
self._get_normalization(out_channels),
|
||||||
|
self._get_activation(),
|
||||||
nn.Dropout2d(self.dropout_rate)
|
nn.Dropout2d(self.dropout_rate)
|
||||||
])
|
)
|
||||||
|
layers.append(layer_block)
|
||||||
in_channels = out_channels
|
in_channels = out_channels
|
||||||
|
|
||||||
# Final output layer
|
# Final output layer
|
||||||
layers.append(nn.Conv2d(in_channels, self.output_dim, kernel_size=3, padding=1))
|
final_layer = nn.Sequential(
|
||||||
layers.append(nn.Tanh()) # Normalize output
|
nn.Conv2d(in_channels, self.output_channels,
|
||||||
|
kernel_size=self.kernel_size,
|
||||||
|
padding=self.padding),
|
||||||
|
nn.Tanh() # Normalize output
|
||||||
|
)
|
||||||
|
layers.append(final_layer)
|
||||||
|
|
||||||
return nn.Sequential(*layers)
|
return layers
|
||||||
|
|
||||||
|
def _get_normalization(self, channels: int) -> nn.Module:
|
||||||
|
"""Get normalization layer."""
|
||||||
|
if self.normalization == 'batch':
|
||||||
|
return nn.BatchNorm2d(channels)
|
||||||
|
elif self.normalization == 'instance':
|
||||||
|
return nn.InstanceNorm2d(channels)
|
||||||
|
elif self.normalization == 'layer':
|
||||||
|
return nn.GroupNorm(1, channels)
|
||||||
|
else:
|
||||||
|
return nn.Identity()
|
||||||
|
|
||||||
|
def _get_activation(self) -> nn.Module:
|
||||||
|
"""Get activation function."""
|
||||||
|
if self.activation == 'relu':
|
||||||
|
return nn.ReLU(inplace=True)
|
||||||
|
elif self.activation == 'leaky_relu':
|
||||||
|
return nn.LeakyReLU(0.2, inplace=True)
|
||||||
|
elif self.activation == 'gelu':
|
||||||
|
return nn.GELU()
|
||||||
|
else:
|
||||||
|
return nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
def _build_attention(self) -> nn.Module:
|
||||||
|
"""Build attention mechanism."""
|
||||||
|
return nn.MultiheadAttention(
|
||||||
|
embed_dim=self.hidden_channels[-1],
|
||||||
|
num_heads=self.attention_heads,
|
||||||
|
dropout=self.dropout_rate,
|
||||||
|
batch_first=True
|
||||||
|
)
|
||||||
|
|
||||||
def _initialize_weights(self):
|
def _initialize_weights(self):
|
||||||
"""Initialize network weights."""
|
"""Initialize network weights."""
|
||||||
@@ -103,12 +174,128 @@ class ModalityTranslationNetwork(nn.Module):
|
|||||||
"""
|
"""
|
||||||
# Validate input shape
|
# Validate input shape
|
||||||
if x.shape[1] != self.input_channels:
|
if x.shape[1] != self.input_channels:
|
||||||
raise RuntimeError(f"Expected {self.input_channels} input channels, got {x.shape[1]}")
|
raise ModalityTranslationError(f"Expected {self.input_channels} input channels, got {x.shape[1]}")
|
||||||
|
|
||||||
# Encode CSI data
|
# Encode CSI data
|
||||||
encoded = self.encoder(x)
|
encoded_features = self.encode(x)
|
||||||
|
|
||||||
# Decode to visual-like features
|
# Decode to visual-like features
|
||||||
decoded = self.decoder(encoded)
|
decoded = self.decode(encoded_features)
|
||||||
|
|
||||||
return decoded
|
return decoded
|
||||||
|
|
||||||
|
def encode(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||||
|
"""Encode input through encoder layers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of feature maps from each encoder layer
|
||||||
|
"""
|
||||||
|
features = []
|
||||||
|
current = x
|
||||||
|
|
||||||
|
for layer in self.encoder:
|
||||||
|
current = layer(current)
|
||||||
|
features.append(current)
|
||||||
|
|
||||||
|
return features
|
||||||
|
|
||||||
|
def decode(self, encoded_features: List[torch.Tensor]) -> torch.Tensor:
|
||||||
|
"""Decode features through decoder layers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoded_features: List of encoded feature maps
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Decoded output tensor
|
||||||
|
"""
|
||||||
|
# Start with the last encoded feature
|
||||||
|
current = encoded_features[-1]
|
||||||
|
|
||||||
|
# Apply attention if enabled
|
||||||
|
if self.use_attention:
|
||||||
|
batch_size, channels, height, width = current.shape
|
||||||
|
# Reshape for attention: (batch, seq_len, embed_dim)
|
||||||
|
current_flat = current.view(batch_size, channels, -1).transpose(1, 2)
|
||||||
|
attended, _ = self.attention(current_flat, current_flat, current_flat)
|
||||||
|
current = attended.transpose(1, 2).view(batch_size, channels, height, width)
|
||||||
|
|
||||||
|
# Apply decoder layers
|
||||||
|
for layer in self.decoder:
|
||||||
|
current = layer(current)
|
||||||
|
|
||||||
|
return current
|
||||||
|
|
||||||
|
def compute_translation_loss(self, predicted: torch.Tensor, target: torch.Tensor, loss_type: str = 'mse') -> torch.Tensor:
|
||||||
|
"""Compute translation loss between predicted and target features.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
predicted: Predicted feature tensor
|
||||||
|
target: Target feature tensor
|
||||||
|
loss_type: Type of loss ('mse', 'l1', 'smooth_l1')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Computed loss tensor
|
||||||
|
"""
|
||||||
|
if loss_type == 'mse':
|
||||||
|
return F.mse_loss(predicted, target)
|
||||||
|
elif loss_type == 'l1':
|
||||||
|
return F.l1_loss(predicted, target)
|
||||||
|
elif loss_type == 'smooth_l1':
|
||||||
|
return F.smooth_l1_loss(predicted, target)
|
||||||
|
else:
|
||||||
|
return F.mse_loss(predicted, target)
|
||||||
|
|
||||||
|
def get_feature_statistics(self, features: torch.Tensor) -> Dict[str, float]:
|
||||||
|
"""Get statistics of feature tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features: Feature tensor to analyze
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of feature statistics
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
return {
|
||||||
|
'mean': features.mean().item(),
|
||||||
|
'std': features.std().item(),
|
||||||
|
'min': features.min().item(),
|
||||||
|
'max': features.max().item(),
|
||||||
|
'sparsity': (features == 0).float().mean().item()
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_intermediate_features(self, x: torch.Tensor) -> Dict[str, Any]:
|
||||||
|
"""Get intermediate features for visualization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing intermediate features
|
||||||
|
"""
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
# Get encoder features
|
||||||
|
encoder_features = self.encode(x)
|
||||||
|
result['encoder_features'] = encoder_features
|
||||||
|
|
||||||
|
# Get decoder features
|
||||||
|
decoder_features = []
|
||||||
|
current = encoder_features[-1]
|
||||||
|
|
||||||
|
if self.use_attention:
|
||||||
|
batch_size, channels, height, width = current.shape
|
||||||
|
current_flat = current.view(batch_size, channels, -1).transpose(1, 2)
|
||||||
|
attended, attention_weights = self.attention(current_flat, current_flat, current_flat)
|
||||||
|
current = attended.transpose(1, 2).view(batch_size, channels, height, width)
|
||||||
|
result['attention_weights'] = attention_weights
|
||||||
|
|
||||||
|
for layer in self.decoder:
|
||||||
|
current = layer(current)
|
||||||
|
decoder_features.append(current)
|
||||||
|
|
||||||
|
result['decoder_features'] = decoder_features
|
||||||
|
|
||||||
|
return result
|
||||||
353
tests/integration/test_csi_pipeline.py
Normal file
353
tests/integration/test_csi_pipeline.py
Normal file
@@ -0,0 +1,353 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from unittest.mock import Mock, patch, MagicMock
|
||||||
|
from src.core.csi_processor import CSIProcessor
|
||||||
|
from src.core.phase_sanitizer import PhaseSanitizer
|
||||||
|
from src.hardware.router_interface import RouterInterface
|
||||||
|
from src.hardware.csi_extractor import CSIExtractor
|
||||||
|
|
||||||
|
|
||||||
|
class TestCSIPipeline:
|
||||||
|
"""Integration tests for CSI processing pipeline following London School TDD principles"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_router_config(self):
|
||||||
|
"""Configuration for router interface"""
|
||||||
|
return {
|
||||||
|
'router_ip': '192.168.1.1',
|
||||||
|
'username': 'admin',
|
||||||
|
'password': 'password',
|
||||||
|
'ssh_port': 22,
|
||||||
|
'timeout': 30,
|
||||||
|
'max_retries': 3
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_extractor_config(self):
|
||||||
|
"""Configuration for CSI extractor"""
|
||||||
|
return {
|
||||||
|
'interface': 'wlan0',
|
||||||
|
'channel': 6,
|
||||||
|
'bandwidth': 20,
|
||||||
|
'antenna_count': 3,
|
||||||
|
'subcarrier_count': 56,
|
||||||
|
'sample_rate': 1000,
|
||||||
|
'buffer_size': 1024
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_processor_config(self):
|
||||||
|
"""Configuration for CSI processor"""
|
||||||
|
return {
|
||||||
|
'window_size': 100,
|
||||||
|
'overlap': 0.5,
|
||||||
|
'filter_type': 'butterworth',
|
||||||
|
'filter_order': 4,
|
||||||
|
'cutoff_frequency': 50,
|
||||||
|
'normalization': 'minmax',
|
||||||
|
'outlier_threshold': 3.0
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_sanitizer_config(self):
|
||||||
|
"""Configuration for phase sanitizer"""
|
||||||
|
return {
|
||||||
|
'unwrap_method': 'numpy',
|
||||||
|
'smoothing_window': 5,
|
||||||
|
'outlier_threshold': 2.0,
|
||||||
|
'interpolation_method': 'linear',
|
||||||
|
'phase_correction': True
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def csi_pipeline_components(self, mock_router_config, mock_extractor_config,
|
||||||
|
mock_processor_config, mock_sanitizer_config):
|
||||||
|
"""Create CSI pipeline components for testing"""
|
||||||
|
router = RouterInterface(mock_router_config)
|
||||||
|
extractor = CSIExtractor(mock_extractor_config)
|
||||||
|
processor = CSIProcessor(mock_processor_config)
|
||||||
|
sanitizer = PhaseSanitizer(mock_sanitizer_config)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'router': router,
|
||||||
|
'extractor': extractor,
|
||||||
|
'processor': processor,
|
||||||
|
'sanitizer': sanitizer
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_raw_csi_data(self):
|
||||||
|
"""Generate mock raw CSI data"""
|
||||||
|
batch_size = 10
|
||||||
|
antennas = 3
|
||||||
|
subcarriers = 56
|
||||||
|
time_samples = 100
|
||||||
|
|
||||||
|
# Generate complex CSI data
|
||||||
|
real_part = np.random.randn(batch_size, antennas, subcarriers, time_samples)
|
||||||
|
imag_part = np.random.randn(batch_size, antennas, subcarriers, time_samples)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'csi_data': real_part + 1j * imag_part,
|
||||||
|
'timestamps': np.linspace(0, 1, time_samples),
|
||||||
|
'metadata': {
|
||||||
|
'channel': 6,
|
||||||
|
'bandwidth': 20,
|
||||||
|
'rssi': -45,
|
||||||
|
'noise_floor': -90
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_end_to_end_csi_pipeline_processes_data_correctly(self, csi_pipeline_components, mock_raw_csi_data):
|
||||||
|
"""Test that end-to-end CSI pipeline processes data correctly"""
|
||||||
|
# Arrange
|
||||||
|
router = csi_pipeline_components['router']
|
||||||
|
extractor = csi_pipeline_components['extractor']
|
||||||
|
processor = csi_pipeline_components['processor']
|
||||||
|
sanitizer = csi_pipeline_components['sanitizer']
|
||||||
|
|
||||||
|
# Mock the hardware extraction
|
||||||
|
with patch.object(extractor, 'extract_csi_data', return_value=mock_raw_csi_data):
|
||||||
|
with patch.object(router, 'connect', return_value=True):
|
||||||
|
with patch.object(router, 'configure_monitor_mode', return_value=True):
|
||||||
|
|
||||||
|
# Act - Run the pipeline
|
||||||
|
# 1. Connect to router and configure
|
||||||
|
router.connect()
|
||||||
|
router.configure_monitor_mode('wlan0', 6)
|
||||||
|
|
||||||
|
# 2. Extract CSI data
|
||||||
|
raw_data = extractor.extract_csi_data()
|
||||||
|
|
||||||
|
# 3. Process CSI data
|
||||||
|
processed_data = processor.process_csi_batch(raw_data['csi_data'])
|
||||||
|
|
||||||
|
# 4. Sanitize phase information
|
||||||
|
sanitized_data = sanitizer.sanitize_phase_batch(processed_data)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert raw_data is not None
|
||||||
|
assert processed_data is not None
|
||||||
|
assert sanitized_data is not None
|
||||||
|
|
||||||
|
# Check data flow integrity
|
||||||
|
assert isinstance(processed_data, torch.Tensor)
|
||||||
|
assert isinstance(sanitized_data, torch.Tensor)
|
||||||
|
assert processed_data.shape == sanitized_data.shape
|
||||||
|
|
||||||
|
def test_pipeline_handles_hardware_connection_failure(self, csi_pipeline_components):
|
||||||
|
"""Test that pipeline handles hardware connection failures gracefully"""
|
||||||
|
# Arrange
|
||||||
|
router = csi_pipeline_components['router']
|
||||||
|
|
||||||
|
# Mock connection failure
|
||||||
|
with patch.object(router, 'connect', return_value=False):
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
connection_result = router.connect()
|
||||||
|
assert connection_result is False
|
||||||
|
|
||||||
|
# Pipeline should handle this gracefully
|
||||||
|
with pytest.raises(Exception): # Should raise appropriate exception
|
||||||
|
router.configure_monitor_mode('wlan0', 6)
|
||||||
|
|
||||||
|
def test_pipeline_handles_csi_extraction_timeout(self, csi_pipeline_components):
|
||||||
|
"""Test that pipeline handles CSI extraction timeouts"""
|
||||||
|
# Arrange
|
||||||
|
extractor = csi_pipeline_components['extractor']
|
||||||
|
|
||||||
|
# Mock extraction timeout
|
||||||
|
with patch.object(extractor, 'extract_csi_data', side_effect=TimeoutError("CSI extraction timeout")):
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(TimeoutError):
|
||||||
|
extractor.extract_csi_data()
|
||||||
|
|
||||||
|
def test_pipeline_handles_invalid_csi_data_format(self, csi_pipeline_components):
|
||||||
|
"""Test that pipeline handles invalid CSI data formats"""
|
||||||
|
# Arrange
|
||||||
|
processor = csi_pipeline_components['processor']
|
||||||
|
|
||||||
|
# Invalid data format
|
||||||
|
invalid_data = np.random.randn(10, 2, 56) # Missing time dimension
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
processor.process_csi_batch(invalid_data)
|
||||||
|
|
||||||
|
def test_pipeline_maintains_data_consistency_across_stages(self, csi_pipeline_components, mock_raw_csi_data):
|
||||||
|
"""Test that pipeline maintains data consistency across processing stages"""
|
||||||
|
# Arrange
|
||||||
|
processor = csi_pipeline_components['processor']
|
||||||
|
sanitizer = csi_pipeline_components['sanitizer']
|
||||||
|
|
||||||
|
csi_data = mock_raw_csi_data['csi_data']
|
||||||
|
|
||||||
|
# Act
|
||||||
|
processed_data = processor.process_csi_batch(csi_data)
|
||||||
|
sanitized_data = sanitizer.sanitize_phase_batch(processed_data)
|
||||||
|
|
||||||
|
# Assert - Check data consistency
|
||||||
|
assert processed_data.shape[0] == sanitized_data.shape[0] # Batch size preserved
|
||||||
|
assert processed_data.shape[1] == sanitized_data.shape[1] # Antenna count preserved
|
||||||
|
assert processed_data.shape[2] == sanitized_data.shape[2] # Subcarrier count preserved
|
||||||
|
|
||||||
|
# Check that data is not corrupted (no NaN or infinite values)
|
||||||
|
assert not torch.isnan(processed_data).any()
|
||||||
|
assert not torch.isinf(processed_data).any()
|
||||||
|
assert not torch.isnan(sanitized_data).any()
|
||||||
|
assert not torch.isinf(sanitized_data).any()
|
||||||
|
|
||||||
|
def test_pipeline_performance_meets_real_time_requirements(self, csi_pipeline_components, mock_raw_csi_data):
|
||||||
|
"""Test that pipeline performance meets real-time processing requirements"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
# Arrange
|
||||||
|
processor = csi_pipeline_components['processor']
|
||||||
|
sanitizer = csi_pipeline_components['sanitizer']
|
||||||
|
|
||||||
|
csi_data = mock_raw_csi_data['csi_data']
|
||||||
|
|
||||||
|
# Act - Measure processing time
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
processed_data = processor.process_csi_batch(csi_data)
|
||||||
|
sanitized_data = sanitizer.sanitize_phase_batch(processed_data)
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
processing_time = end_time - start_time
|
||||||
|
|
||||||
|
# Assert - Should process within reasonable time (< 100ms for this data size)
|
||||||
|
assert processing_time < 0.1, f"Processing took {processing_time:.3f}s, expected < 0.1s"
|
||||||
|
|
||||||
|
def test_pipeline_handles_different_data_sizes(self, csi_pipeline_components):
|
||||||
|
"""Test that pipeline handles different CSI data sizes"""
|
||||||
|
# Arrange
|
||||||
|
processor = csi_pipeline_components['processor']
|
||||||
|
sanitizer = csi_pipeline_components['sanitizer']
|
||||||
|
|
||||||
|
# Different data sizes
|
||||||
|
small_data = np.random.randn(1, 3, 56, 50) + 1j * np.random.randn(1, 3, 56, 50)
|
||||||
|
large_data = np.random.randn(20, 3, 56, 200) + 1j * np.random.randn(20, 3, 56, 200)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
small_processed = processor.process_csi_batch(small_data)
|
||||||
|
small_sanitized = sanitizer.sanitize_phase_batch(small_processed)
|
||||||
|
|
||||||
|
large_processed = processor.process_csi_batch(large_data)
|
||||||
|
large_sanitized = sanitizer.sanitize_phase_batch(large_processed)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert small_processed.shape == small_sanitized.shape
|
||||||
|
assert large_processed.shape == large_sanitized.shape
|
||||||
|
assert small_processed.shape != large_processed.shape # Different sizes
|
||||||
|
|
||||||
|
def test_pipeline_configuration_validation(self, mock_router_config, mock_extractor_config,
|
||||||
|
mock_processor_config, mock_sanitizer_config):
|
||||||
|
"""Test that pipeline components validate configurations properly"""
|
||||||
|
# Arrange - Invalid configurations
|
||||||
|
invalid_router_config = mock_router_config.copy()
|
||||||
|
invalid_router_config['router_ip'] = 'invalid_ip'
|
||||||
|
|
||||||
|
invalid_extractor_config = mock_extractor_config.copy()
|
||||||
|
invalid_extractor_config['antenna_count'] = 0
|
||||||
|
|
||||||
|
invalid_processor_config = mock_processor_config.copy()
|
||||||
|
invalid_processor_config['window_size'] = -1
|
||||||
|
|
||||||
|
invalid_sanitizer_config = mock_sanitizer_config.copy()
|
||||||
|
invalid_sanitizer_config['smoothing_window'] = 0
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
RouterInterface(invalid_router_config)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
CSIExtractor(invalid_extractor_config)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
CSIProcessor(invalid_processor_config)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
PhaseSanitizer(invalid_sanitizer_config)
|
||||||
|
|
||||||
|
def test_pipeline_error_recovery_and_logging(self, csi_pipeline_components, mock_raw_csi_data):
|
||||||
|
"""Test that pipeline handles errors gracefully and logs appropriately"""
|
||||||
|
# Arrange
|
||||||
|
processor = csi_pipeline_components['processor']
|
||||||
|
|
||||||
|
# Corrupt some data to trigger error handling
|
||||||
|
corrupted_data = mock_raw_csi_data['csi_data'].copy()
|
||||||
|
corrupted_data[0, 0, 0, :] = np.inf # Introduce infinite values
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(ValueError): # Should detect and handle corrupted data
|
||||||
|
processor.process_csi_batch(corrupted_data)
|
||||||
|
|
||||||
|
def test_pipeline_memory_usage_optimization(self, csi_pipeline_components):
|
||||||
|
"""Test that pipeline optimizes memory usage for large datasets"""
|
||||||
|
# Arrange
|
||||||
|
processor = csi_pipeline_components['processor']
|
||||||
|
sanitizer = csi_pipeline_components['sanitizer']
|
||||||
|
|
||||||
|
# Large dataset
|
||||||
|
large_data = np.random.randn(100, 3, 56, 1000) + 1j * np.random.randn(100, 3, 56, 1000)
|
||||||
|
|
||||||
|
# Act - Process in chunks to test memory optimization
|
||||||
|
chunk_size = 10
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for i in range(0, large_data.shape[0], chunk_size):
|
||||||
|
chunk = large_data[i:i+chunk_size]
|
||||||
|
processed_chunk = processor.process_csi_batch(chunk)
|
||||||
|
sanitized_chunk = sanitizer.sanitize_phase_batch(processed_chunk)
|
||||||
|
results.append(sanitized_chunk)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(results) == 10 # 100 samples / 10 chunk_size
|
||||||
|
for result in results:
|
||||||
|
assert result.shape[0] <= chunk_size
|
||||||
|
|
||||||
|
def test_pipeline_supports_concurrent_processing(self, csi_pipeline_components, mock_raw_csi_data):
|
||||||
|
"""Test that pipeline supports concurrent processing of multiple streams"""
|
||||||
|
import threading
|
||||||
|
import queue
|
||||||
|
|
||||||
|
# Arrange
|
||||||
|
processor = csi_pipeline_components['processor']
|
||||||
|
sanitizer = csi_pipeline_components['sanitizer']
|
||||||
|
|
||||||
|
results_queue = queue.Queue()
|
||||||
|
|
||||||
|
def process_stream(stream_id, data):
|
||||||
|
try:
|
||||||
|
processed = processor.process_csi_batch(data)
|
||||||
|
sanitized = sanitizer.sanitize_phase_batch(processed)
|
||||||
|
results_queue.put((stream_id, sanitized))
|
||||||
|
except Exception as e:
|
||||||
|
results_queue.put((stream_id, e))
|
||||||
|
|
||||||
|
# Act - Process multiple streams concurrently
|
||||||
|
threads = []
|
||||||
|
for i in range(3):
|
||||||
|
thread = threading.Thread(
|
||||||
|
target=process_stream,
|
||||||
|
args=(i, mock_raw_csi_data['csi_data'])
|
||||||
|
)
|
||||||
|
threads.append(thread)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
# Wait for all threads to complete
|
||||||
|
for thread in threads:
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
results = []
|
||||||
|
while not results_queue.empty():
|
||||||
|
results.append(results_queue.get())
|
||||||
|
|
||||||
|
assert len(results) == 3
|
||||||
|
for stream_id, result in results:
|
||||||
|
assert isinstance(result, torch.Tensor)
|
||||||
|
assert not isinstance(result, Exception)
|
||||||
459
tests/integration/test_inference_pipeline.py
Normal file
459
tests/integration/test_inference_pipeline.py
Normal file
@@ -0,0 +1,459 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from unittest.mock import Mock, patch, MagicMock
|
||||||
|
from src.core.csi_processor import CSIProcessor
|
||||||
|
from src.core.phase_sanitizer import PhaseSanitizer
|
||||||
|
from src.models.modality_translation import ModalityTranslationNetwork
|
||||||
|
from src.models.densepose_head import DensePoseHead
|
||||||
|
|
||||||
|
|
||||||
|
class TestInferencePipeline:
|
||||||
|
"""Integration tests for inference pipeline following London School TDD principles"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_csi_processor_config(self):
|
||||||
|
"""Configuration for CSI processor"""
|
||||||
|
return {
|
||||||
|
'window_size': 100,
|
||||||
|
'overlap': 0.5,
|
||||||
|
'filter_type': 'butterworth',
|
||||||
|
'filter_order': 4,
|
||||||
|
'cutoff_frequency': 50,
|
||||||
|
'normalization': 'minmax',
|
||||||
|
'outlier_threshold': 3.0
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_sanitizer_config(self):
|
||||||
|
"""Configuration for phase sanitizer"""
|
||||||
|
return {
|
||||||
|
'unwrap_method': 'numpy',
|
||||||
|
'smoothing_window': 5,
|
||||||
|
'outlier_threshold': 2.0,
|
||||||
|
'interpolation_method': 'linear',
|
||||||
|
'phase_correction': True
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_translation_config(self):
|
||||||
|
"""Configuration for modality translation network"""
|
||||||
|
return {
|
||||||
|
'input_channels': 6,
|
||||||
|
'output_channels': 256,
|
||||||
|
'hidden_channels': [64, 128, 256],
|
||||||
|
'kernel_sizes': [7, 5, 3],
|
||||||
|
'strides': [2, 2, 1],
|
||||||
|
'dropout_rate': 0.1,
|
||||||
|
'use_attention': True,
|
||||||
|
'attention_heads': 8,
|
||||||
|
'use_residual': True,
|
||||||
|
'activation': 'relu',
|
||||||
|
'normalization': 'batch'
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_densepose_config(self):
|
||||||
|
"""Configuration for DensePose head"""
|
||||||
|
return {
|
||||||
|
'input_channels': 256,
|
||||||
|
'num_body_parts': 24,
|
||||||
|
'num_uv_coordinates': 2,
|
||||||
|
'hidden_channels': [128, 64],
|
||||||
|
'kernel_size': 3,
|
||||||
|
'padding': 1,
|
||||||
|
'dropout_rate': 0.1,
|
||||||
|
'use_deformable_conv': False,
|
||||||
|
'use_fpn': True,
|
||||||
|
'fpn_levels': [2, 3, 4, 5],
|
||||||
|
'output_stride': 4
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def inference_pipeline_components(self, mock_csi_processor_config, mock_sanitizer_config,
|
||||||
|
mock_translation_config, mock_densepose_config):
|
||||||
|
"""Create inference pipeline components for testing"""
|
||||||
|
csi_processor = CSIProcessor(mock_csi_processor_config)
|
||||||
|
phase_sanitizer = PhaseSanitizer(mock_sanitizer_config)
|
||||||
|
translation_network = ModalityTranslationNetwork(mock_translation_config)
|
||||||
|
densepose_head = DensePoseHead(mock_densepose_config)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'csi_processor': csi_processor,
|
||||||
|
'phase_sanitizer': phase_sanitizer,
|
||||||
|
'translation_network': translation_network,
|
||||||
|
'densepose_head': densepose_head
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_raw_csi_input(self):
|
||||||
|
"""Generate mock raw CSI input data"""
|
||||||
|
batch_size = 4
|
||||||
|
antennas = 3
|
||||||
|
subcarriers = 56
|
||||||
|
time_samples = 100
|
||||||
|
|
||||||
|
# Generate complex CSI data
|
||||||
|
real_part = np.random.randn(batch_size, antennas, subcarriers, time_samples)
|
||||||
|
imag_part = np.random.randn(batch_size, antennas, subcarriers, time_samples)
|
||||||
|
|
||||||
|
return real_part + 1j * imag_part
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_ground_truth_densepose(self):
|
||||||
|
"""Generate mock ground truth DensePose annotations"""
|
||||||
|
batch_size = 4
|
||||||
|
height = 224
|
||||||
|
width = 224
|
||||||
|
num_parts = 24
|
||||||
|
|
||||||
|
# Segmentation masks
|
||||||
|
seg_masks = torch.randint(0, num_parts + 1, (batch_size, height, width))
|
||||||
|
|
||||||
|
# UV coordinates
|
||||||
|
uv_coords = torch.randn(batch_size, 2, height, width)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'segmentation': seg_masks,
|
||||||
|
'uv_coordinates': uv_coords
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_end_to_end_inference_pipeline_produces_valid_output(self, inference_pipeline_components, mock_raw_csi_input):
|
||||||
|
"""Test that end-to-end inference pipeline produces valid DensePose output"""
|
||||||
|
# Arrange
|
||||||
|
csi_processor = inference_pipeline_components['csi_processor']
|
||||||
|
phase_sanitizer = inference_pipeline_components['phase_sanitizer']
|
||||||
|
translation_network = inference_pipeline_components['translation_network']
|
||||||
|
densepose_head = inference_pipeline_components['densepose_head']
|
||||||
|
|
||||||
|
# Set models to evaluation mode
|
||||||
|
translation_network.eval()
|
||||||
|
densepose_head.eval()
|
||||||
|
|
||||||
|
# Act - Run the complete inference pipeline
|
||||||
|
with torch.no_grad():
|
||||||
|
# 1. Process CSI data
|
||||||
|
processed_csi = csi_processor.process_csi_batch(mock_raw_csi_input)
|
||||||
|
|
||||||
|
# 2. Sanitize phase information
|
||||||
|
sanitized_csi = phase_sanitizer.sanitize_phase_batch(processed_csi)
|
||||||
|
|
||||||
|
# 3. Translate CSI to visual features
|
||||||
|
visual_features = translation_network(sanitized_csi)
|
||||||
|
|
||||||
|
# 4. Generate DensePose predictions
|
||||||
|
densepose_output = densepose_head(visual_features)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert densepose_output is not None
|
||||||
|
assert isinstance(densepose_output, dict)
|
||||||
|
assert 'segmentation' in densepose_output
|
||||||
|
assert 'uv_coordinates' in densepose_output
|
||||||
|
|
||||||
|
seg_output = densepose_output['segmentation']
|
||||||
|
uv_output = densepose_output['uv_coordinates']
|
||||||
|
|
||||||
|
# Check output shapes
|
||||||
|
assert seg_output.shape[0] == mock_raw_csi_input.shape[0] # Batch size preserved
|
||||||
|
assert seg_output.shape[1] == 25 # 24 body parts + 1 background
|
||||||
|
assert uv_output.shape[0] == mock_raw_csi_input.shape[0] # Batch size preserved
|
||||||
|
assert uv_output.shape[1] == 2 # U and V coordinates
|
||||||
|
|
||||||
|
# Check output ranges
|
||||||
|
assert torch.all(uv_output >= 0) and torch.all(uv_output <= 1) # UV in [0, 1]
|
||||||
|
|
||||||
|
def test_inference_pipeline_handles_different_batch_sizes(self, inference_pipeline_components):
|
||||||
|
"""Test that inference pipeline handles different batch sizes"""
|
||||||
|
# Arrange
|
||||||
|
csi_processor = inference_pipeline_components['csi_processor']
|
||||||
|
phase_sanitizer = inference_pipeline_components['phase_sanitizer']
|
||||||
|
translation_network = inference_pipeline_components['translation_network']
|
||||||
|
densepose_head = inference_pipeline_components['densepose_head']
|
||||||
|
|
||||||
|
# Different batch sizes
|
||||||
|
small_batch = np.random.randn(1, 3, 56, 100) + 1j * np.random.randn(1, 3, 56, 100)
|
||||||
|
large_batch = np.random.randn(8, 3, 56, 100) + 1j * np.random.randn(8, 3, 56, 100)
|
||||||
|
|
||||||
|
# Set models to evaluation mode
|
||||||
|
translation_network.eval()
|
||||||
|
densepose_head.eval()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
with torch.no_grad():
|
||||||
|
# Small batch
|
||||||
|
small_processed = csi_processor.process_csi_batch(small_batch)
|
||||||
|
small_sanitized = phase_sanitizer.sanitize_phase_batch(small_processed)
|
||||||
|
small_features = translation_network(small_sanitized)
|
||||||
|
small_output = densepose_head(small_features)
|
||||||
|
|
||||||
|
# Large batch
|
||||||
|
large_processed = csi_processor.process_csi_batch(large_batch)
|
||||||
|
large_sanitized = phase_sanitizer.sanitize_phase_batch(large_processed)
|
||||||
|
large_features = translation_network(large_sanitized)
|
||||||
|
large_output = densepose_head(large_features)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert small_output['segmentation'].shape[0] == 1
|
||||||
|
assert large_output['segmentation'].shape[0] == 8
|
||||||
|
assert small_output['uv_coordinates'].shape[0] == 1
|
||||||
|
assert large_output['uv_coordinates'].shape[0] == 8
|
||||||
|
|
||||||
|
def test_inference_pipeline_maintains_gradient_flow_during_training(self, inference_pipeline_components,
|
||||||
|
mock_raw_csi_input, mock_ground_truth_densepose):
|
||||||
|
"""Test that inference pipeline maintains gradient flow during training"""
|
||||||
|
# Arrange
|
||||||
|
csi_processor = inference_pipeline_components['csi_processor']
|
||||||
|
phase_sanitizer = inference_pipeline_components['phase_sanitizer']
|
||||||
|
translation_network = inference_pipeline_components['translation_network']
|
||||||
|
densepose_head = inference_pipeline_components['densepose_head']
|
||||||
|
|
||||||
|
# Set models to training mode
|
||||||
|
translation_network.train()
|
||||||
|
densepose_head.train()
|
||||||
|
|
||||||
|
# Create optimizer
|
||||||
|
optimizer = torch.optim.Adam(
|
||||||
|
list(translation_network.parameters()) + list(densepose_head.parameters()),
|
||||||
|
lr=0.001
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
processed_csi = csi_processor.process_csi_batch(mock_raw_csi_input)
|
||||||
|
sanitized_csi = phase_sanitizer.sanitize_phase_batch(processed_csi)
|
||||||
|
visual_features = translation_network(sanitized_csi)
|
||||||
|
densepose_output = densepose_head(visual_features)
|
||||||
|
|
||||||
|
# Resize ground truth to match output
|
||||||
|
seg_target = torch.nn.functional.interpolate(
|
||||||
|
mock_ground_truth_densepose['segmentation'].float().unsqueeze(1),
|
||||||
|
size=densepose_output['segmentation'].shape[2:],
|
||||||
|
mode='nearest'
|
||||||
|
).squeeze(1).long()
|
||||||
|
|
||||||
|
uv_target = torch.nn.functional.interpolate(
|
||||||
|
mock_ground_truth_densepose['uv_coordinates'],
|
||||||
|
size=densepose_output['uv_coordinates'].shape[2:],
|
||||||
|
mode='bilinear',
|
||||||
|
align_corners=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute loss
|
||||||
|
loss = densepose_head.compute_total_loss(densepose_output, seg_target, uv_target)
|
||||||
|
|
||||||
|
# Backward pass
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# Assert - Check that gradients are computed
|
||||||
|
for param in translation_network.parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
assert param.grad is not None
|
||||||
|
assert not torch.allclose(param.grad, torch.zeros_like(param.grad))
|
||||||
|
|
||||||
|
for param in densepose_head.parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
assert param.grad is not None
|
||||||
|
assert not torch.allclose(param.grad, torch.zeros_like(param.grad))
|
||||||
|
|
||||||
|
def test_inference_pipeline_performance_benchmarking(self, inference_pipeline_components, mock_raw_csi_input):
|
||||||
|
"""Test inference pipeline performance for real-time requirements"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
# Arrange
|
||||||
|
csi_processor = inference_pipeline_components['csi_processor']
|
||||||
|
phase_sanitizer = inference_pipeline_components['phase_sanitizer']
|
||||||
|
translation_network = inference_pipeline_components['translation_network']
|
||||||
|
densepose_head = inference_pipeline_components['densepose_head']
|
||||||
|
|
||||||
|
# Set models to evaluation mode for inference
|
||||||
|
translation_network.eval()
|
||||||
|
densepose_head.eval()
|
||||||
|
|
||||||
|
# Warm up (first inference is often slower)
|
||||||
|
with torch.no_grad():
|
||||||
|
processed_csi = csi_processor.process_csi_batch(mock_raw_csi_input)
|
||||||
|
sanitized_csi = phase_sanitizer.sanitize_phase_batch(processed_csi)
|
||||||
|
visual_features = translation_network(sanitized_csi)
|
||||||
|
_ = densepose_head(visual_features)
|
||||||
|
|
||||||
|
# Act - Measure inference time
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
processed_csi = csi_processor.process_csi_batch(mock_raw_csi_input)
|
||||||
|
sanitized_csi = phase_sanitizer.sanitize_phase_batch(processed_csi)
|
||||||
|
visual_features = translation_network(sanitized_csi)
|
||||||
|
densepose_output = densepose_head(visual_features)
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
inference_time = end_time - start_time
|
||||||
|
|
||||||
|
# Assert - Should meet real-time requirements (< 50ms for batch of 4)
|
||||||
|
assert inference_time < 0.05, f"Inference took {inference_time:.3f}s, expected < 0.05s"
|
||||||
|
|
||||||
|
def test_inference_pipeline_handles_edge_cases(self, inference_pipeline_components):
|
||||||
|
"""Test that inference pipeline handles edge cases gracefully"""
|
||||||
|
# Arrange
|
||||||
|
csi_processor = inference_pipeline_components['csi_processor']
|
||||||
|
phase_sanitizer = inference_pipeline_components['phase_sanitizer']
|
||||||
|
translation_network = inference_pipeline_components['translation_network']
|
||||||
|
densepose_head = inference_pipeline_components['densepose_head']
|
||||||
|
|
||||||
|
# Edge cases
|
||||||
|
zero_input = np.zeros((1, 3, 56, 100), dtype=complex)
|
||||||
|
noisy_input = np.random.randn(1, 3, 56, 100) * 100 + 1j * np.random.randn(1, 3, 56, 100) * 100
|
||||||
|
|
||||||
|
translation_network.eval()
|
||||||
|
densepose_head.eval()
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with torch.no_grad():
|
||||||
|
# Zero input
|
||||||
|
zero_processed = csi_processor.process_csi_batch(zero_input)
|
||||||
|
zero_sanitized = phase_sanitizer.sanitize_phase_batch(zero_processed)
|
||||||
|
zero_features = translation_network(zero_sanitized)
|
||||||
|
zero_output = densepose_head(zero_features)
|
||||||
|
|
||||||
|
assert not torch.isnan(zero_output['segmentation']).any()
|
||||||
|
assert not torch.isnan(zero_output['uv_coordinates']).any()
|
||||||
|
|
||||||
|
# Noisy input
|
||||||
|
noisy_processed = csi_processor.process_csi_batch(noisy_input)
|
||||||
|
noisy_sanitized = phase_sanitizer.sanitize_phase_batch(noisy_processed)
|
||||||
|
noisy_features = translation_network(noisy_sanitized)
|
||||||
|
noisy_output = densepose_head(noisy_features)
|
||||||
|
|
||||||
|
assert not torch.isnan(noisy_output['segmentation']).any()
|
||||||
|
assert not torch.isnan(noisy_output['uv_coordinates']).any()
|
||||||
|
|
||||||
|
def test_inference_pipeline_memory_efficiency(self, inference_pipeline_components):
|
||||||
|
"""Test that inference pipeline is memory efficient"""
|
||||||
|
# Arrange
|
||||||
|
csi_processor = inference_pipeline_components['csi_processor']
|
||||||
|
phase_sanitizer = inference_pipeline_components['phase_sanitizer']
|
||||||
|
translation_network = inference_pipeline_components['translation_network']
|
||||||
|
densepose_head = inference_pipeline_components['densepose_head']
|
||||||
|
|
||||||
|
# Large batch to test memory usage
|
||||||
|
large_input = np.random.randn(16, 3, 56, 100) + 1j * np.random.randn(16, 3, 56, 100)
|
||||||
|
|
||||||
|
translation_network.eval()
|
||||||
|
densepose_head.eval()
|
||||||
|
|
||||||
|
# Act - Process in chunks to manage memory
|
||||||
|
chunk_size = 4
|
||||||
|
outputs = []
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for i in range(0, large_input.shape[0], chunk_size):
|
||||||
|
chunk = large_input[i:i+chunk_size]
|
||||||
|
|
||||||
|
processed_chunk = csi_processor.process_csi_batch(chunk)
|
||||||
|
sanitized_chunk = phase_sanitizer.sanitize_phase_batch(processed_chunk)
|
||||||
|
feature_chunk = translation_network(sanitized_chunk)
|
||||||
|
output_chunk = densepose_head(feature_chunk)
|
||||||
|
|
||||||
|
outputs.append(output_chunk)
|
||||||
|
|
||||||
|
# Clear intermediate tensors to free memory
|
||||||
|
del processed_chunk, sanitized_chunk, feature_chunk
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(outputs) == 4 # 16 samples / 4 chunk_size
|
||||||
|
for output in outputs:
|
||||||
|
assert output['segmentation'].shape[0] <= chunk_size
|
||||||
|
|
||||||
|
def test_inference_pipeline_deterministic_output(self, inference_pipeline_components, mock_raw_csi_input):
|
||||||
|
"""Test that inference pipeline produces deterministic output in eval mode"""
|
||||||
|
# Arrange
|
||||||
|
csi_processor = inference_pipeline_components['csi_processor']
|
||||||
|
phase_sanitizer = inference_pipeline_components['phase_sanitizer']
|
||||||
|
translation_network = inference_pipeline_components['translation_network']
|
||||||
|
densepose_head = inference_pipeline_components['densepose_head']
|
||||||
|
|
||||||
|
# Set models to evaluation mode
|
||||||
|
translation_network.eval()
|
||||||
|
densepose_head.eval()
|
||||||
|
|
||||||
|
# Act - Run inference twice
|
||||||
|
with torch.no_grad():
|
||||||
|
# First run
|
||||||
|
processed_csi_1 = csi_processor.process_csi_batch(mock_raw_csi_input)
|
||||||
|
sanitized_csi_1 = phase_sanitizer.sanitize_phase_batch(processed_csi_1)
|
||||||
|
visual_features_1 = translation_network(sanitized_csi_1)
|
||||||
|
output_1 = densepose_head(visual_features_1)
|
||||||
|
|
||||||
|
# Second run
|
||||||
|
processed_csi_2 = csi_processor.process_csi_batch(mock_raw_csi_input)
|
||||||
|
sanitized_csi_2 = phase_sanitizer.sanitize_phase_batch(processed_csi_2)
|
||||||
|
visual_features_2 = translation_network(sanitized_csi_2)
|
||||||
|
output_2 = densepose_head(visual_features_2)
|
||||||
|
|
||||||
|
# Assert - Outputs should be identical in eval mode
|
||||||
|
assert torch.allclose(output_1['segmentation'], output_2['segmentation'], atol=1e-6)
|
||||||
|
assert torch.allclose(output_1['uv_coordinates'], output_2['uv_coordinates'], atol=1e-6)
|
||||||
|
|
||||||
|
def test_inference_pipeline_confidence_estimation(self, inference_pipeline_components, mock_raw_csi_input):
|
||||||
|
"""Test that inference pipeline provides confidence estimates"""
|
||||||
|
# Arrange
|
||||||
|
csi_processor = inference_pipeline_components['csi_processor']
|
||||||
|
phase_sanitizer = inference_pipeline_components['phase_sanitizer']
|
||||||
|
translation_network = inference_pipeline_components['translation_network']
|
||||||
|
densepose_head = inference_pipeline_components['densepose_head']
|
||||||
|
|
||||||
|
translation_network.eval()
|
||||||
|
densepose_head.eval()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
with torch.no_grad():
|
||||||
|
processed_csi = csi_processor.process_csi_batch(mock_raw_csi_input)
|
||||||
|
sanitized_csi = phase_sanitizer.sanitize_phase_batch(processed_csi)
|
||||||
|
visual_features = translation_network(sanitized_csi)
|
||||||
|
densepose_output = densepose_head(visual_features)
|
||||||
|
|
||||||
|
# Get confidence estimates
|
||||||
|
confidence = densepose_head.get_prediction_confidence(densepose_output)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert 'segmentation_confidence' in confidence
|
||||||
|
assert 'uv_confidence' in confidence
|
||||||
|
|
||||||
|
seg_conf = confidence['segmentation_confidence']
|
||||||
|
uv_conf = confidence['uv_confidence']
|
||||||
|
|
||||||
|
assert seg_conf.shape[0] == mock_raw_csi_input.shape[0]
|
||||||
|
assert uv_conf.shape[0] == mock_raw_csi_input.shape[0]
|
||||||
|
assert torch.all(seg_conf >= 0) and torch.all(seg_conf <= 1)
|
||||||
|
assert torch.all(uv_conf >= 0)
|
||||||
|
|
||||||
|
def test_inference_pipeline_post_processing(self, inference_pipeline_components, mock_raw_csi_input):
|
||||||
|
"""Test that inference pipeline post-processes predictions correctly"""
|
||||||
|
# Arrange
|
||||||
|
csi_processor = inference_pipeline_components['csi_processor']
|
||||||
|
phase_sanitizer = inference_pipeline_components['phase_sanitizer']
|
||||||
|
translation_network = inference_pipeline_components['translation_network']
|
||||||
|
densepose_head = inference_pipeline_components['densepose_head']
|
||||||
|
|
||||||
|
translation_network.eval()
|
||||||
|
densepose_head.eval()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
with torch.no_grad():
|
||||||
|
processed_csi = csi_processor.process_csi_batch(mock_raw_csi_input)
|
||||||
|
sanitized_csi = phase_sanitizer.sanitize_phase_batch(processed_csi)
|
||||||
|
visual_features = translation_network(sanitized_csi)
|
||||||
|
raw_output = densepose_head(visual_features)
|
||||||
|
|
||||||
|
# Post-process predictions
|
||||||
|
processed_output = densepose_head.post_process_predictions(raw_output)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert 'body_parts' in processed_output
|
||||||
|
assert 'uv_coordinates' in processed_output
|
||||||
|
assert 'confidence_scores' in processed_output
|
||||||
|
|
||||||
|
body_parts = processed_output['body_parts']
|
||||||
|
assert body_parts.dtype == torch.long # Class indices
|
||||||
|
assert torch.all(body_parts >= 0) and torch.all(body_parts <= 24) # Valid class range
|
||||||
264
tests/unit/test_csi_extractor.py
Normal file
264
tests/unit/test_csi_extractor.py
Normal file
@@ -0,0 +1,264 @@
|
|||||||
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from unittest.mock import Mock, patch, MagicMock
|
||||||
|
from src.hardware.csi_extractor import CSIExtractor, CSIExtractionError
|
||||||
|
|
||||||
|
|
||||||
|
class TestCSIExtractor:
|
||||||
|
"""Test suite for CSI Extractor following London School TDD principles"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_config(self):
|
||||||
|
"""Configuration for CSI extractor"""
|
||||||
|
return {
|
||||||
|
'interface': 'wlan0',
|
||||||
|
'channel': 6,
|
||||||
|
'bandwidth': 20,
|
||||||
|
'sample_rate': 1000,
|
||||||
|
'buffer_size': 1024,
|
||||||
|
'extraction_timeout': 5.0
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_router_interface(self):
|
||||||
|
"""Mock router interface for testing"""
|
||||||
|
mock_router = Mock()
|
||||||
|
mock_router.is_connected = True
|
||||||
|
mock_router.execute_command = Mock()
|
||||||
|
return mock_router
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def csi_extractor(self, mock_config, mock_router_interface):
|
||||||
|
"""Create CSI extractor instance for testing"""
|
||||||
|
return CSIExtractor(mock_config, mock_router_interface)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_csi_data(self):
|
||||||
|
"""Generate synthetic CSI data for testing"""
|
||||||
|
# Simulate CSI data: complex values for multiple subcarriers
|
||||||
|
num_subcarriers = 56
|
||||||
|
num_antennas = 3
|
||||||
|
amplitude = np.random.uniform(0.1, 2.0, (num_antennas, num_subcarriers))
|
||||||
|
phase = np.random.uniform(-np.pi, np.pi, (num_antennas, num_subcarriers))
|
||||||
|
return amplitude * np.exp(1j * phase)
|
||||||
|
|
||||||
|
def test_extractor_initialization_creates_correct_configuration(self, mock_config, mock_router_interface):
|
||||||
|
"""Test that CSI extractor initializes with correct configuration"""
|
||||||
|
# Act
|
||||||
|
extractor = CSIExtractor(mock_config, mock_router_interface)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert extractor is not None
|
||||||
|
assert extractor.interface == mock_config['interface']
|
||||||
|
assert extractor.channel == mock_config['channel']
|
||||||
|
assert extractor.bandwidth == mock_config['bandwidth']
|
||||||
|
assert extractor.sample_rate == mock_config['sample_rate']
|
||||||
|
assert extractor.buffer_size == mock_config['buffer_size']
|
||||||
|
assert extractor.extraction_timeout == mock_config['extraction_timeout']
|
||||||
|
assert extractor.router_interface == mock_router_interface
|
||||||
|
assert not extractor.is_extracting
|
||||||
|
|
||||||
|
def test_start_extraction_configures_monitor_mode(self, csi_extractor, mock_router_interface):
|
||||||
|
"""Test that start_extraction configures monitor mode"""
|
||||||
|
# Arrange
|
||||||
|
mock_router_interface.enable_monitor_mode.return_value = True
|
||||||
|
mock_router_interface.execute_command.return_value = "CSI extraction started"
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = csi_extractor.start_extraction()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is True
|
||||||
|
assert csi_extractor.is_extracting is True
|
||||||
|
mock_router_interface.enable_monitor_mode.assert_called_once_with(csi_extractor.interface)
|
||||||
|
|
||||||
|
def test_start_extraction_handles_monitor_mode_failure(self, csi_extractor, mock_router_interface):
|
||||||
|
"""Test that start_extraction handles monitor mode configuration failure"""
|
||||||
|
# Arrange
|
||||||
|
mock_router_interface.enable_monitor_mode.return_value = False
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(CSIExtractionError):
|
||||||
|
csi_extractor.start_extraction()
|
||||||
|
|
||||||
|
assert csi_extractor.is_extracting is False
|
||||||
|
|
||||||
|
def test_stop_extraction_disables_monitor_mode(self, csi_extractor, mock_router_interface):
|
||||||
|
"""Test that stop_extraction disables monitor mode"""
|
||||||
|
# Arrange
|
||||||
|
mock_router_interface.enable_monitor_mode.return_value = True
|
||||||
|
mock_router_interface.disable_monitor_mode.return_value = True
|
||||||
|
mock_router_interface.execute_command.return_value = "CSI extraction started"
|
||||||
|
|
||||||
|
csi_extractor.start_extraction()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = csi_extractor.stop_extraction()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is True
|
||||||
|
assert csi_extractor.is_extracting is False
|
||||||
|
mock_router_interface.disable_monitor_mode.assert_called_once_with(csi_extractor.interface)
|
||||||
|
|
||||||
|
def test_extract_csi_data_returns_valid_format(self, csi_extractor, mock_router_interface, mock_csi_data):
|
||||||
|
"""Test that extract_csi_data returns data in valid format"""
|
||||||
|
# Arrange
|
||||||
|
mock_router_interface.enable_monitor_mode.return_value = True
|
||||||
|
mock_router_interface.execute_command.return_value = "CSI extraction started"
|
||||||
|
|
||||||
|
# Mock the CSI data extraction
|
||||||
|
with patch.object(csi_extractor, '_parse_csi_output', return_value=mock_csi_data):
|
||||||
|
csi_extractor.start_extraction()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
csi_data = csi_extractor.extract_csi_data()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert csi_data is not None
|
||||||
|
assert isinstance(csi_data, np.ndarray)
|
||||||
|
assert csi_data.dtype == np.complex128
|
||||||
|
assert csi_data.shape == mock_csi_data.shape
|
||||||
|
|
||||||
|
def test_extract_csi_data_requires_active_extraction(self, csi_extractor):
|
||||||
|
"""Test that extract_csi_data requires active extraction"""
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(CSIExtractionError):
|
||||||
|
csi_extractor.extract_csi_data()
|
||||||
|
|
||||||
|
def test_extract_csi_data_handles_timeout(self, csi_extractor, mock_router_interface):
|
||||||
|
"""Test that extract_csi_data handles extraction timeout"""
|
||||||
|
# Arrange
|
||||||
|
mock_router_interface.enable_monitor_mode.return_value = True
|
||||||
|
mock_router_interface.execute_command.side_effect = [
|
||||||
|
"CSI extraction started",
|
||||||
|
Exception("Timeout")
|
||||||
|
]
|
||||||
|
|
||||||
|
csi_extractor.start_extraction()
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(CSIExtractionError):
|
||||||
|
csi_extractor.extract_csi_data()
|
||||||
|
|
||||||
|
def test_convert_to_tensor_produces_correct_format(self, csi_extractor, mock_csi_data):
|
||||||
|
"""Test that convert_to_tensor produces correctly formatted tensor"""
|
||||||
|
# Act
|
||||||
|
tensor = csi_extractor.convert_to_tensor(mock_csi_data)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(tensor, torch.Tensor)
|
||||||
|
assert tensor.dtype == torch.float32
|
||||||
|
assert tensor.shape[0] == mock_csi_data.shape[0] * 2 # Real and imaginary parts
|
||||||
|
assert tensor.shape[1] == mock_csi_data.shape[1]
|
||||||
|
|
||||||
|
def test_convert_to_tensor_handles_invalid_input(self, csi_extractor):
|
||||||
|
"""Test that convert_to_tensor handles invalid input"""
|
||||||
|
# Arrange
|
||||||
|
invalid_data = "not an array"
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
csi_extractor.convert_to_tensor(invalid_data)
|
||||||
|
|
||||||
|
def test_get_extraction_stats_returns_valid_statistics(self, csi_extractor, mock_router_interface):
|
||||||
|
"""Test that get_extraction_stats returns valid statistics"""
|
||||||
|
# Arrange
|
||||||
|
mock_router_interface.enable_monitor_mode.return_value = True
|
||||||
|
mock_router_interface.execute_command.return_value = "CSI extraction started"
|
||||||
|
|
||||||
|
csi_extractor.start_extraction()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
stats = csi_extractor.get_extraction_stats()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert stats is not None
|
||||||
|
assert isinstance(stats, dict)
|
||||||
|
assert 'samples_extracted' in stats
|
||||||
|
assert 'extraction_rate' in stats
|
||||||
|
assert 'buffer_utilization' in stats
|
||||||
|
assert 'last_extraction_time' in stats
|
||||||
|
|
||||||
|
def test_set_channel_configures_wifi_channel(self, csi_extractor, mock_router_interface):
|
||||||
|
"""Test that set_channel configures WiFi channel"""
|
||||||
|
# Arrange
|
||||||
|
new_channel = 11
|
||||||
|
mock_router_interface.execute_command.return_value = f"Channel set to {new_channel}"
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = csi_extractor.set_channel(new_channel)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is True
|
||||||
|
assert csi_extractor.channel == new_channel
|
||||||
|
mock_router_interface.execute_command.assert_called()
|
||||||
|
|
||||||
|
def test_set_channel_validates_channel_range(self, csi_extractor):
|
||||||
|
"""Test that set_channel validates channel range"""
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
csi_extractor.set_channel(0) # Invalid channel
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
csi_extractor.set_channel(15) # Invalid channel
|
||||||
|
|
||||||
|
def test_extractor_supports_context_manager(self, csi_extractor, mock_router_interface):
|
||||||
|
"""Test that CSI extractor supports context manager protocol"""
|
||||||
|
# Arrange
|
||||||
|
mock_router_interface.enable_monitor_mode.return_value = True
|
||||||
|
mock_router_interface.disable_monitor_mode.return_value = True
|
||||||
|
mock_router_interface.execute_command.return_value = "CSI extraction started"
|
||||||
|
|
||||||
|
# Act
|
||||||
|
with csi_extractor as extractor:
|
||||||
|
# Assert
|
||||||
|
assert extractor.is_extracting is True
|
||||||
|
|
||||||
|
# Assert - extraction should be stopped after context
|
||||||
|
assert csi_extractor.is_extracting is False
|
||||||
|
|
||||||
|
def test_extractor_validates_configuration(self, mock_router_interface):
|
||||||
|
"""Test that CSI extractor validates configuration parameters"""
|
||||||
|
# Arrange
|
||||||
|
invalid_config = {
|
||||||
|
'interface': '', # Invalid interface
|
||||||
|
'channel': 6,
|
||||||
|
'bandwidth': 20
|
||||||
|
}
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
CSIExtractor(invalid_config, mock_router_interface)
|
||||||
|
|
||||||
|
def test_parse_csi_output_processes_raw_data(self, csi_extractor):
|
||||||
|
"""Test that _parse_csi_output processes raw CSI data correctly"""
|
||||||
|
# Arrange
|
||||||
|
raw_output = "CSI_DATA: 1.5+0.5j,2.0-1.0j,0.8+1.2j"
|
||||||
|
|
||||||
|
# Act
|
||||||
|
parsed_data = csi_extractor._parse_csi_output(raw_output)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert parsed_data is not None
|
||||||
|
assert isinstance(parsed_data, np.ndarray)
|
||||||
|
assert parsed_data.dtype == np.complex128
|
||||||
|
|
||||||
|
def test_buffer_management_handles_overflow(self, csi_extractor, mock_router_interface, mock_csi_data):
|
||||||
|
"""Test that buffer management handles overflow correctly"""
|
||||||
|
# Arrange
|
||||||
|
mock_router_interface.enable_monitor_mode.return_value = True
|
||||||
|
mock_router_interface.execute_command.return_value = "CSI extraction started"
|
||||||
|
|
||||||
|
with patch.object(csi_extractor, '_parse_csi_output', return_value=mock_csi_data):
|
||||||
|
csi_extractor.start_extraction()
|
||||||
|
|
||||||
|
# Fill buffer beyond capacity
|
||||||
|
for _ in range(csi_extractor.buffer_size + 10):
|
||||||
|
csi_extractor._add_to_buffer(mock_csi_data)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
stats = csi_extractor.get_extraction_stats()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert stats['buffer_utilization'] <= 1.0 # Should not exceed 100%
|
||||||
@@ -3,27 +3,27 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
from src.models.densepose_head import DensePoseHead
|
from src.models.densepose_head import DensePoseHead, DensePoseError
|
||||||
|
|
||||||
|
|
||||||
class TestDensePoseHead:
|
class TestDensePoseHead:
|
||||||
"""Test suite for DensePose Head following London School TDD principles"""
|
"""Test suite for DensePose Head following London School TDD principles"""
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_feature_input(self):
|
|
||||||
"""Generate synthetic feature input tensor for testing"""
|
|
||||||
# Batch size 2, 512 channels, 56 height, 100 width (from modality translation)
|
|
||||||
return torch.randn(2, 512, 56, 100)
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_config(self):
|
def mock_config(self):
|
||||||
"""Configuration for DensePose head"""
|
"""Configuration for DensePose head"""
|
||||||
return {
|
return {
|
||||||
'input_channels': 512,
|
'input_channels': 256,
|
||||||
'num_body_parts': 24, # Standard DensePose body parts
|
'num_body_parts': 24,
|
||||||
'num_uv_coordinates': 2, # U and V coordinates
|
'num_uv_coordinates': 2,
|
||||||
'hidden_dim': 256,
|
'hidden_channels': [128, 64],
|
||||||
'dropout_rate': 0.1
|
'kernel_size': 3,
|
||||||
|
'padding': 1,
|
||||||
|
'dropout_rate': 0.1,
|
||||||
|
'use_deformable_conv': False,
|
||||||
|
'use_fpn': True,
|
||||||
|
'fpn_levels': [2, 3, 4, 5],
|
||||||
|
'output_stride': 4
|
||||||
}
|
}
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -31,6 +31,33 @@ class TestDensePoseHead:
|
|||||||
"""Create DensePose head instance for testing"""
|
"""Create DensePose head instance for testing"""
|
||||||
return DensePoseHead(mock_config)
|
return DensePoseHead(mock_config)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_feature_input(self):
|
||||||
|
"""Generate mock feature input tensor"""
|
||||||
|
batch_size = 2
|
||||||
|
channels = 256
|
||||||
|
height = 56
|
||||||
|
width = 56
|
||||||
|
return torch.randn(batch_size, channels, height, width)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_target_masks(self):
|
||||||
|
"""Generate mock target segmentation masks"""
|
||||||
|
batch_size = 2
|
||||||
|
num_parts = 24
|
||||||
|
height = 224
|
||||||
|
width = 224
|
||||||
|
return torch.randint(0, num_parts + 1, (batch_size, height, width))
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_target_uv(self):
|
||||||
|
"""Generate mock target UV coordinates"""
|
||||||
|
batch_size = 2
|
||||||
|
num_coords = 2
|
||||||
|
height = 224
|
||||||
|
width = 224
|
||||||
|
return torch.randn(batch_size, num_coords, height, width)
|
||||||
|
|
||||||
def test_head_initialization_creates_correct_architecture(self, mock_config):
|
def test_head_initialization_creates_correct_architecture(self, mock_config):
|
||||||
"""Test that DensePose head initializes with correct architecture"""
|
"""Test that DensePose head initializes with correct architecture"""
|
||||||
# Act
|
# Act
|
||||||
@@ -39,135 +66,302 @@ class TestDensePoseHead:
|
|||||||
# Assert
|
# Assert
|
||||||
assert head is not None
|
assert head is not None
|
||||||
assert isinstance(head, nn.Module)
|
assert isinstance(head, nn.Module)
|
||||||
assert hasattr(head, 'segmentation_head')
|
|
||||||
assert hasattr(head, 'uv_regression_head')
|
|
||||||
assert head.input_channels == mock_config['input_channels']
|
assert head.input_channels == mock_config['input_channels']
|
||||||
assert head.num_body_parts == mock_config['num_body_parts']
|
assert head.num_body_parts == mock_config['num_body_parts']
|
||||||
assert head.num_uv_coordinates == mock_config['num_uv_coordinates']
|
assert head.num_uv_coordinates == mock_config['num_uv_coordinates']
|
||||||
|
assert head.use_fpn == mock_config['use_fpn']
|
||||||
|
assert hasattr(head, 'segmentation_head')
|
||||||
|
assert hasattr(head, 'uv_regression_head')
|
||||||
|
if mock_config['use_fpn']:
|
||||||
|
assert hasattr(head, 'fpn')
|
||||||
|
|
||||||
def test_forward_pass_produces_correct_output_shapes(self, densepose_head, mock_feature_input):
|
def test_forward_pass_produces_correct_output_format(self, densepose_head, mock_feature_input):
|
||||||
"""Test that forward pass produces correctly shaped outputs"""
|
"""Test that forward pass produces correctly formatted output"""
|
||||||
# Act
|
# Act
|
||||||
with torch.no_grad():
|
output = densepose_head(mock_feature_input)
|
||||||
segmentation, uv_coords = densepose_head(mock_feature_input)
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert segmentation is not None
|
assert output is not None
|
||||||
assert uv_coords is not None
|
assert isinstance(output, dict)
|
||||||
assert isinstance(segmentation, torch.Tensor)
|
assert 'segmentation' in output
|
||||||
assert isinstance(uv_coords, torch.Tensor)
|
assert 'uv_coordinates' in output
|
||||||
|
|
||||||
# Check segmentation output shape
|
seg_output = output['segmentation']
|
||||||
assert segmentation.shape[0] == mock_feature_input.shape[0] # Batch size preserved
|
uv_output = output['uv_coordinates']
|
||||||
assert segmentation.shape[1] == densepose_head.num_body_parts # Correct number of body parts
|
|
||||||
assert segmentation.shape[2:] == mock_feature_input.shape[2:] # Spatial dimensions preserved
|
|
||||||
|
|
||||||
# Check UV coordinates output shape
|
assert isinstance(seg_output, torch.Tensor)
|
||||||
assert uv_coords.shape[0] == mock_feature_input.shape[0] # Batch size preserved
|
assert isinstance(uv_output, torch.Tensor)
|
||||||
assert uv_coords.shape[1] == densepose_head.num_uv_coordinates # U and V coordinates
|
assert seg_output.shape[0] == mock_feature_input.shape[0] # Batch size preserved
|
||||||
assert uv_coords.shape[2:] == mock_feature_input.shape[2:] # Spatial dimensions preserved
|
assert uv_output.shape[0] == mock_feature_input.shape[0] # Batch size preserved
|
||||||
|
|
||||||
def test_segmentation_output_has_valid_probabilities(self, densepose_head, mock_feature_input):
|
def test_segmentation_head_produces_correct_shape(self, densepose_head, mock_feature_input):
|
||||||
"""Test that segmentation output has valid probability distributions"""
|
"""Test that segmentation head produces correct output shape"""
|
||||||
# Act
|
# Act
|
||||||
with torch.no_grad():
|
output = densepose_head(mock_feature_input)
|
||||||
segmentation, _ = densepose_head(mock_feature_input)
|
seg_output = output['segmentation']
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# After softmax, values should be between 0 and 1
|
expected_channels = densepose_head.num_body_parts + 1 # +1 for background
|
||||||
assert torch.all(segmentation >= 0.0)
|
assert seg_output.shape[1] == expected_channels
|
||||||
assert torch.all(segmentation <= 1.0)
|
assert seg_output.shape[2] >= mock_feature_input.shape[2] # Height upsampled
|
||||||
|
assert seg_output.shape[3] >= mock_feature_input.shape[3] # Width upsampled
|
||||||
# Sum across body parts dimension should be approximately 1
|
|
||||||
part_sums = torch.sum(segmentation, dim=1)
|
|
||||||
assert torch.allclose(part_sums, torch.ones_like(part_sums), atol=1e-5)
|
|
||||||
|
|
||||||
def test_uv_coordinates_output_in_valid_range(self, densepose_head, mock_feature_input):
|
def test_uv_regression_head_produces_correct_shape(self, densepose_head, mock_feature_input):
|
||||||
"""Test that UV coordinates are in valid range [0, 1]"""
|
"""Test that UV regression head produces correct output shape"""
|
||||||
# Act
|
# Act
|
||||||
with torch.no_grad():
|
output = densepose_head(mock_feature_input)
|
||||||
_, uv_coords = densepose_head(mock_feature_input)
|
uv_output = output['uv_coordinates']
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# UV coordinates should be in range [0, 1] after sigmoid
|
assert uv_output.shape[1] == densepose_head.num_uv_coordinates
|
||||||
assert torch.all(uv_coords >= 0.0)
|
assert uv_output.shape[2] >= mock_feature_input.shape[2] # Height upsampled
|
||||||
assert torch.all(uv_coords <= 1.0)
|
assert uv_output.shape[3] >= mock_feature_input.shape[3] # Width upsampled
|
||||||
|
|
||||||
def test_head_handles_different_batch_sizes(self, densepose_head):
|
def test_compute_segmentation_loss_measures_pixel_classification(self, densepose_head, mock_feature_input, mock_target_masks):
|
||||||
"""Test that head handles different batch sizes correctly"""
|
"""Test that compute_segmentation_loss measures pixel classification accuracy"""
|
||||||
# Arrange
|
# Arrange
|
||||||
batch_sizes = [1, 4, 8]
|
output = densepose_head(mock_feature_input)
|
||||||
|
seg_logits = output['segmentation']
|
||||||
|
|
||||||
for batch_size in batch_sizes:
|
# Resize target to match output
|
||||||
input_tensor = torch.randn(batch_size, 512, 56, 100)
|
target_resized = torch.nn.functional.interpolate(
|
||||||
|
mock_target_masks.float().unsqueeze(1),
|
||||||
# Act
|
size=seg_logits.shape[2:],
|
||||||
with torch.no_grad():
|
mode='nearest'
|
||||||
segmentation, uv_coords = densepose_head(input_tensor)
|
).squeeze(1).long()
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert segmentation.shape[0] == batch_size
|
|
||||||
assert uv_coords.shape[0] == batch_size
|
|
||||||
|
|
||||||
def test_head_is_trainable(self, densepose_head, mock_feature_input):
|
|
||||||
"""Test that head parameters are trainable"""
|
|
||||||
# Arrange
|
|
||||||
seg_criterion = nn.CrossEntropyLoss()
|
|
||||||
uv_criterion = nn.MSELoss()
|
|
||||||
|
|
||||||
# Create targets with correct shapes
|
|
||||||
seg_target = torch.randint(0, 24, (2, 56, 100)) # Class indices for segmentation
|
|
||||||
uv_target = torch.rand(2, 2, 56, 100) # UV coordinates target
|
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
segmentation, uv_coords = densepose_head(mock_feature_input)
|
loss = densepose_head.compute_segmentation_loss(seg_logits, target_resized)
|
||||||
seg_loss = seg_criterion(segmentation, seg_target)
|
|
||||||
uv_loss = uv_criterion(uv_coords, uv_target)
|
|
||||||
total_loss = seg_loss + uv_loss
|
|
||||||
total_loss.backward()
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
|
assert loss is not None
|
||||||
|
assert isinstance(loss, torch.Tensor)
|
||||||
|
assert loss.dim() == 0 # Scalar loss
|
||||||
|
assert loss.item() >= 0 # Loss should be non-negative
|
||||||
|
|
||||||
|
def test_compute_uv_loss_measures_coordinate_regression(self, densepose_head, mock_feature_input, mock_target_uv):
|
||||||
|
"""Test that compute_uv_loss measures UV coordinate regression accuracy"""
|
||||||
|
# Arrange
|
||||||
|
output = densepose_head(mock_feature_input)
|
||||||
|
uv_pred = output['uv_coordinates']
|
||||||
|
|
||||||
|
# Resize target to match output
|
||||||
|
target_resized = torch.nn.functional.interpolate(
|
||||||
|
mock_target_uv,
|
||||||
|
size=uv_pred.shape[2:],
|
||||||
|
mode='bilinear',
|
||||||
|
align_corners=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
loss = densepose_head.compute_uv_loss(uv_pred, target_resized)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert loss is not None
|
||||||
|
assert isinstance(loss, torch.Tensor)
|
||||||
|
assert loss.dim() == 0 # Scalar loss
|
||||||
|
assert loss.item() >= 0 # Loss should be non-negative
|
||||||
|
|
||||||
|
def test_compute_total_loss_combines_segmentation_and_uv_losses(self, densepose_head, mock_feature_input, mock_target_masks, mock_target_uv):
|
||||||
|
"""Test that compute_total_loss combines segmentation and UV losses"""
|
||||||
|
# Arrange
|
||||||
|
output = densepose_head(mock_feature_input)
|
||||||
|
|
||||||
|
# Resize targets to match outputs
|
||||||
|
seg_target = torch.nn.functional.interpolate(
|
||||||
|
mock_target_masks.float().unsqueeze(1),
|
||||||
|
size=output['segmentation'].shape[2:],
|
||||||
|
mode='nearest'
|
||||||
|
).squeeze(1).long()
|
||||||
|
|
||||||
|
uv_target = torch.nn.functional.interpolate(
|
||||||
|
mock_target_uv,
|
||||||
|
size=output['uv_coordinates'].shape[2:],
|
||||||
|
mode='bilinear',
|
||||||
|
align_corners=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
total_loss = densepose_head.compute_total_loss(output, seg_target, uv_target)
|
||||||
|
seg_loss = densepose_head.compute_segmentation_loss(output['segmentation'], seg_target)
|
||||||
|
uv_loss = densepose_head.compute_uv_loss(output['uv_coordinates'], uv_target)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert total_loss is not None
|
||||||
|
assert isinstance(total_loss, torch.Tensor)
|
||||||
assert total_loss.item() > 0
|
assert total_loss.item() > 0
|
||||||
# Check that gradients are computed
|
# Total loss should be combination of individual losses
|
||||||
|
expected_total = seg_loss + uv_loss
|
||||||
|
assert torch.allclose(total_loss, expected_total, atol=1e-6)
|
||||||
|
|
||||||
|
def test_fpn_integration_enhances_multi_scale_features(self, mock_config, mock_feature_input):
|
||||||
|
"""Test that FPN integration enhances multi-scale feature processing"""
|
||||||
|
# Arrange
|
||||||
|
config_with_fpn = mock_config.copy()
|
||||||
|
config_with_fpn['use_fpn'] = True
|
||||||
|
|
||||||
|
config_without_fpn = mock_config.copy()
|
||||||
|
config_without_fpn['use_fpn'] = False
|
||||||
|
|
||||||
|
head_with_fpn = DensePoseHead(config_with_fpn)
|
||||||
|
head_without_fpn = DensePoseHead(config_without_fpn)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
output_with_fpn = head_with_fpn(mock_feature_input)
|
||||||
|
output_without_fpn = head_without_fpn(mock_feature_input)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert output_with_fpn['segmentation'].shape == output_without_fpn['segmentation'].shape
|
||||||
|
assert output_with_fpn['uv_coordinates'].shape == output_without_fpn['uv_coordinates'].shape
|
||||||
|
# Outputs should be different due to FPN
|
||||||
|
assert not torch.allclose(output_with_fpn['segmentation'], output_without_fpn['segmentation'], atol=1e-6)
|
||||||
|
|
||||||
|
def test_get_prediction_confidence_provides_uncertainty_estimates(self, densepose_head, mock_feature_input):
|
||||||
|
"""Test that get_prediction_confidence provides uncertainty estimates"""
|
||||||
|
# Arrange
|
||||||
|
output = densepose_head(mock_feature_input)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
confidence = densepose_head.get_prediction_confidence(output)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert confidence is not None
|
||||||
|
assert isinstance(confidence, dict)
|
||||||
|
assert 'segmentation_confidence' in confidence
|
||||||
|
assert 'uv_confidence' in confidence
|
||||||
|
|
||||||
|
seg_conf = confidence['segmentation_confidence']
|
||||||
|
uv_conf = confidence['uv_confidence']
|
||||||
|
|
||||||
|
assert isinstance(seg_conf, torch.Tensor)
|
||||||
|
assert isinstance(uv_conf, torch.Tensor)
|
||||||
|
assert seg_conf.shape[0] == mock_feature_input.shape[0]
|
||||||
|
assert uv_conf.shape[0] == mock_feature_input.shape[0]
|
||||||
|
|
||||||
|
def test_post_process_predictions_formats_output(self, densepose_head, mock_feature_input):
|
||||||
|
"""Test that post_process_predictions formats output correctly"""
|
||||||
|
# Arrange
|
||||||
|
raw_output = densepose_head(mock_feature_input)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
processed = densepose_head.post_process_predictions(raw_output)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert processed is not None
|
||||||
|
assert isinstance(processed, dict)
|
||||||
|
assert 'body_parts' in processed
|
||||||
|
assert 'uv_coordinates' in processed
|
||||||
|
assert 'confidence_scores' in processed
|
||||||
|
|
||||||
|
def test_training_mode_enables_dropout(self, densepose_head, mock_feature_input):
|
||||||
|
"""Test that training mode enables dropout for regularization"""
|
||||||
|
# Arrange
|
||||||
|
densepose_head.train()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
output1 = densepose_head(mock_feature_input)
|
||||||
|
output2 = densepose_head(mock_feature_input)
|
||||||
|
|
||||||
|
# Assert - outputs should be different due to dropout
|
||||||
|
assert not torch.allclose(output1['segmentation'], output2['segmentation'], atol=1e-6)
|
||||||
|
assert not torch.allclose(output1['uv_coordinates'], output2['uv_coordinates'], atol=1e-6)
|
||||||
|
|
||||||
|
def test_evaluation_mode_disables_dropout(self, densepose_head, mock_feature_input):
|
||||||
|
"""Test that evaluation mode disables dropout for consistent inference"""
|
||||||
|
# Arrange
|
||||||
|
densepose_head.eval()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
output1 = densepose_head(mock_feature_input)
|
||||||
|
output2 = densepose_head(mock_feature_input)
|
||||||
|
|
||||||
|
# Assert - outputs should be identical in eval mode
|
||||||
|
assert torch.allclose(output1['segmentation'], output2['segmentation'], atol=1e-6)
|
||||||
|
assert torch.allclose(output1['uv_coordinates'], output2['uv_coordinates'], atol=1e-6)
|
||||||
|
|
||||||
|
def test_head_validates_input_dimensions(self, densepose_head):
|
||||||
|
"""Test that head validates input dimensions"""
|
||||||
|
# Arrange
|
||||||
|
invalid_input = torch.randn(2, 128, 56, 56) # Wrong number of channels
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(DensePoseError):
|
||||||
|
densepose_head(invalid_input)
|
||||||
|
|
||||||
|
def test_head_handles_different_input_sizes(self, densepose_head):
|
||||||
|
"""Test that head handles different input sizes"""
|
||||||
|
# Arrange
|
||||||
|
small_input = torch.randn(1, 256, 28, 28)
|
||||||
|
large_input = torch.randn(1, 256, 112, 112)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
small_output = densepose_head(small_input)
|
||||||
|
large_output = densepose_head(large_input)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert small_output['segmentation'].shape[2:] != large_output['segmentation'].shape[2:]
|
||||||
|
assert small_output['uv_coordinates'].shape[2:] != large_output['uv_coordinates'].shape[2:]
|
||||||
|
|
||||||
|
def test_head_supports_gradient_computation(self, densepose_head, mock_feature_input, mock_target_masks, mock_target_uv):
|
||||||
|
"""Test that head supports gradient computation for training"""
|
||||||
|
# Arrange
|
||||||
|
densepose_head.train()
|
||||||
|
optimizer = torch.optim.Adam(densepose_head.parameters(), lr=0.001)
|
||||||
|
|
||||||
|
output = densepose_head(mock_feature_input)
|
||||||
|
|
||||||
|
# Resize targets
|
||||||
|
seg_target = torch.nn.functional.interpolate(
|
||||||
|
mock_target_masks.float().unsqueeze(1),
|
||||||
|
size=output['segmentation'].shape[2:],
|
||||||
|
mode='nearest'
|
||||||
|
).squeeze(1).long()
|
||||||
|
|
||||||
|
uv_target = torch.nn.functional.interpolate(
|
||||||
|
mock_target_uv,
|
||||||
|
size=output['uv_coordinates'].shape[2:],
|
||||||
|
mode='bilinear',
|
||||||
|
align_corners=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
loss = densepose_head.compute_total_loss(output, seg_target, uv_target)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# Assert
|
||||||
for param in densepose_head.parameters():
|
for param in densepose_head.parameters():
|
||||||
if param.requires_grad:
|
if param.requires_grad:
|
||||||
assert param.grad is not None
|
assert param.grad is not None
|
||||||
|
assert not torch.allclose(param.grad, torch.zeros_like(param.grad))
|
||||||
|
|
||||||
def test_head_handles_invalid_input_shape(self, densepose_head):
|
def test_head_configuration_validation(self):
|
||||||
"""Test that head handles invalid input shapes gracefully"""
|
"""Test that head validates configuration parameters"""
|
||||||
# Arrange
|
# Arrange
|
||||||
invalid_input = torch.randn(2, 256, 56, 100) # Wrong number of channels
|
invalid_config = {
|
||||||
|
'input_channels': 0, # Invalid
|
||||||
|
'num_body_parts': -1, # Invalid
|
||||||
|
'num_uv_coordinates': 2
|
||||||
|
}
|
||||||
|
|
||||||
# Act & Assert
|
# Act & Assert
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(ValueError):
|
||||||
densepose_head(invalid_input)
|
DensePoseHead(invalid_config)
|
||||||
|
|
||||||
def test_head_supports_evaluation_mode(self, densepose_head, mock_feature_input):
|
def test_save_and_load_model_state(self, densepose_head, mock_feature_input):
|
||||||
"""Test that head supports evaluation mode"""
|
"""Test that model state can be saved and loaded"""
|
||||||
# Act
|
# Arrange
|
||||||
densepose_head.eval()
|
original_output = densepose_head(mock_feature_input)
|
||||||
|
|
||||||
with torch.no_grad():
|
# Act - Save state
|
||||||
seg1, uv1 = densepose_head(mock_feature_input)
|
state_dict = densepose_head.state_dict()
|
||||||
seg2, uv2 = densepose_head(mock_feature_input)
|
|
||||||
|
|
||||||
# Assert - In eval mode with same input, outputs should be identical
|
# Create new head and load state
|
||||||
assert torch.allclose(seg1, seg2, atol=1e-6)
|
new_head = DensePoseHead(densepose_head.config)
|
||||||
assert torch.allclose(uv1, uv2, atol=1e-6)
|
new_head.load_state_dict(state_dict)
|
||||||
|
new_output = new_head(mock_feature_input)
|
||||||
def test_head_output_quality(self, densepose_head, mock_feature_input):
|
|
||||||
"""Test that head produces meaningful outputs"""
|
|
||||||
# Act
|
|
||||||
with torch.no_grad():
|
|
||||||
segmentation, uv_coords = densepose_head(mock_feature_input)
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# Outputs should not contain NaN or Inf values
|
assert torch.allclose(original_output['segmentation'], new_output['segmentation'], atol=1e-6)
|
||||||
assert not torch.isnan(segmentation).any()
|
assert torch.allclose(original_output['uv_coordinates'], new_output['uv_coordinates'], atol=1e-6)
|
||||||
assert not torch.isinf(segmentation).any()
|
|
||||||
assert not torch.isnan(uv_coords).any()
|
|
||||||
assert not torch.isinf(uv_coords).any()
|
|
||||||
|
|
||||||
# Outputs should have reasonable variance (not all zeros or ones)
|
|
||||||
assert segmentation.std() > 0.01
|
|
||||||
assert uv_coords.std() > 0.01
|
|
||||||
@@ -3,27 +3,27 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
from src.models.modality_translation import ModalityTranslationNetwork
|
from src.models.modality_translation import ModalityTranslationNetwork, ModalityTranslationError
|
||||||
|
|
||||||
|
|
||||||
class TestModalityTranslationNetwork:
|
class TestModalityTranslationNetwork:
|
||||||
"""Test suite for Modality Translation Network following London School TDD principles"""
|
"""Test suite for Modality Translation Network following London School TDD principles"""
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_csi_input(self):
|
|
||||||
"""Generate synthetic CSI input tensor for testing"""
|
|
||||||
# Batch size 2, 3 antennas, 56 subcarriers, 100 temporal samples
|
|
||||||
return torch.randn(2, 3, 56, 100)
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_config(self):
|
def mock_config(self):
|
||||||
"""Configuration for modality translation network"""
|
"""Configuration for modality translation network"""
|
||||||
return {
|
return {
|
||||||
'input_channels': 3,
|
'input_channels': 6, # Real and imaginary parts for 3 antennas
|
||||||
'hidden_dim': 256,
|
'hidden_channels': [64, 128, 256],
|
||||||
'output_dim': 512,
|
'output_channels': 256,
|
||||||
'num_layers': 3,
|
'kernel_size': 3,
|
||||||
'dropout_rate': 0.1
|
'stride': 1,
|
||||||
|
'padding': 1,
|
||||||
|
'dropout_rate': 0.1,
|
||||||
|
'activation': 'relu',
|
||||||
|
'normalization': 'batch',
|
||||||
|
'use_attention': True,
|
||||||
|
'attention_heads': 8
|
||||||
}
|
}
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -31,98 +31,263 @@ class TestModalityTranslationNetwork:
|
|||||||
"""Create modality translation network instance for testing"""
|
"""Create modality translation network instance for testing"""
|
||||||
return ModalityTranslationNetwork(mock_config)
|
return ModalityTranslationNetwork(mock_config)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_csi_input(self):
|
||||||
|
"""Generate mock CSI input tensor"""
|
||||||
|
batch_size = 4
|
||||||
|
channels = 6 # Real and imaginary parts for 3 antennas
|
||||||
|
height = 56 # Number of subcarriers
|
||||||
|
width = 100 # Time samples
|
||||||
|
return torch.randn(batch_size, channels, height, width)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_target_features(self):
|
||||||
|
"""Generate mock target feature tensor for training"""
|
||||||
|
batch_size = 4
|
||||||
|
feature_dim = 256
|
||||||
|
spatial_height = 56
|
||||||
|
spatial_width = 100
|
||||||
|
return torch.randn(batch_size, feature_dim, spatial_height, spatial_width)
|
||||||
|
|
||||||
def test_network_initialization_creates_correct_architecture(self, mock_config):
|
def test_network_initialization_creates_correct_architecture(self, mock_config):
|
||||||
"""Test that network initializes with correct architecture"""
|
"""Test that modality translation network initializes with correct architecture"""
|
||||||
# Act
|
# Act
|
||||||
network = ModalityTranslationNetwork(mock_config)
|
network = ModalityTranslationNetwork(mock_config)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert network is not None
|
assert network is not None
|
||||||
assert isinstance(network, nn.Module)
|
assert isinstance(network, nn.Module)
|
||||||
|
assert network.input_channels == mock_config['input_channels']
|
||||||
|
assert network.output_channels == mock_config['output_channels']
|
||||||
|
assert network.use_attention == mock_config['use_attention']
|
||||||
assert hasattr(network, 'encoder')
|
assert hasattr(network, 'encoder')
|
||||||
assert hasattr(network, 'decoder')
|
assert hasattr(network, 'decoder')
|
||||||
assert network.input_channels == mock_config['input_channels']
|
if mock_config['use_attention']:
|
||||||
assert network.hidden_dim == mock_config['hidden_dim']
|
assert hasattr(network, 'attention')
|
||||||
assert network.output_dim == mock_config['output_dim']
|
|
||||||
|
|
||||||
def test_forward_pass_produces_correct_output_shape(self, translation_network, mock_csi_input):
|
def test_forward_pass_produces_correct_output_shape(self, translation_network, mock_csi_input):
|
||||||
"""Test that forward pass produces correctly shaped output"""
|
"""Test that forward pass produces correctly shaped output"""
|
||||||
# Act
|
# Act
|
||||||
with torch.no_grad():
|
output = translation_network(mock_csi_input)
|
||||||
output = translation_network(mock_csi_input)
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert output is not None
|
assert output is not None
|
||||||
assert isinstance(output, torch.Tensor)
|
assert isinstance(output, torch.Tensor)
|
||||||
assert output.shape[0] == mock_csi_input.shape[0] # Batch size preserved
|
assert output.shape[0] == mock_csi_input.shape[0] # Batch size preserved
|
||||||
assert output.shape[1] == translation_network.output_dim # Correct output dimension
|
assert output.shape[1] == translation_network.output_channels # Correct output channels
|
||||||
assert len(output.shape) == 4 # Should maintain spatial dimensions
|
assert output.shape[2] == mock_csi_input.shape[2] # Spatial height preserved
|
||||||
|
assert output.shape[3] == mock_csi_input.shape[3] # Spatial width preserved
|
||||||
|
|
||||||
def test_forward_pass_handles_different_batch_sizes(self, translation_network):
|
def test_forward_pass_handles_different_input_sizes(self, translation_network):
|
||||||
"""Test that network handles different batch sizes correctly"""
|
"""Test that forward pass handles different input sizes"""
|
||||||
# Arrange
|
# Arrange
|
||||||
batch_sizes = [1, 4, 8]
|
small_input = torch.randn(2, 6, 28, 50)
|
||||||
|
large_input = torch.randn(8, 6, 112, 200)
|
||||||
|
|
||||||
for batch_size in batch_sizes:
|
# Act
|
||||||
input_tensor = torch.randn(batch_size, 3, 56, 100)
|
small_output = translation_network(small_input)
|
||||||
|
large_output = translation_network(large_input)
|
||||||
# Act
|
|
||||||
with torch.no_grad():
|
# Assert
|
||||||
output = translation_network(input_tensor)
|
assert small_output.shape == (2, 256, 28, 50)
|
||||||
|
assert large_output.shape == (8, 256, 112, 200)
|
||||||
# Assert
|
|
||||||
assert output.shape[0] == batch_size
|
|
||||||
assert output.shape[1] == translation_network.output_dim
|
|
||||||
|
|
||||||
def test_network_is_trainable(self, translation_network, mock_csi_input):
|
def test_encoder_extracts_hierarchical_features(self, translation_network, mock_csi_input):
|
||||||
"""Test that network parameters are trainable"""
|
"""Test that encoder extracts hierarchical features"""
|
||||||
|
# Act
|
||||||
|
features = translation_network.encode(mock_csi_input)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert features is not None
|
||||||
|
assert isinstance(features, list)
|
||||||
|
assert len(features) == len(translation_network.encoder)
|
||||||
|
|
||||||
|
# Check feature map sizes decrease with depth
|
||||||
|
for i in range(1, len(features)):
|
||||||
|
assert features[i].shape[2] <= features[i-1].shape[2] # Height decreases or stays same
|
||||||
|
assert features[i].shape[3] <= features[i-1].shape[3] # Width decreases or stays same
|
||||||
|
|
||||||
|
def test_decoder_reconstructs_target_features(self, translation_network, mock_csi_input):
|
||||||
|
"""Test that decoder reconstructs target feature representation"""
|
||||||
# Arrange
|
# Arrange
|
||||||
criterion = nn.MSELoss()
|
encoded_features = translation_network.encode(mock_csi_input)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
decoded_output = translation_network.decode(encoded_features)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert decoded_output is not None
|
||||||
|
assert isinstance(decoded_output, torch.Tensor)
|
||||||
|
assert decoded_output.shape[1] == translation_network.output_channels
|
||||||
|
assert decoded_output.shape[2:] == mock_csi_input.shape[2:]
|
||||||
|
|
||||||
|
def test_attention_mechanism_enhances_features(self, mock_config, mock_csi_input):
|
||||||
|
"""Test that attention mechanism enhances feature representation"""
|
||||||
|
# Arrange
|
||||||
|
config_with_attention = mock_config.copy()
|
||||||
|
config_with_attention['use_attention'] = True
|
||||||
|
|
||||||
|
config_without_attention = mock_config.copy()
|
||||||
|
config_without_attention['use_attention'] = False
|
||||||
|
|
||||||
|
network_with_attention = ModalityTranslationNetwork(config_with_attention)
|
||||||
|
network_without_attention = ModalityTranslationNetwork(config_without_attention)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
output_with_attention = network_with_attention(mock_csi_input)
|
||||||
|
output_without_attention = network_without_attention(mock_csi_input)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert output_with_attention.shape == output_without_attention.shape
|
||||||
|
# Outputs should be different due to attention mechanism
|
||||||
|
assert not torch.allclose(output_with_attention, output_without_attention, atol=1e-6)
|
||||||
|
|
||||||
|
def test_training_mode_enables_dropout(self, translation_network, mock_csi_input):
|
||||||
|
"""Test that training mode enables dropout for regularization"""
|
||||||
|
# Arrange
|
||||||
|
translation_network.train()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
output1 = translation_network(mock_csi_input)
|
||||||
|
output2 = translation_network(mock_csi_input)
|
||||||
|
|
||||||
|
# Assert - outputs should be different due to dropout
|
||||||
|
assert not torch.allclose(output1, output2, atol=1e-6)
|
||||||
|
|
||||||
|
def test_evaluation_mode_disables_dropout(self, translation_network, mock_csi_input):
|
||||||
|
"""Test that evaluation mode disables dropout for consistent inference"""
|
||||||
|
# Arrange
|
||||||
|
translation_network.eval()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
output1 = translation_network(mock_csi_input)
|
||||||
|
output2 = translation_network(mock_csi_input)
|
||||||
|
|
||||||
|
# Assert - outputs should be identical in eval mode
|
||||||
|
assert torch.allclose(output1, output2, atol=1e-6)
|
||||||
|
|
||||||
|
def test_compute_translation_loss_measures_feature_alignment(self, translation_network, mock_csi_input, mock_target_features):
|
||||||
|
"""Test that compute_translation_loss measures feature alignment"""
|
||||||
|
# Arrange
|
||||||
|
predicted_features = translation_network(mock_csi_input)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
loss = translation_network.compute_translation_loss(predicted_features, mock_target_features)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert loss is not None
|
||||||
|
assert isinstance(loss, torch.Tensor)
|
||||||
|
assert loss.dim() == 0 # Scalar loss
|
||||||
|
assert loss.item() >= 0 # Loss should be non-negative
|
||||||
|
|
||||||
|
def test_compute_translation_loss_handles_different_loss_types(self, translation_network, mock_csi_input, mock_target_features):
|
||||||
|
"""Test that compute_translation_loss handles different loss types"""
|
||||||
|
# Arrange
|
||||||
|
predicted_features = translation_network(mock_csi_input)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
mse_loss = translation_network.compute_translation_loss(predicted_features, mock_target_features, loss_type='mse')
|
||||||
|
l1_loss = translation_network.compute_translation_loss(predicted_features, mock_target_features, loss_type='l1')
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert mse_loss is not None
|
||||||
|
assert l1_loss is not None
|
||||||
|
assert mse_loss.item() != l1_loss.item() # Different loss types should give different values
|
||||||
|
|
||||||
|
def test_get_feature_statistics_provides_analysis(self, translation_network, mock_csi_input):
|
||||||
|
"""Test that get_feature_statistics provides feature analysis"""
|
||||||
|
# Arrange
|
||||||
|
output = translation_network(mock_csi_input)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
stats = translation_network.get_feature_statistics(output)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert stats is not None
|
||||||
|
assert isinstance(stats, dict)
|
||||||
|
assert 'mean' in stats
|
||||||
|
assert 'std' in stats
|
||||||
|
assert 'min' in stats
|
||||||
|
assert 'max' in stats
|
||||||
|
assert 'sparsity' in stats
|
||||||
|
|
||||||
|
def test_network_supports_gradient_computation(self, translation_network, mock_csi_input, mock_target_features):
|
||||||
|
"""Test that network supports gradient computation for training"""
|
||||||
|
# Arrange
|
||||||
|
translation_network.train()
|
||||||
|
optimizer = torch.optim.Adam(translation_network.parameters(), lr=0.001)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
output = translation_network(mock_csi_input)
|
output = translation_network(mock_csi_input)
|
||||||
# Create target with same shape as output
|
loss = translation_network.compute_translation_loss(output, mock_target_features)
|
||||||
target = torch.randn_like(output)
|
|
||||||
loss = criterion(output, target)
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert loss.item() > 0
|
|
||||||
# Check that gradients are computed
|
|
||||||
for param in translation_network.parameters():
|
for param in translation_network.parameters():
|
||||||
if param.requires_grad:
|
if param.requires_grad:
|
||||||
assert param.grad is not None
|
assert param.grad is not None
|
||||||
|
assert not torch.allclose(param.grad, torch.zeros_like(param.grad))
|
||||||
|
|
||||||
def test_network_handles_invalid_input_shape(self, translation_network):
|
def test_network_validates_input_dimensions(self, translation_network):
|
||||||
"""Test that network handles invalid input shapes gracefully"""
|
"""Test that network validates input dimensions"""
|
||||||
# Arrange
|
# Arrange
|
||||||
invalid_input = torch.randn(2, 5, 56, 100) # Wrong number of channels
|
invalid_input = torch.randn(4, 3, 56, 100) # Wrong number of channels
|
||||||
|
|
||||||
# Act & Assert
|
# Act & Assert
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(ModalityTranslationError):
|
||||||
translation_network(invalid_input)
|
translation_network(invalid_input)
|
||||||
|
|
||||||
def test_network_supports_evaluation_mode(self, translation_network, mock_csi_input):
|
def test_network_handles_batch_size_one(self, translation_network):
|
||||||
"""Test that network supports evaluation mode"""
|
"""Test that network handles single sample inference"""
|
||||||
# Act
|
# Arrange
|
||||||
translation_network.eval()
|
single_input = torch.randn(1, 6, 56, 100)
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
output1 = translation_network(mock_csi_input)
|
|
||||||
output2 = translation_network(mock_csi_input)
|
|
||||||
|
|
||||||
# Assert - In eval mode with same input, outputs should be identical
|
|
||||||
assert torch.allclose(output1, output2, atol=1e-6)
|
|
||||||
|
|
||||||
def test_network_feature_extraction_quality(self, translation_network, mock_csi_input):
|
|
||||||
"""Test that network extracts meaningful features"""
|
|
||||||
# Act
|
# Act
|
||||||
with torch.no_grad():
|
output = translation_network(single_input)
|
||||||
output = translation_network(mock_csi_input)
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# Features should have reasonable statistics
|
assert output.shape == (1, 256, 56, 100)
|
||||||
assert not torch.isnan(output).any()
|
|
||||||
assert not torch.isinf(output).any()
|
def test_save_and_load_model_state(self, translation_network, mock_csi_input):
|
||||||
assert output.std() > 0.01 # Features should have some variance
|
"""Test that model state can be saved and loaded"""
|
||||||
assert output.std() < 10.0 # But not be too extreme
|
# Arrange
|
||||||
|
original_output = translation_network(mock_csi_input)
|
||||||
|
|
||||||
|
# Act - Save state
|
||||||
|
state_dict = translation_network.state_dict()
|
||||||
|
|
||||||
|
# Create new network and load state
|
||||||
|
new_network = ModalityTranslationNetwork(translation_network.config)
|
||||||
|
new_network.load_state_dict(state_dict)
|
||||||
|
new_output = new_network(mock_csi_input)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert torch.allclose(original_output, new_output, atol=1e-6)
|
||||||
|
|
||||||
|
def test_network_configuration_validation(self):
|
||||||
|
"""Test that network validates configuration parameters"""
|
||||||
|
# Arrange
|
||||||
|
invalid_config = {
|
||||||
|
'input_channels': 0, # Invalid
|
||||||
|
'hidden_channels': [], # Invalid
|
||||||
|
'output_channels': 256
|
||||||
|
}
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
ModalityTranslationNetwork(invalid_config)
|
||||||
|
|
||||||
|
def test_feature_visualization_support(self, translation_network, mock_csi_input):
|
||||||
|
"""Test that network supports feature visualization"""
|
||||||
|
# Act
|
||||||
|
features = translation_network.get_intermediate_features(mock_csi_input)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert features is not None
|
||||||
|
assert isinstance(features, dict)
|
||||||
|
assert 'encoder_features' in features
|
||||||
|
assert 'decoder_features' in features
|
||||||
|
if translation_network.use_attention:
|
||||||
|
assert 'attention_weights' in features
|
||||||
244
tests/unit/test_router_interface.py
Normal file
244
tests/unit/test_router_interface.py
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
from unittest.mock import Mock, patch, MagicMock
|
||||||
|
from src.hardware.router_interface import RouterInterface, RouterConnectionError
|
||||||
|
|
||||||
|
|
||||||
|
class TestRouterInterface:
|
||||||
|
"""Test suite for Router Interface following London School TDD principles"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_config(self):
|
||||||
|
"""Configuration for router interface"""
|
||||||
|
return {
|
||||||
|
'router_ip': '192.168.1.1',
|
||||||
|
'username': 'admin',
|
||||||
|
'password': 'password',
|
||||||
|
'ssh_port': 22,
|
||||||
|
'timeout': 30,
|
||||||
|
'max_retries': 3
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def router_interface(self, mock_config):
|
||||||
|
"""Create router interface instance for testing"""
|
||||||
|
return RouterInterface(mock_config)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_ssh_client(self):
|
||||||
|
"""Mock SSH client for testing"""
|
||||||
|
mock_client = Mock()
|
||||||
|
mock_client.connect = Mock()
|
||||||
|
mock_client.exec_command = Mock()
|
||||||
|
mock_client.close = Mock()
|
||||||
|
return mock_client
|
||||||
|
|
||||||
|
def test_interface_initialization_creates_correct_configuration(self, mock_config):
|
||||||
|
"""Test that router interface initializes with correct configuration"""
|
||||||
|
# Act
|
||||||
|
interface = RouterInterface(mock_config)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert interface is not None
|
||||||
|
assert interface.router_ip == mock_config['router_ip']
|
||||||
|
assert interface.username == mock_config['username']
|
||||||
|
assert interface.password == mock_config['password']
|
||||||
|
assert interface.ssh_port == mock_config['ssh_port']
|
||||||
|
assert interface.timeout == mock_config['timeout']
|
||||||
|
assert interface.max_retries == mock_config['max_retries']
|
||||||
|
assert not interface.is_connected
|
||||||
|
|
||||||
|
@patch('paramiko.SSHClient')
|
||||||
|
def test_connect_establishes_ssh_connection(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||||
|
"""Test that connect method establishes SSH connection"""
|
||||||
|
# Arrange
|
||||||
|
mock_ssh_class.return_value = mock_ssh_client
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = router_interface.connect()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is True
|
||||||
|
assert router_interface.is_connected is True
|
||||||
|
mock_ssh_client.set_missing_host_key_policy.assert_called_once()
|
||||||
|
mock_ssh_client.connect.assert_called_once_with(
|
||||||
|
hostname=router_interface.router_ip,
|
||||||
|
port=router_interface.ssh_port,
|
||||||
|
username=router_interface.username,
|
||||||
|
password=router_interface.password,
|
||||||
|
timeout=router_interface.timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch('paramiko.SSHClient')
|
||||||
|
def test_connect_handles_connection_failure(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||||
|
"""Test that connect method handles connection failures gracefully"""
|
||||||
|
# Arrange
|
||||||
|
mock_ssh_class.return_value = mock_ssh_client
|
||||||
|
mock_ssh_client.connect.side_effect = Exception("Connection failed")
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(RouterConnectionError):
|
||||||
|
router_interface.connect()
|
||||||
|
|
||||||
|
assert router_interface.is_connected is False
|
||||||
|
|
||||||
|
@patch('paramiko.SSHClient')
|
||||||
|
def test_disconnect_closes_ssh_connection(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||||
|
"""Test that disconnect method closes SSH connection"""
|
||||||
|
# Arrange
|
||||||
|
mock_ssh_class.return_value = mock_ssh_client
|
||||||
|
router_interface.connect()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
router_interface.disconnect()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert router_interface.is_connected is False
|
||||||
|
mock_ssh_client.close.assert_called_once()
|
||||||
|
|
||||||
|
@patch('paramiko.SSHClient')
|
||||||
|
def test_execute_command_runs_ssh_command(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||||
|
"""Test that execute_command runs SSH commands correctly"""
|
||||||
|
# Arrange
|
||||||
|
mock_ssh_class.return_value = mock_ssh_client
|
||||||
|
mock_stdout = Mock()
|
||||||
|
mock_stdout.read.return_value = b"command output"
|
||||||
|
mock_stderr = Mock()
|
||||||
|
mock_stderr.read.return_value = b""
|
||||||
|
mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr)
|
||||||
|
|
||||||
|
router_interface.connect()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = router_interface.execute_command("test command")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result == "command output"
|
||||||
|
mock_ssh_client.exec_command.assert_called_with("test command")
|
||||||
|
|
||||||
|
@patch('paramiko.SSHClient')
|
||||||
|
def test_execute_command_handles_command_errors(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||||
|
"""Test that execute_command handles command errors"""
|
||||||
|
# Arrange
|
||||||
|
mock_ssh_class.return_value = mock_ssh_client
|
||||||
|
mock_stdout = Mock()
|
||||||
|
mock_stdout.read.return_value = b""
|
||||||
|
mock_stderr = Mock()
|
||||||
|
mock_stderr.read.return_value = b"command error"
|
||||||
|
mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr)
|
||||||
|
|
||||||
|
router_interface.connect()
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(RouterConnectionError):
|
||||||
|
router_interface.execute_command("failing command")
|
||||||
|
|
||||||
|
def test_execute_command_requires_connection(self, router_interface):
|
||||||
|
"""Test that execute_command requires active connection"""
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(RouterConnectionError):
|
||||||
|
router_interface.execute_command("test command")
|
||||||
|
|
||||||
|
@patch('paramiko.SSHClient')
|
||||||
|
def test_get_router_info_retrieves_system_information(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||||
|
"""Test that get_router_info retrieves router system information"""
|
||||||
|
# Arrange
|
||||||
|
mock_ssh_class.return_value = mock_ssh_client
|
||||||
|
mock_stdout = Mock()
|
||||||
|
mock_stdout.read.return_value = b"Router Model: AC1900\nFirmware: 1.2.3"
|
||||||
|
mock_stderr = Mock()
|
||||||
|
mock_stderr.read.return_value = b""
|
||||||
|
mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr)
|
||||||
|
|
||||||
|
router_interface.connect()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
info = router_interface.get_router_info()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert info is not None
|
||||||
|
assert isinstance(info, dict)
|
||||||
|
assert 'model' in info
|
||||||
|
assert 'firmware' in info
|
||||||
|
|
||||||
|
@patch('paramiko.SSHClient')
|
||||||
|
def test_enable_monitor_mode_configures_wifi_monitoring(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||||
|
"""Test that enable_monitor_mode configures WiFi monitoring"""
|
||||||
|
# Arrange
|
||||||
|
mock_ssh_class.return_value = mock_ssh_client
|
||||||
|
mock_stdout = Mock()
|
||||||
|
mock_stdout.read.return_value = b"Monitor mode enabled"
|
||||||
|
mock_stderr = Mock()
|
||||||
|
mock_stderr.read.return_value = b""
|
||||||
|
mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr)
|
||||||
|
|
||||||
|
router_interface.connect()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = router_interface.enable_monitor_mode("wlan0")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is True
|
||||||
|
mock_ssh_client.exec_command.assert_called()
|
||||||
|
|
||||||
|
@patch('paramiko.SSHClient')
|
||||||
|
def test_disable_monitor_mode_disables_wifi_monitoring(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||||
|
"""Test that disable_monitor_mode disables WiFi monitoring"""
|
||||||
|
# Arrange
|
||||||
|
mock_ssh_class.return_value = mock_ssh_client
|
||||||
|
mock_stdout = Mock()
|
||||||
|
mock_stdout.read.return_value = b"Monitor mode disabled"
|
||||||
|
mock_stderr = Mock()
|
||||||
|
mock_stderr.read.return_value = b""
|
||||||
|
mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr)
|
||||||
|
|
||||||
|
router_interface.connect()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = router_interface.disable_monitor_mode("wlan0")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is True
|
||||||
|
mock_ssh_client.exec_command.assert_called()
|
||||||
|
|
||||||
|
@patch('paramiko.SSHClient')
|
||||||
|
def test_interface_supports_context_manager(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||||
|
"""Test that router interface supports context manager protocol"""
|
||||||
|
# Arrange
|
||||||
|
mock_ssh_class.return_value = mock_ssh_client
|
||||||
|
|
||||||
|
# Act
|
||||||
|
with router_interface as interface:
|
||||||
|
# Assert
|
||||||
|
assert interface.is_connected is True
|
||||||
|
|
||||||
|
# Assert - connection should be closed after context
|
||||||
|
assert router_interface.is_connected is False
|
||||||
|
mock_ssh_client.close.assert_called_once()
|
||||||
|
|
||||||
|
def test_interface_validates_configuration(self):
|
||||||
|
"""Test that router interface validates configuration parameters"""
|
||||||
|
# Arrange
|
||||||
|
invalid_config = {
|
||||||
|
'router_ip': '', # Invalid IP
|
||||||
|
'username': 'admin',
|
||||||
|
'password': 'password'
|
||||||
|
}
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
RouterInterface(invalid_config)
|
||||||
|
|
||||||
|
@patch('paramiko.SSHClient')
|
||||||
|
def test_interface_implements_retry_logic(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||||
|
"""Test that interface implements retry logic for failed operations"""
|
||||||
|
# Arrange
|
||||||
|
mock_ssh_class.return_value = mock_ssh_client
|
||||||
|
mock_ssh_client.connect.side_effect = [Exception("Temp failure"), None] # Fail once, then succeed
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = router_interface.connect()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is True
|
||||||
|
assert mock_ssh_client.connect.call_count == 2 # Should retry once
|
||||||
Reference in New Issue
Block a user