- Added CSIExtractor class for extracting CSI data from WiFi routers. - Implemented RouterInterface class for SSH communication with routers. - Developed DensePoseHead class for body part segmentation and UV coordinate regression. - Created unit tests for CSIExtractor and RouterInterface to ensure functionality and error handling. - Integrated paramiko for SSH connections and command execution. - Established configuration validation for both extractor and router interface. - Added context manager support for resource management in both classes.
293 lines
12 KiB
Python
293 lines
12 KiB
Python
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
from unittest.mock import Mock, patch
|
|
from src.models.modality_translation import ModalityTranslationNetwork, ModalityTranslationError
|
|
|
|
|
|
class TestModalityTranslationNetwork:
|
|
"""Test suite for Modality Translation Network following London School TDD principles"""
|
|
|
|
@pytest.fixture
|
|
def mock_config(self):
|
|
"""Configuration for modality translation network"""
|
|
return {
|
|
'input_channels': 6, # Real and imaginary parts for 3 antennas
|
|
'hidden_channels': [64, 128, 256],
|
|
'output_channels': 256,
|
|
'kernel_size': 3,
|
|
'stride': 1,
|
|
'padding': 1,
|
|
'dropout_rate': 0.1,
|
|
'activation': 'relu',
|
|
'normalization': 'batch',
|
|
'use_attention': True,
|
|
'attention_heads': 8
|
|
}
|
|
|
|
@pytest.fixture
|
|
def translation_network(self, mock_config):
|
|
"""Create modality translation network instance for testing"""
|
|
return ModalityTranslationNetwork(mock_config)
|
|
|
|
@pytest.fixture
|
|
def mock_csi_input(self):
|
|
"""Generate mock CSI input tensor"""
|
|
batch_size = 4
|
|
channels = 6 # Real and imaginary parts for 3 antennas
|
|
height = 56 # Number of subcarriers
|
|
width = 100 # Time samples
|
|
return torch.randn(batch_size, channels, height, width)
|
|
|
|
@pytest.fixture
|
|
def mock_target_features(self):
|
|
"""Generate mock target feature tensor for training"""
|
|
batch_size = 4
|
|
feature_dim = 256
|
|
spatial_height = 56
|
|
spatial_width = 100
|
|
return torch.randn(batch_size, feature_dim, spatial_height, spatial_width)
|
|
|
|
def test_network_initialization_creates_correct_architecture(self, mock_config):
|
|
"""Test that modality translation network initializes with correct architecture"""
|
|
# Act
|
|
network = ModalityTranslationNetwork(mock_config)
|
|
|
|
# Assert
|
|
assert network is not None
|
|
assert isinstance(network, nn.Module)
|
|
assert network.input_channels == mock_config['input_channels']
|
|
assert network.output_channels == mock_config['output_channels']
|
|
assert network.use_attention == mock_config['use_attention']
|
|
assert hasattr(network, 'encoder')
|
|
assert hasattr(network, 'decoder')
|
|
if mock_config['use_attention']:
|
|
assert hasattr(network, 'attention')
|
|
|
|
def test_forward_pass_produces_correct_output_shape(self, translation_network, mock_csi_input):
|
|
"""Test that forward pass produces correctly shaped output"""
|
|
# Act
|
|
output = translation_network(mock_csi_input)
|
|
|
|
# Assert
|
|
assert output is not None
|
|
assert isinstance(output, torch.Tensor)
|
|
assert output.shape[0] == mock_csi_input.shape[0] # Batch size preserved
|
|
assert output.shape[1] == translation_network.output_channels # Correct output channels
|
|
assert output.shape[2] == mock_csi_input.shape[2] # Spatial height preserved
|
|
assert output.shape[3] == mock_csi_input.shape[3] # Spatial width preserved
|
|
|
|
def test_forward_pass_handles_different_input_sizes(self, translation_network):
|
|
"""Test that forward pass handles different input sizes"""
|
|
# Arrange
|
|
small_input = torch.randn(2, 6, 28, 50)
|
|
large_input = torch.randn(8, 6, 112, 200)
|
|
|
|
# Act
|
|
small_output = translation_network(small_input)
|
|
large_output = translation_network(large_input)
|
|
|
|
# Assert
|
|
assert small_output.shape == (2, 256, 28, 50)
|
|
assert large_output.shape == (8, 256, 112, 200)
|
|
|
|
def test_encoder_extracts_hierarchical_features(self, translation_network, mock_csi_input):
|
|
"""Test that encoder extracts hierarchical features"""
|
|
# Act
|
|
features = translation_network.encode(mock_csi_input)
|
|
|
|
# Assert
|
|
assert features is not None
|
|
assert isinstance(features, list)
|
|
assert len(features) == len(translation_network.encoder)
|
|
|
|
# Check feature map sizes decrease with depth
|
|
for i in range(1, len(features)):
|
|
assert features[i].shape[2] <= features[i-1].shape[2] # Height decreases or stays same
|
|
assert features[i].shape[3] <= features[i-1].shape[3] # Width decreases or stays same
|
|
|
|
def test_decoder_reconstructs_target_features(self, translation_network, mock_csi_input):
|
|
"""Test that decoder reconstructs target feature representation"""
|
|
# Arrange
|
|
encoded_features = translation_network.encode(mock_csi_input)
|
|
|
|
# Act
|
|
decoded_output = translation_network.decode(encoded_features)
|
|
|
|
# Assert
|
|
assert decoded_output is not None
|
|
assert isinstance(decoded_output, torch.Tensor)
|
|
assert decoded_output.shape[1] == translation_network.output_channels
|
|
assert decoded_output.shape[2:] == mock_csi_input.shape[2:]
|
|
|
|
def test_attention_mechanism_enhances_features(self, mock_config, mock_csi_input):
|
|
"""Test that attention mechanism enhances feature representation"""
|
|
# Arrange
|
|
config_with_attention = mock_config.copy()
|
|
config_with_attention['use_attention'] = True
|
|
|
|
config_without_attention = mock_config.copy()
|
|
config_without_attention['use_attention'] = False
|
|
|
|
network_with_attention = ModalityTranslationNetwork(config_with_attention)
|
|
network_without_attention = ModalityTranslationNetwork(config_without_attention)
|
|
|
|
# Act
|
|
output_with_attention = network_with_attention(mock_csi_input)
|
|
output_without_attention = network_without_attention(mock_csi_input)
|
|
|
|
# Assert
|
|
assert output_with_attention.shape == output_without_attention.shape
|
|
# Outputs should be different due to attention mechanism
|
|
assert not torch.allclose(output_with_attention, output_without_attention, atol=1e-6)
|
|
|
|
def test_training_mode_enables_dropout(self, translation_network, mock_csi_input):
|
|
"""Test that training mode enables dropout for regularization"""
|
|
# Arrange
|
|
translation_network.train()
|
|
|
|
# Act
|
|
output1 = translation_network(mock_csi_input)
|
|
output2 = translation_network(mock_csi_input)
|
|
|
|
# Assert - outputs should be different due to dropout
|
|
assert not torch.allclose(output1, output2, atol=1e-6)
|
|
|
|
def test_evaluation_mode_disables_dropout(self, translation_network, mock_csi_input):
|
|
"""Test that evaluation mode disables dropout for consistent inference"""
|
|
# Arrange
|
|
translation_network.eval()
|
|
|
|
# Act
|
|
output1 = translation_network(mock_csi_input)
|
|
output2 = translation_network(mock_csi_input)
|
|
|
|
# Assert - outputs should be identical in eval mode
|
|
assert torch.allclose(output1, output2, atol=1e-6)
|
|
|
|
def test_compute_translation_loss_measures_feature_alignment(self, translation_network, mock_csi_input, mock_target_features):
|
|
"""Test that compute_translation_loss measures feature alignment"""
|
|
# Arrange
|
|
predicted_features = translation_network(mock_csi_input)
|
|
|
|
# Act
|
|
loss = translation_network.compute_translation_loss(predicted_features, mock_target_features)
|
|
|
|
# Assert
|
|
assert loss is not None
|
|
assert isinstance(loss, torch.Tensor)
|
|
assert loss.dim() == 0 # Scalar loss
|
|
assert loss.item() >= 0 # Loss should be non-negative
|
|
|
|
def test_compute_translation_loss_handles_different_loss_types(self, translation_network, mock_csi_input, mock_target_features):
|
|
"""Test that compute_translation_loss handles different loss types"""
|
|
# Arrange
|
|
predicted_features = translation_network(mock_csi_input)
|
|
|
|
# Act
|
|
mse_loss = translation_network.compute_translation_loss(predicted_features, mock_target_features, loss_type='mse')
|
|
l1_loss = translation_network.compute_translation_loss(predicted_features, mock_target_features, loss_type='l1')
|
|
|
|
# Assert
|
|
assert mse_loss is not None
|
|
assert l1_loss is not None
|
|
assert mse_loss.item() != l1_loss.item() # Different loss types should give different values
|
|
|
|
def test_get_feature_statistics_provides_analysis(self, translation_network, mock_csi_input):
|
|
"""Test that get_feature_statistics provides feature analysis"""
|
|
# Arrange
|
|
output = translation_network(mock_csi_input)
|
|
|
|
# Act
|
|
stats = translation_network.get_feature_statistics(output)
|
|
|
|
# Assert
|
|
assert stats is not None
|
|
assert isinstance(stats, dict)
|
|
assert 'mean' in stats
|
|
assert 'std' in stats
|
|
assert 'min' in stats
|
|
assert 'max' in stats
|
|
assert 'sparsity' in stats
|
|
|
|
def test_network_supports_gradient_computation(self, translation_network, mock_csi_input, mock_target_features):
|
|
"""Test that network supports gradient computation for training"""
|
|
# Arrange
|
|
translation_network.train()
|
|
optimizer = torch.optim.Adam(translation_network.parameters(), lr=0.001)
|
|
|
|
# Act
|
|
output = translation_network(mock_csi_input)
|
|
loss = translation_network.compute_translation_loss(output, mock_target_features)
|
|
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
|
|
# Assert
|
|
for param in translation_network.parameters():
|
|
if param.requires_grad:
|
|
assert param.grad is not None
|
|
assert not torch.allclose(param.grad, torch.zeros_like(param.grad))
|
|
|
|
def test_network_validates_input_dimensions(self, translation_network):
|
|
"""Test that network validates input dimensions"""
|
|
# Arrange
|
|
invalid_input = torch.randn(4, 3, 56, 100) # Wrong number of channels
|
|
|
|
# Act & Assert
|
|
with pytest.raises(ModalityTranslationError):
|
|
translation_network(invalid_input)
|
|
|
|
def test_network_handles_batch_size_one(self, translation_network):
|
|
"""Test that network handles single sample inference"""
|
|
# Arrange
|
|
single_input = torch.randn(1, 6, 56, 100)
|
|
|
|
# Act
|
|
output = translation_network(single_input)
|
|
|
|
# Assert
|
|
assert output.shape == (1, 256, 56, 100)
|
|
|
|
def test_save_and_load_model_state(self, translation_network, mock_csi_input):
|
|
"""Test that model state can be saved and loaded"""
|
|
# Arrange
|
|
original_output = translation_network(mock_csi_input)
|
|
|
|
# Act - Save state
|
|
state_dict = translation_network.state_dict()
|
|
|
|
# Create new network and load state
|
|
new_network = ModalityTranslationNetwork(translation_network.config)
|
|
new_network.load_state_dict(state_dict)
|
|
new_output = new_network(mock_csi_input)
|
|
|
|
# Assert
|
|
assert torch.allclose(original_output, new_output, atol=1e-6)
|
|
|
|
def test_network_configuration_validation(self):
|
|
"""Test that network validates configuration parameters"""
|
|
# Arrange
|
|
invalid_config = {
|
|
'input_channels': 0, # Invalid
|
|
'hidden_channels': [], # Invalid
|
|
'output_channels': 256
|
|
}
|
|
|
|
# Act & Assert
|
|
with pytest.raises(ValueError):
|
|
ModalityTranslationNetwork(invalid_config)
|
|
|
|
def test_feature_visualization_support(self, translation_network, mock_csi_input):
|
|
"""Test that network supports feature visualization"""
|
|
# Act
|
|
features = translation_network.get_intermediate_features(mock_csi_input)
|
|
|
|
# Assert
|
|
assert features is not None
|
|
assert isinstance(features, dict)
|
|
assert 'encoder_features' in features
|
|
assert 'decoder_features' in features
|
|
if translation_network.use_attention:
|
|
assert 'attention_weights' in features |