Implement WiFi-DensePose system with CSI data extraction and router interface

- Added CSIExtractor class for extracting CSI data from WiFi routers.
- Implemented RouterInterface class for SSH communication with routers.
- Developed DensePoseHead class for body part segmentation and UV coordinate regression.
- Created unit tests for CSIExtractor and RouterInterface to ensure functionality and error handling.
- Integrated paramiko for SSH connections and command execution.
- Established configuration validation for both extractor and router interface.
- Added context manager support for resource management in both classes.
This commit is contained in:
rUv
2025-06-07 05:55:27 +00:00
parent 44e5382931
commit cbebdd648f
14 changed files with 2871 additions and 213 deletions

View File

@@ -0,0 +1,264 @@
import pytest
import numpy as np
import torch
from unittest.mock import Mock, patch, MagicMock
from src.hardware.csi_extractor import CSIExtractor, CSIExtractionError
class TestCSIExtractor:
"""Test suite for CSI Extractor following London School TDD principles"""
@pytest.fixture
def mock_config(self):
"""Configuration for CSI extractor"""
return {
'interface': 'wlan0',
'channel': 6,
'bandwidth': 20,
'sample_rate': 1000,
'buffer_size': 1024,
'extraction_timeout': 5.0
}
@pytest.fixture
def mock_router_interface(self):
"""Mock router interface for testing"""
mock_router = Mock()
mock_router.is_connected = True
mock_router.execute_command = Mock()
return mock_router
@pytest.fixture
def csi_extractor(self, mock_config, mock_router_interface):
"""Create CSI extractor instance for testing"""
return CSIExtractor(mock_config, mock_router_interface)
@pytest.fixture
def mock_csi_data(self):
"""Generate synthetic CSI data for testing"""
# Simulate CSI data: complex values for multiple subcarriers
num_subcarriers = 56
num_antennas = 3
amplitude = np.random.uniform(0.1, 2.0, (num_antennas, num_subcarriers))
phase = np.random.uniform(-np.pi, np.pi, (num_antennas, num_subcarriers))
return amplitude * np.exp(1j * phase)
def test_extractor_initialization_creates_correct_configuration(self, mock_config, mock_router_interface):
"""Test that CSI extractor initializes with correct configuration"""
# Act
extractor = CSIExtractor(mock_config, mock_router_interface)
# Assert
assert extractor is not None
assert extractor.interface == mock_config['interface']
assert extractor.channel == mock_config['channel']
assert extractor.bandwidth == mock_config['bandwidth']
assert extractor.sample_rate == mock_config['sample_rate']
assert extractor.buffer_size == mock_config['buffer_size']
assert extractor.extraction_timeout == mock_config['extraction_timeout']
assert extractor.router_interface == mock_router_interface
assert not extractor.is_extracting
def test_start_extraction_configures_monitor_mode(self, csi_extractor, mock_router_interface):
"""Test that start_extraction configures monitor mode"""
# Arrange
mock_router_interface.enable_monitor_mode.return_value = True
mock_router_interface.execute_command.return_value = "CSI extraction started"
# Act
result = csi_extractor.start_extraction()
# Assert
assert result is True
assert csi_extractor.is_extracting is True
mock_router_interface.enable_monitor_mode.assert_called_once_with(csi_extractor.interface)
def test_start_extraction_handles_monitor_mode_failure(self, csi_extractor, mock_router_interface):
"""Test that start_extraction handles monitor mode configuration failure"""
# Arrange
mock_router_interface.enable_monitor_mode.return_value = False
# Act & Assert
with pytest.raises(CSIExtractionError):
csi_extractor.start_extraction()
assert csi_extractor.is_extracting is False
def test_stop_extraction_disables_monitor_mode(self, csi_extractor, mock_router_interface):
"""Test that stop_extraction disables monitor mode"""
# Arrange
mock_router_interface.enable_monitor_mode.return_value = True
mock_router_interface.disable_monitor_mode.return_value = True
mock_router_interface.execute_command.return_value = "CSI extraction started"
csi_extractor.start_extraction()
# Act
result = csi_extractor.stop_extraction()
# Assert
assert result is True
assert csi_extractor.is_extracting is False
mock_router_interface.disable_monitor_mode.assert_called_once_with(csi_extractor.interface)
def test_extract_csi_data_returns_valid_format(self, csi_extractor, mock_router_interface, mock_csi_data):
"""Test that extract_csi_data returns data in valid format"""
# Arrange
mock_router_interface.enable_monitor_mode.return_value = True
mock_router_interface.execute_command.return_value = "CSI extraction started"
# Mock the CSI data extraction
with patch.object(csi_extractor, '_parse_csi_output', return_value=mock_csi_data):
csi_extractor.start_extraction()
# Act
csi_data = csi_extractor.extract_csi_data()
# Assert
assert csi_data is not None
assert isinstance(csi_data, np.ndarray)
assert csi_data.dtype == np.complex128
assert csi_data.shape == mock_csi_data.shape
def test_extract_csi_data_requires_active_extraction(self, csi_extractor):
"""Test that extract_csi_data requires active extraction"""
# Act & Assert
with pytest.raises(CSIExtractionError):
csi_extractor.extract_csi_data()
def test_extract_csi_data_handles_timeout(self, csi_extractor, mock_router_interface):
"""Test that extract_csi_data handles extraction timeout"""
# Arrange
mock_router_interface.enable_monitor_mode.return_value = True
mock_router_interface.execute_command.side_effect = [
"CSI extraction started",
Exception("Timeout")
]
csi_extractor.start_extraction()
# Act & Assert
with pytest.raises(CSIExtractionError):
csi_extractor.extract_csi_data()
def test_convert_to_tensor_produces_correct_format(self, csi_extractor, mock_csi_data):
"""Test that convert_to_tensor produces correctly formatted tensor"""
# Act
tensor = csi_extractor.convert_to_tensor(mock_csi_data)
# Assert
assert isinstance(tensor, torch.Tensor)
assert tensor.dtype == torch.float32
assert tensor.shape[0] == mock_csi_data.shape[0] * 2 # Real and imaginary parts
assert tensor.shape[1] == mock_csi_data.shape[1]
def test_convert_to_tensor_handles_invalid_input(self, csi_extractor):
"""Test that convert_to_tensor handles invalid input"""
# Arrange
invalid_data = "not an array"
# Act & Assert
with pytest.raises(ValueError):
csi_extractor.convert_to_tensor(invalid_data)
def test_get_extraction_stats_returns_valid_statistics(self, csi_extractor, mock_router_interface):
"""Test that get_extraction_stats returns valid statistics"""
# Arrange
mock_router_interface.enable_monitor_mode.return_value = True
mock_router_interface.execute_command.return_value = "CSI extraction started"
csi_extractor.start_extraction()
# Act
stats = csi_extractor.get_extraction_stats()
# Assert
assert stats is not None
assert isinstance(stats, dict)
assert 'samples_extracted' in stats
assert 'extraction_rate' in stats
assert 'buffer_utilization' in stats
assert 'last_extraction_time' in stats
def test_set_channel_configures_wifi_channel(self, csi_extractor, mock_router_interface):
"""Test that set_channel configures WiFi channel"""
# Arrange
new_channel = 11
mock_router_interface.execute_command.return_value = f"Channel set to {new_channel}"
# Act
result = csi_extractor.set_channel(new_channel)
# Assert
assert result is True
assert csi_extractor.channel == new_channel
mock_router_interface.execute_command.assert_called()
def test_set_channel_validates_channel_range(self, csi_extractor):
"""Test that set_channel validates channel range"""
# Act & Assert
with pytest.raises(ValueError):
csi_extractor.set_channel(0) # Invalid channel
with pytest.raises(ValueError):
csi_extractor.set_channel(15) # Invalid channel
def test_extractor_supports_context_manager(self, csi_extractor, mock_router_interface):
"""Test that CSI extractor supports context manager protocol"""
# Arrange
mock_router_interface.enable_monitor_mode.return_value = True
mock_router_interface.disable_monitor_mode.return_value = True
mock_router_interface.execute_command.return_value = "CSI extraction started"
# Act
with csi_extractor as extractor:
# Assert
assert extractor.is_extracting is True
# Assert - extraction should be stopped after context
assert csi_extractor.is_extracting is False
def test_extractor_validates_configuration(self, mock_router_interface):
"""Test that CSI extractor validates configuration parameters"""
# Arrange
invalid_config = {
'interface': '', # Invalid interface
'channel': 6,
'bandwidth': 20
}
# Act & Assert
with pytest.raises(ValueError):
CSIExtractor(invalid_config, mock_router_interface)
def test_parse_csi_output_processes_raw_data(self, csi_extractor):
"""Test that _parse_csi_output processes raw CSI data correctly"""
# Arrange
raw_output = "CSI_DATA: 1.5+0.5j,2.0-1.0j,0.8+1.2j"
# Act
parsed_data = csi_extractor._parse_csi_output(raw_output)
# Assert
assert parsed_data is not None
assert isinstance(parsed_data, np.ndarray)
assert parsed_data.dtype == np.complex128
def test_buffer_management_handles_overflow(self, csi_extractor, mock_router_interface, mock_csi_data):
"""Test that buffer management handles overflow correctly"""
# Arrange
mock_router_interface.enable_monitor_mode.return_value = True
mock_router_interface.execute_command.return_value = "CSI extraction started"
with patch.object(csi_extractor, '_parse_csi_output', return_value=mock_csi_data):
csi_extractor.start_extraction()
# Fill buffer beyond capacity
for _ in range(csi_extractor.buffer_size + 10):
csi_extractor._add_to_buffer(mock_csi_data)
# Act
stats = csi_extractor.get_extraction_stats()
# Assert
assert stats['buffer_utilization'] <= 1.0 # Should not exceed 100%

