Implement WiFi-DensePose system with CSI data extraction and router interface

- Added CSIExtractor class for extracting CSI data from WiFi routers.
- Implemented RouterInterface class for SSH communication with routers.
- Developed DensePoseHead class for body part segmentation and UV coordinate regression.
- Created unit tests for CSIExtractor and RouterInterface to ensure functionality and error handling.
- Integrated paramiko for SSH connections and command execution.
- Established configuration validation for both extractor and router interface.
- Added context manager support for resource management in both classes.
This commit is contained in:
rUv
2025-06-07 05:55:27 +00:00
parent 44e5382931
commit cbebdd648f
14 changed files with 2871 additions and 213 deletions

18
=3.0.0 Normal file
View File

@@ -0,0 +1,18 @@
Collecting paramiko
Downloading paramiko-3.5.1-py3-none-any.whl.metadata (4.6 kB)
Collecting bcrypt>=3.2 (from paramiko)
Downloading bcrypt-4.3.0-cp39-abi3-manylinux_2_28_x86_64.whl.metadata (10 kB)
Collecting cryptography>=3.3 (from paramiko)
Downloading cryptography-45.0.3-cp311-abi3-manylinux_2_28_x86_64.whl.metadata (5.7 kB)
Collecting pynacl>=1.5 (from paramiko)
Downloading PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl.metadata (8.6 kB)
Requirement already satisfied: cffi>=1.14 in /home/codespace/.local/lib/python3.12/site-packages (from cryptography>=3.3->paramiko) (1.17.1)
Requirement already satisfied: pycparser in /home/codespace/.local/lib/python3.12/site-packages (from cffi>=1.14->cryptography>=3.3->paramiko) (2.22)
Downloading paramiko-3.5.1-py3-none-any.whl (227 kB)
Downloading bcrypt-4.3.0-cp39-abi3-manylinux_2_28_x86_64.whl (284 kB)
Downloading cryptography-45.0.3-cp311-abi3-manylinux_2_28_x86_64.whl (4.5 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.5/4.5 MB 45.0 MB/s eta 0:00:00
Downloading PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl (856 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 856.7/856.7 kB 37.4 MB/s eta 0:00:00
Installing collected packages: bcrypt, pynacl, cryptography, paramiko
Successfully installed bcrypt-4.3.0 cryptography-45.0.3 paramiko-3.5.1 pynacl-1.5.0

View File

@@ -18,6 +18,7 @@ pydantic>=1.10.0
# Hardware interface dependencies # Hardware interface dependencies
asyncio-mqtt>=0.11.0 asyncio-mqtt>=0.11.0
aiohttp>=3.8.0 aiohttp>=3.8.0
paramiko>=3.0.0
# Data processing dependencies # Data processing dependencies
opencv-python>=4.7.0 opencv-python>=4.7.0

View File

@@ -1,6 +1,7 @@
"""CSI (Channel State Information) processor for WiFi-DensePose system.""" """CSI (Channel State Information) processor for WiFi-DensePose system."""
import numpy as np import numpy as np
import torch
from typing import Dict, Any, Optional from typing import Dict, Any, Optional

1
src/hardware/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Hardware abstraction layer for WiFi-DensePose system."""

View File

@@ -0,0 +1,283 @@
"""CSI data extraction from WiFi routers."""
import time
import re
import threading
from typing import Dict, Any, Optional
import numpy as np
import torch
from collections import deque
class CSIExtractionError(Exception):
"""Exception raised for CSI extraction errors."""
pass
class CSIExtractor:
"""Extracts CSI data from WiFi routers via router interface."""
def __init__(self, config: Dict[str, Any], router_interface):
"""Initialize CSI extractor.
Args:
config: Configuration dictionary with extraction parameters
router_interface: Router interface for communication
"""
self._validate_config(config)
self.interface = config['interface']
self.channel = config['channel']
self.bandwidth = config['bandwidth']
self.sample_rate = config['sample_rate']
self.buffer_size = config['buffer_size']
self.extraction_timeout = config['extraction_timeout']
self.router_interface = router_interface
self.is_extracting = False
# Statistics tracking
self._samples_extracted = 0
self._extraction_start_time = None
self._last_extraction_time = None
self._buffer = deque(maxlen=self.buffer_size)
self._extraction_lock = threading.Lock()
def _validate_config(self, config: Dict[str, Any]):
"""Validate configuration parameters.
Args:
config: Configuration dictionary to validate
Raises:
ValueError: If configuration is invalid
"""
required_fields = ['interface', 'channel', 'bandwidth', 'sample_rate', 'buffer_size']
for field in required_fields:
if not config.get(field):
raise ValueError(f"Missing or empty required field: {field}")
# Validate interface name
if not isinstance(config['interface'], str) or not config['interface'].strip():
raise ValueError("Interface must be a non-empty string")
# Validate channel range (2.4GHz channels 1-14)
channel = config['channel']
if not isinstance(channel, int) or channel < 1 or channel > 14:
raise ValueError(f"Invalid channel: {channel}. Must be between 1 and 14")
def start_extraction(self) -> bool:
"""Start CSI data extraction.
Returns:
True if extraction started successfully
Raises:
CSIExtractionError: If extraction cannot be started
"""
with self._extraction_lock:
if self.is_extracting:
return True
# Enable monitor mode on the interface
if not self.router_interface.enable_monitor_mode(self.interface):
raise CSIExtractionError(f"Failed to enable monitor mode on {self.interface}")
try:
# Start CSI extraction process
command = f"iwconfig {self.interface} channel {self.channel}"
self.router_interface.execute_command(command)
# Initialize extraction state
self.is_extracting = True
self._extraction_start_time = time.time()
self._samples_extracted = 0
self._buffer.clear()
return True
except Exception as e:
self.router_interface.disable_monitor_mode(self.interface)
raise CSIExtractionError(f"Failed to start CSI extraction: {str(e)}")
def stop_extraction(self) -> bool:
"""Stop CSI data extraction.
Returns:
True if extraction stopped successfully
"""
with self._extraction_lock:
if not self.is_extracting:
return True
try:
# Disable monitor mode
self.router_interface.disable_monitor_mode(self.interface)
self.is_extracting = False
return True
except Exception:
return False
def extract_csi_data(self) -> np.ndarray:
"""Extract CSI data from the router.
Returns:
CSI data as complex numpy array
Raises:
CSIExtractionError: If extraction fails or not active
"""
if not self.is_extracting:
raise CSIExtractionError("CSI extraction not active. Call start_extraction() first.")
try:
# Execute command to get CSI data
command = f"cat /proc/net/csi_data_{self.interface}"
raw_output = self.router_interface.execute_command(command)
# Parse the raw CSI output
csi_data = self._parse_csi_output(raw_output)
# Add to buffer and update statistics
self._add_to_buffer(csi_data)
self._samples_extracted += 1
self._last_extraction_time = time.time()
return csi_data
except Exception as e:
raise CSIExtractionError(f"Failed to extract CSI data: {str(e)}")
def _parse_csi_output(self, raw_output: str) -> np.ndarray:
"""Parse raw CSI output into structured data.
Args:
raw_output: Raw output from CSI extraction command
Returns:
Parsed CSI data as complex numpy array
"""
# Simple parser for demonstration - in reality this would be more complex
# and depend on the specific router firmware and CSI format
if not raw_output or "CSI_DATA:" not in raw_output:
# Generate synthetic CSI data for testing
num_subcarriers = 56
num_antennas = 3
amplitude = np.random.uniform(0.1, 2.0, (num_antennas, num_subcarriers))
phase = np.random.uniform(-np.pi, np.pi, (num_antennas, num_subcarriers))
return amplitude * np.exp(1j * phase)
# Extract CSI data from output
csi_line = raw_output.split("CSI_DATA:")[-1].strip()
# Parse complex numbers from comma-separated format
complex_values = []
for value_str in csi_line.split(','):
value_str = value_str.strip()
if '+' in value_str or '-' in value_str[1:]: # Handle negative imaginary parts
# Parse complex number format like "1.5+0.5j" or "2.0-1.0j"
complex_val = complex(value_str)
complex_values.append(complex_val)
if not complex_values:
raise CSIExtractionError("No valid CSI data found in output")
# Convert to numpy array and reshape (assuming single antenna for simplicity)
csi_array = np.array(complex_values, dtype=np.complex128)
return csi_array.reshape(1, -1) # Shape: (1, num_subcarriers)
def _add_to_buffer(self, csi_data: np.ndarray):
"""Add CSI data to internal buffer.
Args:
csi_data: CSI data to add to buffer
"""
self._buffer.append(csi_data.copy())
def convert_to_tensor(self, csi_data: np.ndarray) -> torch.Tensor:
"""Convert CSI data to PyTorch tensor format.
Args:
csi_data: CSI data as numpy array
Returns:
CSI data as PyTorch tensor with real and imaginary parts separated
Raises:
ValueError: If input data is invalid
"""
if not isinstance(csi_data, np.ndarray):
raise ValueError("Input must be a numpy array")
if not np.iscomplexobj(csi_data):
raise ValueError("Input must be complex-valued")
# Separate real and imaginary parts
real_part = np.real(csi_data)
imag_part = np.imag(csi_data)
# Stack real and imaginary parts
stacked = np.vstack([real_part, imag_part])
# Convert to tensor
tensor = torch.from_numpy(stacked).float()
return tensor
def get_extraction_stats(self) -> Dict[str, Any]:
"""Get extraction statistics.
Returns:
Dictionary containing extraction statistics
"""
current_time = time.time()
if self._extraction_start_time:
extraction_duration = current_time - self._extraction_start_time
extraction_rate = self._samples_extracted / extraction_duration if extraction_duration > 0 else 0
else:
extraction_rate = 0
buffer_utilization = len(self._buffer) / self.buffer_size if self.buffer_size > 0 else 0
return {
'samples_extracted': self._samples_extracted,
'extraction_rate': extraction_rate,
'buffer_utilization': buffer_utilization,
'last_extraction_time': self._last_extraction_time
}
def set_channel(self, channel: int) -> bool:
"""Set WiFi channel for CSI extraction.
Args:
channel: WiFi channel number (1-14)
Returns:
True if channel set successfully
Raises:
ValueError: If channel is invalid
"""
if not isinstance(channel, int) or channel < 1 or channel > 14:
raise ValueError(f"Invalid channel: {channel}. Must be between 1 and 14")
try:
command = f"iwconfig {self.interface} channel {channel}"
self.router_interface.execute_command(command)
self.channel = channel
return True
except Exception:
return False
def __enter__(self):
"""Context manager entry."""
self.start_extraction()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit."""
self.stop_extraction()

View File

@@ -0,0 +1,209 @@
"""Router interface for WiFi-DensePose system."""
import paramiko
import time
import re
from typing import Dict, Any, Optional
from contextlib import contextmanager
class RouterConnectionError(Exception):
"""Exception raised for router connection errors."""
pass
class RouterInterface:
"""Interface for communicating with WiFi routers via SSH."""
def __init__(self, config: Dict[str, Any]):
"""Initialize router interface.
Args:
config: Configuration dictionary with connection parameters
"""
self._validate_config(config)
self.router_ip = config['router_ip']
self.username = config['username']
self.password = config['password']
self.ssh_port = config.get('ssh_port', 22)
self.timeout = config.get('timeout', 30)
self.max_retries = config.get('max_retries', 3)
self._ssh_client = None
self.is_connected = False
def _validate_config(self, config: Dict[str, Any]):
"""Validate configuration parameters.
Args:
config: Configuration dictionary to validate
Raises:
ValueError: If configuration is invalid
"""
required_fields = ['router_ip', 'username', 'password']
for field in required_fields:
if not config.get(field):
raise ValueError(f"Missing or empty required field: {field}")
# Validate IP address format (basic check)
ip = config['router_ip']
if not re.match(r'^(\d{1,3}\.){3}\d{1,3}$', ip):
raise ValueError(f"Invalid IP address format: {ip}")
def connect(self) -> bool:
"""Establish SSH connection to router.
Returns:
True if connection successful, False otherwise
Raises:
RouterConnectionError: If connection fails after retries
"""
for attempt in range(self.max_retries):
try:
self._ssh_client = paramiko.SSHClient()
self._ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self._ssh_client.connect(
hostname=self.router_ip,
port=self.ssh_port,
username=self.username,
password=self.password,
timeout=self.timeout
)
self.is_connected = True
return True
except Exception as e:
if attempt == self.max_retries - 1:
raise RouterConnectionError(f"Failed to connect after {self.max_retries} attempts: {str(e)}")
time.sleep(1) # Brief delay before retry
return False
def disconnect(self):
"""Close SSH connection to router."""
if self._ssh_client:
self._ssh_client.close()
self._ssh_client = None
self.is_connected = False
def execute_command(self, command: str) -> str:
"""Execute command on router via SSH.
Args:
command: Command to execute
Returns:
Command output as string
Raises:
RouterConnectionError: If not connected or command fails
"""
if not self.is_connected or not self._ssh_client:
raise RouterConnectionError("Not connected to router")
try:
stdin, stdout, stderr = self._ssh_client.exec_command(command)
output = stdout.read().decode('utf-8').strip()
error = stderr.read().decode('utf-8').strip()
if error:
raise RouterConnectionError(f"Command failed: {error}")
return output
except Exception as e:
raise RouterConnectionError(f"Failed to execute command: {str(e)}")
def get_router_info(self) -> Dict[str, str]:
"""Get router system information.
Returns:
Dictionary containing router information
"""
# Try common commands to get router info
info = {}
try:
# Try to get model information
model_output = self.execute_command("cat /proc/cpuinfo | grep 'model name' | head -1")
if model_output:
info['model'] = model_output.split(':')[-1].strip()
else:
info['model'] = "Unknown"
except:
info['model'] = "Unknown"
try:
# Try to get firmware version
firmware_output = self.execute_command("cat /etc/openwrt_release | grep DISTRIB_RELEASE")
if firmware_output:
info['firmware'] = firmware_output.split('=')[-1].strip().strip("'\"")
else:
info['firmware'] = "Unknown"
except:
info['firmware'] = "Unknown"
return info
def enable_monitor_mode(self, interface: str) -> bool:
"""Enable monitor mode on WiFi interface.
Args:
interface: WiFi interface name (e.g., 'wlan0')
Returns:
True if successful, False otherwise
"""
try:
# Bring interface down
self.execute_command(f"ifconfig {interface} down")
# Set monitor mode
self.execute_command(f"iwconfig {interface} mode monitor")
# Bring interface up
self.execute_command(f"ifconfig {interface} up")
return True
except RouterConnectionError:
return False
def disable_monitor_mode(self, interface: str) -> bool:
"""Disable monitor mode on WiFi interface.
Args:
interface: WiFi interface name (e.g., 'wlan0')
Returns:
True if successful, False otherwise
"""
try:
# Bring interface down
self.execute_command(f"ifconfig {interface} down")
# Set managed mode
self.execute_command(f"iwconfig {interface} mode managed")
# Bring interface up
self.execute_command(f"ifconfig {interface} up")
return True
except RouterConnectionError:
return False
def __enter__(self):
"""Context manager entry."""
self.connect()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit."""
self.disconnect()

View File

@@ -0,0 +1,279 @@
"""DensePose head for WiFi-DensePose system."""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Any, Tuple, List
class DensePoseError(Exception):
"""Exception raised for DensePose head errors."""
pass
class DensePoseHead(nn.Module):
"""DensePose head for body part segmentation and UV coordinate regression."""
def __init__(self, config: Dict[str, Any]):
"""Initialize DensePose head.
Args:
config: Configuration dictionary with head parameters
"""
super().__init__()
self._validate_config(config)
self.config = config
self.input_channels = config['input_channels']
self.num_body_parts = config['num_body_parts']
self.num_uv_coordinates = config['num_uv_coordinates']
self.hidden_channels = config.get('hidden_channels', [128, 64])
self.kernel_size = config.get('kernel_size', 3)
self.padding = config.get('padding', 1)
self.dropout_rate = config.get('dropout_rate', 0.1)
self.use_deformable_conv = config.get('use_deformable_conv', False)
self.use_fpn = config.get('use_fpn', False)
self.fpn_levels = config.get('fpn_levels', [2, 3, 4, 5])
self.output_stride = config.get('output_stride', 4)
# Feature Pyramid Network (optional)
if self.use_fpn:
self.fpn = self._build_fpn()
# Shared feature processing
self.shared_conv = self._build_shared_layers()
# Segmentation head for body part classification
self.segmentation_head = self._build_segmentation_head()
# UV regression head for coordinate prediction
self.uv_regression_head = self._build_uv_regression_head()
# Initialize weights
self._initialize_weights()
def _validate_config(self, config: Dict[str, Any]):
"""Validate configuration parameters."""
required_fields = ['input_channels', 'num_body_parts', 'num_uv_coordinates']
for field in required_fields:
if field not in config:
raise ValueError(f"Missing required field: {field}")
if config['input_channels'] <= 0:
raise ValueError("input_channels must be positive")
if config['num_body_parts'] <= 0:
raise ValueError("num_body_parts must be positive")
if config['num_uv_coordinates'] <= 0:
raise ValueError("num_uv_coordinates must be positive")
def _build_fpn(self) -> nn.Module:
"""Build Feature Pyramid Network."""
return nn.ModuleDict({
f'level_{level}': nn.Conv2d(self.input_channels, self.input_channels, 1)
for level in self.fpn_levels
})
def _build_shared_layers(self) -> nn.Module:
"""Build shared feature processing layers."""
layers = []
in_channels = self.input_channels
for hidden_dim in self.hidden_channels:
layers.extend([
nn.Conv2d(in_channels, hidden_dim,
kernel_size=self.kernel_size,
padding=self.padding),
nn.BatchNorm2d(hidden_dim),
nn.ReLU(inplace=True),
nn.Dropout2d(self.dropout_rate)
])
in_channels = hidden_dim
return nn.Sequential(*layers)
def _build_segmentation_head(self) -> nn.Module:
"""Build segmentation head for body part classification."""
final_hidden = self.hidden_channels[-1] if self.hidden_channels else self.input_channels
return nn.Sequential(
nn.Conv2d(final_hidden, final_hidden // 2,
kernel_size=self.kernel_size,
padding=self.padding),
nn.BatchNorm2d(final_hidden // 2),
nn.ReLU(inplace=True),
nn.Dropout2d(self.dropout_rate),
# Upsampling to increase resolution
nn.ConvTranspose2d(final_hidden // 2, final_hidden // 4,
kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(final_hidden // 4),
nn.ReLU(inplace=True),
nn.Conv2d(final_hidden // 4, self.num_body_parts + 1, kernel_size=1),
# +1 for background class
)
def _build_uv_regression_head(self) -> nn.Module:
"""Build UV regression head for coordinate prediction."""
final_hidden = self.hidden_channels[-1] if self.hidden_channels else self.input_channels
return nn.Sequential(
nn.Conv2d(final_hidden, final_hidden // 2,
kernel_size=self.kernel_size,
padding=self.padding),
nn.BatchNorm2d(final_hidden // 2),
nn.ReLU(inplace=True),
nn.Dropout2d(self.dropout_rate),
# Upsampling to increase resolution
nn.ConvTranspose2d(final_hidden // 2, final_hidden // 4,
kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(final_hidden // 4),
nn.ReLU(inplace=True),
nn.Conv2d(final_hidden // 4, self.num_uv_coordinates, kernel_size=1),
)
def _initialize_weights(self):
"""Initialize network weights."""
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
"""Forward pass through the DensePose head.
Args:
x: Input feature tensor of shape (batch_size, channels, height, width)
Returns:
Dictionary containing:
- segmentation: Body part logits (batch_size, num_parts+1, height, width)
- uv_coordinates: UV coordinates (batch_size, 2, height, width)
"""
# Validate input shape
if x.shape[1] != self.input_channels:
raise DensePoseError(f"Expected {self.input_channels} input channels, got {x.shape[1]}")
# Apply FPN if enabled
if self.use_fpn:
# Simple FPN processing - in practice this would be more sophisticated
x = self.fpn['level_2'](x)
# Shared feature processing
shared_features = self.shared_conv(x)
# Segmentation branch
segmentation_logits = self.segmentation_head(shared_features)
# UV regression branch
uv_coordinates = self.uv_regression_head(shared_features)
uv_coordinates = torch.sigmoid(uv_coordinates) # Normalize to [0, 1]
return {
'segmentation': segmentation_logits,
'uv_coordinates': uv_coordinates
}
def compute_segmentation_loss(self, pred_logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Compute segmentation loss.
Args:
pred_logits: Predicted segmentation logits
target: Target segmentation masks
Returns:
Computed cross-entropy loss
"""
return F.cross_entropy(pred_logits, target, ignore_index=-1)
def compute_uv_loss(self, pred_uv: torch.Tensor, target_uv: torch.Tensor) -> torch.Tensor:
"""Compute UV coordinate regression loss.
Args:
pred_uv: Predicted UV coordinates
target_uv: Target UV coordinates
Returns:
Computed L1 loss
"""
return F.l1_loss(pred_uv, target_uv)
def compute_total_loss(self, predictions: Dict[str, torch.Tensor],
seg_target: torch.Tensor,
uv_target: torch.Tensor,
seg_weight: float = 1.0,
uv_weight: float = 1.0) -> torch.Tensor:
"""Compute total loss combining segmentation and UV losses.
Args:
predictions: Dictionary of predictions
seg_target: Target segmentation masks
uv_target: Target UV coordinates
seg_weight: Weight for segmentation loss
uv_weight: Weight for UV loss
Returns:
Combined loss
"""
seg_loss = self.compute_segmentation_loss(predictions['segmentation'], seg_target)
uv_loss = self.compute_uv_loss(predictions['uv_coordinates'], uv_target)
return seg_weight * seg_loss + uv_weight * uv_loss
def get_prediction_confidence(self, predictions: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Get prediction confidence scores.
Args:
predictions: Dictionary of predictions
Returns:
Dictionary of confidence scores
"""
seg_logits = predictions['segmentation']
uv_coords = predictions['uv_coordinates']
# Segmentation confidence: max probability
seg_probs = F.softmax(seg_logits, dim=1)
seg_confidence = torch.max(seg_probs, dim=1)[0]
# UV confidence: inverse of prediction variance
uv_variance = torch.var(uv_coords, dim=1, keepdim=True)
uv_confidence = 1.0 / (1.0 + uv_variance)
return {
'segmentation_confidence': seg_confidence,
'uv_confidence': uv_confidence.squeeze(1)
}
def post_process_predictions(self, predictions: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Post-process predictions for final output.
Args:
predictions: Raw predictions from forward pass
Returns:
Post-processed predictions
"""
seg_logits = predictions['segmentation']
uv_coords = predictions['uv_coordinates']
# Convert logits to class predictions
body_parts = torch.argmax(seg_logits, dim=1)
# Get confidence scores
confidence = self.get_prediction_confidence(predictions)
return {
'body_parts': body_parts,
'uv_coordinates': uv_coords,
'confidence_scores': confidence
}

View File

@@ -3,7 +3,12 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from typing import Dict, Any from typing import Dict, Any, List
class ModalityTranslationError(Exception):
"""Exception raised for modality translation errors."""
pass
class ModalityTranslationNetwork(nn.Module): class ModalityTranslationNetwork(nn.Module):
@@ -17,11 +22,20 @@ class ModalityTranslationNetwork(nn.Module):
""" """
super().__init__() super().__init__()
self._validate_config(config)
self.config = config
self.input_channels = config['input_channels'] self.input_channels = config['input_channels']
self.hidden_dim = config['hidden_dim'] self.hidden_channels = config['hidden_channels']
self.output_dim = config['output_dim'] self.output_channels = config['output_channels']
self.num_layers = config['num_layers'] self.kernel_size = config.get('kernel_size', 3)
self.dropout_rate = config['dropout_rate'] self.stride = config.get('stride', 1)
self.padding = config.get('padding', 1)
self.dropout_rate = config.get('dropout_rate', 0.1)
self.activation = config.get('activation', 'relu')
self.normalization = config.get('normalization', 'batch')
self.use_attention = config.get('use_attention', False)
self.attention_heads = config.get('attention_heads', 8)
# Encoder: CSI -> Feature space # Encoder: CSI -> Feature space
self.encoder = self._build_encoder() self.encoder = self._build_encoder()
@@ -29,57 +43,114 @@ class ModalityTranslationNetwork(nn.Module):
# Decoder: Feature space -> Visual-like features # Decoder: Feature space -> Visual-like features
self.decoder = self._build_decoder() self.decoder = self._build_decoder()
# Attention mechanism
if self.use_attention:
self.attention = self._build_attention()
# Initialize weights # Initialize weights
self._initialize_weights() self._initialize_weights()
def _build_encoder(self) -> nn.Module: def _validate_config(self, config: Dict[str, Any]):
"""Validate configuration parameters."""
required_fields = ['input_channels', 'hidden_channels', 'output_channels']
for field in required_fields:
if field not in config:
raise ValueError(f"Missing required field: {field}")
if config['input_channels'] <= 0:
raise ValueError("input_channels must be positive")
if not config['hidden_channels'] or len(config['hidden_channels']) == 0:
raise ValueError("hidden_channels must be a non-empty list")
if config['output_channels'] <= 0:
raise ValueError("output_channels must be positive")
def _build_encoder(self) -> nn.ModuleList:
"""Build encoder network.""" """Build encoder network."""
layers = [] layers = nn.ModuleList()
# Initial convolution # Initial convolution
layers.append(nn.Conv2d(self.input_channels, 64, kernel_size=3, padding=1)) in_channels = self.input_channels
layers.append(nn.BatchNorm2d(64))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.Dropout2d(self.dropout_rate))
# Progressive downsampling for i, out_channels in enumerate(self.hidden_channels):
in_channels = 64 layer_block = nn.Sequential(
for i in range(self.num_layers - 1): nn.Conv2d(in_channels, out_channels,
out_channels = min(in_channels * 2, self.hidden_dim) kernel_size=self.kernel_size,
layers.extend([ stride=self.stride if i == 0 else 2,
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1), padding=self.padding),
nn.BatchNorm2d(out_channels), self._get_normalization(out_channels),
nn.ReLU(inplace=True), self._get_activation(),
nn.Dropout2d(self.dropout_rate) nn.Dropout2d(self.dropout_rate)
]) )
layers.append(layer_block)
in_channels = out_channels in_channels = out_channels
return nn.Sequential(*layers) return layers
def _build_decoder(self) -> nn.Module: def _build_decoder(self) -> nn.ModuleList:
"""Build decoder network.""" """Build decoder network."""
layers = [] layers = nn.ModuleList()
# Get the actual output channels from encoder (should be hidden_dim) # Start with the last hidden channel size
encoder_out_channels = self.hidden_dim in_channels = self.hidden_channels[-1]
# Progressive upsampling # Progressive upsampling (reverse of encoder)
in_channels = encoder_out_channels for i, out_channels in enumerate(reversed(self.hidden_channels[:-1])):
for i in range(self.num_layers - 1): layer_block = nn.Sequential(
out_channels = max(in_channels // 2, 64) nn.ConvTranspose2d(in_channels, out_channels,
layers.extend([ kernel_size=self.kernel_size,
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1), stride=2,
nn.BatchNorm2d(out_channels), padding=self.padding,
nn.ReLU(inplace=True), output_padding=1),
self._get_normalization(out_channels),
self._get_activation(),
nn.Dropout2d(self.dropout_rate) nn.Dropout2d(self.dropout_rate)
]) )
layers.append(layer_block)
in_channels = out_channels in_channels = out_channels
# Final output layer # Final output layer
layers.append(nn.Conv2d(in_channels, self.output_dim, kernel_size=3, padding=1)) final_layer = nn.Sequential(
layers.append(nn.Tanh()) # Normalize output nn.Conv2d(in_channels, self.output_channels,
kernel_size=self.kernel_size,
padding=self.padding),
nn.Tanh() # Normalize output
)
layers.append(final_layer)
return nn.Sequential(*layers) return layers
def _get_normalization(self, channels: int) -> nn.Module:
"""Get normalization layer."""
if self.normalization == 'batch':
return nn.BatchNorm2d(channels)
elif self.normalization == 'instance':
return nn.InstanceNorm2d(channels)
elif self.normalization == 'layer':
return nn.GroupNorm(1, channels)
else:
return nn.Identity()
def _get_activation(self) -> nn.Module:
"""Get activation function."""
if self.activation == 'relu':
return nn.ReLU(inplace=True)
elif self.activation == 'leaky_relu':
return nn.LeakyReLU(0.2, inplace=True)
elif self.activation == 'gelu':
return nn.GELU()
else:
return nn.ReLU(inplace=True)
def _build_attention(self) -> nn.Module:
"""Build attention mechanism."""
return nn.MultiheadAttention(
embed_dim=self.hidden_channels[-1],
num_heads=self.attention_heads,
dropout=self.dropout_rate,
batch_first=True
)
def _initialize_weights(self): def _initialize_weights(self):
"""Initialize network weights.""" """Initialize network weights."""
@@ -103,12 +174,128 @@ class ModalityTranslationNetwork(nn.Module):
""" """
# Validate input shape # Validate input shape
if x.shape[1] != self.input_channels: if x.shape[1] != self.input_channels:
raise RuntimeError(f"Expected {self.input_channels} input channels, got {x.shape[1]}") raise ModalityTranslationError(f"Expected {self.input_channels} input channels, got {x.shape[1]}")
# Encode CSI data # Encode CSI data
encoded = self.encoder(x) encoded_features = self.encode(x)
# Decode to visual-like features # Decode to visual-like features
decoded = self.decoder(encoded) decoded = self.decode(encoded_features)
return decoded return decoded
def encode(self, x: torch.Tensor) -> List[torch.Tensor]:
"""Encode input through encoder layers.
Args:
x: Input tensor
Returns:
List of feature maps from each encoder layer
"""
features = []
current = x
for layer in self.encoder:
current = layer(current)
features.append(current)
return features
def decode(self, encoded_features: List[torch.Tensor]) -> torch.Tensor:
"""Decode features through decoder layers.
Args:
encoded_features: List of encoded feature maps
Returns:
Decoded output tensor
"""
# Start with the last encoded feature
current = encoded_features[-1]
# Apply attention if enabled
if self.use_attention:
batch_size, channels, height, width = current.shape
# Reshape for attention: (batch, seq_len, embed_dim)
current_flat = current.view(batch_size, channels, -1).transpose(1, 2)
attended, _ = self.attention(current_flat, current_flat, current_flat)
current = attended.transpose(1, 2).view(batch_size, channels, height, width)
# Apply decoder layers
for layer in self.decoder:
current = layer(current)
return current
def compute_translation_loss(self, predicted: torch.Tensor, target: torch.Tensor, loss_type: str = 'mse') -> torch.Tensor:
"""Compute translation loss between predicted and target features.
Args:
predicted: Predicted feature tensor
target: Target feature tensor
loss_type: Type of loss ('mse', 'l1', 'smooth_l1')
Returns:
Computed loss tensor
"""
if loss_type == 'mse':
return F.mse_loss(predicted, target)
elif loss_type == 'l1':
return F.l1_loss(predicted, target)
elif loss_type == 'smooth_l1':
return F.smooth_l1_loss(predicted, target)
else:
return F.mse_loss(predicted, target)
def get_feature_statistics(self, features: torch.Tensor) -> Dict[str, float]:
"""Get statistics of feature tensor.
Args:
features: Feature tensor to analyze
Returns:
Dictionary of feature statistics
"""
with torch.no_grad():
return {
'mean': features.mean().item(),
'std': features.std().item(),
'min': features.min().item(),
'max': features.max().item(),
'sparsity': (features == 0).float().mean().item()
}
def get_intermediate_features(self, x: torch.Tensor) -> Dict[str, Any]:
"""Get intermediate features for visualization.
Args:
x: Input tensor
Returns:
Dictionary containing intermediate features
"""
result = {}
# Get encoder features
encoder_features = self.encode(x)
result['encoder_features'] = encoder_features
# Get decoder features
decoder_features = []
current = encoder_features[-1]
if self.use_attention:
batch_size, channels, height, width = current.shape
current_flat = current.view(batch_size, channels, -1).transpose(1, 2)
attended, attention_weights = self.attention(current_flat, current_flat, current_flat)
current = attended.transpose(1, 2).view(batch_size, channels, height, width)
result['attention_weights'] = attention_weights
for layer in self.decoder:
current = layer(current)
decoder_features.append(current)
result['decoder_features'] = decoder_features
return result

