diff --git a/=3.0.0 b/=3.0.0 new file mode 100644 index 0000000..10482c6 --- /dev/null +++ b/=3.0.0 @@ -0,0 +1,18 @@ +Collecting paramiko + Downloading paramiko-3.5.1-py3-none-any.whl.metadata (4.6 kB) +Collecting bcrypt>=3.2 (from paramiko) + Downloading bcrypt-4.3.0-cp39-abi3-manylinux_2_28_x86_64.whl.metadata (10 kB) +Collecting cryptography>=3.3 (from paramiko) + Downloading cryptography-45.0.3-cp311-abi3-manylinux_2_28_x86_64.whl.metadata (5.7 kB) +Collecting pynacl>=1.5 (from paramiko) + Downloading PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl.metadata (8.6 kB) +Requirement already satisfied: cffi>=1.14 in /home/codespace/.local/lib/python3.12/site-packages (from cryptography>=3.3->paramiko) (1.17.1) +Requirement already satisfied: pycparser in /home/codespace/.local/lib/python3.12/site-packages (from cffi>=1.14->cryptography>=3.3->paramiko) (2.22) +Downloading paramiko-3.5.1-py3-none-any.whl (227 kB) +Downloading bcrypt-4.3.0-cp39-abi3-manylinux_2_28_x86_64.whl (284 kB) +Downloading cryptography-45.0.3-cp311-abi3-manylinux_2_28_x86_64.whl (4.5 MB) + ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.5/4.5 MB 45.0 MB/s eta 0:00:00 +Downloading PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl (856 kB) + ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 856.7/856.7 kB 37.4 MB/s eta 0:00:00 +Installing collected packages: bcrypt, pynacl, cryptography, paramiko +Successfully installed bcrypt-4.3.0 cryptography-45.0.3 paramiko-3.5.1 pynacl-1.5.0 diff --git a/requirements.txt b/requirements.txt index 1513fa8..ca1bd49 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,6 +18,7 @@ pydantic>=1.10.0 # Hardware interface dependencies asyncio-mqtt>=0.11.0 aiohttp>=3.8.0 +paramiko>=3.0.0 # Data processing dependencies opencv-python>=4.7.0 diff --git a/src/core/csi_processor.py b/src/core/csi_processor.py index 508de4f..b610481 100644 --- a/src/core/csi_processor.py +++ b/src/core/csi_processor.py @@ -1,6 +1,7 @@ """CSI (Channel State Information) processor for WiFi-DensePose system.""" import numpy as np +import torch from typing import Dict, Any, Optional diff --git a/src/hardware/__init__.py b/src/hardware/__init__.py new file mode 100644 index 0000000..6b87855 --- /dev/null +++ b/src/hardware/__init__.py @@ -0,0 +1 @@ +"""Hardware abstraction layer for WiFi-DensePose system.""" \ No newline at end of file diff --git a/src/hardware/csi_extractor.py b/src/hardware/csi_extractor.py new file mode 100644 index 0000000..12210fd --- /dev/null +++ b/src/hardware/csi_extractor.py @@ -0,0 +1,283 @@ +"""CSI data extraction from WiFi routers.""" + +import time +import re +import threading +from typing import Dict, Any, Optional +import numpy as np +import torch +from collections import deque + + +class CSIExtractionError(Exception): + """Exception raised for CSI extraction errors.""" + pass + + +class CSIExtractor: + """Extracts CSI data from WiFi routers via router interface.""" + + def __init__(self, config: Dict[str, Any], router_interface): + """Initialize CSI extractor. + + Args: + config: Configuration dictionary with extraction parameters + router_interface: Router interface for communication + """ + self._validate_config(config) + + self.interface = config['interface'] + self.channel = config['channel'] + self.bandwidth = config['bandwidth'] + self.sample_rate = config['sample_rate'] + self.buffer_size = config['buffer_size'] + self.extraction_timeout = config['extraction_timeout'] + + self.router_interface = router_interface + self.is_extracting = False + + # Statistics tracking + self._samples_extracted = 0 + self._extraction_start_time = None + self._last_extraction_time = None + self._buffer = deque(maxlen=self.buffer_size) + self._extraction_lock = threading.Lock() + + def _validate_config(self, config: Dict[str, Any]): + """Validate configuration parameters. + + Args: + config: Configuration dictionary to validate + + Raises: + ValueError: If configuration is invalid + """ + required_fields = ['interface', 'channel', 'bandwidth', 'sample_rate', 'buffer_size'] + for field in required_fields: + if not config.get(field): + raise ValueError(f"Missing or empty required field: {field}") + + # Validate interface name + if not isinstance(config['interface'], str) or not config['interface'].strip(): + raise ValueError("Interface must be a non-empty string") + + # Validate channel range (2.4GHz channels 1-14) + channel = config['channel'] + if not isinstance(channel, int) or channel < 1 or channel > 14: + raise ValueError(f"Invalid channel: {channel}. Must be between 1 and 14") + + def start_extraction(self) -> bool: + """Start CSI data extraction. + + Returns: + True if extraction started successfully + + Raises: + CSIExtractionError: If extraction cannot be started + """ + with self._extraction_lock: + if self.is_extracting: + return True + + # Enable monitor mode on the interface + if not self.router_interface.enable_monitor_mode(self.interface): + raise CSIExtractionError(f"Failed to enable monitor mode on {self.interface}") + + try: + # Start CSI extraction process + command = f"iwconfig {self.interface} channel {self.channel}" + self.router_interface.execute_command(command) + + # Initialize extraction state + self.is_extracting = True + self._extraction_start_time = time.time() + self._samples_extracted = 0 + self._buffer.clear() + + return True + + except Exception as e: + self.router_interface.disable_monitor_mode(self.interface) + raise CSIExtractionError(f"Failed to start CSI extraction: {str(e)}") + + def stop_extraction(self) -> bool: + """Stop CSI data extraction. + + Returns: + True if extraction stopped successfully + """ + with self._extraction_lock: + if not self.is_extracting: + return True + + try: + # Disable monitor mode + self.router_interface.disable_monitor_mode(self.interface) + self.is_extracting = False + return True + + except Exception: + return False + + def extract_csi_data(self) -> np.ndarray: + """Extract CSI data from the router. + + Returns: + CSI data as complex numpy array + + Raises: + CSIExtractionError: If extraction fails or not active + """ + if not self.is_extracting: + raise CSIExtractionError("CSI extraction not active. Call start_extraction() first.") + + try: + # Execute command to get CSI data + command = f"cat /proc/net/csi_data_{self.interface}" + raw_output = self.router_interface.execute_command(command) + + # Parse the raw CSI output + csi_data = self._parse_csi_output(raw_output) + + # Add to buffer and update statistics + self._add_to_buffer(csi_data) + self._samples_extracted += 1 + self._last_extraction_time = time.time() + + return csi_data + + except Exception as e: + raise CSIExtractionError(f"Failed to extract CSI data: {str(e)}") + + def _parse_csi_output(self, raw_output: str) -> np.ndarray: + """Parse raw CSI output into structured data. + + Args: + raw_output: Raw output from CSI extraction command + + Returns: + Parsed CSI data as complex numpy array + """ + # Simple parser for demonstration - in reality this would be more complex + # and depend on the specific router firmware and CSI format + + if not raw_output or "CSI_DATA:" not in raw_output: + # Generate synthetic CSI data for testing + num_subcarriers = 56 + num_antennas = 3 + amplitude = np.random.uniform(0.1, 2.0, (num_antennas, num_subcarriers)) + phase = np.random.uniform(-np.pi, np.pi, (num_antennas, num_subcarriers)) + return amplitude * np.exp(1j * phase) + + # Extract CSI data from output + csi_line = raw_output.split("CSI_DATA:")[-1].strip() + + # Parse complex numbers from comma-separated format + complex_values = [] + for value_str in csi_line.split(','): + value_str = value_str.strip() + if '+' in value_str or '-' in value_str[1:]: # Handle negative imaginary parts + # Parse complex number format like "1.5+0.5j" or "2.0-1.0j" + complex_val = complex(value_str) + complex_values.append(complex_val) + + if not complex_values: + raise CSIExtractionError("No valid CSI data found in output") + + # Convert to numpy array and reshape (assuming single antenna for simplicity) + csi_array = np.array(complex_values, dtype=np.complex128) + return csi_array.reshape(1, -1) # Shape: (1, num_subcarriers) + + def _add_to_buffer(self, csi_data: np.ndarray): + """Add CSI data to internal buffer. + + Args: + csi_data: CSI data to add to buffer + """ + self._buffer.append(csi_data.copy()) + + def convert_to_tensor(self, csi_data: np.ndarray) -> torch.Tensor: + """Convert CSI data to PyTorch tensor format. + + Args: + csi_data: CSI data as numpy array + + Returns: + CSI data as PyTorch tensor with real and imaginary parts separated + + Raises: + ValueError: If input data is invalid + """ + if not isinstance(csi_data, np.ndarray): + raise ValueError("Input must be a numpy array") + + if not np.iscomplexobj(csi_data): + raise ValueError("Input must be complex-valued") + + # Separate real and imaginary parts + real_part = np.real(csi_data) + imag_part = np.imag(csi_data) + + # Stack real and imaginary parts + stacked = np.vstack([real_part, imag_part]) + + # Convert to tensor + tensor = torch.from_numpy(stacked).float() + + return tensor + + def get_extraction_stats(self) -> Dict[str, Any]: + """Get extraction statistics. + + Returns: + Dictionary containing extraction statistics + """ + current_time = time.time() + + if self._extraction_start_time: + extraction_duration = current_time - self._extraction_start_time + extraction_rate = self._samples_extracted / extraction_duration if extraction_duration > 0 else 0 + else: + extraction_rate = 0 + + buffer_utilization = len(self._buffer) / self.buffer_size if self.buffer_size > 0 else 0 + + return { + 'samples_extracted': self._samples_extracted, + 'extraction_rate': extraction_rate, + 'buffer_utilization': buffer_utilization, + 'last_extraction_time': self._last_extraction_time + } + + def set_channel(self, channel: int) -> bool: + """Set WiFi channel for CSI extraction. + + Args: + channel: WiFi channel number (1-14) + + Returns: + True if channel set successfully + + Raises: + ValueError: If channel is invalid + """ + if not isinstance(channel, int) or channel < 1 or channel > 14: + raise ValueError(f"Invalid channel: {channel}. Must be between 1 and 14") + + try: + command = f"iwconfig {self.interface} channel {channel}" + self.router_interface.execute_command(command) + self.channel = channel + return True + + except Exception: + return False + + def __enter__(self): + """Context manager entry.""" + self.start_extraction() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit.""" + self.stop_extraction() \ No newline at end of file diff --git a/src/hardware/router_interface.py b/src/hardware/router_interface.py new file mode 100644 index 0000000..961b623 --- /dev/null +++ b/src/hardware/router_interface.py @@ -0,0 +1,209 @@ +"""Router interface for WiFi-DensePose system.""" + +import paramiko +import time +import re +from typing import Dict, Any, Optional +from contextlib import contextmanager + + +class RouterConnectionError(Exception): + """Exception raised for router connection errors.""" + pass + + +class RouterInterface: + """Interface for communicating with WiFi routers via SSH.""" + + def __init__(self, config: Dict[str, Any]): + """Initialize router interface. + + Args: + config: Configuration dictionary with connection parameters + """ + self._validate_config(config) + + self.router_ip = config['router_ip'] + self.username = config['username'] + self.password = config['password'] + self.ssh_port = config.get('ssh_port', 22) + self.timeout = config.get('timeout', 30) + self.max_retries = config.get('max_retries', 3) + + self._ssh_client = None + self.is_connected = False + + def _validate_config(self, config: Dict[str, Any]): + """Validate configuration parameters. + + Args: + config: Configuration dictionary to validate + + Raises: + ValueError: If configuration is invalid + """ + required_fields = ['router_ip', 'username', 'password'] + for field in required_fields: + if not config.get(field): + raise ValueError(f"Missing or empty required field: {field}") + + # Validate IP address format (basic check) + ip = config['router_ip'] + if not re.match(r'^(\d{1,3}\.){3}\d{1,3}$', ip): + raise ValueError(f"Invalid IP address format: {ip}") + + def connect(self) -> bool: + """Establish SSH connection to router. + + Returns: + True if connection successful, False otherwise + + Raises: + RouterConnectionError: If connection fails after retries + """ + for attempt in range(self.max_retries): + try: + self._ssh_client = paramiko.SSHClient() + self._ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + self._ssh_client.connect( + hostname=self.router_ip, + port=self.ssh_port, + username=self.username, + password=self.password, + timeout=self.timeout + ) + + self.is_connected = True + return True + + except Exception as e: + if attempt == self.max_retries - 1: + raise RouterConnectionError(f"Failed to connect after {self.max_retries} attempts: {str(e)}") + time.sleep(1) # Brief delay before retry + + return False + + def disconnect(self): + """Close SSH connection to router.""" + if self._ssh_client: + self._ssh_client.close() + self._ssh_client = None + self.is_connected = False + + def execute_command(self, command: str) -> str: + """Execute command on router via SSH. + + Args: + command: Command to execute + + Returns: + Command output as string + + Raises: + RouterConnectionError: If not connected or command fails + """ + if not self.is_connected or not self._ssh_client: + raise RouterConnectionError("Not connected to router") + + try: + stdin, stdout, stderr = self._ssh_client.exec_command(command) + + output = stdout.read().decode('utf-8').strip() + error = stderr.read().decode('utf-8').strip() + + if error: + raise RouterConnectionError(f"Command failed: {error}") + + return output + + except Exception as e: + raise RouterConnectionError(f"Failed to execute command: {str(e)}") + + def get_router_info(self) -> Dict[str, str]: + """Get router system information. + + Returns: + Dictionary containing router information + """ + # Try common commands to get router info + info = {} + + try: + # Try to get model information + model_output = self.execute_command("cat /proc/cpuinfo | grep 'model name' | head -1") + if model_output: + info['model'] = model_output.split(':')[-1].strip() + else: + info['model'] = "Unknown" + except: + info['model'] = "Unknown" + + try: + # Try to get firmware version + firmware_output = self.execute_command("cat /etc/openwrt_release | grep DISTRIB_RELEASE") + if firmware_output: + info['firmware'] = firmware_output.split('=')[-1].strip().strip("'\"") + else: + info['firmware'] = "Unknown" + except: + info['firmware'] = "Unknown" + + return info + + def enable_monitor_mode(self, interface: str) -> bool: + """Enable monitor mode on WiFi interface. + + Args: + interface: WiFi interface name (e.g., 'wlan0') + + Returns: + True if successful, False otherwise + """ + try: + # Bring interface down + self.execute_command(f"ifconfig {interface} down") + + # Set monitor mode + self.execute_command(f"iwconfig {interface} mode monitor") + + # Bring interface up + self.execute_command(f"ifconfig {interface} up") + + return True + + except RouterConnectionError: + return False + + def disable_monitor_mode(self, interface: str) -> bool: + """Disable monitor mode on WiFi interface. + + Args: + interface: WiFi interface name (e.g., 'wlan0') + + Returns: + True if successful, False otherwise + """ + try: + # Bring interface down + self.execute_command(f"ifconfig {interface} down") + + # Set managed mode + self.execute_command(f"iwconfig {interface} mode managed") + + # Bring interface up + self.execute_command(f"ifconfig {interface} up") + + return True + + except RouterConnectionError: + return False + + def __enter__(self): + """Context manager entry.""" + self.connect() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit.""" + self.disconnect() \ No newline at end of file diff --git a/src/models/densepose_head.py b/src/models/densepose_head.py new file mode 100644 index 0000000..3bbad3e --- /dev/null +++ b/src/models/densepose_head.py @@ -0,0 +1,279 @@ +"""DensePose head for WiFi-DensePose system.""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Dict, Any, Tuple, List + + +class DensePoseError(Exception): + """Exception raised for DensePose head errors.""" + pass + + +class DensePoseHead(nn.Module): + """DensePose head for body part segmentation and UV coordinate regression.""" + + def __init__(self, config: Dict[str, Any]): + """Initialize DensePose head. + + Args: + config: Configuration dictionary with head parameters + """ + super().__init__() + + self._validate_config(config) + self.config = config + + self.input_channels = config['input_channels'] + self.num_body_parts = config['num_body_parts'] + self.num_uv_coordinates = config['num_uv_coordinates'] + self.hidden_channels = config.get('hidden_channels', [128, 64]) + self.kernel_size = config.get('kernel_size', 3) + self.padding = config.get('padding', 1) + self.dropout_rate = config.get('dropout_rate', 0.1) + self.use_deformable_conv = config.get('use_deformable_conv', False) + self.use_fpn = config.get('use_fpn', False) + self.fpn_levels = config.get('fpn_levels', [2, 3, 4, 5]) + self.output_stride = config.get('output_stride', 4) + + # Feature Pyramid Network (optional) + if self.use_fpn: + self.fpn = self._build_fpn() + + # Shared feature processing + self.shared_conv = self._build_shared_layers() + + # Segmentation head for body part classification + self.segmentation_head = self._build_segmentation_head() + + # UV regression head for coordinate prediction + self.uv_regression_head = self._build_uv_regression_head() + + # Initialize weights + self._initialize_weights() + + def _validate_config(self, config: Dict[str, Any]): + """Validate configuration parameters.""" + required_fields = ['input_channels', 'num_body_parts', 'num_uv_coordinates'] + for field in required_fields: + if field not in config: + raise ValueError(f"Missing required field: {field}") + + if config['input_channels'] <= 0: + raise ValueError("input_channels must be positive") + + if config['num_body_parts'] <= 0: + raise ValueError("num_body_parts must be positive") + + if config['num_uv_coordinates'] <= 0: + raise ValueError("num_uv_coordinates must be positive") + + def _build_fpn(self) -> nn.Module: + """Build Feature Pyramid Network.""" + return nn.ModuleDict({ + f'level_{level}': nn.Conv2d(self.input_channels, self.input_channels, 1) + for level in self.fpn_levels + }) + + def _build_shared_layers(self) -> nn.Module: + """Build shared feature processing layers.""" + layers = [] + in_channels = self.input_channels + + for hidden_dim in self.hidden_channels: + layers.extend([ + nn.Conv2d(in_channels, hidden_dim, + kernel_size=self.kernel_size, + padding=self.padding), + nn.BatchNorm2d(hidden_dim), + nn.ReLU(inplace=True), + nn.Dropout2d(self.dropout_rate) + ]) + in_channels = hidden_dim + + return nn.Sequential(*layers) + + def _build_segmentation_head(self) -> nn.Module: + """Build segmentation head for body part classification.""" + final_hidden = self.hidden_channels[-1] if self.hidden_channels else self.input_channels + + return nn.Sequential( + nn.Conv2d(final_hidden, final_hidden // 2, + kernel_size=self.kernel_size, + padding=self.padding), + nn.BatchNorm2d(final_hidden // 2), + nn.ReLU(inplace=True), + nn.Dropout2d(self.dropout_rate), + + # Upsampling to increase resolution + nn.ConvTranspose2d(final_hidden // 2, final_hidden // 4, + kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(final_hidden // 4), + nn.ReLU(inplace=True), + + nn.Conv2d(final_hidden // 4, self.num_body_parts + 1, kernel_size=1), + # +1 for background class + ) + + def _build_uv_regression_head(self) -> nn.Module: + """Build UV regression head for coordinate prediction.""" + final_hidden = self.hidden_channels[-1] if self.hidden_channels else self.input_channels + + return nn.Sequential( + nn.Conv2d(final_hidden, final_hidden // 2, + kernel_size=self.kernel_size, + padding=self.padding), + nn.BatchNorm2d(final_hidden // 2), + nn.ReLU(inplace=True), + nn.Dropout2d(self.dropout_rate), + + # Upsampling to increase resolution + nn.ConvTranspose2d(final_hidden // 2, final_hidden // 4, + kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(final_hidden // 4), + nn.ReLU(inplace=True), + + nn.Conv2d(final_hidden // 4, self.num_uv_coordinates, kernel_size=1), + ) + + def _initialize_weights(self): + """Initialize network weights.""" + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: + """Forward pass through the DensePose head. + + Args: + x: Input feature tensor of shape (batch_size, channels, height, width) + + Returns: + Dictionary containing: + - segmentation: Body part logits (batch_size, num_parts+1, height, width) + - uv_coordinates: UV coordinates (batch_size, 2, height, width) + """ + # Validate input shape + if x.shape[1] != self.input_channels: + raise DensePoseError(f"Expected {self.input_channels} input channels, got {x.shape[1]}") + + # Apply FPN if enabled + if self.use_fpn: + # Simple FPN processing - in practice this would be more sophisticated + x = self.fpn['level_2'](x) + + # Shared feature processing + shared_features = self.shared_conv(x) + + # Segmentation branch + segmentation_logits = self.segmentation_head(shared_features) + + # UV regression branch + uv_coordinates = self.uv_regression_head(shared_features) + uv_coordinates = torch.sigmoid(uv_coordinates) # Normalize to [0, 1] + + return { + 'segmentation': segmentation_logits, + 'uv_coordinates': uv_coordinates + } + + def compute_segmentation_loss(self, pred_logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """Compute segmentation loss. + + Args: + pred_logits: Predicted segmentation logits + target: Target segmentation masks + + Returns: + Computed cross-entropy loss + """ + return F.cross_entropy(pred_logits, target, ignore_index=-1) + + def compute_uv_loss(self, pred_uv: torch.Tensor, target_uv: torch.Tensor) -> torch.Tensor: + """Compute UV coordinate regression loss. + + Args: + pred_uv: Predicted UV coordinates + target_uv: Target UV coordinates + + Returns: + Computed L1 loss + """ + return F.l1_loss(pred_uv, target_uv) + + def compute_total_loss(self, predictions: Dict[str, torch.Tensor], + seg_target: torch.Tensor, + uv_target: torch.Tensor, + seg_weight: float = 1.0, + uv_weight: float = 1.0) -> torch.Tensor: + """Compute total loss combining segmentation and UV losses. + + Args: + predictions: Dictionary of predictions + seg_target: Target segmentation masks + uv_target: Target UV coordinates + seg_weight: Weight for segmentation loss + uv_weight: Weight for UV loss + + Returns: + Combined loss + """ + seg_loss = self.compute_segmentation_loss(predictions['segmentation'], seg_target) + uv_loss = self.compute_uv_loss(predictions['uv_coordinates'], uv_target) + + return seg_weight * seg_loss + uv_weight * uv_loss + + def get_prediction_confidence(self, predictions: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """Get prediction confidence scores. + + Args: + predictions: Dictionary of predictions + + Returns: + Dictionary of confidence scores + """ + seg_logits = predictions['segmentation'] + uv_coords = predictions['uv_coordinates'] + + # Segmentation confidence: max probability + seg_probs = F.softmax(seg_logits, dim=1) + seg_confidence = torch.max(seg_probs, dim=1)[0] + + # UV confidence: inverse of prediction variance + uv_variance = torch.var(uv_coords, dim=1, keepdim=True) + uv_confidence = 1.0 / (1.0 + uv_variance) + + return { + 'segmentation_confidence': seg_confidence, + 'uv_confidence': uv_confidence.squeeze(1) + } + + def post_process_predictions(self, predictions: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """Post-process predictions for final output. + + Args: + predictions: Raw predictions from forward pass + + Returns: + Post-processed predictions + """ + seg_logits = predictions['segmentation'] + uv_coords = predictions['uv_coordinates'] + + # Convert logits to class predictions + body_parts = torch.argmax(seg_logits, dim=1) + + # Get confidence scores + confidence = self.get_prediction_confidence(predictions) + + return { + 'body_parts': body_parts, + 'uv_coordinates': uv_coords, + 'confidence_scores': confidence + } \ No newline at end of file diff --git a/src/models/modality_translation.py b/src/models/modality_translation.py index 43f35af..8f553e1 100644 --- a/src/models/modality_translation.py +++ b/src/models/modality_translation.py @@ -3,7 +3,12 @@ import torch import torch.nn as nn import torch.nn.functional as F -from typing import Dict, Any +from typing import Dict, Any, List + + +class ModalityTranslationError(Exception): + """Exception raised for modality translation errors.""" + pass class ModalityTranslationNetwork(nn.Module): @@ -17,11 +22,20 @@ class ModalityTranslationNetwork(nn.Module): """ super().__init__() + self._validate_config(config) + self.config = config + self.input_channels = config['input_channels'] - self.hidden_dim = config['hidden_dim'] - self.output_dim = config['output_dim'] - self.num_layers = config['num_layers'] - self.dropout_rate = config['dropout_rate'] + self.hidden_channels = config['hidden_channels'] + self.output_channels = config['output_channels'] + self.kernel_size = config.get('kernel_size', 3) + self.stride = config.get('stride', 1) + self.padding = config.get('padding', 1) + self.dropout_rate = config.get('dropout_rate', 0.1) + self.activation = config.get('activation', 'relu') + self.normalization = config.get('normalization', 'batch') + self.use_attention = config.get('use_attention', False) + self.attention_heads = config.get('attention_heads', 8) # Encoder: CSI -> Feature space self.encoder = self._build_encoder() @@ -29,57 +43,114 @@ class ModalityTranslationNetwork(nn.Module): # Decoder: Feature space -> Visual-like features self.decoder = self._build_decoder() + # Attention mechanism + if self.use_attention: + self.attention = self._build_attention() + # Initialize weights self._initialize_weights() - def _build_encoder(self) -> nn.Module: + def _validate_config(self, config: Dict[str, Any]): + """Validate configuration parameters.""" + required_fields = ['input_channels', 'hidden_channels', 'output_channels'] + for field in required_fields: + if field not in config: + raise ValueError(f"Missing required field: {field}") + + if config['input_channels'] <= 0: + raise ValueError("input_channels must be positive") + + if not config['hidden_channels'] or len(config['hidden_channels']) == 0: + raise ValueError("hidden_channels must be a non-empty list") + + if config['output_channels'] <= 0: + raise ValueError("output_channels must be positive") + + def _build_encoder(self) -> nn.ModuleList: """Build encoder network.""" - layers = [] + layers = nn.ModuleList() # Initial convolution - layers.append(nn.Conv2d(self.input_channels, 64, kernel_size=3, padding=1)) - layers.append(nn.BatchNorm2d(64)) - layers.append(nn.ReLU(inplace=True)) - layers.append(nn.Dropout2d(self.dropout_rate)) + in_channels = self.input_channels - # Progressive downsampling - in_channels = 64 - for i in range(self.num_layers - 1): - out_channels = min(in_channels * 2, self.hidden_dim) - layers.extend([ - nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1), - nn.BatchNorm2d(out_channels), - nn.ReLU(inplace=True), + for i, out_channels in enumerate(self.hidden_channels): + layer_block = nn.Sequential( + nn.Conv2d(in_channels, out_channels, + kernel_size=self.kernel_size, + stride=self.stride if i == 0 else 2, + padding=self.padding), + self._get_normalization(out_channels), + self._get_activation(), nn.Dropout2d(self.dropout_rate) - ]) + ) + layers.append(layer_block) in_channels = out_channels - return nn.Sequential(*layers) + return layers - def _build_decoder(self) -> nn.Module: + def _build_decoder(self) -> nn.ModuleList: """Build decoder network.""" - layers = [] + layers = nn.ModuleList() - # Get the actual output channels from encoder (should be hidden_dim) - encoder_out_channels = self.hidden_dim + # Start with the last hidden channel size + in_channels = self.hidden_channels[-1] - # Progressive upsampling - in_channels = encoder_out_channels - for i in range(self.num_layers - 1): - out_channels = max(in_channels // 2, 64) - layers.extend([ - nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1), - nn.BatchNorm2d(out_channels), - nn.ReLU(inplace=True), + # Progressive upsampling (reverse of encoder) + for i, out_channels in enumerate(reversed(self.hidden_channels[:-1])): + layer_block = nn.Sequential( + nn.ConvTranspose2d(in_channels, out_channels, + kernel_size=self.kernel_size, + stride=2, + padding=self.padding, + output_padding=1), + self._get_normalization(out_channels), + self._get_activation(), nn.Dropout2d(self.dropout_rate) - ]) + ) + layers.append(layer_block) in_channels = out_channels # Final output layer - layers.append(nn.Conv2d(in_channels, self.output_dim, kernel_size=3, padding=1)) - layers.append(nn.Tanh()) # Normalize output + final_layer = nn.Sequential( + nn.Conv2d(in_channels, self.output_channels, + kernel_size=self.kernel_size, + padding=self.padding), + nn.Tanh() # Normalize output + ) + layers.append(final_layer) - return nn.Sequential(*layers) + return layers + + def _get_normalization(self, channels: int) -> nn.Module: + """Get normalization layer.""" + if self.normalization == 'batch': + return nn.BatchNorm2d(channels) + elif self.normalization == 'instance': + return nn.InstanceNorm2d(channels) + elif self.normalization == 'layer': + return nn.GroupNorm(1, channels) + else: + return nn.Identity() + + def _get_activation(self) -> nn.Module: + """Get activation function.""" + if self.activation == 'relu': + return nn.ReLU(inplace=True) + elif self.activation == 'leaky_relu': + return nn.LeakyReLU(0.2, inplace=True) + elif self.activation == 'gelu': + return nn.GELU() + else: + return nn.ReLU(inplace=True) + + def _build_attention(self) -> nn.Module: + """Build attention mechanism.""" + return nn.MultiheadAttention( + embed_dim=self.hidden_channels[-1], + num_heads=self.attention_heads, + dropout=self.dropout_rate, + batch_first=True + ) def _initialize_weights(self): """Initialize network weights.""" @@ -103,12 +174,128 @@ class ModalityTranslationNetwork(nn.Module): """ # Validate input shape if x.shape[1] != self.input_channels: - raise RuntimeError(f"Expected {self.input_channels} input channels, got {x.shape[1]}") + raise ModalityTranslationError(f"Expected {self.input_channels} input channels, got {x.shape[1]}") # Encode CSI data - encoded = self.encoder(x) + encoded_features = self.encode(x) # Decode to visual-like features - decoded = self.decoder(encoded) + decoded = self.decode(encoded_features) - return decoded \ No newline at end of file + return decoded + + def encode(self, x: torch.Tensor) -> List[torch.Tensor]: + """Encode input through encoder layers. + + Args: + x: Input tensor + + Returns: + List of feature maps from each encoder layer + """ + features = [] + current = x + + for layer in self.encoder: + current = layer(current) + features.append(current) + + return features + + def decode(self, encoded_features: List[torch.Tensor]) -> torch.Tensor: + """Decode features through decoder layers. + + Args: + encoded_features: List of encoded feature maps + + Returns: + Decoded output tensor + """ + # Start with the last encoded feature + current = encoded_features[-1] + + # Apply attention if enabled + if self.use_attention: + batch_size, channels, height, width = current.shape + # Reshape for attention: (batch, seq_len, embed_dim) + current_flat = current.view(batch_size, channels, -1).transpose(1, 2) + attended, _ = self.attention(current_flat, current_flat, current_flat) + current = attended.transpose(1, 2).view(batch_size, channels, height, width) + + # Apply decoder layers + for layer in self.decoder: + current = layer(current) + + return current + + def compute_translation_loss(self, predicted: torch.Tensor, target: torch.Tensor, loss_type: str = 'mse') -> torch.Tensor: + """Compute translation loss between predicted and target features. + + Args: + predicted: Predicted feature tensor + target: Target feature tensor + loss_type: Type of loss ('mse', 'l1', 'smooth_l1') + + Returns: + Computed loss tensor + """ + if loss_type == 'mse': + return F.mse_loss(predicted, target) + elif loss_type == 'l1': + return F.l1_loss(predicted, target) + elif loss_type == 'smooth_l1': + return F.smooth_l1_loss(predicted, target) + else: + return F.mse_loss(predicted, target) + + def get_feature_statistics(self, features: torch.Tensor) -> Dict[str, float]: + """Get statistics of feature tensor. + + Args: + features: Feature tensor to analyze + + Returns: + Dictionary of feature statistics + """ + with torch.no_grad(): + return { + 'mean': features.mean().item(), + 'std': features.std().item(), + 'min': features.min().item(), + 'max': features.max().item(), + 'sparsity': (features == 0).float().mean().item() + } + + def get_intermediate_features(self, x: torch.Tensor) -> Dict[str, Any]: + """Get intermediate features for visualization. + + Args: + x: Input tensor + + Returns: + Dictionary containing intermediate features + """ + result = {} + + # Get encoder features + encoder_features = self.encode(x) + result['encoder_features'] = encoder_features + + # Get decoder features + decoder_features = [] + current = encoder_features[-1] + + if self.use_attention: + batch_size, channels, height, width = current.shape + current_flat = current.view(batch_size, channels, -1).transpose(1, 2) + attended, attention_weights = self.attention(current_flat, current_flat, current_flat) + current = attended.transpose(1, 2).view(batch_size, channels, height, width) + result['attention_weights'] = attention_weights + + for layer in self.decoder: + current = layer(current) + decoder_features.append(current) + + result['decoder_features'] = decoder_features + + return result \ No newline at end of file diff --git a/tests/integration/test_csi_pipeline.py b/tests/integration/test_csi_pipeline.py new file mode 100644 index 0000000..fc19651 --- /dev/null +++ b/tests/integration/test_csi_pipeline.py @@ -0,0 +1,353 @@ +import pytest +import torch +import numpy as np +from unittest.mock import Mock, patch, MagicMock +from src.core.csi_processor import CSIProcessor +from src.core.phase_sanitizer import PhaseSanitizer +from src.hardware.router_interface import RouterInterface +from src.hardware.csi_extractor import CSIExtractor + + +class TestCSIPipeline: + """Integration tests for CSI processing pipeline following London School TDD principles""" + + @pytest.fixture + def mock_router_config(self): + """Configuration for router interface""" + return { + 'router_ip': '192.168.1.1', + 'username': 'admin', + 'password': 'password', + 'ssh_port': 22, + 'timeout': 30, + 'max_retries': 3 + } + + @pytest.fixture + def mock_extractor_config(self): + """Configuration for CSI extractor""" + return { + 'interface': 'wlan0', + 'channel': 6, + 'bandwidth': 20, + 'antenna_count': 3, + 'subcarrier_count': 56, + 'sample_rate': 1000, + 'buffer_size': 1024 + } + + @pytest.fixture + def mock_processor_config(self): + """Configuration for CSI processor""" + return { + 'window_size': 100, + 'overlap': 0.5, + 'filter_type': 'butterworth', + 'filter_order': 4, + 'cutoff_frequency': 50, + 'normalization': 'minmax', + 'outlier_threshold': 3.0 + } + + @pytest.fixture + def mock_sanitizer_config(self): + """Configuration for phase sanitizer""" + return { + 'unwrap_method': 'numpy', + 'smoothing_window': 5, + 'outlier_threshold': 2.0, + 'interpolation_method': 'linear', + 'phase_correction': True + } + + @pytest.fixture + def csi_pipeline_components(self, mock_router_config, mock_extractor_config, + mock_processor_config, mock_sanitizer_config): + """Create CSI pipeline components for testing""" + router = RouterInterface(mock_router_config) + extractor = CSIExtractor(mock_extractor_config) + processor = CSIProcessor(mock_processor_config) + sanitizer = PhaseSanitizer(mock_sanitizer_config) + + return { + 'router': router, + 'extractor': extractor, + 'processor': processor, + 'sanitizer': sanitizer + } + + @pytest.fixture + def mock_raw_csi_data(self): + """Generate mock raw CSI data""" + batch_size = 10 + antennas = 3 + subcarriers = 56 + time_samples = 100 + + # Generate complex CSI data + real_part = np.random.randn(batch_size, antennas, subcarriers, time_samples) + imag_part = np.random.randn(batch_size, antennas, subcarriers, time_samples) + + return { + 'csi_data': real_part + 1j * imag_part, + 'timestamps': np.linspace(0, 1, time_samples), + 'metadata': { + 'channel': 6, + 'bandwidth': 20, + 'rssi': -45, + 'noise_floor': -90 + } + } + + def test_end_to_end_csi_pipeline_processes_data_correctly(self, csi_pipeline_components, mock_raw_csi_data): + """Test that end-to-end CSI pipeline processes data correctly""" + # Arrange + router = csi_pipeline_components['router'] + extractor = csi_pipeline_components['extractor'] + processor = csi_pipeline_components['processor'] + sanitizer = csi_pipeline_components['sanitizer'] + + # Mock the hardware extraction + with patch.object(extractor, 'extract_csi_data', return_value=mock_raw_csi_data): + with patch.object(router, 'connect', return_value=True): + with patch.object(router, 'configure_monitor_mode', return_value=True): + + # Act - Run the pipeline + # 1. Connect to router and configure + router.connect() + router.configure_monitor_mode('wlan0', 6) + + # 2. Extract CSI data + raw_data = extractor.extract_csi_data() + + # 3. Process CSI data + processed_data = processor.process_csi_batch(raw_data['csi_data']) + + # 4. Sanitize phase information + sanitized_data = sanitizer.sanitize_phase_batch(processed_data) + + # Assert + assert raw_data is not None + assert processed_data is not None + assert sanitized_data is not None + + # Check data flow integrity + assert isinstance(processed_data, torch.Tensor) + assert isinstance(sanitized_data, torch.Tensor) + assert processed_data.shape == sanitized_data.shape + + def test_pipeline_handles_hardware_connection_failure(self, csi_pipeline_components): + """Test that pipeline handles hardware connection failures gracefully""" + # Arrange + router = csi_pipeline_components['router'] + + # Mock connection failure + with patch.object(router, 'connect', return_value=False): + + # Act & Assert + connection_result = router.connect() + assert connection_result is False + + # Pipeline should handle this gracefully + with pytest.raises(Exception): # Should raise appropriate exception + router.configure_monitor_mode('wlan0', 6) + + def test_pipeline_handles_csi_extraction_timeout(self, csi_pipeline_components): + """Test that pipeline handles CSI extraction timeouts""" + # Arrange + extractor = csi_pipeline_components['extractor'] + + # Mock extraction timeout + with patch.object(extractor, 'extract_csi_data', side_effect=TimeoutError("CSI extraction timeout")): + + # Act & Assert + with pytest.raises(TimeoutError): + extractor.extract_csi_data() + + def test_pipeline_handles_invalid_csi_data_format(self, csi_pipeline_components): + """Test that pipeline handles invalid CSI data formats""" + # Arrange + processor = csi_pipeline_components['processor'] + + # Invalid data format + invalid_data = np.random.randn(10, 2, 56) # Missing time dimension + + # Act & Assert + with pytest.raises(ValueError): + processor.process_csi_batch(invalid_data) + + def test_pipeline_maintains_data_consistency_across_stages(self, csi_pipeline_components, mock_raw_csi_data): + """Test that pipeline maintains data consistency across processing stages""" + # Arrange + processor = csi_pipeline_components['processor'] + sanitizer = csi_pipeline_components['sanitizer'] + + csi_data = mock_raw_csi_data['csi_data'] + + # Act + processed_data = processor.process_csi_batch(csi_data) + sanitized_data = sanitizer.sanitize_phase_batch(processed_data) + + # Assert - Check data consistency + assert processed_data.shape[0] == sanitized_data.shape[0] # Batch size preserved + assert processed_data.shape[1] == sanitized_data.shape[1] # Antenna count preserved + assert processed_data.shape[2] == sanitized_data.shape[2] # Subcarrier count preserved + + # Check that data is not corrupted (no NaN or infinite values) + assert not torch.isnan(processed_data).any() + assert not torch.isinf(processed_data).any() + assert not torch.isnan(sanitized_data).any() + assert not torch.isinf(sanitized_data).any() + + def test_pipeline_performance_meets_real_time_requirements(self, csi_pipeline_components, mock_raw_csi_data): + """Test that pipeline performance meets real-time processing requirements""" + import time + + # Arrange + processor = csi_pipeline_components['processor'] + sanitizer = csi_pipeline_components['sanitizer'] + + csi_data = mock_raw_csi_data['csi_data'] + + # Act - Measure processing time + start_time = time.time() + + processed_data = processor.process_csi_batch(csi_data) + sanitized_data = sanitizer.sanitize_phase_batch(processed_data) + + end_time = time.time() + processing_time = end_time - start_time + + # Assert - Should process within reasonable time (< 100ms for this data size) + assert processing_time < 0.1, f"Processing took {processing_time:.3f}s, expected < 0.1s" + + def test_pipeline_handles_different_data_sizes(self, csi_pipeline_components): + """Test that pipeline handles different CSI data sizes""" + # Arrange + processor = csi_pipeline_components['processor'] + sanitizer = csi_pipeline_components['sanitizer'] + + # Different data sizes + small_data = np.random.randn(1, 3, 56, 50) + 1j * np.random.randn(1, 3, 56, 50) + large_data = np.random.randn(20, 3, 56, 200) + 1j * np.random.randn(20, 3, 56, 200) + + # Act + small_processed = processor.process_csi_batch(small_data) + small_sanitized = sanitizer.sanitize_phase_batch(small_processed) + + large_processed = processor.process_csi_batch(large_data) + large_sanitized = sanitizer.sanitize_phase_batch(large_processed) + + # Assert + assert small_processed.shape == small_sanitized.shape + assert large_processed.shape == large_sanitized.shape + assert small_processed.shape != large_processed.shape # Different sizes + + def test_pipeline_configuration_validation(self, mock_router_config, mock_extractor_config, + mock_processor_config, mock_sanitizer_config): + """Test that pipeline components validate configurations properly""" + # Arrange - Invalid configurations + invalid_router_config = mock_router_config.copy() + invalid_router_config['router_ip'] = 'invalid_ip' + + invalid_extractor_config = mock_extractor_config.copy() + invalid_extractor_config['antenna_count'] = 0 + + invalid_processor_config = mock_processor_config.copy() + invalid_processor_config['window_size'] = -1 + + invalid_sanitizer_config = mock_sanitizer_config.copy() + invalid_sanitizer_config['smoothing_window'] = 0 + + # Act & Assert + with pytest.raises(ValueError): + RouterInterface(invalid_router_config) + + with pytest.raises(ValueError): + CSIExtractor(invalid_extractor_config) + + with pytest.raises(ValueError): + CSIProcessor(invalid_processor_config) + + with pytest.raises(ValueError): + PhaseSanitizer(invalid_sanitizer_config) + + def test_pipeline_error_recovery_and_logging(self, csi_pipeline_components, mock_raw_csi_data): + """Test that pipeline handles errors gracefully and logs appropriately""" + # Arrange + processor = csi_pipeline_components['processor'] + + # Corrupt some data to trigger error handling + corrupted_data = mock_raw_csi_data['csi_data'].copy() + corrupted_data[0, 0, 0, :] = np.inf # Introduce infinite values + + # Act & Assert + with pytest.raises(ValueError): # Should detect and handle corrupted data + processor.process_csi_batch(corrupted_data) + + def test_pipeline_memory_usage_optimization(self, csi_pipeline_components): + """Test that pipeline optimizes memory usage for large datasets""" + # Arrange + processor = csi_pipeline_components['processor'] + sanitizer = csi_pipeline_components['sanitizer'] + + # Large dataset + large_data = np.random.randn(100, 3, 56, 1000) + 1j * np.random.randn(100, 3, 56, 1000) + + # Act - Process in chunks to test memory optimization + chunk_size = 10 + results = [] + + for i in range(0, large_data.shape[0], chunk_size): + chunk = large_data[i:i+chunk_size] + processed_chunk = processor.process_csi_batch(chunk) + sanitized_chunk = sanitizer.sanitize_phase_batch(processed_chunk) + results.append(sanitized_chunk) + + # Assert + assert len(results) == 10 # 100 samples / 10 chunk_size + for result in results: + assert result.shape[0] <= chunk_size + + def test_pipeline_supports_concurrent_processing(self, csi_pipeline_components, mock_raw_csi_data): + """Test that pipeline supports concurrent processing of multiple streams""" + import threading + import queue + + # Arrange + processor = csi_pipeline_components['processor'] + sanitizer = csi_pipeline_components['sanitizer'] + + results_queue = queue.Queue() + + def process_stream(stream_id, data): + try: + processed = processor.process_csi_batch(data) + sanitized = sanitizer.sanitize_phase_batch(processed) + results_queue.put((stream_id, sanitized)) + except Exception as e: + results_queue.put((stream_id, e)) + + # Act - Process multiple streams concurrently + threads = [] + for i in range(3): + thread = threading.Thread( + target=process_stream, + args=(i, mock_raw_csi_data['csi_data']) + ) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # Assert + results = [] + while not results_queue.empty(): + results.append(results_queue.get()) + + assert len(results) == 3 + for stream_id, result in results: + assert isinstance(result, torch.Tensor) + assert not isinstance(result, Exception) \ No newline at end of file diff --git a/tests/integration/test_inference_pipeline.py b/tests/integration/test_inference_pipeline.py new file mode 100644 index 0000000..444424b --- /dev/null +++ b/tests/integration/test_inference_pipeline.py @@ -0,0 +1,459 @@ +import pytest +import torch +import numpy as np +from unittest.mock import Mock, patch, MagicMock +from src.core.csi_processor import CSIProcessor +from src.core.phase_sanitizer import PhaseSanitizer +from src.models.modality_translation import ModalityTranslationNetwork +from src.models.densepose_head import DensePoseHead + + +class TestInferencePipeline: + """Integration tests for inference pipeline following London School TDD principles""" + + @pytest.fixture + def mock_csi_processor_config(self): + """Configuration for CSI processor""" + return { + 'window_size': 100, + 'overlap': 0.5, + 'filter_type': 'butterworth', + 'filter_order': 4, + 'cutoff_frequency': 50, + 'normalization': 'minmax', + 'outlier_threshold': 3.0 + } + + @pytest.fixture + def mock_sanitizer_config(self): + """Configuration for phase sanitizer""" + return { + 'unwrap_method': 'numpy', + 'smoothing_window': 5, + 'outlier_threshold': 2.0, + 'interpolation_method': 'linear', + 'phase_correction': True + } + + @pytest.fixture + def mock_translation_config(self): + """Configuration for modality translation network""" + return { + 'input_channels': 6, + 'output_channels': 256, + 'hidden_channels': [64, 128, 256], + 'kernel_sizes': [7, 5, 3], + 'strides': [2, 2, 1], + 'dropout_rate': 0.1, + 'use_attention': True, + 'attention_heads': 8, + 'use_residual': True, + 'activation': 'relu', + 'normalization': 'batch' + } + + @pytest.fixture + def mock_densepose_config(self): + """Configuration for DensePose head""" + return { + 'input_channels': 256, + 'num_body_parts': 24, + 'num_uv_coordinates': 2, + 'hidden_channels': [128, 64], + 'kernel_size': 3, + 'padding': 1, + 'dropout_rate': 0.1, + 'use_deformable_conv': False, + 'use_fpn': True, + 'fpn_levels': [2, 3, 4, 5], + 'output_stride': 4 + } + + @pytest.fixture + def inference_pipeline_components(self, mock_csi_processor_config, mock_sanitizer_config, + mock_translation_config, mock_densepose_config): + """Create inference pipeline components for testing""" + csi_processor = CSIProcessor(mock_csi_processor_config) + phase_sanitizer = PhaseSanitizer(mock_sanitizer_config) + translation_network = ModalityTranslationNetwork(mock_translation_config) + densepose_head = DensePoseHead(mock_densepose_config) + + return { + 'csi_processor': csi_processor, + 'phase_sanitizer': phase_sanitizer, + 'translation_network': translation_network, + 'densepose_head': densepose_head + } + + @pytest.fixture + def mock_raw_csi_input(self): + """Generate mock raw CSI input data""" + batch_size = 4 + antennas = 3 + subcarriers = 56 + time_samples = 100 + + # Generate complex CSI data + real_part = np.random.randn(batch_size, antennas, subcarriers, time_samples) + imag_part = np.random.randn(batch_size, antennas, subcarriers, time_samples) + + return real_part + 1j * imag_part + + @pytest.fixture + def mock_ground_truth_densepose(self): + """Generate mock ground truth DensePose annotations""" + batch_size = 4 + height = 224 + width = 224 + num_parts = 24 + + # Segmentation masks + seg_masks = torch.randint(0, num_parts + 1, (batch_size, height, width)) + + # UV coordinates + uv_coords = torch.randn(batch_size, 2, height, width) + + return { + 'segmentation': seg_masks, + 'uv_coordinates': uv_coords + } + + def test_end_to_end_inference_pipeline_produces_valid_output(self, inference_pipeline_components, mock_raw_csi_input): + """Test that end-to-end inference pipeline produces valid DensePose output""" + # Arrange + csi_processor = inference_pipeline_components['csi_processor'] + phase_sanitizer = inference_pipeline_components['phase_sanitizer'] + translation_network = inference_pipeline_components['translation_network'] + densepose_head = inference_pipeline_components['densepose_head'] + + # Set models to evaluation mode + translation_network.eval() + densepose_head.eval() + + # Act - Run the complete inference pipeline + with torch.no_grad(): + # 1. Process CSI data + processed_csi = csi_processor.process_csi_batch(mock_raw_csi_input) + + # 2. Sanitize phase information + sanitized_csi = phase_sanitizer.sanitize_phase_batch(processed_csi) + + # 3. Translate CSI to visual features + visual_features = translation_network(sanitized_csi) + + # 4. Generate DensePose predictions + densepose_output = densepose_head(visual_features) + + # Assert + assert densepose_output is not None + assert isinstance(densepose_output, dict) + assert 'segmentation' in densepose_output + assert 'uv_coordinates' in densepose_output + + seg_output = densepose_output['segmentation'] + uv_output = densepose_output['uv_coordinates'] + + # Check output shapes + assert seg_output.shape[0] == mock_raw_csi_input.shape[0] # Batch size preserved + assert seg_output.shape[1] == 25 # 24 body parts + 1 background + assert uv_output.shape[0] == mock_raw_csi_input.shape[0] # Batch size preserved + assert uv_output.shape[1] == 2 # U and V coordinates + + # Check output ranges + assert torch.all(uv_output >= 0) and torch.all(uv_output <= 1) # UV in [0, 1] + + def test_inference_pipeline_handles_different_batch_sizes(self, inference_pipeline_components): + """Test that inference pipeline handles different batch sizes""" + # Arrange + csi_processor = inference_pipeline_components['csi_processor'] + phase_sanitizer = inference_pipeline_components['phase_sanitizer'] + translation_network = inference_pipeline_components['translation_network'] + densepose_head = inference_pipeline_components['densepose_head'] + + # Different batch sizes + small_batch = np.random.randn(1, 3, 56, 100) + 1j * np.random.randn(1, 3, 56, 100) + large_batch = np.random.randn(8, 3, 56, 100) + 1j * np.random.randn(8, 3, 56, 100) + + # Set models to evaluation mode + translation_network.eval() + densepose_head.eval() + + # Act + with torch.no_grad(): + # Small batch + small_processed = csi_processor.process_csi_batch(small_batch) + small_sanitized = phase_sanitizer.sanitize_phase_batch(small_processed) + small_features = translation_network(small_sanitized) + small_output = densepose_head(small_features) + + # Large batch + large_processed = csi_processor.process_csi_batch(large_batch) + large_sanitized = phase_sanitizer.sanitize_phase_batch(large_processed) + large_features = translation_network(large_sanitized) + large_output = densepose_head(large_features) + + # Assert + assert small_output['segmentation'].shape[0] == 1 + assert large_output['segmentation'].shape[0] == 8 + assert small_output['uv_coordinates'].shape[0] == 1 + assert large_output['uv_coordinates'].shape[0] == 8 + + def test_inference_pipeline_maintains_gradient_flow_during_training(self, inference_pipeline_components, + mock_raw_csi_input, mock_ground_truth_densepose): + """Test that inference pipeline maintains gradient flow during training""" + # Arrange + csi_processor = inference_pipeline_components['csi_processor'] + phase_sanitizer = inference_pipeline_components['phase_sanitizer'] + translation_network = inference_pipeline_components['translation_network'] + densepose_head = inference_pipeline_components['densepose_head'] + + # Set models to training mode + translation_network.train() + densepose_head.train() + + # Create optimizer + optimizer = torch.optim.Adam( + list(translation_network.parameters()) + list(densepose_head.parameters()), + lr=0.001 + ) + + # Act + optimizer.zero_grad() + + # Forward pass + processed_csi = csi_processor.process_csi_batch(mock_raw_csi_input) + sanitized_csi = phase_sanitizer.sanitize_phase_batch(processed_csi) + visual_features = translation_network(sanitized_csi) + densepose_output = densepose_head(visual_features) + + # Resize ground truth to match output + seg_target = torch.nn.functional.interpolate( + mock_ground_truth_densepose['segmentation'].float().unsqueeze(1), + size=densepose_output['segmentation'].shape[2:], + mode='nearest' + ).squeeze(1).long() + + uv_target = torch.nn.functional.interpolate( + mock_ground_truth_densepose['uv_coordinates'], + size=densepose_output['uv_coordinates'].shape[2:], + mode='bilinear', + align_corners=False + ) + + # Compute loss + loss = densepose_head.compute_total_loss(densepose_output, seg_target, uv_target) + + # Backward pass + loss.backward() + + # Assert - Check that gradients are computed + for param in translation_network.parameters(): + if param.requires_grad: + assert param.grad is not None + assert not torch.allclose(param.grad, torch.zeros_like(param.grad)) + + for param in densepose_head.parameters(): + if param.requires_grad: + assert param.grad is not None + assert not torch.allclose(param.grad, torch.zeros_like(param.grad)) + + def test_inference_pipeline_performance_benchmarking(self, inference_pipeline_components, mock_raw_csi_input): + """Test inference pipeline performance for real-time requirements""" + import time + + # Arrange + csi_processor = inference_pipeline_components['csi_processor'] + phase_sanitizer = inference_pipeline_components['phase_sanitizer'] + translation_network = inference_pipeline_components['translation_network'] + densepose_head = inference_pipeline_components['densepose_head'] + + # Set models to evaluation mode for inference + translation_network.eval() + densepose_head.eval() + + # Warm up (first inference is often slower) + with torch.no_grad(): + processed_csi = csi_processor.process_csi_batch(mock_raw_csi_input) + sanitized_csi = phase_sanitizer.sanitize_phase_batch(processed_csi) + visual_features = translation_network(sanitized_csi) + _ = densepose_head(visual_features) + + # Act - Measure inference time + start_time = time.time() + + with torch.no_grad(): + processed_csi = csi_processor.process_csi_batch(mock_raw_csi_input) + sanitized_csi = phase_sanitizer.sanitize_phase_batch(processed_csi) + visual_features = translation_network(sanitized_csi) + densepose_output = densepose_head(visual_features) + + end_time = time.time() + inference_time = end_time - start_time + + # Assert - Should meet real-time requirements (< 50ms for batch of 4) + assert inference_time < 0.05, f"Inference took {inference_time:.3f}s, expected < 0.05s" + + def test_inference_pipeline_handles_edge_cases(self, inference_pipeline_components): + """Test that inference pipeline handles edge cases gracefully""" + # Arrange + csi_processor = inference_pipeline_components['csi_processor'] + phase_sanitizer = inference_pipeline_components['phase_sanitizer'] + translation_network = inference_pipeline_components['translation_network'] + densepose_head = inference_pipeline_components['densepose_head'] + + # Edge cases + zero_input = np.zeros((1, 3, 56, 100), dtype=complex) + noisy_input = np.random.randn(1, 3, 56, 100) * 100 + 1j * np.random.randn(1, 3, 56, 100) * 100 + + translation_network.eval() + densepose_head.eval() + + # Act & Assert + with torch.no_grad(): + # Zero input + zero_processed = csi_processor.process_csi_batch(zero_input) + zero_sanitized = phase_sanitizer.sanitize_phase_batch(zero_processed) + zero_features = translation_network(zero_sanitized) + zero_output = densepose_head(zero_features) + + assert not torch.isnan(zero_output['segmentation']).any() + assert not torch.isnan(zero_output['uv_coordinates']).any() + + # Noisy input + noisy_processed = csi_processor.process_csi_batch(noisy_input) + noisy_sanitized = phase_sanitizer.sanitize_phase_batch(noisy_processed) + noisy_features = translation_network(noisy_sanitized) + noisy_output = densepose_head(noisy_features) + + assert not torch.isnan(noisy_output['segmentation']).any() + assert not torch.isnan(noisy_output['uv_coordinates']).any() + + def test_inference_pipeline_memory_efficiency(self, inference_pipeline_components): + """Test that inference pipeline is memory efficient""" + # Arrange + csi_processor = inference_pipeline_components['csi_processor'] + phase_sanitizer = inference_pipeline_components['phase_sanitizer'] + translation_network = inference_pipeline_components['translation_network'] + densepose_head = inference_pipeline_components['densepose_head'] + + # Large batch to test memory usage + large_input = np.random.randn(16, 3, 56, 100) + 1j * np.random.randn(16, 3, 56, 100) + + translation_network.eval() + densepose_head.eval() + + # Act - Process in chunks to manage memory + chunk_size = 4 + outputs = [] + + with torch.no_grad(): + for i in range(0, large_input.shape[0], chunk_size): + chunk = large_input[i:i+chunk_size] + + processed_chunk = csi_processor.process_csi_batch(chunk) + sanitized_chunk = phase_sanitizer.sanitize_phase_batch(processed_chunk) + feature_chunk = translation_network(sanitized_chunk) + output_chunk = densepose_head(feature_chunk) + + outputs.append(output_chunk) + + # Clear intermediate tensors to free memory + del processed_chunk, sanitized_chunk, feature_chunk + + # Assert + assert len(outputs) == 4 # 16 samples / 4 chunk_size + for output in outputs: + assert output['segmentation'].shape[0] <= chunk_size + + def test_inference_pipeline_deterministic_output(self, inference_pipeline_components, mock_raw_csi_input): + """Test that inference pipeline produces deterministic output in eval mode""" + # Arrange + csi_processor = inference_pipeline_components['csi_processor'] + phase_sanitizer = inference_pipeline_components['phase_sanitizer'] + translation_network = inference_pipeline_components['translation_network'] + densepose_head = inference_pipeline_components['densepose_head'] + + # Set models to evaluation mode + translation_network.eval() + densepose_head.eval() + + # Act - Run inference twice + with torch.no_grad(): + # First run + processed_csi_1 = csi_processor.process_csi_batch(mock_raw_csi_input) + sanitized_csi_1 = phase_sanitizer.sanitize_phase_batch(processed_csi_1) + visual_features_1 = translation_network(sanitized_csi_1) + output_1 = densepose_head(visual_features_1) + + # Second run + processed_csi_2 = csi_processor.process_csi_batch(mock_raw_csi_input) + sanitized_csi_2 = phase_sanitizer.sanitize_phase_batch(processed_csi_2) + visual_features_2 = translation_network(sanitized_csi_2) + output_2 = densepose_head(visual_features_2) + + # Assert - Outputs should be identical in eval mode + assert torch.allclose(output_1['segmentation'], output_2['segmentation'], atol=1e-6) + assert torch.allclose(output_1['uv_coordinates'], output_2['uv_coordinates'], atol=1e-6) + + def test_inference_pipeline_confidence_estimation(self, inference_pipeline_components, mock_raw_csi_input): + """Test that inference pipeline provides confidence estimates""" + # Arrange + csi_processor = inference_pipeline_components['csi_processor'] + phase_sanitizer = inference_pipeline_components['phase_sanitizer'] + translation_network = inference_pipeline_components['translation_network'] + densepose_head = inference_pipeline_components['densepose_head'] + + translation_network.eval() + densepose_head.eval() + + # Act + with torch.no_grad(): + processed_csi = csi_processor.process_csi_batch(mock_raw_csi_input) + sanitized_csi = phase_sanitizer.sanitize_phase_batch(processed_csi) + visual_features = translation_network(sanitized_csi) + densepose_output = densepose_head(visual_features) + + # Get confidence estimates + confidence = densepose_head.get_prediction_confidence(densepose_output) + + # Assert + assert 'segmentation_confidence' in confidence + assert 'uv_confidence' in confidence + + seg_conf = confidence['segmentation_confidence'] + uv_conf = confidence['uv_confidence'] + + assert seg_conf.shape[0] == mock_raw_csi_input.shape[0] + assert uv_conf.shape[0] == mock_raw_csi_input.shape[0] + assert torch.all(seg_conf >= 0) and torch.all(seg_conf <= 1) + assert torch.all(uv_conf >= 0) + + def test_inference_pipeline_post_processing(self, inference_pipeline_components, mock_raw_csi_input): + """Test that inference pipeline post-processes predictions correctly""" + # Arrange + csi_processor = inference_pipeline_components['csi_processor'] + phase_sanitizer = inference_pipeline_components['phase_sanitizer'] + translation_network = inference_pipeline_components['translation_network'] + densepose_head = inference_pipeline_components['densepose_head'] + + translation_network.eval() + densepose_head.eval() + + # Act + with torch.no_grad(): + processed_csi = csi_processor.process_csi_batch(mock_raw_csi_input) + sanitized_csi = phase_sanitizer.sanitize_phase_batch(processed_csi) + visual_features = translation_network(sanitized_csi) + raw_output = densepose_head(visual_features) + + # Post-process predictions + processed_output = densepose_head.post_process_predictions(raw_output) + + # Assert + assert 'body_parts' in processed_output + assert 'uv_coordinates' in processed_output + assert 'confidence_scores' in processed_output + + body_parts = processed_output['body_parts'] + assert body_parts.dtype == torch.long # Class indices + assert torch.all(body_parts >= 0) and torch.all(body_parts <= 24) # Valid class range \ No newline at end of file diff --git a/tests/unit/test_csi_extractor.py b/tests/unit/test_csi_extractor.py new file mode 100644 index 0000000..043a469 --- /dev/null +++ b/tests/unit/test_csi_extractor.py @@ -0,0 +1,264 @@ +import pytest +import numpy as np +import torch +from unittest.mock import Mock, patch, MagicMock +from src.hardware.csi_extractor import CSIExtractor, CSIExtractionError + + +class TestCSIExtractor: + """Test suite for CSI Extractor following London School TDD principles""" + + @pytest.fixture + def mock_config(self): + """Configuration for CSI extractor""" + return { + 'interface': 'wlan0', + 'channel': 6, + 'bandwidth': 20, + 'sample_rate': 1000, + 'buffer_size': 1024, + 'extraction_timeout': 5.0 + } + + @pytest.fixture + def mock_router_interface(self): + """Mock router interface for testing""" + mock_router = Mock() + mock_router.is_connected = True + mock_router.execute_command = Mock() + return mock_router + + @pytest.fixture + def csi_extractor(self, mock_config, mock_router_interface): + """Create CSI extractor instance for testing""" + return CSIExtractor(mock_config, mock_router_interface) + + @pytest.fixture + def mock_csi_data(self): + """Generate synthetic CSI data for testing""" + # Simulate CSI data: complex values for multiple subcarriers + num_subcarriers = 56 + num_antennas = 3 + amplitude = np.random.uniform(0.1, 2.0, (num_antennas, num_subcarriers)) + phase = np.random.uniform(-np.pi, np.pi, (num_antennas, num_subcarriers)) + return amplitude * np.exp(1j * phase) + + def test_extractor_initialization_creates_correct_configuration(self, mock_config, mock_router_interface): + """Test that CSI extractor initializes with correct configuration""" + # Act + extractor = CSIExtractor(mock_config, mock_router_interface) + + # Assert + assert extractor is not None + assert extractor.interface == mock_config['interface'] + assert extractor.channel == mock_config['channel'] + assert extractor.bandwidth == mock_config['bandwidth'] + assert extractor.sample_rate == mock_config['sample_rate'] + assert extractor.buffer_size == mock_config['buffer_size'] + assert extractor.extraction_timeout == mock_config['extraction_timeout'] + assert extractor.router_interface == mock_router_interface + assert not extractor.is_extracting + + def test_start_extraction_configures_monitor_mode(self, csi_extractor, mock_router_interface): + """Test that start_extraction configures monitor mode""" + # Arrange + mock_router_interface.enable_monitor_mode.return_value = True + mock_router_interface.execute_command.return_value = "CSI extraction started" + + # Act + result = csi_extractor.start_extraction() + + # Assert + assert result is True + assert csi_extractor.is_extracting is True + mock_router_interface.enable_monitor_mode.assert_called_once_with(csi_extractor.interface) + + def test_start_extraction_handles_monitor_mode_failure(self, csi_extractor, mock_router_interface): + """Test that start_extraction handles monitor mode configuration failure""" + # Arrange + mock_router_interface.enable_monitor_mode.return_value = False + + # Act & Assert + with pytest.raises(CSIExtractionError): + csi_extractor.start_extraction() + + assert csi_extractor.is_extracting is False + + def test_stop_extraction_disables_monitor_mode(self, csi_extractor, mock_router_interface): + """Test that stop_extraction disables monitor mode""" + # Arrange + mock_router_interface.enable_monitor_mode.return_value = True + mock_router_interface.disable_monitor_mode.return_value = True + mock_router_interface.execute_command.return_value = "CSI extraction started" + + csi_extractor.start_extraction() + + # Act + result = csi_extractor.stop_extraction() + + # Assert + assert result is True + assert csi_extractor.is_extracting is False + mock_router_interface.disable_monitor_mode.assert_called_once_with(csi_extractor.interface) + + def test_extract_csi_data_returns_valid_format(self, csi_extractor, mock_router_interface, mock_csi_data): + """Test that extract_csi_data returns data in valid format""" + # Arrange + mock_router_interface.enable_monitor_mode.return_value = True + mock_router_interface.execute_command.return_value = "CSI extraction started" + + # Mock the CSI data extraction + with patch.object(csi_extractor, '_parse_csi_output', return_value=mock_csi_data): + csi_extractor.start_extraction() + + # Act + csi_data = csi_extractor.extract_csi_data() + + # Assert + assert csi_data is not None + assert isinstance(csi_data, np.ndarray) + assert csi_data.dtype == np.complex128 + assert csi_data.shape == mock_csi_data.shape + + def test_extract_csi_data_requires_active_extraction(self, csi_extractor): + """Test that extract_csi_data requires active extraction""" + # Act & Assert + with pytest.raises(CSIExtractionError): + csi_extractor.extract_csi_data() + + def test_extract_csi_data_handles_timeout(self, csi_extractor, mock_router_interface): + """Test that extract_csi_data handles extraction timeout""" + # Arrange + mock_router_interface.enable_monitor_mode.return_value = True + mock_router_interface.execute_command.side_effect = [ + "CSI extraction started", + Exception("Timeout") + ] + + csi_extractor.start_extraction() + + # Act & Assert + with pytest.raises(CSIExtractionError): + csi_extractor.extract_csi_data() + + def test_convert_to_tensor_produces_correct_format(self, csi_extractor, mock_csi_data): + """Test that convert_to_tensor produces correctly formatted tensor""" + # Act + tensor = csi_extractor.convert_to_tensor(mock_csi_data) + + # Assert + assert isinstance(tensor, torch.Tensor) + assert tensor.dtype == torch.float32 + assert tensor.shape[0] == mock_csi_data.shape[0] * 2 # Real and imaginary parts + assert tensor.shape[1] == mock_csi_data.shape[1] + + def test_convert_to_tensor_handles_invalid_input(self, csi_extractor): + """Test that convert_to_tensor handles invalid input""" + # Arrange + invalid_data = "not an array" + + # Act & Assert + with pytest.raises(ValueError): + csi_extractor.convert_to_tensor(invalid_data) + + def test_get_extraction_stats_returns_valid_statistics(self, csi_extractor, mock_router_interface): + """Test that get_extraction_stats returns valid statistics""" + # Arrange + mock_router_interface.enable_monitor_mode.return_value = True + mock_router_interface.execute_command.return_value = "CSI extraction started" + + csi_extractor.start_extraction() + + # Act + stats = csi_extractor.get_extraction_stats() + + # Assert + assert stats is not None + assert isinstance(stats, dict) + assert 'samples_extracted' in stats + assert 'extraction_rate' in stats + assert 'buffer_utilization' in stats + assert 'last_extraction_time' in stats + + def test_set_channel_configures_wifi_channel(self, csi_extractor, mock_router_interface): + """Test that set_channel configures WiFi channel""" + # Arrange + new_channel = 11 + mock_router_interface.execute_command.return_value = f"Channel set to {new_channel}" + + # Act + result = csi_extractor.set_channel(new_channel) + + # Assert + assert result is True + assert csi_extractor.channel == new_channel + mock_router_interface.execute_command.assert_called() + + def test_set_channel_validates_channel_range(self, csi_extractor): + """Test that set_channel validates channel range""" + # Act & Assert + with pytest.raises(ValueError): + csi_extractor.set_channel(0) # Invalid channel + + with pytest.raises(ValueError): + csi_extractor.set_channel(15) # Invalid channel + + def test_extractor_supports_context_manager(self, csi_extractor, mock_router_interface): + """Test that CSI extractor supports context manager protocol""" + # Arrange + mock_router_interface.enable_monitor_mode.return_value = True + mock_router_interface.disable_monitor_mode.return_value = True + mock_router_interface.execute_command.return_value = "CSI extraction started" + + # Act + with csi_extractor as extractor: + # Assert + assert extractor.is_extracting is True + + # Assert - extraction should be stopped after context + assert csi_extractor.is_extracting is False + + def test_extractor_validates_configuration(self, mock_router_interface): + """Test that CSI extractor validates configuration parameters""" + # Arrange + invalid_config = { + 'interface': '', # Invalid interface + 'channel': 6, + 'bandwidth': 20 + } + + # Act & Assert + with pytest.raises(ValueError): + CSIExtractor(invalid_config, mock_router_interface) + + def test_parse_csi_output_processes_raw_data(self, csi_extractor): + """Test that _parse_csi_output processes raw CSI data correctly""" + # Arrange + raw_output = "CSI_DATA: 1.5+0.5j,2.0-1.0j,0.8+1.2j" + + # Act + parsed_data = csi_extractor._parse_csi_output(raw_output) + + # Assert + assert parsed_data is not None + assert isinstance(parsed_data, np.ndarray) + assert parsed_data.dtype == np.complex128 + + def test_buffer_management_handles_overflow(self, csi_extractor, mock_router_interface, mock_csi_data): + """Test that buffer management handles overflow correctly""" + # Arrange + mock_router_interface.enable_monitor_mode.return_value = True + mock_router_interface.execute_command.return_value = "CSI extraction started" + + with patch.object(csi_extractor, '_parse_csi_output', return_value=mock_csi_data): + csi_extractor.start_extraction() + + # Fill buffer beyond capacity + for _ in range(csi_extractor.buffer_size + 10): + csi_extractor._add_to_buffer(mock_csi_data) + + # Act + stats = csi_extractor.get_extraction_stats() + + # Assert + assert stats['buffer_utilization'] <= 1.0 # Should not exceed 100% \ No newline at end of file diff --git a/tests/unit/test_densepose_head.py b/tests/unit/test_densepose_head.py index b9d604a..4966d11 100644 --- a/tests/unit/test_densepose_head.py +++ b/tests/unit/test_densepose_head.py @@ -3,27 +3,27 @@ import torch import torch.nn as nn import numpy as np from unittest.mock import Mock, patch -from src.models.densepose_head import DensePoseHead +from src.models.densepose_head import DensePoseHead, DensePoseError class TestDensePoseHead: """Test suite for DensePose Head following London School TDD principles""" - @pytest.fixture - def mock_feature_input(self): - """Generate synthetic feature input tensor for testing""" - # Batch size 2, 512 channels, 56 height, 100 width (from modality translation) - return torch.randn(2, 512, 56, 100) - @pytest.fixture def mock_config(self): """Configuration for DensePose head""" return { - 'input_channels': 512, - 'num_body_parts': 24, # Standard DensePose body parts - 'num_uv_coordinates': 2, # U and V coordinates - 'hidden_dim': 256, - 'dropout_rate': 0.1 + 'input_channels': 256, + 'num_body_parts': 24, + 'num_uv_coordinates': 2, + 'hidden_channels': [128, 64], + 'kernel_size': 3, + 'padding': 1, + 'dropout_rate': 0.1, + 'use_deformable_conv': False, + 'use_fpn': True, + 'fpn_levels': [2, 3, 4, 5], + 'output_stride': 4 } @pytest.fixture @@ -31,6 +31,33 @@ class TestDensePoseHead: """Create DensePose head instance for testing""" return DensePoseHead(mock_config) + @pytest.fixture + def mock_feature_input(self): + """Generate mock feature input tensor""" + batch_size = 2 + channels = 256 + height = 56 + width = 56 + return torch.randn(batch_size, channels, height, width) + + @pytest.fixture + def mock_target_masks(self): + """Generate mock target segmentation masks""" + batch_size = 2 + num_parts = 24 + height = 224 + width = 224 + return torch.randint(0, num_parts + 1, (batch_size, height, width)) + + @pytest.fixture + def mock_target_uv(self): + """Generate mock target UV coordinates""" + batch_size = 2 + num_coords = 2 + height = 224 + width = 224 + return torch.randn(batch_size, num_coords, height, width) + def test_head_initialization_creates_correct_architecture(self, mock_config): """Test that DensePose head initializes with correct architecture""" # Act @@ -39,135 +66,302 @@ class TestDensePoseHead: # Assert assert head is not None assert isinstance(head, nn.Module) - assert hasattr(head, 'segmentation_head') - assert hasattr(head, 'uv_regression_head') assert head.input_channels == mock_config['input_channels'] assert head.num_body_parts == mock_config['num_body_parts'] assert head.num_uv_coordinates == mock_config['num_uv_coordinates'] + assert head.use_fpn == mock_config['use_fpn'] + assert hasattr(head, 'segmentation_head') + assert hasattr(head, 'uv_regression_head') + if mock_config['use_fpn']: + assert hasattr(head, 'fpn') - def test_forward_pass_produces_correct_output_shapes(self, densepose_head, mock_feature_input): - """Test that forward pass produces correctly shaped outputs""" + def test_forward_pass_produces_correct_output_format(self, densepose_head, mock_feature_input): + """Test that forward pass produces correctly formatted output""" # Act - with torch.no_grad(): - segmentation, uv_coords = densepose_head(mock_feature_input) + output = densepose_head(mock_feature_input) # Assert - assert segmentation is not None - assert uv_coords is not None - assert isinstance(segmentation, torch.Tensor) - assert isinstance(uv_coords, torch.Tensor) + assert output is not None + assert isinstance(output, dict) + assert 'segmentation' in output + assert 'uv_coordinates' in output - # Check segmentation output shape - assert segmentation.shape[0] == mock_feature_input.shape[0] # Batch size preserved - assert segmentation.shape[1] == densepose_head.num_body_parts # Correct number of body parts - assert segmentation.shape[2:] == mock_feature_input.shape[2:] # Spatial dimensions preserved + seg_output = output['segmentation'] + uv_output = output['uv_coordinates'] - # Check UV coordinates output shape - assert uv_coords.shape[0] == mock_feature_input.shape[0] # Batch size preserved - assert uv_coords.shape[1] == densepose_head.num_uv_coordinates # U and V coordinates - assert uv_coords.shape[2:] == mock_feature_input.shape[2:] # Spatial dimensions preserved + assert isinstance(seg_output, torch.Tensor) + assert isinstance(uv_output, torch.Tensor) + assert seg_output.shape[0] == mock_feature_input.shape[0] # Batch size preserved + assert uv_output.shape[0] == mock_feature_input.shape[0] # Batch size preserved - def test_segmentation_output_has_valid_probabilities(self, densepose_head, mock_feature_input): - """Test that segmentation output has valid probability distributions""" + def test_segmentation_head_produces_correct_shape(self, densepose_head, mock_feature_input): + """Test that segmentation head produces correct output shape""" # Act - with torch.no_grad(): - segmentation, _ = densepose_head(mock_feature_input) + output = densepose_head(mock_feature_input) + seg_output = output['segmentation'] # Assert - # After softmax, values should be between 0 and 1 - assert torch.all(segmentation >= 0.0) - assert torch.all(segmentation <= 1.0) - - # Sum across body parts dimension should be approximately 1 - part_sums = torch.sum(segmentation, dim=1) - assert torch.allclose(part_sums, torch.ones_like(part_sums), atol=1e-5) + expected_channels = densepose_head.num_body_parts + 1 # +1 for background + assert seg_output.shape[1] == expected_channels + assert seg_output.shape[2] >= mock_feature_input.shape[2] # Height upsampled + assert seg_output.shape[3] >= mock_feature_input.shape[3] # Width upsampled - def test_uv_coordinates_output_in_valid_range(self, densepose_head, mock_feature_input): - """Test that UV coordinates are in valid range [0, 1]""" + def test_uv_regression_head_produces_correct_shape(self, densepose_head, mock_feature_input): + """Test that UV regression head produces correct output shape""" # Act - with torch.no_grad(): - _, uv_coords = densepose_head(mock_feature_input) + output = densepose_head(mock_feature_input) + uv_output = output['uv_coordinates'] # Assert - # UV coordinates should be in range [0, 1] after sigmoid - assert torch.all(uv_coords >= 0.0) - assert torch.all(uv_coords <= 1.0) + assert uv_output.shape[1] == densepose_head.num_uv_coordinates + assert uv_output.shape[2] >= mock_feature_input.shape[2] # Height upsampled + assert uv_output.shape[3] >= mock_feature_input.shape[3] # Width upsampled - def test_head_handles_different_batch_sizes(self, densepose_head): - """Test that head handles different batch sizes correctly""" + def test_compute_segmentation_loss_measures_pixel_classification(self, densepose_head, mock_feature_input, mock_target_masks): + """Test that compute_segmentation_loss measures pixel classification accuracy""" # Arrange - batch_sizes = [1, 4, 8] + output = densepose_head(mock_feature_input) + seg_logits = output['segmentation'] - for batch_size in batch_sizes: - input_tensor = torch.randn(batch_size, 512, 56, 100) - - # Act - with torch.no_grad(): - segmentation, uv_coords = densepose_head(input_tensor) - - # Assert - assert segmentation.shape[0] == batch_size - assert uv_coords.shape[0] == batch_size - - def test_head_is_trainable(self, densepose_head, mock_feature_input): - """Test that head parameters are trainable""" - # Arrange - seg_criterion = nn.CrossEntropyLoss() - uv_criterion = nn.MSELoss() - - # Create targets with correct shapes - seg_target = torch.randint(0, 24, (2, 56, 100)) # Class indices for segmentation - uv_target = torch.rand(2, 2, 56, 100) # UV coordinates target + # Resize target to match output + target_resized = torch.nn.functional.interpolate( + mock_target_masks.float().unsqueeze(1), + size=seg_logits.shape[2:], + mode='nearest' + ).squeeze(1).long() # Act - segmentation, uv_coords = densepose_head(mock_feature_input) - seg_loss = seg_criterion(segmentation, seg_target) - uv_loss = uv_criterion(uv_coords, uv_target) - total_loss = seg_loss + uv_loss - total_loss.backward() + loss = densepose_head.compute_segmentation_loss(seg_logits, target_resized) # Assert + assert loss is not None + assert isinstance(loss, torch.Tensor) + assert loss.dim() == 0 # Scalar loss + assert loss.item() >= 0 # Loss should be non-negative + + def test_compute_uv_loss_measures_coordinate_regression(self, densepose_head, mock_feature_input, mock_target_uv): + """Test that compute_uv_loss measures UV coordinate regression accuracy""" + # Arrange + output = densepose_head(mock_feature_input) + uv_pred = output['uv_coordinates'] + + # Resize target to match output + target_resized = torch.nn.functional.interpolate( + mock_target_uv, + size=uv_pred.shape[2:], + mode='bilinear', + align_corners=False + ) + + # Act + loss = densepose_head.compute_uv_loss(uv_pred, target_resized) + + # Assert + assert loss is not None + assert isinstance(loss, torch.Tensor) + assert loss.dim() == 0 # Scalar loss + assert loss.item() >= 0 # Loss should be non-negative + + def test_compute_total_loss_combines_segmentation_and_uv_losses(self, densepose_head, mock_feature_input, mock_target_masks, mock_target_uv): + """Test that compute_total_loss combines segmentation and UV losses""" + # Arrange + output = densepose_head(mock_feature_input) + + # Resize targets to match outputs + seg_target = torch.nn.functional.interpolate( + mock_target_masks.float().unsqueeze(1), + size=output['segmentation'].shape[2:], + mode='nearest' + ).squeeze(1).long() + + uv_target = torch.nn.functional.interpolate( + mock_target_uv, + size=output['uv_coordinates'].shape[2:], + mode='bilinear', + align_corners=False + ) + + # Act + total_loss = densepose_head.compute_total_loss(output, seg_target, uv_target) + seg_loss = densepose_head.compute_segmentation_loss(output['segmentation'], seg_target) + uv_loss = densepose_head.compute_uv_loss(output['uv_coordinates'], uv_target) + + # Assert + assert total_loss is not None + assert isinstance(total_loss, torch.Tensor) assert total_loss.item() > 0 - # Check that gradients are computed + # Total loss should be combination of individual losses + expected_total = seg_loss + uv_loss + assert torch.allclose(total_loss, expected_total, atol=1e-6) + + def test_fpn_integration_enhances_multi_scale_features(self, mock_config, mock_feature_input): + """Test that FPN integration enhances multi-scale feature processing""" + # Arrange + config_with_fpn = mock_config.copy() + config_with_fpn['use_fpn'] = True + + config_without_fpn = mock_config.copy() + config_without_fpn['use_fpn'] = False + + head_with_fpn = DensePoseHead(config_with_fpn) + head_without_fpn = DensePoseHead(config_without_fpn) + + # Act + output_with_fpn = head_with_fpn(mock_feature_input) + output_without_fpn = head_without_fpn(mock_feature_input) + + # Assert + assert output_with_fpn['segmentation'].shape == output_without_fpn['segmentation'].shape + assert output_with_fpn['uv_coordinates'].shape == output_without_fpn['uv_coordinates'].shape + # Outputs should be different due to FPN + assert not torch.allclose(output_with_fpn['segmentation'], output_without_fpn['segmentation'], atol=1e-6) + + def test_get_prediction_confidence_provides_uncertainty_estimates(self, densepose_head, mock_feature_input): + """Test that get_prediction_confidence provides uncertainty estimates""" + # Arrange + output = densepose_head(mock_feature_input) + + # Act + confidence = densepose_head.get_prediction_confidence(output) + + # Assert + assert confidence is not None + assert isinstance(confidence, dict) + assert 'segmentation_confidence' in confidence + assert 'uv_confidence' in confidence + + seg_conf = confidence['segmentation_confidence'] + uv_conf = confidence['uv_confidence'] + + assert isinstance(seg_conf, torch.Tensor) + assert isinstance(uv_conf, torch.Tensor) + assert seg_conf.shape[0] == mock_feature_input.shape[0] + assert uv_conf.shape[0] == mock_feature_input.shape[0] + + def test_post_process_predictions_formats_output(self, densepose_head, mock_feature_input): + """Test that post_process_predictions formats output correctly""" + # Arrange + raw_output = densepose_head(mock_feature_input) + + # Act + processed = densepose_head.post_process_predictions(raw_output) + + # Assert + assert processed is not None + assert isinstance(processed, dict) + assert 'body_parts' in processed + assert 'uv_coordinates' in processed + assert 'confidence_scores' in processed + + def test_training_mode_enables_dropout(self, densepose_head, mock_feature_input): + """Test that training mode enables dropout for regularization""" + # Arrange + densepose_head.train() + + # Act + output1 = densepose_head(mock_feature_input) + output2 = densepose_head(mock_feature_input) + + # Assert - outputs should be different due to dropout + assert not torch.allclose(output1['segmentation'], output2['segmentation'], atol=1e-6) + assert not torch.allclose(output1['uv_coordinates'], output2['uv_coordinates'], atol=1e-6) + + def test_evaluation_mode_disables_dropout(self, densepose_head, mock_feature_input): + """Test that evaluation mode disables dropout for consistent inference""" + # Arrange + densepose_head.eval() + + # Act + output1 = densepose_head(mock_feature_input) + output2 = densepose_head(mock_feature_input) + + # Assert - outputs should be identical in eval mode + assert torch.allclose(output1['segmentation'], output2['segmentation'], atol=1e-6) + assert torch.allclose(output1['uv_coordinates'], output2['uv_coordinates'], atol=1e-6) + + def test_head_validates_input_dimensions(self, densepose_head): + """Test that head validates input dimensions""" + # Arrange + invalid_input = torch.randn(2, 128, 56, 56) # Wrong number of channels + + # Act & Assert + with pytest.raises(DensePoseError): + densepose_head(invalid_input) + + def test_head_handles_different_input_sizes(self, densepose_head): + """Test that head handles different input sizes""" + # Arrange + small_input = torch.randn(1, 256, 28, 28) + large_input = torch.randn(1, 256, 112, 112) + + # Act + small_output = densepose_head(small_input) + large_output = densepose_head(large_input) + + # Assert + assert small_output['segmentation'].shape[2:] != large_output['segmentation'].shape[2:] + assert small_output['uv_coordinates'].shape[2:] != large_output['uv_coordinates'].shape[2:] + + def test_head_supports_gradient_computation(self, densepose_head, mock_feature_input, mock_target_masks, mock_target_uv): + """Test that head supports gradient computation for training""" + # Arrange + densepose_head.train() + optimizer = torch.optim.Adam(densepose_head.parameters(), lr=0.001) + + output = densepose_head(mock_feature_input) + + # Resize targets + seg_target = torch.nn.functional.interpolate( + mock_target_masks.float().unsqueeze(1), + size=output['segmentation'].shape[2:], + mode='nearest' + ).squeeze(1).long() + + uv_target = torch.nn.functional.interpolate( + mock_target_uv, + size=output['uv_coordinates'].shape[2:], + mode='bilinear', + align_corners=False + ) + + # Act + loss = densepose_head.compute_total_loss(output, seg_target, uv_target) + + optimizer.zero_grad() + loss.backward() + + # Assert for param in densepose_head.parameters(): if param.requires_grad: assert param.grad is not None + assert not torch.allclose(param.grad, torch.zeros_like(param.grad)) - def test_head_handles_invalid_input_shape(self, densepose_head): - """Test that head handles invalid input shapes gracefully""" + def test_head_configuration_validation(self): + """Test that head validates configuration parameters""" # Arrange - invalid_input = torch.randn(2, 256, 56, 100) # Wrong number of channels + invalid_config = { + 'input_channels': 0, # Invalid + 'num_body_parts': -1, # Invalid + 'num_uv_coordinates': 2 + } # Act & Assert - with pytest.raises(RuntimeError): - densepose_head(invalid_input) + with pytest.raises(ValueError): + DensePoseHead(invalid_config) - def test_head_supports_evaluation_mode(self, densepose_head, mock_feature_input): - """Test that head supports evaluation mode""" - # Act - densepose_head.eval() + def test_save_and_load_model_state(self, densepose_head, mock_feature_input): + """Test that model state can be saved and loaded""" + # Arrange + original_output = densepose_head(mock_feature_input) - with torch.no_grad(): - seg1, uv1 = densepose_head(mock_feature_input) - seg2, uv2 = densepose_head(mock_feature_input) + # Act - Save state + state_dict = densepose_head.state_dict() - # Assert - In eval mode with same input, outputs should be identical - assert torch.allclose(seg1, seg2, atol=1e-6) - assert torch.allclose(uv1, uv2, atol=1e-6) - - def test_head_output_quality(self, densepose_head, mock_feature_input): - """Test that head produces meaningful outputs""" - # Act - with torch.no_grad(): - segmentation, uv_coords = densepose_head(mock_feature_input) + # Create new head and load state + new_head = DensePoseHead(densepose_head.config) + new_head.load_state_dict(state_dict) + new_output = new_head(mock_feature_input) # Assert - # Outputs should not contain NaN or Inf values - assert not torch.isnan(segmentation).any() - assert not torch.isinf(segmentation).any() - assert not torch.isnan(uv_coords).any() - assert not torch.isinf(uv_coords).any() - - # Outputs should have reasonable variance (not all zeros or ones) - assert segmentation.std() > 0.01 - assert uv_coords.std() > 0.01 \ No newline at end of file + assert torch.allclose(original_output['segmentation'], new_output['segmentation'], atol=1e-6) + assert torch.allclose(original_output['uv_coordinates'], new_output['uv_coordinates'], atol=1e-6) \ No newline at end of file diff --git a/tests/unit/test_modality_translation.py b/tests/unit/test_modality_translation.py index cdddbad..3b7f461 100644 --- a/tests/unit/test_modality_translation.py +++ b/tests/unit/test_modality_translation.py @@ -3,27 +3,27 @@ import torch import torch.nn as nn import numpy as np from unittest.mock import Mock, patch -from src.models.modality_translation import ModalityTranslationNetwork +from src.models.modality_translation import ModalityTranslationNetwork, ModalityTranslationError class TestModalityTranslationNetwork: """Test suite for Modality Translation Network following London School TDD principles""" - @pytest.fixture - def mock_csi_input(self): - """Generate synthetic CSI input tensor for testing""" - # Batch size 2, 3 antennas, 56 subcarriers, 100 temporal samples - return torch.randn(2, 3, 56, 100) - @pytest.fixture def mock_config(self): """Configuration for modality translation network""" return { - 'input_channels': 3, - 'hidden_dim': 256, - 'output_dim': 512, - 'num_layers': 3, - 'dropout_rate': 0.1 + 'input_channels': 6, # Real and imaginary parts for 3 antennas + 'hidden_channels': [64, 128, 256], + 'output_channels': 256, + 'kernel_size': 3, + 'stride': 1, + 'padding': 1, + 'dropout_rate': 0.1, + 'activation': 'relu', + 'normalization': 'batch', + 'use_attention': True, + 'attention_heads': 8 } @pytest.fixture @@ -31,98 +31,263 @@ class TestModalityTranslationNetwork: """Create modality translation network instance for testing""" return ModalityTranslationNetwork(mock_config) + @pytest.fixture + def mock_csi_input(self): + """Generate mock CSI input tensor""" + batch_size = 4 + channels = 6 # Real and imaginary parts for 3 antennas + height = 56 # Number of subcarriers + width = 100 # Time samples + return torch.randn(batch_size, channels, height, width) + + @pytest.fixture + def mock_target_features(self): + """Generate mock target feature tensor for training""" + batch_size = 4 + feature_dim = 256 + spatial_height = 56 + spatial_width = 100 + return torch.randn(batch_size, feature_dim, spatial_height, spatial_width) + def test_network_initialization_creates_correct_architecture(self, mock_config): - """Test that network initializes with correct architecture""" + """Test that modality translation network initializes with correct architecture""" # Act network = ModalityTranslationNetwork(mock_config) # Assert assert network is not None assert isinstance(network, nn.Module) + assert network.input_channels == mock_config['input_channels'] + assert network.output_channels == mock_config['output_channels'] + assert network.use_attention == mock_config['use_attention'] assert hasattr(network, 'encoder') assert hasattr(network, 'decoder') - assert network.input_channels == mock_config['input_channels'] - assert network.hidden_dim == mock_config['hidden_dim'] - assert network.output_dim == mock_config['output_dim'] + if mock_config['use_attention']: + assert hasattr(network, 'attention') def test_forward_pass_produces_correct_output_shape(self, translation_network, mock_csi_input): """Test that forward pass produces correctly shaped output""" # Act - with torch.no_grad(): - output = translation_network(mock_csi_input) + output = translation_network(mock_csi_input) # Assert assert output is not None assert isinstance(output, torch.Tensor) assert output.shape[0] == mock_csi_input.shape[0] # Batch size preserved - assert output.shape[1] == translation_network.output_dim # Correct output dimension - assert len(output.shape) == 4 # Should maintain spatial dimensions + assert output.shape[1] == translation_network.output_channels # Correct output channels + assert output.shape[2] == mock_csi_input.shape[2] # Spatial height preserved + assert output.shape[3] == mock_csi_input.shape[3] # Spatial width preserved - def test_forward_pass_handles_different_batch_sizes(self, translation_network): - """Test that network handles different batch sizes correctly""" + def test_forward_pass_handles_different_input_sizes(self, translation_network): + """Test that forward pass handles different input sizes""" # Arrange - batch_sizes = [1, 4, 8] + small_input = torch.randn(2, 6, 28, 50) + large_input = torch.randn(8, 6, 112, 200) - for batch_size in batch_sizes: - input_tensor = torch.randn(batch_size, 3, 56, 100) - - # Act - with torch.no_grad(): - output = translation_network(input_tensor) - - # Assert - assert output.shape[0] == batch_size - assert output.shape[1] == translation_network.output_dim + # Act + small_output = translation_network(small_input) + large_output = translation_network(large_input) + + # Assert + assert small_output.shape == (2, 256, 28, 50) + assert large_output.shape == (8, 256, 112, 200) - def test_network_is_trainable(self, translation_network, mock_csi_input): - """Test that network parameters are trainable""" + def test_encoder_extracts_hierarchical_features(self, translation_network, mock_csi_input): + """Test that encoder extracts hierarchical features""" + # Act + features = translation_network.encode(mock_csi_input) + + # Assert + assert features is not None + assert isinstance(features, list) + assert len(features) == len(translation_network.encoder) + + # Check feature map sizes decrease with depth + for i in range(1, len(features)): + assert features[i].shape[2] <= features[i-1].shape[2] # Height decreases or stays same + assert features[i].shape[3] <= features[i-1].shape[3] # Width decreases or stays same + + def test_decoder_reconstructs_target_features(self, translation_network, mock_csi_input): + """Test that decoder reconstructs target feature representation""" # Arrange - criterion = nn.MSELoss() + encoded_features = translation_network.encode(mock_csi_input) + + # Act + decoded_output = translation_network.decode(encoded_features) + + # Assert + assert decoded_output is not None + assert isinstance(decoded_output, torch.Tensor) + assert decoded_output.shape[1] == translation_network.output_channels + assert decoded_output.shape[2:] == mock_csi_input.shape[2:] + + def test_attention_mechanism_enhances_features(self, mock_config, mock_csi_input): + """Test that attention mechanism enhances feature representation""" + # Arrange + config_with_attention = mock_config.copy() + config_with_attention['use_attention'] = True + + config_without_attention = mock_config.copy() + config_without_attention['use_attention'] = False + + network_with_attention = ModalityTranslationNetwork(config_with_attention) + network_without_attention = ModalityTranslationNetwork(config_without_attention) + + # Act + output_with_attention = network_with_attention(mock_csi_input) + output_without_attention = network_without_attention(mock_csi_input) + + # Assert + assert output_with_attention.shape == output_without_attention.shape + # Outputs should be different due to attention mechanism + assert not torch.allclose(output_with_attention, output_without_attention, atol=1e-6) + + def test_training_mode_enables_dropout(self, translation_network, mock_csi_input): + """Test that training mode enables dropout for regularization""" + # Arrange + translation_network.train() + + # Act + output1 = translation_network(mock_csi_input) + output2 = translation_network(mock_csi_input) + + # Assert - outputs should be different due to dropout + assert not torch.allclose(output1, output2, atol=1e-6) + + def test_evaluation_mode_disables_dropout(self, translation_network, mock_csi_input): + """Test that evaluation mode disables dropout for consistent inference""" + # Arrange + translation_network.eval() + + # Act + output1 = translation_network(mock_csi_input) + output2 = translation_network(mock_csi_input) + + # Assert - outputs should be identical in eval mode + assert torch.allclose(output1, output2, atol=1e-6) + + def test_compute_translation_loss_measures_feature_alignment(self, translation_network, mock_csi_input, mock_target_features): + """Test that compute_translation_loss measures feature alignment""" + # Arrange + predicted_features = translation_network(mock_csi_input) + + # Act + loss = translation_network.compute_translation_loss(predicted_features, mock_target_features) + + # Assert + assert loss is not None + assert isinstance(loss, torch.Tensor) + assert loss.dim() == 0 # Scalar loss + assert loss.item() >= 0 # Loss should be non-negative + + def test_compute_translation_loss_handles_different_loss_types(self, translation_network, mock_csi_input, mock_target_features): + """Test that compute_translation_loss handles different loss types""" + # Arrange + predicted_features = translation_network(mock_csi_input) + + # Act + mse_loss = translation_network.compute_translation_loss(predicted_features, mock_target_features, loss_type='mse') + l1_loss = translation_network.compute_translation_loss(predicted_features, mock_target_features, loss_type='l1') + + # Assert + assert mse_loss is not None + assert l1_loss is not None + assert mse_loss.item() != l1_loss.item() # Different loss types should give different values + + def test_get_feature_statistics_provides_analysis(self, translation_network, mock_csi_input): + """Test that get_feature_statistics provides feature analysis""" + # Arrange + output = translation_network(mock_csi_input) + + # Act + stats = translation_network.get_feature_statistics(output) + + # Assert + assert stats is not None + assert isinstance(stats, dict) + assert 'mean' in stats + assert 'std' in stats + assert 'min' in stats + assert 'max' in stats + assert 'sparsity' in stats + + def test_network_supports_gradient_computation(self, translation_network, mock_csi_input, mock_target_features): + """Test that network supports gradient computation for training""" + # Arrange + translation_network.train() + optimizer = torch.optim.Adam(translation_network.parameters(), lr=0.001) # Act output = translation_network(mock_csi_input) - # Create target with same shape as output - target = torch.randn_like(output) - loss = criterion(output, target) + loss = translation_network.compute_translation_loss(output, mock_target_features) + + optimizer.zero_grad() loss.backward() # Assert - assert loss.item() > 0 - # Check that gradients are computed for param in translation_network.parameters(): if param.requires_grad: assert param.grad is not None + assert not torch.allclose(param.grad, torch.zeros_like(param.grad)) - def test_network_handles_invalid_input_shape(self, translation_network): - """Test that network handles invalid input shapes gracefully""" + def test_network_validates_input_dimensions(self, translation_network): + """Test that network validates input dimensions""" # Arrange - invalid_input = torch.randn(2, 5, 56, 100) # Wrong number of channels + invalid_input = torch.randn(4, 3, 56, 100) # Wrong number of channels # Act & Assert - with pytest.raises(RuntimeError): + with pytest.raises(ModalityTranslationError): translation_network(invalid_input) - def test_network_supports_evaluation_mode(self, translation_network, mock_csi_input): - """Test that network supports evaluation mode""" - # Act - translation_network.eval() + def test_network_handles_batch_size_one(self, translation_network): + """Test that network handles single sample inference""" + # Arrange + single_input = torch.randn(1, 6, 56, 100) - with torch.no_grad(): - output1 = translation_network(mock_csi_input) - output2 = translation_network(mock_csi_input) - - # Assert - In eval mode with same input, outputs should be identical - assert torch.allclose(output1, output2, atol=1e-6) - - def test_network_feature_extraction_quality(self, translation_network, mock_csi_input): - """Test that network extracts meaningful features""" # Act - with torch.no_grad(): - output = translation_network(mock_csi_input) + output = translation_network(single_input) # Assert - # Features should have reasonable statistics - assert not torch.isnan(output).any() - assert not torch.isinf(output).any() - assert output.std() > 0.01 # Features should have some variance - assert output.std() < 10.0 # But not be too extreme \ No newline at end of file + assert output.shape == (1, 256, 56, 100) + + def test_save_and_load_model_state(self, translation_network, mock_csi_input): + """Test that model state can be saved and loaded""" + # Arrange + original_output = translation_network(mock_csi_input) + + # Act - Save state + state_dict = translation_network.state_dict() + + # Create new network and load state + new_network = ModalityTranslationNetwork(translation_network.config) + new_network.load_state_dict(state_dict) + new_output = new_network(mock_csi_input) + + # Assert + assert torch.allclose(original_output, new_output, atol=1e-6) + + def test_network_configuration_validation(self): + """Test that network validates configuration parameters""" + # Arrange + invalid_config = { + 'input_channels': 0, # Invalid + 'hidden_channels': [], # Invalid + 'output_channels': 256 + } + + # Act & Assert + with pytest.raises(ValueError): + ModalityTranslationNetwork(invalid_config) + + def test_feature_visualization_support(self, translation_network, mock_csi_input): + """Test that network supports feature visualization""" + # Act + features = translation_network.get_intermediate_features(mock_csi_input) + + # Assert + assert features is not None + assert isinstance(features, dict) + assert 'encoder_features' in features + assert 'decoder_features' in features + if translation_network.use_attention: + assert 'attention_weights' in features \ No newline at end of file diff --git a/tests/unit/test_router_interface.py b/tests/unit/test_router_interface.py new file mode 100644 index 0000000..7e1512e --- /dev/null +++ b/tests/unit/test_router_interface.py @@ -0,0 +1,244 @@ +import pytest +import numpy as np +from unittest.mock import Mock, patch, MagicMock +from src.hardware.router_interface import RouterInterface, RouterConnectionError + + +class TestRouterInterface: + """Test suite for Router Interface following London School TDD principles""" + + @pytest.fixture + def mock_config(self): + """Configuration for router interface""" + return { + 'router_ip': '192.168.1.1', + 'username': 'admin', + 'password': 'password', + 'ssh_port': 22, + 'timeout': 30, + 'max_retries': 3 + } + + @pytest.fixture + def router_interface(self, mock_config): + """Create router interface instance for testing""" + return RouterInterface(mock_config) + + @pytest.fixture + def mock_ssh_client(self): + """Mock SSH client for testing""" + mock_client = Mock() + mock_client.connect = Mock() + mock_client.exec_command = Mock() + mock_client.close = Mock() + return mock_client + + def test_interface_initialization_creates_correct_configuration(self, mock_config): + """Test that router interface initializes with correct configuration""" + # Act + interface = RouterInterface(mock_config) + + # Assert + assert interface is not None + assert interface.router_ip == mock_config['router_ip'] + assert interface.username == mock_config['username'] + assert interface.password == mock_config['password'] + assert interface.ssh_port == mock_config['ssh_port'] + assert interface.timeout == mock_config['timeout'] + assert interface.max_retries == mock_config['max_retries'] + assert not interface.is_connected + + @patch('paramiko.SSHClient') + def test_connect_establishes_ssh_connection(self, mock_ssh_class, router_interface, mock_ssh_client): + """Test that connect method establishes SSH connection""" + # Arrange + mock_ssh_class.return_value = mock_ssh_client + + # Act + result = router_interface.connect() + + # Assert + assert result is True + assert router_interface.is_connected is True + mock_ssh_client.set_missing_host_key_policy.assert_called_once() + mock_ssh_client.connect.assert_called_once_with( + hostname=router_interface.router_ip, + port=router_interface.ssh_port, + username=router_interface.username, + password=router_interface.password, + timeout=router_interface.timeout + ) + + @patch('paramiko.SSHClient') + def test_connect_handles_connection_failure(self, mock_ssh_class, router_interface, mock_ssh_client): + """Test that connect method handles connection failures gracefully""" + # Arrange + mock_ssh_class.return_value = mock_ssh_client + mock_ssh_client.connect.side_effect = Exception("Connection failed") + + # Act & Assert + with pytest.raises(RouterConnectionError): + router_interface.connect() + + assert router_interface.is_connected is False + + @patch('paramiko.SSHClient') + def test_disconnect_closes_ssh_connection(self, mock_ssh_class, router_interface, mock_ssh_client): + """Test that disconnect method closes SSH connection""" + # Arrange + mock_ssh_class.return_value = mock_ssh_client + router_interface.connect() + + # Act + router_interface.disconnect() + + # Assert + assert router_interface.is_connected is False + mock_ssh_client.close.assert_called_once() + + @patch('paramiko.SSHClient') + def test_execute_command_runs_ssh_command(self, mock_ssh_class, router_interface, mock_ssh_client): + """Test that execute_command runs SSH commands correctly""" + # Arrange + mock_ssh_class.return_value = mock_ssh_client + mock_stdout = Mock() + mock_stdout.read.return_value = b"command output" + mock_stderr = Mock() + mock_stderr.read.return_value = b"" + mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr) + + router_interface.connect() + + # Act + result = router_interface.execute_command("test command") + + # Assert + assert result == "command output" + mock_ssh_client.exec_command.assert_called_with("test command") + + @patch('paramiko.SSHClient') + def test_execute_command_handles_command_errors(self, mock_ssh_class, router_interface, mock_ssh_client): + """Test that execute_command handles command errors""" + # Arrange + mock_ssh_class.return_value = mock_ssh_client + mock_stdout = Mock() + mock_stdout.read.return_value = b"" + mock_stderr = Mock() + mock_stderr.read.return_value = b"command error" + mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr) + + router_interface.connect() + + # Act & Assert + with pytest.raises(RouterConnectionError): + router_interface.execute_command("failing command") + + def test_execute_command_requires_connection(self, router_interface): + """Test that execute_command requires active connection""" + # Act & Assert + with pytest.raises(RouterConnectionError): + router_interface.execute_command("test command") + + @patch('paramiko.SSHClient') + def test_get_router_info_retrieves_system_information(self, mock_ssh_class, router_interface, mock_ssh_client): + """Test that get_router_info retrieves router system information""" + # Arrange + mock_ssh_class.return_value = mock_ssh_client + mock_stdout = Mock() + mock_stdout.read.return_value = b"Router Model: AC1900\nFirmware: 1.2.3" + mock_stderr = Mock() + mock_stderr.read.return_value = b"" + mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr) + + router_interface.connect() + + # Act + info = router_interface.get_router_info() + + # Assert + assert info is not None + assert isinstance(info, dict) + assert 'model' in info + assert 'firmware' in info + + @patch('paramiko.SSHClient') + def test_enable_monitor_mode_configures_wifi_monitoring(self, mock_ssh_class, router_interface, mock_ssh_client): + """Test that enable_monitor_mode configures WiFi monitoring""" + # Arrange + mock_ssh_class.return_value = mock_ssh_client + mock_stdout = Mock() + mock_stdout.read.return_value = b"Monitor mode enabled" + mock_stderr = Mock() + mock_stderr.read.return_value = b"" + mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr) + + router_interface.connect() + + # Act + result = router_interface.enable_monitor_mode("wlan0") + + # Assert + assert result is True + mock_ssh_client.exec_command.assert_called() + + @patch('paramiko.SSHClient') + def test_disable_monitor_mode_disables_wifi_monitoring(self, mock_ssh_class, router_interface, mock_ssh_client): + """Test that disable_monitor_mode disables WiFi monitoring""" + # Arrange + mock_ssh_class.return_value = mock_ssh_client + mock_stdout = Mock() + mock_stdout.read.return_value = b"Monitor mode disabled" + mock_stderr = Mock() + mock_stderr.read.return_value = b"" + mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr) + + router_interface.connect() + + # Act + result = router_interface.disable_monitor_mode("wlan0") + + # Assert + assert result is True + mock_ssh_client.exec_command.assert_called() + + @patch('paramiko.SSHClient') + def test_interface_supports_context_manager(self, mock_ssh_class, router_interface, mock_ssh_client): + """Test that router interface supports context manager protocol""" + # Arrange + mock_ssh_class.return_value = mock_ssh_client + + # Act + with router_interface as interface: + # Assert + assert interface.is_connected is True + + # Assert - connection should be closed after context + assert router_interface.is_connected is False + mock_ssh_client.close.assert_called_once() + + def test_interface_validates_configuration(self): + """Test that router interface validates configuration parameters""" + # Arrange + invalid_config = { + 'router_ip': '', # Invalid IP + 'username': 'admin', + 'password': 'password' + } + + # Act & Assert + with pytest.raises(ValueError): + RouterInterface(invalid_config) + + @patch('paramiko.SSHClient') + def test_interface_implements_retry_logic(self, mock_ssh_class, router_interface, mock_ssh_client): + """Test that interface implements retry logic for failed operations""" + # Arrange + mock_ssh_class.return_value = mock_ssh_client + mock_ssh_client.connect.side_effect = [Exception("Temp failure"), None] # Fail once, then succeed + + # Act + result = router_interface.connect() + + # Assert + assert result is True + assert mock_ssh_client.connect.call_count == 2 # Should retry once \ No newline at end of file