View File

@@ -3,27 +3,27 @@ import torch
import torch.nn as nn
import numpy as np
from unittest.mock import Mock, patch
from src.models.densepose_head import DensePoseHead
from src.models.densepose_head import DensePoseHead, DensePoseError
class TestDensePoseHead:
"""Test suite for DensePose Head following London School TDD principles"""
@pytest.fixture
def mock_feature_input(self):
"""Generate synthetic feature input tensor for testing"""
# Batch size 2, 512 channels, 56 height, 100 width (from modality translation)
return torch.randn(2, 512, 56, 100)
@pytest.fixture
def mock_config(self):
"""Configuration for DensePose head"""
return {
'input_channels': 512,
'num_body_parts': 24, # Standard DensePose body parts
'num_uv_coordinates': 2, # U and V coordinates
'hidden_dim': 256,
'dropout_rate': 0.1
'input_channels': 256,
'num_body_parts': 24,
'num_uv_coordinates': 2,
'hidden_channels': [128, 64],
'kernel_size': 3,
'padding': 1,
'dropout_rate': 0.1,
'use_deformable_conv': False,
'use_fpn': True,
'fpn_levels': [2, 3, 4, 5],
'output_stride': 4
}
@pytest.fixture
@@ -31,6 +31,33 @@ class TestDensePoseHead:
"""Create DensePose head instance for testing"""
return DensePoseHead(mock_config)
@pytest.fixture
def mock_feature_input(self):
"""Generate mock feature input tensor"""
batch_size = 2
channels = 256
height = 56
width = 56
return torch.randn(batch_size, channels, height, width)
@pytest.fixture
def mock_target_masks(self):
"""Generate mock target segmentation masks"""
batch_size = 2
num_parts = 24
height = 224
width = 224
return torch.randint(0, num_parts + 1, (batch_size, height, width))
@pytest.fixture
def mock_target_uv(self):
"""Generate mock target UV coordinates"""
batch_size = 2
num_coords = 2
height = 224
width = 224
return torch.randn(batch_size, num_coords, height, width)
def test_head_initialization_creates_correct_architecture(self, mock_config):
"""Test that DensePose head initializes with correct architecture"""
# Act
@@ -39,135 +66,302 @@ class TestDensePoseHead:
# Assert
assert head is not None
assert isinstance(head, nn.Module)
assert hasattr(head, 'segmentation_head')
assert hasattr(head, 'uv_regression_head')
assert head.input_channels == mock_config['input_channels']
assert head.num_body_parts == mock_config['num_body_parts']
assert head.num_uv_coordinates == mock_config['num_uv_coordinates']
assert head.use_fpn == mock_config['use_fpn']
assert hasattr(head, 'segmentation_head')
assert hasattr(head, 'uv_regression_head')
if mock_config['use_fpn']:
assert hasattr(head, 'fpn')
def test_forward_pass_produces_correct_output_shapes(self, densepose_head, mock_feature_input):
"""Test that forward pass produces correctly shaped outputs"""
def test_forward_pass_produces_correct_output_format(self, densepose_head, mock_feature_input):
"""Test that forward pass produces correctly formatted output"""
# Act
with torch.no_grad():
segmentation, uv_coords = densepose_head(mock_feature_input)
output = densepose_head(mock_feature_input)
# Assert
assert segmentation is not None
assert uv_coords is not None
assert isinstance(segmentation, torch.Tensor)
assert isinstance(uv_coords, torch.Tensor)
assert output is not None
assert isinstance(output, dict)
assert 'segmentation' in output
assert 'uv_coordinates' in output
# Check segmentation output shape
assert segmentation.shape[0] == mock_feature_input.shape[0] # Batch size preserved
assert segmentation.shape[1] == densepose_head.num_body_parts # Correct number of body parts
assert segmentation.shape[2:] == mock_feature_input.shape[2:] # Spatial dimensions preserved
seg_output = output['segmentation']
uv_output = output['uv_coordinates']
# Check UV coordinates output shape
assert uv_coords.shape[0] == mock_feature_input.shape[0] # Batch size preserved
assert uv_coords.shape[1] == densepose_head.num_uv_coordinates # U and V coordinates
assert uv_coords.shape[2:] == mock_feature_input.shape[2:] # Spatial dimensions preserved
assert isinstance(seg_output, torch.Tensor)
assert isinstance(uv_output, torch.Tensor)
assert seg_output.shape[0] == mock_feature_input.shape[0] # Batch size preserved
assert uv_output.shape[0] == mock_feature_input.shape[0] # Batch size preserved
def test_segmentation_output_has_valid_probabilities(self, densepose_head, mock_feature_input):
"""Test that segmentation output has valid probability distributions"""
def test_segmentation_head_produces_correct_shape(self, densepose_head, mock_feature_input):
"""Test that segmentation head produces correct output shape"""
# Act
with torch.no_grad():
segmentation, _ = densepose_head(mock_feature_input)
output = densepose_head(mock_feature_input)
seg_output = output['segmentation']
# Assert
# After softmax, values should be between 0 and 1
assert torch.all(segmentation >= 0.0)
assert torch.all(segmentation <= 1.0)
# Sum across body parts dimension should be approximately 1
part_sums = torch.sum(segmentation, dim=1)
assert torch.allclose(part_sums, torch.ones_like(part_sums), atol=1e-5)
expected_channels = densepose_head.num_body_parts + 1 # +1 for background
assert seg_output.shape[1] == expected_channels
assert seg_output.shape[2] >= mock_feature_input.shape[2] # Height upsampled
assert seg_output.shape[3] >= mock_feature_input.shape[3] # Width upsampled
def test_uv_coordinates_output_in_valid_range(self, densepose_head, mock_feature_input):
"""Test that UV coordinates are in valid range [0, 1]"""
def test_uv_regression_head_produces_correct_shape(self, densepose_head, mock_feature_input):
"""Test that UV regression head produces correct output shape"""
# Act
with torch.no_grad():
_, uv_coords = densepose_head(mock_feature_input)
output = densepose_head(mock_feature_input)
uv_output = output['uv_coordinates']
# Assert
# UV coordinates should be in range [0, 1] after sigmoid
assert torch.all(uv_coords >= 0.0)
assert torch.all(uv_coords <= 1.0)
assert uv_output.shape[1] == densepose_head.num_uv_coordinates
assert uv_output.shape[2] >= mock_feature_input.shape[2] # Height upsampled
assert uv_output.shape[3] >= mock_feature_input.shape[3] # Width upsampled
def test_head_handles_different_batch_sizes(self, densepose_head):
"""Test that head handles different batch sizes correctly"""
def test_compute_segmentation_loss_measures_pixel_classification(self, densepose_head, mock_feature_input, mock_target_masks):
"""Test that compute_segmentation_loss measures pixel classification accuracy"""
# Arrange
batch_sizes = [1, 4, 8]
output = densepose_head(mock_feature_input)
seg_logits = output['segmentation']
for batch_size in batch_sizes:
input_tensor = torch.randn(batch_size, 512, 56, 100)
# Act
with torch.no_grad():
segmentation, uv_coords = densepose_head(input_tensor)
# Assert
assert segmentation.shape[0] == batch_size
assert uv_coords.shape[0] == batch_size
def test_head_is_trainable(self, densepose_head, mock_feature_input):
"""Test that head parameters are trainable"""
# Arrange
seg_criterion = nn.CrossEntropyLoss()
uv_criterion = nn.MSELoss()
# Create targets with correct shapes
seg_target = torch.randint(0, 24, (2, 56, 100)) # Class indices for segmentation
uv_target = torch.rand(2, 2, 56, 100) # UV coordinates target
# Resize target to match output
target_resized = torch.nn.functional.interpolate(
mock_target_masks.float().unsqueeze(1),
size=seg_logits.shape[2:],
mode='nearest'
).squeeze(1).long()
# Act
segmentation, uv_coords = densepose_head(mock_feature_input)
seg_loss = seg_criterion(segmentation, seg_target)
uv_loss = uv_criterion(uv_coords, uv_target)
total_loss = seg_loss + uv_loss
total_loss.backward()
loss = densepose_head.compute_segmentation_loss(seg_logits, target_resized)
# Assert
assert loss is not None
assert isinstance(loss, torch.Tensor)
assert loss.dim() == 0 # Scalar loss
assert loss.item() >= 0 # Loss should be non-negative
def test_compute_uv_loss_measures_coordinate_regression(self, densepose_head, mock_feature_input, mock_target_uv):
"""Test that compute_uv_loss measures UV coordinate regression accuracy"""
# Arrange
output = densepose_head(mock_feature_input)
uv_pred = output['uv_coordinates']
# Resize target to match output
target_resized = torch.nn.functional.interpolate(
mock_target_uv,
size=uv_pred.shape[2:],
mode='bilinear',
align_corners=False
)
# Act
loss = densepose_head.compute_uv_loss(uv_pred, target_resized)
# Assert
assert loss is not None
assert isinstance(loss, torch.Tensor)
assert loss.dim() == 0 # Scalar loss
assert loss.item() >= 0 # Loss should be non-negative
def test_compute_total_loss_combines_segmentation_and_uv_losses(self, densepose_head, mock_feature_input, mock_target_masks, mock_target_uv):
"""Test that compute_total_loss combines segmentation and UV losses"""
# Arrange
output = densepose_head(mock_feature_input)
# Resize targets to match outputs
seg_target = torch.nn.functional.interpolate(
mock_target_masks.float().unsqueeze(1),
size=output['segmentation'].shape[2:],
mode='nearest'
).squeeze(1).long()
uv_target = torch.nn.functional.interpolate(
mock_target_uv,
size=output['uv_coordinates'].shape[2:],
mode='bilinear',
align_corners=False
)
# Act
total_loss = densepose_head.compute_total_loss(output, seg_target, uv_target)
seg_loss = densepose_head.compute_segmentation_loss(output['segmentation'], seg_target)
uv_loss = densepose_head.compute_uv_loss(output['uv_coordinates'], uv_target)
# Assert
assert total_loss is not None
assert isinstance(total_loss, torch.Tensor)
assert total_loss.item() > 0
# Check that gradients are computed
# Total loss should be combination of individual losses
expected_total = seg_loss + uv_loss
assert torch.allclose(total_loss, expected_total, atol=1e-6)
def test_fpn_integration_enhances_multi_scale_features(self, mock_config, mock_feature_input):
"""Test that FPN integration enhances multi-scale feature processing"""
# Arrange
config_with_fpn = mock_config.copy()
config_with_fpn['use_fpn'] = True
config_without_fpn = mock_config.copy()
config_without_fpn['use_fpn'] = False
head_with_fpn = DensePoseHead(config_with_fpn)
head_without_fpn = DensePoseHead(config_without_fpn)
# Act
output_with_fpn = head_with_fpn(mock_feature_input)
output_without_fpn = head_without_fpn(mock_feature_input)
# Assert
assert output_with_fpn['segmentation'].shape == output_without_fpn['segmentation'].shape
assert output_with_fpn['uv_coordinates'].shape == output_without_fpn['uv_coordinates'].shape
# Outputs should be different due to FPN
assert not torch.allclose(output_with_fpn['segmentation'], output_without_fpn['segmentation'], atol=1e-6)
def test_get_prediction_confidence_provides_uncertainty_estimates(self, densepose_head, mock_feature_input):
"""Test that get_prediction_confidence provides uncertainty estimates"""
# Arrange
output = densepose_head(mock_feature_input)
# Act
confidence = densepose_head.get_prediction_confidence(output)
# Assert
assert confidence is not None
assert isinstance(confidence, dict)
assert 'segmentation_confidence' in confidence
assert 'uv_confidence' in confidence
seg_conf = confidence['segmentation_confidence']
uv_conf = confidence['uv_confidence']
assert isinstance(seg_conf, torch.Tensor)
assert isinstance(uv_conf, torch.Tensor)
assert seg_conf.shape[0] == mock_feature_input.shape[0]
assert uv_conf.shape[0] == mock_feature_input.shape[0]
def test_post_process_predictions_formats_output(self, densepose_head, mock_feature_input):
"""Test that post_process_predictions formats output correctly"""
# Arrange
raw_output = densepose_head(mock_feature_input)
# Act
processed = densepose_head.post_process_predictions(raw_output)
# Assert
assert processed is not None
assert isinstance(processed, dict)
assert 'body_parts' in processed
assert 'uv_coordinates' in processed
assert 'confidence_scores' in processed
def test_training_mode_enables_dropout(self, densepose_head, mock_feature_input):
"""Test that training mode enables dropout for regularization"""
# Arrange
densepose_head.train()
# Act
output1 = densepose_head(mock_feature_input)
output2 = densepose_head(mock_feature_input)
# Assert - outputs should be different due to dropout
assert not torch.allclose(output1['segmentation'], output2['segmentation'], atol=1e-6)
assert not torch.allclose(output1['uv_coordinates'], output2['uv_coordinates'], atol=1e-6)
def test_evaluation_mode_disables_dropout(self, densepose_head, mock_feature_input):
"""Test that evaluation mode disables dropout for consistent inference"""
# Arrange
densepose_head.eval()
# Act
output1 = densepose_head(mock_feature_input)
output2 = densepose_head(mock_feature_input)
# Assert - outputs should be identical in eval mode
assert torch.allclose(output1['segmentation'], output2['segmentation'], atol=1e-6)
assert torch.allclose(output1['uv_coordinates'], output2['uv_coordinates'], atol=1e-6)
def test_head_validates_input_dimensions(self, densepose_head):
"""Test that head validates input dimensions"""
# Arrange
invalid_input = torch.randn(2, 128, 56, 56) # Wrong number of channels
# Act & Assert
with pytest.raises(DensePoseError):
densepose_head(invalid_input)
def test_head_handles_different_input_sizes(self, densepose_head):
"""Test that head handles different input sizes"""
# Arrange
small_input = torch.randn(1, 256, 28, 28)
large_input = torch.randn(1, 256, 112, 112)
# Act
small_output = densepose_head(small_input)
large_output = densepose_head(large_input)
# Assert
assert small_output['segmentation'].shape[2:] != large_output['segmentation'].shape[2:]
assert small_output['uv_coordinates'].shape[2:] != large_output['uv_coordinates'].shape[2:]
def test_head_supports_gradient_computation(self, densepose_head, mock_feature_input, mock_target_masks, mock_target_uv):
"""Test that head supports gradient computation for training"""
# Arrange
densepose_head.train()
optimizer = torch.optim.Adam(densepose_head.parameters(), lr=0.001)
output = densepose_head(mock_feature_input)
# Resize targets
seg_target = torch.nn.functional.interpolate(
mock_target_masks.float().unsqueeze(1),
size=output['segmentation'].shape[2:],
mode='nearest'
).squeeze(1).long()
uv_target = torch.nn.functional.interpolate(
mock_target_uv,
size=output['uv_coordinates'].shape[2:],
mode='bilinear',
align_corners=False
)
# Act
loss = densepose_head.compute_total_loss(output, seg_target, uv_target)
optimizer.zero_grad()
loss.backward()
# Assert
for param in densepose_head.parameters():
if param.requires_grad:
assert param.grad is not None
assert not torch.allclose(param.grad, torch.zeros_like(param.grad))
def test_head_handles_invalid_input_shape(self, densepose_head):
"""Test that head handles invalid input shapes gracefully"""
def test_head_configuration_validation(self):
"""Test that head validates configuration parameters"""
# Arrange
invalid_input = torch.randn(2, 256, 56, 100) # Wrong number of channels
invalid_config = {
'input_channels': 0, # Invalid
'num_body_parts': -1, # Invalid
'num_uv_coordinates': 2
}
# Act & Assert
with pytest.raises(RuntimeError):
densepose_head(invalid_input)
with pytest.raises(ValueError):
DensePoseHead(invalid_config)
def test_head_supports_evaluation_mode(self, densepose_head, mock_feature_input):
"""Test that head supports evaluation mode"""
# Act
densepose_head.eval()
def test_save_and_load_model_state(self, densepose_head, mock_feature_input):
"""Test that model state can be saved and loaded"""
# Arrange
original_output = densepose_head(mock_feature_input)
with torch.no_grad():
seg1, uv1 = densepose_head(mock_feature_input)
seg2, uv2 = densepose_head(mock_feature_input)
# Act - Save state
state_dict = densepose_head.state_dict()
# Assert - In eval mode with same input, outputs should be identical
assert torch.allclose(seg1, seg2, atol=1e-6)
assert torch.allclose(uv1, uv2, atol=1e-6)
def test_head_output_quality(self, densepose_head, mock_feature_input):
"""Test that head produces meaningful outputs"""
# Act
with torch.no_grad():
segmentation, uv_coords = densepose_head(mock_feature_input)
# Create new head and load state
new_head = DensePoseHead(densepose_head.config)
new_head.load_state_dict(state_dict)
new_output = new_head(mock_feature_input)
# Assert
# Outputs should not contain NaN or Inf values
assert not torch.isnan(segmentation).any()
assert not torch.isinf(segmentation).any()
assert not torch.isnan(uv_coords).any()
assert not torch.isinf(uv_coords).any()
# Outputs should have reasonable variance (not all zeros or ones)
assert segmentation.std() > 0.01
assert uv_coords.std() > 0.01
assert torch.allclose(original_output['segmentation'], new_output['segmentation'], atol=1e-6)
assert torch.allclose(original_output['uv_coordinates'], new_output['uv_coordinates'], atol=1e-6)