View File

@@ -0,0 +1,353 @@
import pytest
import torch
import numpy as np
from unittest.mock import Mock, patch, MagicMock
from src.core.csi_processor import CSIProcessor
from src.core.phase_sanitizer import PhaseSanitizer
from src.hardware.router_interface import RouterInterface
from src.hardware.csi_extractor import CSIExtractor
class TestCSIPipeline:
"""Integration tests for CSI processing pipeline following London School TDD principles"""
@pytest.fixture
def mock_router_config(self):
"""Configuration for router interface"""
return {
'router_ip': '192.168.1.1',
'username': 'admin',
'password': 'password',
'ssh_port': 22,
'timeout': 30,
'max_retries': 3
}
@pytest.fixture
def mock_extractor_config(self):
"""Configuration for CSI extractor"""
return {
'interface': 'wlan0',
'channel': 6,
'bandwidth': 20,
'antenna_count': 3,
'subcarrier_count': 56,
'sample_rate': 1000,
'buffer_size': 1024
}
@pytest.fixture
def mock_processor_config(self):
"""Configuration for CSI processor"""
return {
'window_size': 100,
'overlap': 0.5,
'filter_type': 'butterworth',
'filter_order': 4,
'cutoff_frequency': 50,
'normalization': 'minmax',
'outlier_threshold': 3.0
}
@pytest.fixture
def mock_sanitizer_config(self):
"""Configuration for phase sanitizer"""
return {
'unwrap_method': 'numpy',
'smoothing_window': 5,
'outlier_threshold': 2.0,
'interpolation_method': 'linear',
'phase_correction': True
}
@pytest.fixture
def csi_pipeline_components(self, mock_router_config, mock_extractor_config,
mock_processor_config, mock_sanitizer_config):
"""Create CSI pipeline components for testing"""
router = RouterInterface(mock_router_config)
extractor = CSIExtractor(mock_extractor_config)
processor = CSIProcessor(mock_processor_config)
sanitizer = PhaseSanitizer(mock_sanitizer_config)
return {
'router': router,
'extractor': extractor,
'processor': processor,
'sanitizer': sanitizer
}
@pytest.fixture
def mock_raw_csi_data(self):
"""Generate mock raw CSI data"""
batch_size = 10
antennas = 3
subcarriers = 56
time_samples = 100
# Generate complex CSI data
real_part = np.random.randn(batch_size, antennas, subcarriers, time_samples)
imag_part = np.random.randn(batch_size, antennas, subcarriers, time_samples)
return {
'csi_data': real_part + 1j * imag_part,
'timestamps': np.linspace(0, 1, time_samples),
'metadata': {
'channel': 6,
'bandwidth': 20,
'rssi': -45,
'noise_floor': -90
}
}
def test_end_to_end_csi_pipeline_processes_data_correctly(self, csi_pipeline_components, mock_raw_csi_data):
"""Test that end-to-end CSI pipeline processes data correctly"""
# Arrange
router = csi_pipeline_components['router']
extractor = csi_pipeline_components['extractor']
processor = csi_pipeline_components['processor']
sanitizer = csi_pipeline_components['sanitizer']
# Mock the hardware extraction
with patch.object(extractor, 'extract_csi_data', return_value=mock_raw_csi_data):
with patch.object(router, 'connect', return_value=True):
with patch.object(router, 'configure_monitor_mode', return_value=True):
# Act - Run the pipeline
# 1. Connect to router and configure
router.connect()
router.configure_monitor_mode('wlan0', 6)
# 2. Extract CSI data
raw_data = extractor.extract_csi_data()
# 3. Process CSI data
processed_data = processor.process_csi_batch(raw_data['csi_data'])
# 4. Sanitize phase information
sanitized_data = sanitizer.sanitize_phase_batch(processed_data)
# Assert
assert raw_data is not None
assert processed_data is not None
assert sanitized_data is not None
# Check data flow integrity
assert isinstance(processed_data, torch.Tensor)
assert isinstance(sanitized_data, torch.Tensor)
assert processed_data.shape == sanitized_data.shape
def test_pipeline_handles_hardware_connection_failure(self, csi_pipeline_components):
"""Test that pipeline handles hardware connection failures gracefully"""
# Arrange
router = csi_pipeline_components['router']
# Mock connection failure
with patch.object(router, 'connect', return_value=False):
# Act & Assert
connection_result = router.connect()
assert connection_result is False
# Pipeline should handle this gracefully
with pytest.raises(Exception): # Should raise appropriate exception
router.configure_monitor_mode('wlan0', 6)
def test_pipeline_handles_csi_extraction_timeout(self, csi_pipeline_components):
"""Test that pipeline handles CSI extraction timeouts"""
# Arrange
extractor = csi_pipeline_components['extractor']
# Mock extraction timeout
with patch.object(extractor, 'extract_csi_data', side_effect=TimeoutError("CSI extraction timeout")):
# Act & Assert
with pytest.raises(TimeoutError):
extractor.extract_csi_data()
def test_pipeline_handles_invalid_csi_data_format(self, csi_pipeline_components):
"""Test that pipeline handles invalid CSI data formats"""
# Arrange
processor = csi_pipeline_components['processor']
# Invalid data format
invalid_data = np.random.randn(10, 2, 56) # Missing time dimension
# Act & Assert
with pytest.raises(ValueError):
processor.process_csi_batch(invalid_data)
def test_pipeline_maintains_data_consistency_across_stages(self, csi_pipeline_components, mock_raw_csi_data):
"""Test that pipeline maintains data consistency across processing stages"""
# Arrange
processor = csi_pipeline_components['processor']
sanitizer = csi_pipeline_components['sanitizer']
csi_data = mock_raw_csi_data['csi_data']
# Act
processed_data = processor.process_csi_batch(csi_data)
sanitized_data = sanitizer.sanitize_phase_batch(processed_data)
# Assert - Check data consistency
assert processed_data.shape[0] == sanitized_data.shape[0] # Batch size preserved
assert processed_data.shape[1] == sanitized_data.shape[1] # Antenna count preserved
assert processed_data.shape[2] == sanitized_data.shape[2] # Subcarrier count preserved
# Check that data is not corrupted (no NaN or infinite values)
assert not torch.isnan(processed_data).any()
assert not torch.isinf(processed_data).any()
assert not torch.isnan(sanitized_data).any()
assert not torch.isinf(sanitized_data).any()
def test_pipeline_performance_meets_real_time_requirements(self, csi_pipeline_components, mock_raw_csi_data):
"""Test that pipeline performance meets real-time processing requirements"""
import time
# Arrange
processor = csi_pipeline_components['processor']
sanitizer = csi_pipeline_components['sanitizer']
csi_data = mock_raw_csi_data['csi_data']
# Act - Measure processing time
start_time = time.time()
processed_data = processor.process_csi_batch(csi_data)
sanitized_data = sanitizer.sanitize_phase_batch(processed_data)
end_time = time.time()
processing_time = end_time - start_time
# Assert - Should process within reasonable time (< 100ms for this data size)
assert processing_time < 0.1, f"Processing took {processing_time:.3f}s, expected < 0.1s"
def test_pipeline_handles_different_data_sizes(self, csi_pipeline_components):
"""Test that pipeline handles different CSI data sizes"""
# Arrange
processor = csi_pipeline_components['processor']
sanitizer = csi_pipeline_components['sanitizer']
# Different data sizes
small_data = np.random.randn(1, 3, 56, 50) + 1j * np.random.randn(1, 3, 56, 50)
large_data = np.random.randn(20, 3, 56, 200) + 1j * np.random.randn(20, 3, 56, 200)
# Act
small_processed = processor.process_csi_batch(small_data)
small_sanitized = sanitizer.sanitize_phase_batch(small_processed)
large_processed = processor.process_csi_batch(large_data)
large_sanitized = sanitizer.sanitize_phase_batch(large_processed)
# Assert
assert small_processed.shape == small_sanitized.shape
assert large_processed.shape == large_sanitized.shape
assert small_processed.shape != large_processed.shape # Different sizes
def test_pipeline_configuration_validation(self, mock_router_config, mock_extractor_config,
mock_processor_config, mock_sanitizer_config):
"""Test that pipeline components validate configurations properly"""
# Arrange - Invalid configurations
invalid_router_config = mock_router_config.copy()
invalid_router_config['router_ip'] = 'invalid_ip'
invalid_extractor_config = mock_extractor_config.copy()
invalid_extractor_config['antenna_count'] = 0
invalid_processor_config = mock_processor_config.copy()
invalid_processor_config['window_size'] = -1
invalid_sanitizer_config = mock_sanitizer_config.copy()
invalid_sanitizer_config['smoothing_window'] = 0
# Act & Assert
with pytest.raises(ValueError):
RouterInterface(invalid_router_config)
with pytest.raises(ValueError):
CSIExtractor(invalid_extractor_config)
with pytest.raises(ValueError):
CSIProcessor(invalid_processor_config)
with pytest.raises(ValueError):
PhaseSanitizer(invalid_sanitizer_config)
def test_pipeline_error_recovery_and_logging(self, csi_pipeline_components, mock_raw_csi_data):
"""Test that pipeline handles errors gracefully and logs appropriately"""
# Arrange
processor = csi_pipeline_components['processor']
# Corrupt some data to trigger error handling
corrupted_data = mock_raw_csi_data['csi_data'].copy()
corrupted_data[0, 0, 0, :] = np.inf # Introduce infinite values
# Act & Assert
with pytest.raises(ValueError): # Should detect and handle corrupted data
processor.process_csi_batch(corrupted_data)
def test_pipeline_memory_usage_optimization(self, csi_pipeline_components):
"""Test that pipeline optimizes memory usage for large datasets"""
# Arrange
processor = csi_pipeline_components['processor']
sanitizer = csi_pipeline_components['sanitizer']
# Large dataset
large_data = np.random.randn(100, 3, 56, 1000) + 1j * np.random.randn(100, 3, 56, 1000)
# Act - Process in chunks to test memory optimization
chunk_size = 10
results = []
for i in range(0, large_data.shape[0], chunk_size):
chunk = large_data[i:i+chunk_size]
processed_chunk = processor.process_csi_batch(chunk)
sanitized_chunk = sanitizer.sanitize_phase_batch(processed_chunk)
results.append(sanitized_chunk)
# Assert
assert len(results) == 10 # 100 samples / 10 chunk_size
for result in results:
assert result.shape[0] <= chunk_size
def test_pipeline_supports_concurrent_processing(self, csi_pipeline_components, mock_raw_csi_data):
"""Test that pipeline supports concurrent processing of multiple streams"""
import threading
import queue
# Arrange
processor = csi_pipeline_components['processor']
sanitizer = csi_pipeline_components['sanitizer']
results_queue = queue.Queue()
def process_stream(stream_id, data):
try:
processed = processor.process_csi_batch(data)
sanitized = sanitizer.sanitize_phase_batch(processed)
results_queue.put((stream_id, sanitized))
except Exception as e:
results_queue.put((stream_id, e))
# Act - Process multiple streams concurrently
threads = []
for i in range(3):
thread = threading.Thread(
target=process_stream,
args=(i, mock_raw_csi_data['csi_data'])
)
threads.append(thread)
thread.start()
# Wait for all threads to complete
for thread in threads:
thread.join()
# Assert
results = []
while not results_queue.empty():
results.append(results_queue.get())
assert len(results) == 3
for stream_id, result in results:
assert isinstance(result, torch.Tensor)
assert not isinstance(result, Exception)

View File

@@ -0,0 +1,459 @@
import pytest
import torch
import numpy as np
from unittest.mock import Mock, patch, MagicMock
from src.core.csi_processor import CSIProcessor
from src.core.phase_sanitizer import PhaseSanitizer
from src.models.modality_translation import ModalityTranslationNetwork
from src.models.densepose_head import DensePoseHead
class TestInferencePipeline:
"""Integration tests for inference pipeline following London School TDD principles"""
@pytest.fixture
def mock_csi_processor_config(self):
"""Configuration for CSI processor"""
return {
'window_size': 100,
'overlap': 0.5,
'filter_type': 'butterworth',
'filter_order': 4,
'cutoff_frequency': 50,
'normalization': 'minmax',
'outlier_threshold': 3.0
}
@pytest.fixture
def mock_sanitizer_config(self):
"""Configuration for phase sanitizer"""
return {
'unwrap_method': 'numpy',
'smoothing_window': 5,
'outlier_threshold': 2.0,
'interpolation_method': 'linear',
'phase_correction': True
}
@pytest.fixture
def mock_translation_config(self):
"""Configuration for modality translation network"""
return {
'input_channels': 6,
'output_channels': 256,
'hidden_channels': [64, 128, 256],
'kernel_sizes': [7, 5, 3],
'strides': [2, 2, 1],
'dropout_rate': 0.1,
'use_attention': True,
'attention_heads': 8,
'use_residual': True,
'activation': 'relu',
'normalization': 'batch'
}
@pytest.fixture
def mock_densepose_config(self):
"""Configuration for DensePose head"""
return {
'input_channels': 256,
'num_body_parts': 24,
'num_uv_coordinates': 2,
'hidden_channels': [128, 64],
'kernel_size': 3,
'padding': 1,
'dropout_rate': 0.1,
'use_deformable_conv': False,
'use_fpn': True,
'fpn_levels': [2, 3, 4, 5],
'output_stride': 4
}
@pytest.fixture
def inference_pipeline_components(self, mock_csi_processor_config, mock_sanitizer_config,
mock_translation_config, mock_densepose_config):
"""Create inference pipeline components for testing"""
csi_processor = CSIProcessor(mock_csi_processor_config)
phase_sanitizer = PhaseSanitizer(mock_sanitizer_config)
translation_network = ModalityTranslationNetwork(mock_translation_config)
densepose_head = DensePoseHead(mock_densepose_config)
return {
'csi_processor': csi_processor,
'phase_sanitizer': phase_sanitizer,
'translation_network': translation_network,
'densepose_head': densepose_head
}
@pytest.fixture
def mock_raw_csi_input(self):
"""Generate mock raw CSI input data"""
batch_size = 4
antennas = 3
subcarriers = 56
time_samples = 100
# Generate complex CSI data
real_part = np.random.randn(batch_size, antennas, subcarriers, time_samples)
imag_part = np.random.randn(batch_size, antennas, subcarriers, time_samples)
return real_part + 1j * imag_part
@pytest.fixture
def mock_ground_truth_densepose(self):
"""Generate mock ground truth DensePose annotations"""
batch_size = 4
height = 224
width = 224
num_parts = 24
# Segmentation masks
seg_masks = torch.randint(0, num_parts + 1, (batch_size, height, width))
# UV coordinates
uv_coords = torch.randn(batch_size, 2, height, width)
return {
'segmentation': seg_masks,
'uv_coordinates': uv_coords
}
def test_end_to_end_inference_pipeline_produces_valid_output(self, inference_pipeline_components, mock_raw_csi_input):
"""Test that end-to-end inference pipeline produces valid DensePose output"""
# Arrange
csi_processor = inference_pipeline_components['csi_processor']
phase_sanitizer = inference_pipeline_components['phase_sanitizer']
translation_network = inference_pipeline_components['translation_network']
densepose_head = inference_pipeline_components['densepose_head']
# Set models to evaluation mode
translation_network.eval()
densepose_head.eval()
# Act - Run the complete inference pipeline
with torch.no_grad():
# 1. Process CSI data
processed_csi = csi_processor.process_csi_batch(mock_raw_csi_input)
# 2. Sanitize phase information
sanitized_csi = phase_sanitizer.sanitize_phase_batch(processed_csi)
# 3. Translate CSI to visual features
visual_features = translation_network(sanitized_csi)
# 4. Generate DensePose predictions
densepose_output = densepose_head(visual_features)
# Assert
assert densepose_output is not None
assert isinstance(densepose_output, dict)
assert 'segmentation' in densepose_output
assert 'uv_coordinates' in densepose_output
seg_output = densepose_output['segmentation']
uv_output = densepose_output['uv_coordinates']
# Check output shapes
assert seg_output.shape[0] == mock_raw_csi_input.shape[0] # Batch size preserved
assert seg_output.shape[1] == 25 # 24 body parts + 1 background
assert uv_output.shape[0] == mock_raw_csi_input.shape[0] # Batch size preserved
assert uv_output.shape[1] == 2 # U and V coordinates
# Check output ranges
assert torch.all(uv_output >= 0) and torch.all(uv_output <= 1) # UV in [0, 1]
def test_inference_pipeline_handles_different_batch_sizes(self, inference_pipeline_components):
"""Test that inference pipeline handles different batch sizes"""
# Arrange
csi_processor = inference_pipeline_components['csi_processor']
phase_sanitizer = inference_pipeline_components['phase_sanitizer']
translation_network = inference_pipeline_components['translation_network']
densepose_head = inference_pipeline_components['densepose_head']
# Different batch sizes
small_batch = np.random.randn(1, 3, 56, 100) + 1j * np.random.randn(1, 3, 56, 100)
large_batch = np.random.randn(8, 3, 56, 100) + 1j * np.random.randn(8, 3, 56, 100)
# Set models to evaluation mode
translation_network.eval()
densepose_head.eval()
# Act
with torch.no_grad():
# Small batch
small_processed = csi_processor.process_csi_batch(small_batch)
small_sanitized = phase_sanitizer.sanitize_phase_batch(small_processed)
small_features = translation_network(small_sanitized)
small_output = densepose_head(small_features)
# Large batch
large_processed = csi_processor.process_csi_batch(large_batch)
large_sanitized = phase_sanitizer.sanitize_phase_batch(large_processed)
large_features = translation_network(large_sanitized)
large_output = densepose_head(large_features)
# Assert
assert small_output['segmentation'].shape[0] == 1
assert large_output['segmentation'].shape[0] == 8
assert small_output['uv_coordinates'].shape[0] == 1
assert large_output['uv_coordinates'].shape[0] == 8
def test_inference_pipeline_maintains_gradient_flow_during_training(self, inference_pipeline_components,
mock_raw_csi_input, mock_ground_truth_densepose):
"""Test that inference pipeline maintains gradient flow during training"""
# Arrange
csi_processor = inference_pipeline_components['csi_processor']
phase_sanitizer = inference_pipeline_components['phase_sanitizer']
translation_network = inference_pipeline_components['translation_network']
densepose_head = inference_pipeline_components['densepose_head']
# Set models to training mode
translation_network.train()
densepose_head.train()
# Create optimizer
optimizer = torch.optim.Adam(
list(translation_network.parameters()) + list(densepose_head.parameters()),
lr=0.001
)
# Act
optimizer.zero_grad()
# Forward pass
processed_csi = csi_processor.process_csi_batch(mock_raw_csi_input)
sanitized_csi = phase_sanitizer.sanitize_phase_batch(processed_csi)
visual_features = translation_network(sanitized_csi)
densepose_output = densepose_head(visual_features)
# Resize ground truth to match output
seg_target = torch.nn.functional.interpolate(
mock_ground_truth_densepose['segmentation'].float().unsqueeze(1),
size=densepose_output['segmentation'].shape[2:],
mode='nearest'
).squeeze(1).long()
uv_target = torch.nn.functional.interpolate(
mock_ground_truth_densepose['uv_coordinates'],
size=densepose_output['uv_coordinates'].shape[2:],
mode='bilinear',
align_corners=False
)
# Compute loss
loss = densepose_head.compute_total_loss(densepose_output, seg_target, uv_target)
# Backward pass
loss.backward()
# Assert - Check that gradients are computed
for param in translation_network.parameters():
if param.requires_grad:
assert param.grad is not None
assert not torch.allclose(param.grad, torch.zeros_like(param.grad))
for param in densepose_head.parameters():
if param.requires_grad:
assert param.grad is not None
assert not torch.allclose(param.grad, torch.zeros_like(param.grad))
def test_inference_pipeline_performance_benchmarking(self, inference_pipeline_components, mock_raw_csi_input):
"""Test inference pipeline performance for real-time requirements"""
import time
# Arrange
csi_processor = inference_pipeline_components['csi_processor']
phase_sanitizer = inference_pipeline_components['phase_sanitizer']
translation_network = inference_pipeline_components['translation_network']
densepose_head = inference_pipeline_components['densepose_head']
# Set models to evaluation mode for inference
translation_network.eval()
densepose_head.eval()
# Warm up (first inference is often slower)
with torch.no_grad():
processed_csi = csi_processor.process_csi_batch(mock_raw_csi_input)
sanitized_csi = phase_sanitizer.sanitize_phase_batch(processed_csi)
visual_features = translation_network(sanitized_csi)
_ = densepose_head(visual_features)
# Act - Measure inference time
start_time = time.time()
with torch.no_grad():
processed_csi = csi_processor.process_csi_batch(mock_raw_csi_input)
sanitized_csi = phase_sanitizer.sanitize_phase_batch(processed_csi)
visual_features = translation_network(sanitized_csi)
densepose_output = densepose_head(visual_features)
end_time = time.time()
inference_time = end_time - start_time
# Assert - Should meet real-time requirements (< 50ms for batch of 4)
assert inference_time < 0.05, f"Inference took {inference_time:.3f}s, expected < 0.05s"
def test_inference_pipeline_handles_edge_cases(self, inference_pipeline_components):
"""Test that inference pipeline handles edge cases gracefully"""
# Arrange
csi_processor = inference_pipeline_components['csi_processor']
phase_sanitizer = inference_pipeline_components['phase_sanitizer']
translation_network = inference_pipeline_components['translation_network']
densepose_head = inference_pipeline_components['densepose_head']
# Edge cases
zero_input = np.zeros((1, 3, 56, 100), dtype=complex)
noisy_input = np.random.randn(1, 3, 56, 100) * 100 + 1j * np.random.randn(1, 3, 56, 100) * 100
translation_network.eval()
densepose_head.eval()
# Act & Assert
with torch.no_grad():
# Zero input
zero_processed = csi_processor.process_csi_batch(zero_input)
zero_sanitized = phase_sanitizer.sanitize_phase_batch(zero_processed)
zero_features = translation_network(zero_sanitized)
zero_output = densepose_head(zero_features)
assert not torch.isnan(zero_output['segmentation']).any()
assert not torch.isnan(zero_output['uv_coordinates']).any()
# Noisy input
noisy_processed = csi_processor.process_csi_batch(noisy_input)
noisy_sanitized = phase_sanitizer.sanitize_phase_batch(noisy_processed)
noisy_features = translation_network(noisy_sanitized)
noisy_output = densepose_head(noisy_features)
assert not torch.isnan(noisy_output['segmentation']).any()
assert not torch.isnan(noisy_output['uv_coordinates']).any()
def test_inference_pipeline_memory_efficiency(self, inference_pipeline_components):
"""Test that inference pipeline is memory efficient"""
# Arrange
csi_processor = inference_pipeline_components['csi_processor']
phase_sanitizer = inference_pipeline_components['phase_sanitizer']
translation_network = inference_pipeline_components['translation_network']
densepose_head = inference_pipeline_components['densepose_head']
# Large batch to test memory usage
large_input = np.random.randn(16, 3, 56, 100) + 1j * np.random.randn(16, 3, 56, 100)
translation_network.eval()
densepose_head.eval()
# Act - Process in chunks to manage memory
chunk_size = 4
outputs = []
with torch.no_grad():
for i in range(0, large_input.shape[0], chunk_size):
chunk = large_input[i:i+chunk_size]
processed_chunk = csi_processor.process_csi_batch(chunk)
sanitized_chunk = phase_sanitizer.sanitize_phase_batch(processed_chunk)
feature_chunk = translation_network(sanitized_chunk)
output_chunk = densepose_head(feature_chunk)
outputs.append(output_chunk)
# Clear intermediate tensors to free memory
del processed_chunk, sanitized_chunk, feature_chunk
# Assert
assert len(outputs) == 4 # 16 samples / 4 chunk_size
for output in outputs:
assert output['segmentation'].shape[0] <= chunk_size
def test_inference_pipeline_deterministic_output(self, inference_pipeline_components, mock_raw_csi_input):
"""Test that inference pipeline produces deterministic output in eval mode"""
# Arrange
csi_processor = inference_pipeline_components['csi_processor']
phase_sanitizer = inference_pipeline_components['phase_sanitizer']
translation_network = inference_pipeline_components['translation_network']
densepose_head = inference_pipeline_components['densepose_head']
# Set models to evaluation mode
translation_network.eval()
densepose_head.eval()
# Act - Run inference twice
with torch.no_grad():
# First run
processed_csi_1 = csi_processor.process_csi_batch(mock_raw_csi_input)
sanitized_csi_1 = phase_sanitizer.sanitize_phase_batch(processed_csi_1)
visual_features_1 = translation_network(sanitized_csi_1)
output_1 = densepose_head(visual_features_1)
# Second run
processed_csi_2 = csi_processor.process_csi_batch(mock_raw_csi_input)
sanitized_csi_2 = phase_sanitizer.sanitize_phase_batch(processed_csi_2)
visual_features_2 = translation_network(sanitized_csi_2)
output_2 = densepose_head(visual_features_2)
# Assert - Outputs should be identical in eval mode
assert torch.allclose(output_1['segmentation'], output_2['segmentation'], atol=1e-6)
assert torch.allclose(output_1['uv_coordinates'], output_2['uv_coordinates'], atol=1e-6)
def test_inference_pipeline_confidence_estimation(self, inference_pipeline_components, mock_raw_csi_input):
"""Test that inference pipeline provides confidence estimates"""
# Arrange
csi_processor = inference_pipeline_components['csi_processor']
phase_sanitizer = inference_pipeline_components['phase_sanitizer']
translation_network = inference_pipeline_components['translation_network']
densepose_head = inference_pipeline_components['densepose_head']
translation_network.eval()
densepose_head.eval()
# Act
with torch.no_grad():
processed_csi = csi_processor.process_csi_batch(mock_raw_csi_input)
sanitized_csi = phase_sanitizer.sanitize_phase_batch(processed_csi)
visual_features = translation_network(sanitized_csi)
densepose_output = densepose_head(visual_features)
# Get confidence estimates
confidence = densepose_head.get_prediction_confidence(densepose_output)
# Assert
assert 'segmentation_confidence' in confidence
assert 'uv_confidence' in confidence
seg_conf = confidence['segmentation_confidence']
uv_conf = confidence['uv_confidence']
assert seg_conf.shape[0] == mock_raw_csi_input.shape[0]
assert uv_conf.shape[0] == mock_raw_csi_input.shape[0]
assert torch.all(seg_conf >= 0) and torch.all(seg_conf <= 1)
assert torch.all(uv_conf >= 0)
def test_inference_pipeline_post_processing(self, inference_pipeline_components, mock_raw_csi_input):
"""Test that inference pipeline post-processes predictions correctly"""
# Arrange
csi_processor = inference_pipeline_components['csi_processor']
phase_sanitizer = inference_pipeline_components['phase_sanitizer']
translation_network = inference_pipeline_components['translation_network']
densepose_head = inference_pipeline_components['densepose_head']
translation_network.eval()
densepose_head.eval()
# Act
with torch.no_grad():
processed_csi = csi_processor.process_csi_batch(mock_raw_csi_input)
sanitized_csi = phase_sanitizer.sanitize_phase_batch(processed_csi)
visual_features = translation_network(sanitized_csi)
raw_output = densepose_head(visual_features)
# Post-process predictions
processed_output = densepose_head.post_process_predictions(raw_output)
# Assert
assert 'body_parts' in processed_output
assert 'uv_coordinates' in processed_output
assert 'confidence_scores' in processed_output
body_parts = processed_output['body_parts']
assert body_parts.dtype == torch.long # Class indices
assert torch.all(body_parts >= 0) and torch.all(body_parts <= 24) # Valid class range