View File

@@ -3,27 +3,27 @@ import torch
import torch.nn as nn
import numpy as np
from unittest.mock import Mock, patch
from src.models.modality_translation import ModalityTranslationNetwork
from src.models.modality_translation import ModalityTranslationNetwork, ModalityTranslationError
class TestModalityTranslationNetwork:
"""Test suite for Modality Translation Network following London School TDD principles"""
@pytest.fixture
def mock_csi_input(self):
"""Generate synthetic CSI input tensor for testing"""
# Batch size 2, 3 antennas, 56 subcarriers, 100 temporal samples
return torch.randn(2, 3, 56, 100)
@pytest.fixture
def mock_config(self):
"""Configuration for modality translation network"""
return {
'input_channels': 3,
'hidden_dim': 256,
'output_dim': 512,
'num_layers': 3,
'dropout_rate': 0.1
'input_channels': 6, # Real and imaginary parts for 3 antennas
'hidden_channels': [64, 128, 256],
'output_channels': 256,
'kernel_size': 3,
'stride': 1,
'padding': 1,
'dropout_rate': 0.1,
'activation': 'relu',
'normalization': 'batch',
'use_attention': True,
'attention_heads': 8
}
@pytest.fixture
@@ -31,98 +31,263 @@ class TestModalityTranslationNetwork:
"""Create modality translation network instance for testing"""
return ModalityTranslationNetwork(mock_config)
@pytest.fixture
def mock_csi_input(self):
"""Generate mock CSI input tensor"""
batch_size = 4
channels = 6 # Real and imaginary parts for 3 antennas
height = 56 # Number of subcarriers
width = 100 # Time samples
return torch.randn(batch_size, channels, height, width)
@pytest.fixture
def mock_target_features(self):
"""Generate mock target feature tensor for training"""
batch_size = 4
feature_dim = 256
spatial_height = 56
spatial_width = 100
return torch.randn(batch_size, feature_dim, spatial_height, spatial_width)
def test_network_initialization_creates_correct_architecture(self, mock_config):
"""Test that network initializes with correct architecture"""
"""Test that modality translation network initializes with correct architecture"""
# Act
network = ModalityTranslationNetwork(mock_config)
# Assert
assert network is not None
assert isinstance(network, nn.Module)
assert network.input_channels == mock_config['input_channels']
assert network.output_channels == mock_config['output_channels']
assert network.use_attention == mock_config['use_attention']
assert hasattr(network, 'encoder')
assert hasattr(network, 'decoder')
assert network.input_channels == mock_config['input_channels']
assert network.hidden_dim == mock_config['hidden_dim']
assert network.output_dim == mock_config['output_dim']
if mock_config['use_attention']:
assert hasattr(network, 'attention')
def test_forward_pass_produces_correct_output_shape(self, translation_network, mock_csi_input):
"""Test that forward pass produces correctly shaped output"""
# Act
with torch.no_grad():
output = translation_network(mock_csi_input)
output = translation_network(mock_csi_input)
# Assert
assert output is not None
assert isinstance(output, torch.Tensor)
assert output.shape[0] == mock_csi_input.shape[0] # Batch size preserved
assert output.shape[1] == translation_network.output_dim # Correct output dimension
assert len(output.shape) == 4 # Should maintain spatial dimensions
assert output.shape[1] == translation_network.output_channels # Correct output channels
assert output.shape[2] == mock_csi_input.shape[2] # Spatial height preserved
assert output.shape[3] == mock_csi_input.shape[3] # Spatial width preserved
def test_forward_pass_handles_different_batch_sizes(self, translation_network):
"""Test that network handles different batch sizes correctly"""
def test_forward_pass_handles_different_input_sizes(self, translation_network):
"""Test that forward pass handles different input sizes"""
# Arrange
batch_sizes = [1, 4, 8]
small_input = torch.randn(2, 6, 28, 50)
large_input = torch.randn(8, 6, 112, 200)
for batch_size in batch_sizes:
input_tensor = torch.randn(batch_size, 3, 56, 100)
# Act
with torch.no_grad():
output = translation_network(input_tensor)
# Assert
assert output.shape[0] == batch_size
assert output.shape[1] == translation_network.output_dim
# Act
small_output = translation_network(small_input)
large_output = translation_network(large_input)
# Assert
assert small_output.shape == (2, 256, 28, 50)
assert large_output.shape == (8, 256, 112, 200)
def test_network_is_trainable(self, translation_network, mock_csi_input):
"""Test that network parameters are trainable"""
def test_encoder_extracts_hierarchical_features(self, translation_network, mock_csi_input):
"""Test that encoder extracts hierarchical features"""
# Act
features = translation_network.encode(mock_csi_input)
# Assert
assert features is not None
assert isinstance(features, list)
assert len(features) == len(translation_network.encoder)
# Check feature map sizes decrease with depth
for i in range(1, len(features)):
assert features[i].shape[2] <= features[i-1].shape[2] # Height decreases or stays same
assert features[i].shape[3] <= features[i-1].shape[3] # Width decreases or stays same
def test_decoder_reconstructs_target_features(self, translation_network, mock_csi_input):
"""Test that decoder reconstructs target feature representation"""
# Arrange
criterion = nn.MSELoss()
encoded_features = translation_network.encode(mock_csi_input)
# Act
decoded_output = translation_network.decode(encoded_features)
# Assert
assert decoded_output is not None
assert isinstance(decoded_output, torch.Tensor)
assert decoded_output.shape[1] == translation_network.output_channels
assert decoded_output.shape[2:] == mock_csi_input.shape[2:]
def test_attention_mechanism_enhances_features(self, mock_config, mock_csi_input):
"""Test that attention mechanism enhances feature representation"""
# Arrange
config_with_attention = mock_config.copy()
config_with_attention['use_attention'] = True
config_without_attention = mock_config.copy()
config_without_attention['use_attention'] = False
network_with_attention = ModalityTranslationNetwork(config_with_attention)
network_without_attention = ModalityTranslationNetwork(config_without_attention)
# Act
output_with_attention = network_with_attention(mock_csi_input)
output_without_attention = network_without_attention(mock_csi_input)
# Assert
assert output_with_attention.shape == output_without_attention.shape
# Outputs should be different due to attention mechanism
assert not torch.allclose(output_with_attention, output_without_attention, atol=1e-6)
def test_training_mode_enables_dropout(self, translation_network, mock_csi_input):
"""Test that training mode enables dropout for regularization"""
# Arrange
translation_network.train()
# Act
output1 = translation_network(mock_csi_input)
output2 = translation_network(mock_csi_input)
# Assert - outputs should be different due to dropout
assert not torch.allclose(output1, output2, atol=1e-6)
def test_evaluation_mode_disables_dropout(self, translation_network, mock_csi_input):
"""Test that evaluation mode disables dropout for consistent inference"""
# Arrange
translation_network.eval()
# Act
output1 = translation_network(mock_csi_input)
output2 = translation_network(mock_csi_input)
# Assert - outputs should be identical in eval mode
assert torch.allclose(output1, output2, atol=1e-6)
def test_compute_translation_loss_measures_feature_alignment(self, translation_network, mock_csi_input, mock_target_features):
"""Test that compute_translation_loss measures feature alignment"""
# Arrange
predicted_features = translation_network(mock_csi_input)
# Act
loss = translation_network.compute_translation_loss(predicted_features, mock_target_features)
# Assert
assert loss is not None
assert isinstance(loss, torch.Tensor)
assert loss.dim() == 0 # Scalar loss
assert loss.item() >= 0 # Loss should be non-negative
def test_compute_translation_loss_handles_different_loss_types(self, translation_network, mock_csi_input, mock_target_features):
"""Test that compute_translation_loss handles different loss types"""
# Arrange
predicted_features = translation_network(mock_csi_input)
# Act
mse_loss = translation_network.compute_translation_loss(predicted_features, mock_target_features, loss_type='mse')
l1_loss = translation_network.compute_translation_loss(predicted_features, mock_target_features, loss_type='l1')
# Assert
assert mse_loss is not None
assert l1_loss is not None
assert mse_loss.item() != l1_loss.item() # Different loss types should give different values
def test_get_feature_statistics_provides_analysis(self, translation_network, mock_csi_input):
"""Test that get_feature_statistics provides feature analysis"""
# Arrange
output = translation_network(mock_csi_input)
# Act
stats = translation_network.get_feature_statistics(output)
# Assert
assert stats is not None
assert isinstance(stats, dict)
assert 'mean' in stats
assert 'std' in stats
assert 'min' in stats
assert 'max' in stats
assert 'sparsity' in stats
def test_network_supports_gradient_computation(self, translation_network, mock_csi_input, mock_target_features):
"""Test that network supports gradient computation for training"""
# Arrange
translation_network.train()
optimizer = torch.optim.Adam(translation_network.parameters(), lr=0.001)
# Act
output = translation_network(mock_csi_input)
# Create target with same shape as output
target = torch.randn_like(output)
loss = criterion(output, target)
loss = translation_network.compute_translation_loss(output, mock_target_features)
optimizer.zero_grad()
loss.backward()
# Assert
assert loss.item() > 0
# Check that gradients are computed
for param in translation_network.parameters():
if param.requires_grad:
assert param.grad is not None
assert not torch.allclose(param.grad, torch.zeros_like(param.grad))
def test_network_handles_invalid_input_shape(self, translation_network):
"""Test that network handles invalid input shapes gracefully"""
def test_network_validates_input_dimensions(self, translation_network):
"""Test that network validates input dimensions"""
# Arrange
invalid_input = torch.randn(2, 5, 56, 100) # Wrong number of channels
invalid_input = torch.randn(4, 3, 56, 100) # Wrong number of channels
# Act & Assert
with pytest.raises(RuntimeError):
with pytest.raises(ModalityTranslationError):
translation_network(invalid_input)
def test_network_supports_evaluation_mode(self, translation_network, mock_csi_input):
"""Test that network supports evaluation mode"""
# Act
translation_network.eval()
def test_network_handles_batch_size_one(self, translation_network):
"""Test that network handles single sample inference"""
# Arrange
single_input = torch.randn(1, 6, 56, 100)
with torch.no_grad():
output1 = translation_network(mock_csi_input)
output2 = translation_network(mock_csi_input)
# Assert - In eval mode with same input, outputs should be identical
assert torch.allclose(output1, output2, atol=1e-6)
def test_network_feature_extraction_quality(self, translation_network, mock_csi_input):
"""Test that network extracts meaningful features"""
# Act
with torch.no_grad():
output = translation_network(mock_csi_input)
output = translation_network(single_input)
# Assert
# Features should have reasonable statistics
assert not torch.isnan(output).any()
assert not torch.isinf(output).any()
assert output.std() > 0.01 # Features should have some variance
assert output.std() < 10.0 # But not be too extreme
assert output.shape == (1, 256, 56, 100)
def test_save_and_load_model_state(self, translation_network, mock_csi_input):
"""Test that model state can be saved and loaded"""
# Arrange
original_output = translation_network(mock_csi_input)
# Act - Save state
state_dict = translation_network.state_dict()
# Create new network and load state
new_network = ModalityTranslationNetwork(translation_network.config)
new_network.load_state_dict(state_dict)
new_output = new_network(mock_csi_input)
# Assert
assert torch.allclose(original_output, new_output, atol=1e-6)
def test_network_configuration_validation(self):
"""Test that network validates configuration parameters"""
# Arrange
invalid_config = {
'input_channels': 0, # Invalid
'hidden_channels': [], # Invalid
'output_channels': 256
}
# Act & Assert
with pytest.raises(ValueError):
ModalityTranslationNetwork(invalid_config)
def test_feature_visualization_support(self, translation_network, mock_csi_input):
"""Test that network supports feature visualization"""
# Act
features = translation_network.get_intermediate_features(mock_csi_input)
# Assert
assert features is not None
assert isinstance(features, dict)
assert 'encoder_features' in features
assert 'decoder_features' in features
if translation_network.use_attention:
assert 'attention_weights' in features