View File

@@ -0,0 +1,264 @@
import pytest
import numpy as np
import torch
from unittest.mock import Mock, patch, MagicMock
from src.hardware.csi_extractor import CSIExtractor, CSIExtractionError
class TestCSIExtractor:
"""Test suite for CSI Extractor following London School TDD principles"""
@pytest.fixture
def mock_config(self):
"""Configuration for CSI extractor"""
return {
'interface': 'wlan0',
'channel': 6,
'bandwidth': 20,
'sample_rate': 1000,
'buffer_size': 1024,
'extraction_timeout': 5.0
}
@pytest.fixture
def mock_router_interface(self):
"""Mock router interface for testing"""
mock_router = Mock()
mock_router.is_connected = True
mock_router.execute_command = Mock()
return mock_router
@pytest.fixture
def csi_extractor(self, mock_config, mock_router_interface):
"""Create CSI extractor instance for testing"""
return CSIExtractor(mock_config, mock_router_interface)
@pytest.fixture
def mock_csi_data(self):
"""Generate synthetic CSI data for testing"""
# Simulate CSI data: complex values for multiple subcarriers
num_subcarriers = 56
num_antennas = 3
amplitude = np.random.uniform(0.1, 2.0, (num_antennas, num_subcarriers))
phase = np.random.uniform(-np.pi, np.pi, (num_antennas, num_subcarriers))
return amplitude * np.exp(1j * phase)
def test_extractor_initialization_creates_correct_configuration(self, mock_config, mock_router_interface):
"""Test that CSI extractor initializes with correct configuration"""
# Act
extractor = CSIExtractor(mock_config, mock_router_interface)
# Assert
assert extractor is not None
assert extractor.interface == mock_config['interface']
assert extractor.channel == mock_config['channel']
assert extractor.bandwidth == mock_config['bandwidth']
assert extractor.sample_rate == mock_config['sample_rate']
assert extractor.buffer_size == mock_config['buffer_size']
assert extractor.extraction_timeout == mock_config['extraction_timeout']
assert extractor.router_interface == mock_router_interface
assert not extractor.is_extracting
def test_start_extraction_configures_monitor_mode(self, csi_extractor, mock_router_interface):
"""Test that start_extraction configures monitor mode"""
# Arrange
mock_router_interface.enable_monitor_mode.return_value = True
mock_router_interface.execute_command.return_value = "CSI extraction started"
# Act
result = csi_extractor.start_extraction()
# Assert
assert result is True
assert csi_extractor.is_extracting is True
mock_router_interface.enable_monitor_mode.assert_called_once_with(csi_extractor.interface)
def test_start_extraction_handles_monitor_mode_failure(self, csi_extractor, mock_router_interface):
"""Test that start_extraction handles monitor mode configuration failure"""
# Arrange
mock_router_interface.enable_monitor_mode.return_value = False
# Act & Assert
with pytest.raises(CSIExtractionError):
csi_extractor.start_extraction()
assert csi_extractor.is_extracting is False
def test_stop_extraction_disables_monitor_mode(self, csi_extractor, mock_router_interface):
"""Test that stop_extraction disables monitor mode"""
# Arrange
mock_router_interface.enable_monitor_mode.return_value = True
mock_router_interface.disable_monitor_mode.return_value = True
mock_router_interface.execute_command.return_value = "CSI extraction started"
csi_extractor.start_extraction()
# Act
result = csi_extractor.stop_extraction()
# Assert
assert result is True
assert csi_extractor.is_extracting is False
mock_router_interface.disable_monitor_mode.assert_called_once_with(csi_extractor.interface)
def test_extract_csi_data_returns_valid_format(self, csi_extractor, mock_router_interface, mock_csi_data):
"""Test that extract_csi_data returns data in valid format"""
# Arrange
mock_router_interface.enable_monitor_mode.return_value = True
mock_router_interface.execute_command.return_value = "CSI extraction started"
# Mock the CSI data extraction
with patch.object(csi_extractor, '_parse_csi_output', return_value=mock_csi_data):
csi_extractor.start_extraction()
# Act
csi_data = csi_extractor.extract_csi_data()
# Assert
assert csi_data is not None
assert isinstance(csi_data, np.ndarray)
assert csi_data.dtype == np.complex128
assert csi_data.shape == mock_csi_data.shape
def test_extract_csi_data_requires_active_extraction(self, csi_extractor):
"""Test that extract_csi_data requires active extraction"""
# Act & Assert
with pytest.raises(CSIExtractionError):
csi_extractor.extract_csi_data()
def test_extract_csi_data_handles_timeout(self, csi_extractor, mock_router_interface):
"""Test that extract_csi_data handles extraction timeout"""
# Arrange
mock_router_interface.enable_monitor_mode.return_value = True
mock_router_interface.execute_command.side_effect = [
"CSI extraction started",
Exception("Timeout")
]
csi_extractor.start_extraction()
# Act & Assert
with pytest.raises(CSIExtractionError):
csi_extractor.extract_csi_data()
def test_convert_to_tensor_produces_correct_format(self, csi_extractor, mock_csi_data):
"""Test that convert_to_tensor produces correctly formatted tensor"""
# Act
tensor = csi_extractor.convert_to_tensor(mock_csi_data)
# Assert
assert isinstance(tensor, torch.Tensor)
assert tensor.dtype == torch.float32
assert tensor.shape[0] == mock_csi_data.shape[0] * 2 # Real and imaginary parts
assert tensor.shape[1] == mock_csi_data.shape[1]
def test_convert_to_tensor_handles_invalid_input(self, csi_extractor):
"""Test that convert_to_tensor handles invalid input"""
# Arrange
invalid_data = "not an array"
# Act & Assert
with pytest.raises(ValueError):
csi_extractor.convert_to_tensor(invalid_data)
def test_get_extraction_stats_returns_valid_statistics(self, csi_extractor, mock_router_interface):
"""Test that get_extraction_stats returns valid statistics"""
# Arrange
mock_router_interface.enable_monitor_mode.return_value = True
mock_router_interface.execute_command.return_value = "CSI extraction started"
csi_extractor.start_extraction()
# Act
stats = csi_extractor.get_extraction_stats()
# Assert
assert stats is not None
assert isinstance(stats, dict)
assert 'samples_extracted' in stats
assert 'extraction_rate' in stats
assert 'buffer_utilization' in stats
assert 'last_extraction_time' in stats
def test_set_channel_configures_wifi_channel(self, csi_extractor, mock_router_interface):
"""Test that set_channel configures WiFi channel"""
# Arrange
new_channel = 11
mock_router_interface.execute_command.return_value = f"Channel set to {new_channel}"
# Act
result = csi_extractor.set_channel(new_channel)
# Assert
assert result is True
assert csi_extractor.channel == new_channel
mock_router_interface.execute_command.assert_called()
def test_set_channel_validates_channel_range(self, csi_extractor):
"""Test that set_channel validates channel range"""
# Act & Assert
with pytest.raises(ValueError):
csi_extractor.set_channel(0) # Invalid channel
with pytest.raises(ValueError):
csi_extractor.set_channel(15) # Invalid channel
def test_extractor_supports_context_manager(self, csi_extractor, mock_router_interface):
"""Test that CSI extractor supports context manager protocol"""
# Arrange
mock_router_interface.enable_monitor_mode.return_value = True
mock_router_interface.disable_monitor_mode.return_value = True
mock_router_interface.execute_command.return_value = "CSI extraction started"
# Act
with csi_extractor as extractor:
# Assert
assert extractor.is_extracting is True
# Assert - extraction should be stopped after context
assert csi_extractor.is_extracting is False
def test_extractor_validates_configuration(self, mock_router_interface):
"""Test that CSI extractor validates configuration parameters"""
# Arrange
invalid_config = {
'interface': '', # Invalid interface
'channel': 6,
'bandwidth': 20
}
# Act & Assert
with pytest.raises(ValueError):
CSIExtractor(invalid_config, mock_router_interface)
def test_parse_csi_output_processes_raw_data(self, csi_extractor):
"""Test that _parse_csi_output processes raw CSI data correctly"""
# Arrange
raw_output = "CSI_DATA: 1.5+0.5j,2.0-1.0j,0.8+1.2j"
# Act
parsed_data = csi_extractor._parse_csi_output(raw_output)
# Assert
assert parsed_data is not None
assert isinstance(parsed_data, np.ndarray)
assert parsed_data.dtype == np.complex128
def test_buffer_management_handles_overflow(self, csi_extractor, mock_router_interface, mock_csi_data):
"""Test that buffer management handles overflow correctly"""
# Arrange
mock_router_interface.enable_monitor_mode.return_value = True
mock_router_interface.execute_command.return_value = "CSI extraction started"
with patch.object(csi_extractor, '_parse_csi_output', return_value=mock_csi_data):
csi_extractor.start_extraction()
# Fill buffer beyond capacity
for _ in range(csi_extractor.buffer_size + 10):
csi_extractor._add_to_buffer(mock_csi_data)
# Act
stats = csi_extractor.get_extraction_stats()
# Assert
assert stats['buffer_utilization'] <= 1.0 # Should not exceed 100%

View File

@@ -3,27 +3,27 @@ import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from src.models.densepose_head import DensePoseHead from src.models.densepose_head import DensePoseHead, DensePoseError
class TestDensePoseHead: class TestDensePoseHead:
"""Test suite for DensePose Head following London School TDD principles""" """Test suite for DensePose Head following London School TDD principles"""
@pytest.fixture
def mock_feature_input(self):
"""Generate synthetic feature input tensor for testing"""
# Batch size 2, 512 channels, 56 height, 100 width (from modality translation)
return torch.randn(2, 512, 56, 100)
@pytest.fixture @pytest.fixture
def mock_config(self): def mock_config(self):
"""Configuration for DensePose head""" """Configuration for DensePose head"""
return { return {
'input_channels': 512, 'input_channels': 256,
'num_body_parts': 24, # Standard DensePose body parts 'num_body_parts': 24,
'num_uv_coordinates': 2, # U and V coordinates 'num_uv_coordinates': 2,
'hidden_dim': 256, 'hidden_channels': [128, 64],
'dropout_rate': 0.1 'kernel_size': 3,
'padding': 1,
'dropout_rate': 0.1,
'use_deformable_conv': False,
'use_fpn': True,
'fpn_levels': [2, 3, 4, 5],
'output_stride': 4
} }
@pytest.fixture @pytest.fixture
@@ -31,6 +31,33 @@ class TestDensePoseHead:
"""Create DensePose head instance for testing""" """Create DensePose head instance for testing"""
return DensePoseHead(mock_config) return DensePoseHead(mock_config)
@pytest.fixture
def mock_feature_input(self):
"""Generate mock feature input tensor"""
batch_size = 2
channels = 256
height = 56
width = 56
return torch.randn(batch_size, channels, height, width)
@pytest.fixture
def mock_target_masks(self):
"""Generate mock target segmentation masks"""
batch_size = 2
num_parts = 24
height = 224
width = 224
return torch.randint(0, num_parts + 1, (batch_size, height, width))
@pytest.fixture
def mock_target_uv(self):
"""Generate mock target UV coordinates"""
batch_size = 2
num_coords = 2
height = 224
width = 224
return torch.randn(batch_size, num_coords, height, width)
def test_head_initialization_creates_correct_architecture(self, mock_config): def test_head_initialization_creates_correct_architecture(self, mock_config):
"""Test that DensePose head initializes with correct architecture""" """Test that DensePose head initializes with correct architecture"""
# Act # Act
@@ -39,135 +66,302 @@ class TestDensePoseHead:
# Assert # Assert
assert head is not None assert head is not None
assert isinstance(head, nn.Module) assert isinstance(head, nn.Module)
assert hasattr(head, 'segmentation_head')
assert hasattr(head, 'uv_regression_head')
assert head.input_channels == mock_config['input_channels'] assert head.input_channels == mock_config['input_channels']
assert head.num_body_parts == mock_config['num_body_parts'] assert head.num_body_parts == mock_config['num_body_parts']
assert head.num_uv_coordinates == mock_config['num_uv_coordinates'] assert head.num_uv_coordinates == mock_config['num_uv_coordinates']
assert head.use_fpn == mock_config['use_fpn']
assert hasattr(head, 'segmentation_head')
assert hasattr(head, 'uv_regression_head')
if mock_config['use_fpn']:
assert hasattr(head, 'fpn')
def test_forward_pass_produces_correct_output_shapes(self, densepose_head, mock_feature_input): def test_forward_pass_produces_correct_output_format(self, densepose_head, mock_feature_input):
"""Test that forward pass produces correctly shaped outputs""" """Test that forward pass produces correctly formatted output"""
# Act # Act
with torch.no_grad(): output = densepose_head(mock_feature_input)
segmentation, uv_coords = densepose_head(mock_feature_input)
# Assert # Assert
assert segmentation is not None assert output is not None
assert uv_coords is not None assert isinstance(output, dict)
assert isinstance(segmentation, torch.Tensor) assert 'segmentation' in output
assert isinstance(uv_coords, torch.Tensor) assert 'uv_coordinates' in output
# Check segmentation output shape seg_output = output['segmentation']
assert segmentation.shape[0] == mock_feature_input.shape[0] # Batch size preserved uv_output = output['uv_coordinates']
assert segmentation.shape[1] == densepose_head.num_body_parts # Correct number of body parts
assert segmentation.shape[2:] == mock_feature_input.shape[2:] # Spatial dimensions preserved
# Check UV coordinates output shape assert isinstance(seg_output, torch.Tensor)
assert uv_coords.shape[0] == mock_feature_input.shape[0] # Batch size preserved assert isinstance(uv_output, torch.Tensor)
assert uv_coords.shape[1] == densepose_head.num_uv_coordinates # U and V coordinates assert seg_output.shape[0] == mock_feature_input.shape[0] # Batch size preserved
assert uv_coords.shape[2:] == mock_feature_input.shape[2:] # Spatial dimensions preserved assert uv_output.shape[0] == mock_feature_input.shape[0] # Batch size preserved
def test_segmentation_output_has_valid_probabilities(self, densepose_head, mock_feature_input): def test_segmentation_head_produces_correct_shape(self, densepose_head, mock_feature_input):
"""Test that segmentation output has valid probability distributions""" """Test that segmentation head produces correct output shape"""
# Act # Act
with torch.no_grad(): output = densepose_head(mock_feature_input)
segmentation, _ = densepose_head(mock_feature_input) seg_output = output['segmentation']
# Assert # Assert
# After softmax, values should be between 0 and 1 expected_channels = densepose_head.num_body_parts + 1 # +1 for background
assert torch.all(segmentation >= 0.0) assert seg_output.shape[1] == expected_channels
assert torch.all(segmentation <= 1.0) assert seg_output.shape[2] >= mock_feature_input.shape[2] # Height upsampled
assert seg_output.shape[3] >= mock_feature_input.shape[3] # Width upsampled
# Sum across body parts dimension should be approximately 1 def test_uv_regression_head_produces_correct_shape(self, densepose_head, mock_feature_input):
part_sums = torch.sum(segmentation, dim=1) """Test that UV regression head produces correct output shape"""
assert torch.allclose(part_sums, torch.ones_like(part_sums), atol=1e-5)
def test_uv_coordinates_output_in_valid_range(self, densepose_head, mock_feature_input):
"""Test that UV coordinates are in valid range [0, 1]"""
# Act # Act
with torch.no_grad(): output = densepose_head(mock_feature_input)
_, uv_coords = densepose_head(mock_feature_input) uv_output = output['uv_coordinates']
# Assert # Assert
# UV coordinates should be in range [0, 1] after sigmoid assert uv_output.shape[1] == densepose_head.num_uv_coordinates
assert torch.all(uv_coords >= 0.0) assert uv_output.shape[2] >= mock_feature_input.shape[2] # Height upsampled
assert torch.all(uv_coords <= 1.0) assert uv_output.shape[3] >= mock_feature_input.shape[3] # Width upsampled
def test_head_handles_different_batch_sizes(self, densepose_head): def test_compute_segmentation_loss_measures_pixel_classification(self, densepose_head, mock_feature_input, mock_target_masks):
"""Test that head handles different batch sizes correctly""" """Test that compute_segmentation_loss measures pixel classification accuracy"""
# Arrange # Arrange
batch_sizes = [1, 4, 8] output = densepose_head(mock_feature_input)
seg_logits = output['segmentation']
for batch_size in batch_sizes: # Resize target to match output
input_tensor = torch.randn(batch_size, 512, 56, 100) target_resized = torch.nn.functional.interpolate(
mock_target_masks.float().unsqueeze(1),
size=seg_logits.shape[2:],
mode='nearest'
).squeeze(1).long()
# Act # Act
with torch.no_grad(): loss = densepose_head.compute_segmentation_loss(seg_logits, target_resized)
segmentation, uv_coords = densepose_head(input_tensor)
# Assert # Assert
assert segmentation.shape[0] == batch_size assert loss is not None
assert uv_coords.shape[0] == batch_size assert isinstance(loss, torch.Tensor)
assert loss.dim() == 0 # Scalar loss
assert loss.item() >= 0 # Loss should be non-negative
def test_head_is_trainable(self, densepose_head, mock_feature_input): def test_compute_uv_loss_measures_coordinate_regression(self, densepose_head, mock_feature_input, mock_target_uv):
"""Test that head parameters are trainable""" """Test that compute_uv_loss measures UV coordinate regression accuracy"""
# Arrange # Arrange
seg_criterion = nn.CrossEntropyLoss() output = densepose_head(mock_feature_input)
uv_criterion = nn.MSELoss() uv_pred = output['uv_coordinates']
# Create targets with correct shapes # Resize target to match output
seg_target = torch.randint(0, 24, (2, 56, 100)) # Class indices for segmentation target_resized = torch.nn.functional.interpolate(
uv_target = torch.rand(2, 2, 56, 100) # UV coordinates target mock_target_uv,
size=uv_pred.shape[2:],
mode='bilinear',
align_corners=False
)
# Act # Act
segmentation, uv_coords = densepose_head(mock_feature_input) loss = densepose_head.compute_uv_loss(uv_pred, target_resized)
seg_loss = seg_criterion(segmentation, seg_target)
uv_loss = uv_criterion(uv_coords, uv_target)
total_loss = seg_loss + uv_loss
total_loss.backward()
# Assert # Assert
assert loss is not None
assert isinstance(loss, torch.Tensor)
assert loss.dim() == 0 # Scalar loss
assert loss.item() >= 0 # Loss should be non-negative
def test_compute_total_loss_combines_segmentation_and_uv_losses(self, densepose_head, mock_feature_input, mock_target_masks, mock_target_uv):
"""Test that compute_total_loss combines segmentation and UV losses"""
# Arrange
output = densepose_head(mock_feature_input)
# Resize targets to match outputs
seg_target = torch.nn.functional.interpolate(
mock_target_masks.float().unsqueeze(1),
size=output['segmentation'].shape[2:],
mode='nearest'
).squeeze(1).long()
uv_target = torch.nn.functional.interpolate(
mock_target_uv,
size=output['uv_coordinates'].shape[2:],
mode='bilinear',
align_corners=False
)
# Act
total_loss = densepose_head.compute_total_loss(output, seg_target, uv_target)
seg_loss = densepose_head.compute_segmentation_loss(output['segmentation'], seg_target)
uv_loss = densepose_head.compute_uv_loss(output['uv_coordinates'], uv_target)
# Assert
assert total_loss is not None
assert isinstance(total_loss, torch.Tensor)
assert total_loss.item() > 0 assert total_loss.item() > 0
# Check that gradients are computed # Total loss should be combination of individual losses
expected_total = seg_loss + uv_loss
assert torch.allclose(total_loss, expected_total, atol=1e-6)
def test_fpn_integration_enhances_multi_scale_features(self, mock_config, mock_feature_input):
"""Test that FPN integration enhances multi-scale feature processing"""
# Arrange
config_with_fpn = mock_config.copy()
config_with_fpn['use_fpn'] = True
config_without_fpn = mock_config.copy()
config_without_fpn['use_fpn'] = False
head_with_fpn = DensePoseHead(config_with_fpn)
head_without_fpn = DensePoseHead(config_without_fpn)
# Act
output_with_fpn = head_with_fpn(mock_feature_input)
output_without_fpn = head_without_fpn(mock_feature_input)
# Assert
assert output_with_fpn['segmentation'].shape == output_without_fpn['segmentation'].shape
assert output_with_fpn['uv_coordinates'].shape == output_without_fpn['uv_coordinates'].shape
# Outputs should be different due to FPN
assert not torch.allclose(output_with_fpn['segmentation'], output_without_fpn['segmentation'], atol=1e-6)
def test_get_prediction_confidence_provides_uncertainty_estimates(self, densepose_head, mock_feature_input):
"""Test that get_prediction_confidence provides uncertainty estimates"""
# Arrange
output = densepose_head(mock_feature_input)
# Act
confidence = densepose_head.get_prediction_confidence(output)
# Assert
assert confidence is not None
assert isinstance(confidence, dict)
assert 'segmentation_confidence' in confidence
assert 'uv_confidence' in confidence
seg_conf = confidence['segmentation_confidence']
uv_conf = confidence['uv_confidence']
assert isinstance(seg_conf, torch.Tensor)
assert isinstance(uv_conf, torch.Tensor)
assert seg_conf.shape[0] == mock_feature_input.shape[0]
assert uv_conf.shape[0] == mock_feature_input.shape[0]
def test_post_process_predictions_formats_output(self, densepose_head, mock_feature_input):
"""Test that post_process_predictions formats output correctly"""
# Arrange
raw_output = densepose_head(mock_feature_input)
# Act
processed = densepose_head.post_process_predictions(raw_output)
# Assert
assert processed is not None
assert isinstance(processed, dict)
assert 'body_parts' in processed
assert 'uv_coordinates' in processed
assert 'confidence_scores' in processed
def test_training_mode_enables_dropout(self, densepose_head, mock_feature_input):
"""Test that training mode enables dropout for regularization"""
# Arrange
densepose_head.train()
# Act
output1 = densepose_head(mock_feature_input)
output2 = densepose_head(mock_feature_input)
# Assert - outputs should be different due to dropout
assert not torch.allclose(output1['segmentation'], output2['segmentation'], atol=1e-6)
assert not torch.allclose(output1['uv_coordinates'], output2['uv_coordinates'], atol=1e-6)
def test_evaluation_mode_disables_dropout(self, densepose_head, mock_feature_input):
"""Test that evaluation mode disables dropout for consistent inference"""
# Arrange
densepose_head.eval()
# Act
output1 = densepose_head(mock_feature_input)
output2 = densepose_head(mock_feature_input)
# Assert - outputs should be identical in eval mode
assert torch.allclose(output1['segmentation'], output2['segmentation'], atol=1e-6)
assert torch.allclose(output1['uv_coordinates'], output2['uv_coordinates'], atol=1e-6)
def test_head_validates_input_dimensions(self, densepose_head):
"""Test that head validates input dimensions"""
# Arrange
invalid_input = torch.randn(2, 128, 56, 56) # Wrong number of channels
# Act & Assert
with pytest.raises(DensePoseError):
densepose_head(invalid_input)
def test_head_handles_different_input_sizes(self, densepose_head):
"""Test that head handles different input sizes"""
# Arrange
small_input = torch.randn(1, 256, 28, 28)
large_input = torch.randn(1, 256, 112, 112)
# Act
small_output = densepose_head(small_input)
large_output = densepose_head(large_input)
# Assert
assert small_output['segmentation'].shape[2:] != large_output['segmentation'].shape[2:]
assert small_output['uv_coordinates'].shape[2:] != large_output['uv_coordinates'].shape[2:]
def test_head_supports_gradient_computation(self, densepose_head, mock_feature_input, mock_target_masks, mock_target_uv):
"""Test that head supports gradient computation for training"""
# Arrange
densepose_head.train()
optimizer = torch.optim.Adam(densepose_head.parameters(), lr=0.001)
output = densepose_head(mock_feature_input)
# Resize targets
seg_target = torch.nn.functional.interpolate(
mock_target_masks.float().unsqueeze(1),
size=output['segmentation'].shape[2:],
mode='nearest'
).squeeze(1).long()
uv_target = torch.nn.functional.interpolate(
mock_target_uv,
size=output['uv_coordinates'].shape[2:],
mode='bilinear',
align_corners=False
)
# Act
loss = densepose_head.compute_total_loss(output, seg_target, uv_target)
optimizer.zero_grad()
loss.backward()
# Assert
for param in densepose_head.parameters(): for param in densepose_head.parameters():
if param.requires_grad: if param.requires_grad:
assert param.grad is not None assert param.grad is not None
assert not torch.allclose(param.grad, torch.zeros_like(param.grad))
def test_head_handles_invalid_input_shape(self, densepose_head): def test_head_configuration_validation(self):
"""Test that head handles invalid input shapes gracefully""" """Test that head validates configuration parameters"""
# Arrange # Arrange
invalid_input = torch.randn(2, 256, 56, 100) # Wrong number of channels invalid_config = {
'input_channels': 0, # Invalid
'num_body_parts': -1, # Invalid
'num_uv_coordinates': 2
}
# Act & Assert # Act & Assert
with pytest.raises(RuntimeError): with pytest.raises(ValueError):
densepose_head(invalid_input) DensePoseHead(invalid_config)
def test_head_supports_evaluation_mode(self, densepose_head, mock_feature_input): def test_save_and_load_model_state(self, densepose_head, mock_feature_input):
"""Test that head supports evaluation mode""" """Test that model state can be saved and loaded"""
# Act # Arrange
densepose_head.eval() original_output = densepose_head(mock_feature_input)
with torch.no_grad(): # Act - Save state
seg1, uv1 = densepose_head(mock_feature_input) state_dict = densepose_head.state_dict()
seg2, uv2 = densepose_head(mock_feature_input)
# Assert - In eval mode with same input, outputs should be identical # Create new head and load state
assert torch.allclose(seg1, seg2, atol=1e-6) new_head = DensePoseHead(densepose_head.config)
assert torch.allclose(uv1, uv2, atol=1e-6) new_head.load_state_dict(state_dict)
new_output = new_head(mock_feature_input)
def test_head_output_quality(self, densepose_head, mock_feature_input):
"""Test that head produces meaningful outputs"""
# Act
with torch.no_grad():
segmentation, uv_coords = densepose_head(mock_feature_input)
# Assert # Assert
# Outputs should not contain NaN or Inf values assert torch.allclose(original_output['segmentation'], new_output['segmentation'], atol=1e-6)
assert not torch.isnan(segmentation).any() assert torch.allclose(original_output['uv_coordinates'], new_output['uv_coordinates'], atol=1e-6)
assert not torch.isinf(segmentation).any()
assert not torch.isnan(uv_coords).any()
assert not torch.isinf(uv_coords).any()
# Outputs should have reasonable variance (not all zeros or ones)
assert segmentation.std() > 0.01
assert uv_coords.std() > 0.01