View File

@@ -0,0 +1,244 @@
import pytest
import numpy as np
from unittest.mock import Mock, patch, MagicMock
from src.hardware.router_interface import RouterInterface, RouterConnectionError
class TestRouterInterface:
"""Test suite for Router Interface following London School TDD principles"""
@pytest.fixture
def mock_config(self):
"""Configuration for router interface"""
return {
'router_ip': '192.168.1.1',
'username': 'admin',
'password': 'password',
'ssh_port': 22,
'timeout': 30,
'max_retries': 3
}
@pytest.fixture
def router_interface(self, mock_config):
"""Create router interface instance for testing"""
return RouterInterface(mock_config)
@pytest.fixture
def mock_ssh_client(self):
"""Mock SSH client for testing"""
mock_client = Mock()
mock_client.connect = Mock()
mock_client.exec_command = Mock()
mock_client.close = Mock()
return mock_client
def test_interface_initialization_creates_correct_configuration(self, mock_config):
"""Test that router interface initializes with correct configuration"""
# Act
interface = RouterInterface(mock_config)
# Assert
assert interface is not None
assert interface.router_ip == mock_config['router_ip']
assert interface.username == mock_config['username']
assert interface.password == mock_config['password']
assert interface.ssh_port == mock_config['ssh_port']
assert interface.timeout == mock_config['timeout']
assert interface.max_retries == mock_config['max_retries']
assert not interface.is_connected
@patch('paramiko.SSHClient')
def test_connect_establishes_ssh_connection(self, mock_ssh_class, router_interface, mock_ssh_client):
"""Test that connect method establishes SSH connection"""
# Arrange
mock_ssh_class.return_value = mock_ssh_client
# Act
result = router_interface.connect()
# Assert
assert result is True
assert router_interface.is_connected is True
mock_ssh_client.set_missing_host_key_policy.assert_called_once()
mock_ssh_client.connect.assert_called_once_with(
hostname=router_interface.router_ip,
port=router_interface.ssh_port,
username=router_interface.username,
password=router_interface.password,
timeout=router_interface.timeout
)
@patch('paramiko.SSHClient')
def test_connect_handles_connection_failure(self, mock_ssh_class, router_interface, mock_ssh_client):
"""Test that connect method handles connection failures gracefully"""
# Arrange
mock_ssh_class.return_value = mock_ssh_client
mock_ssh_client.connect.side_effect = Exception("Connection failed")
# Act & Assert
with pytest.raises(RouterConnectionError):
router_interface.connect()
assert router_interface.is_connected is False
@patch('paramiko.SSHClient')
def test_disconnect_closes_ssh_connection(self, mock_ssh_class, router_interface, mock_ssh_client):
"""Test that disconnect method closes SSH connection"""
# Arrange
mock_ssh_class.return_value = mock_ssh_client
router_interface.connect()
# Act
router_interface.disconnect()
# Assert
assert router_interface.is_connected is False
mock_ssh_client.close.assert_called_once()
@patch('paramiko.SSHClient')
def test_execute_command_runs_ssh_command(self, mock_ssh_class, router_interface, mock_ssh_client):
"""Test that execute_command runs SSH commands correctly"""
# Arrange
mock_ssh_class.return_value = mock_ssh_client
mock_stdout = Mock()
mock_stdout.read.return_value = b"command output"
mock_stderr = Mock()
mock_stderr.read.return_value = b""
mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr)
router_interface.connect()
# Act
result = router_interface.execute_command("test command")
# Assert
assert result == "command output"
mock_ssh_client.exec_command.assert_called_with("test command")
@patch('paramiko.SSHClient')
def test_execute_command_handles_command_errors(self, mock_ssh_class, router_interface, mock_ssh_client):
"""Test that execute_command handles command errors"""
# Arrange
mock_ssh_class.return_value = mock_ssh_client
mock_stdout = Mock()
mock_stdout.read.return_value = b""
mock_stderr = Mock()
mock_stderr.read.return_value = b"command error"
mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr)
router_interface.connect()
# Act & Assert
with pytest.raises(RouterConnectionError):
router_interface.execute_command("failing command")
def test_execute_command_requires_connection(self, router_interface):
"""Test that execute_command requires active connection"""
# Act & Assert
with pytest.raises(RouterConnectionError):
router_interface.execute_command("test command")
@patch('paramiko.SSHClient')
def test_get_router_info_retrieves_system_information(self, mock_ssh_class, router_interface, mock_ssh_client):
"""Test that get_router_info retrieves router system information"""
# Arrange
mock_ssh_class.return_value = mock_ssh_client
mock_stdout = Mock()
mock_stdout.read.return_value = b"Router Model: AC1900\nFirmware: 1.2.3"
mock_stderr = Mock()
mock_stderr.read.return_value = b""
mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr)
router_interface.connect()
# Act
info = router_interface.get_router_info()
# Assert
assert info is not None
assert isinstance(info, dict)
assert 'model' in info
assert 'firmware' in info
@patch('paramiko.SSHClient')
def test_enable_monitor_mode_configures_wifi_monitoring(self, mock_ssh_class, router_interface, mock_ssh_client):
"""Test that enable_monitor_mode configures WiFi monitoring"""
# Arrange
mock_ssh_class.return_value = mock_ssh_client
mock_stdout = Mock()
mock_stdout.read.return_value = b"Monitor mode enabled"
mock_stderr = Mock()
mock_stderr.read.return_value = b""
mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr)
router_interface.connect()
# Act
result = router_interface.enable_monitor_mode("wlan0")
# Assert
assert result is True
mock_ssh_client.exec_command.assert_called()
@patch('paramiko.SSHClient')
def test_disable_monitor_mode_disables_wifi_monitoring(self, mock_ssh_class, router_interface, mock_ssh_client):
"""Test that disable_monitor_mode disables WiFi monitoring"""
# Arrange
mock_ssh_class.return_value = mock_ssh_client
mock_stdout = Mock()
mock_stdout.read.return_value = b"Monitor mode disabled"
mock_stderr = Mock()
mock_stderr.read.return_value = b""
mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr)
router_interface.connect()
# Act
result = router_interface.disable_monitor_mode("wlan0")
# Assert
assert result is True
mock_ssh_client.exec_command.assert_called()
@patch('paramiko.SSHClient')
def test_interface_supports_context_manager(self, mock_ssh_class, router_interface, mock_ssh_client):
"""Test that router interface supports context manager protocol"""
# Arrange
mock_ssh_class.return_value = mock_ssh_client
# Act
with router_interface as interface:
# Assert
assert interface.is_connected is True
# Assert - connection should be closed after context
assert router_interface.is_connected is False
mock_ssh_client.close.assert_called_once()
def test_interface_validates_configuration(self):
"""Test that router interface validates configuration parameters"""
# Arrange
invalid_config = {
'router_ip': '', # Invalid IP
'username': 'admin',
'password': 'password'
}
# Act & Assert
with pytest.raises(ValueError):
RouterInterface(invalid_config)
@patch('paramiko.SSHClient')
def test_interface_implements_retry_logic(self, mock_ssh_class, router_interface, mock_ssh_client):
"""Test that interface implements retry logic for failed operations"""
# Arrange
mock_ssh_class.return_value = mock_ssh_client
mock_ssh_client.connect.side_effect = [Exception("Temp failure"), None] # Fail once, then succeed
# Act
result = router_interface.connect()
# Assert
assert result is True
assert mock_ssh_client.connect.call_count == 2 # Should retry once