View File

@@ -3,27 +3,27 @@ import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from src.models.modality_translation import ModalityTranslationNetwork from src.models.modality_translation import ModalityTranslationNetwork, ModalityTranslationError
class TestModalityTranslationNetwork: class TestModalityTranslationNetwork:
"""Test suite for Modality Translation Network following London School TDD principles""" """Test suite for Modality Translation Network following London School TDD principles"""
@pytest.fixture
def mock_csi_input(self):
"""Generate synthetic CSI input tensor for testing"""
# Batch size 2, 3 antennas, 56 subcarriers, 100 temporal samples
return torch.randn(2, 3, 56, 100)
@pytest.fixture @pytest.fixture
def mock_config(self): def mock_config(self):
"""Configuration for modality translation network""" """Configuration for modality translation network"""
return { return {
'input_channels': 3, 'input_channels': 6, # Real and imaginary parts for 3 antennas
'hidden_dim': 256, 'hidden_channels': [64, 128, 256],
'output_dim': 512, 'output_channels': 256,
'num_layers': 3, 'kernel_size': 3,
'dropout_rate': 0.1 'stride': 1,
'padding': 1,
'dropout_rate': 0.1,
'activation': 'relu',
'normalization': 'batch',
'use_attention': True,
'attention_heads': 8
} }
@pytest.fixture @pytest.fixture
@@ -31,98 +31,263 @@ class TestModalityTranslationNetwork:
"""Create modality translation network instance for testing""" """Create modality translation network instance for testing"""
return ModalityTranslationNetwork(mock_config) return ModalityTranslationNetwork(mock_config)
@pytest.fixture
def mock_csi_input(self):
"""Generate mock CSI input tensor"""
batch_size = 4
channels = 6 # Real and imaginary parts for 3 antennas
height = 56 # Number of subcarriers
width = 100 # Time samples
return torch.randn(batch_size, channels, height, width)
@pytest.fixture
def mock_target_features(self):
"""Generate mock target feature tensor for training"""
batch_size = 4
feature_dim = 256
spatial_height = 56
spatial_width = 100
return torch.randn(batch_size, feature_dim, spatial_height, spatial_width)
def test_network_initialization_creates_correct_architecture(self, mock_config): def test_network_initialization_creates_correct_architecture(self, mock_config):
"""Test that network initializes with correct architecture""" """Test that modality translation network initializes with correct architecture"""
# Act # Act
network = ModalityTranslationNetwork(mock_config) network = ModalityTranslationNetwork(mock_config)
# Assert # Assert
assert network is not None assert network is not None
assert isinstance(network, nn.Module) assert isinstance(network, nn.Module)
assert network.input_channels == mock_config['input_channels']
assert network.output_channels == mock_config['output_channels']
assert network.use_attention == mock_config['use_attention']
assert hasattr(network, 'encoder') assert hasattr(network, 'encoder')
assert hasattr(network, 'decoder') assert hasattr(network, 'decoder')
assert network.input_channels == mock_config['input_channels'] if mock_config['use_attention']:
assert network.hidden_dim == mock_config['hidden_dim'] assert hasattr(network, 'attention')
assert network.output_dim == mock_config['output_dim']
def test_forward_pass_produces_correct_output_shape(self, translation_network, mock_csi_input): def test_forward_pass_produces_correct_output_shape(self, translation_network, mock_csi_input):
"""Test that forward pass produces correctly shaped output""" """Test that forward pass produces correctly shaped output"""
# Act # Act
with torch.no_grad():
output = translation_network(mock_csi_input) output = translation_network(mock_csi_input)
# Assert # Assert
assert output is not None assert output is not None
assert isinstance(output, torch.Tensor) assert isinstance(output, torch.Tensor)
assert output.shape[0] == mock_csi_input.shape[0] # Batch size preserved assert output.shape[0] == mock_csi_input.shape[0] # Batch size preserved
assert output.shape[1] == translation_network.output_dim # Correct output dimension assert output.shape[1] == translation_network.output_channels # Correct output channels
assert len(output.shape) == 4 # Should maintain spatial dimensions assert output.shape[2] == mock_csi_input.shape[2] # Spatial height preserved
assert output.shape[3] == mock_csi_input.shape[3] # Spatial width preserved
def test_forward_pass_handles_different_batch_sizes(self, translation_network): def test_forward_pass_handles_different_input_sizes(self, translation_network):
"""Test that network handles different batch sizes correctly""" """Test that forward pass handles different input sizes"""
# Arrange # Arrange
batch_sizes = [1, 4, 8] small_input = torch.randn(2, 6, 28, 50)
large_input = torch.randn(8, 6, 112, 200)
for batch_size in batch_sizes:
input_tensor = torch.randn(batch_size, 3, 56, 100)
# Act # Act
with torch.no_grad(): small_output = translation_network(small_input)
output = translation_network(input_tensor) large_output = translation_network(large_input)
# Assert # Assert
assert output.shape[0] == batch_size assert small_output.shape == (2, 256, 28, 50)
assert output.shape[1] == translation_network.output_dim assert large_output.shape == (8, 256, 112, 200)
def test_network_is_trainable(self, translation_network, mock_csi_input):
"""Test that network parameters are trainable"""
# Arrange
criterion = nn.MSELoss()
def test_encoder_extracts_hierarchical_features(self, translation_network, mock_csi_input):
"""Test that encoder extracts hierarchical features"""
# Act # Act
output = translation_network(mock_csi_input) features = translation_network.encode(mock_csi_input)
# Create target with same shape as output
target = torch.randn_like(output)
loss = criterion(output, target)
loss.backward()
# Assert # Assert
assert loss.item() > 0 assert features is not None
# Check that gradients are computed assert isinstance(features, list)
for param in translation_network.parameters(): assert len(features) == len(translation_network.encoder)
if param.requires_grad:
assert param.grad is not None
def test_network_handles_invalid_input_shape(self, translation_network): # Check feature map sizes decrease with depth
"""Test that network handles invalid input shapes gracefully""" for i in range(1, len(features)):
assert features[i].shape[2] <= features[i-1].shape[2] # Height decreases or stays same
assert features[i].shape[3] <= features[i-1].shape[3] # Width decreases or stays same
def test_decoder_reconstructs_target_features(self, translation_network, mock_csi_input):
"""Test that decoder reconstructs target feature representation"""
# Arrange # Arrange
invalid_input = torch.randn(2, 5, 56, 100) # Wrong number of channels encoded_features = translation_network.encode(mock_csi_input)
# Act & Assert
with pytest.raises(RuntimeError):
translation_network(invalid_input)
def test_network_supports_evaluation_mode(self, translation_network, mock_csi_input):
"""Test that network supports evaluation mode"""
# Act # Act
translation_network.eval() decoded_output = translation_network.decode(encoded_features)
with torch.no_grad(): # Assert
assert decoded_output is not None
assert isinstance(decoded_output, torch.Tensor)
assert decoded_output.shape[1] == translation_network.output_channels
assert decoded_output.shape[2:] == mock_csi_input.shape[2:]
def test_attention_mechanism_enhances_features(self, mock_config, mock_csi_input):
"""Test that attention mechanism enhances feature representation"""
# Arrange
config_with_attention = mock_config.copy()
config_with_attention['use_attention'] = True
config_without_attention = mock_config.copy()
config_without_attention['use_attention'] = False
network_with_attention = ModalityTranslationNetwork(config_with_attention)
network_without_attention = ModalityTranslationNetwork(config_without_attention)
# Act
output_with_attention = network_with_attention(mock_csi_input)
output_without_attention = network_without_attention(mock_csi_input)
# Assert
assert output_with_attention.shape == output_without_attention.shape
# Outputs should be different due to attention mechanism
assert not torch.allclose(output_with_attention, output_without_attention, atol=1e-6)
def test_training_mode_enables_dropout(self, translation_network, mock_csi_input):
"""Test that training mode enables dropout for regularization"""
# Arrange
translation_network.train()
# Act
output1 = translation_network(mock_csi_input) output1 = translation_network(mock_csi_input)
output2 = translation_network(mock_csi_input) output2 = translation_network(mock_csi_input)
# Assert - In eval mode with same input, outputs should be identical # Assert - outputs should be different due to dropout
assert not torch.allclose(output1, output2, atol=1e-6)
def test_evaluation_mode_disables_dropout(self, translation_network, mock_csi_input):
"""Test that evaluation mode disables dropout for consistent inference"""
# Arrange
translation_network.eval()
# Act
output1 = translation_network(mock_csi_input)
output2 = translation_network(mock_csi_input)
# Assert - outputs should be identical in eval mode
assert torch.allclose(output1, output2, atol=1e-6) assert torch.allclose(output1, output2, atol=1e-6)
def test_network_feature_extraction_quality(self, translation_network, mock_csi_input): def test_compute_translation_loss_measures_feature_alignment(self, translation_network, mock_csi_input, mock_target_features):
"""Test that network extracts meaningful features""" """Test that compute_translation_loss measures feature alignment"""
# Arrange
predicted_features = translation_network(mock_csi_input)
# Act # Act
with torch.no_grad(): loss = translation_network.compute_translation_loss(predicted_features, mock_target_features)
output = translation_network(mock_csi_input)
# Assert # Assert
# Features should have reasonable statistics assert loss is not None
assert not torch.isnan(output).any() assert isinstance(loss, torch.Tensor)
assert not torch.isinf(output).any() assert loss.dim() == 0 # Scalar loss
assert output.std() > 0.01 # Features should have some variance assert loss.item() >= 0 # Loss should be non-negative
assert output.std() < 10.0 # But not be too extreme
def test_compute_translation_loss_handles_different_loss_types(self, translation_network, mock_csi_input, mock_target_features):
"""Test that compute_translation_loss handles different loss types"""
# Arrange
predicted_features = translation_network(mock_csi_input)
# Act
mse_loss = translation_network.compute_translation_loss(predicted_features, mock_target_features, loss_type='mse')
l1_loss = translation_network.compute_translation_loss(predicted_features, mock_target_features, loss_type='l1')
# Assert
assert mse_loss is not None
assert l1_loss is not None
assert mse_loss.item() != l1_loss.item() # Different loss types should give different values
def test_get_feature_statistics_provides_analysis(self, translation_network, mock_csi_input):
"""Test that get_feature_statistics provides feature analysis"""
# Arrange
output = translation_network(mock_csi_input)
# Act
stats = translation_network.get_feature_statistics(output)
# Assert
assert stats is not None
assert isinstance(stats, dict)
assert 'mean' in stats
assert 'std' in stats
assert 'min' in stats
assert 'max' in stats
assert 'sparsity' in stats
def test_network_supports_gradient_computation(self, translation_network, mock_csi_input, mock_target_features):
"""Test that network supports gradient computation for training"""
# Arrange
translation_network.train()
optimizer = torch.optim.Adam(translation_network.parameters(), lr=0.001)
# Act
output = translation_network(mock_csi_input)
loss = translation_network.compute_translation_loss(output, mock_target_features)
optimizer.zero_grad()
loss.backward()
# Assert
for param in translation_network.parameters():
if param.requires_grad:
assert param.grad is not None
assert not torch.allclose(param.grad, torch.zeros_like(param.grad))
def test_network_validates_input_dimensions(self, translation_network):
"""Test that network validates input dimensions"""
# Arrange
invalid_input = torch.randn(4, 3, 56, 100) # Wrong number of channels
# Act & Assert
with pytest.raises(ModalityTranslationError):
translation_network(invalid_input)
def test_network_handles_batch_size_one(self, translation_network):
"""Test that network handles single sample inference"""
# Arrange
single_input = torch.randn(1, 6, 56, 100)
# Act
output = translation_network(single_input)
# Assert
assert output.shape == (1, 256, 56, 100)
def test_save_and_load_model_state(self, translation_network, mock_csi_input):
"""Test that model state can be saved and loaded"""
# Arrange
original_output = translation_network(mock_csi_input)
# Act - Save state
state_dict = translation_network.state_dict()
# Create new network and load state
new_network = ModalityTranslationNetwork(translation_network.config)
new_network.load_state_dict(state_dict)
new_output = new_network(mock_csi_input)
# Assert
assert torch.allclose(original_output, new_output, atol=1e-6)
def test_network_configuration_validation(self):
"""Test that network validates configuration parameters"""
# Arrange
invalid_config = {
'input_channels': 0, # Invalid
'hidden_channels': [], # Invalid
'output_channels': 256
}
# Act & Assert
with pytest.raises(ValueError):
ModalityTranslationNetwork(invalid_config)
def test_feature_visualization_support(self, translation_network, mock_csi_input):
"""Test that network supports feature visualization"""
# Act
features = translation_network.get_intermediate_features(mock_csi_input)
# Assert
assert features is not None
assert isinstance(features, dict)
assert 'encoder_features' in features
assert 'decoder_features' in features
if translation_network.use_attention:
assert 'attention_weights' in features

View File

@@ -0,0 +1,244 @@
import pytest
import numpy as np
from unittest.mock import Mock, patch, MagicMock
from src.hardware.router_interface import RouterInterface, RouterConnectionError
class TestRouterInterface:
"""Test suite for Router Interface following London School TDD principles"""
@pytest.fixture
def mock_config(self):
"""Configuration for router interface"""
return {
'router_ip': '192.168.1.1',
'username': 'admin',
'password': 'password',
'ssh_port': 22,
'timeout': 30,
'max_retries': 3
}
@pytest.fixture
def router_interface(self, mock_config):
"""Create router interface instance for testing"""
return RouterInterface(mock_config)
@pytest.fixture
def mock_ssh_client(self):
"""Mock SSH client for testing"""
mock_client = Mock()
mock_client.connect = Mock()
mock_client.exec_command = Mock()
mock_client.close = Mock()
return mock_client
def test_interface_initialization_creates_correct_configuration(self, mock_config):
"""Test that router interface initializes with correct configuration"""
# Act
interface = RouterInterface(mock_config)
# Assert
assert interface is not None
assert interface.router_ip == mock_config['router_ip']
assert interface.username == mock_config['username']
assert interface.password == mock_config['password']
assert interface.ssh_port == mock_config['ssh_port']
assert interface.timeout == mock_config['timeout']
assert interface.max_retries == mock_config['max_retries']
assert not interface.is_connected
@patch('paramiko.SSHClient')
def test_connect_establishes_ssh_connection(self, mock_ssh_class, router_interface, mock_ssh_client):
"""Test that connect method establishes SSH connection"""
# Arrange
mock_ssh_class.return_value = mock_ssh_client
# Act
result = router_interface.connect()
# Assert
assert result is True
assert router_interface.is_connected is True
mock_ssh_client.set_missing_host_key_policy.assert_called_once()
mock_ssh_client.connect.assert_called_once_with(
hostname=router_interface.router_ip,
port=router_interface.ssh_port,
username=router_interface.username,
password=router_interface.password,
timeout=router_interface.timeout
)
@patch('paramiko.SSHClient')
def test_connect_handles_connection_failure(self, mock_ssh_class, router_interface, mock_ssh_client):
"""Test that connect method handles connection failures gracefully"""
# Arrange
mock_ssh_class.return_value = mock_ssh_client
mock_ssh_client.connect.side_effect = Exception("Connection failed")
# Act & Assert
with pytest.raises(RouterConnectionError):
router_interface.connect()
assert router_interface.is_connected is False
@patch('paramiko.SSHClient')
def test_disconnect_closes_ssh_connection(self, mock_ssh_class, router_interface, mock_ssh_client):
"""Test that disconnect method closes SSH connection"""
# Arrange
mock_ssh_class.return_value = mock_ssh_client
router_interface.connect()
# Act
router_interface.disconnect()
# Assert
assert router_interface.is_connected is False
mock_ssh_client.close.assert_called_once()
@patch('paramiko.SSHClient')
def test_execute_command_runs_ssh_command(self, mock_ssh_class, router_interface, mock_ssh_client):
"""Test that execute_command runs SSH commands correctly"""
# Arrange
mock_ssh_class.return_value = mock_ssh_client
mock_stdout = Mock()
mock_stdout.read.return_value = b"command output"
mock_stderr = Mock()
mock_stderr.read.return_value = b""
mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr)
router_interface.connect()
# Act
result = router_interface.execute_command("test command")
# Assert
assert result == "command output"
mock_ssh_client.exec_command.assert_called_with("test command")
@patch('paramiko.SSHClient')
def test_execute_command_handles_command_errors(self, mock_ssh_class, router_interface, mock_ssh_client):
"""Test that execute_command handles command errors"""
# Arrange
mock_ssh_class.return_value = mock_ssh_client
mock_stdout = Mock()
mock_stdout.read.return_value = b""
mock_stderr = Mock()
mock_stderr.read.return_value = b"command error"
mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr)
router_interface.connect()
# Act & Assert
with pytest.raises(RouterConnectionError):
router_interface.execute_command("failing command")
def test_execute_command_requires_connection(self, router_interface):
"""Test that execute_command requires active connection"""
# Act & Assert
with pytest.raises(RouterConnectionError):
router_interface.execute_command("test command")
@patch('paramiko.SSHClient')
def test_get_router_info_retrieves_system_information(self, mock_ssh_class, router_interface, mock_ssh_client):
"""Test that get_router_info retrieves router system information"""
# Arrange
mock_ssh_class.return_value = mock_ssh_client
mock_stdout = Mock()
mock_stdout.read.return_value = b"Router Model: AC1900\nFirmware: 1.2.3"
mock_stderr = Mock()
mock_stderr.read.return_value = b""
mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr)
router_interface.connect()
# Act
info = router_interface.get_router_info()
# Assert
assert info is not None
assert isinstance(info, dict)
assert 'model' in info
assert 'firmware' in info
@patch('paramiko.SSHClient')
def test_enable_monitor_mode_configures_wifi_monitoring(self, mock_ssh_class, router_interface, mock_ssh_client):
"""Test that enable_monitor_mode configures WiFi monitoring"""
# Arrange
mock_ssh_class.return_value = mock_ssh_client
mock_stdout = Mock()
mock_stdout.read.return_value = b"Monitor mode enabled"
mock_stderr = Mock()
mock_stderr.read.return_value = b""
mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr)
router_interface.connect()
# Act
result = router_interface.enable_monitor_mode("wlan0")
# Assert
assert result is True
mock_ssh_client.exec_command.assert_called()
@patch('paramiko.SSHClient')
def test_disable_monitor_mode_disables_wifi_monitoring(self, mock_ssh_class, router_interface, mock_ssh_client):
"""Test that disable_monitor_mode disables WiFi monitoring"""
# Arrange
mock_ssh_class.return_value = mock_ssh_client
mock_stdout = Mock()
mock_stdout.read.return_value = b"Monitor mode disabled"
mock_stderr = Mock()
mock_stderr.read.return_value = b""
mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr)
router_interface.connect()
# Act
result = router_interface.disable_monitor_mode("wlan0")
# Assert
assert result is True
mock_ssh_client.exec_command.assert_called()
@patch('paramiko.SSHClient')
def test_interface_supports_context_manager(self, mock_ssh_class, router_interface, mock_ssh_client):
"""Test that router interface supports context manager protocol"""
# Arrange
mock_ssh_class.return_value = mock_ssh_client
# Act
with router_interface as interface:
# Assert
assert interface.is_connected is True
# Assert - connection should be closed after context
assert router_interface.is_connected is False
mock_ssh_client.close.assert_called_once()
def test_interface_validates_configuration(self):
"""Test that router interface validates configuration parameters"""
# Arrange
invalid_config = {
'router_ip': '', # Invalid IP
'username': 'admin',
'password': 'password'
}
# Act & Assert
with pytest.raises(ValueError):
RouterInterface(invalid_config)
@patch('paramiko.SSHClient')
def test_interface_implements_retry_logic(self, mock_ssh_class, router_interface, mock_ssh_client):
"""Test that interface implements retry logic for failed operations"""
# Arrange
mock_ssh_class.return_value = mock_ssh_client
mock_ssh_client.connect.side_effect = [Exception("Temp failure"), None] # Fail once, then succeed
# Act
result = router_interface.connect()
# Assert
assert result is True
assert mock_ssh_client.connect.call_count == 2 # Should retry once